Skip to main content

rig_core/test_utils/
streaming.rs

1//! Streaming helpers for [`MockCompletionModel`](super::MockCompletionModel).
2
3use crate::{
4    completion::{CompletionError, GetTokenUsage, Usage},
5    message::ReasoningContent,
6    streaming::{RawStreamingChoice, RawStreamingToolCall, ToolCallDeltaContent},
7};
8use serde::{Deserialize, Serialize};
9
10/// Raw mock response used by completion and streaming test utilities.
11#[derive(Clone, Debug, Default, Deserialize, Serialize)]
12pub struct MockResponse {
13    usage: Usage,
14}
15
16impl MockResponse {
17    /// Create a mock raw response without token usage (zero-valued usage,
18    /// [`Usage`]'s documented sentinel for missing provider metrics).
19    pub fn new() -> Self {
20        Self {
21            usage: Usage::new(),
22        }
23    }
24
25    /// Create a mock raw response carrying token usage.
26    pub fn with_usage(usage: Usage) -> Self {
27        Self { usage }
28    }
29
30    /// Create a mock raw response whose usage has only `total_tokens` set.
31    pub fn with_total_tokens(total_tokens: u64) -> Self {
32        let mut usage = Usage::new();
33        usage.total_tokens = total_tokens;
34        Self::with_usage(usage)
35    }
36}
37
38impl GetTokenUsage for MockResponse {
39    fn token_usage(&self) -> Usage {
40        self.usage
41    }
42}
43
44/// Scripted streaming event yielded by [`MockCompletionModel`](super::MockCompletionModel).
45#[derive(Clone, Debug)]
46pub enum MockStreamEvent {
47    /// Text chunk.
48    Text(String),
49    /// Start a new text content block with optional provider metadata.
50    TextStart {
51        additional_params: Option<serde_json::Value>,
52    },
53    /// Provider-specific metadata for the current text content block.
54    TextAdditionalParams(serde_json::Value),
55    /// Complete tool call event.
56    ToolCall {
57        id: String,
58        name: String,
59        arguments: serde_json::Value,
60        call_id: Option<String>,
61    },
62    /// Tool call delta event.
63    ToolCallDelta {
64        id: String,
65        internal_call_id: String,
66        content: ToolCallDeltaContent,
67    },
68    /// Complete reasoning event.
69    Reasoning {
70        id: Option<String>,
71        content: ReasoningContent,
72    },
73    /// Reasoning delta event.
74    ReasoningDelta {
75        id: Option<String>,
76        reasoning: String,
77    },
78    /// Provider-assigned message ID.
79    MessageId(String),
80    /// Final raw response carrying optional usage.
81    FinalResponse(MockResponse),
82    /// Stream error.
83    Error(MockError),
84}
85
86use super::completion::MockError;
87
88impl MockStreamEvent {
89    /// Create a text chunk.
90    pub fn text(text: impl Into<String>) -> Self {
91        Self::Text(text.into())
92    }
93
94    /// Start a new text content block.
95    pub fn text_start(additional_params: Option<serde_json::Value>) -> Self {
96        Self::TextStart { additional_params }
97    }
98
99    /// Add provider-specific metadata to the current text content block.
100    pub fn text_additional_params(additional_params: serde_json::Value) -> Self {
101        Self::TextAdditionalParams(additional_params)
102    }
103
104    /// Create a complete tool call event.
105    pub fn tool_call(
106        id: impl Into<String>,
107        name: impl Into<String>,
108        arguments: serde_json::Value,
109    ) -> Self {
110        Self::ToolCall {
111            id: id.into(),
112            name: name.into(),
113            arguments,
114            call_id: None,
115        }
116    }
117
118    /// Attach a provider-specific call ID to a complete tool call event.
119    pub fn with_call_id(mut self, call_id: impl Into<String>) -> Self {
120        if let Self::ToolCall { call_id: id, .. } = &mut self {
121            *id = Some(call_id.into());
122        }
123        self
124    }
125
126    /// Create a tool call name delta.
127    pub fn tool_call_name_delta(
128        id: impl Into<String>,
129        internal_call_id: impl Into<String>,
130        name: impl Into<String>,
131    ) -> Self {
132        Self::ToolCallDelta {
133            id: id.into(),
134            internal_call_id: internal_call_id.into(),
135            content: ToolCallDeltaContent::Name(name.into()),
136        }
137    }
138
139    /// Create a tool call arguments delta.
140    pub fn tool_call_arguments_delta(
141        id: impl Into<String>,
142        internal_call_id: impl Into<String>,
143        arguments: impl Into<String>,
144    ) -> Self {
145        Self::ToolCallDelta {
146            id: id.into(),
147            internal_call_id: internal_call_id.into(),
148            content: ToolCallDeltaContent::Delta(arguments.into()),
149        }
150    }
151
152    /// Create a complete reasoning event.
153    pub fn reasoning(reasoning: impl Into<String>) -> Self {
154        Self::Reasoning {
155            id: None,
156            content: ReasoningContent::Text {
157                text: reasoning.into(),
158                signature: None,
159            },
160        }
161    }
162
163    /// Attach a provider-specific reasoning ID to a complete reasoning event.
164    pub fn with_reasoning_id(mut self, reasoning_id: impl Into<String>) -> Self {
165        if let Self::Reasoning { id, .. } = &mut self {
166            *id = Some(reasoning_id.into());
167        }
168        self
169    }
170
171    /// Create a reasoning delta event.
172    pub fn reasoning_delta(id: Option<impl Into<String>>, reasoning: impl Into<String>) -> Self {
173        Self::ReasoningDelta {
174            id: id.map(Into::into),
175            reasoning: reasoning.into(),
176        }
177    }
178
179    /// Create a provider-assigned message ID event.
180    pub fn message_id(id: impl Into<String>) -> Self {
181        Self::MessageId(id.into())
182    }
183
184    /// Create a final response event with usage.
185    pub fn final_response(usage: Usage) -> Self {
186        Self::FinalResponse(MockResponse::with_usage(usage))
187    }
188
189    /// Create a final response event with default zero usage.
190    pub fn final_response_with_default_usage() -> Self {
191        Self::FinalResponse(MockResponse::with_usage(Usage::new()))
192    }
193
194    /// Create a final response event whose usage has only `total_tokens` set.
195    pub fn final_response_with_total_tokens(total_tokens: u64) -> Self {
196        Self::FinalResponse(MockResponse::with_total_tokens(total_tokens))
197    }
198
199    /// Create a stream error event.
200    pub fn error(message: impl Into<String>) -> Self {
201        Self::Error(MockError::provider(message))
202    }
203
204    pub(crate) fn into_raw_choice(
205        self,
206    ) -> Result<RawStreamingChoice<MockResponse>, CompletionError> {
207        match self {
208            Self::Text(text) => Ok(RawStreamingChoice::Message(text)),
209            Self::TextStart { additional_params } => {
210                Ok(RawStreamingChoice::TextStart { additional_params })
211            }
212            Self::TextAdditionalParams(additional_params) => {
213                Ok(RawStreamingChoice::TextAdditionalParams(additional_params))
214            }
215            Self::ToolCall {
216                id,
217                name,
218                arguments,
219                call_id,
220            } => {
221                let mut tool_call = RawStreamingToolCall::new(id, name, arguments);
222                if let Some(call_id) = call_id {
223                    tool_call = tool_call.with_call_id(call_id);
224                }
225                Ok(RawStreamingChoice::ToolCall(tool_call))
226            }
227            Self::ToolCallDelta {
228                id,
229                internal_call_id,
230                content,
231            } => Ok(RawStreamingChoice::ToolCallDelta {
232                id,
233                internal_call_id,
234                content,
235            }),
236            Self::Reasoning { id, content } => Ok(RawStreamingChoice::Reasoning { id, content }),
237            Self::ReasoningDelta { id, reasoning } => {
238                Ok(RawStreamingChoice::ReasoningDelta { id, reasoning })
239            }
240            Self::MessageId(id) => Ok(RawStreamingChoice::MessageId(id)),
241            Self::FinalResponse(response) => Ok(RawStreamingChoice::FinalResponse(response)),
242            Self::Error(error) => Err(error.into_completion_error()),
243        }
244    }
245}