1use core::fmt;
2use serde::{de::Error, Deserialize, Serialize};
3use std::str::FromStr;
4
5use super::message::{MessageResponse, StopReason};
6
7#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
8#[serde(rename_all = "snake_case")]
9pub enum StreamEvent {
10 Ping,
11 MessageStart { message: MessageResponse },
12 MessageDelta(MessageDelta),
13 MessageStop,
14 ContentBlockStart(ContentBlockStart),
15 ContentBlockDelta(ContentBlockDelta),
16 ContentBlockStop(ContentBlockStop),
17}
18
19impl FromStr for StreamEvent {
20 type Err = serde_json::Error;
21
22 fn from_str(s: &str) -> Result<Self, Self::Err> {
23 let value: serde_json::Value = serde_json::from_str(s)?;
24 let event_type = value["type"]
25 .as_str()
26 .ok_or_else(|| serde_json::Error::custom("Missing or invalid 'type' field"))?;
27
28 match event_type {
29 "ping" => Ok(StreamEvent::Ping),
30 "message_start" => {
31 let message: MessageResponse = serde_json::from_value(value["message"].clone())?;
32 Ok(StreamEvent::MessageStart { message })
33 }
34 "content_block_start" => {
35 let message: ContentBlockStart = serde_json::from_value(value)?;
36 Ok(StreamEvent::ContentBlockStart(message))
37 }
38 "content_block_delta" => {
39 let message: ContentBlockDelta = serde_json::from_value(value)?;
40 Ok(StreamEvent::ContentBlockDelta(message))
41 }
42 "content_block_stop" => {
43 let message: ContentBlockStop = serde_json::from_value(value)?;
44 Ok(StreamEvent::ContentBlockStop(message))
45 }
46 "message_delta" => {
47 let message: MessageDelta = serde_json::from_value(value)?;
48 Ok(StreamEvent::MessageDelta(message))
49 }
50 "message_stop" => Ok(StreamEvent::MessageStop),
51 _ => Ok(StreamEvent::MessageStop),
52 }
53 }
54}
55
56#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
57pub struct MessageDelta {
58 pub delta: MessageDeltaStop,
59 pub usage: StreamUsageTokens,
60}
61
62#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
63pub struct MessageDeltaStop {
64 pub stop_reason: StopReason,
65 pub stop_sequence: Option<String>,
66}
67
68#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
69pub struct StreamUsageTokens {
70 pub output_tokens: u32,
71}
72
73#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
74pub struct ContentBlockStart {
75 pub index: i64,
76 pub content_block: ContentBlock,
77}
78
79#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
80pub struct ContentBlockDelta {
81 pub index: i64,
82 pub delta: ContentBlock,
83}
84
85#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
86pub struct ContentBlockStop {
87 pub index: i64,
88}
89
90#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
91pub struct ContentBlock {
92 #[serde(rename = "type")]
93 pub kind: ContentBlockKind,
94 pub text: String,
95}
96
97#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
98#[serde(rename_all = "snake_case")]
99pub enum ContentBlockKind {
100 Text,
101 TextDelta,
102}
103
104impl fmt::Display for ContentBlockKind {
105 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
106 match self {
107 Self::Text => write!(f, "text"),
108 Self::TextDelta => write!(f, "text_delta"),
109 }
110 }
111}
112
113#[cfg(test)]
114mod tests {
115 use crate::{completion::message::RoleResponse, models::claude::ClaudeModel};
116
117 use super::*;
118 use pretty_assertions::assert_eq;
119
120 #[test]
121 fn should_deserialize_ping_event() {
122 let raw = r#"{"type": "ping"}"#;
123 let event: StreamEvent = raw.parse().unwrap();
124 assert_eq!(event, StreamEvent::Ping);
125 }
126
127 #[test]
128 fn should_deserialize_message_start_event() {
129 let raw = r#"{"type":"message_start","message":{"id":"msg_0117mpmR7a2JEj2Z1G4jqjkf","type":"message","role":"assistant","model":"claude-3-5-sonnet-20240620","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":9,"output_tokens":3}}}"#;
130 let event: StreamEvent = raw.parse().unwrap();
131
132 if let StreamEvent::MessageStart { message } = event {
133 assert_eq!(message.id, "msg_0117mpmR7a2JEj2Z1G4jqjkf");
134 assert_eq!(message.role, RoleResponse::Assistant);
135 assert_eq!(message.model, ClaudeModel::Claude35Sonnet);
136 assert_eq!(message.content.is_empty(), true);
137 assert_eq!(message.stop_reason, None);
138 assert_eq!(message.stop_sequence, None);
139 assert_eq!(message.usage.input_tokens, 9);
140 assert_eq!(message.usage.output_tokens, 3);
141 } else {
142 panic!("Expected 'message_start' event");
143 }
144 }
145
146 #[test]
147 fn should_deserialize_content_block_start_event() {
148 let raw =
149 r#"{"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}"#;
150 let event: StreamEvent = raw.parse().unwrap();
151
152 if let StreamEvent::ContentBlockStart(content) = event {
153 assert_eq!(content.index, 0);
154 assert_eq!(content.content_block.kind, ContentBlockKind::Text);
155 assert_eq!(content.content_block.text, "");
156 } else {
157 panic!("Expected 'content_block_start' event");
158 }
159 }
160
161 #[test]
162 fn should_deserialize_content_block_delta_event() {
163 let raw = r#"{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hello!"}}"#;
164 let event: StreamEvent = raw.parse().unwrap();
165
166 if let StreamEvent::ContentBlockDelta(content) = event {
167 assert_eq!(content.index, 0);
168 assert_eq!(content.delta.kind, ContentBlockKind::TextDelta);
169 assert_eq!(content.delta.text, "Hello!");
170 } else {
171 panic!("Expected 'content_block_delta' event");
172 }
173 }
174
175 #[test]
176 fn should_deserialize_content_block_stop_event() {
177 let raw = r#"{"type":"content_block_stop","index":0}"#;
178 let event: StreamEvent = raw.parse().unwrap();
179
180 if let StreamEvent::ContentBlockStop(content) = event {
181 assert_eq!(content.index, 0);
182 } else {
183 panic!("Expected 'content_block_stop' event");
184 }
185 }
186
187 #[test]
188 fn should_deserialize_message_delta_event() {
189 let raw = r#"{"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"output_tokens":30}}"#;
190 let event: StreamEvent = raw.parse().unwrap();
191
192 if let StreamEvent::MessageDelta(content) = event {
193 assert_eq!(content.delta.stop_reason, StopReason::EndTurn);
194 assert_eq!(content.delta.stop_sequence, None);
195 assert_eq!(content.usage.output_tokens, 30);
196 } else {
197 panic!("Expected 'message_delta' event");
198 }
199 }
200
201 #[test]
202 fn should_deserialize_message_stop_event() {
203 let raw = r#"{"type":"message_stop"}"#;
204 let event: StreamEvent = raw.parse().unwrap();
205 assert_eq!(event, StreamEvent::MessageStop);
206 }
207}