nt_memory/reasoningbank/
trajectory.rs1use serde::{Serialize, Deserialize};
4use std::sync::Arc;
5use tokio::sync::RwLock;
6use uuid::Uuid;
7
8#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct Observation {
11 pub id: String,
13
14 pub timestamp: chrono::DateTime<chrono::Utc>,
16
17 pub data: serde_json::Value,
19
20 pub embedding: Option<Vec<f32>>,
22}
23
24#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct Action {
27 pub id: String,
29
30 pub timestamp: chrono::DateTime<chrono::Utc>,
32
33 pub action_type: String,
35
36 pub parameters: serde_json::Value,
38
39 pub predicted_outcome: Option<f64>,
41}
42
43#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct Trajectory {
46 pub id: String,
48
49 pub agent_id: String,
51
52 pub start_time: chrono::DateTime<chrono::Utc>,
54
55 pub end_time: Option<chrono::DateTime<chrono::Utc>>,
57
58 pub observations: Vec<Observation>,
60
61 pub actions: Vec<Action>,
63
64 pub outcomes: Vec<f64>,
66
67 pub metadata: serde_json::Value,
69}
70
71impl Trajectory {
72 pub fn new(agent_id: String) -> Self {
74 Self {
75 id: Uuid::new_v4().to_string(),
76 agent_id,
77 start_time: chrono::Utc::now(),
78 end_time: None,
79 observations: Vec::new(),
80 actions: Vec::new(),
81 outcomes: Vec::new(),
82 metadata: serde_json::json!({}),
83 }
84 }
85
86 pub fn add_observation(&mut self, data: serde_json::Value, embedding: Option<Vec<f32>>) {
88 self.observations.push(Observation {
89 id: Uuid::new_v4().to_string(),
90 timestamp: chrono::Utc::now(),
91 data,
92 embedding,
93 });
94 }
95
96 pub fn add_action(
98 &mut self,
99 action_type: String,
100 parameters: serde_json::Value,
101 predicted_outcome: Option<f64>,
102 ) {
103 self.actions.push(Action {
104 id: Uuid::new_v4().to_string(),
105 timestamp: chrono::Utc::now(),
106 action_type,
107 parameters,
108 predicted_outcome,
109 });
110 }
111
112 pub fn add_outcome(&mut self, outcome: f64) {
114 self.outcomes.push(outcome);
115 }
116
117 pub fn complete(&mut self) {
119 self.end_time = Some(chrono::Utc::now());
120 }
121
122 pub fn is_complete(&self) -> bool {
124 self.end_time.is_some()
125 }
126
127 pub fn score(&self) -> Option<f64> {
129 if self.outcomes.is_empty() {
130 return None;
131 }
132
133 let sum: f64 = self.outcomes.iter().sum();
134 Some(sum / self.outcomes.len() as f64)
135 }
136}
137
138pub struct TrajectoryTracker {
140 active: Arc<RwLock<std::collections::HashMap<String, Trajectory>>>,
142
143 completed: Arc<RwLock<Vec<Trajectory>>>,
145}
146
147impl TrajectoryTracker {
148 pub fn new() -> Self {
150 Self {
151 active: Arc::new(RwLock::new(std::collections::HashMap::new())),
152 completed: Arc::new(RwLock::new(Vec::new())),
153 }
154 }
155
156 pub async fn start(&self, agent_id: String) -> String {
158 let trajectory = Trajectory::new(agent_id);
159 let id = trajectory.id.clone();
160
161 let mut active = self.active.write().await;
162 active.insert(id.clone(), trajectory);
163
164 id
165 }
166
167 pub async fn track(&self, trajectory: Trajectory) -> anyhow::Result<()> {
169 let mut active = self.active.write().await;
170 active.insert(trajectory.id.clone(), trajectory);
171 Ok(())
172 }
173
174 pub async fn get_active(&self, id: &str) -> Option<Trajectory> {
176 let active = self.active.read().await;
177 active.get(id).cloned()
178 }
179
180 pub async fn complete(&self, id: &str) -> anyhow::Result<()> {
182 let mut active = self.active.write().await;
183
184 if let Some(mut trajectory) = active.remove(id) {
185 trajectory.complete();
186
187 let mut completed = self.completed.write().await;
188 completed.push(trajectory);
189 }
190
191 Ok(())
192 }
193
194 pub async fn get_completed(&self, agent_id: Option<&str>) -> Vec<Trajectory> {
196 let completed = self.completed.read().await;
197
198 if let Some(agent_id) = agent_id {
199 completed
200 .iter()
201 .filter(|t| t.agent_id == agent_id)
202 .cloned()
203 .collect()
204 } else {
205 completed.clone()
206 }
207 }
208
209 pub fn count(&self) -> usize {
211 let active = self.active.blocking_read();
213 let completed = self.completed.blocking_read();
214 active.len() + completed.len()
215 }
216
217 pub async fn get_successful(&self) -> Vec<Trajectory> {
219 let completed = self.completed.read().await;
220
221 completed
222 .iter()
223 .filter(|t| t.score().map(|s| s > 0.5).unwrap_or(false))
224 .cloned()
225 .collect()
226 }
227}
228
229impl Default for TrajectoryTracker {
230 fn default() -> Self {
231 Self::new()
232 }
233}
234
235#[cfg(test)]
236mod tests {
237 use super::*;
238
239 #[tokio::test]
240 async fn test_trajectory_creation() {
241 let mut trajectory = Trajectory::new("agent_1".to_string());
242
243 trajectory.add_observation(serde_json::json!({"price": 100.0}), None);
244 trajectory.add_action(
245 "buy".to_string(),
246 serde_json::json!({"quantity": 10}),
247 Some(110.0),
248 );
249 trajectory.add_outcome(105.0);
250
251 assert_eq!(trajectory.observations.len(), 1);
252 assert_eq!(trajectory.actions.len(), 1);
253 assert_eq!(trajectory.outcomes.len(), 1);
254 }
255
256 #[tokio::test]
257 async fn test_trajectory_tracker() {
258 let tracker = TrajectoryTracker::new();
259
260 let id = tracker.start("agent_1".to_string()).await;
262
263 let trajectory = tracker.get_active(&id).await;
265 assert!(trajectory.is_some());
266
267 tracker.complete(&id).await.unwrap();
269
270 let completed = tracker.get_completed(Some("agent_1")).await;
272 assert_eq!(completed.len(), 1);
273 }
274
275 #[test]
276 fn test_trajectory_score() {
277 let mut trajectory = Trajectory::new("agent_1".to_string());
278
279 trajectory.add_outcome(0.8);
280 trajectory.add_outcome(0.6);
281 trajectory.add_outcome(0.7);
282
283 let score = trajectory.score().unwrap();
284 assert!((score - 0.7).abs() < 0.01);
285 }
286}