Skip to main content

rig/
streaming.rs

1//! This module provides functionality for working with streaming completion models.
2//! It provides traits and types for generating streaming completion requests and
3//! handling streaming completion responses.
4//!
5//! The main traits defined in this module are:
6//! - [StreamingPrompt]: Defines a high-level streaming LLM one-shot prompt interface
7//! - [StreamingChat]: Defines a high-level streaming LLM chat interface with history
8//! - [StreamingCompletion]: Defines a low-level streaming LLM completion interface
9//!
10
11use crate::OneOrMany;
12use crate::agent::Agent;
13use crate::agent::prompt_request::hooks::PromptHook;
14use crate::agent::prompt_request::streaming::StreamingPromptRequest;
15use crate::client::FinalCompletionResponse;
16use crate::completion::{
17    CompletionError, CompletionModel, CompletionRequestBuilder, CompletionResponse, GetTokenUsage,
18    Message, Usage,
19};
20use crate::message::{
21    AssistantContent, Reasoning, ReasoningContent, Text, ToolCall, ToolFunction, ToolResult,
22};
23use crate::wasm_compat::{WasmCompatSend, WasmCompatSync};
24use futures::stream::{AbortHandle, Abortable};
25use futures::{Stream, StreamExt};
26use serde::{Deserialize, Serialize};
27use std::future::Future;
28use std::pin::Pin;
29use std::sync::atomic::AtomicBool;
30use std::task::{Context, Poll};
31use tokio::sync::watch;
32
33/// Control for pausing and resuming a streaming response
34pub struct PauseControl {
35    pub(crate) paused_tx: watch::Sender<bool>,
36    pub(crate) paused_rx: watch::Receiver<bool>,
37}
38
39impl PauseControl {
40    pub fn new() -> Self {
41        let (paused_tx, paused_rx) = watch::channel(false);
42        Self {
43            paused_tx,
44            paused_rx,
45        }
46    }
47
48    pub fn pause(&self) {
49        self.paused_tx.send(true).unwrap();
50    }
51
52    pub fn resume(&self) {
53        self.paused_tx.send(false).unwrap();
54    }
55
56    pub fn is_paused(&self) -> bool {
57        *self.paused_rx.borrow()
58    }
59}
60
61impl Default for PauseControl {
62    fn default() -> Self {
63        Self::new()
64    }
65}
66
67/// The content of a tool call delta - either the tool name or argument data
68#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
69pub enum ToolCallDeltaContent {
70    Name(String),
71    Delta(String),
72}
73
74/// Enum representing a streaming chunk from the model
75#[derive(Debug, Clone)]
76pub enum RawStreamingChoice<R>
77where
78    R: Clone,
79{
80    /// A text chunk from a message response
81    Message(String),
82
83    /// A tool call response (in its entirety)
84    ToolCall(RawStreamingToolCall),
85    /// A tool call partial/delta
86    ToolCallDelta {
87        /// Provider-supplied tool call ID.
88        id: String,
89        /// Rig-generated unique identifier for this tool call.
90        internal_call_id: String,
91        content: ToolCallDeltaContent,
92    },
93    /// A reasoning (in its entirety)
94    Reasoning {
95        id: Option<String>,
96        content: ReasoningContent,
97    },
98    /// A reasoning partial/delta
99    ReasoningDelta {
100        id: Option<String>,
101        reasoning: String,
102    },
103
104    /// The final response object, must be yielded if you want the
105    /// `response` field to be populated on the `StreamingCompletionResponse`
106    FinalResponse(R),
107
108    /// Provider-assigned message ID (e.g. OpenAI Responses API `msg_` ID).
109    /// Captured silently into `StreamingCompletionResponse::message_id`.
110    MessageId(String),
111}
112
113/// Describes a streaming tool call response (in its entirety)
114#[derive(Debug, Clone)]
115pub struct RawStreamingToolCall {
116    /// Provider-supplied tool call ID.
117    pub id: String,
118    /// Rig-generated unique identifier for this tool call.
119    pub internal_call_id: String,
120    pub call_id: Option<String>,
121    pub name: String,
122    pub arguments: serde_json::Value,
123    pub signature: Option<String>,
124    pub additional_params: Option<serde_json::Value>,
125}
126
127impl RawStreamingToolCall {
128    pub fn empty() -> Self {
129        Self {
130            id: String::new(),
131            internal_call_id: nanoid::nanoid!(),
132            call_id: None,
133            name: String::new(),
134            arguments: serde_json::Value::Null,
135            signature: None,
136            additional_params: None,
137        }
138    }
139
140    pub fn new(id: String, name: String, arguments: serde_json::Value) -> Self {
141        Self {
142            id,
143            internal_call_id: nanoid::nanoid!(),
144            call_id: None,
145            name,
146            arguments,
147            signature: None,
148            additional_params: None,
149        }
150    }
151
152    pub fn with_internal_call_id(mut self, internal_call_id: String) -> Self {
153        self.internal_call_id = internal_call_id;
154        self
155    }
156
157    pub fn with_call_id(mut self, call_id: String) -> Self {
158        self.call_id = Some(call_id);
159        self
160    }
161
162    pub fn with_signature(mut self, signature: Option<String>) -> Self {
163        self.signature = signature;
164        self
165    }
166
167    pub fn with_additional_params(mut self, additional_params: Option<serde_json::Value>) -> Self {
168        self.additional_params = additional_params;
169        self
170    }
171}
172
173impl From<RawStreamingToolCall> for ToolCall {
174    fn from(tool_call: RawStreamingToolCall) -> Self {
175        ToolCall {
176            id: tool_call.id,
177            call_id: tool_call.call_id,
178            function: ToolFunction {
179                name: tool_call.name,
180                arguments: tool_call.arguments,
181            },
182            signature: tool_call.signature,
183            additional_params: tool_call.additional_params,
184        }
185    }
186}
187
188#[cfg(not(all(feature = "wasm", target_arch = "wasm32")))]
189pub type StreamingResult<R> =
190    Pin<Box<dyn Stream<Item = Result<RawStreamingChoice<R>, CompletionError>> + Send>>;
191
192#[cfg(all(feature = "wasm", target_arch = "wasm32"))]
193pub type StreamingResult<R> =
194    Pin<Box<dyn Stream<Item = Result<RawStreamingChoice<R>, CompletionError>>>>;
195
196/// The response from a streaming completion request;
197/// message and response are populated at the end of the
198/// `inner` stream.
199pub struct StreamingCompletionResponse<R>
200where
201    R: Clone + Unpin + GetTokenUsage,
202{
203    pub(crate) inner: Abortable<StreamingResult<R>>,
204    pub(crate) abort_handle: AbortHandle,
205    pub(crate) pause_control: PauseControl,
206    assistant_items: Vec<AssistantContent>,
207    text_item_index: Option<usize>,
208    reasoning_item_index: Option<usize>,
209    /// The final aggregated message from the stream
210    /// contains all text and tool calls generated
211    pub choice: OneOrMany<AssistantContent>,
212    /// The final response from the stream, may be `None`
213    /// if the provider didn't yield it during the stream
214    pub response: Option<R>,
215    pub final_response_yielded: AtomicBool,
216    /// Provider-assigned message ID (e.g. OpenAI Responses API `msg_` ID).
217    pub message_id: Option<String>,
218}
219
220impl<R> StreamingCompletionResponse<R>
221where
222    R: Clone + Unpin + GetTokenUsage,
223{
224    pub fn stream(inner: StreamingResult<R>) -> StreamingCompletionResponse<R> {
225        let (abort_handle, abort_registration) = AbortHandle::new_pair();
226        let abortable_stream = Abortable::new(inner, abort_registration);
227        let pause_control = PauseControl::new();
228        Self {
229            inner: abortable_stream,
230            abort_handle,
231            pause_control,
232            assistant_items: vec![],
233            text_item_index: None,
234            reasoning_item_index: None,
235            choice: OneOrMany::one(AssistantContent::text("")),
236            response: None,
237            final_response_yielded: AtomicBool::new(false),
238            message_id: None,
239        }
240    }
241
242    pub fn cancel(&self) {
243        self.abort_handle.abort();
244    }
245
246    pub fn pause(&self) {
247        self.pause_control.pause();
248    }
249
250    pub fn resume(&self) {
251        self.pause_control.resume();
252    }
253
254    pub fn is_paused(&self) -> bool {
255        self.pause_control.is_paused()
256    }
257
258    fn append_text_chunk(&mut self, text: &str) {
259        if let Some(index) = self.text_item_index
260            && let Some(AssistantContent::Text(existing_text)) = self.assistant_items.get_mut(index)
261        {
262            existing_text.text.push_str(text);
263            return;
264        }
265
266        self.assistant_items
267            .push(AssistantContent::text(text.to_owned()));
268        self.text_item_index = Some(self.assistant_items.len() - 1);
269    }
270
271    /// Accumulate streaming reasoning delta text into assistant_items.
272    /// Providers that only emit ReasoningDelta (not full Reasoning blocks)
273    /// need this so the aggregated response includes reasoning content.
274    fn append_reasoning_chunk(&mut self, id: &Option<String>, text: &str) {
275        if let Some(index) = self.reasoning_item_index
276            && let Some(AssistantContent::Reasoning(existing)) = self.assistant_items.get_mut(index)
277            && let Some(ReasoningContent::Text {
278                text: existing_text,
279                ..
280            }) = existing.content.last_mut()
281        {
282            existing_text.push_str(text);
283            return;
284        }
285
286        self.assistant_items
287            .push(AssistantContent::Reasoning(Reasoning {
288                id: id.clone(),
289                content: vec![ReasoningContent::Text {
290                    text: text.to_string(),
291                    signature: None,
292                }],
293            }));
294        self.reasoning_item_index = Some(self.assistant_items.len() - 1);
295    }
296}
297
298impl<R> From<StreamingCompletionResponse<R>> for CompletionResponse<Option<R>>
299where
300    R: Clone + Unpin + GetTokenUsage,
301{
302    fn from(value: StreamingCompletionResponse<R>) -> CompletionResponse<Option<R>> {
303        CompletionResponse {
304            choice: value.choice,
305            usage: Usage::new(), // Usage is not tracked in streaming responses
306            raw_response: value.response,
307            message_id: value.message_id,
308        }
309    }
310}
311
312impl<R> Stream for StreamingCompletionResponse<R>
313where
314    R: Clone + Unpin + GetTokenUsage,
315{
316    type Item = Result<StreamedAssistantContent<R>, CompletionError>;
317
318    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
319        let stream = self.get_mut();
320
321        if stream.is_paused() {
322            cx.waker().wake_by_ref();
323            return Poll::Pending;
324        }
325
326        match Pin::new(&mut stream.inner).poll_next(cx) {
327            Poll::Pending => Poll::Pending,
328            Poll::Ready(None) => {
329                // This is run at the end of the inner stream to collect all tokens into
330                // a single unified `Message`.
331                if stream.assistant_items.is_empty() {
332                    stream.assistant_items.push(AssistantContent::text(""));
333                }
334
335                stream.choice = OneOrMany::many(std::mem::take(&mut stream.assistant_items))
336                    .expect("There should be at least one assistant message");
337
338                Poll::Ready(None)
339            }
340            Poll::Ready(Some(Err(err))) => {
341                if matches!(err, CompletionError::ProviderError(ref e) if e.to_string().contains("aborted"))
342                {
343                    return Poll::Ready(None); // Treat cancellation as stream termination
344                }
345                Poll::Ready(Some(Err(err)))
346            }
347            Poll::Ready(Some(Ok(choice))) => match choice {
348                RawStreamingChoice::Message(text) => {
349                    stream.reasoning_item_index = None;
350                    stream.append_text_chunk(&text);
351                    Poll::Ready(Some(Ok(StreamedAssistantContent::text(&text))))
352                }
353                RawStreamingChoice::ToolCallDelta {
354                    id,
355                    internal_call_id,
356                    content,
357                } => Poll::Ready(Some(Ok(StreamedAssistantContent::ToolCallDelta {
358                    id,
359                    internal_call_id,
360                    content,
361                }))),
362                RawStreamingChoice::Reasoning { id, content } => {
363                    let reasoning = Reasoning {
364                        id,
365                        content: vec![content],
366                    };
367                    stream.text_item_index = None;
368                    // Full reasoning block supersedes any delta accumulation
369                    stream.reasoning_item_index = None;
370                    stream
371                        .assistant_items
372                        .push(AssistantContent::Reasoning(reasoning.clone()));
373                    Poll::Ready(Some(Ok(StreamedAssistantContent::Reasoning(reasoning))))
374                }
375                RawStreamingChoice::ReasoningDelta { id, reasoning } => {
376                    stream.text_item_index = None;
377                    stream.append_reasoning_chunk(&id, &reasoning);
378                    Poll::Ready(Some(Ok(StreamedAssistantContent::ReasoningDelta {
379                        id,
380                        reasoning,
381                    })))
382                }
383                RawStreamingChoice::ToolCall(raw_tool_call) => {
384                    let internal_call_id = raw_tool_call.internal_call_id.clone();
385                    let tool_call: ToolCall = raw_tool_call.into();
386                    stream.text_item_index = None;
387                    stream.reasoning_item_index = None;
388                    stream
389                        .assistant_items
390                        .push(AssistantContent::ToolCall(tool_call.clone()));
391                    Poll::Ready(Some(Ok(StreamedAssistantContent::ToolCall {
392                        tool_call,
393                        internal_call_id,
394                    })))
395                }
396                RawStreamingChoice::FinalResponse(response) => {
397                    if stream
398                        .final_response_yielded
399                        .load(std::sync::atomic::Ordering::SeqCst)
400                    {
401                        stream.poll_next_unpin(cx)
402                    } else {
403                        // Set the final response field and return the next item in the stream
404                        stream.response = Some(response.clone());
405                        stream
406                            .final_response_yielded
407                            .store(true, std::sync::atomic::Ordering::SeqCst);
408                        let final_response = StreamedAssistantContent::final_response(response);
409                        Poll::Ready(Some(Ok(final_response)))
410                    }
411                }
412                RawStreamingChoice::MessageId(id) => {
413                    stream.message_id = Some(id);
414                    stream.poll_next_unpin(cx)
415                }
416            },
417        }
418    }
419}
420
421/// Trait for high-level streaming prompt interface.
422///
423/// This trait provides a simple interface for streaming prompts to a completion model.
424/// Implementations can optionally support prompt hooks for observing and controlling
425/// the agent's execution lifecycle.
426pub trait StreamingPrompt<M, R>
427where
428    M: CompletionModel + 'static,
429    <M as CompletionModel>::StreamingResponse: WasmCompatSend,
430    R: Clone + Unpin + GetTokenUsage,
431{
432    /// The hook type used by this streaming prompt implementation.
433    ///
434    /// If your implementation does not need prompt hooks, use `()` as the hook type:
435    ///
436    /// ```ignore
437    /// impl<M, R> StreamingPrompt<M, R> for MyType<M>
438    /// where
439    ///     M: CompletionModel + 'static,
440    ///     // ... other bounds ...
441    /// {
442    ///     type Hook = ();
443    ///
444    ///     fn stream_prompt(&self, prompt: impl Into<Message>) -> StreamingPromptRequest<M, ()> {
445    ///         // ...
446    ///     }
447    /// }
448    /// ```
449    type Hook: PromptHook<M>;
450
451    /// Stream a simple prompt to the model
452    fn stream_prompt(
453        &self,
454        prompt: impl Into<Message> + WasmCompatSend,
455    ) -> StreamingPromptRequest<M, Self::Hook>;
456}
457
458/// Trait for high-level streaming chat interface with conversation history.
459///
460/// This trait provides an interface for streaming chat completions with support
461/// for maintaining conversation history. Implementations can optionally support
462/// prompt hooks for observing and controlling the agent's execution lifecycle.
463pub trait StreamingChat<M, R>: WasmCompatSend + WasmCompatSync
464where
465    M: CompletionModel + 'static,
466    <M as CompletionModel>::StreamingResponse: WasmCompatSend,
467    R: Clone + Unpin + GetTokenUsage,
468{
469    /// The hook type used by this streaming chat implementation.
470    ///
471    /// If your implementation does not need prompt hooks, use `()` as the hook type:
472    ///
473    /// ```ignore
474    /// impl<M, R> StreamingChat<M, R> for MyType<M>
475    /// where
476    ///     M: CompletionModel + 'static,
477    ///     // ... other bounds ...
478    /// {
479    ///     type Hook = ();
480    ///
481    ///     fn stream_chat(
482    ///         &self,
483    ///         prompt: impl Into<Message>,
484    ///         chat_history: Vec<Message>,
485    ///     ) -> StreamingPromptRequest<M, ()> {
486    ///         // ...
487    ///     }
488    /// }
489    /// ```
490    type Hook: PromptHook<M>;
491
492    /// Stream a chat with history to the model.
493    ///
494    /// The messages returned by the model can be accessed via `FinalResponse::history()`
495    ///
496    /// You are responsible for managing history, a simple linear solution could look like:
497    /// ```ignore
498    ///  let mut history = vec![];
499    ///
500    ///  loop {
501    ///      let prompt = "Create GPT-67, make no mistakes";
502    ///      let mut stream = agent.stream_chat(prompt, &history).await;
503    ///
504    ///      while let Some(msg) = stream.next().await {
505    ///         match msg {
506    ///              Ok(MultiTurnStreamItem::FinalResponse(fin)) => {
507    ///                  history.extend_from_slice(fin.history().unwrap_or_default());
508    ///                  break;
509    ///             }
510    ///             Ok(_other) => { /* Do something with this chunk */ }
511    ///             Err(e) => return Err(e.into()),
512    ///         }
513    ///     }
514    /// }
515    /// ```
516    fn stream_chat<I, T>(
517        &self,
518        prompt: impl Into<Message> + WasmCompatSend,
519        chat_history: I,
520    ) -> StreamingPromptRequest<M, Self::Hook>
521    where
522        I: IntoIterator<Item = T> + WasmCompatSend,
523        T: Into<Message>;
524}
525
526/// Trait for low-level streaming completion interface
527pub trait StreamingCompletion<M: CompletionModel> {
528    /// Generate a streaming completion from a request
529    fn stream_completion<I, T>(
530        &self,
531        prompt: impl Into<Message> + WasmCompatSend,
532        chat_history: I,
533    ) -> impl Future<Output = Result<CompletionRequestBuilder<M>, CompletionError>>
534    where
535        I: IntoIterator<Item = T> + WasmCompatSend,
536        T: Into<Message>;
537}
538
539pub(crate) struct StreamingResultDyn<R: Clone + Unpin + GetTokenUsage> {
540    pub(crate) inner: StreamingResult<R>,
541}
542
543fn map_raw_streaming_choice<R>(
544    chunk: RawStreamingChoice<R>,
545) -> RawStreamingChoice<FinalCompletionResponse>
546where
547    R: Clone + Unpin + GetTokenUsage,
548{
549    match chunk {
550        RawStreamingChoice::FinalResponse(res) => {
551            RawStreamingChoice::FinalResponse(FinalCompletionResponse {
552                usage: res.token_usage(),
553            })
554        }
555        RawStreamingChoice::Message(m) => RawStreamingChoice::Message(m),
556        RawStreamingChoice::ToolCallDelta {
557            id,
558            internal_call_id,
559            content,
560        } => RawStreamingChoice::ToolCallDelta {
561            id,
562            internal_call_id,
563            content,
564        },
565        RawStreamingChoice::Reasoning { id, content } => {
566            RawStreamingChoice::Reasoning { id, content }
567        }
568        RawStreamingChoice::ReasoningDelta { id, reasoning } => {
569            RawStreamingChoice::ReasoningDelta { id, reasoning }
570        }
571        RawStreamingChoice::ToolCall(tool_call) => RawStreamingChoice::ToolCall(tool_call),
572        RawStreamingChoice::MessageId(id) => RawStreamingChoice::MessageId(id),
573    }
574}
575
576impl<R: Clone + Unpin + GetTokenUsage> Stream for StreamingResultDyn<R> {
577    type Item = Result<RawStreamingChoice<FinalCompletionResponse>, CompletionError>;
578
579    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
580        let stream = self.get_mut();
581
582        match stream.inner.as_mut().poll_next(cx) {
583            Poll::Pending => Poll::Pending,
584            Poll::Ready(None) => Poll::Ready(None),
585            Poll::Ready(Some(item)) => Poll::Ready(Some(item.map(map_raw_streaming_choice::<R>))),
586        }
587    }
588}
589
590/// A helper function to stream a completion request to stdout.
591/// Tool call deltas are ignored as tool calls are generally much easier to handle when received in their entirety rather than using deltas.
592pub async fn stream_to_stdout<M>(
593    agent: &'static Agent<M>,
594    stream: &mut StreamingCompletionResponse<M::StreamingResponse>,
595) -> Result<(), std::io::Error>
596where
597    M: CompletionModel,
598{
599    let mut is_reasoning = false;
600    print!("Response: ");
601    while let Some(chunk) = stream.next().await {
602        match chunk {
603            Ok(StreamedAssistantContent::Text(text)) => {
604                if is_reasoning {
605                    is_reasoning = false;
606                    println!("\n---\n");
607                }
608                print!("{}", text.text);
609                std::io::Write::flush(&mut std::io::stdout())?;
610            }
611            Ok(StreamedAssistantContent::ToolCall {
612                tool_call,
613                internal_call_id: _,
614            }) => {
615                let res = agent
616                    .tool_server_handle
617                    .call_tool(
618                        &tool_call.function.name,
619                        &tool_call.function.arguments.to_string(),
620                    )
621                    .await
622                    .map_err(|x| std::io::Error::other(x.to_string()))?;
623                println!("\nResult: {res}");
624            }
625            Ok(StreamedAssistantContent::Final(res)) => {
626                let json_res = serde_json::to_string_pretty(&res).unwrap();
627                println!();
628                tracing::info!("Final result: {json_res}");
629            }
630            Ok(StreamedAssistantContent::Reasoning(reasoning)) => {
631                if !is_reasoning {
632                    is_reasoning = true;
633                    println!();
634                    println!("Thinking: ");
635                }
636                let reasoning = reasoning.display_text();
637
638                print!("{reasoning}");
639                std::io::Write::flush(&mut std::io::stdout())?;
640            }
641            Err(e) => {
642                if e.to_string().contains("aborted") {
643                    println!("\nStream cancelled.");
644                    break;
645                }
646                eprintln!("Error: {e}");
647                break;
648            }
649            _ => {}
650        }
651    }
652
653    println!(); // New line after streaming completes
654
655    Ok(())
656}
657
658// Test module
659#[cfg(test)]
660mod tests {
661    use std::time::Duration;
662
663    use super::*;
664    use async_stream::stream;
665    use tokio::time::sleep;
666
667    #[derive(Debug, Clone)]
668    pub struct MockResponse {
669        #[allow(dead_code)]
670        token_count: u32,
671    }
672
673    impl GetTokenUsage for MockResponse {
674        fn token_usage(&self) -> Option<crate::completion::Usage> {
675            let mut usage = Usage::new();
676            usage.total_tokens = 15;
677            Some(usage)
678        }
679    }
680
681    #[cfg(not(all(feature = "wasm", target_arch = "wasm32")))]
682    fn to_stream_result(
683        stream: impl futures::Stream<Item = Result<RawStreamingChoice<MockResponse>, CompletionError>>
684        + Send
685        + 'static,
686    ) -> StreamingResult<MockResponse> {
687        Box::pin(stream)
688    }
689
690    #[cfg(all(feature = "wasm", target_arch = "wasm32"))]
691    fn to_stream_result(
692        stream: impl futures::Stream<Item = Result<RawStreamingChoice<MockResponse>, CompletionError>>
693        + 'static,
694    ) -> StreamingResult<MockResponse> {
695        Box::pin(stream)
696    }
697
698    fn create_mock_stream() -> StreamingCompletionResponse<MockResponse> {
699        let stream = stream! {
700            yield Ok(RawStreamingChoice::Message("hello 1".to_string()));
701            sleep(Duration::from_millis(100)).await;
702            yield Ok(RawStreamingChoice::Message("hello 2".to_string()));
703            sleep(Duration::from_millis(100)).await;
704            yield Ok(RawStreamingChoice::Message("hello 3".to_string()));
705            sleep(Duration::from_millis(100)).await;
706            yield Ok(RawStreamingChoice::FinalResponse(MockResponse { token_count: 15 }));
707        };
708
709        StreamingCompletionResponse::stream(to_stream_result(stream))
710    }
711
712    fn create_reasoning_stream() -> StreamingCompletionResponse<MockResponse> {
713        let stream = stream! {
714            yield Ok(RawStreamingChoice::Reasoning {
715                id: Some("rs_1".to_string()),
716                content: ReasoningContent::Text {
717                    text: "step one".to_string(),
718                    signature: Some("sig_1".to_string()),
719                },
720            });
721            yield Ok(RawStreamingChoice::Message("final answer".to_string()));
722            yield Ok(RawStreamingChoice::FinalResponse(MockResponse { token_count: 5 }));
723        };
724
725        StreamingCompletionResponse::stream(to_stream_result(stream))
726    }
727
728    fn create_reasoning_only_stream() -> StreamingCompletionResponse<MockResponse> {
729        let stream = stream! {
730            yield Ok(RawStreamingChoice::Reasoning {
731                id: Some("rs_only".to_string()),
732                content: ReasoningContent::Summary("hidden summary".to_string()),
733            });
734            yield Ok(RawStreamingChoice::FinalResponse(MockResponse { token_count: 2 }));
735        };
736
737        StreamingCompletionResponse::stream(to_stream_result(stream))
738    }
739
740    fn create_interleaved_stream() -> StreamingCompletionResponse<MockResponse> {
741        let stream = stream! {
742            yield Ok(RawStreamingChoice::Reasoning {
743                id: Some("rs_interleaved".to_string()),
744                content: ReasoningContent::Text {
745                    text: "chain-of-thought".to_string(),
746                    signature: None,
747                },
748            });
749            yield Ok(RawStreamingChoice::Message("final-text".to_string()));
750            yield Ok(RawStreamingChoice::ToolCall(
751                RawStreamingToolCall::new(
752                    "tool_1".to_string(),
753                    "mock_tool".to_string(),
754                    serde_json::json!({"arg": 1}),
755                ),
756            ));
757            yield Ok(RawStreamingChoice::FinalResponse(MockResponse { token_count: 3 }));
758        };
759
760        StreamingCompletionResponse::stream(to_stream_result(stream))
761    }
762
763    fn create_text_tool_text_stream() -> StreamingCompletionResponse<MockResponse> {
764        let stream = stream! {
765            yield Ok(RawStreamingChoice::Message("first".to_string()));
766            yield Ok(RawStreamingChoice::ToolCall(
767                RawStreamingToolCall::new(
768                    "tool_split".to_string(),
769                    "mock_tool".to_string(),
770                    serde_json::json!({"arg": "x"}),
771                ),
772            ));
773            yield Ok(RawStreamingChoice::Message("second".to_string()));
774            yield Ok(RawStreamingChoice::FinalResponse(MockResponse { token_count: 3 }));
775        };
776
777        StreamingCompletionResponse::stream(to_stream_result(stream))
778    }
779
780    #[tokio::test]
781    async fn test_stream_cancellation() {
782        let mut stream = create_mock_stream();
783
784        println!("Response: ");
785        let mut chunk_count = 0;
786        while let Some(chunk) = stream.next().await {
787            match chunk {
788                Ok(StreamedAssistantContent::Text(text)) => {
789                    print!("{}", text.text);
790                    std::io::Write::flush(&mut std::io::stdout()).unwrap();
791                    chunk_count += 1;
792                }
793                Ok(StreamedAssistantContent::ToolCall {
794                    tool_call,
795                    internal_call_id,
796                }) => {
797                    println!("\nTool Call: {tool_call:?}, internal_call_id={internal_call_id:?}");
798                    chunk_count += 1;
799                }
800                Ok(StreamedAssistantContent::ToolCallDelta {
801                    id,
802                    internal_call_id,
803                    content,
804                }) => {
805                    println!(
806                        "\nTool Call delta: id={id:?}, internal_call_id={internal_call_id:?}, content={content:?}"
807                    );
808                    chunk_count += 1;
809                }
810                Ok(StreamedAssistantContent::Final(res)) => {
811                    println!("\nFinal response: {res:?}");
812                }
813                Ok(StreamedAssistantContent::Reasoning(reasoning)) => {
814                    let reasoning = reasoning.display_text();
815                    print!("{reasoning}");
816                    std::io::Write::flush(&mut std::io::stdout()).unwrap();
817                }
818                Ok(StreamedAssistantContent::ReasoningDelta { reasoning, .. }) => {
819                    println!("Reasoning delta: {reasoning}");
820                    chunk_count += 1;
821                }
822                Err(e) => {
823                    eprintln!("Error: {e:?}");
824                    break;
825                }
826            }
827
828            if chunk_count >= 2 {
829                println!("\nCancelling stream...");
830                stream.cancel();
831                println!("Stream cancelled.");
832                break;
833            }
834        }
835
836        let next_chunk = stream.next().await;
837        assert!(
838            next_chunk.is_none(),
839            "Expected no further chunks after cancellation, got {next_chunk:?}"
840        );
841    }
842
843    #[tokio::test]
844    async fn test_stream_pause_resume() {
845        let stream = create_mock_stream();
846
847        // Test pause
848        stream.pause();
849        assert!(stream.is_paused());
850
851        // Test resume
852        stream.resume();
853        assert!(!stream.is_paused());
854    }
855
856    #[tokio::test]
857    async fn test_stream_aggregates_reasoning_content() {
858        let mut stream = create_reasoning_stream();
859        while stream.next().await.is_some() {}
860
861        let choice_items: Vec<AssistantContent> = stream.choice.clone().into_iter().collect();
862
863        assert!(choice_items.iter().any(|item| matches!(
864            item,
865            AssistantContent::Reasoning(Reasoning {
866                id: Some(id),
867                content
868            }) if id == "rs_1"
869                && matches!(
870                    content.first(),
871                    Some(ReasoningContent::Text {
872                        text,
873                        signature: Some(signature)
874                    }) if text == "step one" && signature == "sig_1"
875                )
876        )));
877    }
878
879    #[tokio::test]
880    async fn test_stream_reasoning_only_does_not_inject_empty_text() {
881        let mut stream = create_reasoning_only_stream();
882        while stream.next().await.is_some() {}
883
884        let choice_items: Vec<AssistantContent> = stream.choice.clone().into_iter().collect();
885        assert_eq!(choice_items.len(), 1);
886        assert!(matches!(
887            choice_items.first(),
888            Some(AssistantContent::Reasoning(Reasoning { id: Some(id), .. })) if id == "rs_only"
889        ));
890    }
891
892    #[tokio::test]
893    async fn test_stream_aggregates_assistant_items_in_arrival_order() {
894        let mut stream = create_interleaved_stream();
895        while stream.next().await.is_some() {}
896
897        let choice_items: Vec<AssistantContent> = stream.choice.clone().into_iter().collect();
898        assert_eq!(choice_items.len(), 3);
899        assert!(matches!(
900            choice_items.first(),
901            Some(AssistantContent::Reasoning(Reasoning { id: Some(id), .. })) if id == "rs_interleaved"
902        ));
903        assert!(matches!(
904            choice_items.get(1),
905            Some(AssistantContent::Text(Text { text })) if text == "final-text"
906        ));
907        assert!(matches!(
908            choice_items.get(2),
909            Some(AssistantContent::ToolCall(ToolCall { id, .. })) if id == "tool_1"
910        ));
911    }
912
913    #[tokio::test]
914    async fn test_stream_keeps_non_contiguous_text_chunks_split_by_tool_call() {
915        let mut stream = create_text_tool_text_stream();
916        while stream.next().await.is_some() {}
917
918        let choice_items: Vec<AssistantContent> = stream.choice.clone().into_iter().collect();
919        assert_eq!(choice_items.len(), 3);
920        assert!(matches!(
921            choice_items.first(),
922            Some(AssistantContent::Text(Text { text })) if text == "first"
923        ));
924        assert!(matches!(
925            choice_items.get(1),
926            Some(AssistantContent::ToolCall(ToolCall { id, .. })) if id == "tool_split"
927        ));
928        assert!(matches!(
929            choice_items.get(2),
930            Some(AssistantContent::Text(Text { text })) if text == "second"
931        ));
932    }
933}
934
935/// Describes responses from a streamed provider response which is either text, a tool call or a final usage response.
936#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
937#[serde(untagged)]
938pub enum StreamedAssistantContent<R> {
939    Text(Text),
940    ToolCall {
941        tool_call: ToolCall,
942        /// Rig-generated unique identifier for this tool call.
943        /// Use this to correlate with ToolCallDelta events.
944        internal_call_id: String,
945    },
946    ToolCallDelta {
947        /// Provider-supplied tool call ID.
948        id: String,
949        /// Rig-generated unique identifier for this tool call.
950        internal_call_id: String,
951        content: ToolCallDeltaContent,
952    },
953    Reasoning(Reasoning),
954    ReasoningDelta {
955        id: Option<String>,
956        reasoning: String,
957    },
958    Final(R),
959}
960
961impl<R> StreamedAssistantContent<R>
962where
963    R: Clone + Unpin,
964{
965    pub fn text(text: &str) -> Self {
966        Self::Text(Text {
967            text: text.to_string(),
968        })
969    }
970
971    pub fn final_response(res: R) -> Self {
972        Self::Final(res)
973    }
974}
975
976/// Streamed user content. This content is primarily used to represent tool results from tool calls made during a multi-turn/step agent prompt.
977#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
978#[serde(untagged)]
979pub enum StreamedUserContent {
980    ToolResult {
981        tool_result: ToolResult,
982        /// Rig-generated unique identifier for the tool call this result
983        /// belongs to. Use this to correlate with the originating
984        /// [`StreamedAssistantContent::ToolCall::internal_call_id`].
985        internal_call_id: String,
986    },
987}
988
989impl StreamedUserContent {
990    pub fn tool_result(tool_result: ToolResult, internal_call_id: String) -> Self {
991        Self::ToolResult {
992            tool_result,
993            internal_call_id,
994        }
995    }
996}