use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use tokio::sync::Mutex;
use crate::event::AgentEvent;
use crate::plugin::{Plugin, PluginCapabilities};
use crate::types::{AgentMessage, RunIdentity};
pub const TRAJECTORY_SCHEMA_VERSION: u32 = 1;
fn pre_versioning_schema() -> u32 {
0
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TrajectoryRecord {
#[serde(default = "pre_versioning_schema")]
pub schema_version: u32,
pub seq: u64,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub run_id: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub parent_run_id: Option<String>,
#[serde(default)]
pub depth: usize,
pub recorded_at_unix_ms: u64,
pub payload: TrajectoryPayload,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "kind", rename_all = "snake_case")]
pub enum TrajectoryPayload {
RunStarted {
identity: RunIdentity,
},
RunEnded {
outcome: String,
new_messages: Vec<AgentMessage>,
},
TurnStarted,
TurnEnded {
assistant: AgentMessage,
tool_results: Vec<AgentMessage>,
},
MessageAppended {
message: AgentMessage,
},
ToolStarted {
tool_call_id: String,
tool_name: String,
args: serde_json::Value,
},
ToolEnded {
tool_call_id: String,
tool_name: String,
result: crate::tool::ToolResult,
is_error: bool,
},
ProviderRequestPrepared {
iteration: usize,
model_id: Option<String>,
system_prompt_chars: usize,
message_count: usize,
tool_count: usize,
tools: Vec<String>,
},
ContextTransformApplied {
iteration: usize,
plugin: String,
before_count: usize,
after_count: usize,
},
ToolGateApplied {
iteration: usize,
plugin: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
allow: Option<Vec<String>>,
},
ToolGateConflictResolved {
iteration: usize,
plugins: Vec<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
chosen_plugin: Option<String>,
allow: Vec<String>,
reason: String,
},
OutputTokensEscalation {
attempt: u8,
prev_cap: u32,
new_cap: u32,
},
}
#[derive(Debug, thiserror::Error)]
pub enum TrajectoryError {
#[error("trajectory sink rejected record: {0}")]
Rejected(String),
#[error("trajectory sink i/o failure: {0}")]
Io(String),
}
#[async_trait]
pub trait TrajectorySink: Send + Sync {
async fn record(&self, record: TrajectoryRecord) -> Result<(), TrajectoryError>;
}
#[derive(Debug, Default)]
pub struct InMemoryTrajectorySink {
records: Mutex<Vec<TrajectoryRecord>>,
}
impl InMemoryTrajectorySink {
pub fn new() -> Self {
Self::default()
}
pub async fn snapshot(&self) -> Vec<TrajectoryRecord> {
self.records.lock().await.clone()
}
pub async fn len(&self) -> usize {
self.records.lock().await.len()
}
pub async fn is_empty(&self) -> bool {
self.records.lock().await.is_empty()
}
}
#[async_trait]
impl TrajectorySink for InMemoryTrajectorySink {
async fn record(&self, record: TrajectoryRecord) -> Result<(), TrajectoryError> {
self.records.lock().await.push(record);
Ok(())
}
}
pub struct TrajectoryRecorder {
sink: Arc<dyn TrajectorySink>,
seq: AtomicU64,
identity: Mutex<Option<RunIdentity>>,
}
impl TrajectoryRecorder {
pub fn new(sink: Arc<dyn TrajectorySink>) -> Self {
Self {
sink,
seq: AtomicU64::new(0),
identity: Mutex::new(None),
}
}
async fn record(&self, payload: TrajectoryPayload) {
let seq = self.seq.fetch_add(1, Ordering::SeqCst);
let identity = self.identity.lock().await.clone();
let recorded_at_unix_ms = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_millis() as u64)
.unwrap_or(0);
let record = TrajectoryRecord {
schema_version: TRAJECTORY_SCHEMA_VERSION,
seq,
run_id: identity.as_ref().map(|i| i.run_id.clone()),
parent_run_id: identity.as_ref().and_then(|i| i.parent_run_id.clone()),
depth: identity.as_ref().map(|i| i.depth).unwrap_or(0),
recorded_at_unix_ms,
payload,
};
if let Err(e) = self.sink.record(record).await {
tracing::warn!(error = %e, "trajectory sink rejected record; continuing");
}
}
}
impl Plugin for TrajectoryRecorder {
fn name(&self) -> &'static str {
"trajectory_recorder"
}
fn capabilities(&self) -> PluginCapabilities {
PluginCapabilities::event_observer()
}
}
#[async_trait]
impl crate::plugin::EventObserver for TrajectoryRecorder {
async fn on_event(&self, event: &AgentEvent) {
match event {
AgentEvent::AgentStart => {
self.seq.store(0, Ordering::SeqCst);
*self.identity.lock().await = None;
}
AgentEvent::RunIdentified { identity } => {
*self.identity.lock().await = Some(identity.clone());
self.record(TrajectoryPayload::RunStarted {
identity: identity.clone(),
})
.await;
}
AgentEvent::AgentEnd { messages } => {
self.record(TrajectoryPayload::RunEnded {
outcome: "ended".to_string(),
new_messages: messages.clone(),
})
.await;
}
AgentEvent::TurnStart => {
self.record(TrajectoryPayload::TurnStarted).await;
}
AgentEvent::TurnEnd {
message,
tool_results,
} => {
self.record(TrajectoryPayload::TurnEnded {
assistant: message.clone(),
tool_results: tool_results.clone(),
})
.await;
}
AgentEvent::MessageEnd { message } => {
self.record(TrajectoryPayload::MessageAppended {
message: message.clone(),
})
.await;
}
AgentEvent::ToolExecutionStart {
tool_call_id,
tool_name,
args,
} => {
self.record(TrajectoryPayload::ToolStarted {
tool_call_id: tool_call_id.clone(),
tool_name: tool_name.clone(),
args: args.clone(),
})
.await;
}
AgentEvent::ToolExecutionEnd {
tool_call_id,
tool_name,
result,
is_error,
} => {
self.record(TrajectoryPayload::ToolEnded {
tool_call_id: tool_call_id.clone(),
tool_name: tool_name.clone(),
result: result.clone(),
is_error: *is_error,
})
.await;
}
AgentEvent::ProviderRequestPrepared {
iteration,
model_id,
system_prompt,
messages,
tools,
..
} => {
self.record(TrajectoryPayload::ProviderRequestPrepared {
iteration: *iteration,
model_id: model_id.clone(),
system_prompt_chars: system_prompt.chars().count(),
message_count: messages.len(),
tool_count: tools.len(),
tools: tools.iter().map(|t| t.name.clone()).collect(),
})
.await;
}
AgentEvent::ContextTransformApplied {
iteration,
plugin,
before,
after,
} => {
self.record(TrajectoryPayload::ContextTransformApplied {
iteration: *iteration,
plugin: (*plugin).to_string(),
before_count: before.len(),
after_count: after.len(),
})
.await;
}
AgentEvent::ToolGateApplied {
iteration,
plugin,
allow,
} => {
self.record(TrajectoryPayload::ToolGateApplied {
iteration: *iteration,
plugin: (*plugin).to_string(),
allow: allow.clone(),
})
.await;
}
AgentEvent::ToolGateConflictResolved {
iteration,
plugins,
chosen_plugin,
allow,
reason,
} => {
self.record(TrajectoryPayload::ToolGateConflictResolved {
iteration: *iteration,
plugins: plugins.clone(),
chosen_plugin: chosen_plugin.clone(),
allow: allow.clone(),
reason: reason.clone(),
})
.await;
}
AgentEvent::OutputTokensEscalation {
attempt,
prev_cap,
new_cap,
} => {
self.record(TrajectoryPayload::OutputTokensEscalation {
attempt: *attempt,
prev_cap: *prev_cap,
new_cap: *new_cap,
})
.await;
}
AgentEvent::MessageStart { .. }
| AgentEvent::MessageUpdate { .. }
| AgentEvent::ToolExecutionUpdate { .. } => {
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::plugin::EventObserver;
use crate::types::{AssistantContent, StopReason};
#[tokio::test]
async fn recorder_writes_ordered_records_with_run_id() {
let sink = Arc::new(InMemoryTrajectorySink::new());
let recorder = TrajectoryRecorder::new(sink.clone());
recorder.on_event(&AgentEvent::AgentStart).await;
let identity = RunIdentity::root().with_conversation_id("conv-1");
recorder
.on_event(&AgentEvent::RunIdentified {
identity: identity.clone(),
})
.await;
recorder.on_event(&AgentEvent::TurnStart).await;
recorder
.on_event(&AgentEvent::TurnEnd {
message: AgentMessage::Assistant {
content: AssistantContent { blocks: Vec::new() },
stop_reason: StopReason::EndTurn,
error_message: None,
timestamp: None,
usage: None,
},
tool_results: Vec::new(),
})
.await;
recorder
.on_event(&AgentEvent::AgentEnd {
messages: Vec::new(),
})
.await;
let records = sink.snapshot().await;
assert_eq!(records.len(), 4);
assert!(matches!(
records[0].payload,
TrajectoryPayload::RunStarted { .. }
));
assert!(matches!(records[1].payload, TrajectoryPayload::TurnStarted));
assert!(matches!(
records[2].payload,
TrajectoryPayload::TurnEnded { .. }
));
assert!(matches!(
records[3].payload,
TrajectoryPayload::RunEnded { .. }
));
for (i, r) in records.iter().enumerate() {
assert_eq!(r.seq, i as u64);
assert_eq!(r.run_id.as_deref(), Some(identity.run_id.as_str()));
assert_eq!(
r.schema_version, TRAJECTORY_SCHEMA_VERSION,
"new records carry the current schema version"
);
}
}
#[test]
fn record_missing_schema_version_deserializes_as_pre_versioning() {
let json = serde_json::json!({
"seq": 7,
"recorded_at_unix_ms": 123,
"payload": { "kind": "turn_started" }
});
let record: TrajectoryRecord =
serde_json::from_value(json).expect("legacy record deserializes");
assert_eq!(record.schema_version, 0);
assert_eq!(record.seq, 7);
}
#[test]
fn record_round_trips_with_schema_version() {
let record = TrajectoryRecord {
schema_version: TRAJECTORY_SCHEMA_VERSION,
seq: 1,
run_id: Some("r1".into()),
parent_run_id: None,
depth: 0,
recorded_at_unix_ms: 1,
payload: TrajectoryPayload::TurnStarted,
};
let json = serde_json::to_value(&record).expect("serialize");
assert_eq!(
json["schema_version"],
serde_json::json!(TRAJECTORY_SCHEMA_VERSION)
);
let back: TrajectoryRecord = serde_json::from_value(json).expect("deserialize");
assert_eq!(back.schema_version, TRAJECTORY_SCHEMA_VERSION);
}
#[tokio::test]
async fn recorder_skips_streaming_only_events() {
let sink = Arc::new(InMemoryTrajectorySink::new());
let recorder = TrajectoryRecorder::new(sink.clone());
let msg = AgentMessage::User {
content: crate::types::UserContent::Text("hi".into()),
timestamp: None,
};
recorder
.on_event(&AgentEvent::MessageStart {
message: msg.clone(),
})
.await;
recorder
.on_event(&AgentEvent::ToolExecutionUpdate {
tool_call_id: "1".into(),
tool_name: "shell".into(),
partial: crate::tool::ToolResult::text("partial"),
})
.await;
assert!(sink.is_empty().await);
}
}