Skip to main content

ai_agents_runtime/
streaming.rs

1use serde::{Deserialize, Serialize};
2
3/// Represents a chunk of streamed response from the agent
4#[derive(Debug, Clone, Serialize, Deserialize)]
5#[serde(tag = "type", rename_all = "snake_case")]
6pub enum StreamChunk {
7    /// Text content from the LLM
8    Content { text: String },
9    /// A tool call is starting
10    ToolCallStart { id: String, name: String },
11    /// Incremental arguments for a tool call
12    ToolCallDelta { id: String, arguments: String },
13    /// A tool call has completed
14    ToolCallEnd { id: String },
15    /// Tool execution result
16    ToolResult {
17        id: String,
18        name: String,
19        output: String,
20        success: bool,
21    },
22    /// State transition occurred
23    StateTransition { from: Option<String>, to: String },
24    /// Stream has completed
25    Done {},
26    /// An error occurred
27    Error { message: String },
28}
29
30impl StreamChunk {
31    pub fn content(text: impl Into<String>) -> Self {
32        StreamChunk::Content { text: text.into() }
33    }
34
35    pub fn tool_start(id: impl Into<String>, name: impl Into<String>) -> Self {
36        StreamChunk::ToolCallStart {
37            id: id.into(),
38            name: name.into(),
39        }
40    }
41
42    pub fn tool_delta(id: impl Into<String>, arguments: impl Into<String>) -> Self {
43        StreamChunk::ToolCallDelta {
44            id: id.into(),
45            arguments: arguments.into(),
46        }
47    }
48
49    pub fn tool_end(id: impl Into<String>) -> Self {
50        StreamChunk::ToolCallEnd { id: id.into() }
51    }
52
53    pub fn tool_result(
54        id: impl Into<String>,
55        name: impl Into<String>,
56        output: impl Into<String>,
57        success: bool,
58    ) -> Self {
59        StreamChunk::ToolResult {
60            id: id.into(),
61            name: name.into(),
62            output: output.into(),
63            success,
64        }
65    }
66
67    pub fn state_transition(from: Option<String>, to: impl Into<String>) -> Self {
68        StreamChunk::StateTransition {
69            from,
70            to: to.into(),
71        }
72    }
73
74    pub fn error(message: impl Into<String>) -> Self {
75        StreamChunk::Error {
76            message: message.into(),
77        }
78    }
79
80    pub fn is_done(&self) -> bool {
81        matches!(self, StreamChunk::Done {})
82    }
83
84    pub fn is_error(&self) -> bool {
85        matches!(self, StreamChunk::Error { .. })
86    }
87
88    pub fn is_content(&self) -> bool {
89        matches!(self, StreamChunk::Content { .. })
90    }
91}
92
93/// Configuration for streaming behavior
94#[derive(Debug, Clone, Serialize, Deserialize)]
95pub struct StreamingConfig {
96    /// Whether streaming is enabled
97    #[serde(default = "default_true")]
98    pub enabled: bool,
99    /// Buffer size for streaming chunks
100    #[serde(default = "default_buffer_size")]
101    pub buffer_size: usize,
102    /// Include tool call events in the stream
103    #[serde(default = "default_true")]
104    pub include_tool_events: bool,
105    /// Include state transition events in the stream
106    #[serde(default = "default_true")]
107    pub include_state_events: bool,
108}
109
110fn default_true() -> bool {
111    true
112}
113
114fn default_buffer_size() -> usize {
115    32
116}
117
118impl Default for StreamingConfig {
119    fn default() -> Self {
120        Self {
121            enabled: true,
122            buffer_size: default_buffer_size(),
123            include_tool_events: true,
124            include_state_events: true,
125        }
126    }
127}
128
129#[cfg(test)]
130mod tests {
131    use super::*;
132
133    #[test]
134    fn test_stream_chunk_constructors() {
135        let content = StreamChunk::content("Hello");
136        assert!(content.is_content());
137
138        let tool_start = StreamChunk::tool_start("id1", "calculator");
139        assert!(matches!(tool_start, StreamChunk::ToolCallStart { .. }));
140
141        let tool_delta = StreamChunk::tool_delta("id1", r#"{"expr":"1+1"}"#);
142        assert!(matches!(tool_delta, StreamChunk::ToolCallDelta { .. }));
143
144        let tool_end = StreamChunk::tool_end("id1");
145        assert!(matches!(tool_end, StreamChunk::ToolCallEnd { .. }));
146
147        let done = StreamChunk::Done {};
148        assert!(done.is_done());
149
150        let error = StreamChunk::error("Something went wrong");
151        assert!(error.is_error());
152    }
153
154    #[test]
155    fn test_stream_chunk_serialization() {
156        let content = StreamChunk::content("Hello");
157        let json = serde_json::to_string(&content).unwrap();
158        assert!(json.contains("content"));
159        assert!(json.contains("Hello"));
160
161        let tool_start = StreamChunk::tool_start("id1", "calculator");
162        let json = serde_json::to_string(&tool_start).unwrap();
163        assert!(json.contains("tool_call_start"));
164        assert!(json.contains("calculator"));
165    }
166
167    #[test]
168    fn test_streaming_config_defaults() {
169        let config = StreamingConfig::default();
170        assert!(config.enabled);
171        assert_eq!(config.buffer_size, 32);
172        assert!(config.include_tool_events);
173        assert!(config.include_state_events);
174    }
175
176    #[test]
177    fn test_streaming_config_deserialization() {
178        let yaml = r#"
179enabled: true
180buffer_size: 64
181include_tool_events: false
182"#;
183        let config: StreamingConfig = serde_yaml::from_str(yaml).unwrap();
184        assert!(config.enabled);
185        assert_eq!(config.buffer_size, 64);
186        assert!(!config.include_tool_events);
187        assert!(config.include_state_events);
188    }
189
190    #[test]
191    fn test_tool_result_chunk() {
192        let result = StreamChunk::tool_result("id1", "calculator", "42", true);
193        match result {
194            StreamChunk::ToolResult {
195                id,
196                name,
197                output,
198                success,
199            } => {
200                assert_eq!(id, "id1");
201                assert_eq!(name, "calculator");
202                assert_eq!(output, "42");
203                assert!(success);
204            }
205            _ => panic!("Expected ToolResult"),
206        }
207    }
208
209    #[test]
210    fn test_state_transition_chunk() {
211        let transition = StreamChunk::state_transition(Some("greeting".to_string()), "support");
212        match transition {
213            StreamChunk::StateTransition { from, to } => {
214                assert_eq!(from, Some("greeting".to_string()));
215                assert_eq!(to, "support");
216            }
217            _ => panic!("Expected StateTransition"),
218        }
219    }
220
221    #[test]
222    fn test_stream_chunk_done_serialization() {
223        let done = StreamChunk::Done {};
224        let json = serde_json::to_string(&done).unwrap();
225        assert!(json.contains("done"));
226    }
227
228    #[test]
229    fn test_stream_chunk_error_serialization() {
230        let error = StreamChunk::error("Test error");
231        let json = serde_json::to_string(&error).unwrap();
232        assert!(json.contains("error"));
233        assert!(json.contains("Test error"));
234    }
235
236    #[test]
237    fn test_stream_chunk_tool_result_serialization() {
238        let result = StreamChunk::tool_result("id1", "calculator", "42", true);
239        let json = serde_json::to_string(&result).unwrap();
240        assert!(json.contains("tool_result"));
241        assert!(json.contains("calculator"));
242        assert!(json.contains("42"));
243        assert!(json.contains("true"));
244    }
245
246    #[test]
247    fn test_streaming_config_full_yaml() {
248        let yaml = r#"
249enabled: false
250buffer_size: 128
251include_tool_events: false
252include_state_events: false
253"#;
254        let config: StreamingConfig = serde_yaml::from_str(yaml).unwrap();
255        assert!(!config.enabled);
256        assert_eq!(config.buffer_size, 128);
257        assert!(!config.include_tool_events);
258        assert!(!config.include_state_events);
259    }
260
261    #[test]
262    fn test_stream_chunk_deserialization() {
263        let json = r#"{"type":"content","text":"Hello"}"#;
264        let chunk: StreamChunk = serde_json::from_str(json).unwrap();
265        assert!(chunk.is_content());
266
267        let json = r#"{"type":"done"}"#;
268        let chunk: StreamChunk = serde_json::from_str(json).unwrap();
269        assert!(chunk.is_done());
270
271        let json = r#"{"type":"error","message":"fail"}"#;
272        let chunk: StreamChunk = serde_json::from_str(json).unwrap();
273        assert!(chunk.is_error());
274    }
275
276    #[test]
277    fn test_stream_chunk_tool_events() {
278        let start = StreamChunk::tool_start("tool-1", "http");
279        let delta = StreamChunk::tool_delta("tool-1", r#"{"url":"test"}"#);
280        let end = StreamChunk::tool_end("tool-1");
281
282        match start {
283            StreamChunk::ToolCallStart { id, name } => {
284                assert_eq!(id, "tool-1");
285                assert_eq!(name, "http");
286            }
287            _ => panic!("Expected ToolCallStart"),
288        }
289
290        match delta {
291            StreamChunk::ToolCallDelta { id, arguments } => {
292                assert_eq!(id, "tool-1");
293                assert!(arguments.contains("url"));
294            }
295            _ => panic!("Expected ToolCallDelta"),
296        }
297
298        match end {
299            StreamChunk::ToolCallEnd { id } => {
300                assert_eq!(id, "tool-1");
301            }
302            _ => panic!("Expected ToolCallEnd"),
303        }
304    }
305}