Skip to main content

xai_rust/api/
responses.rs

1//! Responses API - the primary endpoint for chat interactions.
2
3use futures_util::Stream;
4use pin_project_lite::pin_project;
5use serde::Serialize;
6use std::collections::HashSet;
7use std::pin::Pin;
8use std::task::{Context, Poll};
9use std::time::Duration;
10use tokio::time::sleep;
11
12use crate::client::XaiClient;
13use crate::config::RetryPolicy;
14use crate::models::content::ContentPart;
15use crate::models::message::{Message, MessageContent, Role};
16use crate::models::response::{OutputItem, Response, ResponseFormat, StreamChunk, TextContent};
17use crate::models::tool::{Tool, ToolCall, ToolChoice};
18use crate::stream::ResponseStream;
19use crate::{Error, Result};
20
21const DEFAULT_DEFERRED_MAX_ATTEMPTS: u32 = 30;
22const DEFAULT_DEFERRED_POLL_INTERVAL: Duration = Duration::ZERO;
23const DEFAULT_STATEFUL_TOOL_LOOP_MAX_ROUNDS: u32 = 8;
24
25/// API for creating model responses.
26#[derive(Debug, Clone)]
27pub struct ResponsesApi {
28    client: XaiClient,
29}
30
31impl ResponsesApi {
32    pub(crate) fn new(client: XaiClient) -> Self {
33        Self { client }
34    }
35
36    /// Create a new response request builder.
37    ///
38    /// # Example
39    ///
40    /// ```rust,no_run
41    /// use xai_rust::{XaiClient, Role};
42    ///
43    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
44    /// let client = XaiClient::from_env()?;
45    ///
46    /// let response = client.responses()
47    ///     .create("grok-4")
48    ///     .message(Role::System, "You are a helpful assistant.")
49    ///     .message(Role::User, "Hello!")
50    ///     .temperature(0.7)
51    ///     .send()
52    ///     .await?;
53    /// # Ok(())
54    /// # }
55    /// ```
56    pub fn create(&self, model: impl Into<String>) -> CreateResponseBuilder {
57        CreateResponseBuilder::new(self.client.clone(), model.into())
58    }
59
60    /// Create a deferred response poller.
61    pub fn deferred(&self, response_id: impl Into<String>) -> DeferredResponsePoller {
62        DeferredResponsePoller::new(self.client.clone(), response_id.into())
63    }
64
65    /// Create a stateful chat handle for multi-turn interactions.
66    ///
67    /// This helper keeps a local message list and lets you append messages over time,
68    /// then sample or stream with the accumulated context.
69    pub fn chat(&self, model: impl Into<String>) -> StatefulChat {
70        StatefulChat::new(self.client.clone(), model.into())
71    }
72
73    /// Get a previous response by ID.
74    pub async fn get(&self, response_id: &str) -> Result<Response> {
75        let id = XaiClient::encode_path(response_id);
76        let url = format!("{}/responses/{}", self.client.base_url(), id);
77
78        let response = self.client.send(self.client.http().get(&url)).await?;
79
80        if !response.status().is_success() {
81            return Err(Error::from_response(response).await);
82        }
83
84        Ok(response.json().await?)
85    }
86
87    /// Delete a response by ID.
88    pub async fn delete(&self, response_id: &str) -> Result<()> {
89        let id = XaiClient::encode_path(response_id);
90        let url = format!("{}/responses/{}", self.client.base_url(), id);
91
92        let response = self.client.send(self.client.http().delete(&url)).await?;
93
94        if !response.status().is_success() {
95            return Err(Error::from_response(response).await);
96        }
97
98        Ok(())
99    }
100
101    /// Poll a response until output is available or attempts are exhausted.
102    pub async fn poll_until_ready(&self, response_id: &str, max_attempts: u32) -> Result<Response> {
103        self.deferred(response_id.to_string())
104            .max_attempts(max_attempts)
105            .wait()
106            .await
107    }
108}
109
110/// Helper for polling deferred response completion.
111#[derive(Debug, Clone)]
112pub struct DeferredResponsePoller {
113    client: XaiClient,
114    response_id: String,
115    max_attempts: u32,
116    poll_initial_delay: Duration,
117    poll_max_delay: Duration,
118}
119
120impl DeferredResponsePoller {
121    fn new(client: XaiClient, response_id: String) -> Self {
122        Self {
123            client,
124            response_id,
125            max_attempts: DEFAULT_DEFERRED_MAX_ATTEMPTS,
126            poll_initial_delay: DEFAULT_DEFERRED_POLL_INTERVAL,
127            poll_max_delay: DEFAULT_DEFERRED_POLL_INTERVAL,
128        }
129    }
130
131    /// Set the maximum polling attempts.
132    pub fn max_attempts(mut self, max_attempts: u32) -> Self {
133        self.max_attempts = max_attempts.max(1);
134        self
135    }
136
137    /// Set a fixed polling interval between attempts.
138    pub fn poll_interval(mut self, interval: Duration) -> Self {
139        self.poll_initial_delay = interval;
140        self.poll_max_delay = interval;
141        self
142    }
143
144    /// Set exponential polling backoff bounds between attempts.
145    pub fn poll_backoff(mut self, initial: Duration, max: Duration) -> Self {
146        self.poll_initial_delay = initial;
147        self.poll_max_delay = max.max(initial);
148        self
149    }
150
151    fn poll_delay_for(initial: Duration, max: Duration, attempt: u32) -> Duration {
152        let max_millis = max.as_millis();
153        let initial_millis = initial.as_millis();
154        if max_millis == 0 || initial_millis == 0 {
155            return Duration::ZERO;
156        }
157
158        let factor_shift = attempt.min(16);
159        let factor = 1u128 << factor_shift;
160        let delayed = initial_millis.saturating_mul(factor).min(max_millis);
161        Duration::from_millis(delayed as u64)
162    }
163
164    /// Wait for response output to become available.
165    pub async fn wait(self) -> Result<Response> {
166        let DeferredResponsePoller {
167            client,
168            response_id,
169            max_attempts,
170            poll_initial_delay,
171            poll_max_delay,
172        } = self;
173
174        let api = ResponsesApi::new(client);
175
176        for attempt in 0..max_attempts {
177            let response = api.get(&response_id).await?;
178            if !response.output.is_empty() {
179                return Ok(response);
180            }
181
182            if attempt + 1 < max_attempts {
183                let delay = Self::poll_delay_for(poll_initial_delay, poll_max_delay, attempt);
184                if !delay.is_zero() {
185                    sleep(delay).await;
186                }
187            }
188        }
189
190        Err(Error::Timeout)
191    }
192}
193
194/// Stateful chat helper on top of the Responses API.
195#[derive(Debug, Clone)]
196pub struct StatefulChat {
197    client: XaiClient,
198    model: String,
199    messages: Vec<Message>,
200    pending_tool_calls: Vec<ToolCall>,
201}
202
203impl StatefulChat {
204    fn new(client: XaiClient, model: String) -> Self {
205        Self {
206            client,
207            model,
208            messages: Vec::new(),
209            pending_tool_calls: Vec::new(),
210        }
211    }
212
213    /// Append a message to the local chat state.
214    pub fn append(&mut self, role: Role, content: impl Into<MessageContent>) -> &mut Self {
215        self.messages.push(Message::new(role, content));
216        self
217    }
218
219    /// Append a system message.
220    pub fn append_system(&mut self, content: impl Into<String>) -> &mut Self {
221        self.append(Role::System, content.into())
222    }
223
224    /// Append a user message.
225    pub fn append_user(&mut self, content: impl Into<MessageContent>) -> &mut Self {
226        self.append(Role::User, content)
227    }
228
229    /// Append an assistant message.
230    pub fn append_assistant(&mut self, content: impl Into<String>) -> &mut Self {
231        self.append(Role::Assistant, content.into())
232    }
233
234    /// Append a pre-built message.
235    pub fn append_message(&mut self, message: Message) -> &mut Self {
236        self.messages.push(message);
237        self
238    }
239
240    /// Append a tool result message.
241    pub fn append_tool_result(
242        &mut self,
243        tool_call_id: impl Into<String>,
244        content: impl Into<String>,
245    ) -> &mut Self {
246        self.append_message(Message::tool(tool_call_id, content))
247    }
248
249    /// Get the current local message history.
250    pub fn messages(&self) -> &[Message] {
251        &self.messages
252    }
253
254    /// Get pending tool calls extracted from the latest sampled response.
255    pub fn pending_tool_calls(&self) -> &[ToolCall] {
256        &self.pending_tool_calls
257    }
258
259    /// Take and clear pending tool calls.
260    pub fn take_pending_tool_calls(&mut self) -> Vec<ToolCall> {
261        std::mem::take(&mut self.pending_tool_calls)
262    }
263
264    /// Clear the local message history.
265    pub fn clear(&mut self) -> &mut Self {
266        self.messages.clear();
267        self.pending_tool_calls.clear();
268        self
269    }
270
271    fn text_content_slice(content: &TextContent) -> &str {
272        match content {
273            TextContent::Text { text } => text,
274            // Preserve refusal semantics in history by carrying refusal text.
275            TextContent::Refusal { refusal } => refusal,
276        }
277    }
278
279    fn merge_output_message_content(content: &[TextContent]) -> Option<String> {
280        match content {
281            [] => None,
282            [single] => {
283                let text = Self::text_content_slice(single);
284                (!text.is_empty()).then(|| text.to_string())
285            }
286            _ => {
287                let total_len: usize = content
288                    .iter()
289                    .map(|part| Self::text_content_slice(part).len())
290                    .sum();
291                if total_len == 0 {
292                    return None;
293                }
294
295                let mut merged = String::with_capacity(total_len);
296                for part in content {
297                    merged.push_str(Self::text_content_slice(part));
298                }
299                Some(merged)
300            }
301        }
302    }
303
304    fn collect_response_semantics(response: &Response) -> (Vec<String>, Vec<ToolCall>) {
305        let output_item_count = response.output.len();
306        let top_level_call_count = response.tool_calls.as_ref().map_or(0, Vec::len);
307        let mut assistant_messages = Vec::with_capacity(output_item_count);
308        let mut pending_tool_calls = Vec::with_capacity(top_level_call_count + output_item_count);
309        let mut seen_tool_call_ids: HashSet<String> =
310            HashSet::with_capacity(top_level_call_count + output_item_count);
311
312        // Prefer explicit top-level tool call payloads when present.
313        if let Some(calls) = &response.tool_calls {
314            for call in calls {
315                if seen_tool_call_ids.insert(call.id.clone()) {
316                    pending_tool_calls.push(call.clone());
317                }
318            }
319        }
320
321        for item in &response.output {
322            match item {
323                OutputItem::Message { content, .. } => {
324                    if let Some(merged) = Self::merge_output_message_content(content) {
325                        assistant_messages.push(merged);
326                    }
327                }
328                OutputItem::FunctionCall { call } => {
329                    if seen_tool_call_ids.insert(call.id.clone()) {
330                        pending_tool_calls.push(call.clone());
331                    }
332                }
333                OutputItem::CodeInterpreterCall { id, .. } => {
334                    if seen_tool_call_ids.insert(id.clone()) {
335                        pending_tool_calls.push(ToolCall {
336                            id: id.clone(),
337                            call_type: Some("code_interpreter".to_string()),
338                            function: None,
339                        });
340                    }
341                }
342                OutputItem::WebSearchCall { id, .. } => {
343                    if seen_tool_call_ids.insert(id.clone()) {
344                        pending_tool_calls.push(ToolCall {
345                            id: id.clone(),
346                            call_type: Some("web_search".to_string()),
347                            function: None,
348                        });
349                    }
350                }
351                OutputItem::XSearchCall { id, .. } => {
352                    if seen_tool_call_ids.insert(id.clone()) {
353                        pending_tool_calls.push(ToolCall {
354                            id: id.clone(),
355                            call_type: Some("x_search".to_string()),
356                            function: None,
357                        });
358                    }
359                }
360            }
361        }
362
363        (assistant_messages, pending_tool_calls)
364    }
365
366    /// Sample a response using the current local message history.
367    pub async fn sample(&self) -> Result<Response> {
368        self.client
369            .responses()
370            .create(self.model.clone())
371            .messages(self.messages.clone())
372            .send()
373            .await
374    }
375
376    /// Append assistant text from a response into local chat history.
377    pub fn append_response_text(&mut self, response: &Response) -> &mut Self {
378        let text = response.all_text();
379        if !text.is_empty() {
380            self.append_assistant(text);
381        }
382        self
383    }
384
385    /// Append response semantics to local history and pending tool calls.
386    pub fn append_response_semantics(&mut self, response: &Response) -> &mut Self {
387        let (assistant_messages, pending_tool_calls) = Self::collect_response_semantics(response);
388        for text in assistant_messages {
389            self.append_assistant(text);
390        }
391        self.pending_tool_calls = pending_tool_calls;
392        self
393    }
394
395    /// Sample a response and append semantic carryover to local state.
396    pub async fn sample_and_append(&mut self) -> Result<Response> {
397        let response = self.sample().await?;
398        self.append_response_semantics(&response);
399        Ok(response)
400    }
401
402    /// Run a sampled chat turn with a tool handler loop (default max rounds).
403    ///
404    /// The loop runs `sample_and_append()` repeatedly. When pending tool calls are
405    /// produced, the handler is invoked for each call and tool results are appended
406    /// as `tool` messages before the next sampling round.
407    pub async fn sample_with_tool_loop<H, Fut>(&mut self, handler: H) -> Result<Response>
408    where
409        H: FnMut(ToolCall) -> Fut,
410        Fut: std::future::Future<Output = Result<String>>,
411    {
412        self.sample_with_tool_handler(DEFAULT_STATEFUL_TOOL_LOOP_MAX_ROUNDS, handler)
413            .await
414    }
415
416    /// Run a sampled chat turn with a tool handler loop and explicit max rounds.
417    pub async fn sample_with_tool_handler<H, Fut>(
418        &mut self,
419        max_rounds: u32,
420        mut handler: H,
421    ) -> Result<Response>
422    where
423        H: FnMut(ToolCall) -> Fut,
424        Fut: std::future::Future<Output = Result<String>>,
425    {
426        let rounds = max_rounds.max(1);
427
428        for _ in 0..rounds {
429            let response = self.sample_and_append().await?;
430            let pending = self.take_pending_tool_calls();
431            if pending.is_empty() {
432                return Ok(response);
433            }
434
435            for call in pending {
436                let tool_call_id = call.id.clone();
437                let tool_result = handler(call).await?;
438                self.append_tool_result(tool_call_id, tool_result);
439            }
440        }
441
442        Err(Error::Config(format!(
443            "stateful chat tool loop exceeded max rounds ({rounds})"
444        )))
445    }
446
447    /// Stream a response using the current local message history.
448    pub async fn stream(&self) -> Result<StatefulChatStream> {
449        let stream = self
450            .client
451            .responses()
452            .create(self.model.clone())
453            .messages(self.messages.clone())
454            .stream()
455            .await?;
456
457        Ok(StatefulChatStream::new(stream))
458    }
459}
460
461pin_project! {
462    /// Stream wrapper that accumulates text deltas as chunks arrive.
463    pub struct StatefulChatStream {
464        #[pin]
465        inner: ResponseStream,
466        accumulated_text: String,
467    }
468}
469
470/// A streamed chunk paired with accumulated text so far.
471#[derive(Debug, Clone)]
472pub struct AccumulatedChunk {
473    /// The raw stream chunk from the API.
474    pub chunk: StreamChunk,
475    /// The full accumulated text after applying this chunk.
476    pub accumulated_text: String,
477}
478
479impl StatefulChatStream {
480    fn new(inner: ResponseStream) -> Self {
481        Self {
482            inner,
483            accumulated_text: String::new(),
484        }
485    }
486
487    /// Get accumulated text so far.
488    pub fn accumulated_text(&self) -> &str {
489        &self.accumulated_text
490    }
491
492    /// Consume the stream wrapper and return accumulated text.
493    pub fn into_accumulated_text(self) -> String {
494        self.accumulated_text
495    }
496
497    /// Read the next chunk along with an accumulated text snapshot.
498    pub async fn next_with_accumulated(&mut self) -> Option<Result<AccumulatedChunk>> {
499        use futures_util::future::poll_fn;
500
501        poll_fn(|cx| Pin::new(&mut *self).poll_next(cx))
502            .await
503            .map(|result| {
504                result.map(|chunk| AccumulatedChunk {
505                    chunk,
506                    accumulated_text: self.accumulated_text.clone(),
507                })
508            })
509    }
510}
511
512impl Stream for StatefulChatStream {
513    type Item = Result<StreamChunk>;
514
515    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
516        let mut this = self.project();
517
518        match this.inner.as_mut().poll_next(cx) {
519            Poll::Ready(Some(Ok(chunk))) => {
520                this.accumulated_text.push_str(chunk.delta());
521                Poll::Ready(Some(Ok(chunk)))
522            }
523            Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))),
524            Poll::Ready(None) => Poll::Ready(None),
525            Poll::Pending => Poll::Pending,
526        }
527    }
528}
529
530/// Builder for creating a response request.
531#[derive(Debug)]
532pub struct CreateResponseBuilder {
533    client: XaiClient,
534    request: CreateResponseRequest,
535    retry_policy_override: Option<RetryPolicy>,
536}
537
538#[derive(Debug, Clone, Serialize)]
539struct CreateResponseRequest {
540    model: String,
541    input: Vec<Message>,
542    #[serde(skip_serializing_if = "Option::is_none")]
543    tools: Option<Vec<Tool>>,
544    #[serde(skip_serializing_if = "Option::is_none")]
545    tool_choice: Option<ToolChoice>,
546    #[serde(skip_serializing_if = "Option::is_none")]
547    temperature: Option<f32>,
548    #[serde(skip_serializing_if = "Option::is_none")]
549    top_p: Option<f32>,
550    #[serde(skip_serializing_if = "Option::is_none")]
551    max_tokens: Option<u32>,
552    #[serde(skip_serializing_if = "Option::is_none")]
553    stream: Option<bool>,
554    #[serde(skip_serializing_if = "Option::is_none")]
555    response_format: Option<ResponseFormat>,
556    #[serde(skip_serializing_if = "Option::is_none")]
557    include: Option<Vec<String>>,
558    #[serde(skip_serializing_if = "Option::is_none")]
559    store: Option<bool>,
560}
561
562impl CreateResponseBuilder {
563    fn new(client: XaiClient, model: String) -> Self {
564        Self {
565            client,
566            request: CreateResponseRequest {
567                model,
568                input: Vec::new(),
569                tools: None,
570                tool_choice: None,
571                temperature: None,
572                top_p: None,
573                max_tokens: None,
574                stream: None,
575                response_format: None,
576                include: None,
577                store: None,
578            },
579            retry_policy_override: None,
580        }
581    }
582
583    fn retry_policy_mut(&mut self) -> &mut RetryPolicy {
584        self.retry_policy_override
585            .get_or_insert_with(|| self.client.retry_policy())
586    }
587
588    /// Add a message to the conversation.
589    pub fn message(mut self, role: Role, content: impl Into<MessageContent>) -> Self {
590        self.request.input.push(Message::new(role, content));
591        self
592    }
593
594    /// Add a system message.
595    pub fn system(self, content: impl Into<String>) -> Self {
596        self.message(Role::System, content.into())
597    }
598
599    /// Add a user message.
600    pub fn user(self, content: impl Into<MessageContent>) -> Self {
601        self.message(Role::User, content)
602    }
603
604    /// Add an assistant message.
605    pub fn assistant(self, content: impl Into<String>) -> Self {
606        self.message(Role::Assistant, content.into())
607    }
608
609    /// Add a user message with an image.
610    pub fn user_with_image(
611        mut self,
612        text: impl Into<String>,
613        image_url: impl Into<String>,
614    ) -> Self {
615        let parts = vec![ContentPart::text(text), ContentPart::image_url(image_url)];
616        self.request
617            .input
618            .push(Message::new(Role::User, MessageContent::Parts(parts)));
619        self
620    }
621
622    /// Add pre-built messages.
623    pub fn messages(mut self, messages: Vec<Message>) -> Self {
624        self.request.input.extend(messages);
625        self
626    }
627
628    /// Add a tool.
629    pub fn tool(mut self, tool: Tool) -> Self {
630        self.request.tools.get_or_insert_with(Vec::new).push(tool);
631        self
632    }
633
634    /// Add multiple tools.
635    pub fn tools(mut self, tools: Vec<Tool>) -> Self {
636        self.request
637            .tools
638            .get_or_insert_with(Vec::new)
639            .extend(tools);
640        self
641    }
642
643    /// Set the tool choice.
644    pub fn tool_choice(mut self, choice: ToolChoice) -> Self {
645        self.request.tool_choice = Some(choice);
646        self
647    }
648
649    /// Set the temperature (0.0 - 2.0).
650    pub fn temperature(mut self, temperature: f32) -> Self {
651        self.request.temperature = Some(temperature.clamp(0.0, 2.0));
652        self
653    }
654
655    /// Set top_p (0.0 - 1.0).
656    pub fn top_p(mut self, top_p: f32) -> Self {
657        self.request.top_p = Some(top_p.clamp(0.0, 1.0));
658        self
659    }
660
661    /// Set the maximum number of tokens to generate.
662    pub fn max_tokens(mut self, max_tokens: u32) -> Self {
663        self.request.max_tokens = Some(max_tokens);
664        self
665    }
666
667    /// Override retry attempts for this request only.
668    pub fn max_retries(mut self, max_retries: u32) -> Self {
669        self.retry_policy_mut().max_retries = max_retries;
670        self
671    }
672
673    /// Disable retries for this request only.
674    pub fn disable_retries(self) -> Self {
675        self.max_retries(0)
676    }
677
678    /// Override retry backoff for this request only.
679    pub fn retry_backoff(mut self, initial: Duration, max: Duration) -> Self {
680        let policy = self.retry_policy_mut();
681        policy.initial_backoff = initial;
682        policy.max_backoff = max.max(initial);
683        self
684    }
685
686    /// Override retry jitter factor (0.0 to 1.0) for this request only.
687    pub fn retry_jitter(mut self, factor: f64) -> Self {
688        self.retry_policy_mut().jitter_factor = factor.clamp(0.0, 1.0);
689        self
690    }
691
692    /// Set the response format.
693    pub fn response_format(mut self, format: ResponseFormat) -> Self {
694        self.request.response_format = Some(format);
695        self
696    }
697
698    /// Request JSON output.
699    pub fn json_output(self) -> Self {
700        self.response_format(ResponseFormat::json_object())
701    }
702
703    /// Include additional fields in the response.
704    pub fn include(mut self, fields: Vec<String>) -> Self {
705        self.request.include = Some(fields);
706        self
707    }
708
709    /// Include inline citations.
710    pub fn with_inline_citations(mut self) -> Self {
711        self.request
712            .include
713            .get_or_insert_with(Vec::new)
714            .push("inline_citations".to_string());
715        self
716    }
717
718    /// Include verbose streaming output.
719    pub fn with_verbose_streaming(mut self) -> Self {
720        self.request
721            .include
722            .get_or_insert_with(Vec::new)
723            .push("verbose_streaming".to_string());
724        self
725    }
726
727    /// Store the response for later retrieval.
728    pub fn store(mut self, store: bool) -> Self {
729        self.request.store = Some(store);
730        self
731    }
732
733    /// Send the request and get a response.
734    pub async fn send(self) -> Result<Response> {
735        let url = format!("{}/responses", self.client.base_url());
736
737        let response = self
738            .client
739            .send_with_retry_policy(
740                self.client.http().post(&url).json(&self.request),
741                self.retry_policy_override,
742            )
743            .await?;
744
745        if !response.status().is_success() {
746            return Err(Error::from_response(response).await);
747        }
748
749        Ok(response.json().await?)
750    }
751
752    /// Send the request and stream the response.
753    pub async fn stream(mut self) -> Result<ResponseStream> {
754        self.request.stream = Some(true);
755
756        let url = format!("{}/responses", self.client.base_url());
757
758        let response = self
759            .client
760            .send_with_retry_policy(
761                self.client.http().post(&url).json(&self.request),
762                self.retry_policy_override,
763            )
764            .await?;
765
766        if !response.status().is_success() {
767            return Err(Error::from_response(response).await);
768        }
769
770        Ok(ResponseStream::new(response.bytes_stream()))
771    }
772}
773
774#[cfg(test)]
775mod tests {
776    use super::*;
777    use bytes::Bytes;
778    use futures_util::{stream, StreamExt};
779    use serde_json::json;
780    use std::sync::{
781        atomic::{AtomicUsize, Ordering},
782        Arc,
783    };
784    use wiremock::{
785        matchers::{body_partial_json, method, path},
786        Mock, MockServer, ResponseTemplate,
787    };
788
789    #[test]
790    fn user_with_image_puts_text_before_image() {
791        let client = XaiClient::new("test-key").unwrap();
792        let api = ResponsesApi::new(client);
793        let builder = api
794            .create("grok-4")
795            .user_with_image("describe this", "https://example.com/image.jpg");
796
797        assert_eq!(builder.request.input.len(), 1);
798        let msg = &builder.request.input[0];
799        assert!(matches!(msg.role, Role::User));
800
801        match &msg.content {
802            MessageContent::Parts(parts) => {
803                assert_eq!(parts.len(), 2);
804                assert!(matches!(
805                    &parts[0],
806                    ContentPart::Text { text } if text == "describe this"
807                ));
808                assert!(matches!(
809                    &parts[1],
810                    ContentPart::ImageUrl { image_url } if image_url.url == "https://example.com/image.jpg"
811                ));
812            }
813            _ => panic!("Expected multipart message content"),
814        }
815    }
816
817    #[tokio::test]
818    async fn create_builder_retry_override_enables_retries_when_client_retries_disabled() {
819        let server = MockServer::start().await;
820        let call_count = Arc::new(AtomicUsize::new(0));
821        let responder_count = Arc::clone(&call_count);
822
823        Mock::given(method("POST"))
824            .and(path("/responses"))
825            .respond_with(move |_req: &wiremock::Request| {
826                let count = responder_count.fetch_add(1, Ordering::SeqCst);
827                if count == 0 {
828                    ResponseTemplate::new(503).set_body_json(json!({
829                        "error": {"message": "temporary", "type": "server_error"}
830                    }))
831                } else {
832                    ResponseTemplate::new(200).set_body_json(json!({
833                        "id": "resp_retry",
834                        "model": "grok-4",
835                        "output": [{
836                            "type": "message",
837                            "role": "assistant",
838                            "content": [{"type": "text", "text": "retry worked"}]
839                        }],
840                        "usage": {"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}
841                    }))
842                }
843            })
844            .mount(&server)
845            .await;
846
847        let client = XaiClient::builder()
848            .api_key("test-key")
849            .base_url(server.uri())
850            .disable_retries()
851            .build()
852            .unwrap();
853
854        let response = client
855            .responses()
856            .create("grok-4")
857            .user("hello")
858            .max_retries(1)
859            .retry_backoff(Duration::ZERO, Duration::ZERO)
860            .send()
861            .await
862            .unwrap();
863
864        assert_eq!(response.output_text().as_deref(), Some("retry worked"));
865        assert_eq!(call_count.load(Ordering::SeqCst), 2);
866    }
867
868    #[tokio::test]
869    async fn create_builder_disable_retries_overrides_client_retry_policy() {
870        let server = MockServer::start().await;
871        let call_count = Arc::new(AtomicUsize::new(0));
872        let responder_count = Arc::clone(&call_count);
873
874        Mock::given(method("POST"))
875            .and(path("/responses"))
876            .respond_with(move |_req: &wiremock::Request| {
877                responder_count.fetch_add(1, Ordering::SeqCst);
878                ResponseTemplate::new(503).set_body_json(json!({
879                    "error": {"message": "still unavailable", "type": "server_error"}
880                }))
881            })
882            .mount(&server)
883            .await;
884
885        let client = XaiClient::builder()
886            .api_key("test-key")
887            .base_url(server.uri())
888            .max_retries(2)
889            .retry_backoff(Duration::ZERO, Duration::ZERO)
890            .build()
891            .unwrap();
892
893        let err = client
894            .responses()
895            .create("grok-4")
896            .user("hello")
897            .disable_retries()
898            .send()
899            .await
900            .unwrap_err();
901
902        assert!(matches!(err, Error::Api { status: 503, .. }));
903        assert_eq!(call_count.load(Ordering::SeqCst), 1);
904    }
905
906    #[tokio::test]
907    async fn responses_api_get_propagates_api_error() {
908        let server = MockServer::start().await;
909
910        Mock::given(method("GET"))
911            .and(path("/responses/missing"))
912            .respond_with(ResponseTemplate::new(404).set_body_json(json!({
913                "error": {"message": "response not found"}
914            })))
915            .mount(&server)
916            .await;
917
918        let client = XaiClient::builder()
919            .api_key("test-key")
920            .base_url(server.uri())
921            .build()
922            .unwrap();
923
924        let err = client.responses().get("missing").await.unwrap_err();
925        match err {
926            Error::Api {
927                status, message, ..
928            } => {
929                assert_eq!(status, 404);
930                assert_eq!(message, "response not found");
931            }
932            _ => panic!("expected Error::Api"),
933        }
934    }
935
936    #[tokio::test]
937    async fn responses_api_delete_propagates_api_error() {
938        let server = MockServer::start().await;
939
940        Mock::given(method("DELETE"))
941            .and(path("/responses/missing"))
942            .respond_with(ResponseTemplate::new(404).set_body_json(json!({
943                "error": {"message": "response not found"}
944            })))
945            .mount(&server)
946            .await;
947
948        let client = XaiClient::builder()
949            .api_key("test-key")
950            .base_url(server.uri())
951            .build()
952            .unwrap();
953
954        let err = client.responses().delete("missing").await.unwrap_err();
955        match err {
956            Error::Api {
957                status, message, ..
958            } => {
959                assert_eq!(status, 404);
960                assert_eq!(message, "response not found");
961            }
962            _ => panic!("expected Error::Api"),
963        }
964    }
965
966    #[tokio::test]
967    async fn create_response_builder_send_propagates_api_error() {
968        let server = MockServer::start().await;
969
970        Mock::given(method("POST"))
971            .and(path("/responses"))
972            .and(body_partial_json(json!({
973                "model": "grok-4",
974                "input": [{"role": "user", "content": "error"}]
975            })))
976            .respond_with(ResponseTemplate::new(503).set_body_json(json!({
977                "error": {"message": "service unavailable"}
978            })))
979            .mount(&server)
980            .await;
981
982        let client = XaiClient::builder()
983            .api_key("test-key")
984            .base_url(server.uri())
985            .build()
986            .unwrap();
987
988        let err = client
989            .responses()
990            .create("grok-4")
991            .user("error")
992            .send()
993            .await
994            .unwrap_err();
995
996        match err {
997            Error::Api {
998                status, message, ..
999            } => {
1000                assert_eq!(status, 503);
1001                assert_eq!(message, "service unavailable");
1002            }
1003            _ => panic!("expected Error::Api"),
1004        }
1005    }
1006
1007    #[tokio::test]
1008    async fn create_response_builder_stream_propagates_api_error() {
1009        let server = MockServer::start().await;
1010
1011        Mock::given(method("POST"))
1012            .and(path("/responses"))
1013            .and(body_partial_json(json!({
1014                "model": "grok-4",
1015                "input": [{"role": "user", "content": "stream error"}]
1016            })))
1017            .respond_with(ResponseTemplate::new(503).set_body_json(json!({
1018                "error": {"message": "stream unavailable"}
1019            })))
1020            .mount(&server)
1021            .await;
1022
1023        let client = XaiClient::builder()
1024            .api_key("test-key")
1025            .base_url(server.uri())
1026            .build()
1027            .unwrap();
1028
1029        let err = match client
1030            .responses()
1031            .create("grok-4")
1032            .user("stream error")
1033            .stream()
1034            .await
1035        {
1036            Ok(_) => panic!("expected stream creation to fail"),
1037            Err(err) => err,
1038        };
1039
1040        match err {
1041            Error::Api {
1042                status, message, ..
1043            } => {
1044                assert_eq!(status, 503);
1045                assert_eq!(message, "stream unavailable");
1046            }
1047            _ => panic!("expected Error::Api"),
1048        }
1049    }
1050
1051    #[tokio::test]
1052    async fn create_response_builder_stream_returns_stream() {
1053        let server = MockServer::start().await;
1054
1055        let payload = concat!(
1056            "data: {\"type\":\"response.output_item.delta\",\"delta\":{\"content\":[{\"type\":\"text\",\"text\":\"done\"}]}}\n\n",
1057            "data: {\"type\":\"response.done\",\"response\":{\"id\":\"resp_stream\",\"model\":\"grok-4\",\"output\":[{\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"text\",\"text\":\"done\"}]}],\"usage\":{\"prompt_tokens\":1,\"completion_tokens\":1,\"total_tokens\":2}}}\n\n",
1058            "data: [DONE]\n\n"
1059        );
1060
1061        Mock::given(method("POST"))
1062            .and(path("/responses"))
1063            .respond_with(ResponseTemplate::new(200).set_body_string(payload))
1064            .mount(&server)
1065            .await;
1066
1067        let client = XaiClient::builder()
1068            .api_key("test-key")
1069            .base_url(server.uri())
1070            .build()
1071            .unwrap();
1072
1073        let mut stream = client
1074            .responses()
1075            .create("grok-4")
1076            .user("stream response")
1077            .stream()
1078            .await
1079            .unwrap();
1080
1081        let first = stream.next().await.unwrap().unwrap();
1082        assert!(!first.done);
1083        assert_eq!(first.delta(), "done");
1084        let done = stream.next().await.unwrap().unwrap();
1085        assert!(done.done);
1086        assert!(done.response.is_some());
1087        assert_eq!(done.response.unwrap().id, "resp_stream");
1088    }
1089
1090    #[tokio::test]
1091    async fn create_get_poll_delete_roundtrip() {
1092        let server = MockServer::start().await;
1093        let post_count = Arc::new(AtomicUsize::new(0));
1094        let get_count = Arc::new(AtomicUsize::new(0));
1095        let delete_count = Arc::new(AtomicUsize::new(0));
1096
1097        let post_count_for_responder = Arc::clone(&post_count);
1098        Mock::given(method("POST"))
1099            .and(path("/responses"))
1100            .and(body_partial_json(json!( {
1101                "model": "grok-4",
1102                "input": [
1103                    {"role": "system", "content": "Roundtrip test system"},
1104                    {"role": "user", "content": "What is 1+1?"}
1105                ]
1106            })))
1107            .respond_with(move |_req: &wiremock::Request| {
1108                post_count_for_responder.fetch_add(1, Ordering::SeqCst);
1109                ResponseTemplate::new(200).set_body_json(json!({
1110                    "id": "resp_roundtrip",
1111                    "model": "grok-4",
1112                    "output": [],
1113                    "usage": {"prompt_tokens": 10, "completion_tokens": 0, "total_tokens": 10}
1114                }))
1115            })
1116            .mount(&server)
1117            .await;
1118
1119        let get_count_for_responder = Arc::clone(&get_count);
1120        Mock::given(method("GET"))
1121            .and(path("/responses/resp_roundtrip"))
1122            .respond_with(move |_req: &wiremock::Request| {
1123                let count = get_count_for_responder.fetch_add(1, Ordering::SeqCst);
1124                if count == 0 {
1125                    ResponseTemplate::new(200).set_body_json(json!({
1126                        "id": "resp_roundtrip",
1127                        "model": "grok-4",
1128                        "output": [],
1129                        "usage": {"prompt_tokens": 10, "completion_tokens": 0, "total_tokens": 10}
1130                    }))
1131                } else {
1132                    ResponseTemplate::new(200).set_body_json(json!({
1133                        "id": "resp_roundtrip",
1134                        "model": "grok-4",
1135                        "output": [{
1136                            "type": "message",
1137                            "role": "assistant",
1138                            "content": [{"type": "text", "text": "hello"}]
1139                        }],
1140                        "usage": {"prompt_tokens": 11, "completion_tokens": 1, "total_tokens": 12}
1141                    }))
1142                }
1143            })
1144            .mount(&server)
1145            .await;
1146
1147        let delete_count_for_responder = Arc::clone(&delete_count);
1148        Mock::given(method("DELETE"))
1149            .and(path("/responses/resp_roundtrip"))
1150            .respond_with(move |_req: &wiremock::Request| {
1151                delete_count_for_responder.fetch_add(1, Ordering::SeqCst);
1152                ResponseTemplate::new(200).set_body_json(json!({"id": "resp_roundtrip"}))
1153            })
1154            .mount(&server)
1155            .await;
1156
1157        let client = XaiClient::builder()
1158            .api_key("test-key")
1159            .base_url(server.uri())
1160            .build()
1161            .unwrap();
1162
1163        let create_response = client
1164            .responses()
1165            .create("grok-4")
1166            .system("Roundtrip test system")
1167            .user("What is 1+1?")
1168            .send()
1169            .await
1170            .unwrap();
1171        assert_eq!(create_response.id, "resp_roundtrip");
1172        assert!(create_response.output_text().is_none());
1173
1174        let get_response = client.responses().get("resp_roundtrip").await.unwrap();
1175        assert_eq!(get_response.id, "resp_roundtrip");
1176        assert!(get_response.output_text().is_none());
1177
1178        let polled_response = client
1179            .responses()
1180            .poll_until_ready("resp_roundtrip", 2)
1181            .await
1182            .unwrap();
1183        assert_eq!(polled_response.output_text().as_deref(), Some("hello"));
1184
1185        client.responses().delete("resp_roundtrip").await.unwrap();
1186
1187        assert_eq!(post_count.load(Ordering::SeqCst), 1);
1188        assert_eq!(get_count.load(Ordering::SeqCst), 2);
1189        assert_eq!(delete_count.load(Ordering::SeqCst), 1);
1190    }
1191
1192    #[tokio::test]
1193    async fn stateful_chat_append_and_sample_sends_all_messages() {
1194        let server = MockServer::start().await;
1195
1196        Mock::given(method("POST"))
1197            .and(path("/responses"))
1198            .and(body_partial_json(json!({
1199                "model": "grok-4",
1200                "input": [
1201                    {"role": "system", "content": "You are helpful."},
1202                    {"role": "user", "content": "Hello"}
1203                ]
1204            })))
1205            .respond_with(ResponseTemplate::new(200).set_body_json(json!({
1206                "id": "resp_123",
1207                "model": "grok-4",
1208                "output": [{
1209                    "type": "message",
1210                    "role": "assistant",
1211                    "content": [{"type": "text", "text": "Hi there!"}]
1212                }],
1213                "usage": {
1214                    "prompt_tokens": 10,
1215                    "completion_tokens": 3,
1216                    "total_tokens": 13
1217                }
1218            })))
1219            .mount(&server)
1220            .await;
1221
1222        let client = XaiClient::builder()
1223            .api_key("test-key")
1224            .base_url(server.uri())
1225            .build()
1226            .unwrap();
1227
1228        let mut chat = client.responses().chat("grok-4");
1229        chat.append_system("You are helpful.").append_user("Hello");
1230
1231        assert_eq!(chat.messages().len(), 2);
1232
1233        let response = chat.sample().await.unwrap();
1234        assert_eq!(response.output_text().as_deref(), Some("Hi there!"));
1235    }
1236
1237    #[tokio::test]
1238    async fn stateful_chat_sample_and_append_updates_local_history() {
1239        let server = MockServer::start().await;
1240
1241        Mock::given(method("POST"))
1242            .and(path("/responses"))
1243            .respond_with(ResponseTemplate::new(200).set_body_json(json!({
1244                "id": "resp_124",
1245                "model": "grok-4",
1246                "output": [{
1247                    "type": "message",
1248                    "role": "assistant",
1249                    "content": [{"type": "text", "text": "History updated"}]
1250                }],
1251                "usage": {"prompt_tokens": 10, "completion_tokens": 2, "total_tokens": 12}
1252            })))
1253            .mount(&server)
1254            .await;
1255
1256        let client = XaiClient::builder()
1257            .api_key("test-key")
1258            .base_url(server.uri())
1259            .build()
1260            .unwrap();
1261
1262        let mut chat = client.responses().chat("grok-4");
1263        chat.append_system("You are helpful.").append_user("Hello");
1264
1265        let response = chat.sample_and_append().await.unwrap();
1266        assert_eq!(response.output_text().as_deref(), Some("History updated"));
1267        assert_eq!(chat.messages().len(), 3);
1268        assert!(chat.pending_tool_calls().is_empty());
1269
1270        let last = chat.messages().last().unwrap();
1271        assert!(matches!(last.role, Role::Assistant));
1272        assert_eq!(last.content.as_text(), Some("History updated"));
1273    }
1274
1275    #[tokio::test]
1276    async fn stateful_chat_sample_and_append_skips_empty_text() {
1277        let server = MockServer::start().await;
1278
1279        Mock::given(method("POST"))
1280            .and(path("/responses"))
1281            .respond_with(ResponseTemplate::new(200).set_body_json(json!({
1282                "id": "resp_125",
1283                "model": "grok-4",
1284                "output": [],
1285                "usage": {"prompt_tokens": 10, "completion_tokens": 0, "total_tokens": 10}
1286            })))
1287            .mount(&server)
1288            .await;
1289
1290        let client = XaiClient::builder()
1291            .api_key("test-key")
1292            .base_url(server.uri())
1293            .build()
1294            .unwrap();
1295
1296        let mut chat = client.responses().chat("grok-4");
1297        chat.append_system("You are helpful.").append_user("Hello");
1298
1299        let response = chat.sample_and_append().await.unwrap();
1300        assert!(response.output_text().is_none());
1301        assert_eq!(chat.messages().len(), 2);
1302        assert!(chat.pending_tool_calls().is_empty());
1303    }
1304
1305    #[tokio::test]
1306    async fn stateful_chat_append_response_text_appends_non_empty_and_skips_empty() {
1307        let mut chat = StatefulChat::new(XaiClient::new("test-key").unwrap(), "grok-4".to_string());
1308        chat.append_system("You are helpful.").append_user("Hello");
1309
1310        let non_empty = Response {
1311            id: "resp_non_empty".to_string(),
1312            model: "grok-4".to_string(),
1313            output: vec![OutputItem::Message {
1314                role: Role::Assistant,
1315                content: vec![
1316                    TextContent::Text {
1317                        text: "Hello ".to_string(),
1318                    },
1319                    TextContent::Text {
1320                        text: "again".to_string(),
1321                    },
1322                ],
1323            }],
1324            usage: Default::default(),
1325            citations: None,
1326            inline_citations: None,
1327            server_side_tool_usage: None,
1328            tool_calls: None,
1329            system_fingerprint: None,
1330        };
1331
1332        let empty = Response {
1333            id: "resp_empty".to_string(),
1334            model: "grok-4".to_string(),
1335            output: vec![OutputItem::Message {
1336                role: Role::Assistant,
1337                content: vec![TextContent::Refusal {
1338                    refusal: "policy block".to_string(),
1339                }],
1340            }],
1341            usage: Default::default(),
1342            citations: None,
1343            inline_citations: None,
1344            server_side_tool_usage: None,
1345            tool_calls: None,
1346            system_fingerprint: None,
1347        };
1348
1349        chat.append_response_text(&non_empty);
1350        chat.append_response_text(&empty);
1351
1352        assert_eq!(chat.messages().len(), 3);
1353        assert_eq!(
1354            chat.messages().last().unwrap().content.as_text(),
1355            Some("Hello again")
1356        );
1357    }
1358
1359    #[tokio::test]
1360    async fn stateful_chat_sample_and_append_carries_refusal_text() {
1361        let server = MockServer::start().await;
1362
1363        Mock::given(method("POST"))
1364            .and(path("/responses"))
1365            .respond_with(ResponseTemplate::new(200).set_body_json(json!({
1366                "id": "resp_126",
1367                "model": "grok-4",
1368                "output": [{
1369                    "type": "message",
1370                    "role": "assistant",
1371                    "content": [{"type": "refusal", "refusal": "I can't help with that."}]
1372                }],
1373                "usage": {"prompt_tokens": 10, "completion_tokens": 1, "total_tokens": 11}
1374            })))
1375            .mount(&server)
1376            .await;
1377
1378        let client = XaiClient::builder()
1379            .api_key("test-key")
1380            .base_url(server.uri())
1381            .build()
1382            .unwrap();
1383
1384        let mut chat = client.responses().chat("grok-4");
1385        chat.append_system("You are helpful.")
1386            .append_user("Do something unsafe");
1387
1388        let response = chat.sample_and_append().await.unwrap();
1389        assert!(response.output_text().is_none());
1390        assert_eq!(chat.messages().len(), 3);
1391        assert!(chat.pending_tool_calls().is_empty());
1392        assert_eq!(
1393            chat.messages().last().unwrap().content.as_text(),
1394            Some("I can't help with that.")
1395        );
1396    }
1397
1398    #[tokio::test]
1399    async fn stateful_chat_sample_and_append_captures_tool_calls() {
1400        let server = MockServer::start().await;
1401
1402        Mock::given(method("POST"))
1403            .and(path("/responses"))
1404            .respond_with(ResponseTemplate::new(200).set_body_json(json!({
1405                "id": "resp_127",
1406                "model": "grok-4",
1407                "output": [{
1408                    "type": "function_call",
1409                    "id": "call_weather",
1410                    "function": {
1411                        "name": "get_weather",
1412                        "arguments": "{\"location\":\"Paris\"}"
1413                    }
1414                }],
1415                "usage": {"prompt_tokens": 10, "completion_tokens": 1, "total_tokens": 11}
1416            })))
1417            .mount(&server)
1418            .await;
1419
1420        let client = XaiClient::builder()
1421            .api_key("test-key")
1422            .base_url(server.uri())
1423            .build()
1424            .unwrap();
1425
1426        let mut chat = client.responses().chat("grok-4");
1427        chat.append_system("You are helpful.")
1428            .append_user("What's weather in Paris?");
1429
1430        let _ = chat.sample_and_append().await.unwrap();
1431
1432        // No assistant text message to append, only a pending tool call to resolve.
1433        assert_eq!(chat.messages().len(), 2);
1434        assert_eq!(chat.pending_tool_calls().len(), 1);
1435        assert_eq!(chat.pending_tool_calls()[0].id, "call_weather");
1436        assert_eq!(
1437            chat.pending_tool_calls()[0]
1438                .function
1439                .as_ref()
1440                .map(|f| f.name.as_str()),
1441            Some("get_weather")
1442        );
1443
1444        let pending = chat.take_pending_tool_calls();
1445        assert_eq!(pending.len(), 1);
1446        assert!(chat.pending_tool_calls().is_empty());
1447
1448        chat.append_tool_result("call_weather", r#"{"temperature": 72}"#);
1449        assert_eq!(chat.messages().len(), 3);
1450        assert!(matches!(chat.messages().last().unwrap().role, Role::Tool));
1451    }
1452
1453    #[tokio::test]
1454    async fn stateful_chat_sample_with_tool_handler_resolves_tool_loop() {
1455        let server = MockServer::start().await;
1456        let call_count = Arc::new(AtomicUsize::new(0));
1457        let responder_count = Arc::clone(&call_count);
1458
1459        Mock::given(method("POST"))
1460            .and(path("/responses"))
1461            .respond_with(move |_req: &wiremock::Request| {
1462                let count = responder_count.fetch_add(1, Ordering::SeqCst);
1463                if count == 0 {
1464                    ResponseTemplate::new(200).set_body_json(json!({
1465                        "id": "resp_tool_1",
1466                        "model": "grok-4",
1467                        "output": [{
1468                            "type": "function_call",
1469                            "id": "call_weather",
1470                            "function": {
1471                                "name": "get_weather",
1472                                "arguments": "{\"location\":\"Paris\"}"
1473                            }
1474                        }],
1475                        "usage": {"prompt_tokens": 10, "completion_tokens": 1, "total_tokens": 11}
1476                    }))
1477                } else {
1478                    ResponseTemplate::new(200).set_body_json(json!({
1479                        "id": "resp_tool_2",
1480                        "model": "grok-4",
1481                        "output": [{
1482                            "type": "message",
1483                            "role": "assistant",
1484                            "content": [{"type": "text", "text": "Weather is 72F"}]
1485                        }],
1486                        "usage": {"prompt_tokens": 12, "completion_tokens": 3, "total_tokens": 15}
1487                    }))
1488                }
1489            })
1490            .mount(&server)
1491            .await;
1492
1493        let client = XaiClient::builder()
1494            .api_key("test-key")
1495            .base_url(server.uri())
1496            .build()
1497            .unwrap();
1498
1499        let mut chat = client.responses().chat("grok-4");
1500        chat.append_system("You are helpful.")
1501            .append_user("What's the weather in Paris?");
1502
1503        let response = chat
1504            .sample_with_tool_handler(3, |call| async move {
1505                if call.id == "call_weather" {
1506                    Ok(r#"{"temperature": 72}"#.to_string())
1507                } else {
1508                    Ok("{}".to_string())
1509                }
1510            })
1511            .await
1512            .unwrap();
1513
1514        assert_eq!(response.output_text().as_deref(), Some("Weather is 72F"));
1515        assert_eq!(call_count.load(Ordering::SeqCst), 2);
1516        assert!(chat.pending_tool_calls().is_empty());
1517        assert_eq!(chat.messages().len(), 4);
1518        assert!(matches!(chat.messages()[2].role, Role::Tool));
1519        assert!(matches!(chat.messages()[3].role, Role::Assistant));
1520    }
1521
1522    #[tokio::test]
1523    async fn stateful_chat_sample_with_tool_handler_errors_when_rounds_exhausted() {
1524        let server = MockServer::start().await;
1525
1526        Mock::given(method("POST"))
1527            .and(path("/responses"))
1528            .respond_with(ResponseTemplate::new(200).set_body_json(json!({
1529                "id": "resp_tool_loop",
1530                "model": "grok-4",
1531                "output": [{
1532                    "type": "function_call",
1533                    "id": "call_loop",
1534                    "function": {
1535                        "name": "looping_tool",
1536                        "arguments": "{}"
1537                    }
1538                }],
1539                "usage": {"prompt_tokens": 10, "completion_tokens": 1, "total_tokens": 11}
1540            })))
1541            .mount(&server)
1542            .await;
1543
1544        let client = XaiClient::builder()
1545            .api_key("test-key")
1546            .base_url(server.uri())
1547            .build()
1548            .unwrap();
1549
1550        let mut chat = client.responses().chat("grok-4");
1551        chat.append_system("You are helpful.")
1552            .append_user("Trigger loop");
1553
1554        let err = chat
1555            .sample_with_tool_handler(1, |_call| async move { Ok("{}".to_string()) })
1556            .await
1557            .unwrap_err();
1558
1559        match err {
1560            Error::Config(message) => {
1561                assert!(message.contains("stateful chat tool loop exceeded max rounds (1)"))
1562            }
1563            _ => panic!("expected Error::Config"),
1564        }
1565    }
1566
1567    #[tokio::test]
1568    async fn stateful_chat_sample_with_tool_handler_propagates_handler_error() {
1569        let server = MockServer::start().await;
1570
1571        Mock::given(method("POST"))
1572            .and(path("/responses"))
1573            .respond_with(ResponseTemplate::new(200).set_body_json(json!({
1574                "id": "resp_tool_error",
1575                "model": "grok-4",
1576                "output": [{
1577                    "type": "function_call",
1578                    "id": "call_fail",
1579                    "function": {
1580                        "name": "failing_tool",
1581                        "arguments": "{}"
1582                    }
1583                }],
1584                "usage": {"prompt_tokens": 10, "completion_tokens": 1, "total_tokens": 11}
1585            })))
1586            .mount(&server)
1587            .await;
1588
1589        let client = XaiClient::builder()
1590            .api_key("test-key")
1591            .base_url(server.uri())
1592            .build()
1593            .unwrap();
1594
1595        let mut chat = client.responses().chat("grok-4");
1596        chat.append_system("You are helpful.")
1597            .append_user("Trigger tool error");
1598
1599        let err = chat
1600            .sample_with_tool_handler(3, |_call| async move {
1601                Err(Error::Config("tool handler failed".to_string()))
1602            })
1603            .await
1604            .unwrap_err();
1605
1606        match err {
1607            Error::Config(message) => assert!(message.contains("tool handler failed")),
1608            _ => panic!("expected Error::Config"),
1609        }
1610    }
1611
1612    #[tokio::test]
1613    async fn stateful_chat_stream_accumulates_deltas() {
1614        let payload = concat!(
1615            "data: {\"type\":\"response.output_item.delta\",\"delta\":{\"content\":[{\"type\":\"text\",\"text\":\"Hel\"}]}}\n\n",
1616            "data: {\"type\":\"response.output_item.delta\",\"delta\":{\"content\":[{\"type\":\"text\",\"text\":\"lo\"}]}}\n\n",
1617            "data: [DONE]\n\n"
1618        );
1619
1620        let chunks: Vec<std::result::Result<Bytes, reqwest::Error>> =
1621            vec![Ok(Bytes::from(payload.to_string()))];
1622        let raw_stream = stream::iter(chunks);
1623        let response_stream = ResponseStream::new(raw_stream);
1624        let mut stream = StatefulChatStream::new(response_stream);
1625
1626        let first = stream.next().await.unwrap().unwrap();
1627        assert_eq!(first.delta(), "Hel");
1628        assert_eq!(stream.accumulated_text(), "Hel");
1629
1630        let second = stream.next().await.unwrap().unwrap();
1631        assert_eq!(second.delta(), "lo");
1632        assert_eq!(stream.accumulated_text(), "Hello");
1633
1634        let done = stream.next().await.unwrap().unwrap();
1635        assert!(done.done);
1636        assert_eq!(stream.accumulated_text(), "Hello");
1637    }
1638
1639    #[tokio::test]
1640    async fn stateful_chat_stream_next_with_accumulated_returns_snapshot() {
1641        let payload = concat!(
1642            "data: {\"type\":\"response.output_item.delta\",\"delta\":{\"content\":[{\"type\":\"text\",\"text\":\"Hel\"}]}}\n\n",
1643            "data: {\"type\":\"response.output_item.delta\",\"delta\":{\"content\":[{\"type\":\"text\",\"text\":\"lo\"}]}}\n\n",
1644            "data: [DONE]\n\n"
1645        );
1646
1647        let chunks: Vec<std::result::Result<Bytes, reqwest::Error>> =
1648            vec![Ok(Bytes::from(payload.to_string()))];
1649        let raw_stream = stream::iter(chunks);
1650        let response_stream = ResponseStream::new(raw_stream);
1651        let mut stream = StatefulChatStream::new(response_stream);
1652
1653        let first = stream.next_with_accumulated().await.unwrap().unwrap();
1654        assert_eq!(first.chunk.delta(), "Hel");
1655        assert_eq!(first.accumulated_text, "Hel");
1656
1657        let second = stream.next_with_accumulated().await.unwrap().unwrap();
1658        assert_eq!(second.chunk.delta(), "lo");
1659        assert_eq!(second.accumulated_text, "Hello");
1660
1661        let done = stream.next_with_accumulated().await.unwrap().unwrap();
1662        assert!(done.chunk.done);
1663        assert_eq!(done.accumulated_text, "Hello");
1664    }
1665
1666    #[tokio::test]
1667    async fn deferred_poller_returns_when_response_has_output() {
1668        let server = MockServer::start().await;
1669        let call_count = Arc::new(AtomicUsize::new(0));
1670        let call_count_for_responder = Arc::clone(&call_count);
1671
1672        Mock::given(method("GET"))
1673            .and(path("/responses/resp_deferred"))
1674            .respond_with(move |_req: &wiremock::Request| {
1675                let current = call_count_for_responder.fetch_add(1, Ordering::SeqCst);
1676                if current == 0 {
1677                    ResponseTemplate::new(200).set_body_json(json!({
1678                        "id": "resp_deferred",
1679                        "model": "grok-4",
1680                        "output": [],
1681                        "usage": {"prompt_tokens": 1, "completion_tokens": 0, "total_tokens": 1}
1682                    }))
1683                } else {
1684                    ResponseTemplate::new(200).set_body_json(json!({
1685                        "id": "resp_deferred",
1686                        "model": "grok-4",
1687                        "output": [{
1688                            "type": "message",
1689                            "role": "assistant",
1690                            "content": [{"type": "text", "text": "ready"}]
1691                        }],
1692                        "usage": {"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}
1693                    }))
1694                }
1695            })
1696            .mount(&server)
1697            .await;
1698
1699        let client = XaiClient::builder()
1700            .api_key("test-key")
1701            .base_url(server.uri())
1702            .build()
1703            .unwrap();
1704
1705        let response = client
1706            .responses()
1707            .deferred("resp_deferred")
1708            .max_attempts(3)
1709            .wait()
1710            .await
1711            .unwrap();
1712
1713        assert_eq!(response.output_text().as_deref(), Some("ready"));
1714        assert_eq!(call_count.load(Ordering::SeqCst), 2);
1715    }
1716
1717    #[tokio::test]
1718    async fn deferred_poller_times_out_when_output_never_arrives() {
1719        let server = MockServer::start().await;
1720
1721        Mock::given(method("GET"))
1722            .and(path("/responses/resp_timeout"))
1723            .respond_with(ResponseTemplate::new(200).set_body_json(json!({
1724                "id": "resp_timeout",
1725                "model": "grok-4",
1726                "output": [],
1727                "usage": {"prompt_tokens": 1, "completion_tokens": 0, "total_tokens": 1}
1728            })))
1729            .mount(&server)
1730            .await;
1731
1732        let client = XaiClient::builder()
1733            .api_key("test-key")
1734            .base_url(server.uri())
1735            .build()
1736            .unwrap();
1737
1738        let err = client
1739            .responses()
1740            .deferred("resp_timeout")
1741            .max_attempts(2)
1742            .wait()
1743            .await
1744            .unwrap_err();
1745
1746        assert!(matches!(err, Error::Timeout));
1747    }
1748
1749    #[test]
1750    fn deferred_poller_poll_delay_returns_zero_when_interval_is_zero() {
1751        assert_eq!(
1752            DeferredResponsePoller::poll_delay_for(Duration::ZERO, Duration::from_millis(200), 0),
1753            Duration::ZERO
1754        );
1755        assert_eq!(
1756            DeferredResponsePoller::poll_delay_for(Duration::from_millis(200), Duration::ZERO, 0),
1757            Duration::ZERO
1758        );
1759    }
1760
1761    #[tokio::test]
1762    async fn stateful_chat_sample_with_tool_handler_minimum_round_is_one() {
1763        let server = MockServer::start().await;
1764        let call_count = Arc::new(AtomicUsize::new(0));
1765        let count_for_responses = Arc::clone(&call_count);
1766        let handler_invocations = Arc::new(AtomicUsize::new(0));
1767        let handler_invocations_for_handler = Arc::clone(&handler_invocations);
1768
1769        Mock::given(method("POST"))
1770            .and(path("/responses"))
1771            .respond_with(move |_req: &wiremock::Request| {
1772                let current = count_for_responses.fetch_add(1, Ordering::SeqCst);
1773                if current == 0 {
1774                    ResponseTemplate::new(200).set_body_json(json!( {
1775                        "id": "resp_tool_1",
1776                        "model": "grok-4",
1777                        "output": [{
1778                            "type": "message",
1779                            "role": "assistant",
1780                            "content": [{ "type": "text", "text": "Recovered" }]
1781                        }],
1782                        "usage": {"prompt_tokens": 12, "completion_tokens": 3, "total_tokens": 15}
1783                    }))
1784                } else {
1785                    panic!("unexpected second request")
1786                }
1787            })
1788            .mount(&server)
1789            .await;
1790
1791        let client = XaiClient::builder()
1792            .api_key("test-key")
1793            .base_url(server.uri())
1794            .build()
1795            .unwrap();
1796
1797        let mut chat = client.responses().chat("grok-4");
1798        chat.append_system("You are helpful.").append_user("Hello");
1799
1800        let response = chat
1801            .sample_with_tool_handler(0, move |_call| {
1802                let handler_invocations_for_handler = Arc::clone(&handler_invocations_for_handler);
1803                async move {
1804                    handler_invocations_for_handler.fetch_add(1, Ordering::SeqCst);
1805                    Ok(r#"{"value":42}"#.to_string())
1806                }
1807            })
1808            .await
1809            .unwrap();
1810
1811        assert_eq!(response.output_text().as_deref(), Some("Recovered"));
1812        assert_eq!(call_count.load(Ordering::SeqCst), 1);
1813        assert_eq!(handler_invocations.load(Ordering::SeqCst), 0);
1814    }
1815
1816    #[tokio::test]
1817    async fn deferred_poller_wait_uses_poll_delay_between_attempts() {
1818        let server = MockServer::start().await;
1819        let request_count = Arc::new(AtomicUsize::new(0));
1820        let count_for_handler = Arc::clone(&request_count);
1821
1822        Mock::given(method("GET"))
1823            .and(path("/responses/resp_poll_delay"))
1824            .respond_with(move |_req: &wiremock::Request| {
1825                let current = count_for_handler.fetch_add(1, Ordering::SeqCst);
1826                if current == 0 {
1827                    ResponseTemplate::new(200).set_body_json(json!({
1828                        "id": "resp_poll_delay",
1829                        "model": "grok-4",
1830                        "output": [],
1831                        "usage": {"prompt_tokens": 4, "completion_tokens": 0, "total_tokens": 4}
1832                    }))
1833                } else {
1834                    ResponseTemplate::new(200).set_body_json(json!({
1835                        "id": "resp_poll_delay",
1836                        "model": "grok-4",
1837                        "output": [{
1838                            "type": "message",
1839                            "role": "assistant",
1840                            "content": [{"type": "text", "text": "ready"}]
1841                        }],
1842                        "usage": {"prompt_tokens": 4, "completion_tokens": 1, "total_tokens": 5}
1843                    }))
1844                }
1845            })
1846            .mount(&server)
1847            .await;
1848
1849        let client = XaiClient::builder()
1850            .api_key("test-key")
1851            .base_url(server.uri())
1852            .build()
1853            .unwrap();
1854
1855        let started_at = std::time::Instant::now();
1856        let response = client
1857            .responses()
1858            .deferred("resp_poll_delay")
1859            .poll_interval(Duration::from_millis(80))
1860            .max_attempts(2)
1861            .wait()
1862            .await
1863            .unwrap();
1864        let elapsed = started_at.elapsed();
1865
1866        assert_eq!(response.output_text().as_deref(), Some("ready"));
1867        assert_eq!(request_count.load(Ordering::SeqCst), 2);
1868        assert!(
1869            elapsed >= Duration::from_millis(60),
1870            "expected exponential backoff delay to be observed"
1871        );
1872        assert!(
1873            elapsed < Duration::from_millis(500),
1874            "expected delay test to stay bounded"
1875        );
1876    }
1877
1878    #[test]
1879    fn stateful_chat_merge_output_message_content_and_collect_response_semantics() {
1880        let text_output = vec![
1881            TextContent::Text {
1882                text: "Hello ".to_string(),
1883            },
1884            TextContent::Refusal {
1885                refusal: "Blocked".to_string(),
1886            },
1887        ];
1888        assert_eq!(
1889            StatefulChat::merge_output_message_content(&text_output),
1890            Some("Hello Blocked".to_string())
1891        );
1892        assert_eq!(StatefulChat::merge_output_message_content(&[]), None);
1893
1894        let shared_tool_call_id = "shared-call".to_string();
1895        let response = Response {
1896            id: "resp_semantics".to_string(),
1897            model: "grok-4".to_string(),
1898            output: vec![
1899                OutputItem::Message {
1900                    role: Role::Assistant,
1901                    content: vec![
1902                        TextContent::Text {
1903                            text: "Part 1 ".to_string(),
1904                        },
1905                        TextContent::Text {
1906                            text: "Part 2".to_string(),
1907                        },
1908                    ],
1909                },
1910                OutputItem::FunctionCall {
1911                    call: ToolCall {
1912                        id: shared_tool_call_id.clone(),
1913                        call_type: Some("function".to_string()),
1914                        function: None,
1915                    },
1916                },
1917                OutputItem::FunctionCall {
1918                    call: ToolCall {
1919                        id: "function_call".to_string(),
1920                        call_type: Some("function".to_string()),
1921                        function: None,
1922                    },
1923                },
1924                OutputItem::CodeInterpreterCall {
1925                    id: "ci_call".to_string(),
1926                    code: None,
1927                    outputs: None,
1928                },
1929                OutputItem::WebSearchCall {
1930                    id: "web_call".to_string(),
1931                    results: None,
1932                },
1933                OutputItem::XSearchCall {
1934                    id: "x_call".to_string(),
1935                    results: None,
1936                },
1937            ],
1938            usage: Default::default(),
1939            citations: None,
1940            inline_citations: None,
1941            server_side_tool_usage: None,
1942            tool_calls: Some(vec![ToolCall {
1943                id: shared_tool_call_id.clone(),
1944                call_type: Some("function".to_string()),
1945                function: None,
1946            }]),
1947            system_fingerprint: None,
1948        };
1949
1950        let (assistant_messages, pending_tool_calls) =
1951            StatefulChat::collect_response_semantics(&response);
1952        assert_eq!(assistant_messages, vec!["Part 1 Part 2".to_string()]);
1953        assert_eq!(pending_tool_calls.len(), 5);
1954        assert_eq!(pending_tool_calls[0].id, shared_tool_call_id);
1955        assert_eq!(pending_tool_calls[1].id, "function_call");
1956        assert_eq!(pending_tool_calls[1].call_type.as_deref(), Some("function"));
1957        assert_eq!(pending_tool_calls[2].id, "ci_call");
1958        assert_eq!(
1959            pending_tool_calls[2].call_type.as_deref(),
1960            Some("code_interpreter")
1961        );
1962        assert_eq!(pending_tool_calls[3].id, "web_call");
1963        assert_eq!(
1964            pending_tool_calls[3].call_type.as_deref(),
1965            Some("web_search")
1966        );
1967        assert_eq!(pending_tool_calls[4].id, "x_call");
1968        assert_eq!(pending_tool_calls[4].call_type.as_deref(), Some("x_search"));
1969    }
1970
1971    #[test]
1972    fn deferred_poller_poll_interval_sets_fixed_delay() {
1973        let poller = DeferredResponsePoller::new(
1974            XaiClient::new("test-key").unwrap(),
1975            "resp_123".to_string(),
1976        )
1977        .poll_interval(Duration::from_millis(250));
1978
1979        assert_eq!(poller.poll_initial_delay, Duration::from_millis(250));
1980        assert_eq!(poller.poll_max_delay, Duration::from_millis(250));
1981        assert_eq!(
1982            DeferredResponsePoller::poll_delay_for(
1983                poller.poll_initial_delay,
1984                poller.poll_max_delay,
1985                0
1986            ),
1987            Duration::from_millis(250)
1988        );
1989        assert_eq!(
1990            DeferredResponsePoller::poll_delay_for(
1991                poller.poll_initial_delay,
1992                poller.poll_max_delay,
1993                5
1994            ),
1995            Duration::from_millis(250)
1996        );
1997    }
1998
1999    #[test]
2000    fn deferred_poller_poll_backoff_is_exponential_and_capped() {
2001        let poller = DeferredResponsePoller::new(
2002            XaiClient::new("test-key").unwrap(),
2003            "resp_123".to_string(),
2004        )
2005        .poll_backoff(Duration::from_millis(100), Duration::from_millis(300));
2006
2007        assert_eq!(
2008            DeferredResponsePoller::poll_delay_for(
2009                poller.poll_initial_delay,
2010                poller.poll_max_delay,
2011                0
2012            ),
2013            Duration::from_millis(100)
2014        );
2015        assert_eq!(
2016            DeferredResponsePoller::poll_delay_for(
2017                poller.poll_initial_delay,
2018                poller.poll_max_delay,
2019                1
2020            ),
2021            Duration::from_millis(200)
2022        );
2023        assert_eq!(
2024            DeferredResponsePoller::poll_delay_for(
2025                poller.poll_initial_delay,
2026                poller.poll_max_delay,
2027                2
2028            ),
2029            Duration::from_millis(300)
2030        );
2031        assert_eq!(
2032            DeferredResponsePoller::poll_delay_for(
2033                poller.poll_initial_delay,
2034                poller.poll_max_delay,
2035                3
2036            ),
2037            Duration::from_millis(300)
2038        );
2039    }
2040}