1use serde::{Deserialize, Serialize};
19
20use super::items::ThreadItem;
21
22#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct Usage {
25 pub input_tokens: u64,
26 pub cached_input_tokens: u64,
27 pub output_tokens: u64,
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct ThreadError {
33 pub message: String,
34}
35
36#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct ThreadStartedEvent {
39 pub thread_id: String,
40}
41
42#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct TurnStartedEvent {}
45
46#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct TurnCompletedEvent {
49 pub usage: Usage,
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct TurnFailedEvent {
55 pub error: ThreadError,
56}
57
58#[derive(Debug, Clone, Serialize, Deserialize)]
60pub struct ItemStartedEvent {
61 pub item: ThreadItem,
62}
63
64#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct ItemUpdatedEvent {
67 pub item: ThreadItem,
68}
69
70#[derive(Debug, Clone, Serialize, Deserialize)]
72pub struct ItemCompletedEvent {
73 pub item: ThreadItem,
74}
75
76#[derive(Debug, Clone, Serialize, Deserialize)]
78pub struct ThreadErrorEvent {
79 pub message: String,
80}
81
82#[derive(Debug, Clone, Serialize, Deserialize)]
87#[serde(tag = "type", rename_all = "snake_case")]
88pub enum ThreadEvent {
89 #[serde(rename = "thread.started")]
90 ThreadStarted(ThreadStartedEvent),
91 #[serde(rename = "turn.started")]
92 TurnStarted(TurnStartedEvent),
93 #[serde(rename = "turn.completed")]
94 TurnCompleted(TurnCompletedEvent),
95 #[serde(rename = "turn.failed")]
96 TurnFailed(TurnFailedEvent),
97 #[serde(rename = "item.started")]
98 ItemStarted(ItemStartedEvent),
99 #[serde(rename = "item.updated")]
100 ItemUpdated(ItemUpdatedEvent),
101 #[serde(rename = "item.completed")]
102 ItemCompleted(ItemCompletedEvent),
103 Error(ThreadErrorEvent),
104}
105
106impl ThreadEvent {
107 pub fn event_type(&self) -> &str {
109 match self {
110 ThreadEvent::ThreadStarted(_) => "thread.started",
111 ThreadEvent::TurnStarted(_) => "turn.started",
112 ThreadEvent::TurnCompleted(_) => "turn.completed",
113 ThreadEvent::TurnFailed(_) => "turn.failed",
114 ThreadEvent::ItemStarted(_) => "item.started",
115 ThreadEvent::ItemUpdated(_) => "item.updated",
116 ThreadEvent::ItemCompleted(_) => "item.completed",
117 ThreadEvent::Error(_) => "error",
118 }
119 }
120}
121
122#[cfg(test)]
123mod tests {
124 use super::*;
125
126 #[test]
127 fn test_deserialize_thread_started() {
128 let json = r#"{"type":"thread.started","thread_id":"th_abc123"}"#;
129 let event: ThreadEvent = serde_json::from_str(json).unwrap();
130 assert!(matches!(event, ThreadEvent::ThreadStarted(ref e) if e.thread_id == "th_abc123"));
131 assert_eq!(event.event_type(), "thread.started");
132 }
133
134 #[test]
135 fn test_deserialize_turn_started() {
136 let json = r#"{"type":"turn.started"}"#;
137 let event: ThreadEvent = serde_json::from_str(json).unwrap();
138 assert!(matches!(event, ThreadEvent::TurnStarted(_)));
139 }
140
141 #[test]
142 fn test_deserialize_turn_completed() {
143 let json = r#"{"type":"turn.completed","usage":{"input_tokens":100,"cached_input_tokens":50,"output_tokens":200}}"#;
144 let event: ThreadEvent = serde_json::from_str(json).unwrap();
145 if let ThreadEvent::TurnCompleted(e) = &event {
146 assert_eq!(e.usage.input_tokens, 100);
147 assert_eq!(e.usage.cached_input_tokens, 50);
148 assert_eq!(e.usage.output_tokens, 200);
149 } else {
150 panic!("Expected TurnCompleted");
151 }
152 }
153
154 #[test]
155 fn test_deserialize_turn_failed() {
156 let json = r#"{"type":"turn.failed","error":{"message":"rate limited"}}"#;
157 let event: ThreadEvent = serde_json::from_str(json).unwrap();
158 assert!(
159 matches!(event, ThreadEvent::TurnFailed(ref e) if e.error.message == "rate limited")
160 );
161 }
162
163 #[test]
164 fn test_deserialize_item_started() {
165 let json = r#"{"type":"item.started","item":{"type":"agent_message","id":"msg_1","text":"Starting..."}}"#;
166 let event: ThreadEvent = serde_json::from_str(json).unwrap();
167 assert!(matches!(event, ThreadEvent::ItemStarted(_)));
168 }
169
170 #[test]
171 fn test_deserialize_error_event() {
172 let json = r#"{"type":"error","message":"connection lost"}"#;
173 let event: ThreadEvent = serde_json::from_str(json).unwrap();
174 assert!(matches!(event, ThreadEvent::Error(ref e) if e.message == "connection lost"));
175 }
176}