1use std::{
2 convert::Infallible,
3 fmt::{self, Display, Formatter},
4 str::FromStr,
5};
6
7use serde::{Deserialize, Serialize};
8
9use super::{response::Response, BaseContentBlock, DeltaContentBlock, StopReason, Usage};
10
11#[derive(Debug, Deserialize, Clone, PartialEq, Serialize)]
12#[serde(rename_all = "snake_case")]
13pub enum EventName {
14 Unspecified,
15 Error,
16 MessageStart,
17 ContentBlockDelta,
18 ContentBlockStart,
19 Ping,
20 ContentBlockStop,
21 MessageDelta,
22 MessageStop,
23}
24
25impl FromStr for EventName {
26 type Err = Infallible;
27 fn from_str(s: &str) -> Result<Self, Self::Err> {
28 match s {
29 "error" => Ok(EventName::Error),
30 "message_start" => Ok(EventName::MessageStart),
31 "content_block_start" => Ok(EventName::ContentBlockStart),
32 "ping" => Ok(EventName::Ping),
33 "content_block_delta" => Ok(EventName::ContentBlockDelta),
34 "content_block_stop" => Ok(EventName::ContentBlockStop),
35 "message_delta" => Ok(EventName::MessageDelta),
36 "message_stop" => Ok(EventName::MessageStop),
37 _ => Ok(EventName::Unspecified),
38 }
39 }
40}
41
42#[derive(Debug, Deserialize, Clone, PartialEq, Serialize)]
43#[serde(rename_all = "snake_case", tag = "type")]
44pub enum EventData {
45 Error {
46 error: ErrorData,
47 },
48 MessageStart {
49 message: Response,
50 },
51 ContentBlockStart {
52 index: u32,
53 content_block: BaseContentBlock,
54 },
55 Ping,
56 ContentBlockDelta {
57 index: u32,
58 delta: DeltaContentBlock,
59 },
60 ContentBlockStop {
61 index: u32,
62 },
63 MessageDelta {
64 delta: MessageDelta,
65 usage: Usage,
66 },
67 MessageStop,
68}
69
70#[derive(Debug, Deserialize, Clone, PartialEq, Serialize)]
71#[serde(rename_all = "snake_case", tag = "type")]
72pub enum ErrorData {
73 OverloadedError { message: String },
74 InternalServerError { message: String },
76 BadRequestError { message: String },
77 UnauthorizedError { message: String },
78}
79
80impl Display for ErrorData {
81 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
82 match self {
83 ErrorData::OverloadedError { message } => write!(f, "OverloadedError: {}", message),
84 ErrorData::InternalServerError { message } => {
85 write!(f, "InternalServerError: {}", message)
86 }
87 ErrorData::BadRequestError { message } => write!(f, "BadRequestError: {}", message),
88 ErrorData::UnauthorizedError { message } => write!(f, "UnauthorizedError: {}", message),
89 }
90 }
91}
92
93#[derive(Debug, Deserialize, Clone, PartialEq, Serialize)]
94pub struct MessageDelta {
95 pub stop_reason: StopReason,
96 pub stop_sequence: Option<String>,
97}
98
99#[cfg(test)]
100mod tests {
101 use super::*;
102 use crate::messages::{Role, ToolUseContentBlock};
103 #[test]
104 fn serde() {
105 let tests = vec![
106 (
107 "error_overloaded",
108 "error",
109 r#"{"type": "error", "error": {"type": "overloaded_error", "message": "Overloaded"}}"#,
110 EventName::Error,
111 EventData::Error {
112 error: ErrorData::OverloadedError {
113 message: "Overloaded".to_string(),
114 },
115 },
116 ),
117 (
118 "message_start_empty_content",
119 "message_start",
120 r#"{"type":"message_start","message":{"id":"msg_019LBLYFJ7fG3fuAqzuRQbyi","type":"message","role":"assistant","content":[],"model":"claude-3-opus-20240229","stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":10,"output_tokens":1}}}"#,
121 EventName::MessageStart,
122 EventData::MessageStart {
123 message: Response {
124 id: "msg_019LBLYFJ7fG3fuAqzuRQbyi".to_string(),
125 r#type: "message".to_string(),
126 role: Role::Assistant,
127 content: vec![],
128 model: "claude-3-opus-20240229".to_string(),
129 stop_reason: None,
130 stop_sequence: None,
131 usage: Usage {
132 input_tokens: Some(10),
133 output_tokens: 1,
134 },
135 },
136 },
137 ),
138 (
139 "content_block_start_empty_text",
140 "content_block_start",
141 r#"{"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}"#,
142 EventName::ContentBlockStart,
143 EventData::ContentBlockStart {
144 index: 0,
145 content_block: BaseContentBlock::Text {
146 text: "".to_string(),
147 },
148 },
149 ),
150 (
151 "ping_event",
152 "ping",
153 r#"{"type": "ping"}"#,
154 EventName::Ping,
155 EventData::Ping,
156 ),
157 (
158 "content_block_delta_hello",
159 "content_block_delta",
160 r#"{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hello"}}"#,
161 EventName::ContentBlockDelta,
162 EventData::ContentBlockDelta {
163 index: 0,
164 delta: DeltaContentBlock::TextDelta {
165 text: "Hello".to_string(),
166 },
167 },
168 ),
169 (
170 "content_block_delta_exclamation",
171 "content_block_delta",
172 r#"{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"!"}}"#,
173 EventName::ContentBlockDelta,
174 EventData::ContentBlockDelta {
175 index: 0,
176 delta: DeltaContentBlock::TextDelta {
177 text: "!".to_string(),
178 },
179 },
180 ),
181 (
182 "content_block_stop_index_0",
183 "content_block_stop",
184 r#"{"type":"content_block_stop","index":0}"#,
185 EventName::ContentBlockStop,
186 EventData::ContentBlockStop { index: 0 },
187 ),
188 (
189 "message_delta_end_turn",
190 "message_delta",
191 r#"{"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"output_tokens":12}}"#,
192 EventName::MessageDelta,
193 EventData::MessageDelta {
194 delta: MessageDelta {
195 stop_reason: StopReason::EndTurn,
196 stop_sequence: None,
197 },
198 usage: Usage {
199 input_tokens: None,
200 output_tokens: 12,
201 },
202 },
203 ),
204 (
205 "message_stop_event",
206 "message_stop",
207 r#"{"type":"message_stop"}"#,
208 EventName::MessageStop,
209 EventData::MessageStop,
210 ),
211 (
213 "content_block_start_tool_use",
214 "content_block_start",
215 r#"{"type":"content_block_start","index":1,"content_block":{"type":"tool_use","id":"tu_01AbCdEfGhIjKlMnOpQrStUv","name":"weather_forecast","input":{}}}"#,
216 EventName::ContentBlockStart,
217 EventData::ContentBlockStart {
218 index: 1,
219 content_block: BaseContentBlock::ToolUse(ToolUseContentBlock {
220 id: "tu_01AbCdEfGhIjKlMnOpQrStUv".to_string(),
221 name: "weather_forecast".to_string(),
222 input: serde_json::json!({}),
223 }),
224 },
225 ),
226 (
227 "content_block_delta_input_json_start",
228 "content_block_delta",
229 r#"{"type":"content_block_delta","index":1,"delta":{"type":"input_json_delta","partial_json":"{\"location\": \"San Fra\"}"}}"#,
230 EventName::ContentBlockDelta,
231 EventData::ContentBlockDelta {
232 index: 1,
233 delta: DeltaContentBlock::InputJsonDelta {
234 partial_json: "{\"location\": \"San Fra\"}".to_string(),
235 },
236 },
237 ),
238 (
239 "content_block_delta_input_json_continuation",
240 "content_block_delta",
241 r#"{"type":"content_block_delta","index":1,"delta":{"type":"input_json_delta","partial_json":"ncisco\"}"}}"#,
242 EventName::ContentBlockDelta,
243 EventData::ContentBlockDelta {
244 index: 1,
245 delta: DeltaContentBlock::InputJsonDelta {
246 partial_json: "ncisco\"}".to_string(),
247 },
248 },
249 ),
250 (
251 "content_block_start_thinking",
252 "content_block_start",
253 r#"{"type":"content_block_start","index":2,"content_block":{"type":"thinking","thinking":"","signature":null}}"#,
254 EventName::ContentBlockStart,
255 EventData::ContentBlockStart {
256 index: 2,
257 content_block: BaseContentBlock::Thinking {
258 thinking: "".to_string(),
259 signature: None,
260 },
261 },
262 ),
263 (
264 "content_block_delta_thinking",
265 "content_block_delta",
266 r#"{"type":"content_block_delta","index":2,"delta":{"type":"thinking_delta","thinking":"Let me solve this step by step:\n\n1. First break down 27 * 453"}}"#,
267 EventName::ContentBlockDelta,
268 EventData::ContentBlockDelta {
269 index: 2,
270 delta: DeltaContentBlock::ThinkingDelta {
271 thinking: "Let me solve this step by step:\n\n1. First break down 27 * 453"
272 .to_string(),
273 },
274 },
275 ),
276 (
277 "content_block_delta_signature",
278 "content_block_delta",
279 r#"{"type":"content_block_delta","index":2,"delta":{"type":"signature_delta","signature":"EqQBCgIYAhIM1gbcDa9GJwZA2b3hGgxBdjrkzLoky3dl1pkiMOYds..."}}"#,
280 EventName::ContentBlockDelta,
281 EventData::ContentBlockDelta {
282 index: 2,
283 delta: DeltaContentBlock::SignatureDelta {
284 signature: "EqQBCgIYAhIM1gbcDa9GJwZA2b3hGgxBdjrkzLoky3dl1pkiMOYds..."
285 .to_string(),
286 },
287 },
288 ),
289 (
290 "message_delta_max_tokens",
291 "message_delta",
292 r#"{"type":"message_delta","delta":{"stop_reason":"max_tokens","stop_sequence":null},"usage":{"output_tokens":1024}}"#,
293 EventName::MessageDelta,
294 EventData::MessageDelta {
295 delta: MessageDelta {
296 stop_reason: StopReason::MaxTokens,
297 stop_sequence: None,
298 },
299 usage: Usage {
300 input_tokens: None,
301 output_tokens: 1024,
302 },
303 },
304 ),
305 (
306 "message_delta_stop_sequence",
307 "message_delta",
308 r#"{"type":"message_delta","delta":{"stop_reason":"stop_sequence","stop_sequence":"STOP"},"usage":{"output_tokens":45}}"#,
309 EventName::MessageDelta,
310 EventData::MessageDelta {
311 delta: MessageDelta {
312 stop_reason: StopReason::StopSequence,
313 stop_sequence: Some("STOP".to_string()),
314 },
315 usage: Usage {
316 input_tokens: None,
317 output_tokens: 45,
318 },
319 },
320 ),
321 (
322 "content_block_start_tool_result",
323 "content_block_start",
324 r#"{"type":"content_block_start","index":1,"content_block":{"type":"tool_use","id":"toolu_01T1x1fJ34qAmk2tNTrN7Up6","name":"get_weather","input":{}}}"#,
325 EventName::ContentBlockStart,
326 EventData::ContentBlockStart {
327 index: 1,
328 content_block: BaseContentBlock::ToolUse(ToolUseContentBlock {
329 id: "toolu_01T1x1fJ34qAmk2tNTrN7Up6".to_string(),
330 name: "get_weather".to_string(),
331 input: serde_json::json!({}),
332 }),
333 },
334 ),
335 ];
336 for (test_name, name, input, event_name, event_data) in tests {
337 let got_event_name = EventName::from_str(name).unwrap();
338 assert_eq!(
339 got_event_name, event_name,
340 "test failed for event name: {} ({})",
341 name, test_name
342 );
343
344 let got_event_data: EventData = serde_json::from_str(input).unwrap();
345 assert_eq!(
346 got_event_data, event_data,
347 "test failed for event data: {}",
348 test_name
349 );
350 }
351 }
352}