1use crate::types::MessageId;
2use serde::{Deserialize, Serialize};
3
4#[derive(Clone, Debug, Serialize, Deserialize)]
5#[serde(tag = "type", rename_all = "snake_case")]
6pub enum StreamChunk {
7 MessageStart {
8 message_id: MessageId,
9 },
10 TextDelta {
11 text: String,
12 },
13 ToolUseStart {
14 id: String,
15 name: String,
16 },
17 ToolUseDelta {
18 id: String,
19 input_delta: String,
20 },
21 ToolUseEnd {
22 id: String,
23 },
24 ContentBlockStart {
25 index: usize,
26 },
27 ContentBlockEnd {
28 index: usize,
29 },
30 MessageEnd {
31 stop_reason: Option<StopReason>,
32 },
33 Usage {
34 input_tokens: u32,
35 output_tokens: u32,
36 },
37 Error {
38 code: String,
39 message: String,
40 },
41 Ping,
42}
43
44#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
45#[serde(rename_all = "snake_case")]
46pub enum StopReason {
47 EndTurn,
48 MaxTokens,
49 StopSequence,
50 ToolUse,
51}
52
53#[derive(Default)]
54struct ToolUseBuilder {
55 id: String,
56 name: String,
57 input_json: String,
58}
59
60#[derive(Default)]
61pub struct StreamAggregator {
62 message_id: Option<MessageId>,
63 text_buffer: String,
64 tool_uses: Vec<ToolUseBuilder>,
65 stop_reason: Option<StopReason>,
66 input_tokens: u32,
67 output_tokens: u32,
68 error: Option<(String, String)>,
69}
70
71impl StreamAggregator {
72 pub fn new() -> Self {
73 Self::default()
74 }
75
76 pub fn push(&mut self, chunk: StreamChunk) {
77 match chunk {
78 StreamChunk::MessageStart { message_id } => {
79 self.message_id = Some(message_id);
80 }
81 StreamChunk::TextDelta { text } => {
82 self.text_buffer.push_str(&text);
83 }
84 StreamChunk::ToolUseStart { id, name } => {
85 self.tool_uses.push(ToolUseBuilder {
86 id,
87 name,
88 input_json: String::new(),
89 });
90 }
91 StreamChunk::ToolUseDelta { id, input_delta } => {
92 if let Some(tu) = self.tool_uses.iter_mut().find(|t| t.id == id) {
93 tu.input_json.push_str(&input_delta);
94 }
95 }
96 StreamChunk::MessageEnd { stop_reason } => {
97 self.stop_reason = stop_reason;
98 }
99 StreamChunk::Usage {
100 input_tokens,
101 output_tokens,
102 } => {
103 self.input_tokens = input_tokens;
104 self.output_tokens = output_tokens;
105 }
106 StreamChunk::Error { code, message } => {
107 self.error = Some((code, message));
108 }
109 _ => {}
110 }
111 }
112
113 pub fn message_id(&self) -> Option<MessageId> {
114 self.message_id
115 }
116
117 pub fn text(&self) -> &str {
118 &self.text_buffer
119 }
120
121 pub fn stop_reason(&self) -> Option<&StopReason> {
122 self.stop_reason.as_ref()
123 }
124
125 pub fn is_complete(&self) -> bool {
126 self.stop_reason.is_some()
127 }
128
129 pub fn has_error(&self) -> bool {
130 self.error.is_some()
131 }
132
133 pub fn error(&self) -> Option<(&str, &str)> {
134 self.error.as_ref().map(|(c, m)| (c.as_str(), m.as_str()))
135 }
136
137 pub fn input_tokens(&self) -> u32 {
138 self.input_tokens
139 }
140
141 pub fn output_tokens(&self) -> u32 {
142 self.output_tokens
143 }
144
145 pub fn total_tokens(&self) -> u32 {
146 self.input_tokens + self.output_tokens
147 }
148
149 pub fn tool_use_count(&self) -> usize {
150 self.tool_uses.len()
151 }
152
153 pub fn has_tool_use(&self) -> bool {
154 !self.tool_uses.is_empty()
155 }
156
157 pub fn clear(&mut self) {
158 *self = Self::default();
159 }
160}