1use futures_util::Stream;
4use pin_project_lite::pin_project;
5use serde::Serialize;
6use std::collections::HashSet;
7use std::pin::Pin;
8use std::task::{Context, Poll};
9use std::time::Duration;
10use tokio::time::sleep;
11
12use crate::client::XaiClient;
13use crate::config::RetryPolicy;
14use crate::models::content::ContentPart;
15use crate::models::message::{Message, MessageContent, Role};
16use crate::models::response::{OutputItem, Response, ResponseFormat, StreamChunk, TextContent};
17use crate::models::tool::{Tool, ToolCall, ToolChoice};
18use crate::stream::ResponseStream;
19use crate::{Error, Result};
20
21const DEFAULT_DEFERRED_MAX_ATTEMPTS: u32 = 30;
22const DEFAULT_DEFERRED_POLL_INTERVAL: Duration = Duration::ZERO;
23const DEFAULT_STATEFUL_TOOL_LOOP_MAX_ROUNDS: u32 = 8;
24
25#[derive(Debug, Clone)]
27pub struct ResponsesApi {
28 client: XaiClient,
29}
30
31impl ResponsesApi {
32 pub(crate) fn new(client: XaiClient) -> Self {
33 Self { client }
34 }
35
36 pub fn create(&self, model: impl Into<String>) -> CreateResponseBuilder {
57 CreateResponseBuilder::new(self.client.clone(), model.into())
58 }
59
60 pub fn deferred(&self, response_id: impl Into<String>) -> DeferredResponsePoller {
62 DeferredResponsePoller::new(self.client.clone(), response_id.into())
63 }
64
65 pub fn chat(&self, model: impl Into<String>) -> StatefulChat {
70 StatefulChat::new(self.client.clone(), model.into())
71 }
72
73 pub async fn get(&self, response_id: &str) -> Result<Response> {
75 let id = XaiClient::encode_path(response_id);
76 let url = format!("{}/responses/{}", self.client.base_url(), id);
77
78 let response = self.client.send(self.client.http().get(&url)).await?;
79
80 if !response.status().is_success() {
81 return Err(Error::from_response(response).await);
82 }
83
84 Ok(response.json().await?)
85 }
86
87 pub async fn delete(&self, response_id: &str) -> Result<()> {
89 let id = XaiClient::encode_path(response_id);
90 let url = format!("{}/responses/{}", self.client.base_url(), id);
91
92 let response = self.client.send(self.client.http().delete(&url)).await?;
93
94 if !response.status().is_success() {
95 return Err(Error::from_response(response).await);
96 }
97
98 Ok(())
99 }
100
101 pub async fn poll_until_ready(&self, response_id: &str, max_attempts: u32) -> Result<Response> {
103 self.deferred(response_id.to_string())
104 .max_attempts(max_attempts)
105 .wait()
106 .await
107 }
108}
109
110#[derive(Debug, Clone)]
112pub struct DeferredResponsePoller {
113 client: XaiClient,
114 response_id: String,
115 max_attempts: u32,
116 poll_initial_delay: Duration,
117 poll_max_delay: Duration,
118}
119
120impl DeferredResponsePoller {
121 fn new(client: XaiClient, response_id: String) -> Self {
122 Self {
123 client,
124 response_id,
125 max_attempts: DEFAULT_DEFERRED_MAX_ATTEMPTS,
126 poll_initial_delay: DEFAULT_DEFERRED_POLL_INTERVAL,
127 poll_max_delay: DEFAULT_DEFERRED_POLL_INTERVAL,
128 }
129 }
130
131 pub fn max_attempts(mut self, max_attempts: u32) -> Self {
133 self.max_attempts = max_attempts.max(1);
134 self
135 }
136
137 pub fn poll_interval(mut self, interval: Duration) -> Self {
139 self.poll_initial_delay = interval;
140 self.poll_max_delay = interval;
141 self
142 }
143
144 pub fn poll_backoff(mut self, initial: Duration, max: Duration) -> Self {
146 self.poll_initial_delay = initial;
147 self.poll_max_delay = max.max(initial);
148 self
149 }
150
151 fn poll_delay_for(initial: Duration, max: Duration, attempt: u32) -> Duration {
152 let max_millis = max.as_millis();
153 let initial_millis = initial.as_millis();
154 if max_millis == 0 || initial_millis == 0 {
155 return Duration::ZERO;
156 }
157
158 let factor_shift = attempt.min(16);
159 let factor = 1u128 << factor_shift;
160 let delayed = initial_millis.saturating_mul(factor).min(max_millis);
161 Duration::from_millis(delayed as u64)
162 }
163
164 pub async fn wait(self) -> Result<Response> {
166 let DeferredResponsePoller {
167 client,
168 response_id,
169 max_attempts,
170 poll_initial_delay,
171 poll_max_delay,
172 } = self;
173
174 let api = ResponsesApi::new(client);
175
176 for attempt in 0..max_attempts {
177 let response = api.get(&response_id).await?;
178 if !response.output.is_empty() {
179 return Ok(response);
180 }
181
182 if attempt + 1 < max_attempts {
183 let delay = Self::poll_delay_for(poll_initial_delay, poll_max_delay, attempt);
184 if !delay.is_zero() {
185 sleep(delay).await;
186 }
187 }
188 }
189
190 Err(Error::Timeout)
191 }
192}
193
194#[derive(Debug, Clone)]
196pub struct StatefulChat {
197 client: XaiClient,
198 model: String,
199 messages: Vec<Message>,
200 pending_tool_calls: Vec<ToolCall>,
201}
202
203impl StatefulChat {
204 fn new(client: XaiClient, model: String) -> Self {
205 Self {
206 client,
207 model,
208 messages: Vec::new(),
209 pending_tool_calls: Vec::new(),
210 }
211 }
212
213 pub fn append(&mut self, role: Role, content: impl Into<MessageContent>) -> &mut Self {
215 self.messages.push(Message::new(role, content));
216 self
217 }
218
219 pub fn append_system(&mut self, content: impl Into<String>) -> &mut Self {
221 self.append(Role::System, content.into())
222 }
223
224 pub fn append_user(&mut self, content: impl Into<MessageContent>) -> &mut Self {
226 self.append(Role::User, content)
227 }
228
229 pub fn append_assistant(&mut self, content: impl Into<String>) -> &mut Self {
231 self.append(Role::Assistant, content.into())
232 }
233
234 pub fn append_message(&mut self, message: Message) -> &mut Self {
236 self.messages.push(message);
237 self
238 }
239
240 pub fn append_tool_result(
242 &mut self,
243 tool_call_id: impl Into<String>,
244 content: impl Into<String>,
245 ) -> &mut Self {
246 self.append_message(Message::tool(tool_call_id, content))
247 }
248
249 pub fn messages(&self) -> &[Message] {
251 &self.messages
252 }
253
254 pub fn pending_tool_calls(&self) -> &[ToolCall] {
256 &self.pending_tool_calls
257 }
258
259 pub fn take_pending_tool_calls(&mut self) -> Vec<ToolCall> {
261 std::mem::take(&mut self.pending_tool_calls)
262 }
263
264 pub fn clear(&mut self) -> &mut Self {
266 self.messages.clear();
267 self.pending_tool_calls.clear();
268 self
269 }
270
271 fn text_content_slice(content: &TextContent) -> &str {
272 match content {
273 TextContent::Text { text } => text,
274 TextContent::Refusal { refusal } => refusal,
276 }
277 }
278
279 fn merge_output_message_content(content: &[TextContent]) -> Option<String> {
280 match content {
281 [] => None,
282 [single] => {
283 let text = Self::text_content_slice(single);
284 (!text.is_empty()).then(|| text.to_string())
285 }
286 _ => {
287 let total_len: usize = content
288 .iter()
289 .map(|part| Self::text_content_slice(part).len())
290 .sum();
291 if total_len == 0 {
292 return None;
293 }
294
295 let mut merged = String::with_capacity(total_len);
296 for part in content {
297 merged.push_str(Self::text_content_slice(part));
298 }
299 Some(merged)
300 }
301 }
302 }
303
304 fn collect_response_semantics(response: &Response) -> (Vec<String>, Vec<ToolCall>) {
305 let output_item_count = response.output.len();
306 let top_level_call_count = response.tool_calls.as_ref().map_or(0, Vec::len);
307 let mut assistant_messages = Vec::with_capacity(output_item_count);
308 let mut pending_tool_calls = Vec::with_capacity(top_level_call_count + output_item_count);
309 let mut seen_tool_call_ids: HashSet<String> =
310 HashSet::with_capacity(top_level_call_count + output_item_count);
311
312 if let Some(calls) = &response.tool_calls {
314 for call in calls {
315 if seen_tool_call_ids.insert(call.id.clone()) {
316 pending_tool_calls.push(call.clone());
317 }
318 }
319 }
320
321 for item in &response.output {
322 match item {
323 OutputItem::Message { content, .. } => {
324 if let Some(merged) = Self::merge_output_message_content(content) {
325 assistant_messages.push(merged);
326 }
327 }
328 OutputItem::FunctionCall { call } => {
329 if seen_tool_call_ids.insert(call.id.clone()) {
330 pending_tool_calls.push(call.clone());
331 }
332 }
333 OutputItem::CodeInterpreterCall { id, .. } => {
334 if seen_tool_call_ids.insert(id.clone()) {
335 pending_tool_calls.push(ToolCall {
336 id: id.clone(),
337 call_type: Some("code_interpreter".to_string()),
338 function: None,
339 });
340 }
341 }
342 OutputItem::WebSearchCall { id, .. } => {
343 if seen_tool_call_ids.insert(id.clone()) {
344 pending_tool_calls.push(ToolCall {
345 id: id.clone(),
346 call_type: Some("web_search".to_string()),
347 function: None,
348 });
349 }
350 }
351 OutputItem::XSearchCall { id, .. } => {
352 if seen_tool_call_ids.insert(id.clone()) {
353 pending_tool_calls.push(ToolCall {
354 id: id.clone(),
355 call_type: Some("x_search".to_string()),
356 function: None,
357 });
358 }
359 }
360 }
361 }
362
363 (assistant_messages, pending_tool_calls)
364 }
365
366 pub async fn sample(&self) -> Result<Response> {
368 self.client
369 .responses()
370 .create(self.model.clone())
371 .messages(self.messages.clone())
372 .send()
373 .await
374 }
375
376 pub fn append_response_text(&mut self, response: &Response) -> &mut Self {
378 let text = response.all_text();
379 if !text.is_empty() {
380 self.append_assistant(text);
381 }
382 self
383 }
384
385 pub fn append_response_semantics(&mut self, response: &Response) -> &mut Self {
387 let (assistant_messages, pending_tool_calls) = Self::collect_response_semantics(response);
388 for text in assistant_messages {
389 self.append_assistant(text);
390 }
391 self.pending_tool_calls = pending_tool_calls;
392 self
393 }
394
395 pub async fn sample_and_append(&mut self) -> Result<Response> {
397 let response = self.sample().await?;
398 self.append_response_semantics(&response);
399 Ok(response)
400 }
401
402 pub async fn sample_with_tool_loop<H, Fut>(&mut self, handler: H) -> Result<Response>
408 where
409 H: FnMut(ToolCall) -> Fut,
410 Fut: std::future::Future<Output = Result<String>>,
411 {
412 self.sample_with_tool_handler(DEFAULT_STATEFUL_TOOL_LOOP_MAX_ROUNDS, handler)
413 .await
414 }
415
416 pub async fn sample_with_tool_handler<H, Fut>(
418 &mut self,
419 max_rounds: u32,
420 mut handler: H,
421 ) -> Result<Response>
422 where
423 H: FnMut(ToolCall) -> Fut,
424 Fut: std::future::Future<Output = Result<String>>,
425 {
426 let rounds = max_rounds.max(1);
427
428 for _ in 0..rounds {
429 let response = self.sample_and_append().await?;
430 let pending = self.take_pending_tool_calls();
431 if pending.is_empty() {
432 return Ok(response);
433 }
434
435 for call in pending {
436 let tool_call_id = call.id.clone();
437 let tool_result = handler(call).await?;
438 self.append_tool_result(tool_call_id, tool_result);
439 }
440 }
441
442 Err(Error::Config(format!(
443 "stateful chat tool loop exceeded max rounds ({rounds})"
444 )))
445 }
446
447 pub async fn stream(&self) -> Result<StatefulChatStream> {
449 let stream = self
450 .client
451 .responses()
452 .create(self.model.clone())
453 .messages(self.messages.clone())
454 .stream()
455 .await?;
456
457 Ok(StatefulChatStream::new(stream))
458 }
459}
460
461pin_project! {
462 pub struct StatefulChatStream {
464 #[pin]
465 inner: ResponseStream,
466 accumulated_text: String,
467 }
468}
469
470#[derive(Debug, Clone)]
472pub struct AccumulatedChunk {
473 pub chunk: StreamChunk,
475 pub accumulated_text: String,
477}
478
479impl StatefulChatStream {
480 fn new(inner: ResponseStream) -> Self {
481 Self {
482 inner,
483 accumulated_text: String::new(),
484 }
485 }
486
487 pub fn accumulated_text(&self) -> &str {
489 &self.accumulated_text
490 }
491
492 pub fn into_accumulated_text(self) -> String {
494 self.accumulated_text
495 }
496
497 pub async fn next_with_accumulated(&mut self) -> Option<Result<AccumulatedChunk>> {
499 use futures_util::future::poll_fn;
500
501 poll_fn(|cx| Pin::new(&mut *self).poll_next(cx))
502 .await
503 .map(|result| {
504 result.map(|chunk| AccumulatedChunk {
505 chunk,
506 accumulated_text: self.accumulated_text.clone(),
507 })
508 })
509 }
510}
511
512impl Stream for StatefulChatStream {
513 type Item = Result<StreamChunk>;
514
515 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
516 let mut this = self.project();
517
518 match this.inner.as_mut().poll_next(cx) {
519 Poll::Ready(Some(Ok(chunk))) => {
520 this.accumulated_text.push_str(chunk.delta());
521 Poll::Ready(Some(Ok(chunk)))
522 }
523 Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))),
524 Poll::Ready(None) => Poll::Ready(None),
525 Poll::Pending => Poll::Pending,
526 }
527 }
528}
529
530#[derive(Debug)]
532pub struct CreateResponseBuilder {
533 client: XaiClient,
534 request: CreateResponseRequest,
535 retry_policy_override: Option<RetryPolicy>,
536}
537
538#[derive(Debug, Clone, Serialize)]
539struct CreateResponseRequest {
540 model: String,
541 input: Vec<Message>,
542 #[serde(skip_serializing_if = "Option::is_none")]
543 tools: Option<Vec<Tool>>,
544 #[serde(skip_serializing_if = "Option::is_none")]
545 tool_choice: Option<ToolChoice>,
546 #[serde(skip_serializing_if = "Option::is_none")]
547 temperature: Option<f32>,
548 #[serde(skip_serializing_if = "Option::is_none")]
549 top_p: Option<f32>,
550 #[serde(skip_serializing_if = "Option::is_none")]
551 max_tokens: Option<u32>,
552 #[serde(skip_serializing_if = "Option::is_none")]
553 stream: Option<bool>,
554 #[serde(skip_serializing_if = "Option::is_none")]
555 response_format: Option<ResponseFormat>,
556 #[serde(skip_serializing_if = "Option::is_none")]
557 include: Option<Vec<String>>,
558 #[serde(skip_serializing_if = "Option::is_none")]
559 store: Option<bool>,
560}
561
562impl CreateResponseBuilder {
563 fn new(client: XaiClient, model: String) -> Self {
564 Self {
565 client,
566 request: CreateResponseRequest {
567 model,
568 input: Vec::new(),
569 tools: None,
570 tool_choice: None,
571 temperature: None,
572 top_p: None,
573 max_tokens: None,
574 stream: None,
575 response_format: None,
576 include: None,
577 store: None,
578 },
579 retry_policy_override: None,
580 }
581 }
582
583 fn retry_policy_mut(&mut self) -> &mut RetryPolicy {
584 self.retry_policy_override
585 .get_or_insert_with(|| self.client.retry_policy())
586 }
587
588 pub fn message(mut self, role: Role, content: impl Into<MessageContent>) -> Self {
590 self.request.input.push(Message::new(role, content));
591 self
592 }
593
594 pub fn system(self, content: impl Into<String>) -> Self {
596 self.message(Role::System, content.into())
597 }
598
599 pub fn user(self, content: impl Into<MessageContent>) -> Self {
601 self.message(Role::User, content)
602 }
603
604 pub fn assistant(self, content: impl Into<String>) -> Self {
606 self.message(Role::Assistant, content.into())
607 }
608
609 pub fn user_with_image(
611 mut self,
612 text: impl Into<String>,
613 image_url: impl Into<String>,
614 ) -> Self {
615 let parts = vec![ContentPart::text(text), ContentPart::image_url(image_url)];
616 self.request
617 .input
618 .push(Message::new(Role::User, MessageContent::Parts(parts)));
619 self
620 }
621
622 pub fn messages(mut self, messages: Vec<Message>) -> Self {
624 self.request.input.extend(messages);
625 self
626 }
627
628 pub fn tool(mut self, tool: Tool) -> Self {
630 self.request.tools.get_or_insert_with(Vec::new).push(tool);
631 self
632 }
633
634 pub fn tools(mut self, tools: Vec<Tool>) -> Self {
636 self.request
637 .tools
638 .get_or_insert_with(Vec::new)
639 .extend(tools);
640 self
641 }
642
643 pub fn tool_choice(mut self, choice: ToolChoice) -> Self {
645 self.request.tool_choice = Some(choice);
646 self
647 }
648
649 pub fn temperature(mut self, temperature: f32) -> Self {
651 self.request.temperature = Some(temperature.clamp(0.0, 2.0));
652 self
653 }
654
655 pub fn top_p(mut self, top_p: f32) -> Self {
657 self.request.top_p = Some(top_p.clamp(0.0, 1.0));
658 self
659 }
660
661 pub fn max_tokens(mut self, max_tokens: u32) -> Self {
663 self.request.max_tokens = Some(max_tokens);
664 self
665 }
666
667 pub fn max_retries(mut self, max_retries: u32) -> Self {
669 self.retry_policy_mut().max_retries = max_retries;
670 self
671 }
672
673 pub fn disable_retries(self) -> Self {
675 self.max_retries(0)
676 }
677
678 pub fn retry_backoff(mut self, initial: Duration, max: Duration) -> Self {
680 let policy = self.retry_policy_mut();
681 policy.initial_backoff = initial;
682 policy.max_backoff = max.max(initial);
683 self
684 }
685
686 pub fn retry_jitter(mut self, factor: f64) -> Self {
688 self.retry_policy_mut().jitter_factor = factor.clamp(0.0, 1.0);
689 self
690 }
691
692 pub fn response_format(mut self, format: ResponseFormat) -> Self {
694 self.request.response_format = Some(format);
695 self
696 }
697
698 pub fn json_output(self) -> Self {
700 self.response_format(ResponseFormat::json_object())
701 }
702
703 pub fn include(mut self, fields: Vec<String>) -> Self {
705 self.request.include = Some(fields);
706 self
707 }
708
709 pub fn with_inline_citations(mut self) -> Self {
711 self.request
712 .include
713 .get_or_insert_with(Vec::new)
714 .push("inline_citations".to_string());
715 self
716 }
717
718 pub fn with_verbose_streaming(mut self) -> Self {
720 self.request
721 .include
722 .get_or_insert_with(Vec::new)
723 .push("verbose_streaming".to_string());
724 self
725 }
726
727 pub fn store(mut self, store: bool) -> Self {
729 self.request.store = Some(store);
730 self
731 }
732
733 pub async fn send(self) -> Result<Response> {
735 let url = format!("{}/responses", self.client.base_url());
736
737 let response = self
738 .client
739 .send_with_retry_policy(
740 self.client.http().post(&url).json(&self.request),
741 self.retry_policy_override,
742 )
743 .await?;
744
745 if !response.status().is_success() {
746 return Err(Error::from_response(response).await);
747 }
748
749 Ok(response.json().await?)
750 }
751
752 pub async fn stream(mut self) -> Result<ResponseStream> {
754 self.request.stream = Some(true);
755
756 let url = format!("{}/responses", self.client.base_url());
757
758 let response = self
759 .client
760 .send_with_retry_policy(
761 self.client.http().post(&url).json(&self.request),
762 self.retry_policy_override,
763 )
764 .await?;
765
766 if !response.status().is_success() {
767 return Err(Error::from_response(response).await);
768 }
769
770 Ok(ResponseStream::new(response.bytes_stream()))
771 }
772}
773
774#[cfg(test)]
775mod tests {
776 use super::*;
777 use bytes::Bytes;
778 use futures_util::{stream, StreamExt};
779 use serde_json::json;
780 use std::sync::{
781 atomic::{AtomicUsize, Ordering},
782 Arc,
783 };
784 use wiremock::{
785 matchers::{body_partial_json, method, path},
786 Mock, MockServer, ResponseTemplate,
787 };
788
789 #[test]
790 fn user_with_image_puts_text_before_image() {
791 let client = XaiClient::new("test-key").unwrap();
792 let api = ResponsesApi::new(client);
793 let builder = api
794 .create("grok-4")
795 .user_with_image("describe this", "https://example.com/image.jpg");
796
797 assert_eq!(builder.request.input.len(), 1);
798 let msg = &builder.request.input[0];
799 assert!(matches!(msg.role, Role::User));
800
801 match &msg.content {
802 MessageContent::Parts(parts) => {
803 assert_eq!(parts.len(), 2);
804 assert!(matches!(
805 &parts[0],
806 ContentPart::Text { text } if text == "describe this"
807 ));
808 assert!(matches!(
809 &parts[1],
810 ContentPart::ImageUrl { image_url } if image_url.url == "https://example.com/image.jpg"
811 ));
812 }
813 _ => panic!("Expected multipart message content"),
814 }
815 }
816
817 #[tokio::test]
818 async fn create_builder_retry_override_enables_retries_when_client_retries_disabled() {
819 let server = MockServer::start().await;
820 let call_count = Arc::new(AtomicUsize::new(0));
821 let responder_count = Arc::clone(&call_count);
822
823 Mock::given(method("POST"))
824 .and(path("/responses"))
825 .respond_with(move |_req: &wiremock::Request| {
826 let count = responder_count.fetch_add(1, Ordering::SeqCst);
827 if count == 0 {
828 ResponseTemplate::new(503).set_body_json(json!({
829 "error": {"message": "temporary", "type": "server_error"}
830 }))
831 } else {
832 ResponseTemplate::new(200).set_body_json(json!({
833 "id": "resp_retry",
834 "model": "grok-4",
835 "output": [{
836 "type": "message",
837 "role": "assistant",
838 "content": [{"type": "text", "text": "retry worked"}]
839 }],
840 "usage": {"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}
841 }))
842 }
843 })
844 .mount(&server)
845 .await;
846
847 let client = XaiClient::builder()
848 .api_key("test-key")
849 .base_url(server.uri())
850 .disable_retries()
851 .build()
852 .unwrap();
853
854 let response = client
855 .responses()
856 .create("grok-4")
857 .user("hello")
858 .max_retries(1)
859 .retry_backoff(Duration::ZERO, Duration::ZERO)
860 .send()
861 .await
862 .unwrap();
863
864 assert_eq!(response.output_text().as_deref(), Some("retry worked"));
865 assert_eq!(call_count.load(Ordering::SeqCst), 2);
866 }
867
868 #[tokio::test]
869 async fn create_builder_disable_retries_overrides_client_retry_policy() {
870 let server = MockServer::start().await;
871 let call_count = Arc::new(AtomicUsize::new(0));
872 let responder_count = Arc::clone(&call_count);
873
874 Mock::given(method("POST"))
875 .and(path("/responses"))
876 .respond_with(move |_req: &wiremock::Request| {
877 responder_count.fetch_add(1, Ordering::SeqCst);
878 ResponseTemplate::new(503).set_body_json(json!({
879 "error": {"message": "still unavailable", "type": "server_error"}
880 }))
881 })
882 .mount(&server)
883 .await;
884
885 let client = XaiClient::builder()
886 .api_key("test-key")
887 .base_url(server.uri())
888 .max_retries(2)
889 .retry_backoff(Duration::ZERO, Duration::ZERO)
890 .build()
891 .unwrap();
892
893 let err = client
894 .responses()
895 .create("grok-4")
896 .user("hello")
897 .disable_retries()
898 .send()
899 .await
900 .unwrap_err();
901
902 assert!(matches!(err, Error::Api { status: 503, .. }));
903 assert_eq!(call_count.load(Ordering::SeqCst), 1);
904 }
905
906 #[tokio::test]
907 async fn responses_api_get_propagates_api_error() {
908 let server = MockServer::start().await;
909
910 Mock::given(method("GET"))
911 .and(path("/responses/missing"))
912 .respond_with(ResponseTemplate::new(404).set_body_json(json!({
913 "error": {"message": "response not found"}
914 })))
915 .mount(&server)
916 .await;
917
918 let client = XaiClient::builder()
919 .api_key("test-key")
920 .base_url(server.uri())
921 .build()
922 .unwrap();
923
924 let err = client.responses().get("missing").await.unwrap_err();
925 match err {
926 Error::Api {
927 status, message, ..
928 } => {
929 assert_eq!(status, 404);
930 assert_eq!(message, "response not found");
931 }
932 _ => panic!("expected Error::Api"),
933 }
934 }
935
936 #[tokio::test]
937 async fn responses_api_delete_propagates_api_error() {
938 let server = MockServer::start().await;
939
940 Mock::given(method("DELETE"))
941 .and(path("/responses/missing"))
942 .respond_with(ResponseTemplate::new(404).set_body_json(json!({
943 "error": {"message": "response not found"}
944 })))
945 .mount(&server)
946 .await;
947
948 let client = XaiClient::builder()
949 .api_key("test-key")
950 .base_url(server.uri())
951 .build()
952 .unwrap();
953
954 let err = client.responses().delete("missing").await.unwrap_err();
955 match err {
956 Error::Api {
957 status, message, ..
958 } => {
959 assert_eq!(status, 404);
960 assert_eq!(message, "response not found");
961 }
962 _ => panic!("expected Error::Api"),
963 }
964 }
965
966 #[tokio::test]
967 async fn create_response_builder_send_propagates_api_error() {
968 let server = MockServer::start().await;
969
970 Mock::given(method("POST"))
971 .and(path("/responses"))
972 .and(body_partial_json(json!({
973 "model": "grok-4",
974 "input": [{"role": "user", "content": "error"}]
975 })))
976 .respond_with(ResponseTemplate::new(503).set_body_json(json!({
977 "error": {"message": "service unavailable"}
978 })))
979 .mount(&server)
980 .await;
981
982 let client = XaiClient::builder()
983 .api_key("test-key")
984 .base_url(server.uri())
985 .build()
986 .unwrap();
987
988 let err = client
989 .responses()
990 .create("grok-4")
991 .user("error")
992 .send()
993 .await
994 .unwrap_err();
995
996 match err {
997 Error::Api {
998 status, message, ..
999 } => {
1000 assert_eq!(status, 503);
1001 assert_eq!(message, "service unavailable");
1002 }
1003 _ => panic!("expected Error::Api"),
1004 }
1005 }
1006
1007 #[tokio::test]
1008 async fn create_response_builder_stream_propagates_api_error() {
1009 let server = MockServer::start().await;
1010
1011 Mock::given(method("POST"))
1012 .and(path("/responses"))
1013 .and(body_partial_json(json!({
1014 "model": "grok-4",
1015 "input": [{"role": "user", "content": "stream error"}]
1016 })))
1017 .respond_with(ResponseTemplate::new(503).set_body_json(json!({
1018 "error": {"message": "stream unavailable"}
1019 })))
1020 .mount(&server)
1021 .await;
1022
1023 let client = XaiClient::builder()
1024 .api_key("test-key")
1025 .base_url(server.uri())
1026 .build()
1027 .unwrap();
1028
1029 let err = match client
1030 .responses()
1031 .create("grok-4")
1032 .user("stream error")
1033 .stream()
1034 .await
1035 {
1036 Ok(_) => panic!("expected stream creation to fail"),
1037 Err(err) => err,
1038 };
1039
1040 match err {
1041 Error::Api {
1042 status, message, ..
1043 } => {
1044 assert_eq!(status, 503);
1045 assert_eq!(message, "stream unavailable");
1046 }
1047 _ => panic!("expected Error::Api"),
1048 }
1049 }
1050
1051 #[tokio::test]
1052 async fn create_response_builder_stream_returns_stream() {
1053 let server = MockServer::start().await;
1054
1055 let payload = concat!(
1056 "data: {\"type\":\"response.output_item.delta\",\"delta\":{\"content\":[{\"type\":\"text\",\"text\":\"done\"}]}}\n\n",
1057 "data: {\"type\":\"response.done\",\"response\":{\"id\":\"resp_stream\",\"model\":\"grok-4\",\"output\":[{\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"text\",\"text\":\"done\"}]}],\"usage\":{\"prompt_tokens\":1,\"completion_tokens\":1,\"total_tokens\":2}}}\n\n",
1058 "data: [DONE]\n\n"
1059 );
1060
1061 Mock::given(method("POST"))
1062 .and(path("/responses"))
1063 .respond_with(ResponseTemplate::new(200).set_body_string(payload))
1064 .mount(&server)
1065 .await;
1066
1067 let client = XaiClient::builder()
1068 .api_key("test-key")
1069 .base_url(server.uri())
1070 .build()
1071 .unwrap();
1072
1073 let mut stream = client
1074 .responses()
1075 .create("grok-4")
1076 .user("stream response")
1077 .stream()
1078 .await
1079 .unwrap();
1080
1081 let first = stream.next().await.unwrap().unwrap();
1082 assert!(!first.done);
1083 assert_eq!(first.delta(), "done");
1084 let done = stream.next().await.unwrap().unwrap();
1085 assert!(done.done);
1086 assert!(done.response.is_some());
1087 assert_eq!(done.response.unwrap().id, "resp_stream");
1088 }
1089
1090 #[tokio::test]
1091 async fn create_get_poll_delete_roundtrip() {
1092 let server = MockServer::start().await;
1093 let post_count = Arc::new(AtomicUsize::new(0));
1094 let get_count = Arc::new(AtomicUsize::new(0));
1095 let delete_count = Arc::new(AtomicUsize::new(0));
1096
1097 let post_count_for_responder = Arc::clone(&post_count);
1098 Mock::given(method("POST"))
1099 .and(path("/responses"))
1100 .and(body_partial_json(json!( {
1101 "model": "grok-4",
1102 "input": [
1103 {"role": "system", "content": "Roundtrip test system"},
1104 {"role": "user", "content": "What is 1+1?"}
1105 ]
1106 })))
1107 .respond_with(move |_req: &wiremock::Request| {
1108 post_count_for_responder.fetch_add(1, Ordering::SeqCst);
1109 ResponseTemplate::new(200).set_body_json(json!({
1110 "id": "resp_roundtrip",
1111 "model": "grok-4",
1112 "output": [],
1113 "usage": {"prompt_tokens": 10, "completion_tokens": 0, "total_tokens": 10}
1114 }))
1115 })
1116 .mount(&server)
1117 .await;
1118
1119 let get_count_for_responder = Arc::clone(&get_count);
1120 Mock::given(method("GET"))
1121 .and(path("/responses/resp_roundtrip"))
1122 .respond_with(move |_req: &wiremock::Request| {
1123 let count = get_count_for_responder.fetch_add(1, Ordering::SeqCst);
1124 if count == 0 {
1125 ResponseTemplate::new(200).set_body_json(json!({
1126 "id": "resp_roundtrip",
1127 "model": "grok-4",
1128 "output": [],
1129 "usage": {"prompt_tokens": 10, "completion_tokens": 0, "total_tokens": 10}
1130 }))
1131 } else {
1132 ResponseTemplate::new(200).set_body_json(json!({
1133 "id": "resp_roundtrip",
1134 "model": "grok-4",
1135 "output": [{
1136 "type": "message",
1137 "role": "assistant",
1138 "content": [{"type": "text", "text": "hello"}]
1139 }],
1140 "usage": {"prompt_tokens": 11, "completion_tokens": 1, "total_tokens": 12}
1141 }))
1142 }
1143 })
1144 .mount(&server)
1145 .await;
1146
1147 let delete_count_for_responder = Arc::clone(&delete_count);
1148 Mock::given(method("DELETE"))
1149 .and(path("/responses/resp_roundtrip"))
1150 .respond_with(move |_req: &wiremock::Request| {
1151 delete_count_for_responder.fetch_add(1, Ordering::SeqCst);
1152 ResponseTemplate::new(200).set_body_json(json!({"id": "resp_roundtrip"}))
1153 })
1154 .mount(&server)
1155 .await;
1156
1157 let client = XaiClient::builder()
1158 .api_key("test-key")
1159 .base_url(server.uri())
1160 .build()
1161 .unwrap();
1162
1163 let create_response = client
1164 .responses()
1165 .create("grok-4")
1166 .system("Roundtrip test system")
1167 .user("What is 1+1?")
1168 .send()
1169 .await
1170 .unwrap();
1171 assert_eq!(create_response.id, "resp_roundtrip");
1172 assert!(create_response.output_text().is_none());
1173
1174 let get_response = client.responses().get("resp_roundtrip").await.unwrap();
1175 assert_eq!(get_response.id, "resp_roundtrip");
1176 assert!(get_response.output_text().is_none());
1177
1178 let polled_response = client
1179 .responses()
1180 .poll_until_ready("resp_roundtrip", 2)
1181 .await
1182 .unwrap();
1183 assert_eq!(polled_response.output_text().as_deref(), Some("hello"));
1184
1185 client.responses().delete("resp_roundtrip").await.unwrap();
1186
1187 assert_eq!(post_count.load(Ordering::SeqCst), 1);
1188 assert_eq!(get_count.load(Ordering::SeqCst), 2);
1189 assert_eq!(delete_count.load(Ordering::SeqCst), 1);
1190 }
1191
1192 #[tokio::test]
1193 async fn stateful_chat_append_and_sample_sends_all_messages() {
1194 let server = MockServer::start().await;
1195
1196 Mock::given(method("POST"))
1197 .and(path("/responses"))
1198 .and(body_partial_json(json!({
1199 "model": "grok-4",
1200 "input": [
1201 {"role": "system", "content": "You are helpful."},
1202 {"role": "user", "content": "Hello"}
1203 ]
1204 })))
1205 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
1206 "id": "resp_123",
1207 "model": "grok-4",
1208 "output": [{
1209 "type": "message",
1210 "role": "assistant",
1211 "content": [{"type": "text", "text": "Hi there!"}]
1212 }],
1213 "usage": {
1214 "prompt_tokens": 10,
1215 "completion_tokens": 3,
1216 "total_tokens": 13
1217 }
1218 })))
1219 .mount(&server)
1220 .await;
1221
1222 let client = XaiClient::builder()
1223 .api_key("test-key")
1224 .base_url(server.uri())
1225 .build()
1226 .unwrap();
1227
1228 let mut chat = client.responses().chat("grok-4");
1229 chat.append_system("You are helpful.").append_user("Hello");
1230
1231 assert_eq!(chat.messages().len(), 2);
1232
1233 let response = chat.sample().await.unwrap();
1234 assert_eq!(response.output_text().as_deref(), Some("Hi there!"));
1235 }
1236
1237 #[tokio::test]
1238 async fn stateful_chat_sample_and_append_updates_local_history() {
1239 let server = MockServer::start().await;
1240
1241 Mock::given(method("POST"))
1242 .and(path("/responses"))
1243 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
1244 "id": "resp_124",
1245 "model": "grok-4",
1246 "output": [{
1247 "type": "message",
1248 "role": "assistant",
1249 "content": [{"type": "text", "text": "History updated"}]
1250 }],
1251 "usage": {"prompt_tokens": 10, "completion_tokens": 2, "total_tokens": 12}
1252 })))
1253 .mount(&server)
1254 .await;
1255
1256 let client = XaiClient::builder()
1257 .api_key("test-key")
1258 .base_url(server.uri())
1259 .build()
1260 .unwrap();
1261
1262 let mut chat = client.responses().chat("grok-4");
1263 chat.append_system("You are helpful.").append_user("Hello");
1264
1265 let response = chat.sample_and_append().await.unwrap();
1266 assert_eq!(response.output_text().as_deref(), Some("History updated"));
1267 assert_eq!(chat.messages().len(), 3);
1268 assert!(chat.pending_tool_calls().is_empty());
1269
1270 let last = chat.messages().last().unwrap();
1271 assert!(matches!(last.role, Role::Assistant));
1272 assert_eq!(last.content.as_text(), Some("History updated"));
1273 }
1274
1275 #[tokio::test]
1276 async fn stateful_chat_sample_and_append_skips_empty_text() {
1277 let server = MockServer::start().await;
1278
1279 Mock::given(method("POST"))
1280 .and(path("/responses"))
1281 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
1282 "id": "resp_125",
1283 "model": "grok-4",
1284 "output": [],
1285 "usage": {"prompt_tokens": 10, "completion_tokens": 0, "total_tokens": 10}
1286 })))
1287 .mount(&server)
1288 .await;
1289
1290 let client = XaiClient::builder()
1291 .api_key("test-key")
1292 .base_url(server.uri())
1293 .build()
1294 .unwrap();
1295
1296 let mut chat = client.responses().chat("grok-4");
1297 chat.append_system("You are helpful.").append_user("Hello");
1298
1299 let response = chat.sample_and_append().await.unwrap();
1300 assert!(response.output_text().is_none());
1301 assert_eq!(chat.messages().len(), 2);
1302 assert!(chat.pending_tool_calls().is_empty());
1303 }
1304
1305 #[tokio::test]
1306 async fn stateful_chat_append_response_text_appends_non_empty_and_skips_empty() {
1307 let mut chat = StatefulChat::new(XaiClient::new("test-key").unwrap(), "grok-4".to_string());
1308 chat.append_system("You are helpful.").append_user("Hello");
1309
1310 let non_empty = Response {
1311 id: "resp_non_empty".to_string(),
1312 model: "grok-4".to_string(),
1313 output: vec![OutputItem::Message {
1314 role: Role::Assistant,
1315 content: vec![
1316 TextContent::Text {
1317 text: "Hello ".to_string(),
1318 },
1319 TextContent::Text {
1320 text: "again".to_string(),
1321 },
1322 ],
1323 }],
1324 usage: Default::default(),
1325 citations: None,
1326 inline_citations: None,
1327 server_side_tool_usage: None,
1328 tool_calls: None,
1329 system_fingerprint: None,
1330 };
1331
1332 let empty = Response {
1333 id: "resp_empty".to_string(),
1334 model: "grok-4".to_string(),
1335 output: vec![OutputItem::Message {
1336 role: Role::Assistant,
1337 content: vec![TextContent::Refusal {
1338 refusal: "policy block".to_string(),
1339 }],
1340 }],
1341 usage: Default::default(),
1342 citations: None,
1343 inline_citations: None,
1344 server_side_tool_usage: None,
1345 tool_calls: None,
1346 system_fingerprint: None,
1347 };
1348
1349 chat.append_response_text(&non_empty);
1350 chat.append_response_text(&empty);
1351
1352 assert_eq!(chat.messages().len(), 3);
1353 assert_eq!(
1354 chat.messages().last().unwrap().content.as_text(),
1355 Some("Hello again")
1356 );
1357 }
1358
1359 #[tokio::test]
1360 async fn stateful_chat_sample_and_append_carries_refusal_text() {
1361 let server = MockServer::start().await;
1362
1363 Mock::given(method("POST"))
1364 .and(path("/responses"))
1365 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
1366 "id": "resp_126",
1367 "model": "grok-4",
1368 "output": [{
1369 "type": "message",
1370 "role": "assistant",
1371 "content": [{"type": "refusal", "refusal": "I can't help with that."}]
1372 }],
1373 "usage": {"prompt_tokens": 10, "completion_tokens": 1, "total_tokens": 11}
1374 })))
1375 .mount(&server)
1376 .await;
1377
1378 let client = XaiClient::builder()
1379 .api_key("test-key")
1380 .base_url(server.uri())
1381 .build()
1382 .unwrap();
1383
1384 let mut chat = client.responses().chat("grok-4");
1385 chat.append_system("You are helpful.")
1386 .append_user("Do something unsafe");
1387
1388 let response = chat.sample_and_append().await.unwrap();
1389 assert!(response.output_text().is_none());
1390 assert_eq!(chat.messages().len(), 3);
1391 assert!(chat.pending_tool_calls().is_empty());
1392 assert_eq!(
1393 chat.messages().last().unwrap().content.as_text(),
1394 Some("I can't help with that.")
1395 );
1396 }
1397
1398 #[tokio::test]
1399 async fn stateful_chat_sample_and_append_captures_tool_calls() {
1400 let server = MockServer::start().await;
1401
1402 Mock::given(method("POST"))
1403 .and(path("/responses"))
1404 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
1405 "id": "resp_127",
1406 "model": "grok-4",
1407 "output": [{
1408 "type": "function_call",
1409 "id": "call_weather",
1410 "function": {
1411 "name": "get_weather",
1412 "arguments": "{\"location\":\"Paris\"}"
1413 }
1414 }],
1415 "usage": {"prompt_tokens": 10, "completion_tokens": 1, "total_tokens": 11}
1416 })))
1417 .mount(&server)
1418 .await;
1419
1420 let client = XaiClient::builder()
1421 .api_key("test-key")
1422 .base_url(server.uri())
1423 .build()
1424 .unwrap();
1425
1426 let mut chat = client.responses().chat("grok-4");
1427 chat.append_system("You are helpful.")
1428 .append_user("What's weather in Paris?");
1429
1430 let _ = chat.sample_and_append().await.unwrap();
1431
1432 assert_eq!(chat.messages().len(), 2);
1434 assert_eq!(chat.pending_tool_calls().len(), 1);
1435 assert_eq!(chat.pending_tool_calls()[0].id, "call_weather");
1436 assert_eq!(
1437 chat.pending_tool_calls()[0]
1438 .function
1439 .as_ref()
1440 .map(|f| f.name.as_str()),
1441 Some("get_weather")
1442 );
1443
1444 let pending = chat.take_pending_tool_calls();
1445 assert_eq!(pending.len(), 1);
1446 assert!(chat.pending_tool_calls().is_empty());
1447
1448 chat.append_tool_result("call_weather", r#"{"temperature": 72}"#);
1449 assert_eq!(chat.messages().len(), 3);
1450 assert!(matches!(chat.messages().last().unwrap().role, Role::Tool));
1451 }
1452
1453 #[tokio::test]
1454 async fn stateful_chat_sample_with_tool_handler_resolves_tool_loop() {
1455 let server = MockServer::start().await;
1456 let call_count = Arc::new(AtomicUsize::new(0));
1457 let responder_count = Arc::clone(&call_count);
1458
1459 Mock::given(method("POST"))
1460 .and(path("/responses"))
1461 .respond_with(move |_req: &wiremock::Request| {
1462 let count = responder_count.fetch_add(1, Ordering::SeqCst);
1463 if count == 0 {
1464 ResponseTemplate::new(200).set_body_json(json!({
1465 "id": "resp_tool_1",
1466 "model": "grok-4",
1467 "output": [{
1468 "type": "function_call",
1469 "id": "call_weather",
1470 "function": {
1471 "name": "get_weather",
1472 "arguments": "{\"location\":\"Paris\"}"
1473 }
1474 }],
1475 "usage": {"prompt_tokens": 10, "completion_tokens": 1, "total_tokens": 11}
1476 }))
1477 } else {
1478 ResponseTemplate::new(200).set_body_json(json!({
1479 "id": "resp_tool_2",
1480 "model": "grok-4",
1481 "output": [{
1482 "type": "message",
1483 "role": "assistant",
1484 "content": [{"type": "text", "text": "Weather is 72F"}]
1485 }],
1486 "usage": {"prompt_tokens": 12, "completion_tokens": 3, "total_tokens": 15}
1487 }))
1488 }
1489 })
1490 .mount(&server)
1491 .await;
1492
1493 let client = XaiClient::builder()
1494 .api_key("test-key")
1495 .base_url(server.uri())
1496 .build()
1497 .unwrap();
1498
1499 let mut chat = client.responses().chat("grok-4");
1500 chat.append_system("You are helpful.")
1501 .append_user("What's the weather in Paris?");
1502
1503 let response = chat
1504 .sample_with_tool_handler(3, |call| async move {
1505 if call.id == "call_weather" {
1506 Ok(r#"{"temperature": 72}"#.to_string())
1507 } else {
1508 Ok("{}".to_string())
1509 }
1510 })
1511 .await
1512 .unwrap();
1513
1514 assert_eq!(response.output_text().as_deref(), Some("Weather is 72F"));
1515 assert_eq!(call_count.load(Ordering::SeqCst), 2);
1516 assert!(chat.pending_tool_calls().is_empty());
1517 assert_eq!(chat.messages().len(), 4);
1518 assert!(matches!(chat.messages()[2].role, Role::Tool));
1519 assert!(matches!(chat.messages()[3].role, Role::Assistant));
1520 }
1521
1522 #[tokio::test]
1523 async fn stateful_chat_sample_with_tool_handler_errors_when_rounds_exhausted() {
1524 let server = MockServer::start().await;
1525
1526 Mock::given(method("POST"))
1527 .and(path("/responses"))
1528 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
1529 "id": "resp_tool_loop",
1530 "model": "grok-4",
1531 "output": [{
1532 "type": "function_call",
1533 "id": "call_loop",
1534 "function": {
1535 "name": "looping_tool",
1536 "arguments": "{}"
1537 }
1538 }],
1539 "usage": {"prompt_tokens": 10, "completion_tokens": 1, "total_tokens": 11}
1540 })))
1541 .mount(&server)
1542 .await;
1543
1544 let client = XaiClient::builder()
1545 .api_key("test-key")
1546 .base_url(server.uri())
1547 .build()
1548 .unwrap();
1549
1550 let mut chat = client.responses().chat("grok-4");
1551 chat.append_system("You are helpful.")
1552 .append_user("Trigger loop");
1553
1554 let err = chat
1555 .sample_with_tool_handler(1, |_call| async move { Ok("{}".to_string()) })
1556 .await
1557 .unwrap_err();
1558
1559 match err {
1560 Error::Config(message) => {
1561 assert!(message.contains("stateful chat tool loop exceeded max rounds (1)"))
1562 }
1563 _ => panic!("expected Error::Config"),
1564 }
1565 }
1566
1567 #[tokio::test]
1568 async fn stateful_chat_sample_with_tool_handler_propagates_handler_error() {
1569 let server = MockServer::start().await;
1570
1571 Mock::given(method("POST"))
1572 .and(path("/responses"))
1573 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
1574 "id": "resp_tool_error",
1575 "model": "grok-4",
1576 "output": [{
1577 "type": "function_call",
1578 "id": "call_fail",
1579 "function": {
1580 "name": "failing_tool",
1581 "arguments": "{}"
1582 }
1583 }],
1584 "usage": {"prompt_tokens": 10, "completion_tokens": 1, "total_tokens": 11}
1585 })))
1586 .mount(&server)
1587 .await;
1588
1589 let client = XaiClient::builder()
1590 .api_key("test-key")
1591 .base_url(server.uri())
1592 .build()
1593 .unwrap();
1594
1595 let mut chat = client.responses().chat("grok-4");
1596 chat.append_system("You are helpful.")
1597 .append_user("Trigger tool error");
1598
1599 let err = chat
1600 .sample_with_tool_handler(3, |_call| async move {
1601 Err(Error::Config("tool handler failed".to_string()))
1602 })
1603 .await
1604 .unwrap_err();
1605
1606 match err {
1607 Error::Config(message) => assert!(message.contains("tool handler failed")),
1608 _ => panic!("expected Error::Config"),
1609 }
1610 }
1611
1612 #[tokio::test]
1613 async fn stateful_chat_stream_accumulates_deltas() {
1614 let payload = concat!(
1615 "data: {\"type\":\"response.output_item.delta\",\"delta\":{\"content\":[{\"type\":\"text\",\"text\":\"Hel\"}]}}\n\n",
1616 "data: {\"type\":\"response.output_item.delta\",\"delta\":{\"content\":[{\"type\":\"text\",\"text\":\"lo\"}]}}\n\n",
1617 "data: [DONE]\n\n"
1618 );
1619
1620 let chunks: Vec<std::result::Result<Bytes, reqwest::Error>> =
1621 vec![Ok(Bytes::from(payload.to_string()))];
1622 let raw_stream = stream::iter(chunks);
1623 let response_stream = ResponseStream::new(raw_stream);
1624 let mut stream = StatefulChatStream::new(response_stream);
1625
1626 let first = stream.next().await.unwrap().unwrap();
1627 assert_eq!(first.delta(), "Hel");
1628 assert_eq!(stream.accumulated_text(), "Hel");
1629
1630 let second = stream.next().await.unwrap().unwrap();
1631 assert_eq!(second.delta(), "lo");
1632 assert_eq!(stream.accumulated_text(), "Hello");
1633
1634 let done = stream.next().await.unwrap().unwrap();
1635 assert!(done.done);
1636 assert_eq!(stream.accumulated_text(), "Hello");
1637 }
1638
1639 #[tokio::test]
1640 async fn stateful_chat_stream_next_with_accumulated_returns_snapshot() {
1641 let payload = concat!(
1642 "data: {\"type\":\"response.output_item.delta\",\"delta\":{\"content\":[{\"type\":\"text\",\"text\":\"Hel\"}]}}\n\n",
1643 "data: {\"type\":\"response.output_item.delta\",\"delta\":{\"content\":[{\"type\":\"text\",\"text\":\"lo\"}]}}\n\n",
1644 "data: [DONE]\n\n"
1645 );
1646
1647 let chunks: Vec<std::result::Result<Bytes, reqwest::Error>> =
1648 vec![Ok(Bytes::from(payload.to_string()))];
1649 let raw_stream = stream::iter(chunks);
1650 let response_stream = ResponseStream::new(raw_stream);
1651 let mut stream = StatefulChatStream::new(response_stream);
1652
1653 let first = stream.next_with_accumulated().await.unwrap().unwrap();
1654 assert_eq!(first.chunk.delta(), "Hel");
1655 assert_eq!(first.accumulated_text, "Hel");
1656
1657 let second = stream.next_with_accumulated().await.unwrap().unwrap();
1658 assert_eq!(second.chunk.delta(), "lo");
1659 assert_eq!(second.accumulated_text, "Hello");
1660
1661 let done = stream.next_with_accumulated().await.unwrap().unwrap();
1662 assert!(done.chunk.done);
1663 assert_eq!(done.accumulated_text, "Hello");
1664 }
1665
1666 #[tokio::test]
1667 async fn deferred_poller_returns_when_response_has_output() {
1668 let server = MockServer::start().await;
1669 let call_count = Arc::new(AtomicUsize::new(0));
1670 let call_count_for_responder = Arc::clone(&call_count);
1671
1672 Mock::given(method("GET"))
1673 .and(path("/responses/resp_deferred"))
1674 .respond_with(move |_req: &wiremock::Request| {
1675 let current = call_count_for_responder.fetch_add(1, Ordering::SeqCst);
1676 if current == 0 {
1677 ResponseTemplate::new(200).set_body_json(json!({
1678 "id": "resp_deferred",
1679 "model": "grok-4",
1680 "output": [],
1681 "usage": {"prompt_tokens": 1, "completion_tokens": 0, "total_tokens": 1}
1682 }))
1683 } else {
1684 ResponseTemplate::new(200).set_body_json(json!({
1685 "id": "resp_deferred",
1686 "model": "grok-4",
1687 "output": [{
1688 "type": "message",
1689 "role": "assistant",
1690 "content": [{"type": "text", "text": "ready"}]
1691 }],
1692 "usage": {"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}
1693 }))
1694 }
1695 })
1696 .mount(&server)
1697 .await;
1698
1699 let client = XaiClient::builder()
1700 .api_key("test-key")
1701 .base_url(server.uri())
1702 .build()
1703 .unwrap();
1704
1705 let response = client
1706 .responses()
1707 .deferred("resp_deferred")
1708 .max_attempts(3)
1709 .wait()
1710 .await
1711 .unwrap();
1712
1713 assert_eq!(response.output_text().as_deref(), Some("ready"));
1714 assert_eq!(call_count.load(Ordering::SeqCst), 2);
1715 }
1716
1717 #[tokio::test]
1718 async fn deferred_poller_times_out_when_output_never_arrives() {
1719 let server = MockServer::start().await;
1720
1721 Mock::given(method("GET"))
1722 .and(path("/responses/resp_timeout"))
1723 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
1724 "id": "resp_timeout",
1725 "model": "grok-4",
1726 "output": [],
1727 "usage": {"prompt_tokens": 1, "completion_tokens": 0, "total_tokens": 1}
1728 })))
1729 .mount(&server)
1730 .await;
1731
1732 let client = XaiClient::builder()
1733 .api_key("test-key")
1734 .base_url(server.uri())
1735 .build()
1736 .unwrap();
1737
1738 let err = client
1739 .responses()
1740 .deferred("resp_timeout")
1741 .max_attempts(2)
1742 .wait()
1743 .await
1744 .unwrap_err();
1745
1746 assert!(matches!(err, Error::Timeout));
1747 }
1748
1749 #[test]
1750 fn deferred_poller_poll_delay_returns_zero_when_interval_is_zero() {
1751 assert_eq!(
1752 DeferredResponsePoller::poll_delay_for(Duration::ZERO, Duration::from_millis(200), 0),
1753 Duration::ZERO
1754 );
1755 assert_eq!(
1756 DeferredResponsePoller::poll_delay_for(Duration::from_millis(200), Duration::ZERO, 0),
1757 Duration::ZERO
1758 );
1759 }
1760
1761 #[tokio::test]
1762 async fn stateful_chat_sample_with_tool_handler_minimum_round_is_one() {
1763 let server = MockServer::start().await;
1764 let call_count = Arc::new(AtomicUsize::new(0));
1765 let count_for_responses = Arc::clone(&call_count);
1766 let handler_invocations = Arc::new(AtomicUsize::new(0));
1767 let handler_invocations_for_handler = Arc::clone(&handler_invocations);
1768
1769 Mock::given(method("POST"))
1770 .and(path("/responses"))
1771 .respond_with(move |_req: &wiremock::Request| {
1772 let current = count_for_responses.fetch_add(1, Ordering::SeqCst);
1773 if current == 0 {
1774 ResponseTemplate::new(200).set_body_json(json!( {
1775 "id": "resp_tool_1",
1776 "model": "grok-4",
1777 "output": [{
1778 "type": "message",
1779 "role": "assistant",
1780 "content": [{ "type": "text", "text": "Recovered" }]
1781 }],
1782 "usage": {"prompt_tokens": 12, "completion_tokens": 3, "total_tokens": 15}
1783 }))
1784 } else {
1785 panic!("unexpected second request")
1786 }
1787 })
1788 .mount(&server)
1789 .await;
1790
1791 let client = XaiClient::builder()
1792 .api_key("test-key")
1793 .base_url(server.uri())
1794 .build()
1795 .unwrap();
1796
1797 let mut chat = client.responses().chat("grok-4");
1798 chat.append_system("You are helpful.").append_user("Hello");
1799
1800 let response = chat
1801 .sample_with_tool_handler(0, move |_call| {
1802 let handler_invocations_for_handler = Arc::clone(&handler_invocations_for_handler);
1803 async move {
1804 handler_invocations_for_handler.fetch_add(1, Ordering::SeqCst);
1805 Ok(r#"{"value":42}"#.to_string())
1806 }
1807 })
1808 .await
1809 .unwrap();
1810
1811 assert_eq!(response.output_text().as_deref(), Some("Recovered"));
1812 assert_eq!(call_count.load(Ordering::SeqCst), 1);
1813 assert_eq!(handler_invocations.load(Ordering::SeqCst), 0);
1814 }
1815
1816 #[tokio::test]
1817 async fn deferred_poller_wait_uses_poll_delay_between_attempts() {
1818 let server = MockServer::start().await;
1819 let request_count = Arc::new(AtomicUsize::new(0));
1820 let count_for_handler = Arc::clone(&request_count);
1821
1822 Mock::given(method("GET"))
1823 .and(path("/responses/resp_poll_delay"))
1824 .respond_with(move |_req: &wiremock::Request| {
1825 let current = count_for_handler.fetch_add(1, Ordering::SeqCst);
1826 if current == 0 {
1827 ResponseTemplate::new(200).set_body_json(json!({
1828 "id": "resp_poll_delay",
1829 "model": "grok-4",
1830 "output": [],
1831 "usage": {"prompt_tokens": 4, "completion_tokens": 0, "total_tokens": 4}
1832 }))
1833 } else {
1834 ResponseTemplate::new(200).set_body_json(json!({
1835 "id": "resp_poll_delay",
1836 "model": "grok-4",
1837 "output": [{
1838 "type": "message",
1839 "role": "assistant",
1840 "content": [{"type": "text", "text": "ready"}]
1841 }],
1842 "usage": {"prompt_tokens": 4, "completion_tokens": 1, "total_tokens": 5}
1843 }))
1844 }
1845 })
1846 .mount(&server)
1847 .await;
1848
1849 let client = XaiClient::builder()
1850 .api_key("test-key")
1851 .base_url(server.uri())
1852 .build()
1853 .unwrap();
1854
1855 let started_at = std::time::Instant::now();
1856 let response = client
1857 .responses()
1858 .deferred("resp_poll_delay")
1859 .poll_interval(Duration::from_millis(80))
1860 .max_attempts(2)
1861 .wait()
1862 .await
1863 .unwrap();
1864 let elapsed = started_at.elapsed();
1865
1866 assert_eq!(response.output_text().as_deref(), Some("ready"));
1867 assert_eq!(request_count.load(Ordering::SeqCst), 2);
1868 assert!(
1869 elapsed >= Duration::from_millis(60),
1870 "expected exponential backoff delay to be observed"
1871 );
1872 assert!(
1873 elapsed < Duration::from_millis(500),
1874 "expected delay test to stay bounded"
1875 );
1876 }
1877
1878 #[test]
1879 fn stateful_chat_merge_output_message_content_and_collect_response_semantics() {
1880 let text_output = vec![
1881 TextContent::Text {
1882 text: "Hello ".to_string(),
1883 },
1884 TextContent::Refusal {
1885 refusal: "Blocked".to_string(),
1886 },
1887 ];
1888 assert_eq!(
1889 StatefulChat::merge_output_message_content(&text_output),
1890 Some("Hello Blocked".to_string())
1891 );
1892 assert_eq!(StatefulChat::merge_output_message_content(&[]), None);
1893
1894 let shared_tool_call_id = "shared-call".to_string();
1895 let response = Response {
1896 id: "resp_semantics".to_string(),
1897 model: "grok-4".to_string(),
1898 output: vec![
1899 OutputItem::Message {
1900 role: Role::Assistant,
1901 content: vec![
1902 TextContent::Text {
1903 text: "Part 1 ".to_string(),
1904 },
1905 TextContent::Text {
1906 text: "Part 2".to_string(),
1907 },
1908 ],
1909 },
1910 OutputItem::FunctionCall {
1911 call: ToolCall {
1912 id: shared_tool_call_id.clone(),
1913 call_type: Some("function".to_string()),
1914 function: None,
1915 },
1916 },
1917 OutputItem::FunctionCall {
1918 call: ToolCall {
1919 id: "function_call".to_string(),
1920 call_type: Some("function".to_string()),
1921 function: None,
1922 },
1923 },
1924 OutputItem::CodeInterpreterCall {
1925 id: "ci_call".to_string(),
1926 code: None,
1927 outputs: None,
1928 },
1929 OutputItem::WebSearchCall {
1930 id: "web_call".to_string(),
1931 results: None,
1932 },
1933 OutputItem::XSearchCall {
1934 id: "x_call".to_string(),
1935 results: None,
1936 },
1937 ],
1938 usage: Default::default(),
1939 citations: None,
1940 inline_citations: None,
1941 server_side_tool_usage: None,
1942 tool_calls: Some(vec![ToolCall {
1943 id: shared_tool_call_id.clone(),
1944 call_type: Some("function".to_string()),
1945 function: None,
1946 }]),
1947 system_fingerprint: None,
1948 };
1949
1950 let (assistant_messages, pending_tool_calls) =
1951 StatefulChat::collect_response_semantics(&response);
1952 assert_eq!(assistant_messages, vec!["Part 1 Part 2".to_string()]);
1953 assert_eq!(pending_tool_calls.len(), 5);
1954 assert_eq!(pending_tool_calls[0].id, shared_tool_call_id);
1955 assert_eq!(pending_tool_calls[1].id, "function_call");
1956 assert_eq!(pending_tool_calls[1].call_type.as_deref(), Some("function"));
1957 assert_eq!(pending_tool_calls[2].id, "ci_call");
1958 assert_eq!(
1959 pending_tool_calls[2].call_type.as_deref(),
1960 Some("code_interpreter")
1961 );
1962 assert_eq!(pending_tool_calls[3].id, "web_call");
1963 assert_eq!(
1964 pending_tool_calls[3].call_type.as_deref(),
1965 Some("web_search")
1966 );
1967 assert_eq!(pending_tool_calls[4].id, "x_call");
1968 assert_eq!(pending_tool_calls[4].call_type.as_deref(), Some("x_search"));
1969 }
1970
1971 #[test]
1972 fn deferred_poller_poll_interval_sets_fixed_delay() {
1973 let poller = DeferredResponsePoller::new(
1974 XaiClient::new("test-key").unwrap(),
1975 "resp_123".to_string(),
1976 )
1977 .poll_interval(Duration::from_millis(250));
1978
1979 assert_eq!(poller.poll_initial_delay, Duration::from_millis(250));
1980 assert_eq!(poller.poll_max_delay, Duration::from_millis(250));
1981 assert_eq!(
1982 DeferredResponsePoller::poll_delay_for(
1983 poller.poll_initial_delay,
1984 poller.poll_max_delay,
1985 0
1986 ),
1987 Duration::from_millis(250)
1988 );
1989 assert_eq!(
1990 DeferredResponsePoller::poll_delay_for(
1991 poller.poll_initial_delay,
1992 poller.poll_max_delay,
1993 5
1994 ),
1995 Duration::from_millis(250)
1996 );
1997 }
1998
1999 #[test]
2000 fn deferred_poller_poll_backoff_is_exponential_and_capped() {
2001 let poller = DeferredResponsePoller::new(
2002 XaiClient::new("test-key").unwrap(),
2003 "resp_123".to_string(),
2004 )
2005 .poll_backoff(Duration::from_millis(100), Duration::from_millis(300));
2006
2007 assert_eq!(
2008 DeferredResponsePoller::poll_delay_for(
2009 poller.poll_initial_delay,
2010 poller.poll_max_delay,
2011 0
2012 ),
2013 Duration::from_millis(100)
2014 );
2015 assert_eq!(
2016 DeferredResponsePoller::poll_delay_for(
2017 poller.poll_initial_delay,
2018 poller.poll_max_delay,
2019 1
2020 ),
2021 Duration::from_millis(200)
2022 );
2023 assert_eq!(
2024 DeferredResponsePoller::poll_delay_for(
2025 poller.poll_initial_delay,
2026 poller.poll_max_delay,
2027 2
2028 ),
2029 Duration::from_millis(300)
2030 );
2031 assert_eq!(
2032 DeferredResponsePoller::poll_delay_for(
2033 poller.poll_initial_delay,
2034 poller.poll_max_delay,
2035 3
2036 ),
2037 Duration::from_millis(300)
2038 );
2039 }
2040}