Skip to main content

aster/streaming/
stream_io.rs

1//! Stream JSON I/O
2//!
3//! Provides streaming JSON input/output for CLI communication.
4
5use serde::{Deserialize, Serialize};
6use std::collections::VecDeque;
7use std::io::{BufRead, Write};
8
9/// Stream message types
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
11#[serde(rename_all = "snake_case")]
12pub enum StreamMessageType {
13    UserMessage,
14    AssistantMessage,
15    ToolUse,
16    ToolResult,
17    Error,
18    Done,
19    Partial,
20    System,
21}
22
23/// Base stream message
24#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct StreamMessage {
26    pub r#type: StreamMessageType,
27    pub timestamp: u64,
28    #[serde(skip_serializing_if = "Option::is_none")]
29    pub session_id: Option<String>,
30}
31
32/// User message
33#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct UserStreamMessage {
35    pub r#type: StreamMessageType,
36    pub timestamp: u64,
37    #[serde(skip_serializing_if = "Option::is_none")]
38    pub session_id: Option<String>,
39    pub content: String,
40    #[serde(skip_serializing_if = "Option::is_none")]
41    pub attachments: Option<Vec<Attachment>>,
42}
43
44/// Attachment for user messages
45#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct Attachment {
47    pub r#type: AttachmentType,
48    #[serde(skip_serializing_if = "Option::is_none")]
49    pub path: Option<String>,
50    #[serde(skip_serializing_if = "Option::is_none")]
51    pub data: Option<String>,
52    #[serde(skip_serializing_if = "Option::is_none")]
53    pub mime_type: Option<String>,
54}
55
56#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
57#[serde(rename_all = "snake_case")]
58pub enum AttachmentType {
59    File,
60    Image,
61}
62
63/// Assistant message
64#[derive(Debug, Clone, Serialize, Deserialize)]
65pub struct AssistantStreamMessage {
66    pub r#type: StreamMessageType,
67    pub timestamp: u64,
68    #[serde(skip_serializing_if = "Option::is_none")]
69    pub session_id: Option<String>,
70    pub content: String,
71    #[serde(skip_serializing_if = "Option::is_none")]
72    pub model: Option<String>,
73    #[serde(skip_serializing_if = "Option::is_none")]
74    pub stop_reason: Option<String>,
75}
76
77/// Tool use message
78#[derive(Debug, Clone, Serialize, Deserialize)]
79pub struct ToolUseStreamMessage {
80    pub r#type: StreamMessageType,
81    pub timestamp: u64,
82    #[serde(skip_serializing_if = "Option::is_none")]
83    pub session_id: Option<String>,
84    pub tool_id: String,
85    pub tool_name: String,
86    pub input: serde_json::Value,
87}
88
89/// Tool result message
90#[derive(Debug, Clone, Serialize, Deserialize)]
91pub struct ToolResultStreamMessage {
92    pub r#type: StreamMessageType,
93    pub timestamp: u64,
94    #[serde(skip_serializing_if = "Option::is_none")]
95    pub session_id: Option<String>,
96    pub tool_id: String,
97    pub success: bool,
98    #[serde(skip_serializing_if = "Option::is_none")]
99    pub output: Option<String>,
100    #[serde(skip_serializing_if = "Option::is_none")]
101    pub error: Option<String>,
102}
103
104/// Partial message (streaming output)
105#[derive(Debug, Clone, Serialize, Deserialize)]
106pub struct PartialStreamMessage {
107    pub r#type: StreamMessageType,
108    pub timestamp: u64,
109    #[serde(skip_serializing_if = "Option::is_none")]
110    pub session_id: Option<String>,
111    pub content: String,
112    pub index: usize,
113}
114
115/// Error message
116#[derive(Debug, Clone, Serialize, Deserialize)]
117pub struct ErrorStreamMessage {
118    pub r#type: StreamMessageType,
119    pub timestamp: u64,
120    #[serde(skip_serializing_if = "Option::is_none")]
121    pub session_id: Option<String>,
122    pub code: String,
123    pub message: String,
124    #[serde(skip_serializing_if = "Option::is_none")]
125    pub details: Option<serde_json::Value>,
126}
127
128/// Done message
129#[derive(Debug, Clone, Serialize, Deserialize)]
130pub struct DoneStreamMessage {
131    pub r#type: StreamMessageType,
132    pub timestamp: u64,
133    #[serde(skip_serializing_if = "Option::is_none")]
134    pub session_id: Option<String>,
135    #[serde(skip_serializing_if = "Option::is_none")]
136    pub stats: Option<StreamStats>,
137}
138
139/// Stream statistics
140#[derive(Debug, Clone, Serialize, Deserialize)]
141pub struct StreamStats {
142    pub input_tokens: usize,
143    pub output_tokens: usize,
144    pub total_cost_usd: f64,
145    pub duration_ms: u64,
146}
147
148/// System message
149#[derive(Debug, Clone, Serialize, Deserialize)]
150pub struct SystemStreamMessage {
151    pub r#type: StreamMessageType,
152    pub timestamp: u64,
153    #[serde(skip_serializing_if = "Option::is_none")]
154    pub session_id: Option<String>,
155    pub event: String,
156    #[serde(skip_serializing_if = "Option::is_none")]
157    pub data: Option<serde_json::Value>,
158}
159
160/// Union type for all stream messages
161#[derive(Debug, Clone, Serialize, Deserialize)]
162#[serde(untagged)]
163pub enum AnyStreamMessage {
164    User(UserStreamMessage),
165    Assistant(AssistantStreamMessage),
166    ToolUse(ToolUseStreamMessage),
167    ToolResult(ToolResultStreamMessage),
168    Partial(PartialStreamMessage),
169    Error(ErrorStreamMessage),
170    Done(DoneStreamMessage),
171    System(SystemStreamMessage),
172}
173
174impl AnyStreamMessage {
175    /// Get message type
176    pub fn message_type(&self) -> StreamMessageType {
177        match self {
178            AnyStreamMessage::User(m) => m.r#type,
179            AnyStreamMessage::Assistant(m) => m.r#type,
180            AnyStreamMessage::ToolUse(m) => m.r#type,
181            AnyStreamMessage::ToolResult(m) => m.r#type,
182            AnyStreamMessage::Partial(m) => m.r#type,
183            AnyStreamMessage::Error(m) => m.r#type,
184            AnyStreamMessage::Done(m) => m.r#type,
185            AnyStreamMessage::System(m) => m.r#type,
186        }
187    }
188}
189
190/// Get current timestamp in milliseconds
191fn current_timestamp() -> u64 {
192    std::time::SystemTime::now()
193        .duration_since(std::time::UNIX_EPOCH)
194        .map(|d| d.as_millis() as u64)
195        .unwrap_or(0)
196}
197
198/// Generate a session ID
199fn generate_session_id() -> String {
200    use std::time::{SystemTime, UNIX_EPOCH};
201    let timestamp = SystemTime::now()
202        .duration_since(UNIX_EPOCH)
203        .map(|d| d.as_millis())
204        .unwrap_or(0);
205    let random: u64 = rand::random();
206    format!("session_{}_{:x}", timestamp, random & 0xFFFFFFFF)
207}
208
209/// Stream JSON reader
210pub struct StreamJsonReader {
211    buffer: VecDeque<AnyStreamMessage>,
212    closed: bool,
213}
214
215impl Default for StreamJsonReader {
216    fn default() -> Self {
217        Self::new()
218    }
219}
220
221impl StreamJsonReader {
222    /// Create a new reader
223    pub fn new() -> Self {
224        Self {
225            buffer: VecDeque::new(),
226            closed: false,
227        }
228    }
229
230    /// Process a line of JSON
231    pub fn process_line(&mut self, line: &str) -> Option<AnyStreamMessage> {
232        let trimmed = line.trim();
233        if trimmed.is_empty() {
234            return None;
235        }
236
237        serde_json::from_str::<AnyStreamMessage>(trimmed).ok()
238    }
239
240    /// Read from a BufRead source
241    pub fn read_from<R: BufRead>(
242        &mut self,
243        reader: &mut R,
244    ) -> std::io::Result<Option<AnyStreamMessage>> {
245        let mut line = String::new();
246        let bytes = reader.read_line(&mut line)?;
247
248        if bytes == 0 {
249            self.closed = true;
250            return Ok(None);
251        }
252
253        Ok(self.process_line(&line))
254    }
255
256    /// Check if closed
257    pub fn is_closed(&self) -> bool {
258        self.closed
259    }
260}
261
262/// Stream JSON writer
263pub struct StreamJsonWriter<W: Write> {
264    output: W,
265    session_id: String,
266    message_index: usize,
267}
268
269impl<W: Write> StreamJsonWriter<W> {
270    /// Create a new writer
271    pub fn new(output: W, session_id: Option<String>) -> Self {
272        Self {
273            output,
274            session_id: session_id.unwrap_or_else(generate_session_id),
275            message_index: 0,
276        }
277    }
278
279    /// Write a raw message
280    pub fn write(&mut self, message: &impl Serialize) -> std::io::Result<()> {
281        let json = serde_json::to_string(message)?;
282        writeln!(self.output, "{}", json)?;
283        self.output.flush()
284    }
285
286    /// Write user message
287    pub fn write_user_message(
288        &mut self,
289        content: &str,
290        attachments: Option<Vec<Attachment>>,
291    ) -> std::io::Result<()> {
292        let msg = UserStreamMessage {
293            r#type: StreamMessageType::UserMessage,
294            timestamp: current_timestamp(),
295            session_id: Some(self.session_id.clone()),
296            content: content.to_string(),
297            attachments,
298        };
299        self.write(&msg)
300    }
301
302    /// Write assistant message
303    pub fn write_assistant_message(
304        &mut self,
305        content: &str,
306        model: Option<&str>,
307        stop_reason: Option<&str>,
308    ) -> std::io::Result<()> {
309        let msg = AssistantStreamMessage {
310            r#type: StreamMessageType::AssistantMessage,
311            timestamp: current_timestamp(),
312            session_id: Some(self.session_id.clone()),
313            content: content.to_string(),
314            model: model.map(|s| s.to_string()),
315            stop_reason: stop_reason.map(|s| s.to_string()),
316        };
317        self.write(&msg)
318    }
319
320    /// Write tool use
321    pub fn write_tool_use(
322        &mut self,
323        tool_id: &str,
324        tool_name: &str,
325        input: serde_json::Value,
326    ) -> std::io::Result<()> {
327        let msg = ToolUseStreamMessage {
328            r#type: StreamMessageType::ToolUse,
329            timestamp: current_timestamp(),
330            session_id: Some(self.session_id.clone()),
331            tool_id: tool_id.to_string(),
332            tool_name: tool_name.to_string(),
333            input,
334        };
335        self.write(&msg)
336    }
337
338    /// Write tool result
339    pub fn write_tool_result(
340        &mut self,
341        tool_id: &str,
342        success: bool,
343        output: Option<&str>,
344        error: Option<&str>,
345    ) -> std::io::Result<()> {
346        let msg = ToolResultStreamMessage {
347            r#type: StreamMessageType::ToolResult,
348            timestamp: current_timestamp(),
349            session_id: Some(self.session_id.clone()),
350            tool_id: tool_id.to_string(),
351            success,
352            output: output.map(|s| s.to_string()),
353            error: error.map(|s| s.to_string()),
354        };
355        self.write(&msg)
356    }
357
358    /// Write partial message
359    pub fn write_partial(&mut self, content: &str) -> std::io::Result<()> {
360        let msg = PartialStreamMessage {
361            r#type: StreamMessageType::Partial,
362            timestamp: current_timestamp(),
363            session_id: Some(self.session_id.clone()),
364            content: content.to_string(),
365            index: self.message_index,
366        };
367        self.message_index += 1;
368        self.write(&msg)
369    }
370
371    /// Write error
372    pub fn write_error(
373        &mut self,
374        code: &str,
375        message: &str,
376        details: Option<serde_json::Value>,
377    ) -> std::io::Result<()> {
378        let msg = ErrorStreamMessage {
379            r#type: StreamMessageType::Error,
380            timestamp: current_timestamp(),
381            session_id: Some(self.session_id.clone()),
382            code: code.to_string(),
383            message: message.to_string(),
384            details,
385        };
386        self.write(&msg)
387    }
388
389    /// Write done
390    pub fn write_done(&mut self, stats: Option<StreamStats>) -> std::io::Result<()> {
391        let msg = DoneStreamMessage {
392            r#type: StreamMessageType::Done,
393            timestamp: current_timestamp(),
394            session_id: Some(self.session_id.clone()),
395            stats,
396        };
397        self.write(&msg)
398    }
399
400    /// Write system event
401    pub fn write_system(
402        &mut self,
403        event: &str,
404        data: Option<serde_json::Value>,
405    ) -> std::io::Result<()> {
406        let msg = SystemStreamMessage {
407            r#type: StreamMessageType::System,
408            timestamp: current_timestamp(),
409            session_id: Some(self.session_id.clone()),
410            event: event.to_string(),
411            data,
412        };
413        self.write(&msg)
414    }
415
416    /// Get session ID
417    pub fn session_id(&self) -> &str {
418        &self.session_id
419    }
420}
421
422/// Stream session handler
423pub struct StreamSession<R: BufRead, W: Write> {
424    reader: StreamJsonReader,
425    writer: StreamJsonWriter<W>,
426    input: R,
427}
428
429impl<R: BufRead, W: Write> StreamSession<R, W> {
430    /// Create a new stream session
431    pub fn new(input: R, output: W) -> Self {
432        Self {
433            reader: StreamJsonReader::new(),
434            writer: StreamJsonWriter::new(output, None),
435            input,
436        }
437    }
438
439    /// Get the writer
440    pub fn writer(&mut self) -> &mut StreamJsonWriter<W> {
441        &mut self.writer
442    }
443
444    /// Read next message
445    pub fn read_message(&mut self) -> std::io::Result<Option<AnyStreamMessage>> {
446        self.reader.read_from(&mut self.input)
447    }
448
449    /// Start session
450    pub fn start(&mut self) -> std::io::Result<()> {
451        self.writer.write_system("session_start", None)
452    }
453
454    /// End session
455    pub fn end(&mut self) -> std::io::Result<()> {
456        self.writer.write_done(None)
457    }
458}
459
460#[cfg(test)]
461mod tests {
462    use super::*;
463    use std::io::Cursor;
464
465    #[test]
466    fn test_stream_message_type_serialize() {
467        let msg_type = StreamMessageType::UserMessage;
468        let json = serde_json::to_string(&msg_type).unwrap();
469        assert_eq!(json, r#""user_message""#);
470    }
471
472    #[test]
473    fn test_stream_json_reader_process_line() {
474        let mut reader = StreamJsonReader::new();
475
476        let line = r#"{"type":"user_message","timestamp":123,"content":"hello"}"#;
477        let msg = reader.process_line(line);
478
479        assert!(msg.is_some());
480    }
481
482    #[test]
483    fn test_stream_json_reader_empty_line() {
484        let mut reader = StreamJsonReader::new();
485        assert!(reader.process_line("").is_none());
486        assert!(reader.process_line("   ").is_none());
487    }
488
489    #[test]
490    fn test_stream_json_writer_partial() {
491        let mut buffer = Vec::new();
492        {
493            let mut writer = StreamJsonWriter::new(&mut buffer, Some("test_session".to_string()));
494            writer.write_partial("Hello").unwrap();
495        }
496
497        let output = String::from_utf8(buffer).unwrap();
498        assert!(output.contains("partial"));
499        assert!(output.contains("Hello"));
500    }
501
502    #[test]
503    fn test_stream_json_writer_error() {
504        let mut buffer = Vec::new();
505        {
506            let mut writer = StreamJsonWriter::new(&mut buffer, None);
507            writer.write_error("ERR001", "Test error", None).unwrap();
508        }
509
510        let output = String::from_utf8(buffer).unwrap();
511        assert!(output.contains("error"));
512        assert!(output.contains("ERR001"));
513    }
514
515    #[test]
516    fn test_stream_json_writer_done() {
517        let mut buffer = Vec::new();
518        {
519            let mut writer = StreamJsonWriter::new(&mut buffer, None);
520            let stats = StreamStats {
521                input_tokens: 100,
522                output_tokens: 50,
523                total_cost_usd: 0.001,
524                duration_ms: 1000,
525            };
526            writer.write_done(Some(stats)).unwrap();
527        }
528
529        let output = String::from_utf8(buffer).unwrap();
530        assert!(output.contains("done"));
531        assert!(output.contains("100"));
532    }
533
534    #[test]
535    fn test_stream_session() {
536        let input = Cursor::new(Vec::new());
537        let mut output = Vec::new();
538
539        {
540            let mut session = StreamSession::new(input, &mut output);
541            session.start().unwrap();
542            session.end().unwrap();
543        }
544
545        let output_str = String::from_utf8(output).unwrap();
546        assert!(output_str.contains("session_start"));
547        assert!(output_str.contains("done"));
548    }
549
550    #[test]
551    fn test_any_stream_message_type() {
552        let msg = AnyStreamMessage::Error(ErrorStreamMessage {
553            r#type: StreamMessageType::Error,
554            timestamp: 0,
555            session_id: None,
556            code: "E1".to_string(),
557            message: "test".to_string(),
558            details: None,
559        });
560
561        assert_eq!(msg.message_type(), StreamMessageType::Error);
562    }
563}