nt_memory/reasoningbank/
trajectory.rs

1//! Trajectory tracking for agent decision paths
2
3use serde::{Serialize, Deserialize};
4use std::sync::Arc;
5use tokio::sync::RwLock;
6use uuid::Uuid;
7
8/// Agent observation
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct Observation {
11    /// Observation ID
12    pub id: String,
13
14    /// Timestamp
15    pub timestamp: chrono::DateTime<chrono::Utc>,
16
17    /// Observation data (JSON)
18    pub data: serde_json::Value,
19
20    /// Embedding vector
21    pub embedding: Option<Vec<f32>>,
22}
23
24/// Agent action
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct Action {
27    /// Action ID
28    pub id: String,
29
30    /// Timestamp
31    pub timestamp: chrono::DateTime<chrono::Utc>,
32
33    /// Action type
34    pub action_type: String,
35
36    /// Action parameters
37    pub parameters: serde_json::Value,
38
39    /// Predicted outcome
40    pub predicted_outcome: Option<f64>,
41}
42
43/// Complete agent trajectory
44#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct Trajectory {
46    /// Trajectory ID
47    pub id: String,
48
49    /// Agent ID
50    pub agent_id: String,
51
52    /// Start timestamp
53    pub start_time: chrono::DateTime<chrono::Utc>,
54
55    /// End timestamp
56    pub end_time: Option<chrono::DateTime<chrono::Utc>>,
57
58    /// Observations
59    pub observations: Vec<Observation>,
60
61    /// Actions taken
62    pub actions: Vec<Action>,
63
64    /// Actual outcomes
65    pub outcomes: Vec<f64>,
66
67    /// Metadata
68    pub metadata: serde_json::Value,
69}
70
71impl Trajectory {
72    /// Create new trajectory
73    pub fn new(agent_id: String) -> Self {
74        Self {
75            id: Uuid::new_v4().to_string(),
76            agent_id,
77            start_time: chrono::Utc::now(),
78            end_time: None,
79            observations: Vec::new(),
80            actions: Vec::new(),
81            outcomes: Vec::new(),
82            metadata: serde_json::json!({}),
83        }
84    }
85
86    /// Add observation
87    pub fn add_observation(&mut self, data: serde_json::Value, embedding: Option<Vec<f32>>) {
88        self.observations.push(Observation {
89            id: Uuid::new_v4().to_string(),
90            timestamp: chrono::Utc::now(),
91            data,
92            embedding,
93        });
94    }
95
96    /// Add action
97    pub fn add_action(
98        &mut self,
99        action_type: String,
100        parameters: serde_json::Value,
101        predicted_outcome: Option<f64>,
102    ) {
103        self.actions.push(Action {
104            id: Uuid::new_v4().to_string(),
105            timestamp: chrono::Utc::now(),
106            action_type,
107            parameters,
108            predicted_outcome,
109        });
110    }
111
112    /// Add outcome
113    pub fn add_outcome(&mut self, outcome: f64) {
114        self.outcomes.push(outcome);
115    }
116
117    /// Complete trajectory
118    pub fn complete(&mut self) {
119        self.end_time = Some(chrono::Utc::now());
120    }
121
122    /// Check if trajectory is complete
123    pub fn is_complete(&self) -> bool {
124        self.end_time.is_some()
125    }
126
127    /// Calculate trajectory score
128    pub fn score(&self) -> Option<f64> {
129        if self.outcomes.is_empty() {
130            return None;
131        }
132
133        let sum: f64 = self.outcomes.iter().sum();
134        Some(sum / self.outcomes.len() as f64)
135    }
136}
137
138/// Trajectory tracker
139pub struct TrajectoryTracker {
140    /// Active trajectories
141    active: Arc<RwLock<std::collections::HashMap<String, Trajectory>>>,
142
143    /// Completed trajectories
144    completed: Arc<RwLock<Vec<Trajectory>>>,
145}
146
147impl TrajectoryTracker {
148    /// Create new trajectory tracker
149    pub fn new() -> Self {
150        Self {
151            active: Arc::new(RwLock::new(std::collections::HashMap::new())),
152            completed: Arc::new(RwLock::new(Vec::new())),
153        }
154    }
155
156    /// Start new trajectory
157    pub async fn start(&self, agent_id: String) -> String {
158        let trajectory = Trajectory::new(agent_id);
159        let id = trajectory.id.clone();
160
161        let mut active = self.active.write().await;
162        active.insert(id.clone(), trajectory);
163
164        id
165    }
166
167    /// Track trajectory
168    pub async fn track(&self, trajectory: Trajectory) -> anyhow::Result<()> {
169        let mut active = self.active.write().await;
170        active.insert(trajectory.id.clone(), trajectory);
171        Ok(())
172    }
173
174    /// Get active trajectory
175    pub async fn get_active(&self, id: &str) -> Option<Trajectory> {
176        let active = self.active.read().await;
177        active.get(id).cloned()
178    }
179
180    /// Complete trajectory
181    pub async fn complete(&self, id: &str) -> anyhow::Result<()> {
182        let mut active = self.active.write().await;
183
184        if let Some(mut trajectory) = active.remove(id) {
185            trajectory.complete();
186
187            let mut completed = self.completed.write().await;
188            completed.push(trajectory);
189        }
190
191        Ok(())
192    }
193
194    /// Get completed trajectories
195    pub async fn get_completed(&self, agent_id: Option<&str>) -> Vec<Trajectory> {
196        let completed = self.completed.read().await;
197
198        if let Some(agent_id) = agent_id {
199            completed
200                .iter()
201                .filter(|t| t.agent_id == agent_id)
202                .cloned()
203                .collect()
204        } else {
205            completed.clone()
206        }
207    }
208
209    /// Count trajectories
210    pub fn count(&self) -> usize {
211        // Blocking read for stats
212        let active = self.active.blocking_read();
213        let completed = self.completed.blocking_read();
214        active.len() + completed.len()
215    }
216
217    /// Get successful trajectories (score > 0.5)
218    pub async fn get_successful(&self) -> Vec<Trajectory> {
219        let completed = self.completed.read().await;
220
221        completed
222            .iter()
223            .filter(|t| t.score().map(|s| s > 0.5).unwrap_or(false))
224            .cloned()
225            .collect()
226    }
227}
228
229impl Default for TrajectoryTracker {
230    fn default() -> Self {
231        Self::new()
232    }
233}
234
235#[cfg(test)]
236mod tests {
237    use super::*;
238
239    #[tokio::test]
240    async fn test_trajectory_creation() {
241        let mut trajectory = Trajectory::new("agent_1".to_string());
242
243        trajectory.add_observation(serde_json::json!({"price": 100.0}), None);
244        trajectory.add_action(
245            "buy".to_string(),
246            serde_json::json!({"quantity": 10}),
247            Some(110.0),
248        );
249        trajectory.add_outcome(105.0);
250
251        assert_eq!(trajectory.observations.len(), 1);
252        assert_eq!(trajectory.actions.len(), 1);
253        assert_eq!(trajectory.outcomes.len(), 1);
254    }
255
256    #[tokio::test]
257    async fn test_trajectory_tracker() {
258        let tracker = TrajectoryTracker::new();
259
260        // Start trajectory
261        let id = tracker.start("agent_1".to_string()).await;
262
263        // Get active
264        let trajectory = tracker.get_active(&id).await;
265        assert!(trajectory.is_some());
266
267        // Complete
268        tracker.complete(&id).await.unwrap();
269
270        // Should be in completed
271        let completed = tracker.get_completed(Some("agent_1")).await;
272        assert_eq!(completed.len(), 1);
273    }
274
275    #[test]
276    fn test_trajectory_score() {
277        let mut trajectory = Trajectory::new("agent_1".to_string());
278
279        trajectory.add_outcome(0.8);
280        trajectory.add_outcome(0.6);
281        trajectory.add_outcome(0.7);
282
283        let score = trajectory.score().unwrap();
284        assert!((score - 0.7).abs() < 0.01);
285    }
286}