use crate::types::items::ThreadItem;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum ThreadEvent {
#[serde(rename = "thread.started")]
ThreadStarted { thread_id: String },
#[serde(rename = "turn.started")]
TurnStarted,
#[serde(rename = "turn.completed")]
TurnCompleted { usage: Usage },
#[serde(rename = "turn.failed")]
TurnFailed { error: ThreadError },
#[serde(rename = "item.started")]
ItemStarted { item: ThreadItem },
#[serde(rename = "item.updated")]
ItemUpdated { item: ThreadItem },
#[serde(rename = "item.completed")]
ItemCompleted { item: ThreadItem },
#[serde(rename = "exec_approval_request")]
ApprovalRequest(ApprovalRequestEvent),
#[serde(rename = "apply_patch_approval_request")]
PatchApprovalRequest(PatchApprovalRequestEvent),
#[serde(rename = "error")]
Error { message: String },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ApprovalRequestEvent {
pub id: String,
#[serde(default)]
pub command: String,
#[serde(default)]
pub cwd: Option<std::path::PathBuf>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PatchApprovalRequestEvent {
pub id: String,
#[serde(default)]
pub changes: std::collections::HashMap<String, serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ThreadError {
pub message: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Usage {
#[serde(default)]
pub input_tokens: u64,
#[serde(default)]
pub cached_input_tokens: u64,
#[serde(default)]
pub output_tokens: u64,
}
#[derive(Debug, Clone)]
pub struct Turn {
pub events: Vec<ThreadEvent>,
pub final_response: String,
pub usage: Option<Usage>,
}
pub struct StreamedTurn {
inner: std::pin::Pin<Box<dyn futures_core::Stream<Item = crate::Result<ThreadEvent>> + Send>>,
}
impl StreamedTurn {
pub(crate) fn new(
stream: impl futures_core::Stream<Item = crate::Result<ThreadEvent>> + Send + 'static,
) -> Self {
Self {
inner: Box::pin(stream),
}
}
}
impl futures_core::Stream for StreamedTurn {
type Item = crate::Result<ThreadEvent>;
fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
self.inner.as_mut().poll_next(cx)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn thread_started_round_trip() {
let event = ThreadEvent::ThreadStarted {
thread_id: "test-123".into(),
};
let json = serde_json::to_string(&event).unwrap();
let parsed: ThreadEvent = serde_json::from_str(&json).unwrap();
let ThreadEvent::ThreadStarted { thread_id } = parsed else {
panic!("wrong variant");
};
assert_eq!(thread_id, "test-123");
}
#[test]
fn turn_completed_round_trip() {
let json = r#"{"type":"turn.completed","usage":{"input_tokens":100,"output_tokens":50}}"#;
let event: ThreadEvent = serde_json::from_str(json).unwrap();
let ThreadEvent::TurnCompleted { usage } = event else {
panic!("wrong variant");
};
assert_eq!(usage.input_tokens, 100);
assert_eq!(usage.output_tokens, 50);
assert_eq!(usage.cached_input_tokens, 0); }
#[test]
fn item_started_agent_message() {
let json =
r#"{"type":"item.started","item":{"type":"agent_message","id":"msg-1","text":""}}"#;
let event: ThreadEvent = serde_json::from_str(json).unwrap();
let ThreadEvent::ItemStarted { item } = event else {
panic!("wrong variant");
};
assert_eq!(item.id(), "msg-1");
}
#[test]
fn approval_request_round_trip() {
let json =
r#"{"type":"exec_approval_request","id":"ap-1","command":"rm -rf /","cwd":null}"#;
let event: ThreadEvent = serde_json::from_str(json).unwrap();
let ThreadEvent::ApprovalRequest(req) = event else {
panic!("wrong variant");
};
assert_eq!(req.id, "ap-1");
assert_eq!(req.command, "rm -rf /");
}
#[test]
fn error_event() {
let json = r#"{"type":"error","message":"something broke"}"#;
let event: ThreadEvent = serde_json::from_str(json).unwrap();
let ThreadEvent::Error { message } = event else {
panic!("wrong variant");
};
assert_eq!(message, "something broke");
}
}