Skip to main content

rig/
streaming.rs

1//! This module provides functionality for working with streaming completion models.
2//! It provides traits and types for generating streaming completion requests and
3//! handling streaming completion responses.
4//!
5//! The main traits defined in this module are:
6//! - [StreamingPrompt]: Defines a high-level streaming LLM one-shot prompt interface
7//! - [StreamingChat]: Defines a high-level streaming LLM chat interface with history
8//! - [StreamingCompletion]: Defines a low-level streaming LLM completion interface
9//!
10
11use crate::OneOrMany;
12use crate::agent::Agent;
13use crate::agent::prompt_request::hooks::PromptHook;
14use crate::agent::prompt_request::streaming::StreamingPromptRequest;
15use crate::completion::{
16    CompletionError, CompletionModel, CompletionRequestBuilder, CompletionResponse, GetTokenUsage,
17    Message, Usage,
18};
19use crate::message::{
20    AssistantContent, Reasoning, ReasoningContent, Text, ToolCall, ToolFunction, ToolResult,
21};
22use crate::wasm_compat::{WasmCompatSend, WasmCompatSync};
23use futures::stream::{AbortHandle, Abortable};
24use futures::{Stream, StreamExt};
25use serde::{Deserialize, Serialize};
26use std::future::Future;
27use std::pin::Pin;
28use std::sync::atomic::AtomicBool;
29use std::task::{Context, Poll};
30use tokio::sync::watch;
31
32/// Control for pausing and resuming a streaming response
33pub struct PauseControl {
34    pub(crate) paused_tx: watch::Sender<bool>,
35    pub(crate) paused_rx: watch::Receiver<bool>,
36}
37
38impl PauseControl {
39    pub fn new() -> Self {
40        let (paused_tx, paused_rx) = watch::channel(false);
41        Self {
42            paused_tx,
43            paused_rx,
44        }
45    }
46
47    pub fn pause(&self) {
48        let _ = self.paused_tx.send(true);
49    }
50
51    pub fn resume(&self) {
52        let _ = self.paused_tx.send(false);
53    }
54
55    pub fn is_paused(&self) -> bool {
56        *self.paused_rx.borrow()
57    }
58}
59
60impl Default for PauseControl {
61    fn default() -> Self {
62        Self::new()
63    }
64}
65
66/// The content of a tool call delta - either the tool name or argument data
67#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
68pub enum ToolCallDeltaContent {
69    Name(String),
70    Delta(String),
71}
72
73/// Enum representing a streaming chunk from the model
74#[derive(Debug, Clone)]
75pub enum RawStreamingChoice<R>
76where
77    R: Clone,
78{
79    /// A text chunk from a message response
80    Message(String),
81
82    /// A tool call response (in its entirety)
83    ToolCall(RawStreamingToolCall),
84    /// A tool call partial/delta
85    ToolCallDelta {
86        /// Provider-supplied tool call ID.
87        id: String,
88        /// Rig-generated unique identifier for this tool call.
89        internal_call_id: String,
90        content: ToolCallDeltaContent,
91    },
92    /// A reasoning (in its entirety)
93    Reasoning {
94        id: Option<String>,
95        content: ReasoningContent,
96    },
97    /// A reasoning partial/delta
98    ReasoningDelta {
99        id: Option<String>,
100        reasoning: String,
101    },
102
103    /// The final response object, must be yielded if you want the
104    /// `response` field to be populated on the `StreamingCompletionResponse`
105    FinalResponse(R),
106
107    /// Provider-assigned message ID (e.g. OpenAI Responses API `msg_` ID).
108    /// Captured silently into `StreamingCompletionResponse::message_id`.
109    MessageId(String),
110}
111
112/// Describes a streaming tool call response (in its entirety)
113#[derive(Debug, Clone)]
114pub struct RawStreamingToolCall {
115    /// Provider-supplied tool call ID.
116    pub id: String,
117    /// Rig-generated unique identifier for this tool call.
118    pub internal_call_id: String,
119    pub call_id: Option<String>,
120    pub name: String,
121    pub arguments: serde_json::Value,
122    pub signature: Option<String>,
123    pub additional_params: Option<serde_json::Value>,
124}
125
126impl RawStreamingToolCall {
127    pub fn empty() -> Self {
128        Self {
129            id: String::new(),
130            internal_call_id: nanoid::nanoid!(),
131            call_id: None,
132            name: String::new(),
133            arguments: serde_json::Value::Null,
134            signature: None,
135            additional_params: None,
136        }
137    }
138
139    pub fn new(id: String, name: String, arguments: serde_json::Value) -> Self {
140        Self {
141            id,
142            internal_call_id: nanoid::nanoid!(),
143            call_id: None,
144            name,
145            arguments,
146            signature: None,
147            additional_params: None,
148        }
149    }
150
151    pub fn with_internal_call_id(mut self, internal_call_id: String) -> Self {
152        self.internal_call_id = internal_call_id;
153        self
154    }
155
156    pub fn with_call_id(mut self, call_id: String) -> Self {
157        self.call_id = Some(call_id);
158        self
159    }
160
161    pub fn with_signature(mut self, signature: Option<String>) -> Self {
162        self.signature = signature;
163        self
164    }
165
166    pub fn with_additional_params(mut self, additional_params: Option<serde_json::Value>) -> Self {
167        self.additional_params = additional_params;
168        self
169    }
170}
171
172impl From<RawStreamingToolCall> for ToolCall {
173    fn from(tool_call: RawStreamingToolCall) -> Self {
174        ToolCall {
175            id: tool_call.id,
176            call_id: tool_call.call_id,
177            function: ToolFunction {
178                name: tool_call.name,
179                arguments: tool_call.arguments,
180            },
181            signature: tool_call.signature,
182            additional_params: tool_call.additional_params,
183        }
184    }
185}
186
187#[cfg(not(all(feature = "wasm", target_arch = "wasm32")))]
188pub type StreamingResult<R> =
189    Pin<Box<dyn Stream<Item = Result<RawStreamingChoice<R>, CompletionError>> + Send>>;
190
191#[cfg(all(feature = "wasm", target_arch = "wasm32"))]
192pub type StreamingResult<R> =
193    Pin<Box<dyn Stream<Item = Result<RawStreamingChoice<R>, CompletionError>>>>;
194
195/// The response from a streaming completion request;
196/// message and response are populated at the end of the
197/// `inner` stream.
198pub struct StreamingCompletionResponse<R>
199where
200    R: Clone + Unpin + GetTokenUsage,
201{
202    pub(crate) inner: Abortable<StreamingResult<R>>,
203    pub(crate) abort_handle: AbortHandle,
204    pub(crate) pause_control: PauseControl,
205    assistant_items: Vec<AssistantContent>,
206    text_item_index: Option<usize>,
207    reasoning_item_index: Option<usize>,
208    /// The final aggregated message from the stream
209    /// contains all text and tool calls generated
210    pub choice: OneOrMany<AssistantContent>,
211    /// The final response from the stream, may be `None`
212    /// if the provider didn't yield it during the stream
213    pub response: Option<R>,
214    pub final_response_yielded: AtomicBool,
215    /// Provider-assigned message ID (e.g. OpenAI Responses API `msg_` ID).
216    pub message_id: Option<String>,
217}
218
219impl<R> StreamingCompletionResponse<R>
220where
221    R: Clone + Unpin + GetTokenUsage,
222{
223    pub fn stream(inner: StreamingResult<R>) -> StreamingCompletionResponse<R> {
224        let (abort_handle, abort_registration) = AbortHandle::new_pair();
225        let abortable_stream = Abortable::new(inner, abort_registration);
226        let pause_control = PauseControl::new();
227        Self {
228            inner: abortable_stream,
229            abort_handle,
230            pause_control,
231            assistant_items: vec![],
232            text_item_index: None,
233            reasoning_item_index: None,
234            choice: OneOrMany::one(AssistantContent::text("")),
235            response: None,
236            final_response_yielded: AtomicBool::new(false),
237            message_id: None,
238        }
239    }
240
241    pub fn cancel(&self) {
242        self.abort_handle.abort();
243    }
244
245    pub fn pause(&self) {
246        self.pause_control.pause();
247    }
248
249    pub fn resume(&self) {
250        self.pause_control.resume();
251    }
252
253    pub fn is_paused(&self) -> bool {
254        self.pause_control.is_paused()
255    }
256
257    fn append_text_chunk(&mut self, text: &str) {
258        if let Some(index) = self.text_item_index
259            && let Some(AssistantContent::Text(existing_text)) = self.assistant_items.get_mut(index)
260        {
261            existing_text.text.push_str(text);
262            return;
263        }
264
265        self.assistant_items
266            .push(AssistantContent::text(text.to_owned()));
267        self.text_item_index = Some(self.assistant_items.len() - 1);
268    }
269
270    /// Accumulate streaming reasoning delta text into assistant_items.
271    /// Providers that only emit ReasoningDelta (not full Reasoning blocks)
272    /// need this so the aggregated response includes reasoning content.
273    fn append_reasoning_chunk(&mut self, id: &Option<String>, text: &str) {
274        if let Some(index) = self.reasoning_item_index
275            && let Some(AssistantContent::Reasoning(existing)) = self.assistant_items.get_mut(index)
276            && let Some(ReasoningContent::Text {
277                text: existing_text,
278                ..
279            }) = existing.content.last_mut()
280        {
281            existing_text.push_str(text);
282            return;
283        }
284
285        self.assistant_items
286            .push(AssistantContent::Reasoning(Reasoning {
287                id: id.clone(),
288                content: vec![ReasoningContent::Text {
289                    text: text.to_string(),
290                    signature: None,
291                }],
292            }));
293        self.reasoning_item_index = Some(self.assistant_items.len() - 1);
294    }
295}
296
297impl<R> From<StreamingCompletionResponse<R>> for CompletionResponse<Option<R>>
298where
299    R: Clone + Unpin + GetTokenUsage,
300{
301    fn from(value: StreamingCompletionResponse<R>) -> CompletionResponse<Option<R>> {
302        CompletionResponse {
303            choice: value.choice,
304            usage: Usage::new(), // Usage is not tracked in streaming responses
305            raw_response: value.response,
306            message_id: value.message_id,
307        }
308    }
309}
310
311impl<R> Stream for StreamingCompletionResponse<R>
312where
313    R: Clone + Unpin + GetTokenUsage,
314{
315    type Item = Result<StreamedAssistantContent<R>, CompletionError>;
316
317    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
318        let stream = self.get_mut();
319
320        if stream.is_paused() {
321            cx.waker().wake_by_ref();
322            return Poll::Pending;
323        }
324
325        match Pin::new(&mut stream.inner).poll_next(cx) {
326            Poll::Pending => Poll::Pending,
327            Poll::Ready(None) => {
328                // This is run at the end of the inner stream to collect all tokens into
329                // a single unified `Message`.
330                if stream.assistant_items.is_empty() {
331                    stream.assistant_items.push(AssistantContent::text(""));
332                }
333
334                if let Some(choice) =
335                    OneOrMany::from_iter_optional(std::mem::take(&mut stream.assistant_items))
336                {
337                    stream.choice = choice;
338                }
339
340                Poll::Ready(None)
341            }
342            Poll::Ready(Some(Err(err))) => {
343                if matches!(err, CompletionError::ProviderError(ref e) if e.to_string().contains("aborted"))
344                {
345                    return Poll::Ready(None); // Treat cancellation as stream termination
346                }
347                Poll::Ready(Some(Err(err)))
348            }
349            Poll::Ready(Some(Ok(choice))) => match choice {
350                RawStreamingChoice::Message(text) => {
351                    stream.reasoning_item_index = None;
352                    stream.append_text_chunk(&text);
353                    Poll::Ready(Some(Ok(StreamedAssistantContent::text(&text))))
354                }
355                RawStreamingChoice::ToolCallDelta {
356                    id,
357                    internal_call_id,
358                    content,
359                } => Poll::Ready(Some(Ok(StreamedAssistantContent::ToolCallDelta {
360                    id,
361                    internal_call_id,
362                    content,
363                }))),
364                RawStreamingChoice::Reasoning { id, content } => {
365                    let reasoning = Reasoning {
366                        id,
367                        content: vec![content],
368                    };
369                    stream.text_item_index = None;
370                    // Full reasoning block supersedes any delta accumulation
371                    stream.reasoning_item_index = None;
372                    stream
373                        .assistant_items
374                        .push(AssistantContent::Reasoning(reasoning.clone()));
375                    Poll::Ready(Some(Ok(StreamedAssistantContent::Reasoning(reasoning))))
376                }
377                RawStreamingChoice::ReasoningDelta { id, reasoning } => {
378                    stream.text_item_index = None;
379                    stream.append_reasoning_chunk(&id, &reasoning);
380                    Poll::Ready(Some(Ok(StreamedAssistantContent::ReasoningDelta {
381                        id,
382                        reasoning,
383                    })))
384                }
385                RawStreamingChoice::ToolCall(raw_tool_call) => {
386                    let internal_call_id = raw_tool_call.internal_call_id.clone();
387                    let tool_call: ToolCall = raw_tool_call.into();
388                    stream.text_item_index = None;
389                    stream.reasoning_item_index = None;
390                    stream
391                        .assistant_items
392                        .push(AssistantContent::ToolCall(tool_call.clone()));
393                    Poll::Ready(Some(Ok(StreamedAssistantContent::ToolCall {
394                        tool_call,
395                        internal_call_id,
396                    })))
397                }
398                RawStreamingChoice::FinalResponse(response) => {
399                    if stream
400                        .final_response_yielded
401                        .load(std::sync::atomic::Ordering::SeqCst)
402                    {
403                        stream.poll_next_unpin(cx)
404                    } else {
405                        // Set the final response field and return the next item in the stream
406                        stream.response = Some(response.clone());
407                        stream
408                            .final_response_yielded
409                            .store(true, std::sync::atomic::Ordering::SeqCst);
410                        let final_response = StreamedAssistantContent::final_response(response);
411                        Poll::Ready(Some(Ok(final_response)))
412                    }
413                }
414                RawStreamingChoice::MessageId(id) => {
415                    stream.message_id = Some(id);
416                    stream.poll_next_unpin(cx)
417                }
418            },
419        }
420    }
421}
422
423/// Trait for high-level streaming prompt interface.
424///
425/// This trait provides a simple interface for streaming prompts to a completion model.
426/// Implementations can optionally support prompt hooks for observing and controlling
427/// the agent's execution lifecycle.
428pub trait StreamingPrompt<M, R>
429where
430    M: CompletionModel + 'static,
431    <M as CompletionModel>::StreamingResponse: WasmCompatSend,
432    R: Clone + Unpin + GetTokenUsage,
433{
434    /// The hook type used by this streaming prompt implementation.
435    ///
436    /// If your implementation does not need prompt hooks, use `()` as the hook type:
437    ///
438    /// ```ignore
439    /// impl<M, R> StreamingPrompt<M, R> for MyType<M>
440    /// where
441    ///     M: CompletionModel + 'static,
442    ///     // ... other bounds ...
443    /// {
444    ///     type Hook = ();
445    ///
446    ///     fn stream_prompt(&self, prompt: impl Into<Message>) -> StreamingPromptRequest<M, ()> {
447    ///         // ...
448    ///     }
449    /// }
450    /// ```
451    type Hook: PromptHook<M>;
452
453    /// Stream a simple prompt to the model
454    fn stream_prompt(
455        &self,
456        prompt: impl Into<Message> + WasmCompatSend,
457    ) -> StreamingPromptRequest<M, Self::Hook>;
458}
459
460/// Trait for high-level streaming chat interface with conversation history.
461///
462/// This trait provides an interface for streaming chat completions with support
463/// for maintaining conversation history. Implementations can optionally support
464/// prompt hooks for observing and controlling the agent's execution lifecycle.
465pub trait StreamingChat<M, R>: WasmCompatSend + WasmCompatSync
466where
467    M: CompletionModel + 'static,
468    <M as CompletionModel>::StreamingResponse: WasmCompatSend,
469    R: Clone + Unpin + GetTokenUsage,
470{
471    /// The hook type used by this streaming chat implementation.
472    ///
473    /// If your implementation does not need prompt hooks, use `()` as the hook type:
474    ///
475    /// ```ignore
476    /// impl<M, R> StreamingChat<M, R> for MyType<M>
477    /// where
478    ///     M: CompletionModel + 'static,
479    ///     // ... other bounds ...
480    /// {
481    ///     type Hook = ();
482    ///
483    ///     fn stream_chat(
484    ///         &self,
485    ///         prompt: impl Into<Message>,
486    ///         chat_history: Vec<Message>,
487    ///     ) -> StreamingPromptRequest<M, ()> {
488    ///         // ...
489    ///     }
490    /// }
491    /// ```
492    type Hook: PromptHook<M>;
493
494    /// Stream a chat with history to the model.
495    ///
496    /// The messages returned by the model can be accessed via `FinalResponse::history()`
497    ///
498    /// You are responsible for managing history, a simple linear solution could look like:
499    /// ```ignore
500    ///  let mut history = vec![];
501    ///
502    ///  loop {
503    ///      let prompt = "Create GPT-67, make no mistakes";
504    ///      let mut stream = agent.stream_chat(prompt, &history).await;
505    ///
506    ///      while let Some(msg) = stream.next().await {
507    ///         match msg {
508    ///              Ok(MultiTurnStreamItem::FinalResponse(fin)) => {
509    ///                  history.extend_from_slice(fin.history().unwrap_or_default());
510    ///                  break;
511    ///             }
512    ///             Ok(_other) => { /* Do something with this chunk */ }
513    ///             Err(e) => return Err(e.into()),
514    ///         }
515    ///     }
516    /// }
517    /// ```
518    fn stream_chat<I, T>(
519        &self,
520        prompt: impl Into<Message> + WasmCompatSend,
521        chat_history: I,
522    ) -> StreamingPromptRequest<M, Self::Hook>
523    where
524        I: IntoIterator<Item = T> + WasmCompatSend,
525        T: Into<Message>;
526}
527
528/// Trait for low-level streaming completion interface
529pub trait StreamingCompletion<M: CompletionModel> {
530    /// Generate a streaming completion from a request
531    fn stream_completion<I, T>(
532        &self,
533        prompt: impl Into<Message> + WasmCompatSend,
534        chat_history: I,
535    ) -> impl Future<Output = Result<CompletionRequestBuilder<M>, CompletionError>>
536    where
537        I: IntoIterator<Item = T> + WasmCompatSend,
538        T: Into<Message>;
539}
540
541/// A helper function to stream a completion request to stdout.
542/// Tool call deltas are ignored as tool calls are generally much easier to handle when received in their entirety rather than using deltas.
543pub async fn stream_to_stdout<M>(
544    agent: &'static Agent<M>,
545    stream: &mut StreamingCompletionResponse<M::StreamingResponse>,
546) -> Result<(), std::io::Error>
547where
548    M: CompletionModel,
549{
550    let mut is_reasoning = false;
551    print!("Response: ");
552    while let Some(chunk) = stream.next().await {
553        match chunk {
554            Ok(StreamedAssistantContent::Text(text)) => {
555                if is_reasoning {
556                    is_reasoning = false;
557                    println!("\n---\n");
558                }
559                print!("{}", text.text);
560                std::io::Write::flush(&mut std::io::stdout())?;
561            }
562            Ok(StreamedAssistantContent::ToolCall {
563                tool_call,
564                internal_call_id: _,
565            }) => {
566                let res = agent
567                    .tool_server_handle
568                    .call_tool(
569                        &tool_call.function.name,
570                        &tool_call.function.arguments.to_string(),
571                    )
572                    .await
573                    .map_err(|x| std::io::Error::other(x.to_string()))?;
574                println!("\nResult: {res}");
575            }
576            Ok(StreamedAssistantContent::Final(res)) => {
577                if let Ok(json_res) = serde_json::to_string_pretty(&res) {
578                    println!();
579                    tracing::info!("Final result: {json_res}");
580                }
581            }
582            Ok(StreamedAssistantContent::Reasoning(reasoning)) => {
583                if !is_reasoning {
584                    is_reasoning = true;
585                    println!();
586                    println!("Thinking: ");
587                }
588                let reasoning = reasoning.display_text();
589
590                print!("{reasoning}");
591                std::io::Write::flush(&mut std::io::stdout())?;
592            }
593            Err(e) => {
594                if e.to_string().contains("aborted") {
595                    println!("\nStream cancelled.");
596                    break;
597                }
598                eprintln!("Error: {e}");
599                break;
600            }
601            _ => {}
602        }
603    }
604
605    println!(); // New line after streaming completes
606
607    Ok(())
608}
609
610// Test module
611#[cfg(test)]
612mod tests {
613    use std::time::Duration;
614
615    use super::*;
616    use async_stream::stream;
617    use tokio::time::sleep;
618
619    #[derive(Debug, Clone)]
620    pub struct MockResponse {
621        #[allow(dead_code)]
622        token_count: u32,
623    }
624
625    impl GetTokenUsage for MockResponse {
626        fn token_usage(&self) -> Option<crate::completion::Usage> {
627            let mut usage = Usage::new();
628            usage.total_tokens = 15;
629            Some(usage)
630        }
631    }
632
633    #[cfg(not(all(feature = "wasm", target_arch = "wasm32")))]
634    fn to_stream_result(
635        stream: impl futures::Stream<Item = Result<RawStreamingChoice<MockResponse>, CompletionError>>
636        + Send
637        + 'static,
638    ) -> StreamingResult<MockResponse> {
639        Box::pin(stream)
640    }
641
642    #[cfg(all(feature = "wasm", target_arch = "wasm32"))]
643    fn to_stream_result(
644        stream: impl futures::Stream<Item = Result<RawStreamingChoice<MockResponse>, CompletionError>>
645        + 'static,
646    ) -> StreamingResult<MockResponse> {
647        Box::pin(stream)
648    }
649
650    fn create_mock_stream() -> StreamingCompletionResponse<MockResponse> {
651        let stream = stream! {
652            yield Ok(RawStreamingChoice::Message("hello 1".to_string()));
653            sleep(Duration::from_millis(100)).await;
654            yield Ok(RawStreamingChoice::Message("hello 2".to_string()));
655            sleep(Duration::from_millis(100)).await;
656            yield Ok(RawStreamingChoice::Message("hello 3".to_string()));
657            sleep(Duration::from_millis(100)).await;
658            yield Ok(RawStreamingChoice::FinalResponse(MockResponse { token_count: 15 }));
659        };
660
661        StreamingCompletionResponse::stream(to_stream_result(stream))
662    }
663
664    fn create_reasoning_stream() -> StreamingCompletionResponse<MockResponse> {
665        let stream = stream! {
666            yield Ok(RawStreamingChoice::Reasoning {
667                id: Some("rs_1".to_string()),
668                content: ReasoningContent::Text {
669                    text: "step one".to_string(),
670                    signature: Some("sig_1".to_string()),
671                },
672            });
673            yield Ok(RawStreamingChoice::Message("final answer".to_string()));
674            yield Ok(RawStreamingChoice::FinalResponse(MockResponse { token_count: 5 }));
675        };
676
677        StreamingCompletionResponse::stream(to_stream_result(stream))
678    }
679
680    fn create_reasoning_only_stream() -> StreamingCompletionResponse<MockResponse> {
681        let stream = stream! {
682            yield Ok(RawStreamingChoice::Reasoning {
683                id: Some("rs_only".to_string()),
684                content: ReasoningContent::Summary("hidden summary".to_string()),
685            });
686            yield Ok(RawStreamingChoice::FinalResponse(MockResponse { token_count: 2 }));
687        };
688
689        StreamingCompletionResponse::stream(to_stream_result(stream))
690    }
691
692    fn create_interleaved_stream() -> StreamingCompletionResponse<MockResponse> {
693        let stream = stream! {
694            yield Ok(RawStreamingChoice::Reasoning {
695                id: Some("rs_interleaved".to_string()),
696                content: ReasoningContent::Text {
697                    text: "chain-of-thought".to_string(),
698                    signature: None,
699                },
700            });
701            yield Ok(RawStreamingChoice::Message("final-text".to_string()));
702            yield Ok(RawStreamingChoice::ToolCall(
703                RawStreamingToolCall::new(
704                    "tool_1".to_string(),
705                    "mock_tool".to_string(),
706                    serde_json::json!({"arg": 1}),
707                ),
708            ));
709            yield Ok(RawStreamingChoice::FinalResponse(MockResponse { token_count: 3 }));
710        };
711
712        StreamingCompletionResponse::stream(to_stream_result(stream))
713    }
714
715    fn create_text_tool_text_stream() -> StreamingCompletionResponse<MockResponse> {
716        let stream = stream! {
717            yield Ok(RawStreamingChoice::Message("first".to_string()));
718            yield Ok(RawStreamingChoice::ToolCall(
719                RawStreamingToolCall::new(
720                    "tool_split".to_string(),
721                    "mock_tool".to_string(),
722                    serde_json::json!({"arg": "x"}),
723                ),
724            ));
725            yield Ok(RawStreamingChoice::Message("second".to_string()));
726            yield Ok(RawStreamingChoice::FinalResponse(MockResponse { token_count: 3 }));
727        };
728
729        StreamingCompletionResponse::stream(to_stream_result(stream))
730    }
731
732    #[tokio::test]
733    async fn test_stream_cancellation() {
734        let mut stream = create_mock_stream();
735
736        println!("Response: ");
737        let mut chunk_count = 0;
738        while let Some(chunk) = stream.next().await {
739            match chunk {
740                Ok(StreamedAssistantContent::Text(text)) => {
741                    print!("{}", text.text);
742                    std::io::Write::flush(&mut std::io::stdout()).unwrap();
743                    chunk_count += 1;
744                }
745                Ok(StreamedAssistantContent::ToolCall {
746                    tool_call,
747                    internal_call_id,
748                }) => {
749                    println!("\nTool Call: {tool_call:?}, internal_call_id={internal_call_id:?}");
750                    chunk_count += 1;
751                }
752                Ok(StreamedAssistantContent::ToolCallDelta {
753                    id,
754                    internal_call_id,
755                    content,
756                }) => {
757                    println!(
758                        "\nTool Call delta: id={id:?}, internal_call_id={internal_call_id:?}, content={content:?}"
759                    );
760                    chunk_count += 1;
761                }
762                Ok(StreamedAssistantContent::Final(res)) => {
763                    println!("\nFinal response: {res:?}");
764                }
765                Ok(StreamedAssistantContent::Reasoning(reasoning)) => {
766                    let reasoning = reasoning.display_text();
767                    print!("{reasoning}");
768                    std::io::Write::flush(&mut std::io::stdout()).unwrap();
769                }
770                Ok(StreamedAssistantContent::ReasoningDelta { reasoning, .. }) => {
771                    println!("Reasoning delta: {reasoning}");
772                    chunk_count += 1;
773                }
774                Err(e) => {
775                    eprintln!("Error: {e:?}");
776                    break;
777                }
778            }
779
780            if chunk_count >= 2 {
781                println!("\nCancelling stream...");
782                stream.cancel();
783                println!("Stream cancelled.");
784                break;
785            }
786        }
787
788        let next_chunk = stream.next().await;
789        assert!(
790            next_chunk.is_none(),
791            "Expected no further chunks after cancellation, got {next_chunk:?}"
792        );
793    }
794
795    #[tokio::test]
796    async fn test_stream_pause_resume() {
797        let stream = create_mock_stream();
798
799        // Test pause
800        stream.pause();
801        assert!(stream.is_paused());
802
803        // Test resume
804        stream.resume();
805        assert!(!stream.is_paused());
806    }
807
808    #[tokio::test]
809    async fn test_stream_aggregates_reasoning_content() {
810        let mut stream = create_reasoning_stream();
811        while stream.next().await.is_some() {}
812
813        let choice_items: Vec<AssistantContent> = stream.choice.clone().into_iter().collect();
814
815        assert!(choice_items.iter().any(|item| matches!(
816            item,
817            AssistantContent::Reasoning(Reasoning {
818                id: Some(id),
819                content
820            }) if id == "rs_1"
821                && matches!(
822                    content.first(),
823                    Some(ReasoningContent::Text {
824                        text,
825                        signature: Some(signature)
826                    }) if text == "step one" && signature == "sig_1"
827                )
828        )));
829    }
830
831    #[tokio::test]
832    async fn test_stream_reasoning_only_does_not_inject_empty_text() {
833        let mut stream = create_reasoning_only_stream();
834        while stream.next().await.is_some() {}
835
836        let choice_items: Vec<AssistantContent> = stream.choice.clone().into_iter().collect();
837        assert_eq!(choice_items.len(), 1);
838        assert!(matches!(
839            choice_items.first(),
840            Some(AssistantContent::Reasoning(Reasoning { id: Some(id), .. })) if id == "rs_only"
841        ));
842    }
843
844    #[tokio::test]
845    async fn test_stream_aggregates_assistant_items_in_arrival_order() {
846        let mut stream = create_interleaved_stream();
847        while stream.next().await.is_some() {}
848
849        let choice_items: Vec<AssistantContent> = stream.choice.clone().into_iter().collect();
850        assert_eq!(choice_items.len(), 3);
851        assert!(matches!(
852            choice_items.first(),
853            Some(AssistantContent::Reasoning(Reasoning { id: Some(id), .. })) if id == "rs_interleaved"
854        ));
855        assert!(matches!(
856            choice_items.get(1),
857            Some(AssistantContent::Text(Text { text })) if text == "final-text"
858        ));
859        assert!(matches!(
860            choice_items.get(2),
861            Some(AssistantContent::ToolCall(ToolCall { id, .. })) if id == "tool_1"
862        ));
863    }
864
865    #[tokio::test]
866    async fn test_stream_keeps_non_contiguous_text_chunks_split_by_tool_call() {
867        let mut stream = create_text_tool_text_stream();
868        while stream.next().await.is_some() {}
869
870        let choice_items: Vec<AssistantContent> = stream.choice.clone().into_iter().collect();
871        assert_eq!(choice_items.len(), 3);
872        assert!(matches!(
873            choice_items.first(),
874            Some(AssistantContent::Text(Text { text })) if text == "first"
875        ));
876        assert!(matches!(
877            choice_items.get(1),
878            Some(AssistantContent::ToolCall(ToolCall { id, .. })) if id == "tool_split"
879        ));
880        assert!(matches!(
881            choice_items.get(2),
882            Some(AssistantContent::Text(Text { text })) if text == "second"
883        ));
884    }
885}
886
887/// Describes responses from a streamed provider response which is either text, a tool call or a final usage response.
888#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
889#[serde(untagged)]
890pub enum StreamedAssistantContent<R> {
891    Text(Text),
892    ToolCall {
893        tool_call: ToolCall,
894        /// Rig-generated unique identifier for this tool call.
895        /// Use this to correlate with ToolCallDelta events.
896        internal_call_id: String,
897    },
898    ToolCallDelta {
899        /// Provider-supplied tool call ID.
900        id: String,
901        /// Rig-generated unique identifier for this tool call.
902        internal_call_id: String,
903        content: ToolCallDeltaContent,
904    },
905    Reasoning(Reasoning),
906    ReasoningDelta {
907        id: Option<String>,
908        reasoning: String,
909    },
910    Final(R),
911}
912
913impl<R> StreamedAssistantContent<R>
914where
915    R: Clone + Unpin,
916{
917    pub fn text(text: &str) -> Self {
918        Self::Text(Text {
919            text: text.to_string(),
920        })
921    }
922
923    pub fn final_response(res: R) -> Self {
924        Self::Final(res)
925    }
926}
927
928/// Streamed user content. This content is primarily used to represent tool results from tool calls made during a multi-turn/step agent prompt.
929#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
930#[serde(untagged)]
931pub enum StreamedUserContent {
932    ToolResult {
933        tool_result: ToolResult,
934        /// Rig-generated unique identifier for the tool call this result
935        /// belongs to. Use this to correlate with the originating
936        /// [`StreamedAssistantContent::ToolCall::internal_call_id`].
937        internal_call_id: String,
938    },
939}
940
941impl StreamedUserContent {
942    pub fn tool_result(tool_result: ToolResult, internal_call_id: String) -> Self {
943        Self::ToolResult {
944            tool_result,
945            internal_call_id,
946        }
947    }
948}