1use serde::{Deserialize, Serialize};
2
3#[derive(Debug, Clone, Serialize, Deserialize)]
5#[serde(tag = "type", rename_all = "snake_case")]
6pub enum StreamChunk {
7 Content { text: String },
9 ToolCallStart { id: String, name: String },
11 ToolCallDelta { id: String, arguments: String },
13 ToolCallEnd { id: String },
15 ToolResult {
17 id: String,
18 name: String,
19 output: String,
20 success: bool,
21 },
22 StateTransition { from: Option<String>, to: String },
24 Done {},
26 Error { message: String },
28}
29
30impl StreamChunk {
31 pub fn content(text: impl Into<String>) -> Self {
32 StreamChunk::Content { text: text.into() }
33 }
34
35 pub fn tool_start(id: impl Into<String>, name: impl Into<String>) -> Self {
36 StreamChunk::ToolCallStart {
37 id: id.into(),
38 name: name.into(),
39 }
40 }
41
42 pub fn tool_delta(id: impl Into<String>, arguments: impl Into<String>) -> Self {
43 StreamChunk::ToolCallDelta {
44 id: id.into(),
45 arguments: arguments.into(),
46 }
47 }
48
49 pub fn tool_end(id: impl Into<String>) -> Self {
50 StreamChunk::ToolCallEnd { id: id.into() }
51 }
52
53 pub fn tool_result(
54 id: impl Into<String>,
55 name: impl Into<String>,
56 output: impl Into<String>,
57 success: bool,
58 ) -> Self {
59 StreamChunk::ToolResult {
60 id: id.into(),
61 name: name.into(),
62 output: output.into(),
63 success,
64 }
65 }
66
67 pub fn state_transition(from: Option<String>, to: impl Into<String>) -> Self {
68 StreamChunk::StateTransition {
69 from,
70 to: to.into(),
71 }
72 }
73
74 pub fn error(message: impl Into<String>) -> Self {
75 StreamChunk::Error {
76 message: message.into(),
77 }
78 }
79
80 pub fn is_done(&self) -> bool {
81 matches!(self, StreamChunk::Done {})
82 }
83
84 pub fn is_error(&self) -> bool {
85 matches!(self, StreamChunk::Error { .. })
86 }
87
88 pub fn is_content(&self) -> bool {
89 matches!(self, StreamChunk::Content { .. })
90 }
91}
92
93#[derive(Debug, Clone, Serialize, Deserialize)]
95pub struct StreamingConfig {
96 #[serde(default = "default_true")]
98 pub enabled: bool,
99 #[serde(default = "default_buffer_size")]
101 pub buffer_size: usize,
102 #[serde(default = "default_true")]
104 pub include_tool_events: bool,
105 #[serde(default = "default_true")]
107 pub include_state_events: bool,
108}
109
110fn default_true() -> bool {
111 true
112}
113
114fn default_buffer_size() -> usize {
115 32
116}
117
118impl Default for StreamingConfig {
119 fn default() -> Self {
120 Self {
121 enabled: true,
122 buffer_size: default_buffer_size(),
123 include_tool_events: true,
124 include_state_events: true,
125 }
126 }
127}
128
129#[cfg(test)]
130mod tests {
131 use super::*;
132
133 #[test]
134 fn test_stream_chunk_constructors() {
135 let content = StreamChunk::content("Hello");
136 assert!(content.is_content());
137
138 let tool_start = StreamChunk::tool_start("id1", "calculator");
139 assert!(matches!(tool_start, StreamChunk::ToolCallStart { .. }));
140
141 let tool_delta = StreamChunk::tool_delta("id1", r#"{"expr":"1+1"}"#);
142 assert!(matches!(tool_delta, StreamChunk::ToolCallDelta { .. }));
143
144 let tool_end = StreamChunk::tool_end("id1");
145 assert!(matches!(tool_end, StreamChunk::ToolCallEnd { .. }));
146
147 let done = StreamChunk::Done {};
148 assert!(done.is_done());
149
150 let error = StreamChunk::error("Something went wrong");
151 assert!(error.is_error());
152 }
153
154 #[test]
155 fn test_stream_chunk_serialization() {
156 let content = StreamChunk::content("Hello");
157 let json = serde_json::to_string(&content).unwrap();
158 assert!(json.contains("content"));
159 assert!(json.contains("Hello"));
160
161 let tool_start = StreamChunk::tool_start("id1", "calculator");
162 let json = serde_json::to_string(&tool_start).unwrap();
163 assert!(json.contains("tool_call_start"));
164 assert!(json.contains("calculator"));
165 }
166
167 #[test]
168 fn test_streaming_config_defaults() {
169 let config = StreamingConfig::default();
170 assert!(config.enabled);
171 assert_eq!(config.buffer_size, 32);
172 assert!(config.include_tool_events);
173 assert!(config.include_state_events);
174 }
175
176 #[test]
177 fn test_streaming_config_deserialization() {
178 let yaml = r#"
179enabled: true
180buffer_size: 64
181include_tool_events: false
182"#;
183 let config: StreamingConfig = serde_yaml::from_str(yaml).unwrap();
184 assert!(config.enabled);
185 assert_eq!(config.buffer_size, 64);
186 assert!(!config.include_tool_events);
187 assert!(config.include_state_events);
188 }
189
190 #[test]
191 fn test_tool_result_chunk() {
192 let result = StreamChunk::tool_result("id1", "calculator", "42", true);
193 match result {
194 StreamChunk::ToolResult {
195 id,
196 name,
197 output,
198 success,
199 } => {
200 assert_eq!(id, "id1");
201 assert_eq!(name, "calculator");
202 assert_eq!(output, "42");
203 assert!(success);
204 }
205 _ => panic!("Expected ToolResult"),
206 }
207 }
208
209 #[test]
210 fn test_state_transition_chunk() {
211 let transition = StreamChunk::state_transition(Some("greeting".to_string()), "support");
212 match transition {
213 StreamChunk::StateTransition { from, to } => {
214 assert_eq!(from, Some("greeting".to_string()));
215 assert_eq!(to, "support");
216 }
217 _ => panic!("Expected StateTransition"),
218 }
219 }
220
221 #[test]
222 fn test_stream_chunk_done_serialization() {
223 let done = StreamChunk::Done {};
224 let json = serde_json::to_string(&done).unwrap();
225 assert!(json.contains("done"));
226 }
227
228 #[test]
229 fn test_stream_chunk_error_serialization() {
230 let error = StreamChunk::error("Test error");
231 let json = serde_json::to_string(&error).unwrap();
232 assert!(json.contains("error"));
233 assert!(json.contains("Test error"));
234 }
235
236 #[test]
237 fn test_stream_chunk_tool_result_serialization() {
238 let result = StreamChunk::tool_result("id1", "calculator", "42", true);
239 let json = serde_json::to_string(&result).unwrap();
240 assert!(json.contains("tool_result"));
241 assert!(json.contains("calculator"));
242 assert!(json.contains("42"));
243 assert!(json.contains("true"));
244 }
245
246 #[test]
247 fn test_streaming_config_full_yaml() {
248 let yaml = r#"
249enabled: false
250buffer_size: 128
251include_tool_events: false
252include_state_events: false
253"#;
254 let config: StreamingConfig = serde_yaml::from_str(yaml).unwrap();
255 assert!(!config.enabled);
256 assert_eq!(config.buffer_size, 128);
257 assert!(!config.include_tool_events);
258 assert!(!config.include_state_events);
259 }
260
261 #[test]
262 fn test_stream_chunk_deserialization() {
263 let json = r#"{"type":"content","text":"Hello"}"#;
264 let chunk: StreamChunk = serde_json::from_str(json).unwrap();
265 assert!(chunk.is_content());
266
267 let json = r#"{"type":"done"}"#;
268 let chunk: StreamChunk = serde_json::from_str(json).unwrap();
269 assert!(chunk.is_done());
270
271 let json = r#"{"type":"error","message":"fail"}"#;
272 let chunk: StreamChunk = serde_json::from_str(json).unwrap();
273 assert!(chunk.is_error());
274 }
275
276 #[test]
277 fn test_stream_chunk_tool_events() {
278 let start = StreamChunk::tool_start("tool-1", "http");
279 let delta = StreamChunk::tool_delta("tool-1", r#"{"url":"test"}"#);
280 let end = StreamChunk::tool_end("tool-1");
281
282 match start {
283 StreamChunk::ToolCallStart { id, name } => {
284 assert_eq!(id, "tool-1");
285 assert_eq!(name, "http");
286 }
287 _ => panic!("Expected ToolCallStart"),
288 }
289
290 match delta {
291 StreamChunk::ToolCallDelta { id, arguments } => {
292 assert_eq!(id, "tool-1");
293 assert!(arguments.contains("url"));
294 }
295 _ => panic!("Expected ToolCallDelta"),
296 }
297
298 match end {
299 StreamChunk::ToolCallEnd { id } => {
300 assert_eq!(id, "tool-1");
301 }
302 _ => panic!("Expected ToolCallEnd"),
303 }
304 }
305}