Skip to main content

rig_core/test_utils/
streaming.rs

1//! Streaming helpers for [`MockCompletionModel`](super::MockCompletionModel).
2
3use crate::{
4    completion::{CompletionError, GetTokenUsage, Usage},
5    streaming::{RawStreamingChoice, RawStreamingToolCall, ToolCallDeltaContent},
6};
7use serde::{Deserialize, Serialize};
8
9/// Raw mock response used by completion and streaming test utilities.
10#[derive(Clone, Debug, Default, Deserialize, Serialize)]
11pub struct MockResponse {
12    usage: Option<Usage>,
13}
14
15impl MockResponse {
16    /// Create a mock raw response without token usage.
17    pub fn new() -> Self {
18        Self { usage: None }
19    }
20
21    /// Create a mock raw response carrying token usage.
22    pub fn with_usage(usage: Usage) -> Self {
23        Self { usage: Some(usage) }
24    }
25
26    /// Create a mock raw response whose usage has only `total_tokens` set.
27    pub fn with_total_tokens(total_tokens: u64) -> Self {
28        let mut usage = Usage::new();
29        usage.total_tokens = total_tokens;
30        Self::with_usage(usage)
31    }
32}
33
34impl GetTokenUsage for MockResponse {
35    fn token_usage(&self) -> Option<Usage> {
36        self.usage
37    }
38}
39
40/// Scripted streaming event yielded by [`MockCompletionModel`](super::MockCompletionModel).
41#[derive(Clone, Debug)]
42pub enum MockStreamEvent {
43    /// Text chunk.
44    Text(String),
45    /// Complete tool call event.
46    ToolCall {
47        id: String,
48        name: String,
49        arguments: serde_json::Value,
50        call_id: Option<String>,
51    },
52    /// Tool call delta event.
53    ToolCallDelta {
54        id: String,
55        internal_call_id: String,
56        content: ToolCallDeltaContent,
57    },
58    /// Provider-assigned message ID.
59    MessageId(String),
60    /// Final raw response carrying optional usage.
61    FinalResponse(MockResponse),
62    /// Stream error.
63    Error(MockError),
64}
65
66use super::completion::MockError;
67
68impl MockStreamEvent {
69    /// Create a text chunk.
70    pub fn text(text: impl Into<String>) -> Self {
71        Self::Text(text.into())
72    }
73
74    /// Create a complete tool call event.
75    pub fn tool_call(
76        id: impl Into<String>,
77        name: impl Into<String>,
78        arguments: serde_json::Value,
79    ) -> Self {
80        Self::ToolCall {
81            id: id.into(),
82            name: name.into(),
83            arguments,
84            call_id: None,
85        }
86    }
87
88    /// Attach a provider-specific call ID to a complete tool call event.
89    pub fn with_call_id(mut self, call_id: impl Into<String>) -> Self {
90        if let Self::ToolCall { call_id: id, .. } = &mut self {
91            *id = Some(call_id.into());
92        }
93        self
94    }
95
96    /// Create a tool call name delta.
97    pub fn tool_call_name_delta(
98        id: impl Into<String>,
99        internal_call_id: impl Into<String>,
100        name: impl Into<String>,
101    ) -> Self {
102        Self::ToolCallDelta {
103            id: id.into(),
104            internal_call_id: internal_call_id.into(),
105            content: ToolCallDeltaContent::Name(name.into()),
106        }
107    }
108
109    /// Create a tool call arguments delta.
110    pub fn tool_call_arguments_delta(
111        id: impl Into<String>,
112        internal_call_id: impl Into<String>,
113        arguments: impl Into<String>,
114    ) -> Self {
115        Self::ToolCallDelta {
116            id: id.into(),
117            internal_call_id: internal_call_id.into(),
118            content: ToolCallDeltaContent::Delta(arguments.into()),
119        }
120    }
121
122    /// Create a provider-assigned message ID event.
123    pub fn message_id(id: impl Into<String>) -> Self {
124        Self::MessageId(id.into())
125    }
126
127    /// Create a final response event with usage.
128    pub fn final_response(usage: Usage) -> Self {
129        Self::FinalResponse(MockResponse::with_usage(usage))
130    }
131
132    /// Create a final response event with default zero usage.
133    pub fn final_response_with_default_usage() -> Self {
134        Self::FinalResponse(MockResponse::with_usage(Usage::new()))
135    }
136
137    /// Create a final response event whose usage has only `total_tokens` set.
138    pub fn final_response_with_total_tokens(total_tokens: u64) -> Self {
139        Self::FinalResponse(MockResponse::with_total_tokens(total_tokens))
140    }
141
142    /// Create a stream error event.
143    pub fn error(message: impl Into<String>) -> Self {
144        Self::Error(MockError::provider(message))
145    }
146
147    pub(crate) fn into_raw_choice(
148        self,
149    ) -> Result<RawStreamingChoice<MockResponse>, CompletionError> {
150        match self {
151            Self::Text(text) => Ok(RawStreamingChoice::Message(text)),
152            Self::ToolCall {
153                id,
154                name,
155                arguments,
156                call_id,
157            } => {
158                let mut tool_call = RawStreamingToolCall::new(id, name, arguments);
159                if let Some(call_id) = call_id {
160                    tool_call = tool_call.with_call_id(call_id);
161                }
162                Ok(RawStreamingChoice::ToolCall(tool_call))
163            }
164            Self::ToolCallDelta {
165                id,
166                internal_call_id,
167                content,
168            } => Ok(RawStreamingChoice::ToolCallDelta {
169                id,
170                internal_call_id,
171                content,
172            }),
173            Self::MessageId(id) => Ok(RawStreamingChoice::MessageId(id)),
174            Self::FinalResponse(response) => Ok(RawStreamingChoice::FinalResponse(response)),
175            Self::Error(error) => Err(error.into_completion_error()),
176        }
177    }
178}