rig_core/test_utils/
streaming.rs1use crate::{
4 completion::{CompletionError, GetTokenUsage, Usage},
5 streaming::{RawStreamingChoice, RawStreamingToolCall, ToolCallDeltaContent},
6};
7use serde::{Deserialize, Serialize};
8
9#[derive(Clone, Debug, Default, Deserialize, Serialize)]
11pub struct MockResponse {
12 usage: Option<Usage>,
13}
14
15impl MockResponse {
16 pub fn new() -> Self {
18 Self { usage: None }
19 }
20
21 pub fn with_usage(usage: Usage) -> Self {
23 Self { usage: Some(usage) }
24 }
25
26 pub fn with_total_tokens(total_tokens: u64) -> Self {
28 let mut usage = Usage::new();
29 usage.total_tokens = total_tokens;
30 Self::with_usage(usage)
31 }
32}
33
34impl GetTokenUsage for MockResponse {
35 fn token_usage(&self) -> Option<Usage> {
36 self.usage
37 }
38}
39
40#[derive(Clone, Debug)]
42pub enum MockStreamEvent {
43 Text(String),
45 ToolCall {
47 id: String,
48 name: String,
49 arguments: serde_json::Value,
50 call_id: Option<String>,
51 },
52 ToolCallDelta {
54 id: String,
55 internal_call_id: String,
56 content: ToolCallDeltaContent,
57 },
58 MessageId(String),
60 FinalResponse(MockResponse),
62 Error(MockError),
64}
65
66use super::completion::MockError;
67
68impl MockStreamEvent {
69 pub fn text(text: impl Into<String>) -> Self {
71 Self::Text(text.into())
72 }
73
74 pub fn tool_call(
76 id: impl Into<String>,
77 name: impl Into<String>,
78 arguments: serde_json::Value,
79 ) -> Self {
80 Self::ToolCall {
81 id: id.into(),
82 name: name.into(),
83 arguments,
84 call_id: None,
85 }
86 }
87
88 pub fn with_call_id(mut self, call_id: impl Into<String>) -> Self {
90 if let Self::ToolCall { call_id: id, .. } = &mut self {
91 *id = Some(call_id.into());
92 }
93 self
94 }
95
96 pub fn tool_call_name_delta(
98 id: impl Into<String>,
99 internal_call_id: impl Into<String>,
100 name: impl Into<String>,
101 ) -> Self {
102 Self::ToolCallDelta {
103 id: id.into(),
104 internal_call_id: internal_call_id.into(),
105 content: ToolCallDeltaContent::Name(name.into()),
106 }
107 }
108
109 pub fn tool_call_arguments_delta(
111 id: impl Into<String>,
112 internal_call_id: impl Into<String>,
113 arguments: impl Into<String>,
114 ) -> Self {
115 Self::ToolCallDelta {
116 id: id.into(),
117 internal_call_id: internal_call_id.into(),
118 content: ToolCallDeltaContent::Delta(arguments.into()),
119 }
120 }
121
122 pub fn message_id(id: impl Into<String>) -> Self {
124 Self::MessageId(id.into())
125 }
126
127 pub fn final_response(usage: Usage) -> Self {
129 Self::FinalResponse(MockResponse::with_usage(usage))
130 }
131
132 pub fn final_response_with_default_usage() -> Self {
134 Self::FinalResponse(MockResponse::with_usage(Usage::new()))
135 }
136
137 pub fn final_response_with_total_tokens(total_tokens: u64) -> Self {
139 Self::FinalResponse(MockResponse::with_total_tokens(total_tokens))
140 }
141
142 pub fn error(message: impl Into<String>) -> Self {
144 Self::Error(MockError::provider(message))
145 }
146
147 pub(crate) fn into_raw_choice(
148 self,
149 ) -> Result<RawStreamingChoice<MockResponse>, CompletionError> {
150 match self {
151 Self::Text(text) => Ok(RawStreamingChoice::Message(text)),
152 Self::ToolCall {
153 id,
154 name,
155 arguments,
156 call_id,
157 } => {
158 let mut tool_call = RawStreamingToolCall::new(id, name, arguments);
159 if let Some(call_id) = call_id {
160 tool_call = tool_call.with_call_id(call_id);
161 }
162 Ok(RawStreamingChoice::ToolCall(tool_call))
163 }
164 Self::ToolCallDelta {
165 id,
166 internal_call_id,
167 content,
168 } => Ok(RawStreamingChoice::ToolCallDelta {
169 id,
170 internal_call_id,
171 content,
172 }),
173 Self::MessageId(id) => Ok(RawStreamingChoice::MessageId(id)),
174 Self::FinalResponse(response) => Ok(RawStreamingChoice::FinalResponse(response)),
175 Self::Error(error) => Err(error.into_completion_error()),
176 }
177 }
178}