Skip to main content

rig_core/
streaming.rs

1//! This module provides functionality for working with streaming completion models.
2//! It provides traits and types for generating streaming completion requests and
3//! handling streaming completion responses.
4//!
5//! The main traits defined in this module are:
6//! - [StreamingPrompt]: Defines a high-level streaming LLM one-shot prompt interface
7//! - [StreamingChat]: Defines a high-level streaming LLM chat interface with history
8//! - [StreamingCompletion]: Defines a low-level streaming LLM completion interface
9//!
10
11use crate::OneOrMany;
12use crate::agent::Agent;
13use crate::agent::prompt_request::hooks::PromptHook;
14use crate::agent::prompt_request::streaming::StreamingPromptRequest;
15use crate::completion::{
16    CompletionError, CompletionModel, CompletionRequestBuilder, CompletionResponse, GetTokenUsage,
17    Message, Usage,
18};
19use crate::message::{
20    AssistantContent, Reasoning, ReasoningContent, Text, ToolCall, ToolFunction, ToolResult,
21};
22use crate::wasm_compat::{WasmCompatSend, WasmCompatSync};
23use futures::stream::{AbortHandle, Abortable};
24use futures::{Stream, StreamExt};
25use serde::{Deserialize, Serialize};
26use std::future::Future;
27use std::pin::Pin;
28use std::sync::atomic::AtomicBool;
29use std::task::{Context, Poll};
30use tokio::sync::watch;
31
32/// Control for pausing and resuming a streaming response
33pub struct PauseControl {
34    pub(crate) paused_tx: watch::Sender<bool>,
35    pub(crate) paused_rx: watch::Receiver<bool>,
36}
37
38impl PauseControl {
39    /// Create a pause controller in the running state.
40    pub fn new() -> Self {
41        let (paused_tx, paused_rx) = watch::channel(false);
42        Self {
43            paused_tx,
44            paused_rx,
45        }
46    }
47
48    /// Pause polling of the public stream until [`PauseControl::resume`] is called.
49    pub fn pause(&self) {
50        let _ = self.paused_tx.send(true);
51    }
52
53    /// Resume polling after a pause.
54    pub fn resume(&self) {
55        let _ = self.paused_tx.send(false);
56    }
57
58    /// Returns whether the stream is currently paused.
59    pub fn is_paused(&self) -> bool {
60        *self.paused_rx.borrow()
61    }
62}
63
64impl Default for PauseControl {
65    fn default() -> Self {
66        Self::new()
67    }
68}
69
70/// The content of a tool call delta - either the tool name or argument data
71#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
72pub enum ToolCallDeltaContent {
73    /// Tool/function name emitted by the provider.
74    Name(String),
75    /// Partial JSON argument data emitted by the provider.
76    Delta(String),
77}
78
79/// Enum representing a streaming chunk from the model
80#[derive(Debug, Clone)]
81pub enum RawStreamingChoice<R>
82where
83    R: Clone,
84{
85    /// A text chunk from a message response
86    Message(String),
87
88    /// A tool call response (in its entirety)
89    ToolCall(RawStreamingToolCall),
90    /// A tool call partial/delta
91    ToolCallDelta {
92        /// Provider-supplied tool call ID.
93        id: String,
94        /// Rig-generated unique identifier for this tool call.
95        internal_call_id: String,
96        content: ToolCallDeltaContent,
97    },
98    /// A reasoning (in its entirety)
99    Reasoning {
100        /// Provider-supplied reasoning block ID, when present.
101        id: Option<String>,
102        /// Complete reasoning content block.
103        content: ReasoningContent,
104    },
105    /// A reasoning partial/delta
106    ReasoningDelta {
107        /// Provider-supplied reasoning block ID, when present.
108        id: Option<String>,
109        /// Partial reasoning text.
110        reasoning: String,
111    },
112
113    /// The final response object, must be yielded if you want the
114    /// `response` field to be populated on the `StreamingCompletionResponse`
115    FinalResponse(R),
116
117    /// Provider-assigned message ID (e.g. OpenAI Responses API `msg_` ID).
118    /// Captured silently into `StreamingCompletionResponse::message_id`.
119    MessageId(String),
120}
121
122/// Describes a streaming tool call response (in its entirety)
123#[derive(Debug, Clone)]
124pub struct RawStreamingToolCall {
125    /// Provider-supplied tool call ID.
126    pub id: String,
127    /// Rig-generated unique identifier for this tool call.
128    pub internal_call_id: String,
129    /// Provider-specific call ID used by some APIs for tool result correlation.
130    pub call_id: Option<String>,
131    /// Tool/function name.
132    pub name: String,
133    /// Parsed tool arguments.
134    pub arguments: serde_json::Value,
135    /// Optional provider signature associated with the tool call.
136    pub signature: Option<String>,
137    /// Additional provider-specific tool call metadata.
138    pub additional_params: Option<serde_json::Value>,
139}
140
141impl RawStreamingToolCall {
142    /// Create an empty tool call accumulator for provider streaming parsers.
143    pub fn empty() -> Self {
144        Self {
145            id: String::new(),
146            internal_call_id: nanoid::nanoid!(),
147            call_id: None,
148            name: String::new(),
149            arguments: serde_json::Value::Null,
150            signature: None,
151            additional_params: None,
152        }
153    }
154
155    /// Create a complete tool call with a generated internal call ID.
156    pub fn new(id: String, name: String, arguments: serde_json::Value) -> Self {
157        Self {
158            id,
159            internal_call_id: nanoid::nanoid!(),
160            call_id: None,
161            name,
162            arguments,
163            signature: None,
164            additional_params: None,
165        }
166    }
167
168    /// Override the generated internal call ID.
169    pub fn with_internal_call_id(mut self, internal_call_id: String) -> Self {
170        self.internal_call_id = internal_call_id;
171        self
172    }
173
174    /// Attach a provider-specific call ID.
175    pub fn with_call_id(mut self, call_id: String) -> Self {
176        self.call_id = Some(call_id);
177        self
178    }
179
180    /// Attach or clear a provider signature.
181    pub fn with_signature(mut self, signature: Option<String>) -> Self {
182        self.signature = signature;
183        self
184    }
185
186    /// Attach provider-specific metadata.
187    pub fn with_additional_params(mut self, additional_params: Option<serde_json::Value>) -> Self {
188        self.additional_params = additional_params;
189        self
190    }
191}
192
193impl From<RawStreamingToolCall> for ToolCall {
194    fn from(tool_call: RawStreamingToolCall) -> Self {
195        ToolCall {
196            id: tool_call.id,
197            call_id: tool_call.call_id,
198            function: ToolFunction {
199                name: tool_call.name,
200                arguments: tool_call.arguments,
201            },
202            signature: tool_call.signature,
203            additional_params: tool_call.additional_params,
204        }
205    }
206}
207
208#[cfg(not(all(feature = "wasm", target_arch = "wasm32")))]
209/// Provider stream of raw completion chunks on native targets.
210pub type StreamingResult<R> =
211    Pin<Box<dyn Stream<Item = Result<RawStreamingChoice<R>, CompletionError>> + Send>>;
212
213#[cfg(all(feature = "wasm", target_arch = "wasm32"))]
214/// Provider stream of raw completion chunks on wasm targets.
215pub type StreamingResult<R> =
216    Pin<Box<dyn Stream<Item = Result<RawStreamingChoice<R>, CompletionError>>>>;
217
218/// The response from a streaming completion request;
219/// message and response are populated at the end of the
220/// `inner` stream.
221pub struct StreamingCompletionResponse<R>
222where
223    R: Clone + Unpin + GetTokenUsage,
224{
225    pub(crate) inner: Abortable<StreamingResult<R>>,
226    pub(crate) abort_handle: AbortHandle,
227    pub(crate) pause_control: PauseControl,
228    assistant_items: Vec<AssistantContent>,
229    text_item_index: Option<usize>,
230    reasoning_item_index: Option<usize>,
231    /// The final aggregated message from the stream
232    /// contains all text and tool calls generated
233    pub choice: OneOrMany<AssistantContent>,
234    /// The final response from the stream, may be `None`
235    /// if the provider didn't yield it during the stream
236    pub response: Option<R>,
237    pub final_response_yielded: AtomicBool,
238    /// Provider-assigned message ID (e.g. OpenAI Responses API `msg_` ID).
239    pub message_id: Option<String>,
240}
241
242impl<R> StreamingCompletionResponse<R>
243where
244    R: Clone + Unpin + GetTokenUsage,
245{
246    /// Wrap a provider stream and initialize aggregation state.
247    pub fn stream(inner: StreamingResult<R>) -> StreamingCompletionResponse<R> {
248        let (abort_handle, abort_registration) = AbortHandle::new_pair();
249        let abortable_stream = Abortable::new(inner, abort_registration);
250        let pause_control = PauseControl::new();
251        Self {
252            inner: abortable_stream,
253            abort_handle,
254            pause_control,
255            assistant_items: vec![],
256            text_item_index: None,
257            reasoning_item_index: None,
258            choice: OneOrMany::one(AssistantContent::text("")),
259            response: None,
260            final_response_yielded: AtomicBool::new(false),
261            message_id: None,
262        }
263    }
264
265    /// Cancel the stream. Cancellation is surfaced as normal stream termination.
266    pub fn cancel(&self) {
267        self.abort_handle.abort();
268    }
269
270    /// Pause stream polling.
271    pub fn pause(&self) {
272        self.pause_control.pause();
273    }
274
275    /// Resume stream polling after a pause.
276    pub fn resume(&self) {
277        self.pause_control.resume();
278    }
279
280    /// Returns whether the stream is currently paused.
281    pub fn is_paused(&self) -> bool {
282        self.pause_control.is_paused()
283    }
284
285    fn append_text_chunk(&mut self, text: &str) {
286        if let Some(index) = self.text_item_index
287            && let Some(AssistantContent::Text(existing_text)) = self.assistant_items.get_mut(index)
288        {
289            existing_text.text.push_str(text);
290            return;
291        }
292
293        self.assistant_items
294            .push(AssistantContent::text(text.to_owned()));
295        self.text_item_index = Some(self.assistant_items.len() - 1);
296    }
297
298    /// Accumulate streaming reasoning delta text into assistant_items.
299    /// Providers that only emit ReasoningDelta (not full Reasoning blocks)
300    /// need this so the aggregated response includes reasoning content.
301    fn append_reasoning_chunk(&mut self, id: &Option<String>, text: &str) {
302        if let Some(index) = self.reasoning_item_index
303            && let Some(AssistantContent::Reasoning(existing)) = self.assistant_items.get_mut(index)
304            && let Some(ReasoningContent::Text {
305                text: existing_text,
306                ..
307            }) = existing.content.last_mut()
308        {
309            existing_text.push_str(text);
310            return;
311        }
312
313        self.assistant_items
314            .push(AssistantContent::Reasoning(Reasoning {
315                id: id.clone(),
316                content: vec![ReasoningContent::Text {
317                    text: text.to_string(),
318                    signature: None,
319                }],
320            }));
321        self.reasoning_item_index = Some(self.assistant_items.len() - 1);
322    }
323}
324
325impl<R> From<StreamingCompletionResponse<R>> for CompletionResponse<Option<R>>
326where
327    R: Clone + Unpin + GetTokenUsage,
328{
329    fn from(value: StreamingCompletionResponse<R>) -> CompletionResponse<Option<R>> {
330        CompletionResponse {
331            choice: value.choice,
332            usage: Usage::new(), // Usage is not tracked in streaming responses
333            raw_response: value.response,
334            message_id: value.message_id,
335        }
336    }
337}
338
339impl<R> Stream for StreamingCompletionResponse<R>
340where
341    R: Clone + Unpin + GetTokenUsage,
342{
343    type Item = Result<StreamedAssistantContent<R>, CompletionError>;
344
345    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
346        let stream = self.get_mut();
347
348        if stream.is_paused() {
349            cx.waker().wake_by_ref();
350            return Poll::Pending;
351        }
352
353        match Pin::new(&mut stream.inner).poll_next(cx) {
354            Poll::Pending => Poll::Pending,
355            Poll::Ready(None) => {
356                // This is run at the end of the inner stream to collect all tokens into
357                // a single unified `Message`.
358                if stream.assistant_items.is_empty() {
359                    stream.assistant_items.push(AssistantContent::text(""));
360                }
361
362                if let Some(choice) =
363                    OneOrMany::from_iter_optional(std::mem::take(&mut stream.assistant_items))
364                {
365                    stream.choice = choice;
366                }
367
368                Poll::Ready(None)
369            }
370            Poll::Ready(Some(Err(err))) => {
371                if matches!(err, CompletionError::ProviderError(ref e) if e.to_string().contains("aborted"))
372                {
373                    return Poll::Ready(None); // Treat cancellation as stream termination
374                }
375                Poll::Ready(Some(Err(err)))
376            }
377            Poll::Ready(Some(Ok(choice))) => match choice {
378                RawStreamingChoice::Message(text) => {
379                    stream.reasoning_item_index = None;
380                    stream.append_text_chunk(&text);
381                    Poll::Ready(Some(Ok(StreamedAssistantContent::text(&text))))
382                }
383                RawStreamingChoice::ToolCallDelta {
384                    id,
385                    internal_call_id,
386                    content,
387                } => Poll::Ready(Some(Ok(StreamedAssistantContent::ToolCallDelta {
388                    id,
389                    internal_call_id,
390                    content,
391                }))),
392                RawStreamingChoice::Reasoning { id, content } => {
393                    let reasoning = Reasoning {
394                        id,
395                        content: vec![content],
396                    };
397                    stream.text_item_index = None;
398                    // Full reasoning block supersedes any delta accumulation
399                    stream.reasoning_item_index = None;
400                    stream
401                        .assistant_items
402                        .push(AssistantContent::Reasoning(reasoning.clone()));
403                    Poll::Ready(Some(Ok(StreamedAssistantContent::Reasoning(reasoning))))
404                }
405                RawStreamingChoice::ReasoningDelta { id, reasoning } => {
406                    stream.text_item_index = None;
407                    stream.append_reasoning_chunk(&id, &reasoning);
408                    Poll::Ready(Some(Ok(StreamedAssistantContent::ReasoningDelta {
409                        id,
410                        reasoning,
411                    })))
412                }
413                RawStreamingChoice::ToolCall(raw_tool_call) => {
414                    let internal_call_id = raw_tool_call.internal_call_id.clone();
415                    let tool_call: ToolCall = raw_tool_call.into();
416                    stream.text_item_index = None;
417                    stream.reasoning_item_index = None;
418                    stream
419                        .assistant_items
420                        .push(AssistantContent::ToolCall(tool_call.clone()));
421                    Poll::Ready(Some(Ok(StreamedAssistantContent::ToolCall {
422                        tool_call,
423                        internal_call_id,
424                    })))
425                }
426                RawStreamingChoice::FinalResponse(response) => {
427                    if stream
428                        .final_response_yielded
429                        .load(std::sync::atomic::Ordering::SeqCst)
430                    {
431                        stream.poll_next_unpin(cx)
432                    } else {
433                        // Set the final response field and return the next item in the stream
434                        stream.response = Some(response.clone());
435                        stream
436                            .final_response_yielded
437                            .store(true, std::sync::atomic::Ordering::SeqCst);
438                        let final_response = StreamedAssistantContent::final_response(response);
439                        Poll::Ready(Some(Ok(final_response)))
440                    }
441                }
442                RawStreamingChoice::MessageId(id) => {
443                    stream.message_id = Some(id);
444                    stream.poll_next_unpin(cx)
445                }
446            },
447        }
448    }
449}
450
451/// Trait for high-level streaming prompt interface.
452///
453/// This trait provides a simple interface for streaming prompts to a completion model.
454/// Implementations can optionally support prompt hooks for observing and controlling
455/// the agent's execution lifecycle.
456pub trait StreamingPrompt<M, R>
457where
458    M: CompletionModel + 'static,
459    <M as CompletionModel>::StreamingResponse: WasmCompatSend,
460    R: Clone + Unpin + GetTokenUsage,
461{
462    /// The hook type used by this streaming prompt implementation.
463    ///
464    /// If your implementation does not need prompt hooks, use `()` as the hook type:
465    ///
466    /// ```ignore
467    /// impl<M, R> StreamingPrompt<M, R> for MyType<M>
468    /// where
469    ///     M: CompletionModel + 'static,
470    ///     // ... other bounds ...
471    /// {
472    ///     type Hook = ();
473    ///
474    ///     fn stream_prompt(&self, prompt: impl Into<Message>) -> StreamingPromptRequest<M, ()> {
475    ///         // ...
476    ///     }
477    /// }
478    /// ```
479    type Hook: PromptHook<M>;
480
481    /// Stream a simple prompt to the model
482    fn stream_prompt(
483        &self,
484        prompt: impl Into<Message> + WasmCompatSend,
485    ) -> StreamingPromptRequest<M, Self::Hook>;
486}
487
488/// Trait for high-level streaming chat interface with conversation history.
489///
490/// This trait provides an interface for streaming chat completions with support
491/// for maintaining conversation history. Implementations can optionally support
492/// prompt hooks for observing and controlling the agent's execution lifecycle.
493pub trait StreamingChat<M, R>: WasmCompatSend + WasmCompatSync
494where
495    M: CompletionModel + 'static,
496    <M as CompletionModel>::StreamingResponse: WasmCompatSend,
497    R: Clone + Unpin + GetTokenUsage,
498{
499    /// The hook type used by this streaming chat implementation.
500    ///
501    /// If your implementation does not need prompt hooks, use `()` as the hook type:
502    ///
503    /// ```ignore
504    /// impl<M, R> StreamingChat<M, R> for MyType<M>
505    /// where
506    ///     M: CompletionModel + 'static,
507    ///     // ... other bounds ...
508    /// {
509    ///     type Hook = ();
510    ///
511    ///     fn stream_chat(
512    ///         &self,
513    ///         prompt: impl Into<Message>,
514    ///         chat_history: Vec<Message>,
515    ///     ) -> StreamingPromptRequest<M, ()> {
516    ///         // ...
517    ///     }
518    /// }
519    /// ```
520    type Hook: PromptHook<M>;
521
522    /// Stream a chat with history to the model.
523    ///
524    /// The messages returned by the model can be accessed via `FinalResponse::history()`
525    ///
526    /// You are responsible for managing history, a simple linear solution could look like:
527    /// ```ignore
528    ///  let mut history = vec![];
529    ///
530    ///  loop {
531    ///      let prompt = "Create GPT-67, make no mistakes";
532    ///      let mut stream = agent.stream_chat(prompt, &history).await;
533    ///
534    ///      while let Some(msg) = stream.next().await {
535    ///         match msg {
536    ///              Ok(MultiTurnStreamItem::FinalResponse(fin)) => {
537    ///                  history.extend_from_slice(fin.history().unwrap_or_default());
538    ///                  break;
539    ///             }
540    ///             Ok(_other) => { /* Do something with this chunk */ }
541    ///             Err(e) => return Err(e.into()),
542    ///         }
543    ///     }
544    /// }
545    /// ```
546    fn stream_chat<I, T>(
547        &self,
548        prompt: impl Into<Message> + WasmCompatSend,
549        chat_history: I,
550    ) -> StreamingPromptRequest<M, Self::Hook>
551    where
552        I: IntoIterator<Item = T> + WasmCompatSend,
553        T: Into<Message>;
554}
555
556/// Trait for low-level streaming completion interface
557pub trait StreamingCompletion<M: CompletionModel> {
558    /// Generate a streaming completion from a request
559    fn stream_completion<I, T>(
560        &self,
561        prompt: impl Into<Message> + WasmCompatSend,
562        chat_history: I,
563    ) -> impl Future<Output = Result<CompletionRequestBuilder<M>, CompletionError>>
564    where
565        I: IntoIterator<Item = T> + WasmCompatSend,
566        T: Into<Message>;
567}
568
569/// A helper function to stream a completion request to stdout.
570/// Tool call deltas are ignored as tool calls are generally much easier to handle when received in their entirety rather than using deltas.
571pub async fn stream_to_stdout<M>(
572    agent: &'static Agent<M>,
573    stream: &mut StreamingCompletionResponse<M::StreamingResponse>,
574) -> Result<(), std::io::Error>
575where
576    M: CompletionModel,
577{
578    let mut is_reasoning = false;
579    print!("Response: ");
580    while let Some(chunk) = stream.next().await {
581        match chunk {
582            Ok(StreamedAssistantContent::Text(text)) => {
583                if is_reasoning {
584                    is_reasoning = false;
585                    println!("\n---\n");
586                }
587                print!("{}", text.text);
588                std::io::Write::flush(&mut std::io::stdout())?;
589            }
590            Ok(StreamedAssistantContent::ToolCall {
591                tool_call,
592                internal_call_id: _,
593            }) => {
594                let res = agent
595                    .tool_server_handle
596                    .call_tool(
597                        &tool_call.function.name,
598                        &tool_call.function.arguments.to_string(),
599                    )
600                    .await
601                    .map_err(|x| std::io::Error::other(x.to_string()))?;
602                println!("\nResult: {res}");
603            }
604            Ok(StreamedAssistantContent::Final(res)) => {
605                if let Ok(json_res) = serde_json::to_string_pretty(&res) {
606                    println!();
607                    tracing::info!("Final result: {json_res}");
608                }
609            }
610            Ok(StreamedAssistantContent::Reasoning(reasoning)) => {
611                if !is_reasoning {
612                    is_reasoning = true;
613                    println!();
614                    println!("Thinking: ");
615                }
616                let reasoning = reasoning.display_text();
617
618                print!("{reasoning}");
619                std::io::Write::flush(&mut std::io::stdout())?;
620            }
621            Err(e) => {
622                if e.to_string().contains("aborted") {
623                    println!("\nStream cancelled.");
624                    break;
625                }
626                eprintln!("Error: {e}");
627                break;
628            }
629            _ => {}
630        }
631    }
632
633    println!(); // New line after streaming completes
634
635    Ok(())
636}
637
638// Test module
639#[cfg(test)]
640mod tests {
641    use std::time::Duration;
642
643    use super::*;
644    use crate::test_utils::MockResponse;
645    use async_stream::stream;
646    use tokio::time::sleep;
647
648    #[cfg(not(all(feature = "wasm", target_arch = "wasm32")))]
649    fn to_stream_result(
650        stream: impl futures::Stream<Item = Result<RawStreamingChoice<MockResponse>, CompletionError>>
651        + Send
652        + 'static,
653    ) -> StreamingResult<MockResponse> {
654        Box::pin(stream)
655    }
656
657    #[cfg(all(feature = "wasm", target_arch = "wasm32"))]
658    fn to_stream_result(
659        stream: impl futures::Stream<Item = Result<RawStreamingChoice<MockResponse>, CompletionError>>
660        + 'static,
661    ) -> StreamingResult<MockResponse> {
662        Box::pin(stream)
663    }
664
665    fn create_mock_stream() -> StreamingCompletionResponse<MockResponse> {
666        let stream = stream! {
667            yield Ok(RawStreamingChoice::Message("hello 1".to_string()));
668            sleep(Duration::from_millis(100)).await;
669            yield Ok(RawStreamingChoice::Message("hello 2".to_string()));
670            sleep(Duration::from_millis(100)).await;
671            yield Ok(RawStreamingChoice::Message("hello 3".to_string()));
672            sleep(Duration::from_millis(100)).await;
673            yield Ok(RawStreamingChoice::FinalResponse(MockResponse::with_total_tokens(15)));
674        };
675
676        StreamingCompletionResponse::stream(to_stream_result(stream))
677    }
678
679    fn create_reasoning_stream() -> StreamingCompletionResponse<MockResponse> {
680        let stream = stream! {
681            yield Ok(RawStreamingChoice::Reasoning {
682                id: Some("rs_1".to_string()),
683                content: ReasoningContent::Text {
684                    text: "step one".to_string(),
685                    signature: Some("sig_1".to_string()),
686                },
687            });
688            yield Ok(RawStreamingChoice::Message("final answer".to_string()));
689            yield Ok(RawStreamingChoice::FinalResponse(MockResponse::with_total_tokens(5)));
690        };
691
692        StreamingCompletionResponse::stream(to_stream_result(stream))
693    }
694
695    fn create_reasoning_only_stream() -> StreamingCompletionResponse<MockResponse> {
696        let stream = stream! {
697            yield Ok(RawStreamingChoice::Reasoning {
698                id: Some("rs_only".to_string()),
699                content: ReasoningContent::Summary("hidden summary".to_string()),
700            });
701            yield Ok(RawStreamingChoice::FinalResponse(MockResponse::with_total_tokens(2)));
702        };
703
704        StreamingCompletionResponse::stream(to_stream_result(stream))
705    }
706
707    fn create_interleaved_stream() -> StreamingCompletionResponse<MockResponse> {
708        let stream = stream! {
709            yield Ok(RawStreamingChoice::Reasoning {
710                id: Some("rs_interleaved".to_string()),
711                content: ReasoningContent::Text {
712                    text: "chain-of-thought".to_string(),
713                    signature: None,
714                },
715            });
716            yield Ok(RawStreamingChoice::Message("final-text".to_string()));
717            yield Ok(RawStreamingChoice::ToolCall(
718                RawStreamingToolCall::new(
719                    "tool_1".to_string(),
720                    "mock_tool".to_string(),
721                    serde_json::json!({"arg": 1}),
722                ),
723            ));
724            yield Ok(RawStreamingChoice::FinalResponse(MockResponse::with_total_tokens(3)));
725        };
726
727        StreamingCompletionResponse::stream(to_stream_result(stream))
728    }
729
730    fn create_text_tool_text_stream() -> StreamingCompletionResponse<MockResponse> {
731        let stream = stream! {
732            yield Ok(RawStreamingChoice::Message("first".to_string()));
733            yield Ok(RawStreamingChoice::ToolCall(
734                RawStreamingToolCall::new(
735                    "tool_split".to_string(),
736                    "mock_tool".to_string(),
737                    serde_json::json!({"arg": "x"}),
738                ),
739            ));
740            yield Ok(RawStreamingChoice::Message("second".to_string()));
741            yield Ok(RawStreamingChoice::FinalResponse(MockResponse::with_total_tokens(3)));
742        };
743
744        StreamingCompletionResponse::stream(to_stream_result(stream))
745    }
746
747    #[tokio::test]
748    async fn test_stream_cancellation() {
749        let mut stream = create_mock_stream();
750
751        println!("Response: ");
752        let mut chunk_count = 0;
753        while let Some(chunk) = stream.next().await {
754            match chunk {
755                Ok(StreamedAssistantContent::Text(text)) => {
756                    print!("{}", text.text);
757                    std::io::Write::flush(&mut std::io::stdout()).unwrap();
758                    chunk_count += 1;
759                }
760                Ok(StreamedAssistantContent::ToolCall {
761                    tool_call,
762                    internal_call_id,
763                }) => {
764                    println!("\nTool Call: {tool_call:?}, internal_call_id={internal_call_id:?}");
765                    chunk_count += 1;
766                }
767                Ok(StreamedAssistantContent::ToolCallDelta {
768                    id,
769                    internal_call_id,
770                    content,
771                }) => {
772                    println!(
773                        "\nTool Call delta: id={id:?}, internal_call_id={internal_call_id:?}, content={content:?}"
774                    );
775                    chunk_count += 1;
776                }
777                Ok(StreamedAssistantContent::Final(res)) => {
778                    println!("\nFinal response: {res:?}");
779                }
780                Ok(StreamedAssistantContent::Reasoning(reasoning)) => {
781                    let reasoning = reasoning.display_text();
782                    print!("{reasoning}");
783                    std::io::Write::flush(&mut std::io::stdout()).unwrap();
784                }
785                Ok(StreamedAssistantContent::ReasoningDelta { reasoning, .. }) => {
786                    println!("Reasoning delta: {reasoning}");
787                    chunk_count += 1;
788                }
789                Err(e) => {
790                    eprintln!("Error: {e:?}");
791                    break;
792                }
793            }
794
795            if chunk_count >= 2 {
796                println!("\nCancelling stream...");
797                stream.cancel();
798                println!("Stream cancelled.");
799                break;
800            }
801        }
802
803        let next_chunk = stream.next().await;
804        assert!(
805            next_chunk.is_none(),
806            "Expected no further chunks after cancellation, got {next_chunk:?}"
807        );
808    }
809
810    #[tokio::test]
811    async fn test_stream_pause_resume() {
812        let stream = create_mock_stream();
813
814        // Test pause
815        stream.pause();
816        assert!(stream.is_paused());
817
818        // Test resume
819        stream.resume();
820        assert!(!stream.is_paused());
821    }
822
823    #[tokio::test]
824    async fn test_stream_aggregates_reasoning_content() {
825        let mut stream = create_reasoning_stream();
826        while stream.next().await.is_some() {}
827
828        let choice_items: Vec<AssistantContent> = stream.choice.clone().into_iter().collect();
829
830        assert!(choice_items.iter().any(|item| matches!(
831            item,
832            AssistantContent::Reasoning(Reasoning {
833                id: Some(id),
834                content
835            }) if id == "rs_1"
836                && matches!(
837                    content.first(),
838                    Some(ReasoningContent::Text {
839                        text,
840                        signature: Some(signature)
841                    }) if text == "step one" && signature == "sig_1"
842                )
843        )));
844    }
845
846    #[tokio::test]
847    async fn test_stream_reasoning_only_does_not_inject_empty_text() {
848        let mut stream = create_reasoning_only_stream();
849        while stream.next().await.is_some() {}
850
851        let choice_items: Vec<AssistantContent> = stream.choice.clone().into_iter().collect();
852        assert_eq!(choice_items.len(), 1);
853        assert!(matches!(
854            choice_items.first(),
855            Some(AssistantContent::Reasoning(Reasoning { id: Some(id), .. })) if id == "rs_only"
856        ));
857    }
858
859    #[tokio::test]
860    async fn test_stream_aggregates_assistant_items_in_arrival_order() {
861        let mut stream = create_interleaved_stream();
862        while stream.next().await.is_some() {}
863
864        let choice_items: Vec<AssistantContent> = stream.choice.clone().into_iter().collect();
865        assert_eq!(choice_items.len(), 3);
866        assert!(matches!(
867            choice_items.first(),
868            Some(AssistantContent::Reasoning(Reasoning { id: Some(id), .. })) if id == "rs_interleaved"
869        ));
870        assert!(matches!(
871            choice_items.get(1),
872            Some(AssistantContent::Text(Text { text })) if text == "final-text"
873        ));
874        assert!(matches!(
875            choice_items.get(2),
876            Some(AssistantContent::ToolCall(ToolCall { id, .. })) if id == "tool_1"
877        ));
878    }
879
880    #[tokio::test]
881    async fn test_stream_keeps_non_contiguous_text_chunks_split_by_tool_call() {
882        let mut stream = create_text_tool_text_stream();
883        while stream.next().await.is_some() {}
884
885        let choice_items: Vec<AssistantContent> = stream.choice.clone().into_iter().collect();
886        assert_eq!(choice_items.len(), 3);
887        assert!(matches!(
888            choice_items.first(),
889            Some(AssistantContent::Text(Text { text })) if text == "first"
890        ));
891        assert!(matches!(
892            choice_items.get(1),
893            Some(AssistantContent::ToolCall(ToolCall { id, .. })) if id == "tool_split"
894        ));
895        assert!(matches!(
896            choice_items.get(2),
897            Some(AssistantContent::Text(Text { text })) if text == "second"
898        ));
899    }
900}
901
902/// Describes responses from a streamed provider response which is either text, a tool call or a final usage response.
903#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
904#[serde(untagged)]
905pub enum StreamedAssistantContent<R> {
906    /// Text delta emitted by the assistant.
907    Text(Text),
908    /// Complete tool call emitted by the assistant.
909    ToolCall {
910        tool_call: ToolCall,
911        /// Rig-generated unique identifier for this tool call.
912        /// Use this to correlate with ToolCallDelta events.
913        internal_call_id: String,
914    },
915    /// Partial tool call data emitted by the assistant.
916    ToolCallDelta {
917        /// Provider-supplied tool call ID.
918        id: String,
919        /// Rig-generated unique identifier for this tool call.
920        internal_call_id: String,
921        content: ToolCallDeltaContent,
922    },
923    /// Complete reasoning block emitted by the assistant.
924    Reasoning(Reasoning),
925    /// Partial reasoning text emitted by the assistant.
926    ReasoningDelta {
927        /// Provider-supplied reasoning block ID, when present.
928        id: Option<String>,
929        /// Partial reasoning text.
930        reasoning: String,
931    },
932    /// Final provider response object, if yielded by the provider stream.
933    Final(R),
934}
935
936impl<R> StreamedAssistantContent<R>
937where
938    R: Clone + Unpin,
939{
940    /// Create a text stream item.
941    pub fn text(text: &str) -> Self {
942        Self::Text(Text {
943            text: text.to_string(),
944        })
945    }
946
947    /// Create a final response stream item.
948    pub fn final_response(res: R) -> Self {
949        Self::Final(res)
950    }
951}
952
953/// Streamed user content. This content is primarily used to represent tool results from tool calls made during a multi-turn/step agent prompt.
954#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
955#[serde(untagged)]
956pub enum StreamedUserContent {
957    /// Tool result emitted during a multi-turn streaming agent loop.
958    ToolResult {
959        tool_result: ToolResult,
960        /// Rig-generated unique identifier for the tool call this result
961        /// belongs to. Use this to correlate with the originating
962        /// [`StreamedAssistantContent::ToolCall::internal_call_id`].
963        internal_call_id: String,
964    },
965}
966
967impl StreamedUserContent {
968    /// Create a streamed tool result correlated to an internal tool call ID.
969    pub fn tool_result(tool_result: ToolResult, internal_call_id: String) -> Self {
970        Self::ToolResult {
971            tool_result,
972            internal_call_id,
973        }
974    }
975}