Skip to main content

rig/
streaming.rs

1//! This module provides functionality for working with streaming completion models.
2//! It provides traits and types for generating streaming completion requests and
3//! handling streaming completion responses.
4//!
5//! The main traits defined in this module are:
6//! - [StreamingPrompt]: Defines a high-level streaming LLM one-shot prompt interface
7//! - [StreamingChat]: Defines a high-level streaming LLM chat interface with history
8//! - [StreamingCompletion]: Defines a low-level streaming LLM completion interface
9//!
10
11use crate::OneOrMany;
12use crate::agent::Agent;
13use crate::agent::prompt_request::hooks::PromptHook;
14use crate::agent::prompt_request::streaming::StreamingPromptRequest;
15use crate::client::FinalCompletionResponse;
16use crate::completion::{
17    CompletionError, CompletionModel, CompletionRequestBuilder, CompletionResponse, GetTokenUsage,
18    Message, Usage,
19};
20use crate::message::{
21    AssistantContent, Reasoning, ReasoningContent, Text, ToolCall, ToolFunction, ToolResult,
22};
23use crate::wasm_compat::{WasmCompatSend, WasmCompatSync};
24use futures::stream::{AbortHandle, Abortable};
25use futures::{Stream, StreamExt};
26use serde::{Deserialize, Serialize};
27use std::future::Future;
28use std::pin::Pin;
29use std::sync::atomic::AtomicBool;
30use std::task::{Context, Poll};
31use tokio::sync::watch;
32
33/// Control for pausing and resuming a streaming response
34pub struct PauseControl {
35    pub(crate) paused_tx: watch::Sender<bool>,
36    pub(crate) paused_rx: watch::Receiver<bool>,
37}
38
39impl PauseControl {
40    pub fn new() -> Self {
41        let (paused_tx, paused_rx) = watch::channel(false);
42        Self {
43            paused_tx,
44            paused_rx,
45        }
46    }
47
48    pub fn pause(&self) {
49        self.paused_tx.send(true).unwrap();
50    }
51
52    pub fn resume(&self) {
53        self.paused_tx.send(false).unwrap();
54    }
55
56    pub fn is_paused(&self) -> bool {
57        *self.paused_rx.borrow()
58    }
59}
60
61impl Default for PauseControl {
62    fn default() -> Self {
63        Self::new()
64    }
65}
66
67/// The content of a tool call delta - either the tool name or argument data
68#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
69pub enum ToolCallDeltaContent {
70    Name(String),
71    Delta(String),
72}
73
74/// Enum representing a streaming chunk from the model
75#[derive(Debug, Clone)]
76pub enum RawStreamingChoice<R>
77where
78    R: Clone,
79{
80    /// A text chunk from a message response
81    Message(String),
82
83    /// A tool call response (in its entirety)
84    ToolCall(RawStreamingToolCall),
85    /// A tool call partial/delta
86    ToolCallDelta {
87        /// Provider-supplied tool call ID.
88        id: String,
89        /// Rig-generated unique identifier for this tool call.
90        internal_call_id: String,
91        content: ToolCallDeltaContent,
92    },
93    /// A reasoning (in its entirety)
94    Reasoning {
95        id: Option<String>,
96        content: ReasoningContent,
97    },
98    /// A reasoning partial/delta
99    ReasoningDelta {
100        id: Option<String>,
101        reasoning: String,
102    },
103
104    /// The final response object, must be yielded if you want the
105    /// `response` field to be populated on the `StreamingCompletionResponse`
106    FinalResponse(R),
107
108    /// Provider-assigned message ID (e.g. OpenAI Responses API `msg_` ID).
109    /// Captured silently into `StreamingCompletionResponse::message_id`.
110    MessageId(String),
111}
112
113/// Describes a streaming tool call response (in its entirety)
114#[derive(Debug, Clone)]
115pub struct RawStreamingToolCall {
116    /// Provider-supplied tool call ID.
117    pub id: String,
118    /// Rig-generated unique identifier for this tool call.
119    pub internal_call_id: String,
120    pub call_id: Option<String>,
121    pub name: String,
122    pub arguments: serde_json::Value,
123    pub signature: Option<String>,
124    pub additional_params: Option<serde_json::Value>,
125}
126
127impl RawStreamingToolCall {
128    pub fn empty() -> Self {
129        Self {
130            id: String::new(),
131            internal_call_id: nanoid::nanoid!(),
132            call_id: None,
133            name: String::new(),
134            arguments: serde_json::Value::Null,
135            signature: None,
136            additional_params: None,
137        }
138    }
139
140    pub fn new(id: String, name: String, arguments: serde_json::Value) -> Self {
141        Self {
142            id,
143            internal_call_id: nanoid::nanoid!(),
144            call_id: None,
145            name,
146            arguments,
147            signature: None,
148            additional_params: None,
149        }
150    }
151
152    pub fn with_internal_call_id(mut self, internal_call_id: String) -> Self {
153        self.internal_call_id = internal_call_id;
154        self
155    }
156
157    pub fn with_call_id(mut self, call_id: String) -> Self {
158        self.call_id = Some(call_id);
159        self
160    }
161
162    pub fn with_signature(mut self, signature: Option<String>) -> Self {
163        self.signature = signature;
164        self
165    }
166
167    pub fn with_additional_params(mut self, additional_params: Option<serde_json::Value>) -> Self {
168        self.additional_params = additional_params;
169        self
170    }
171}
172
173impl From<RawStreamingToolCall> for ToolCall {
174    fn from(tool_call: RawStreamingToolCall) -> Self {
175        ToolCall {
176            id: tool_call.id,
177            call_id: tool_call.call_id,
178            function: ToolFunction {
179                name: tool_call.name,
180                arguments: tool_call.arguments,
181            },
182            signature: tool_call.signature,
183            additional_params: tool_call.additional_params,
184        }
185    }
186}
187
188#[cfg(not(all(feature = "wasm", target_arch = "wasm32")))]
189pub type StreamingResult<R> =
190    Pin<Box<dyn Stream<Item = Result<RawStreamingChoice<R>, CompletionError>> + Send>>;
191
192#[cfg(all(feature = "wasm", target_arch = "wasm32"))]
193pub type StreamingResult<R> =
194    Pin<Box<dyn Stream<Item = Result<RawStreamingChoice<R>, CompletionError>>>>;
195
196/// The response from a streaming completion request;
197/// message and response are populated at the end of the
198/// `inner` stream.
199pub struct StreamingCompletionResponse<R>
200where
201    R: Clone + Unpin + GetTokenUsage,
202{
203    pub(crate) inner: Abortable<StreamingResult<R>>,
204    pub(crate) abort_handle: AbortHandle,
205    pub(crate) pause_control: PauseControl,
206    assistant_items: Vec<AssistantContent>,
207    text_item_index: Option<usize>,
208    reasoning_item_index: Option<usize>,
209    /// The final aggregated message from the stream
210    /// contains all text and tool calls generated
211    pub choice: OneOrMany<AssistantContent>,
212    /// The final response from the stream, may be `None`
213    /// if the provider didn't yield it during the stream
214    pub response: Option<R>,
215    pub final_response_yielded: AtomicBool,
216    /// Provider-assigned message ID (e.g. OpenAI Responses API `msg_` ID).
217    pub message_id: Option<String>,
218}
219
220impl<R> StreamingCompletionResponse<R>
221where
222    R: Clone + Unpin + GetTokenUsage,
223{
224    pub fn stream(inner: StreamingResult<R>) -> StreamingCompletionResponse<R> {
225        let (abort_handle, abort_registration) = AbortHandle::new_pair();
226        let abortable_stream = Abortable::new(inner, abort_registration);
227        let pause_control = PauseControl::new();
228        Self {
229            inner: abortable_stream,
230            abort_handle,
231            pause_control,
232            assistant_items: vec![],
233            text_item_index: None,
234            reasoning_item_index: None,
235            choice: OneOrMany::one(AssistantContent::text("")),
236            response: None,
237            final_response_yielded: AtomicBool::new(false),
238            message_id: None,
239        }
240    }
241
242    pub fn cancel(&self) {
243        self.abort_handle.abort();
244    }
245
246    pub fn pause(&self) {
247        self.pause_control.pause();
248    }
249
250    pub fn resume(&self) {
251        self.pause_control.resume();
252    }
253
254    pub fn is_paused(&self) -> bool {
255        self.pause_control.is_paused()
256    }
257
258    fn append_text_chunk(&mut self, text: &str) {
259        if let Some(index) = self.text_item_index
260            && let Some(AssistantContent::Text(existing_text)) = self.assistant_items.get_mut(index)
261        {
262            existing_text.text.push_str(text);
263            return;
264        }
265
266        self.assistant_items
267            .push(AssistantContent::text(text.to_owned()));
268        self.text_item_index = Some(self.assistant_items.len() - 1);
269    }
270
271    /// Accumulate streaming reasoning delta text into assistant_items.
272    /// Providers that only emit ReasoningDelta (not full Reasoning blocks)
273    /// need this so the aggregated response includes reasoning content.
274    fn append_reasoning_chunk(&mut self, id: &Option<String>, text: &str) {
275        if let Some(index) = self.reasoning_item_index
276            && let Some(AssistantContent::Reasoning(existing)) = self.assistant_items.get_mut(index)
277            && let Some(ReasoningContent::Text {
278                text: existing_text,
279                ..
280            }) = existing.content.last_mut()
281        {
282            existing_text.push_str(text);
283            return;
284        }
285
286        self.assistant_items
287            .push(AssistantContent::Reasoning(Reasoning {
288                id: id.clone(),
289                content: vec![ReasoningContent::Text {
290                    text: text.to_string(),
291                    signature: None,
292                }],
293            }));
294        self.reasoning_item_index = Some(self.assistant_items.len() - 1);
295    }
296}
297
298impl<R> From<StreamingCompletionResponse<R>> for CompletionResponse<Option<R>>
299where
300    R: Clone + Unpin + GetTokenUsage,
301{
302    fn from(value: StreamingCompletionResponse<R>) -> CompletionResponse<Option<R>> {
303        CompletionResponse {
304            choice: value.choice,
305            usage: Usage::new(), // Usage is not tracked in streaming responses
306            raw_response: value.response,
307            message_id: value.message_id,
308        }
309    }
310}
311
312impl<R> Stream for StreamingCompletionResponse<R>
313where
314    R: Clone + Unpin + GetTokenUsage,
315{
316    type Item = Result<StreamedAssistantContent<R>, CompletionError>;
317
318    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
319        let stream = self.get_mut();
320
321        if stream.is_paused() {
322            cx.waker().wake_by_ref();
323            return Poll::Pending;
324        }
325
326        match Pin::new(&mut stream.inner).poll_next(cx) {
327            Poll::Pending => Poll::Pending,
328            Poll::Ready(None) => {
329                // This is run at the end of the inner stream to collect all tokens into
330                // a single unified `Message`.
331                if stream.assistant_items.is_empty() {
332                    stream.assistant_items.push(AssistantContent::text(""));
333                }
334
335                stream.choice = OneOrMany::many(std::mem::take(&mut stream.assistant_items))
336                    .expect("There should be at least one assistant message");
337
338                Poll::Ready(None)
339            }
340            Poll::Ready(Some(Err(err))) => {
341                if matches!(err, CompletionError::ProviderError(ref e) if e.to_string().contains("aborted"))
342                {
343                    return Poll::Ready(None); // Treat cancellation as stream termination
344                }
345                Poll::Ready(Some(Err(err)))
346            }
347            Poll::Ready(Some(Ok(choice))) => match choice {
348                RawStreamingChoice::Message(text) => {
349                    stream.reasoning_item_index = None;
350                    stream.append_text_chunk(&text);
351                    Poll::Ready(Some(Ok(StreamedAssistantContent::text(&text))))
352                }
353                RawStreamingChoice::ToolCallDelta {
354                    id,
355                    internal_call_id,
356                    content,
357                } => Poll::Ready(Some(Ok(StreamedAssistantContent::ToolCallDelta {
358                    id,
359                    internal_call_id,
360                    content,
361                }))),
362                RawStreamingChoice::Reasoning { id, content } => {
363                    let reasoning = Reasoning {
364                        id,
365                        content: vec![content],
366                    };
367                    stream.text_item_index = None;
368                    // Full reasoning block supersedes any delta accumulation
369                    stream.reasoning_item_index = None;
370                    stream
371                        .assistant_items
372                        .push(AssistantContent::Reasoning(reasoning.clone()));
373                    Poll::Ready(Some(Ok(StreamedAssistantContent::Reasoning(reasoning))))
374                }
375                RawStreamingChoice::ReasoningDelta { id, reasoning } => {
376                    stream.text_item_index = None;
377                    stream.append_reasoning_chunk(&id, &reasoning);
378                    Poll::Ready(Some(Ok(StreamedAssistantContent::ReasoningDelta {
379                        id,
380                        reasoning,
381                    })))
382                }
383                RawStreamingChoice::ToolCall(raw_tool_call) => {
384                    let internal_call_id = raw_tool_call.internal_call_id.clone();
385                    let tool_call: ToolCall = raw_tool_call.into();
386                    stream.text_item_index = None;
387                    stream.reasoning_item_index = None;
388                    stream
389                        .assistant_items
390                        .push(AssistantContent::ToolCall(tool_call.clone()));
391                    Poll::Ready(Some(Ok(StreamedAssistantContent::ToolCall {
392                        tool_call,
393                        internal_call_id,
394                    })))
395                }
396                RawStreamingChoice::FinalResponse(response) => {
397                    if stream
398                        .final_response_yielded
399                        .load(std::sync::atomic::Ordering::SeqCst)
400                    {
401                        stream.poll_next_unpin(cx)
402                    } else {
403                        // Set the final response field and return the next item in the stream
404                        stream.response = Some(response.clone());
405                        stream
406                            .final_response_yielded
407                            .store(true, std::sync::atomic::Ordering::SeqCst);
408                        let final_response = StreamedAssistantContent::final_response(response);
409                        Poll::Ready(Some(Ok(final_response)))
410                    }
411                }
412                RawStreamingChoice::MessageId(id) => {
413                    stream.message_id = Some(id);
414                    stream.poll_next_unpin(cx)
415                }
416            },
417        }
418    }
419}
420
421/// Trait for high-level streaming prompt interface.
422///
423/// This trait provides a simple interface for streaming prompts to a completion model.
424/// Implementations can optionally support prompt hooks for observing and controlling
425/// the agent's execution lifecycle.
426pub trait StreamingPrompt<M, R>
427where
428    M: CompletionModel + 'static,
429    <M as CompletionModel>::StreamingResponse: WasmCompatSend,
430    R: Clone + Unpin + GetTokenUsage,
431{
432    /// The hook type used by this streaming prompt implementation.
433    ///
434    /// If your implementation does not need prompt hooks, use `()` as the hook type:
435    ///
436    /// ```ignore
437    /// impl<M, R> StreamingPrompt<M, R> for MyType<M>
438    /// where
439    ///     M: CompletionModel + 'static,
440    ///     // ... other bounds ...
441    /// {
442    ///     type Hook = ();
443    ///
444    ///     fn stream_prompt(&self, prompt: impl Into<Message>) -> StreamingPromptRequest<M, ()> {
445    ///         // ...
446    ///     }
447    /// }
448    /// ```
449    type Hook: PromptHook<M>;
450
451    /// Stream a simple prompt to the model
452    fn stream_prompt(
453        &self,
454        prompt: impl Into<Message> + WasmCompatSend,
455    ) -> StreamingPromptRequest<M, Self::Hook>;
456}
457
458/// Trait for high-level streaming chat interface with conversation history.
459///
460/// This trait provides an interface for streaming chat completions with support
461/// for maintaining conversation history. Implementations can optionally support
462/// prompt hooks for observing and controlling the agent's execution lifecycle.
463pub trait StreamingChat<M, R>: WasmCompatSend + WasmCompatSync
464where
465    M: CompletionModel + 'static,
466    <M as CompletionModel>::StreamingResponse: WasmCompatSend,
467    R: Clone + Unpin + GetTokenUsage,
468{
469    /// The hook type used by this streaming chat implementation.
470    ///
471    /// If your implementation does not need prompt hooks, use `()` as the hook type:
472    ///
473    /// ```ignore
474    /// impl<M, R> StreamingChat<M, R> for MyType<M>
475    /// where
476    ///     M: CompletionModel + 'static,
477    ///     // ... other bounds ...
478    /// {
479    ///     type Hook = ();
480    ///
481    ///     fn stream_chat(
482    ///         &self,
483    ///         prompt: impl Into<Message>,
484    ///         chat_history: Vec<Message>,
485    ///     ) -> StreamingPromptRequest<M, ()> {
486    ///         // ...
487    ///     }
488    /// }
489    /// ```
490    type Hook: PromptHook<M>;
491
492    /// Stream a chat with history to the model
493    ///
494    /// The updated history (including the new prompt and response) is returned
495    /// in [`FinalResponse::history()`](crate::agent::prompt_request::streaming::FinalResponse::history).
496    fn stream_chat(
497        &self,
498        prompt: impl Into<Message> + WasmCompatSend,
499        chat_history: Vec<Message>,
500    ) -> StreamingPromptRequest<M, Self::Hook>;
501}
502
503/// Trait for low-level streaming completion interface
504pub trait StreamingCompletion<M: CompletionModel> {
505    /// Generate a streaming completion from a request
506    fn stream_completion(
507        &self,
508        prompt: impl Into<Message> + WasmCompatSend,
509        chat_history: Vec<Message>,
510    ) -> impl Future<Output = Result<CompletionRequestBuilder<M>, CompletionError>>;
511}
512
513pub(crate) struct StreamingResultDyn<R: Clone + Unpin + GetTokenUsage> {
514    pub(crate) inner: StreamingResult<R>,
515}
516
517fn map_raw_streaming_choice<R>(
518    chunk: RawStreamingChoice<R>,
519) -> RawStreamingChoice<FinalCompletionResponse>
520where
521    R: Clone + Unpin + GetTokenUsage,
522{
523    match chunk {
524        RawStreamingChoice::FinalResponse(res) => {
525            RawStreamingChoice::FinalResponse(FinalCompletionResponse {
526                usage: res.token_usage(),
527            })
528        }
529        RawStreamingChoice::Message(m) => RawStreamingChoice::Message(m),
530        RawStreamingChoice::ToolCallDelta {
531            id,
532            internal_call_id,
533            content,
534        } => RawStreamingChoice::ToolCallDelta {
535            id,
536            internal_call_id,
537            content,
538        },
539        RawStreamingChoice::Reasoning { id, content } => {
540            RawStreamingChoice::Reasoning { id, content }
541        }
542        RawStreamingChoice::ReasoningDelta { id, reasoning } => {
543            RawStreamingChoice::ReasoningDelta { id, reasoning }
544        }
545        RawStreamingChoice::ToolCall(tool_call) => RawStreamingChoice::ToolCall(tool_call),
546        RawStreamingChoice::MessageId(id) => RawStreamingChoice::MessageId(id),
547    }
548}
549
550impl<R: Clone + Unpin + GetTokenUsage> Stream for StreamingResultDyn<R> {
551    type Item = Result<RawStreamingChoice<FinalCompletionResponse>, CompletionError>;
552
553    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
554        let stream = self.get_mut();
555
556        match stream.inner.as_mut().poll_next(cx) {
557            Poll::Pending => Poll::Pending,
558            Poll::Ready(None) => Poll::Ready(None),
559            Poll::Ready(Some(item)) => Poll::Ready(Some(item.map(map_raw_streaming_choice::<R>))),
560        }
561    }
562}
563
564/// A helper function to stream a completion request to stdout.
565/// Tool call deltas are ignored as tool calls are generally much easier to handle when received in their entirety rather than using deltas.
566pub async fn stream_to_stdout<M>(
567    agent: &'static Agent<M>,
568    stream: &mut StreamingCompletionResponse<M::StreamingResponse>,
569) -> Result<(), std::io::Error>
570where
571    M: CompletionModel,
572{
573    let mut is_reasoning = false;
574    print!("Response: ");
575    while let Some(chunk) = stream.next().await {
576        match chunk {
577            Ok(StreamedAssistantContent::Text(text)) => {
578                if is_reasoning {
579                    is_reasoning = false;
580                    println!("\n---\n");
581                }
582                print!("{}", text.text);
583                std::io::Write::flush(&mut std::io::stdout())?;
584            }
585            Ok(StreamedAssistantContent::ToolCall {
586                tool_call,
587                internal_call_id: _,
588            }) => {
589                let res = agent
590                    .tool_server_handle
591                    .call_tool(
592                        &tool_call.function.name,
593                        &tool_call.function.arguments.to_string(),
594                    )
595                    .await
596                    .map_err(|x| std::io::Error::other(x.to_string()))?;
597                println!("\nResult: {res}");
598            }
599            Ok(StreamedAssistantContent::Final(res)) => {
600                let json_res = serde_json::to_string_pretty(&res).unwrap();
601                println!();
602                tracing::info!("Final result: {json_res}");
603            }
604            Ok(StreamedAssistantContent::Reasoning(reasoning)) => {
605                if !is_reasoning {
606                    is_reasoning = true;
607                    println!();
608                    println!("Thinking: ");
609                }
610                let reasoning = reasoning.display_text();
611
612                print!("{reasoning}");
613                std::io::Write::flush(&mut std::io::stdout())?;
614            }
615            Err(e) => {
616                if e.to_string().contains("aborted") {
617                    println!("\nStream cancelled.");
618                    break;
619                }
620                eprintln!("Error: {e}");
621                break;
622            }
623            _ => {}
624        }
625    }
626
627    println!(); // New line after streaming completes
628
629    Ok(())
630}
631
632// Test module
633#[cfg(test)]
634mod tests {
635    use std::time::Duration;
636
637    use super::*;
638    use async_stream::stream;
639    use tokio::time::sleep;
640
641    #[derive(Debug, Clone)]
642    pub struct MockResponse {
643        #[allow(dead_code)]
644        token_count: u32,
645    }
646
647    impl GetTokenUsage for MockResponse {
648        fn token_usage(&self) -> Option<crate::completion::Usage> {
649            let mut usage = Usage::new();
650            usage.total_tokens = 15;
651            Some(usage)
652        }
653    }
654
655    #[cfg(not(all(feature = "wasm", target_arch = "wasm32")))]
656    fn to_stream_result(
657        stream: impl futures::Stream<Item = Result<RawStreamingChoice<MockResponse>, CompletionError>>
658        + Send
659        + 'static,
660    ) -> StreamingResult<MockResponse> {
661        Box::pin(stream)
662    }
663
664    #[cfg(all(feature = "wasm", target_arch = "wasm32"))]
665    fn to_stream_result(
666        stream: impl futures::Stream<Item = Result<RawStreamingChoice<MockResponse>, CompletionError>>
667        + 'static,
668    ) -> StreamingResult<MockResponse> {
669        Box::pin(stream)
670    }
671
672    fn create_mock_stream() -> StreamingCompletionResponse<MockResponse> {
673        let stream = stream! {
674            yield Ok(RawStreamingChoice::Message("hello 1".to_string()));
675            sleep(Duration::from_millis(100)).await;
676            yield Ok(RawStreamingChoice::Message("hello 2".to_string()));
677            sleep(Duration::from_millis(100)).await;
678            yield Ok(RawStreamingChoice::Message("hello 3".to_string()));
679            sleep(Duration::from_millis(100)).await;
680            yield Ok(RawStreamingChoice::FinalResponse(MockResponse { token_count: 15 }));
681        };
682
683        StreamingCompletionResponse::stream(to_stream_result(stream))
684    }
685
686    fn create_reasoning_stream() -> StreamingCompletionResponse<MockResponse> {
687        let stream = stream! {
688            yield Ok(RawStreamingChoice::Reasoning {
689                id: Some("rs_1".to_string()),
690                content: ReasoningContent::Text {
691                    text: "step one".to_string(),
692                    signature: Some("sig_1".to_string()),
693                },
694            });
695            yield Ok(RawStreamingChoice::Message("final answer".to_string()));
696            yield Ok(RawStreamingChoice::FinalResponse(MockResponse { token_count: 5 }));
697        };
698
699        StreamingCompletionResponse::stream(to_stream_result(stream))
700    }
701
702    fn create_reasoning_only_stream() -> StreamingCompletionResponse<MockResponse> {
703        let stream = stream! {
704            yield Ok(RawStreamingChoice::Reasoning {
705                id: Some("rs_only".to_string()),
706                content: ReasoningContent::Summary("hidden summary".to_string()),
707            });
708            yield Ok(RawStreamingChoice::FinalResponse(MockResponse { token_count: 2 }));
709        };
710
711        StreamingCompletionResponse::stream(to_stream_result(stream))
712    }
713
714    fn create_interleaved_stream() -> StreamingCompletionResponse<MockResponse> {
715        let stream = stream! {
716            yield Ok(RawStreamingChoice::Reasoning {
717                id: Some("rs_interleaved".to_string()),
718                content: ReasoningContent::Text {
719                    text: "chain-of-thought".to_string(),
720                    signature: None,
721                },
722            });
723            yield Ok(RawStreamingChoice::Message("final-text".to_string()));
724            yield Ok(RawStreamingChoice::ToolCall(
725                RawStreamingToolCall::new(
726                    "tool_1".to_string(),
727                    "mock_tool".to_string(),
728                    serde_json::json!({"arg": 1}),
729                ),
730            ));
731            yield Ok(RawStreamingChoice::FinalResponse(MockResponse { token_count: 3 }));
732        };
733
734        StreamingCompletionResponse::stream(to_stream_result(stream))
735    }
736
737    fn create_text_tool_text_stream() -> StreamingCompletionResponse<MockResponse> {
738        let stream = stream! {
739            yield Ok(RawStreamingChoice::Message("first".to_string()));
740            yield Ok(RawStreamingChoice::ToolCall(
741                RawStreamingToolCall::new(
742                    "tool_split".to_string(),
743                    "mock_tool".to_string(),
744                    serde_json::json!({"arg": "x"}),
745                ),
746            ));
747            yield Ok(RawStreamingChoice::Message("second".to_string()));
748            yield Ok(RawStreamingChoice::FinalResponse(MockResponse { token_count: 3 }));
749        };
750
751        StreamingCompletionResponse::stream(to_stream_result(stream))
752    }
753
754    #[tokio::test]
755    async fn test_stream_cancellation() {
756        let mut stream = create_mock_stream();
757
758        println!("Response: ");
759        let mut chunk_count = 0;
760        while let Some(chunk) = stream.next().await {
761            match chunk {
762                Ok(StreamedAssistantContent::Text(text)) => {
763                    print!("{}", text.text);
764                    std::io::Write::flush(&mut std::io::stdout()).unwrap();
765                    chunk_count += 1;
766                }
767                Ok(StreamedAssistantContent::ToolCall {
768                    tool_call,
769                    internal_call_id,
770                }) => {
771                    println!("\nTool Call: {tool_call:?}, internal_call_id={internal_call_id:?}");
772                    chunk_count += 1;
773                }
774                Ok(StreamedAssistantContent::ToolCallDelta {
775                    id,
776                    internal_call_id,
777                    content,
778                }) => {
779                    println!(
780                        "\nTool Call delta: id={id:?}, internal_call_id={internal_call_id:?}, content={content:?}"
781                    );
782                    chunk_count += 1;
783                }
784                Ok(StreamedAssistantContent::Final(res)) => {
785                    println!("\nFinal response: {res:?}");
786                }
787                Ok(StreamedAssistantContent::Reasoning(reasoning)) => {
788                    let reasoning = reasoning.display_text();
789                    print!("{reasoning}");
790                    std::io::Write::flush(&mut std::io::stdout()).unwrap();
791                }
792                Ok(StreamedAssistantContent::ReasoningDelta { reasoning, .. }) => {
793                    println!("Reasoning delta: {reasoning}");
794                    chunk_count += 1;
795                }
796                Err(e) => {
797                    eprintln!("Error: {e:?}");
798                    break;
799                }
800            }
801
802            if chunk_count >= 2 {
803                println!("\nCancelling stream...");
804                stream.cancel();
805                println!("Stream cancelled.");
806                break;
807            }
808        }
809
810        let next_chunk = stream.next().await;
811        assert!(
812            next_chunk.is_none(),
813            "Expected no further chunks after cancellation, got {next_chunk:?}"
814        );
815    }
816
817    #[tokio::test]
818    async fn test_stream_pause_resume() {
819        let stream = create_mock_stream();
820
821        // Test pause
822        stream.pause();
823        assert!(stream.is_paused());
824
825        // Test resume
826        stream.resume();
827        assert!(!stream.is_paused());
828    }
829
830    #[tokio::test]
831    async fn test_stream_aggregates_reasoning_content() {
832        let mut stream = create_reasoning_stream();
833        while stream.next().await.is_some() {}
834
835        let choice_items: Vec<AssistantContent> = stream.choice.clone().into_iter().collect();
836
837        assert!(choice_items.iter().any(|item| matches!(
838            item,
839            AssistantContent::Reasoning(Reasoning {
840                id: Some(id),
841                content
842            }) if id == "rs_1"
843                && matches!(
844                    content.first(),
845                    Some(ReasoningContent::Text {
846                        text,
847                        signature: Some(signature)
848                    }) if text == "step one" && signature == "sig_1"
849                )
850        )));
851    }
852
853    #[tokio::test]
854    async fn test_stream_reasoning_only_does_not_inject_empty_text() {
855        let mut stream = create_reasoning_only_stream();
856        while stream.next().await.is_some() {}
857
858        let choice_items: Vec<AssistantContent> = stream.choice.clone().into_iter().collect();
859        assert_eq!(choice_items.len(), 1);
860        assert!(matches!(
861            choice_items.first(),
862            Some(AssistantContent::Reasoning(Reasoning { id: Some(id), .. })) if id == "rs_only"
863        ));
864    }
865
866    #[tokio::test]
867    async fn test_stream_aggregates_assistant_items_in_arrival_order() {
868        let mut stream = create_interleaved_stream();
869        while stream.next().await.is_some() {}
870
871        let choice_items: Vec<AssistantContent> = stream.choice.clone().into_iter().collect();
872        assert_eq!(choice_items.len(), 3);
873        assert!(matches!(
874            choice_items.first(),
875            Some(AssistantContent::Reasoning(Reasoning { id: Some(id), .. })) if id == "rs_interleaved"
876        ));
877        assert!(matches!(
878            choice_items.get(1),
879            Some(AssistantContent::Text(Text { text })) if text == "final-text"
880        ));
881        assert!(matches!(
882            choice_items.get(2),
883            Some(AssistantContent::ToolCall(ToolCall { id, .. })) if id == "tool_1"
884        ));
885    }
886
887    #[tokio::test]
888    async fn test_stream_keeps_non_contiguous_text_chunks_split_by_tool_call() {
889        let mut stream = create_text_tool_text_stream();
890        while stream.next().await.is_some() {}
891
892        let choice_items: Vec<AssistantContent> = stream.choice.clone().into_iter().collect();
893        assert_eq!(choice_items.len(), 3);
894        assert!(matches!(
895            choice_items.first(),
896            Some(AssistantContent::Text(Text { text })) if text == "first"
897        ));
898        assert!(matches!(
899            choice_items.get(1),
900            Some(AssistantContent::ToolCall(ToolCall { id, .. })) if id == "tool_split"
901        ));
902        assert!(matches!(
903            choice_items.get(2),
904            Some(AssistantContent::Text(Text { text })) if text == "second"
905        ));
906    }
907}
908
909/// Describes responses from a streamed provider response which is either text, a tool call or a final usage response.
910#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
911#[serde(untagged)]
912pub enum StreamedAssistantContent<R> {
913    Text(Text),
914    ToolCall {
915        tool_call: ToolCall,
916        /// Rig-generated unique identifier for this tool call.
917        /// Use this to correlate with ToolCallDelta events.
918        internal_call_id: String,
919    },
920    ToolCallDelta {
921        /// Provider-supplied tool call ID.
922        id: String,
923        /// Rig-generated unique identifier for this tool call.
924        internal_call_id: String,
925        content: ToolCallDeltaContent,
926    },
927    Reasoning(Reasoning),
928    ReasoningDelta {
929        id: Option<String>,
930        reasoning: String,
931    },
932    Final(R),
933}
934
935impl<R> StreamedAssistantContent<R>
936where
937    R: Clone + Unpin,
938{
939    pub fn text(text: &str) -> Self {
940        Self::Text(Text {
941            text: text.to_string(),
942        })
943    }
944
945    pub fn final_response(res: R) -> Self {
946        Self::Final(res)
947    }
948}
949
950/// Streamed user content. This content is primarily used to represent tool results from tool calls made during a multi-turn/step agent prompt.
951#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
952#[serde(untagged)]
953pub enum StreamedUserContent {
954    ToolResult {
955        tool_result: ToolResult,
956        /// Rig-generated unique identifier for the tool call this result
957        /// belongs to. Use this to correlate with the originating
958        /// [`StreamedAssistantContent::ToolCall::internal_call_id`].
959        internal_call_id: String,
960    },
961}
962
963impl StreamedUserContent {
964    pub fn tool_result(tool_result: ToolResult, internal_call_id: String) -> Self {
965        Self::ToolResult {
966            tool_result,
967            internal_call_id,
968        }
969    }
970}