Skip to main content

rig_core/test_utils/
streaming.rs

1//! Streaming helpers for [`MockCompletionModel`](super::MockCompletionModel).
2
3use crate::{
4    completion::{CompletionError, GetTokenUsage, Usage},
5    message::ReasoningContent,
6    streaming::{RawStreamingChoice, RawStreamingToolCall, ToolCallDeltaContent},
7};
8use serde::{Deserialize, Serialize};
9
10/// Raw mock response used by completion and streaming test utilities.
11#[derive(Clone, Debug, Default, Deserialize, Serialize)]
12pub struct MockResponse {
13    usage: Option<Usage>,
14}
15
16impl MockResponse {
17    /// Create a mock raw response without token usage.
18    pub fn new() -> Self {
19        Self { usage: None }
20    }
21
22    /// Create a mock raw response carrying token usage.
23    pub fn with_usage(usage: Usage) -> Self {
24        Self { usage: Some(usage) }
25    }
26
27    /// Create a mock raw response whose usage has only `total_tokens` set.
28    pub fn with_total_tokens(total_tokens: u64) -> Self {
29        let mut usage = Usage::new();
30        usage.total_tokens = total_tokens;
31        Self::with_usage(usage)
32    }
33}
34
35impl GetTokenUsage for MockResponse {
36    fn token_usage(&self) -> Option<Usage> {
37        self.usage
38    }
39}
40
41/// Scripted streaming event yielded by [`MockCompletionModel`](super::MockCompletionModel).
42#[derive(Clone, Debug)]
43pub enum MockStreamEvent {
44    /// Text chunk.
45    Text(String),
46    /// Start a new text content block with optional provider metadata.
47    TextStart {
48        additional_params: Option<serde_json::Value>,
49    },
50    /// Provider-specific metadata for the current text content block.
51    TextAdditionalParams(serde_json::Value),
52    /// Complete tool call event.
53    ToolCall {
54        id: String,
55        name: String,
56        arguments: serde_json::Value,
57        call_id: Option<String>,
58    },
59    /// Tool call delta event.
60    ToolCallDelta {
61        id: String,
62        internal_call_id: String,
63        content: ToolCallDeltaContent,
64    },
65    /// Complete reasoning event.
66    Reasoning {
67        id: Option<String>,
68        content: ReasoningContent,
69    },
70    /// Reasoning delta event.
71    ReasoningDelta {
72        id: Option<String>,
73        reasoning: String,
74    },
75    /// Provider-assigned message ID.
76    MessageId(String),
77    /// Final raw response carrying optional usage.
78    FinalResponse(MockResponse),
79    /// Stream error.
80    Error(MockError),
81}
82
83use super::completion::MockError;
84
85impl MockStreamEvent {
86    /// Create a text chunk.
87    pub fn text(text: impl Into<String>) -> Self {
88        Self::Text(text.into())
89    }
90
91    /// Start a new text content block.
92    pub fn text_start(additional_params: Option<serde_json::Value>) -> Self {
93        Self::TextStart { additional_params }
94    }
95
96    /// Add provider-specific metadata to the current text content block.
97    pub fn text_additional_params(additional_params: serde_json::Value) -> Self {
98        Self::TextAdditionalParams(additional_params)
99    }
100
101    /// Create a complete tool call event.
102    pub fn tool_call(
103        id: impl Into<String>,
104        name: impl Into<String>,
105        arguments: serde_json::Value,
106    ) -> Self {
107        Self::ToolCall {
108            id: id.into(),
109            name: name.into(),
110            arguments,
111            call_id: None,
112        }
113    }
114
115    /// Attach a provider-specific call ID to a complete tool call event.
116    pub fn with_call_id(mut self, call_id: impl Into<String>) -> Self {
117        if let Self::ToolCall { call_id: id, .. } = &mut self {
118            *id = Some(call_id.into());
119        }
120        self
121    }
122
123    /// Create a tool call name delta.
124    pub fn tool_call_name_delta(
125        id: impl Into<String>,
126        internal_call_id: impl Into<String>,
127        name: impl Into<String>,
128    ) -> Self {
129        Self::ToolCallDelta {
130            id: id.into(),
131            internal_call_id: internal_call_id.into(),
132            content: ToolCallDeltaContent::Name(name.into()),
133        }
134    }
135
136    /// Create a tool call arguments delta.
137    pub fn tool_call_arguments_delta(
138        id: impl Into<String>,
139        internal_call_id: impl Into<String>,
140        arguments: impl Into<String>,
141    ) -> Self {
142        Self::ToolCallDelta {
143            id: id.into(),
144            internal_call_id: internal_call_id.into(),
145            content: ToolCallDeltaContent::Delta(arguments.into()),
146        }
147    }
148
149    /// Create a complete reasoning event.
150    pub fn reasoning(reasoning: impl Into<String>) -> Self {
151        Self::Reasoning {
152            id: None,
153            content: ReasoningContent::Text {
154                text: reasoning.into(),
155                signature: None,
156            },
157        }
158    }
159
160    /// Attach a provider-specific reasoning ID to a complete reasoning event.
161    pub fn with_reasoning_id(mut self, reasoning_id: impl Into<String>) -> Self {
162        if let Self::Reasoning { id, .. } = &mut self {
163            *id = Some(reasoning_id.into());
164        }
165        self
166    }
167
168    /// Create a reasoning delta event.
169    pub fn reasoning_delta(id: Option<impl Into<String>>, reasoning: impl Into<String>) -> Self {
170        Self::ReasoningDelta {
171            id: id.map(Into::into),
172            reasoning: reasoning.into(),
173        }
174    }
175
176    /// Create a provider-assigned message ID event.
177    pub fn message_id(id: impl Into<String>) -> Self {
178        Self::MessageId(id.into())
179    }
180
181    /// Create a final response event with usage.
182    pub fn final_response(usage: Usage) -> Self {
183        Self::FinalResponse(MockResponse::with_usage(usage))
184    }
185
186    /// Create a final response event with default zero usage.
187    pub fn final_response_with_default_usage() -> Self {
188        Self::FinalResponse(MockResponse::with_usage(Usage::new()))
189    }
190
191    /// Create a final response event whose usage has only `total_tokens` set.
192    pub fn final_response_with_total_tokens(total_tokens: u64) -> Self {
193        Self::FinalResponse(MockResponse::with_total_tokens(total_tokens))
194    }
195
196    /// Create a stream error event.
197    pub fn error(message: impl Into<String>) -> Self {
198        Self::Error(MockError::provider(message))
199    }
200
201    pub(crate) fn into_raw_choice(
202        self,
203    ) -> Result<RawStreamingChoice<MockResponse>, CompletionError> {
204        match self {
205            Self::Text(text) => Ok(RawStreamingChoice::Message(text)),
206            Self::TextStart { additional_params } => {
207                Ok(RawStreamingChoice::TextStart { additional_params })
208            }
209            Self::TextAdditionalParams(additional_params) => {
210                Ok(RawStreamingChoice::TextAdditionalParams(additional_params))
211            }
212            Self::ToolCall {
213                id,
214                name,
215                arguments,
216                call_id,
217            } => {
218                let mut tool_call = RawStreamingToolCall::new(id, name, arguments);
219                if let Some(call_id) = call_id {
220                    tool_call = tool_call.with_call_id(call_id);
221                }
222                Ok(RawStreamingChoice::ToolCall(tool_call))
223            }
224            Self::ToolCallDelta {
225                id,
226                internal_call_id,
227                content,
228            } => Ok(RawStreamingChoice::ToolCallDelta {
229                id,
230                internal_call_id,
231                content,
232            }),
233            Self::Reasoning { id, content } => Ok(RawStreamingChoice::Reasoning { id, content }),
234            Self::ReasoningDelta { id, reasoning } => {
235                Ok(RawStreamingChoice::ReasoningDelta { id, reasoning })
236            }
237            Self::MessageId(id) => Ok(RawStreamingChoice::MessageId(id)),
238            Self::FinalResponse(response) => Ok(RawStreamingChoice::FinalResponse(response)),
239            Self::Error(error) => Err(error.into_completion_error()),
240        }
241    }
242}