async_claude/messages/
stream_response.rs

1use std::{
2    convert::Infallible,
3    fmt::{self, Display, Formatter},
4    str::FromStr,
5};
6
7use serde::{Deserialize, Serialize};
8
9use super::{response::Response, BaseContentBlock, DeltaContentBlock, StopReason, Usage};
10
11#[derive(Debug, Deserialize, Clone, PartialEq, Serialize)]
12#[serde(rename_all = "snake_case")]
13pub enum EventName {
14    Unspecified,
15    Error,
16    MessageStart,
17    ContentBlockDelta,
18    ContentBlockStart,
19    Ping,
20    ContentBlockStop,
21    MessageDelta,
22    MessageStop,
23}
24
25impl FromStr for EventName {
26    type Err = Infallible;
27    fn from_str(s: &str) -> Result<Self, Self::Err> {
28        match s {
29            "error" => Ok(EventName::Error),
30            "message_start" => Ok(EventName::MessageStart),
31            "content_block_start" => Ok(EventName::ContentBlockStart),
32            "ping" => Ok(EventName::Ping),
33            "content_block_delta" => Ok(EventName::ContentBlockDelta),
34            "content_block_stop" => Ok(EventName::ContentBlockStop),
35            "message_delta" => Ok(EventName::MessageDelta),
36            "message_stop" => Ok(EventName::MessageStop),
37            _ => Ok(EventName::Unspecified),
38        }
39    }
40}
41
42#[derive(Debug, Deserialize, Clone, PartialEq, Serialize)]
43#[serde(rename_all = "snake_case", tag = "type")]
44pub enum EventData {
45    Error {
46        error: ErrorData,
47    },
48    MessageStart {
49        message: Response,
50    },
51    ContentBlockStart {
52        index: u32,
53        content_block: BaseContentBlock,
54    },
55    Ping,
56    ContentBlockDelta {
57        index: u32,
58        delta: DeltaContentBlock,
59    },
60    ContentBlockStop {
61        index: u32,
62    },
63    MessageDelta {
64        delta: MessageDelta,
65        usage: Usage,
66    },
67    MessageStop,
68}
69
70#[derive(Debug, Deserialize, Clone, PartialEq, Serialize)]
71#[serde(rename_all = "snake_case", tag = "type")]
72pub enum ErrorData {
73    OverloadedError { message: String },
74    // Additional error types
75    InternalServerError { message: String },
76    BadRequestError { message: String },
77    UnauthorizedError { message: String },
78}
79
80impl Display for ErrorData {
81    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
82        match self {
83            ErrorData::OverloadedError { message } => write!(f, "OverloadedError: {}", message),
84            ErrorData::InternalServerError { message } => {
85                write!(f, "InternalServerError: {}", message)
86            }
87            ErrorData::BadRequestError { message } => write!(f, "BadRequestError: {}", message),
88            ErrorData::UnauthorizedError { message } => write!(f, "UnauthorizedError: {}", message),
89        }
90    }
91}
92
93#[derive(Debug, Deserialize, Clone, PartialEq, Serialize)]
94pub struct MessageDelta {
95    pub stop_reason: StopReason,
96    pub stop_sequence: Option<String>,
97}
98
99#[cfg(test)]
100mod tests {
101    use super::*;
102    use crate::messages::{Role, ToolUseContentBlock};
103    #[test]
104    fn serde() {
105        let tests = vec![
106            (
107                "error_overloaded",
108                "error",
109                r#"{"type": "error", "error": {"type": "overloaded_error", "message": "Overloaded"}}"#,
110                EventName::Error,
111                EventData::Error {
112                    error: ErrorData::OverloadedError {
113                        message: "Overloaded".to_string(),
114                    },
115                },
116            ),
117            (
118                "message_start_empty_content",
119                "message_start",
120                r#"{"type":"message_start","message":{"id":"msg_019LBLYFJ7fG3fuAqzuRQbyi","type":"message","role":"assistant","content":[],"model":"claude-3-opus-20240229","stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":10,"output_tokens":1}}}"#,
121                EventName::MessageStart,
122                EventData::MessageStart {
123                    message: Response {
124                        id: "msg_019LBLYFJ7fG3fuAqzuRQbyi".to_string(),
125                        r#type: "message".to_string(),
126                        role: Role::Assistant,
127                        content: vec![],
128                        model: "claude-3-opus-20240229".to_string(),
129                        stop_reason: None,
130                        stop_sequence: None,
131                        usage: Usage {
132                            input_tokens: Some(10),
133                            output_tokens: 1,
134                        },
135                    },
136                },
137            ),
138            (
139                "content_block_start_empty_text",
140                "content_block_start",
141                r#"{"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}"#,
142                EventName::ContentBlockStart,
143                EventData::ContentBlockStart {
144                    index: 0,
145                    content_block: BaseContentBlock::Text {
146                        text: "".to_string(),
147                    },
148                },
149            ),
150            (
151                "ping_event",
152                "ping",
153                r#"{"type": "ping"}"#,
154                EventName::Ping,
155                EventData::Ping,
156            ),
157            (
158                "content_block_delta_hello",
159                "content_block_delta",
160                r#"{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hello"}}"#,
161                EventName::ContentBlockDelta,
162                EventData::ContentBlockDelta {
163                    index: 0,
164                    delta: DeltaContentBlock::TextDelta {
165                        text: "Hello".to_string(),
166                    },
167                },
168            ),
169            (
170                "content_block_delta_exclamation",
171                "content_block_delta",
172                r#"{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"!"}}"#,
173                EventName::ContentBlockDelta,
174                EventData::ContentBlockDelta {
175                    index: 0,
176                    delta: DeltaContentBlock::TextDelta {
177                        text: "!".to_string(),
178                    },
179                },
180            ),
181            (
182                "content_block_stop_index_0",
183                "content_block_stop",
184                r#"{"type":"content_block_stop","index":0}"#,
185                EventName::ContentBlockStop,
186                EventData::ContentBlockStop { index: 0 },
187            ),
188            (
189                "message_delta_end_turn",
190                "message_delta",
191                r#"{"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"output_tokens":12}}"#,
192                EventName::MessageDelta,
193                EventData::MessageDelta {
194                    delta: MessageDelta {
195                        stop_reason: StopReason::EndTurn,
196                        stop_sequence: None,
197                    },
198                    usage: Usage {
199                        input_tokens: None,
200                        output_tokens: 12,
201                    },
202                },
203            ),
204            (
205                "message_stop_event",
206                "message_stop",
207                r#"{"type":"message_stop"}"#,
208                EventName::MessageStop,
209                EventData::MessageStop,
210            ),
211            // New test cases based on the latest Anthropic API documentation
212            (
213                "content_block_start_tool_use",
214                "content_block_start",
215                r#"{"type":"content_block_start","index":1,"content_block":{"type":"tool_use","id":"tu_01AbCdEfGhIjKlMnOpQrStUv","name":"weather_forecast","input":{}}}"#,
216                EventName::ContentBlockStart,
217                EventData::ContentBlockStart {
218                    index: 1,
219                    content_block: BaseContentBlock::ToolUse(ToolUseContentBlock {
220                        id: "tu_01AbCdEfGhIjKlMnOpQrStUv".to_string(),
221                        name: "weather_forecast".to_string(),
222                        input: serde_json::json!({}),
223                    }),
224                },
225            ),
226            (
227                "content_block_delta_input_json_start",
228                "content_block_delta",
229                r#"{"type":"content_block_delta","index":1,"delta":{"type":"input_json_delta","partial_json":"{\"location\": \"San Fra\"}"}}"#,
230                EventName::ContentBlockDelta,
231                EventData::ContentBlockDelta {
232                    index: 1,
233                    delta: DeltaContentBlock::InputJsonDelta {
234                        partial_json: "{\"location\": \"San Fra\"}".to_string(),
235                    },
236                },
237            ),
238            (
239                "content_block_delta_input_json_continuation",
240                "content_block_delta",
241                r#"{"type":"content_block_delta","index":1,"delta":{"type":"input_json_delta","partial_json":"ncisco\"}"}}"#,
242                EventName::ContentBlockDelta,
243                EventData::ContentBlockDelta {
244                    index: 1,
245                    delta: DeltaContentBlock::InputJsonDelta {
246                        partial_json: "ncisco\"}".to_string(),
247                    },
248                },
249            ),
250            (
251                "content_block_start_thinking",
252                "content_block_start",
253                r#"{"type":"content_block_start","index":2,"content_block":{"type":"thinking","thinking":"","signature":null}}"#,
254                EventName::ContentBlockStart,
255                EventData::ContentBlockStart {
256                    index: 2,
257                    content_block: BaseContentBlock::Thinking {
258                        thinking: "".to_string(),
259                        signature: None,
260                    },
261                },
262            ),
263            (
264                "content_block_delta_thinking",
265                "content_block_delta",
266                r#"{"type":"content_block_delta","index":2,"delta":{"type":"thinking_delta","thinking":"Let me solve this step by step:\n\n1. First break down 27 * 453"}}"#,
267                EventName::ContentBlockDelta,
268                EventData::ContentBlockDelta {
269                    index: 2,
270                    delta: DeltaContentBlock::ThinkingDelta {
271                        thinking: "Let me solve this step by step:\n\n1. First break down 27 * 453"
272                            .to_string(),
273                    },
274                },
275            ),
276            (
277                "content_block_delta_signature",
278                "content_block_delta",
279                r#"{"type":"content_block_delta","index":2,"delta":{"type":"signature_delta","signature":"EqQBCgIYAhIM1gbcDa9GJwZA2b3hGgxBdjrkzLoky3dl1pkiMOYds..."}}"#,
280                EventName::ContentBlockDelta,
281                EventData::ContentBlockDelta {
282                    index: 2,
283                    delta: DeltaContentBlock::SignatureDelta {
284                        signature: "EqQBCgIYAhIM1gbcDa9GJwZA2b3hGgxBdjrkzLoky3dl1pkiMOYds..."
285                            .to_string(),
286                    },
287                },
288            ),
289            (
290                "message_delta_max_tokens",
291                "message_delta",
292                r#"{"type":"message_delta","delta":{"stop_reason":"max_tokens","stop_sequence":null},"usage":{"output_tokens":1024}}"#,
293                EventName::MessageDelta,
294                EventData::MessageDelta {
295                    delta: MessageDelta {
296                        stop_reason: StopReason::MaxTokens,
297                        stop_sequence: None,
298                    },
299                    usage: Usage {
300                        input_tokens: None,
301                        output_tokens: 1024,
302                    },
303                },
304            ),
305            (
306                "message_delta_stop_sequence",
307                "message_delta",
308                r#"{"type":"message_delta","delta":{"stop_reason":"stop_sequence","stop_sequence":"STOP"},"usage":{"output_tokens":45}}"#,
309                EventName::MessageDelta,
310                EventData::MessageDelta {
311                    delta: MessageDelta {
312                        stop_reason: StopReason::StopSequence,
313                        stop_sequence: Some("STOP".to_string()),
314                    },
315                    usage: Usage {
316                        input_tokens: None,
317                        output_tokens: 45,
318                    },
319                },
320            ),
321            (
322                "content_block_start_tool_result",
323                "content_block_start",
324                r#"{"type":"content_block_start","index":1,"content_block":{"type":"tool_use","id":"toolu_01T1x1fJ34qAmk2tNTrN7Up6","name":"get_weather","input":{}}}"#,
325                EventName::ContentBlockStart,
326                EventData::ContentBlockStart {
327                    index: 1,
328                    content_block: BaseContentBlock::ToolUse(ToolUseContentBlock {
329                        id: "toolu_01T1x1fJ34qAmk2tNTrN7Up6".to_string(),
330                        name: "get_weather".to_string(),
331                        input: serde_json::json!({}),
332                    }),
333                },
334            ),
335        ];
336        for (test_name, name, input, event_name, event_data) in tests {
337            let got_event_name = EventName::from_str(name).unwrap();
338            assert_eq!(
339                got_event_name, event_name,
340                "test failed for event name: {} ({})",
341                name, test_name
342            );
343
344            let got_event_data: EventData = serde_json::from_str(input).unwrap();
345            assert_eq!(
346                got_event_data, event_data,
347                "test failed for event data: {}",
348                test_name
349            );
350        }
351    }
352}