Skip to main content

rig_core/
streaming.rs

1//! This module provides functionality for working with streaming completion models.
2//! It provides traits and types for generating streaming completion requests and
3//! handling streaming completion responses.
4//!
5//! The main traits defined in this module are:
6//! - [StreamingPrompt]: Defines a high-level streaming LLM one-shot prompt interface
7//! - [StreamingChat]: Defines a high-level streaming LLM chat interface with history
8//! - [StreamingCompletion]: Defines a low-level streaming LLM completion interface
9//!
10
11use crate::OneOrMany;
12use crate::agent::Agent;
13use crate::agent::prompt_request::hooks::PromptHook;
14use crate::agent::prompt_request::streaming::StreamingPromptRequest;
15use crate::completion::{
16    CompletionError, CompletionModel, CompletionRequestBuilder, CompletionResponse, GetTokenUsage,
17    Message, Usage,
18};
19use crate::message::{
20    AssistantContent, Reasoning, ReasoningContent, Text, ToolCall, ToolFunction, ToolResult,
21};
22use crate::wasm_compat::{WasmCompatSend, WasmCompatSync};
23use futures::stream::{AbortHandle, Abortable};
24use futures::{Stream, StreamExt};
25use serde::{Deserialize, Serialize};
26use std::future::Future;
27use std::pin::Pin;
28use std::sync::atomic::AtomicBool;
29use std::task::{Context, Poll};
30use tokio::sync::watch;
31
32/// Control for pausing and resuming a streaming response
33pub struct PauseControl {
34    pub(crate) paused_tx: watch::Sender<bool>,
35    pub(crate) paused_rx: watch::Receiver<bool>,
36}
37
38impl PauseControl {
39    /// Create a pause controller in the running state.
40    pub fn new() -> Self {
41        let (paused_tx, paused_rx) = watch::channel(false);
42        Self {
43            paused_tx,
44            paused_rx,
45        }
46    }
47
48    /// Pause polling of the public stream until [`PauseControl::resume`] is called.
49    pub fn pause(&self) {
50        let _ = self.paused_tx.send(true);
51    }
52
53    /// Resume polling after a pause.
54    pub fn resume(&self) {
55        let _ = self.paused_tx.send(false);
56    }
57
58    /// Returns whether the stream is currently paused.
59    pub fn is_paused(&self) -> bool {
60        *self.paused_rx.borrow()
61    }
62}
63
64impl Default for PauseControl {
65    fn default() -> Self {
66        Self::new()
67    }
68}
69
70/// The content of a tool call delta - either the tool name or argument data
71#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
72pub enum ToolCallDeltaContent {
73    /// Tool/function name emitted by the provider.
74    Name(String),
75    /// Partial JSON argument data emitted by the provider.
76    Delta(String),
77}
78
79/// Enum representing a streaming chunk from the model
80#[derive(Debug, Clone)]
81pub enum RawStreamingChoice<R>
82where
83    R: Clone,
84{
85    /// A text chunk from a message response
86    Message(String),
87
88    /// Start a new text content block in the accumulated final choice.
89    ///
90    /// This is an internal provider-normalization event. It is not yielded to
91    /// public stream consumers, but lets providers preserve metadata boundaries
92    /// for final aggregated assistant text blocks.
93    TextStart {
94        /// Provider-specific metadata attached to this text block.
95        additional_params: Option<serde_json::Value>,
96    },
97
98    /// Provider-specific metadata for the current text content block.
99    ///
100    /// This is not yielded to public stream consumers. The metadata is merged
101    /// into the current aggregated [`Text`] block.
102    TextAdditionalParams(serde_json::Value),
103
104    /// A tool call response (in its entirety)
105    ToolCall(RawStreamingToolCall),
106    /// A tool call partial/delta
107    ToolCallDelta {
108        /// Provider-supplied tool call ID.
109        id: String,
110        /// Rig-generated unique identifier for this tool call.
111        internal_call_id: String,
112        content: ToolCallDeltaContent,
113    },
114    /// A reasoning (in its entirety)
115    Reasoning {
116        /// Provider-supplied reasoning block ID, when present.
117        id: Option<String>,
118        /// Complete reasoning content block.
119        content: ReasoningContent,
120    },
121    /// A reasoning partial/delta
122    ReasoningDelta {
123        /// Provider-supplied reasoning block ID, when present.
124        id: Option<String>,
125        /// Partial reasoning text.
126        reasoning: String,
127    },
128
129    /// The final response object, must be yielded if you want the
130    /// `response` field to be populated on the `StreamingCompletionResponse`
131    FinalResponse(R),
132
133    /// Provider-assigned message ID (e.g. OpenAI Responses API `msg_` ID).
134    /// Captured silently into `StreamingCompletionResponse::message_id`.
135    MessageId(String),
136}
137
138/// Describes a streaming tool call response (in its entirety)
139#[derive(Debug, Clone)]
140pub struct RawStreamingToolCall {
141    /// Provider-supplied tool call ID.
142    pub id: String,
143    /// Rig-generated unique identifier for this tool call.
144    pub internal_call_id: String,
145    /// Provider-specific call ID used by some APIs for tool result correlation.
146    pub call_id: Option<String>,
147    /// Tool/function name.
148    pub name: String,
149    /// Parsed tool arguments.
150    pub arguments: serde_json::Value,
151    /// Optional provider signature associated with the tool call.
152    pub signature: Option<String>,
153    /// Additional provider-specific tool call metadata.
154    pub additional_params: Option<serde_json::Value>,
155}
156
157impl RawStreamingToolCall {
158    /// Create an empty tool call accumulator for provider streaming parsers.
159    pub fn empty() -> Self {
160        Self {
161            id: String::new(),
162            internal_call_id: nanoid::nanoid!(),
163            call_id: None,
164            name: String::new(),
165            arguments: serde_json::Value::Null,
166            signature: None,
167            additional_params: None,
168        }
169    }
170
171    /// Create a complete tool call with a generated internal call ID.
172    pub fn new(id: String, name: String, arguments: serde_json::Value) -> Self {
173        Self {
174            id,
175            internal_call_id: nanoid::nanoid!(),
176            call_id: None,
177            name,
178            arguments,
179            signature: None,
180            additional_params: None,
181        }
182    }
183
184    /// Override the generated internal call ID.
185    pub fn with_internal_call_id(mut self, internal_call_id: String) -> Self {
186        self.internal_call_id = internal_call_id;
187        self
188    }
189
190    /// Attach a provider-specific call ID.
191    pub fn with_call_id(mut self, call_id: String) -> Self {
192        self.call_id = Some(call_id);
193        self
194    }
195
196    /// Attach or clear a provider signature.
197    pub fn with_signature(mut self, signature: Option<String>) -> Self {
198        self.signature = signature;
199        self
200    }
201
202    /// Attach provider-specific metadata.
203    pub fn with_additional_params(mut self, additional_params: Option<serde_json::Value>) -> Self {
204        self.additional_params = additional_params;
205        self
206    }
207}
208
209impl From<RawStreamingToolCall> for ToolCall {
210    fn from(tool_call: RawStreamingToolCall) -> Self {
211        ToolCall {
212            id: tool_call.id,
213            call_id: tool_call.call_id,
214            function: ToolFunction {
215                name: tool_call.name,
216                arguments: tool_call.arguments,
217            },
218            signature: tool_call.signature,
219            additional_params: tool_call.additional_params,
220        }
221    }
222}
223
224#[cfg(not(all(feature = "wasm", target_arch = "wasm32")))]
225/// Provider stream of raw completion chunks on native targets.
226pub type StreamingResult<R> =
227    Pin<Box<dyn Stream<Item = Result<RawStreamingChoice<R>, CompletionError>> + Send>>;
228
229#[cfg(all(feature = "wasm", target_arch = "wasm32"))]
230/// Provider stream of raw completion chunks on wasm targets.
231pub type StreamingResult<R> =
232    Pin<Box<dyn Stream<Item = Result<RawStreamingChoice<R>, CompletionError>>>>;
233
234/// The response from a streaming completion request;
235/// message and response are populated at the end of the
236/// `inner` stream.
237pub struct StreamingCompletionResponse<R>
238where
239    R: Clone + Unpin + GetTokenUsage,
240{
241    pub(crate) inner: Abortable<StreamingResult<R>>,
242    pub(crate) abort_handle: AbortHandle,
243    pub(crate) pause_control: PauseControl,
244    assistant_items: Vec<AssistantContent>,
245    text_item_index: Option<usize>,
246    reasoning_item_index: Option<usize>,
247    /// The final aggregated message from the stream
248    /// contains all text and tool calls generated
249    pub choice: OneOrMany<AssistantContent>,
250    /// The final response from the stream, may be `None`
251    /// if the provider didn't yield it during the stream
252    pub response: Option<R>,
253    pub final_response_yielded: AtomicBool,
254    /// Provider-assigned message ID (e.g. OpenAI Responses API `msg_` ID).
255    pub message_id: Option<String>,
256}
257
258impl<R> StreamingCompletionResponse<R>
259where
260    R: Clone + Unpin + GetTokenUsage,
261{
262    /// Wrap a provider stream and initialize aggregation state.
263    pub fn stream(inner: StreamingResult<R>) -> StreamingCompletionResponse<R> {
264        let (abort_handle, abort_registration) = AbortHandle::new_pair();
265        let abortable_stream = Abortable::new(inner, abort_registration);
266        let pause_control = PauseControl::new();
267        Self {
268            inner: abortable_stream,
269            abort_handle,
270            pause_control,
271            assistant_items: vec![],
272            text_item_index: None,
273            reasoning_item_index: None,
274            choice: OneOrMany::one(AssistantContent::text("")),
275            response: None,
276            final_response_yielded: AtomicBool::new(false),
277            message_id: None,
278        }
279    }
280
281    /// Cancel the stream. Cancellation is surfaced as normal stream termination.
282    pub fn cancel(&self) {
283        self.abort_handle.abort();
284    }
285
286    /// Pause stream polling.
287    pub fn pause(&self) {
288        self.pause_control.pause();
289    }
290
291    /// Resume stream polling after a pause.
292    pub fn resume(&self) {
293        self.pause_control.resume();
294    }
295
296    /// Returns whether the stream is currently paused.
297    pub fn is_paused(&self) -> bool {
298        self.pause_control.is_paused()
299    }
300
301    fn append_text_chunk(&mut self, text: &str) {
302        if let Some(index) = self.text_item_index
303            && let Some(AssistantContent::Text(existing_text)) = self.assistant_items.get_mut(index)
304        {
305            existing_text.text.push_str(text);
306            return;
307        }
308
309        self.assistant_items
310            .push(AssistantContent::text(text.to_owned()));
311        self.text_item_index = Some(self.assistant_items.len() - 1);
312    }
313
314    fn append_text_additional_params(&mut self, additional_params: serde_json::Value) {
315        if additional_params.is_null() {
316            return;
317        }
318
319        let index = if let Some(index) = self.text_item_index
320            && matches!(
321                self.assistant_items.get(index),
322                Some(AssistantContent::Text(_))
323            ) {
324            index
325        } else {
326            self.assistant_items.push(AssistantContent::text(""));
327            let index = self.assistant_items.len() - 1;
328            self.text_item_index = Some(index);
329            index
330        };
331
332        let Some(AssistantContent::Text(text)) = self.assistant_items.get_mut(index) else {
333            return;
334        };
335
336        match text.additional_params.as_mut() {
337            Some(existing) => merge_text_additional_params(existing, additional_params),
338            None => text.additional_params = Some(additional_params),
339        }
340    }
341
342    /// Accumulate streaming reasoning delta text into assistant_items.
343    /// Providers that only emit ReasoningDelta (not full Reasoning blocks)
344    /// need this so the aggregated response includes reasoning content.
345    fn append_reasoning_chunk(&mut self, id: &Option<String>, text: &str) {
346        if let Some(index) = self.reasoning_item_index
347            && let Some(AssistantContent::Reasoning(existing)) = self.assistant_items.get_mut(index)
348            && let Some(ReasoningContent::Text {
349                text: existing_text,
350                ..
351            }) = existing.content.last_mut()
352        {
353            existing_text.push_str(text);
354            return;
355        }
356
357        self.assistant_items
358            .push(AssistantContent::Reasoning(Reasoning {
359                id: id.clone(),
360                content: vec![ReasoningContent::Text {
361                    text: text.to_string(),
362                    signature: None,
363                }],
364            }));
365        self.reasoning_item_index = Some(self.assistant_items.len() - 1);
366    }
367}
368
369fn merge_text_additional_params(existing: &mut serde_json::Value, incoming: serde_json::Value) {
370    match (existing, incoming) {
371        (serde_json::Value::Object(existing_map), serde_json::Value::Object(incoming_map)) => {
372            for (key, incoming_value) in incoming_map {
373                match existing_map.get_mut(&key) {
374                    Some(existing_value) => match (existing_value, incoming_value) {
375                        (
376                            serde_json::Value::Array(existing_array),
377                            serde_json::Value::Array(mut incoming_array),
378                        ) => existing_array.append(&mut incoming_array),
379                        (existing_value, incoming_value) => {
380                            merge_text_additional_params(existing_value, incoming_value);
381                        }
382                    },
383                    None => {
384                        existing_map.insert(key, incoming_value);
385                    }
386                }
387            }
388        }
389        (existing, incoming) => {
390            *existing = incoming;
391        }
392    }
393}
394
395impl<R> From<StreamingCompletionResponse<R>> for CompletionResponse<Option<R>>
396where
397    R: Clone + Unpin + GetTokenUsage,
398{
399    fn from(value: StreamingCompletionResponse<R>) -> CompletionResponse<Option<R>> {
400        CompletionResponse {
401            choice: value.choice,
402            usage: Usage::new(), // Usage is not tracked in streaming responses
403            raw_response: value.response,
404            message_id: value.message_id,
405        }
406    }
407}
408
409impl<R> Stream for StreamingCompletionResponse<R>
410where
411    R: Clone + Unpin + GetTokenUsage,
412{
413    type Item = Result<StreamedAssistantContent<R>, CompletionError>;
414
415    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
416        let stream = self.get_mut();
417
418        if stream.is_paused() {
419            cx.waker().wake_by_ref();
420            return Poll::Pending;
421        }
422
423        match Pin::new(&mut stream.inner).poll_next(cx) {
424            Poll::Pending => Poll::Pending,
425            Poll::Ready(None) => {
426                // This is run at the end of the inner stream to collect all tokens into
427                // a single unified `Message`.
428                if stream.assistant_items.is_empty() {
429                    stream.assistant_items.push(AssistantContent::text(""));
430                }
431
432                if let Some(choice) =
433                    OneOrMany::from_iter_optional(std::mem::take(&mut stream.assistant_items))
434                {
435                    stream.choice = choice;
436                }
437
438                Poll::Ready(None)
439            }
440            Poll::Ready(Some(Err(err))) => {
441                if matches!(err, CompletionError::ProviderError(ref e) if e.to_string().contains("aborted"))
442                {
443                    return Poll::Ready(None); // Treat cancellation as stream termination
444                }
445                Poll::Ready(Some(Err(err)))
446            }
447            Poll::Ready(Some(Ok(choice))) => match choice {
448                RawStreamingChoice::Message(text) => {
449                    stream.reasoning_item_index = None;
450                    stream.append_text_chunk(&text);
451                    Poll::Ready(Some(Ok(StreamedAssistantContent::text(&text))))
452                }
453                RawStreamingChoice::TextStart { additional_params } => {
454                    stream.reasoning_item_index = None;
455                    stream.text_item_index = None;
456                    if let Some(additional_params) = additional_params {
457                        stream.append_text_additional_params(additional_params);
458                    }
459                    stream.poll_next_unpin(cx)
460                }
461                RawStreamingChoice::TextAdditionalParams(additional_params) => {
462                    stream.append_text_additional_params(additional_params);
463                    stream.poll_next_unpin(cx)
464                }
465                RawStreamingChoice::ToolCallDelta {
466                    id,
467                    internal_call_id,
468                    content,
469                } => Poll::Ready(Some(Ok(StreamedAssistantContent::ToolCallDelta {
470                    id,
471                    internal_call_id,
472                    content,
473                }))),
474                RawStreamingChoice::Reasoning { id, content } => {
475                    let reasoning = Reasoning {
476                        id,
477                        content: vec![content],
478                    };
479                    stream.text_item_index = None;
480                    // Full reasoning block supersedes any delta accumulation
481                    stream.reasoning_item_index = None;
482                    stream
483                        .assistant_items
484                        .push(AssistantContent::Reasoning(reasoning.clone()));
485                    Poll::Ready(Some(Ok(StreamedAssistantContent::Reasoning(reasoning))))
486                }
487                RawStreamingChoice::ReasoningDelta { id, reasoning } => {
488                    stream.text_item_index = None;
489                    stream.append_reasoning_chunk(&id, &reasoning);
490                    Poll::Ready(Some(Ok(StreamedAssistantContent::ReasoningDelta {
491                        id,
492                        reasoning,
493                    })))
494                }
495                RawStreamingChoice::ToolCall(raw_tool_call) => {
496                    let internal_call_id = raw_tool_call.internal_call_id.clone();
497                    let tool_call: ToolCall = raw_tool_call.into();
498                    stream.text_item_index = None;
499                    stream.reasoning_item_index = None;
500                    stream
501                        .assistant_items
502                        .push(AssistantContent::ToolCall(tool_call.clone()));
503                    Poll::Ready(Some(Ok(StreamedAssistantContent::ToolCall {
504                        tool_call,
505                        internal_call_id,
506                    })))
507                }
508                RawStreamingChoice::FinalResponse(response) => {
509                    if stream
510                        .final_response_yielded
511                        .load(std::sync::atomic::Ordering::SeqCst)
512                    {
513                        stream.poll_next_unpin(cx)
514                    } else {
515                        // Set the final response field and return the next item in the stream
516                        stream.response = Some(response.clone());
517                        stream
518                            .final_response_yielded
519                            .store(true, std::sync::atomic::Ordering::SeqCst);
520                        let final_response = StreamedAssistantContent::final_response(response);
521                        Poll::Ready(Some(Ok(final_response)))
522                    }
523                }
524                RawStreamingChoice::MessageId(id) => {
525                    stream.message_id = Some(id);
526                    stream.poll_next_unpin(cx)
527                }
528            },
529        }
530    }
531}
532
533/// Trait for high-level streaming prompt interface.
534///
535/// This trait provides a simple interface for streaming prompts to a completion model.
536/// Implementations can optionally support prompt hooks for observing and controlling
537/// the agent's execution lifecycle.
538pub trait StreamingPrompt<M, R>
539where
540    M: CompletionModel + 'static,
541    <M as CompletionModel>::StreamingResponse: WasmCompatSend,
542    R: Clone + Unpin + GetTokenUsage,
543{
544    /// The hook type used by this streaming prompt implementation.
545    ///
546    /// If your implementation does not need prompt hooks, use `()` as the hook type:
547    ///
548    /// ```ignore
549    /// impl<M, R> StreamingPrompt<M, R> for MyType<M>
550    /// where
551    ///     M: CompletionModel + 'static,
552    ///     // ... other bounds ...
553    /// {
554    ///     type Hook = ();
555    ///
556    ///     fn stream_prompt(&self, prompt: impl Into<Message>) -> StreamingPromptRequest<M, ()> {
557    ///         // ...
558    ///     }
559    /// }
560    /// ```
561    type Hook: PromptHook<M>;
562
563    /// Stream a simple prompt to the model
564    fn stream_prompt(
565        &self,
566        prompt: impl Into<Message> + WasmCompatSend,
567    ) -> StreamingPromptRequest<M, Self::Hook>;
568}
569
570/// Trait for high-level streaming chat interface with conversation history.
571///
572/// This trait provides an interface for streaming chat completions with support
573/// for maintaining conversation history. Implementations can optionally support
574/// prompt hooks for observing and controlling the agent's execution lifecycle.
575pub trait StreamingChat<M, R>: WasmCompatSend + WasmCompatSync
576where
577    M: CompletionModel + 'static,
578    <M as CompletionModel>::StreamingResponse: WasmCompatSend,
579    R: Clone + Unpin + GetTokenUsage,
580{
581    /// The hook type used by this streaming chat implementation.
582    ///
583    /// If your implementation does not need prompt hooks, use `()` as the hook type:
584    ///
585    /// ```ignore
586    /// impl<M, R> StreamingChat<M, R> for MyType<M>
587    /// where
588    ///     M: CompletionModel + 'static,
589    ///     // ... other bounds ...
590    /// {
591    ///     type Hook = ();
592    ///
593    ///     fn stream_chat(
594    ///         &self,
595    ///         prompt: impl Into<Message>,
596    ///         chat_history: Vec<Message>,
597    ///     ) -> StreamingPromptRequest<M, ()> {
598    ///         // ...
599    ///     }
600    /// }
601    /// ```
602    type Hook: PromptHook<M>;
603
604    /// Stream a chat with history to the model.
605    ///
606    /// The messages returned by the model can be accessed via `FinalResponse::history()`
607    ///
608    /// You are responsible for managing history, a simple linear solution could look like:
609    /// ```ignore
610    ///  let mut history = vec![];
611    ///
612    ///  loop {
613    ///      let prompt = "Create GPT-67, make no mistakes";
614    ///      let mut stream = agent.stream_chat(prompt, &history).await;
615    ///
616    ///      while let Some(msg) = stream.next().await {
617    ///         match msg {
618    ///              Ok(MultiTurnStreamItem::FinalResponse(fin)) => {
619    ///                  history.extend_from_slice(fin.history().unwrap_or_default());
620    ///                  break;
621    ///             }
622    ///             Ok(_other) => { /* Do something with this chunk */ }
623    ///             Err(e) => return Err(e.into()),
624    ///         }
625    ///     }
626    /// }
627    /// ```
628    fn stream_chat<I, T>(
629        &self,
630        prompt: impl Into<Message> + WasmCompatSend,
631        chat_history: I,
632    ) -> StreamingPromptRequest<M, Self::Hook>
633    where
634        I: IntoIterator<Item = T> + WasmCompatSend,
635        T: Into<Message>;
636}
637
638/// Trait for low-level streaming completion interface
639pub trait StreamingCompletion<M: CompletionModel> {
640    /// Generate a streaming completion from a request
641    fn stream_completion<I, T>(
642        &self,
643        prompt: impl Into<Message> + WasmCompatSend,
644        chat_history: I,
645    ) -> impl Future<Output = Result<CompletionRequestBuilder<M>, CompletionError>>
646    where
647        I: IntoIterator<Item = T> + WasmCompatSend,
648        T: Into<Message>;
649}
650
651/// A helper function to stream a completion request to stdout.
652/// Tool call deltas are ignored as tool calls are generally much easier to handle when received in their entirety rather than using deltas.
653pub async fn stream_to_stdout<M>(
654    agent: &'static Agent<M>,
655    stream: &mut StreamingCompletionResponse<M::StreamingResponse>,
656) -> Result<(), std::io::Error>
657where
658    M: CompletionModel,
659{
660    let mut is_reasoning = false;
661    print!("Response: ");
662    while let Some(chunk) = stream.next().await {
663        match chunk {
664            Ok(StreamedAssistantContent::Text(text)) => {
665                if is_reasoning {
666                    is_reasoning = false;
667                    println!("\n---\n");
668                }
669                print!("{}", text.text);
670                std::io::Write::flush(&mut std::io::stdout())?;
671            }
672            Ok(StreamedAssistantContent::ToolCall {
673                tool_call,
674                internal_call_id: _,
675            }) => {
676                let res = agent
677                    .tool_server_handle
678                    .call_tool(
679                        &tool_call.function.name,
680                        &tool_call.function.arguments.to_string(),
681                    )
682                    .await
683                    .map_err(|x| std::io::Error::other(x.to_string()))?;
684                println!("\nResult: {res}");
685            }
686            Ok(StreamedAssistantContent::Final(res)) => {
687                if let Ok(json_res) = serde_json::to_string_pretty(&res) {
688                    println!();
689                    tracing::info!("Final result: {json_res}");
690                }
691            }
692            Ok(StreamedAssistantContent::Reasoning(reasoning)) => {
693                if !is_reasoning {
694                    is_reasoning = true;
695                    println!();
696                    println!("Thinking: ");
697                }
698                let reasoning = reasoning.display_text();
699
700                print!("{reasoning}");
701                std::io::Write::flush(&mut std::io::stdout())?;
702            }
703            Err(e) => {
704                if e.to_string().contains("aborted") {
705                    println!("\nStream cancelled.");
706                    break;
707                }
708                eprintln!("Error: {e}");
709                break;
710            }
711            _ => {}
712        }
713    }
714
715    println!(); // New line after streaming completes
716
717    Ok(())
718}
719
720// Test module
721#[cfg(test)]
722mod tests {
723    use std::time::Duration;
724
725    use super::*;
726    use crate::test_utils::MockResponse;
727    use async_stream::stream;
728    use tokio::time::sleep;
729
730    #[cfg(not(all(feature = "wasm", target_arch = "wasm32")))]
731    fn to_stream_result(
732        stream: impl futures::Stream<Item = Result<RawStreamingChoice<MockResponse>, CompletionError>>
733        + Send
734        + 'static,
735    ) -> StreamingResult<MockResponse> {
736        Box::pin(stream)
737    }
738
739    #[cfg(all(feature = "wasm", target_arch = "wasm32"))]
740    fn to_stream_result(
741        stream: impl futures::Stream<Item = Result<RawStreamingChoice<MockResponse>, CompletionError>>
742        + 'static,
743    ) -> StreamingResult<MockResponse> {
744        Box::pin(stream)
745    }
746
747    fn create_mock_stream() -> StreamingCompletionResponse<MockResponse> {
748        let stream = stream! {
749            yield Ok(RawStreamingChoice::Message("hello 1".to_string()));
750            sleep(Duration::from_millis(100)).await;
751            yield Ok(RawStreamingChoice::Message("hello 2".to_string()));
752            sleep(Duration::from_millis(100)).await;
753            yield Ok(RawStreamingChoice::Message("hello 3".to_string()));
754            sleep(Duration::from_millis(100)).await;
755            yield Ok(RawStreamingChoice::FinalResponse(MockResponse::with_total_tokens(15)));
756        };
757
758        StreamingCompletionResponse::stream(to_stream_result(stream))
759    }
760
761    fn create_reasoning_stream() -> StreamingCompletionResponse<MockResponse> {
762        let stream = stream! {
763            yield Ok(RawStreamingChoice::Reasoning {
764                id: Some("rs_1".to_string()),
765                content: ReasoningContent::Text {
766                    text: "step one".to_string(),
767                    signature: Some("sig_1".to_string()),
768                },
769            });
770            yield Ok(RawStreamingChoice::Message("final answer".to_string()));
771            yield Ok(RawStreamingChoice::FinalResponse(MockResponse::with_total_tokens(5)));
772        };
773
774        StreamingCompletionResponse::stream(to_stream_result(stream))
775    }
776
777    fn create_reasoning_only_stream() -> StreamingCompletionResponse<MockResponse> {
778        let stream = stream! {
779            yield Ok(RawStreamingChoice::Reasoning {
780                id: Some("rs_only".to_string()),
781                content: ReasoningContent::Summary("hidden summary".to_string()),
782            });
783            yield Ok(RawStreamingChoice::FinalResponse(MockResponse::with_total_tokens(2)));
784        };
785
786        StreamingCompletionResponse::stream(to_stream_result(stream))
787    }
788
789    fn create_interleaved_stream() -> StreamingCompletionResponse<MockResponse> {
790        let stream = stream! {
791            yield Ok(RawStreamingChoice::Reasoning {
792                id: Some("rs_interleaved".to_string()),
793                content: ReasoningContent::Text {
794                    text: "chain-of-thought".to_string(),
795                    signature: None,
796                },
797            });
798            yield Ok(RawStreamingChoice::Message("final-text".to_string()));
799            yield Ok(RawStreamingChoice::ToolCall(
800                RawStreamingToolCall::new(
801                    "tool_1".to_string(),
802                    "mock_tool".to_string(),
803                    serde_json::json!({"arg": 1}),
804                ),
805            ));
806            yield Ok(RawStreamingChoice::FinalResponse(MockResponse::with_total_tokens(3)));
807        };
808
809        StreamingCompletionResponse::stream(to_stream_result(stream))
810    }
811
812    fn create_text_tool_text_stream() -> StreamingCompletionResponse<MockResponse> {
813        let stream = stream! {
814            yield Ok(RawStreamingChoice::Message("first".to_string()));
815            yield Ok(RawStreamingChoice::ToolCall(
816                RawStreamingToolCall::new(
817                    "tool_split".to_string(),
818                    "mock_tool".to_string(),
819                    serde_json::json!({"arg": "x"}),
820                ),
821            ));
822            yield Ok(RawStreamingChoice::Message("second".to_string()));
823            yield Ok(RawStreamingChoice::FinalResponse(MockResponse::with_total_tokens(3)));
824        };
825
826        StreamingCompletionResponse::stream(to_stream_result(stream))
827    }
828
829    fn create_text_metadata_stream() -> StreamingCompletionResponse<MockResponse> {
830        let stream = stream! {
831            yield Ok(RawStreamingChoice::TextStart {
832                additional_params: None,
833            });
834            yield Ok(RawStreamingChoice::Message("first".to_string()));
835            yield Ok(RawStreamingChoice::TextAdditionalParams(serde_json::json!({
836                "citations": [{
837                    "type": "char_location",
838                    "cited_text": "First citation.",
839                    "document_index": 0,
840                    "start_char_index": 0,
841                    "end_char_index": 15
842                }]
843            })));
844            yield Ok(RawStreamingChoice::TextAdditionalParams(serde_json::json!({
845                "citations": [{
846                    "type": "char_location",
847                    "cited_text": "Second citation.",
848                    "document_index": 0,
849                    "start_char_index": 16,
850                    "end_char_index": 32
851                }]
852            })));
853            yield Ok(RawStreamingChoice::TextStart {
854                additional_params: Some(serde_json::json!({
855                    "block": 2
856                })),
857            });
858            yield Ok(RawStreamingChoice::Message("second".to_string()));
859            yield Ok(RawStreamingChoice::FinalResponse(MockResponse::with_total_tokens(3)));
860        };
861
862        StreamingCompletionResponse::stream(to_stream_result(stream))
863    }
864
865    #[tokio::test]
866    async fn test_stream_cancellation() {
867        let mut stream = create_mock_stream();
868
869        println!("Response: ");
870        let mut chunk_count = 0;
871        while let Some(chunk) = stream.next().await {
872            match chunk {
873                Ok(StreamedAssistantContent::Text(text)) => {
874                    print!("{}", text.text);
875                    std::io::Write::flush(&mut std::io::stdout()).unwrap();
876                    chunk_count += 1;
877                }
878                Ok(StreamedAssistantContent::ToolCall {
879                    tool_call,
880                    internal_call_id,
881                }) => {
882                    println!("\nTool Call: {tool_call:?}, internal_call_id={internal_call_id:?}");
883                    chunk_count += 1;
884                }
885                Ok(StreamedAssistantContent::ToolCallDelta {
886                    id,
887                    internal_call_id,
888                    content,
889                }) => {
890                    println!(
891                        "\nTool Call delta: id={id:?}, internal_call_id={internal_call_id:?}, content={content:?}"
892                    );
893                    chunk_count += 1;
894                }
895                Ok(StreamedAssistantContent::Final(res)) => {
896                    println!("\nFinal response: {res:?}");
897                }
898                Ok(StreamedAssistantContent::Reasoning(reasoning)) => {
899                    let reasoning = reasoning.display_text();
900                    print!("{reasoning}");
901                    std::io::Write::flush(&mut std::io::stdout()).unwrap();
902                }
903                Ok(StreamedAssistantContent::ReasoningDelta { reasoning, .. }) => {
904                    println!("Reasoning delta: {reasoning}");
905                    chunk_count += 1;
906                }
907                Err(e) => {
908                    eprintln!("Error: {e:?}");
909                    break;
910                }
911            }
912
913            if chunk_count >= 2 {
914                println!("\nCancelling stream...");
915                stream.cancel();
916                println!("Stream cancelled.");
917                break;
918            }
919        }
920
921        let next_chunk = stream.next().await;
922        assert!(
923            next_chunk.is_none(),
924            "Expected no further chunks after cancellation, got {next_chunk:?}"
925        );
926    }
927
928    #[tokio::test]
929    async fn test_stream_pause_resume() {
930        let stream = create_mock_stream();
931
932        // Test pause
933        stream.pause();
934        assert!(stream.is_paused());
935
936        // Test resume
937        stream.resume();
938        assert!(!stream.is_paused());
939    }
940
941    #[tokio::test]
942    async fn test_stream_aggregates_reasoning_content() {
943        let mut stream = create_reasoning_stream();
944        while stream.next().await.is_some() {}
945
946        let choice_items: Vec<AssistantContent> = stream.choice.clone().into_iter().collect();
947
948        assert!(choice_items.iter().any(|item| matches!(
949            item,
950            AssistantContent::Reasoning(Reasoning {
951                id: Some(id),
952                content
953            }) if id == "rs_1"
954                && matches!(
955                    content.first(),
956                    Some(ReasoningContent::Text {
957                        text,
958                        signature: Some(signature)
959                    }) if text == "step one" && signature == "sig_1"
960                )
961        )));
962    }
963
964    #[tokio::test]
965    async fn test_stream_reasoning_only_does_not_inject_empty_text() {
966        let mut stream = create_reasoning_only_stream();
967        while stream.next().await.is_some() {}
968
969        let choice_items: Vec<AssistantContent> = stream.choice.clone().into_iter().collect();
970        assert_eq!(choice_items.len(), 1);
971        assert!(matches!(
972            choice_items.first(),
973            Some(AssistantContent::Reasoning(Reasoning { id: Some(id), .. })) if id == "rs_only"
974        ));
975    }
976
977    #[tokio::test]
978    async fn test_stream_aggregates_assistant_items_in_arrival_order() {
979        let mut stream = create_interleaved_stream();
980        while stream.next().await.is_some() {}
981
982        let choice_items: Vec<AssistantContent> = stream.choice.clone().into_iter().collect();
983        assert_eq!(choice_items.len(), 3);
984        assert!(matches!(
985            choice_items.first(),
986            Some(AssistantContent::Reasoning(Reasoning { id: Some(id), .. })) if id == "rs_interleaved"
987        ));
988        assert!(matches!(
989            choice_items.get(1),
990            Some(AssistantContent::Text(Text { text, .. })) if text == "final-text"
991        ));
992        assert!(matches!(
993            choice_items.get(2),
994            Some(AssistantContent::ToolCall(ToolCall { id, .. })) if id == "tool_1"
995        ));
996    }
997
998    #[tokio::test]
999    async fn test_stream_keeps_non_contiguous_text_chunks_split_by_tool_call() {
1000        let mut stream = create_text_tool_text_stream();
1001        while stream.next().await.is_some() {}
1002
1003        let choice_items: Vec<AssistantContent> = stream.choice.clone().into_iter().collect();
1004        assert_eq!(choice_items.len(), 3);
1005        assert!(matches!(
1006            choice_items.first(),
1007            Some(AssistantContent::Text(Text { text, .. })) if text == "first"
1008        ));
1009        assert!(matches!(
1010            choice_items.get(1),
1011            Some(AssistantContent::ToolCall(ToolCall { id, .. })) if id == "tool_split"
1012        ));
1013        assert!(matches!(
1014            choice_items.get(2),
1015            Some(AssistantContent::Text(Text { text, .. })) if text == "second"
1016        ));
1017    }
1018
1019    #[tokio::test]
1020    async fn test_stream_preserves_text_additional_params() {
1021        let mut stream = create_text_metadata_stream();
1022        while stream.next().await.is_some() {}
1023
1024        let choice_items: Vec<AssistantContent> = stream.choice.clone().into_iter().collect();
1025        assert_eq!(choice_items.len(), 2);
1026
1027        let Some(AssistantContent::Text(Text {
1028            text,
1029            additional_params: Some(additional_params),
1030        })) = choice_items.first()
1031        else {
1032            panic!("expected first text item with metadata");
1033        };
1034        assert_eq!(text, "first");
1035        assert_eq!(
1036            additional_params["citations"]
1037                .as_array()
1038                .expect("citations should be an array")
1039                .len(),
1040            2
1041        );
1042
1043        let Some(AssistantContent::Text(Text {
1044            text,
1045            additional_params: Some(additional_params),
1046        })) = choice_items.get(1)
1047        else {
1048            panic!("expected second text item with metadata");
1049        };
1050        assert_eq!(text, "second");
1051        assert_eq!(additional_params["block"], 2);
1052    }
1053}
1054
1055/// Describes responses from a streamed provider response which is either text, a tool call or a final usage response.
1056#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
1057#[serde(untagged)]
1058pub enum StreamedAssistantContent<R> {
1059    /// Text delta emitted by the assistant.
1060    Text(Text),
1061    /// Complete tool call emitted by the assistant.
1062    ToolCall {
1063        tool_call: ToolCall,
1064        /// Rig-generated unique identifier for this tool call.
1065        /// Use this to correlate with ToolCallDelta events.
1066        internal_call_id: String,
1067    },
1068    /// Partial tool call data emitted by the assistant.
1069    ToolCallDelta {
1070        /// Provider-supplied tool call ID.
1071        id: String,
1072        /// Rig-generated unique identifier for this tool call.
1073        internal_call_id: String,
1074        content: ToolCallDeltaContent,
1075    },
1076    /// Complete reasoning block emitted by the assistant.
1077    Reasoning(Reasoning),
1078    /// Partial reasoning text emitted by the assistant.
1079    ReasoningDelta {
1080        /// Provider-supplied reasoning block ID, when present.
1081        id: Option<String>,
1082        /// Partial reasoning text.
1083        reasoning: String,
1084    },
1085    /// Final provider response object, if yielded by the provider stream.
1086    Final(R),
1087}
1088
1089impl<R> StreamedAssistantContent<R>
1090where
1091    R: Clone + Unpin,
1092{
1093    /// Create a text stream item.
1094    pub fn text(text: &str) -> Self {
1095        Self::Text(Text::new(text.to_string()))
1096    }
1097
1098    /// Create a final response stream item.
1099    pub fn final_response(res: R) -> Self {
1100        Self::Final(res)
1101    }
1102}
1103
1104/// Streamed user content. This content is primarily used to represent tool results from tool calls made during a multi-turn/step agent prompt.
1105#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
1106#[serde(untagged)]
1107pub enum StreamedUserContent {
1108    /// Tool result emitted during a multi-turn streaming agent loop.
1109    ToolResult {
1110        tool_result: ToolResult,
1111        /// Rig-generated unique identifier for the tool call this result
1112        /// belongs to. Use this to correlate with the originating
1113        /// [`StreamedAssistantContent::ToolCall::internal_call_id`].
1114        internal_call_id: String,
1115    },
1116}
1117
1118impl StreamedUserContent {
1119    /// Create a streamed tool result correlated to an internal tool call ID.
1120    pub fn tool_result(tool_result: ToolResult, internal_call_id: String) -> Self {
1121        Self::ToolResult {
1122            tool_result,
1123            internal_call_id,
1124        }
1125    }
1126}