hehe_core/stream/
mod.rs

1use crate::types::MessageId;
2use serde::{Deserialize, Serialize};
3
4#[derive(Clone, Debug, Serialize, Deserialize)]
5#[serde(tag = "type", rename_all = "snake_case")]
6pub enum StreamChunk {
7    MessageStart {
8        message_id: MessageId,
9    },
10    TextDelta {
11        text: String,
12    },
13    ToolUseStart {
14        id: String,
15        name: String,
16    },
17    ToolUseDelta {
18        id: String,
19        input_delta: String,
20    },
21    ToolUseEnd {
22        id: String,
23    },
24    ContentBlockStart {
25        index: usize,
26    },
27    ContentBlockEnd {
28        index: usize,
29    },
30    MessageEnd {
31        stop_reason: Option<StopReason>,
32    },
33    Usage {
34        input_tokens: u32,
35        output_tokens: u32,
36    },
37    Error {
38        code: String,
39        message: String,
40    },
41    Ping,
42}
43
44#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
45#[serde(rename_all = "snake_case")]
46pub enum StopReason {
47    EndTurn,
48    MaxTokens,
49    StopSequence,
50    ToolUse,
51}
52
53#[derive(Default)]
54struct ToolUseBuilder {
55    id: String,
56    name: String,
57    input_json: String,
58}
59
60#[derive(Default)]
61pub struct StreamAggregator {
62    message_id: Option<MessageId>,
63    text_buffer: String,
64    tool_uses: Vec<ToolUseBuilder>,
65    stop_reason: Option<StopReason>,
66    input_tokens: u32,
67    output_tokens: u32,
68    error: Option<(String, String)>,
69}
70
71impl StreamAggregator {
72    pub fn new() -> Self {
73        Self::default()
74    }
75
76    pub fn push(&mut self, chunk: StreamChunk) {
77        match chunk {
78            StreamChunk::MessageStart { message_id } => {
79                self.message_id = Some(message_id);
80            }
81            StreamChunk::TextDelta { text } => {
82                self.text_buffer.push_str(&text);
83            }
84            StreamChunk::ToolUseStart { id, name } => {
85                self.tool_uses.push(ToolUseBuilder {
86                    id,
87                    name,
88                    input_json: String::new(),
89                });
90            }
91            StreamChunk::ToolUseDelta { id, input_delta } => {
92                if let Some(tu) = self.tool_uses.iter_mut().find(|t| t.id == id) {
93                    tu.input_json.push_str(&input_delta);
94                }
95            }
96            StreamChunk::MessageEnd { stop_reason } => {
97                self.stop_reason = stop_reason;
98            }
99            StreamChunk::Usage {
100                input_tokens,
101                output_tokens,
102            } => {
103                self.input_tokens = input_tokens;
104                self.output_tokens = output_tokens;
105            }
106            StreamChunk::Error { code, message } => {
107                self.error = Some((code, message));
108            }
109            _ => {}
110        }
111    }
112
113    pub fn message_id(&self) -> Option<MessageId> {
114        self.message_id
115    }
116
117    pub fn text(&self) -> &str {
118        &self.text_buffer
119    }
120
121    pub fn stop_reason(&self) -> Option<&StopReason> {
122        self.stop_reason.as_ref()
123    }
124
125    pub fn is_complete(&self) -> bool {
126        self.stop_reason.is_some()
127    }
128
129    pub fn has_error(&self) -> bool {
130        self.error.is_some()
131    }
132
133    pub fn error(&self) -> Option<(&str, &str)> {
134        self.error.as_ref().map(|(c, m)| (c.as_str(), m.as_str()))
135    }
136
137    pub fn input_tokens(&self) -> u32 {
138        self.input_tokens
139    }
140
141    pub fn output_tokens(&self) -> u32 {
142        self.output_tokens
143    }
144
145    pub fn total_tokens(&self) -> u32 {
146        self.input_tokens + self.output_tokens
147    }
148
149    pub fn tool_use_count(&self) -> usize {
150        self.tool_uses.len()
151    }
152
153    pub fn has_tool_use(&self) -> bool {
154        !self.tool_uses.is_empty()
155    }
156
157    pub fn clear(&mut self) {
158        *self = Self::default();
159    }
160}