use serde::{Deserialize, Serialize};
use super::error::ObserverRecvError;
use super::pause::{PauseInfo, PauseKind};
use super::state::ExecutionStateTag;
use crate::TokenUsage;
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "kind", rename_all = "snake_case")]
pub enum ProgressEvent {
StateTransition {
from: ExecutionStateTag,
to: ExecutionStateTag,
at: i64,
},
PauseRequested {
info: PauseInfo,
at: i64,
},
ResumeAccepted {
payload_kind: PauseKind,
at: i64,
},
Note {
title: Option<String>,
content: String,
at: i64,
},
LlmCallBegin {
query_id: String,
at: i64,
},
LlmCallEnd {
query_id: String,
usage: Option<TokenUsage>,
at: i64,
},
Tick {
phase: String,
at: i64,
},
}
pub trait ObserverHandle: Send {
fn recv(
&mut self,
) -> std::pin::Pin<
Box<dyn std::future::Future<Output = Result<ProgressEvent, ObserverRecvError>> + Send + '_>,
>;
fn try_recv(&mut self) -> Result<ProgressEvent, ObserverRecvError>;
fn close(self: Box<Self>);
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn progress_event_serde_tagged_kind() {
let event = ProgressEvent::StateTransition {
from: ExecutionStateTag::Running,
to: ExecutionStateTag::Paused,
at: 1_700_000_000_000,
};
let json = serde_json::to_string(&event).expect("serialize");
assert!(
json.contains(r#""kind":"state_transition""#),
"expected tagged kind in JSON, got: {json}"
);
let roundtripped: ProgressEvent = serde_json::from_str(&json).expect("deserialize");
let json2 = serde_json::to_string(&roundtripped).expect("re-serialize");
assert_eq!(json, json2);
}
#[test]
fn all_progress_event_variants_serde() {
use crate::execution::pause::{PauseInfo, PauseKind};
let events: Vec<ProgressEvent> = vec![
ProgressEvent::StateTransition {
from: ExecutionStateTag::Running,
to: ExecutionStateTag::Done,
at: 0,
},
ProgressEvent::PauseRequested {
info: PauseInfo {
kind: PauseKind::Single,
prompts: vec![],
paused_at: 0,
},
at: 0,
},
ProgressEvent::ResumeAccepted {
payload_kind: PauseKind::Batch,
at: 0,
},
ProgressEvent::Note {
title: Some("test".into()),
content: "hello".into(),
at: 0,
},
ProgressEvent::LlmCallBegin {
query_id: "q1".into(),
at: 0,
},
ProgressEvent::LlmCallEnd {
query_id: "q1".into(),
usage: None,
at: 0,
},
ProgressEvent::Tick {
phase: "running".into(),
at: 0,
},
];
for event in events {
let json = serde_json::to_string(&event).expect("serialize");
let _: ProgressEvent = serde_json::from_str(&json).expect("deserialize");
}
}
}