1use serde::{Deserialize, Serialize};
2
3use super::items::ThreadItem;
4
5#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct Usage {
8 pub input_tokens: u64,
9 pub cached_input_tokens: u64,
10 pub output_tokens: u64,
11}
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct ThreadError {
16 pub message: String,
17}
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct ThreadStartedEvent {
22 pub thread_id: String,
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct TurnStartedEvent {}
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct TurnCompletedEvent {
32 pub usage: Usage,
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct TurnFailedEvent {
38 pub error: ThreadError,
39}
40
41#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct ItemStartedEvent {
44 pub item: ThreadItem,
45}
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct ItemUpdatedEvent {
50 pub item: ThreadItem,
51}
52
53#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct ItemCompletedEvent {
56 pub item: ThreadItem,
57}
58
59#[derive(Debug, Clone, Serialize, Deserialize)]
61pub struct ThreadErrorEvent {
62 pub message: String,
63}
64
65#[derive(Debug, Clone, Serialize, Deserialize)]
67#[serde(tag = "type", rename_all = "snake_case")]
68pub enum ThreadEvent {
69 #[serde(rename = "thread.started")]
70 ThreadStarted(ThreadStartedEvent),
71 #[serde(rename = "turn.started")]
72 TurnStarted(TurnStartedEvent),
73 #[serde(rename = "turn.completed")]
74 TurnCompleted(TurnCompletedEvent),
75 #[serde(rename = "turn.failed")]
76 TurnFailed(TurnFailedEvent),
77 #[serde(rename = "item.started")]
78 ItemStarted(ItemStartedEvent),
79 #[serde(rename = "item.updated")]
80 ItemUpdated(ItemUpdatedEvent),
81 #[serde(rename = "item.completed")]
82 ItemCompleted(ItemCompletedEvent),
83 Error(ThreadErrorEvent),
84}
85
86impl ThreadEvent {
87 pub fn event_type(&self) -> &str {
89 match self {
90 ThreadEvent::ThreadStarted(_) => "thread.started",
91 ThreadEvent::TurnStarted(_) => "turn.started",
92 ThreadEvent::TurnCompleted(_) => "turn.completed",
93 ThreadEvent::TurnFailed(_) => "turn.failed",
94 ThreadEvent::ItemStarted(_) => "item.started",
95 ThreadEvent::ItemUpdated(_) => "item.updated",
96 ThreadEvent::ItemCompleted(_) => "item.completed",
97 ThreadEvent::Error(_) => "error",
98 }
99 }
100}
101
102#[cfg(test)]
103mod tests {
104 use super::*;
105
106 #[test]
107 fn test_deserialize_thread_started() {
108 let json = r#"{"type":"thread.started","thread_id":"th_abc123"}"#;
109 let event: ThreadEvent = serde_json::from_str(json).unwrap();
110 assert!(matches!(event, ThreadEvent::ThreadStarted(ref e) if e.thread_id == "th_abc123"));
111 assert_eq!(event.event_type(), "thread.started");
112 }
113
114 #[test]
115 fn test_deserialize_turn_started() {
116 let json = r#"{"type":"turn.started"}"#;
117 let event: ThreadEvent = serde_json::from_str(json).unwrap();
118 assert!(matches!(event, ThreadEvent::TurnStarted(_)));
119 }
120
121 #[test]
122 fn test_deserialize_turn_completed() {
123 let json = r#"{"type":"turn.completed","usage":{"input_tokens":100,"cached_input_tokens":50,"output_tokens":200}}"#;
124 let event: ThreadEvent = serde_json::from_str(json).unwrap();
125 if let ThreadEvent::TurnCompleted(e) = &event {
126 assert_eq!(e.usage.input_tokens, 100);
127 assert_eq!(e.usage.cached_input_tokens, 50);
128 assert_eq!(e.usage.output_tokens, 200);
129 } else {
130 panic!("Expected TurnCompleted");
131 }
132 }
133
134 #[test]
135 fn test_deserialize_turn_failed() {
136 let json = r#"{"type":"turn.failed","error":{"message":"rate limited"}}"#;
137 let event: ThreadEvent = serde_json::from_str(json).unwrap();
138 assert!(
139 matches!(event, ThreadEvent::TurnFailed(ref e) if e.error.message == "rate limited")
140 );
141 }
142
143 #[test]
144 fn test_deserialize_item_started() {
145 let json = r#"{"type":"item.started","item":{"type":"agent_message","id":"msg_1","text":"Starting..."}}"#;
146 let event: ThreadEvent = serde_json::from_str(json).unwrap();
147 assert!(matches!(event, ThreadEvent::ItemStarted(_)));
148 }
149
150 #[test]
151 fn test_deserialize_error_event() {
152 let json = r#"{"type":"error","message":"connection lost"}"#;
153 let event: ThreadEvent = serde_json::from_str(json).unwrap();
154 assert!(matches!(event, ThreadEvent::Error(ref e) if e.message == "connection lost"));
155 }
156}