anthropic_rs/completion/
stream.rs

1use core::fmt;
2use serde::{de::Error, Deserialize, Serialize};
3use std::str::FromStr;
4
5use super::message::{MessageResponse, StopReason};
6
7#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
8#[serde(rename_all = "snake_case")]
9pub enum StreamEvent {
10    Ping,
11    MessageStart { message: MessageResponse },
12    MessageDelta(MessageDelta),
13    MessageStop,
14    ContentBlockStart(ContentBlockStart),
15    ContentBlockDelta(ContentBlockDelta),
16    ContentBlockStop(ContentBlockStop),
17}
18
19impl FromStr for StreamEvent {
20    type Err = serde_json::Error;
21
22    fn from_str(s: &str) -> Result<Self, Self::Err> {
23        let value: serde_json::Value = serde_json::from_str(s)?;
24        let event_type = value["type"]
25            .as_str()
26            .ok_or_else(|| serde_json::Error::custom("Missing or invalid 'type' field"))?;
27
28        match event_type {
29            "ping" => Ok(StreamEvent::Ping),
30            "message_start" => {
31                let message: MessageResponse = serde_json::from_value(value["message"].clone())?;
32                Ok(StreamEvent::MessageStart { message })
33            }
34            "content_block_start" => {
35                let message: ContentBlockStart = serde_json::from_value(value)?;
36                Ok(StreamEvent::ContentBlockStart(message))
37            }
38            "content_block_delta" => {
39                let message: ContentBlockDelta = serde_json::from_value(value)?;
40                Ok(StreamEvent::ContentBlockDelta(message))
41            }
42            "content_block_stop" => {
43                let message: ContentBlockStop = serde_json::from_value(value)?;
44                Ok(StreamEvent::ContentBlockStop(message))
45            }
46            "message_delta" => {
47                let message: MessageDelta = serde_json::from_value(value)?;
48                Ok(StreamEvent::MessageDelta(message))
49            }
50            "message_stop" => Ok(StreamEvent::MessageStop),
51            _ => Ok(StreamEvent::MessageStop),
52        }
53    }
54}
55
56#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
57pub struct MessageDelta {
58    pub delta: MessageDeltaStop,
59    pub usage: StreamUsageTokens,
60}
61
62#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
63pub struct MessageDeltaStop {
64    pub stop_reason: StopReason,
65    pub stop_sequence: Option<String>,
66}
67
68#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
69pub struct StreamUsageTokens {
70    pub output_tokens: u32,
71}
72
73#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
74pub struct ContentBlockStart {
75    pub index: i64,
76    pub content_block: ContentBlock,
77}
78
79#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
80pub struct ContentBlockDelta {
81    pub index: i64,
82    pub delta: ContentBlock,
83}
84
85#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
86pub struct ContentBlockStop {
87    pub index: i64,
88}
89
90#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
91pub struct ContentBlock {
92    #[serde(rename = "type")]
93    pub kind: ContentBlockKind,
94    pub text: String,
95}
96
97#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
98#[serde(rename_all = "snake_case")]
99pub enum ContentBlockKind {
100    Text,
101    TextDelta,
102}
103
104impl fmt::Display for ContentBlockKind {
105    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
106        match self {
107            Self::Text => write!(f, "text"),
108            Self::TextDelta => write!(f, "text_delta"),
109        }
110    }
111}
112
113#[cfg(test)]
114mod tests {
115    use crate::{completion::message::RoleResponse, models::claude::ClaudeModel};
116
117    use super::*;
118    use pretty_assertions::assert_eq;
119
120    #[test]
121    fn should_deserialize_ping_event() {
122        let raw = r#"{"type": "ping"}"#;
123        let event: StreamEvent = raw.parse().unwrap();
124        assert_eq!(event, StreamEvent::Ping);
125    }
126
127    #[test]
128    fn should_deserialize_message_start_event() {
129        let raw = r#"{"type":"message_start","message":{"id":"msg_0117mpmR7a2JEj2Z1G4jqjkf","type":"message","role":"assistant","model":"claude-3-5-sonnet-20240620","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":9,"output_tokens":3}}}"#;
130        let event: StreamEvent = raw.parse().unwrap();
131
132        if let StreamEvent::MessageStart { message } = event {
133            assert_eq!(message.id, "msg_0117mpmR7a2JEj2Z1G4jqjkf");
134            assert_eq!(message.role, RoleResponse::Assistant);
135            assert_eq!(message.model, ClaudeModel::Claude35Sonnet);
136            assert_eq!(message.content.is_empty(), true);
137            assert_eq!(message.stop_reason, None);
138            assert_eq!(message.stop_sequence, None);
139            assert_eq!(message.usage.input_tokens, 9);
140            assert_eq!(message.usage.output_tokens, 3);
141        } else {
142            panic!("Expected 'message_start' event");
143        }
144    }
145
146    #[test]
147    fn should_deserialize_content_block_start_event() {
148        let raw =
149            r#"{"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}"#;
150        let event: StreamEvent = raw.parse().unwrap();
151
152        if let StreamEvent::ContentBlockStart(content) = event {
153            assert_eq!(content.index, 0);
154            assert_eq!(content.content_block.kind, ContentBlockKind::Text);
155            assert_eq!(content.content_block.text, "");
156        } else {
157            panic!("Expected 'content_block_start' event");
158        }
159    }
160
161    #[test]
162    fn should_deserialize_content_block_delta_event() {
163        let raw = r#"{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hello!"}}"#;
164        let event: StreamEvent = raw.parse().unwrap();
165
166        if let StreamEvent::ContentBlockDelta(content) = event {
167            assert_eq!(content.index, 0);
168            assert_eq!(content.delta.kind, ContentBlockKind::TextDelta);
169            assert_eq!(content.delta.text, "Hello!");
170        } else {
171            panic!("Expected 'content_block_delta' event");
172        }
173    }
174
175    #[test]
176    fn should_deserialize_content_block_stop_event() {
177        let raw = r#"{"type":"content_block_stop","index":0}"#;
178        let event: StreamEvent = raw.parse().unwrap();
179
180        if let StreamEvent::ContentBlockStop(content) = event {
181            assert_eq!(content.index, 0);
182        } else {
183            panic!("Expected 'content_block_stop' event");
184        }
185    }
186
187    #[test]
188    fn should_deserialize_message_delta_event() {
189        let raw = r#"{"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"output_tokens":30}}"#;
190        let event: StreamEvent = raw.parse().unwrap();
191
192        if let StreamEvent::MessageDelta(content) = event {
193            assert_eq!(content.delta.stop_reason, StopReason::EndTurn);
194            assert_eq!(content.delta.stop_sequence, None);
195            assert_eq!(content.usage.output_tokens, 30);
196        } else {
197            panic!("Expected 'message_delta' event");
198        }
199    }
200
201    #[test]
202    fn should_deserialize_message_stop_event() {
203        let raw = r#"{"type":"message_stop"}"#;
204        let event: StreamEvent = raw.parse().unwrap();
205        assert_eq!(event, StreamEvent::MessageStop);
206    }
207}