1use serde::{Deserialize, Serialize};
6use std::collections::VecDeque;
7use std::io::{BufRead, Write};
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
11#[serde(rename_all = "snake_case")]
12pub enum StreamMessageType {
13 UserMessage,
14 AssistantMessage,
15 ToolUse,
16 ToolResult,
17 Error,
18 Done,
19 Partial,
20 System,
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct StreamMessage {
26 pub r#type: StreamMessageType,
27 pub timestamp: u64,
28 #[serde(skip_serializing_if = "Option::is_none")]
29 pub session_id: Option<String>,
30}
31
32#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct UserStreamMessage {
35 pub r#type: StreamMessageType,
36 pub timestamp: u64,
37 #[serde(skip_serializing_if = "Option::is_none")]
38 pub session_id: Option<String>,
39 pub content: String,
40 #[serde(skip_serializing_if = "Option::is_none")]
41 pub attachments: Option<Vec<Attachment>>,
42}
43
44#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct Attachment {
47 pub r#type: AttachmentType,
48 #[serde(skip_serializing_if = "Option::is_none")]
49 pub path: Option<String>,
50 #[serde(skip_serializing_if = "Option::is_none")]
51 pub data: Option<String>,
52 #[serde(skip_serializing_if = "Option::is_none")]
53 pub mime_type: Option<String>,
54}
55
56#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
57#[serde(rename_all = "snake_case")]
58pub enum AttachmentType {
59 File,
60 Image,
61}
62
63#[derive(Debug, Clone, Serialize, Deserialize)]
65pub struct AssistantStreamMessage {
66 pub r#type: StreamMessageType,
67 pub timestamp: u64,
68 #[serde(skip_serializing_if = "Option::is_none")]
69 pub session_id: Option<String>,
70 pub content: String,
71 #[serde(skip_serializing_if = "Option::is_none")]
72 pub model: Option<String>,
73 #[serde(skip_serializing_if = "Option::is_none")]
74 pub stop_reason: Option<String>,
75}
76
77#[derive(Debug, Clone, Serialize, Deserialize)]
79pub struct ToolUseStreamMessage {
80 pub r#type: StreamMessageType,
81 pub timestamp: u64,
82 #[serde(skip_serializing_if = "Option::is_none")]
83 pub session_id: Option<String>,
84 pub tool_id: String,
85 pub tool_name: String,
86 pub input: serde_json::Value,
87}
88
89#[derive(Debug, Clone, Serialize, Deserialize)]
91pub struct ToolResultStreamMessage {
92 pub r#type: StreamMessageType,
93 pub timestamp: u64,
94 #[serde(skip_serializing_if = "Option::is_none")]
95 pub session_id: Option<String>,
96 pub tool_id: String,
97 pub success: bool,
98 #[serde(skip_serializing_if = "Option::is_none")]
99 pub output: Option<String>,
100 #[serde(skip_serializing_if = "Option::is_none")]
101 pub error: Option<String>,
102}
103
104#[derive(Debug, Clone, Serialize, Deserialize)]
106pub struct PartialStreamMessage {
107 pub r#type: StreamMessageType,
108 pub timestamp: u64,
109 #[serde(skip_serializing_if = "Option::is_none")]
110 pub session_id: Option<String>,
111 pub content: String,
112 pub index: usize,
113}
114
115#[derive(Debug, Clone, Serialize, Deserialize)]
117pub struct ErrorStreamMessage {
118 pub r#type: StreamMessageType,
119 pub timestamp: u64,
120 #[serde(skip_serializing_if = "Option::is_none")]
121 pub session_id: Option<String>,
122 pub code: String,
123 pub message: String,
124 #[serde(skip_serializing_if = "Option::is_none")]
125 pub details: Option<serde_json::Value>,
126}
127
128#[derive(Debug, Clone, Serialize, Deserialize)]
130pub struct DoneStreamMessage {
131 pub r#type: StreamMessageType,
132 pub timestamp: u64,
133 #[serde(skip_serializing_if = "Option::is_none")]
134 pub session_id: Option<String>,
135 #[serde(skip_serializing_if = "Option::is_none")]
136 pub stats: Option<StreamStats>,
137}
138
139#[derive(Debug, Clone, Serialize, Deserialize)]
141pub struct StreamStats {
142 pub input_tokens: usize,
143 pub output_tokens: usize,
144 pub total_cost_usd: f64,
145 pub duration_ms: u64,
146}
147
148#[derive(Debug, Clone, Serialize, Deserialize)]
150pub struct SystemStreamMessage {
151 pub r#type: StreamMessageType,
152 pub timestamp: u64,
153 #[serde(skip_serializing_if = "Option::is_none")]
154 pub session_id: Option<String>,
155 pub event: String,
156 #[serde(skip_serializing_if = "Option::is_none")]
157 pub data: Option<serde_json::Value>,
158}
159
160#[derive(Debug, Clone, Serialize, Deserialize)]
162#[serde(untagged)]
163pub enum AnyStreamMessage {
164 User(UserStreamMessage),
165 Assistant(AssistantStreamMessage),
166 ToolUse(ToolUseStreamMessage),
167 ToolResult(ToolResultStreamMessage),
168 Partial(PartialStreamMessage),
169 Error(ErrorStreamMessage),
170 Done(DoneStreamMessage),
171 System(SystemStreamMessage),
172}
173
174impl AnyStreamMessage {
175 pub fn message_type(&self) -> StreamMessageType {
177 match self {
178 AnyStreamMessage::User(m) => m.r#type,
179 AnyStreamMessage::Assistant(m) => m.r#type,
180 AnyStreamMessage::ToolUse(m) => m.r#type,
181 AnyStreamMessage::ToolResult(m) => m.r#type,
182 AnyStreamMessage::Partial(m) => m.r#type,
183 AnyStreamMessage::Error(m) => m.r#type,
184 AnyStreamMessage::Done(m) => m.r#type,
185 AnyStreamMessage::System(m) => m.r#type,
186 }
187 }
188}
189
190fn current_timestamp() -> u64 {
192 std::time::SystemTime::now()
193 .duration_since(std::time::UNIX_EPOCH)
194 .map(|d| d.as_millis() as u64)
195 .unwrap_or(0)
196}
197
198fn generate_session_id() -> String {
200 use std::time::{SystemTime, UNIX_EPOCH};
201 let timestamp = SystemTime::now()
202 .duration_since(UNIX_EPOCH)
203 .map(|d| d.as_millis())
204 .unwrap_or(0);
205 let random: u64 = rand::random();
206 format!("session_{}_{:x}", timestamp, random & 0xFFFFFFFF)
207}
208
209pub struct StreamJsonReader {
211 buffer: VecDeque<AnyStreamMessage>,
212 closed: bool,
213}
214
215impl Default for StreamJsonReader {
216 fn default() -> Self {
217 Self::new()
218 }
219}
220
221impl StreamJsonReader {
222 pub fn new() -> Self {
224 Self {
225 buffer: VecDeque::new(),
226 closed: false,
227 }
228 }
229
230 pub fn process_line(&mut self, line: &str) -> Option<AnyStreamMessage> {
232 let trimmed = line.trim();
233 if trimmed.is_empty() {
234 return None;
235 }
236
237 serde_json::from_str::<AnyStreamMessage>(trimmed).ok()
238 }
239
240 pub fn read_from<R: BufRead>(
242 &mut self,
243 reader: &mut R,
244 ) -> std::io::Result<Option<AnyStreamMessage>> {
245 let mut line = String::new();
246 let bytes = reader.read_line(&mut line)?;
247
248 if bytes == 0 {
249 self.closed = true;
250 return Ok(None);
251 }
252
253 Ok(self.process_line(&line))
254 }
255
256 pub fn is_closed(&self) -> bool {
258 self.closed
259 }
260}
261
262pub struct StreamJsonWriter<W: Write> {
264 output: W,
265 session_id: String,
266 message_index: usize,
267}
268
269impl<W: Write> StreamJsonWriter<W> {
270 pub fn new(output: W, session_id: Option<String>) -> Self {
272 Self {
273 output,
274 session_id: session_id.unwrap_or_else(generate_session_id),
275 message_index: 0,
276 }
277 }
278
279 pub fn write(&mut self, message: &impl Serialize) -> std::io::Result<()> {
281 let json = serde_json::to_string(message)?;
282 writeln!(self.output, "{}", json)?;
283 self.output.flush()
284 }
285
286 pub fn write_user_message(
288 &mut self,
289 content: &str,
290 attachments: Option<Vec<Attachment>>,
291 ) -> std::io::Result<()> {
292 let msg = UserStreamMessage {
293 r#type: StreamMessageType::UserMessage,
294 timestamp: current_timestamp(),
295 session_id: Some(self.session_id.clone()),
296 content: content.to_string(),
297 attachments,
298 };
299 self.write(&msg)
300 }
301
302 pub fn write_assistant_message(
304 &mut self,
305 content: &str,
306 model: Option<&str>,
307 stop_reason: Option<&str>,
308 ) -> std::io::Result<()> {
309 let msg = AssistantStreamMessage {
310 r#type: StreamMessageType::AssistantMessage,
311 timestamp: current_timestamp(),
312 session_id: Some(self.session_id.clone()),
313 content: content.to_string(),
314 model: model.map(|s| s.to_string()),
315 stop_reason: stop_reason.map(|s| s.to_string()),
316 };
317 self.write(&msg)
318 }
319
320 pub fn write_tool_use(
322 &mut self,
323 tool_id: &str,
324 tool_name: &str,
325 input: serde_json::Value,
326 ) -> std::io::Result<()> {
327 let msg = ToolUseStreamMessage {
328 r#type: StreamMessageType::ToolUse,
329 timestamp: current_timestamp(),
330 session_id: Some(self.session_id.clone()),
331 tool_id: tool_id.to_string(),
332 tool_name: tool_name.to_string(),
333 input,
334 };
335 self.write(&msg)
336 }
337
338 pub fn write_tool_result(
340 &mut self,
341 tool_id: &str,
342 success: bool,
343 output: Option<&str>,
344 error: Option<&str>,
345 ) -> std::io::Result<()> {
346 let msg = ToolResultStreamMessage {
347 r#type: StreamMessageType::ToolResult,
348 timestamp: current_timestamp(),
349 session_id: Some(self.session_id.clone()),
350 tool_id: tool_id.to_string(),
351 success,
352 output: output.map(|s| s.to_string()),
353 error: error.map(|s| s.to_string()),
354 };
355 self.write(&msg)
356 }
357
358 pub fn write_partial(&mut self, content: &str) -> std::io::Result<()> {
360 let msg = PartialStreamMessage {
361 r#type: StreamMessageType::Partial,
362 timestamp: current_timestamp(),
363 session_id: Some(self.session_id.clone()),
364 content: content.to_string(),
365 index: self.message_index,
366 };
367 self.message_index += 1;
368 self.write(&msg)
369 }
370
371 pub fn write_error(
373 &mut self,
374 code: &str,
375 message: &str,
376 details: Option<serde_json::Value>,
377 ) -> std::io::Result<()> {
378 let msg = ErrorStreamMessage {
379 r#type: StreamMessageType::Error,
380 timestamp: current_timestamp(),
381 session_id: Some(self.session_id.clone()),
382 code: code.to_string(),
383 message: message.to_string(),
384 details,
385 };
386 self.write(&msg)
387 }
388
389 pub fn write_done(&mut self, stats: Option<StreamStats>) -> std::io::Result<()> {
391 let msg = DoneStreamMessage {
392 r#type: StreamMessageType::Done,
393 timestamp: current_timestamp(),
394 session_id: Some(self.session_id.clone()),
395 stats,
396 };
397 self.write(&msg)
398 }
399
400 pub fn write_system(
402 &mut self,
403 event: &str,
404 data: Option<serde_json::Value>,
405 ) -> std::io::Result<()> {
406 let msg = SystemStreamMessage {
407 r#type: StreamMessageType::System,
408 timestamp: current_timestamp(),
409 session_id: Some(self.session_id.clone()),
410 event: event.to_string(),
411 data,
412 };
413 self.write(&msg)
414 }
415
416 pub fn session_id(&self) -> &str {
418 &self.session_id
419 }
420}
421
422pub struct StreamSession<R: BufRead, W: Write> {
424 reader: StreamJsonReader,
425 writer: StreamJsonWriter<W>,
426 input: R,
427}
428
429impl<R: BufRead, W: Write> StreamSession<R, W> {
430 pub fn new(input: R, output: W) -> Self {
432 Self {
433 reader: StreamJsonReader::new(),
434 writer: StreamJsonWriter::new(output, None),
435 input,
436 }
437 }
438
439 pub fn writer(&mut self) -> &mut StreamJsonWriter<W> {
441 &mut self.writer
442 }
443
444 pub fn read_message(&mut self) -> std::io::Result<Option<AnyStreamMessage>> {
446 self.reader.read_from(&mut self.input)
447 }
448
449 pub fn start(&mut self) -> std::io::Result<()> {
451 self.writer.write_system("session_start", None)
452 }
453
454 pub fn end(&mut self) -> std::io::Result<()> {
456 self.writer.write_done(None)
457 }
458}
459
460#[cfg(test)]
461mod tests {
462 use super::*;
463 use std::io::Cursor;
464
465 #[test]
466 fn test_stream_message_type_serialize() {
467 let msg_type = StreamMessageType::UserMessage;
468 let json = serde_json::to_string(&msg_type).unwrap();
469 assert_eq!(json, r#""user_message""#);
470 }
471
472 #[test]
473 fn test_stream_json_reader_process_line() {
474 let mut reader = StreamJsonReader::new();
475
476 let line = r#"{"type":"user_message","timestamp":123,"content":"hello"}"#;
477 let msg = reader.process_line(line);
478
479 assert!(msg.is_some());
480 }
481
482 #[test]
483 fn test_stream_json_reader_empty_line() {
484 let mut reader = StreamJsonReader::new();
485 assert!(reader.process_line("").is_none());
486 assert!(reader.process_line(" ").is_none());
487 }
488
489 #[test]
490 fn test_stream_json_writer_partial() {
491 let mut buffer = Vec::new();
492 {
493 let mut writer = StreamJsonWriter::new(&mut buffer, Some("test_session".to_string()));
494 writer.write_partial("Hello").unwrap();
495 }
496
497 let output = String::from_utf8(buffer).unwrap();
498 assert!(output.contains("partial"));
499 assert!(output.contains("Hello"));
500 }
501
502 #[test]
503 fn test_stream_json_writer_error() {
504 let mut buffer = Vec::new();
505 {
506 let mut writer = StreamJsonWriter::new(&mut buffer, None);
507 writer.write_error("ERR001", "Test error", None).unwrap();
508 }
509
510 let output = String::from_utf8(buffer).unwrap();
511 assert!(output.contains("error"));
512 assert!(output.contains("ERR001"));
513 }
514
515 #[test]
516 fn test_stream_json_writer_done() {
517 let mut buffer = Vec::new();
518 {
519 let mut writer = StreamJsonWriter::new(&mut buffer, None);
520 let stats = StreamStats {
521 input_tokens: 100,
522 output_tokens: 50,
523 total_cost_usd: 0.001,
524 duration_ms: 1000,
525 };
526 writer.write_done(Some(stats)).unwrap();
527 }
528
529 let output = String::from_utf8(buffer).unwrap();
530 assert!(output.contains("done"));
531 assert!(output.contains("100"));
532 }
533
534 #[test]
535 fn test_stream_session() {
536 let input = Cursor::new(Vec::new());
537 let mut output = Vec::new();
538
539 {
540 let mut session = StreamSession::new(input, &mut output);
541 session.start().unwrap();
542 session.end().unwrap();
543 }
544
545 let output_str = String::from_utf8(output).unwrap();
546 assert!(output_str.contains("session_start"));
547 assert!(output_str.contains("done"));
548 }
549
550 #[test]
551 fn test_any_stream_message_type() {
552 let msg = AnyStreamMessage::Error(ErrorStreamMessage {
553 r#type: StreamMessageType::Error,
554 timestamp: 0,
555 session_id: None,
556 code: "E1".to_string(),
557 message: "test".to_string(),
558 details: None,
559 });
560
561 assert_eq!(msg.message_type(), StreamMessageType::Error);
562 }
563}