1use std::sync::atomic::{AtomicU64, Ordering};
48use std::sync::Arc;
49
50use async_trait::async_trait;
51use serde::{Deserialize, Serialize};
52use tokio::sync::Mutex;
53
54use crate::event::AgentEvent;
55use crate::plugin::{Plugin, PluginCapabilities};
56use crate::types::{AgentMessage, RunIdentity};
57
58pub const TRAJECTORY_SCHEMA_VERSION: u32 = 1;
68
69fn pre_versioning_schema() -> u32 {
74 0
75}
76
77#[derive(Debug, Clone, Serialize, Deserialize)]
84pub struct TrajectoryRecord {
85 #[serde(default = "pre_versioning_schema")]
89 pub schema_version: u32,
90 pub seq: u64,
91 #[serde(default, skip_serializing_if = "Option::is_none")]
92 pub run_id: Option<String>,
93 #[serde(default, skip_serializing_if = "Option::is_none")]
94 pub parent_run_id: Option<String>,
95 #[serde(default)]
96 pub depth: usize,
97 pub recorded_at_unix_ms: u64,
99 pub payload: TrajectoryPayload,
100}
101
102#[derive(Debug, Clone, Serialize, Deserialize)]
107#[serde(tag = "kind", rename_all = "snake_case")]
108pub enum TrajectoryPayload {
109 RunStarted {
110 identity: RunIdentity,
111 },
112 RunEnded {
113 outcome: String,
114 new_messages: Vec<AgentMessage>,
115 },
116 TurnStarted,
117 TurnEnded {
118 assistant: AgentMessage,
119 tool_results: Vec<AgentMessage>,
120 },
121 MessageAppended {
124 message: AgentMessage,
125 },
126 ToolStarted {
127 tool_call_id: String,
128 tool_name: String,
129 args: serde_json::Value,
130 },
131 ToolEnded {
132 tool_call_id: String,
133 tool_name: String,
134 result: crate::tool::ToolResult,
135 is_error: bool,
136 },
137 ProviderRequestPrepared {
140 iteration: usize,
141 model_id: Option<String>,
142 system_prompt_chars: usize,
143 message_count: usize,
144 tool_count: usize,
145 tools: Vec<String>,
146 },
147 ContextTransformApplied {
152 iteration: usize,
153 plugin: String,
154 before_count: usize,
155 after_count: usize,
156 },
157 ToolGateApplied {
159 iteration: usize,
160 plugin: String,
161 #[serde(default, skip_serializing_if = "Option::is_none")]
162 allow: Option<Vec<String>>,
163 },
164 ToolGateConflictResolved {
168 iteration: usize,
169 plugins: Vec<String>,
170 #[serde(default, skip_serializing_if = "Option::is_none")]
171 chosen_plugin: Option<String>,
172 allow: Vec<String>,
173 reason: String,
174 },
175 OutputTokensEscalation {
177 attempt: u8,
178 prev_cap: u32,
179 new_cap: u32,
180 },
181}
182
183#[derive(Debug, thiserror::Error)]
185pub enum TrajectoryError {
186 #[error("trajectory sink rejected record: {0}")]
187 Rejected(String),
188 #[error("trajectory sink i/o failure: {0}")]
189 Io(String),
190}
191
192#[async_trait]
203pub trait TrajectorySink: Send + Sync {
204 async fn record(&self, record: TrajectoryRecord) -> Result<(), TrajectoryError>;
205}
206
207#[derive(Debug, Default)]
212pub struct InMemoryTrajectorySink {
213 records: Mutex<Vec<TrajectoryRecord>>,
214}
215
216impl InMemoryTrajectorySink {
217 pub fn new() -> Self {
218 Self::default()
219 }
220
221 pub async fn snapshot(&self) -> Vec<TrajectoryRecord> {
222 self.records.lock().await.clone()
223 }
224
225 pub async fn len(&self) -> usize {
226 self.records.lock().await.len()
227 }
228
229 pub async fn is_empty(&self) -> bool {
230 self.records.lock().await.is_empty()
231 }
232}
233
234#[async_trait]
235impl TrajectorySink for InMemoryTrajectorySink {
236 async fn record(&self, record: TrajectoryRecord) -> Result<(), TrajectoryError> {
237 self.records.lock().await.push(record);
238 Ok(())
239 }
240}
241
242pub struct TrajectoryRecorder {
254 sink: Arc<dyn TrajectorySink>,
255 seq: AtomicU64,
256 identity: Mutex<Option<RunIdentity>>,
257}
258
259impl TrajectoryRecorder {
260 pub fn new(sink: Arc<dyn TrajectorySink>) -> Self {
261 Self {
262 sink,
263 seq: AtomicU64::new(0),
264 identity: Mutex::new(None),
265 }
266 }
267
268 async fn record(&self, payload: TrajectoryPayload) {
269 let seq = self.seq.fetch_add(1, Ordering::SeqCst);
270 let identity = self.identity.lock().await.clone();
271 let recorded_at_unix_ms = std::time::SystemTime::now()
272 .duration_since(std::time::UNIX_EPOCH)
273 .map(|d| d.as_millis() as u64)
274 .unwrap_or(0);
275 let record = TrajectoryRecord {
276 schema_version: TRAJECTORY_SCHEMA_VERSION,
277 seq,
278 run_id: identity.as_ref().map(|i| i.run_id.clone()),
279 parent_run_id: identity.as_ref().and_then(|i| i.parent_run_id.clone()),
280 depth: identity.as_ref().map(|i| i.depth).unwrap_or(0),
281 recorded_at_unix_ms,
282 payload,
283 };
284 if let Err(e) = self.sink.record(record).await {
285 tracing::warn!(error = %e, "trajectory sink rejected record; continuing");
286 }
287 }
288}
289
290impl Plugin for TrajectoryRecorder {
291 fn name(&self) -> &'static str {
292 "trajectory_recorder"
293 }
294
295 fn capabilities(&self) -> PluginCapabilities {
296 PluginCapabilities::event_observer()
297 }
298}
299
300#[async_trait]
301impl crate::plugin::EventObserver for TrajectoryRecorder {
302 async fn on_event(&self, event: &AgentEvent) {
303 match event {
304 AgentEvent::AgentStart => {
305 self.seq.store(0, Ordering::SeqCst);
309 *self.identity.lock().await = None;
310 }
311 AgentEvent::RunIdentified { identity } => {
312 *self.identity.lock().await = Some(identity.clone());
313 self.record(TrajectoryPayload::RunStarted {
314 identity: identity.clone(),
315 })
316 .await;
317 }
318 AgentEvent::AgentEnd { messages } => {
319 self.record(TrajectoryPayload::RunEnded {
320 outcome: "ended".to_string(),
321 new_messages: messages.clone(),
322 })
323 .await;
324 }
325 AgentEvent::TurnStart => {
326 self.record(TrajectoryPayload::TurnStarted).await;
327 }
328 AgentEvent::TurnEnd {
329 message,
330 tool_results,
331 } => {
332 self.record(TrajectoryPayload::TurnEnded {
333 assistant: message.clone(),
334 tool_results: tool_results.clone(),
335 })
336 .await;
337 }
338 AgentEvent::MessageEnd { message } => {
339 self.record(TrajectoryPayload::MessageAppended {
340 message: message.clone(),
341 })
342 .await;
343 }
344 AgentEvent::ToolExecutionStart {
345 tool_call_id,
346 tool_name,
347 args,
348 } => {
349 self.record(TrajectoryPayload::ToolStarted {
350 tool_call_id: tool_call_id.clone(),
351 tool_name: tool_name.clone(),
352 args: args.clone(),
353 })
354 .await;
355 }
356 AgentEvent::ToolExecutionEnd {
357 tool_call_id,
358 tool_name,
359 result,
360 is_error,
361 } => {
362 self.record(TrajectoryPayload::ToolEnded {
363 tool_call_id: tool_call_id.clone(),
364 tool_name: tool_name.clone(),
365 result: result.clone(),
366 is_error: *is_error,
367 })
368 .await;
369 }
370 AgentEvent::ProviderRequestPrepared {
371 iteration,
372 model_id,
373 system_prompt,
374 messages,
375 tools,
376 ..
377 } => {
378 self.record(TrajectoryPayload::ProviderRequestPrepared {
379 iteration: *iteration,
380 model_id: model_id.clone(),
381 system_prompt_chars: system_prompt.chars().count(),
382 message_count: messages.len(),
383 tool_count: tools.len(),
384 tools: tools.iter().map(|t| t.name.clone()).collect(),
385 })
386 .await;
387 }
388 AgentEvent::ContextTransformApplied {
389 iteration,
390 plugin,
391 before,
392 after,
393 } => {
394 self.record(TrajectoryPayload::ContextTransformApplied {
395 iteration: *iteration,
396 plugin: (*plugin).to_string(),
397 before_count: before.len(),
398 after_count: after.len(),
399 })
400 .await;
401 }
402 AgentEvent::ToolGateApplied {
403 iteration,
404 plugin,
405 allow,
406 } => {
407 self.record(TrajectoryPayload::ToolGateApplied {
408 iteration: *iteration,
409 plugin: (*plugin).to_string(),
410 allow: allow.clone(),
411 })
412 .await;
413 }
414 AgentEvent::ToolGateConflictResolved {
415 iteration,
416 plugins,
417 chosen_plugin,
418 allow,
419 reason,
420 } => {
421 self.record(TrajectoryPayload::ToolGateConflictResolved {
422 iteration: *iteration,
423 plugins: plugins.clone(),
424 chosen_plugin: chosen_plugin.clone(),
425 allow: allow.clone(),
426 reason: reason.clone(),
427 })
428 .await;
429 }
430 AgentEvent::OutputTokensEscalation {
431 attempt,
432 prev_cap,
433 new_cap,
434 } => {
435 self.record(TrajectoryPayload::OutputTokensEscalation {
436 attempt: *attempt,
437 prev_cap: *prev_cap,
438 new_cap: *new_cap,
439 })
440 .await;
441 }
442 AgentEvent::MessageStart { .. }
443 | AgentEvent::MessageUpdate { .. }
444 | AgentEvent::ToolExecutionUpdate { .. } => {
445 }
450 }
451 }
452}
453
454#[cfg(test)]
455mod tests {
456 use super::*;
457 use crate::plugin::EventObserver;
458 use crate::types::{AssistantContent, StopReason};
459
460 #[tokio::test]
461 async fn recorder_writes_ordered_records_with_run_id() {
462 let sink = Arc::new(InMemoryTrajectorySink::new());
463 let recorder = TrajectoryRecorder::new(sink.clone());
464
465 recorder.on_event(&AgentEvent::AgentStart).await;
466 let identity = RunIdentity::root().with_conversation_id("conv-1");
467 recorder
468 .on_event(&AgentEvent::RunIdentified {
469 identity: identity.clone(),
470 })
471 .await;
472 recorder.on_event(&AgentEvent::TurnStart).await;
473 recorder
474 .on_event(&AgentEvent::TurnEnd {
475 message: AgentMessage::Assistant {
476 content: AssistantContent { blocks: Vec::new() },
477 stop_reason: StopReason::EndTurn,
478 error_message: None,
479 timestamp: None,
480 usage: None,
481 },
482 tool_results: Vec::new(),
483 })
484 .await;
485 recorder
486 .on_event(&AgentEvent::AgentEnd {
487 messages: Vec::new(),
488 })
489 .await;
490
491 let records = sink.snapshot().await;
492 assert_eq!(records.len(), 4);
494 assert!(matches!(
495 records[0].payload,
496 TrajectoryPayload::RunStarted { .. }
497 ));
498 assert!(matches!(records[1].payload, TrajectoryPayload::TurnStarted));
499 assert!(matches!(
500 records[2].payload,
501 TrajectoryPayload::TurnEnded { .. }
502 ));
503 assert!(matches!(
504 records[3].payload,
505 TrajectoryPayload::RunEnded { .. }
506 ));
507
508 for (i, r) in records.iter().enumerate() {
509 assert_eq!(r.seq, i as u64);
510 assert_eq!(r.run_id.as_deref(), Some(identity.run_id.as_str()));
511 assert_eq!(
512 r.schema_version, TRAJECTORY_SCHEMA_VERSION,
513 "new records carry the current schema version"
514 );
515 }
516 }
517
518 #[test]
519 fn record_missing_schema_version_deserializes_as_pre_versioning() {
520 let json = serde_json::json!({
524 "seq": 7,
525 "recorded_at_unix_ms": 123,
526 "payload": { "kind": "turn_started" }
527 });
528 let record: TrajectoryRecord =
529 serde_json::from_value(json).expect("legacy record deserializes");
530 assert_eq!(record.schema_version, 0);
531 assert_eq!(record.seq, 7);
532 }
533
534 #[test]
535 fn record_round_trips_with_schema_version() {
536 let record = TrajectoryRecord {
537 schema_version: TRAJECTORY_SCHEMA_VERSION,
538 seq: 1,
539 run_id: Some("r1".into()),
540 parent_run_id: None,
541 depth: 0,
542 recorded_at_unix_ms: 1,
543 payload: TrajectoryPayload::TurnStarted,
544 };
545 let json = serde_json::to_value(&record).expect("serialize");
546 assert_eq!(
547 json["schema_version"],
548 serde_json::json!(TRAJECTORY_SCHEMA_VERSION)
549 );
550 let back: TrajectoryRecord = serde_json::from_value(json).expect("deserialize");
551 assert_eq!(back.schema_version, TRAJECTORY_SCHEMA_VERSION);
552 }
553
554 #[tokio::test]
555 async fn recorder_skips_streaming_only_events() {
556 let sink = Arc::new(InMemoryTrajectorySink::new());
557 let recorder = TrajectoryRecorder::new(sink.clone());
558
559 let msg = AgentMessage::User {
560 content: crate::types::UserContent::Text("hi".into()),
561 timestamp: None,
562 };
563 recorder
564 .on_event(&AgentEvent::MessageStart {
565 message: msg.clone(),
566 })
567 .await;
568 recorder
569 .on_event(&AgentEvent::ToolExecutionUpdate {
570 tool_call_id: "1".into(),
571 tool_name: "shell".into(),
572 partial: crate::tool::ToolResult::text("partial"),
573 })
574 .await;
575
576 assert!(sink.is_empty().await);
577 }
578}