Skip to main content

cortexai_agents/
trajectory.rs

1//! Agent Trajectory Recording
2//!
3//! Structured logging for agent operations to enable debugging,
4//! metrics collection, and execution replay.
5//!
6//! # Example
7//!
8//! ```rust,ignore
9//! use cortex::trajectory::{TrajectoryRecorder, StepType};
10//!
11//! let mut recorder = TrajectoryRecorder::new("my-agent", "task-123");
12//!
13//! // Record LLM call
14//! recorder.start_step();
15//! // ... LLM call happens ...
16//! recorder.record_llm_call("gpt-4", Some(100), Some(50));
17//!
18//! // Record tool call
19//! recorder.start_step();
20//! // ... tool call happens ...
21//! recorder.record_tool_call("search", r#"{"query": "rust"}"#, true, None);
22//!
23//! // Complete the trajectory
24//! let trajectory = recorder.complete("Task completed successfully");
25//! println!("Total duration: {}ms", trajectory.total_duration_ms);
26//! ```
27
28use serde::{Deserialize, Serialize};
29use std::time::Instant;
30
31/// Type of step in a trajectory
32#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
33#[serde(rename_all = "snake_case")]
34pub enum StepType {
35    /// Task started
36    TaskStart,
37    /// LLM/Model call
38    LlmCall,
39    /// Tool invocation
40    ToolCall,
41    /// Tool result received
42    ToolResult,
43    /// Handoff to another agent
44    Handoff,
45    /// Planning step
46    Planning,
47    /// Memory operation
48    Memory,
49    /// Task completed successfully
50    TaskComplete,
51    /// Task failed
52    TaskFailed,
53}
54
55/// A single step in an agent's execution trajectory
56#[derive(Debug, Clone, Serialize, Deserialize)]
57pub struct TrajectoryStep {
58    /// Type of this step
59    pub step_type: StepType,
60    /// Agent that performed this step
61    pub agent_name: String,
62    /// Duration of this step in milliseconds
63    pub duration_ms: u64,
64    /// Additional details (JSON)
65    pub details: serde_json::Value,
66    /// Whether this step succeeded
67    pub success: bool,
68    /// Error message if failed
69    #[serde(skip_serializing_if = "Option::is_none")]
70    pub error: Option<String>,
71    /// Timestamp of this step
72    pub timestamp: i64,
73}
74
75impl TrajectoryStep {
76    /// Create a new trajectory step
77    pub fn new(
78        step_type: StepType,
79        agent_name: impl Into<String>,
80        duration_ms: u64,
81        details: serde_json::Value,
82        success: bool,
83    ) -> Self {
84        Self {
85            step_type,
86            agent_name: agent_name.into(),
87            duration_ms,
88            details,
89            success,
90            error: None,
91            timestamp: chrono::Utc::now().timestamp_millis(),
92        }
93    }
94
95    /// Add error message
96    pub fn with_error(mut self, error: impl Into<String>) -> Self {
97        self.error = Some(error.into());
98        self
99    }
100}
101
102/// Recorder for tracking agent execution trajectory
103#[derive(Debug)]
104pub struct TrajectoryRecorder {
105    agent_name: String,
106    task_id: String,
107    steps: Vec<TrajectoryStep>,
108    start_time: Instant,
109    current_step_start: Option<Instant>,
110    metadata: serde_json::Value,
111}
112
113impl TrajectoryRecorder {
114    /// Create a new trajectory recorder
115    pub fn new(agent_name: impl Into<String>, task_id: impl Into<String>) -> Self {
116        let agent = agent_name.into();
117        let task = task_id.into();
118
119        tracing::info!(
120            target: "trajectory",
121            agent = %agent,
122            task_id = %task,
123            "Task started"
124        );
125
126        Self {
127            agent_name: agent,
128            task_id: task,
129            steps: Vec::new(),
130            start_time: Instant::now(),
131            current_step_start: None,
132            metadata: serde_json::json!({}),
133        }
134    }
135
136    /// Set metadata for the trajectory
137    pub fn with_metadata(mut self, metadata: serde_json::Value) -> Self {
138        self.metadata = metadata;
139        self
140    }
141
142    /// Start timing a new step
143    pub fn start_step(&mut self) {
144        self.current_step_start = Some(Instant::now());
145    }
146
147    /// Get elapsed time since step started
148    fn step_duration(&self) -> u64 {
149        self.current_step_start
150            .map(|s| s.elapsed().as_millis() as u64)
151            .unwrap_or(0)
152    }
153
154    /// Record an LLM call
155    pub fn record_llm_call(
156        &mut self,
157        model: &str,
158        tokens_in: Option<u32>,
159        tokens_out: Option<u32>,
160    ) {
161        let duration = self.step_duration();
162
163        let step = TrajectoryStep::new(
164            StepType::LlmCall,
165            &self.agent_name,
166            duration,
167            serde_json::json!({
168                "model": model,
169                "tokens_in": tokens_in,
170                "tokens_out": tokens_out,
171            }),
172            true,
173        );
174
175        tracing::debug!(
176            target: "trajectory",
177            agent = %self.agent_name,
178            model = %model,
179            duration_ms = %duration,
180            tokens_in = ?tokens_in,
181            tokens_out = ?tokens_out,
182            "LLM call completed"
183        );
184
185        self.steps.push(step);
186        self.current_step_start = None;
187    }
188
189    /// Record an LLM call failure
190    pub fn record_llm_failure(&mut self, model: &str, error: &str) {
191        let duration = self.step_duration();
192
193        let step = TrajectoryStep::new(
194            StepType::LlmCall,
195            &self.agent_name,
196            duration,
197            serde_json::json!({
198                "model": model,
199            }),
200            false,
201        )
202        .with_error(error);
203
204        tracing::warn!(
205            target: "trajectory",
206            agent = %self.agent_name,
207            model = %model,
208            error = %error,
209            "LLM call failed"
210        );
211
212        self.steps.push(step);
213        self.current_step_start = None;
214    }
215
216    /// Record a tool call
217    pub fn record_tool_call(
218        &mut self,
219        tool_name: &str,
220        arguments: &str,
221        success: bool,
222        error: Option<&str>,
223    ) {
224        let duration = self.step_duration();
225
226        // Truncate arguments for logging
227        let args_preview = if arguments.len() > 200 {
228            format!("{}...", &arguments[..200])
229        } else {
230            arguments.to_string()
231        };
232
233        let mut step = TrajectoryStep::new(
234            StepType::ToolCall,
235            &self.agent_name,
236            duration,
237            serde_json::json!({
238                "tool": tool_name,
239                "arguments_preview": args_preview,
240            }),
241            success,
242        );
243
244        if let Some(err) = error {
245            step = step.with_error(err);
246        }
247
248        if success {
249            tracing::debug!(
250                target: "trajectory",
251                agent = %self.agent_name,
252                tool = %tool_name,
253                duration_ms = %duration,
254                "Tool call succeeded"
255            );
256        } else {
257            tracing::warn!(
258                target: "trajectory",
259                agent = %self.agent_name,
260                tool = %tool_name,
261                error = ?error,
262                "Tool call failed"
263            );
264        }
265
266        self.steps.push(step);
267        self.current_step_start = None;
268    }
269
270    /// Record a handoff to another agent
271    pub fn record_handoff(&mut self, target_agent: &str, reason: &str) {
272        let step = TrajectoryStep::new(
273            StepType::Handoff,
274            &self.agent_name,
275            0,
276            serde_json::json!({
277                "target_agent": target_agent,
278                "reason": reason,
279            }),
280            true,
281        );
282
283        tracing::info!(
284            target: "trajectory",
285            agent = %self.agent_name,
286            target = %target_agent,
287            reason = %reason,
288            "Agent handoff"
289        );
290
291        self.steps.push(step);
292    }
293
294    /// Record a planning step
295    pub fn record_planning(&mut self, plan_steps: usize, goal: &str) {
296        let duration = self.step_duration();
297
298        let step = TrajectoryStep::new(
299            StepType::Planning,
300            &self.agent_name,
301            duration,
302            serde_json::json!({
303                "plan_steps": plan_steps,
304                "goal": goal,
305            }),
306            true,
307        );
308
309        tracing::debug!(
310            target: "trajectory",
311            agent = %self.agent_name,
312            plan_steps = %plan_steps,
313            "Plan created"
314        );
315
316        self.steps.push(step);
317        self.current_step_start = None;
318    }
319
320    /// Record a memory operation
321    pub fn record_memory(&mut self, operation: &str, key: &str, success: bool) {
322        let duration = self.step_duration();
323
324        let step = TrajectoryStep::new(
325            StepType::Memory,
326            &self.agent_name,
327            duration,
328            serde_json::json!({
329                "operation": operation,
330                "key": key,
331            }),
332            success,
333        );
334
335        self.steps.push(step);
336        self.current_step_start = None;
337    }
338
339    /// Record a custom step
340    pub fn record_custom(
341        &mut self,
342        step_type: StepType,
343        details: serde_json::Value,
344        success: bool,
345        error: Option<&str>,
346    ) {
347        let duration = self.step_duration();
348
349        let mut step = TrajectoryStep::new(step_type, &self.agent_name, duration, details, success);
350
351        if let Some(err) = error {
352            step = step.with_error(err);
353        }
354
355        self.steps.push(step);
356        self.current_step_start = None;
357    }
358
359    /// Complete the task successfully
360    pub fn complete(mut self, result_preview: &str) -> Trajectory {
361        let total_duration = self.start_time.elapsed();
362
363        // Truncate result for storage
364        let preview = if result_preview.len() > 500 {
365            format!("{}...", &result_preview[..500])
366        } else {
367            result_preview.to_string()
368        };
369
370        self.steps.push(TrajectoryStep::new(
371            StepType::TaskComplete,
372            &self.agent_name,
373            total_duration.as_millis() as u64,
374            serde_json::json!({
375                "result_preview": preview,
376            }),
377            true,
378        ));
379
380        let trajectory = Trajectory {
381            agent_name: self.agent_name.clone(),
382            task_id: self.task_id.clone(),
383            total_duration_ms: total_duration.as_millis() as u64,
384            steps: self.steps,
385            success: true,
386            metadata: self.metadata,
387        };
388
389        tracing::info!(
390            target: "trajectory",
391            agent = %trajectory.agent_name,
392            task_id = %trajectory.task_id,
393            duration_ms = %trajectory.total_duration_ms,
394            steps = %trajectory.steps.len(),
395            "Task completed successfully"
396        );
397
398        trajectory
399    }
400
401    /// Fail the task
402    pub fn fail(mut self, error: &str) -> Trajectory {
403        let total_duration = self.start_time.elapsed();
404
405        self.steps.push(
406            TrajectoryStep::new(
407                StepType::TaskFailed,
408                &self.agent_name,
409                total_duration.as_millis() as u64,
410                serde_json::json!({}),
411                false,
412            )
413            .with_error(error),
414        );
415
416        let trajectory = Trajectory {
417            agent_name: self.agent_name.clone(),
418            task_id: self.task_id.clone(),
419            total_duration_ms: total_duration.as_millis() as u64,
420            steps: self.steps,
421            success: false,
422            metadata: self.metadata,
423        };
424
425        tracing::error!(
426            target: "trajectory",
427            agent = %trajectory.agent_name,
428            task_id = %trajectory.task_id,
429            error = %error,
430            "Task failed"
431        );
432
433        trajectory
434    }
435
436    /// Get current step count
437    pub fn step_count(&self) -> usize {
438        self.steps.len()
439    }
440
441    /// Get elapsed time since start
442    pub fn elapsed_ms(&self) -> u64 {
443        self.start_time.elapsed().as_millis() as u64
444    }
445}
446
447/// Complete trajectory for a task execution
448#[derive(Debug, Clone, Serialize, Deserialize)]
449pub struct Trajectory {
450    /// Agent that executed the task
451    pub agent_name: String,
452    /// Task identifier
453    pub task_id: String,
454    /// Total execution time in milliseconds
455    pub total_duration_ms: u64,
456    /// All steps in the execution
457    pub steps: Vec<TrajectoryStep>,
458    /// Whether the task succeeded
459    pub success: bool,
460    /// Additional metadata
461    pub metadata: serde_json::Value,
462}
463
464impl Trajectory {
465    /// Get count of tool calls in trajectory
466    pub fn tool_call_count(&self) -> usize {
467        self.steps
468            .iter()
469            .filter(|s| s.step_type == StepType::ToolCall)
470            .count()
471    }
472
473    /// Get count of LLM calls in trajectory
474    pub fn llm_call_count(&self) -> usize {
475        self.steps
476            .iter()
477            .filter(|s| s.step_type == StepType::LlmCall)
478            .count()
479    }
480
481    /// Get count of failed steps
482    pub fn failed_step_count(&self) -> usize {
483        self.steps.iter().filter(|s| !s.success).count()
484    }
485
486    /// Get total tokens used (if tracked)
487    pub fn total_tokens(&self) -> (u32, u32) {
488        let mut tokens_in = 0u32;
489        let mut tokens_out = 0u32;
490
491        for step in &self.steps {
492            if step.step_type == StepType::LlmCall {
493                if let Some(t) = step.details.get("tokens_in").and_then(|v| v.as_u64()) {
494                    tokens_in += t as u32;
495                }
496                if let Some(t) = step.details.get("tokens_out").and_then(|v| v.as_u64()) {
497                    tokens_out += t as u32;
498                }
499            }
500        }
501
502        (tokens_in, tokens_out)
503    }
504
505    /// Get all tool names used
506    pub fn tools_used(&self) -> Vec<String> {
507        self.steps
508            .iter()
509            .filter(|s| s.step_type == StepType::ToolCall)
510            .filter_map(|s| s.details.get("tool").and_then(|v| v.as_str()))
511            .map(|s| s.to_string())
512            .collect()
513    }
514
515    /// Get steps by type
516    pub fn steps_by_type(&self, step_type: StepType) -> Vec<&TrajectoryStep> {
517        self.steps
518            .iter()
519            .filter(|s| s.step_type == step_type)
520            .collect()
521    }
522
523    /// Log the full trajectory as structured JSON
524    pub fn log_json(&self) {
525        if let Ok(json) = serde_json::to_string(self) {
526            tracing::info!(target: "trajectory_json", "{}", json);
527        }
528    }
529
530    /// Convert to pretty JSON
531    pub fn to_json_pretty(&self) -> Result<String, serde_json::Error> {
532        serde_json::to_string_pretty(self)
533    }
534
535    /// Get summary statistics
536    pub fn summary(&self) -> TrajectorySummary {
537        let (tokens_in, tokens_out) = self.total_tokens();
538
539        TrajectorySummary {
540            agent_name: self.agent_name.clone(),
541            task_id: self.task_id.clone(),
542            success: self.success,
543            total_duration_ms: self.total_duration_ms,
544            step_count: self.steps.len(),
545            llm_calls: self.llm_call_count(),
546            tool_calls: self.tool_call_count(),
547            failed_steps: self.failed_step_count(),
548            tokens_in,
549            tokens_out,
550        }
551    }
552}
553
554/// Summary of a trajectory
555#[derive(Debug, Clone, Serialize, Deserialize)]
556pub struct TrajectorySummary {
557    pub agent_name: String,
558    pub task_id: String,
559    pub success: bool,
560    pub total_duration_ms: u64,
561    pub step_count: usize,
562    pub llm_calls: usize,
563    pub tool_calls: usize,
564    pub failed_steps: usize,
565    pub tokens_in: u32,
566    pub tokens_out: u32,
567}
568
569/// Storage for trajectories
570#[derive(Default)]
571pub struct TrajectoryStore {
572    trajectories: std::sync::RwLock<Vec<Trajectory>>,
573    max_size: usize,
574}
575
576impl TrajectoryStore {
577    /// Create a new trajectory store
578    pub fn new(max_size: usize) -> Self {
579        Self {
580            trajectories: std::sync::RwLock::new(Vec::new()),
581            max_size,
582        }
583    }
584
585    /// Store a trajectory
586    pub fn store(&self, trajectory: Trajectory) {
587        let mut trajectories = self.trajectories.write().unwrap();
588
589        if trajectories.len() >= self.max_size {
590            trajectories.remove(0);
591        }
592
593        trajectories.push(trajectory);
594    }
595
596    /// Get all trajectories
597    pub fn all(&self) -> Vec<Trajectory> {
598        self.trajectories.read().unwrap().clone()
599    }
600
601    /// Get trajectories by agent
602    pub fn by_agent(&self, agent_name: &str) -> Vec<Trajectory> {
603        self.trajectories
604            .read()
605            .unwrap()
606            .iter()
607            .filter(|t| t.agent_name == agent_name)
608            .cloned()
609            .collect()
610    }
611
612    /// Get trajectory by task ID
613    pub fn by_task_id(&self, task_id: &str) -> Option<Trajectory> {
614        self.trajectories
615            .read()
616            .unwrap()
617            .iter()
618            .find(|t| t.task_id == task_id)
619            .cloned()
620    }
621
622    /// Get failed trajectories
623    pub fn failed(&self) -> Vec<Trajectory> {
624        self.trajectories
625            .read()
626            .unwrap()
627            .iter()
628            .filter(|t| !t.success)
629            .cloned()
630            .collect()
631    }
632
633    /// Get count
634    pub fn len(&self) -> usize {
635        self.trajectories.read().unwrap().len()
636    }
637
638    /// Check if empty
639    pub fn is_empty(&self) -> bool {
640        self.trajectories.read().unwrap().is_empty()
641    }
642
643    /// Clear all trajectories
644    pub fn clear(&self) {
645        self.trajectories.write().unwrap().clear();
646    }
647}
648
649#[cfg(test)]
650mod tests {
651    use super::*;
652    use std::time::Duration;
653
654    #[test]
655    fn test_trajectory_recorder_basic() {
656        let mut recorder = TrajectoryRecorder::new("test-agent", "task-1");
657
658        recorder.start_step();
659        std::thread::sleep(Duration::from_millis(10));
660        recorder.record_llm_call("gpt-4", Some(100), Some(50));
661
662        recorder.start_step();
663        recorder.record_tool_call("search", r#"{"query": "test"}"#, true, None);
664
665        let trajectory = recorder.complete("Done!");
666
667        assert!(trajectory.success);
668        assert_eq!(trajectory.agent_name, "test-agent");
669        assert_eq!(trajectory.task_id, "task-1");
670        assert_eq!(trajectory.llm_call_count(), 1);
671        assert_eq!(trajectory.tool_call_count(), 1);
672        assert!(trajectory.total_duration_ms >= 10);
673    }
674
675    #[test]
676    fn test_trajectory_failure() {
677        let recorder = TrajectoryRecorder::new("test-agent", "task-2");
678        let trajectory = recorder.fail("Something went wrong");
679
680        assert!(!trajectory.success);
681        assert_eq!(trajectory.failed_step_count(), 1);
682    }
683
684    #[test]
685    fn test_trajectory_tokens() {
686        let mut recorder = TrajectoryRecorder::new("test-agent", "task-3");
687
688        recorder.start_step();
689        recorder.record_llm_call("gpt-4", Some(100), Some(50));
690
691        recorder.start_step();
692        recorder.record_llm_call("gpt-4", Some(200), Some(100));
693
694        let trajectory = recorder.complete("Done");
695
696        let (tokens_in, tokens_out) = trajectory.total_tokens();
697        assert_eq!(tokens_in, 300);
698        assert_eq!(tokens_out, 150);
699    }
700
701    #[test]
702    fn test_trajectory_tools_used() {
703        let mut recorder = TrajectoryRecorder::new("test-agent", "task-4");
704
705        recorder.start_step();
706        recorder.record_tool_call("search", "{}", true, None);
707
708        recorder.start_step();
709        recorder.record_tool_call("calculator", "{}", true, None);
710
711        recorder.start_step();
712        recorder.record_tool_call("search", "{}", true, None);
713
714        let trajectory = recorder.complete("Done");
715
716        let tools = trajectory.tools_used();
717        assert_eq!(tools.len(), 3);
718        assert!(tools.contains(&"search".to_string()));
719        assert!(tools.contains(&"calculator".to_string()));
720    }
721
722    #[test]
723    fn test_trajectory_handoff() {
724        let mut recorder = TrajectoryRecorder::new("agent-a", "task-5");
725
726        recorder.record_handoff("agent-b", "Better suited for this task");
727
728        let trajectory = recorder.complete("Handed off");
729
730        let handoffs = trajectory.steps_by_type(StepType::Handoff);
731        assert_eq!(handoffs.len(), 1);
732    }
733
734    #[test]
735    fn test_trajectory_summary() {
736        let mut recorder = TrajectoryRecorder::new("test-agent", "task-6");
737
738        recorder.start_step();
739        recorder.record_llm_call("gpt-4", Some(100), Some(50));
740
741        recorder.start_step();
742        recorder.record_tool_call("search", "{}", true, None);
743
744        recorder.start_step();
745        recorder.record_tool_call("failed_tool", "{}", false, Some("Error"));
746
747        let trajectory = recorder.complete("Done");
748        let summary = trajectory.summary();
749
750        assert_eq!(summary.llm_calls, 1);
751        assert_eq!(summary.tool_calls, 2);
752        assert_eq!(summary.failed_steps, 1);
753        assert_eq!(summary.tokens_in, 100);
754        assert_eq!(summary.tokens_out, 50);
755    }
756
757    #[test]
758    fn test_trajectory_store() {
759        let store = TrajectoryStore::new(10);
760
761        let recorder1 = TrajectoryRecorder::new("agent-1", "task-1");
762        store.store(recorder1.complete("Done 1"));
763
764        let recorder2 = TrajectoryRecorder::new("agent-2", "task-2");
765        store.store(recorder2.complete("Done 2"));
766
767        assert_eq!(store.len(), 2);
768        assert_eq!(store.by_agent("agent-1").len(), 1);
769        assert!(store.by_task_id("task-1").is_some());
770    }
771
772    #[test]
773    fn test_trajectory_store_max_size() {
774        let store = TrajectoryStore::new(2);
775
776        for i in 0..5 {
777            let recorder = TrajectoryRecorder::new("agent", format!("task-{}", i));
778            store.store(recorder.complete("Done"));
779        }
780
781        assert_eq!(store.len(), 2);
782        // First 3 should have been evicted
783        assert!(store.by_task_id("task-0").is_none());
784        assert!(store.by_task_id("task-1").is_none());
785        assert!(store.by_task_id("task-2").is_none());
786        // Last 2 should exist
787        assert!(store.by_task_id("task-3").is_some());
788        assert!(store.by_task_id("task-4").is_some());
789    }
790}