use serde::{Serialize, Deserialize};
use std::sync::Arc;
use tokio::sync::RwLock;
use uuid::Uuid;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Observation {
pub id: String,
pub timestamp: chrono::DateTime<chrono::Utc>,
pub data: serde_json::Value,
pub embedding: Option<Vec<f32>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Action {
pub id: String,
pub timestamp: chrono::DateTime<chrono::Utc>,
pub action_type: String,
pub parameters: serde_json::Value,
pub predicted_outcome: Option<f64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Trajectory {
pub id: String,
pub agent_id: String,
pub start_time: chrono::DateTime<chrono::Utc>,
pub end_time: Option<chrono::DateTime<chrono::Utc>>,
pub observations: Vec<Observation>,
pub actions: Vec<Action>,
pub outcomes: Vec<f64>,
pub metadata: serde_json::Value,
}
impl Trajectory {
pub fn new(agent_id: String) -> Self {
Self {
id: Uuid::new_v4().to_string(),
agent_id,
start_time: chrono::Utc::now(),
end_time: None,
observations: Vec::new(),
actions: Vec::new(),
outcomes: Vec::new(),
metadata: serde_json::json!({}),
}
}
pub fn add_observation(&mut self, data: serde_json::Value, embedding: Option<Vec<f32>>) {
self.observations.push(Observation {
id: Uuid::new_v4().to_string(),
timestamp: chrono::Utc::now(),
data,
embedding,
});
}
pub fn add_action(
&mut self,
action_type: String,
parameters: serde_json::Value,
predicted_outcome: Option<f64>,
) {
self.actions.push(Action {
id: Uuid::new_v4().to_string(),
timestamp: chrono::Utc::now(),
action_type,
parameters,
predicted_outcome,
});
}
pub fn add_outcome(&mut self, outcome: f64) {
self.outcomes.push(outcome);
}
pub fn complete(&mut self) {
self.end_time = Some(chrono::Utc::now());
}
pub fn is_complete(&self) -> bool {
self.end_time.is_some()
}
pub fn score(&self) -> Option<f64> {
if self.outcomes.is_empty() {
return None;
}
let sum: f64 = self.outcomes.iter().sum();
Some(sum / self.outcomes.len() as f64)
}
}
pub struct TrajectoryTracker {
active: Arc<RwLock<std::collections::HashMap<String, Trajectory>>>,
completed: Arc<RwLock<Vec<Trajectory>>>,
}
impl TrajectoryTracker {
pub fn new() -> Self {
Self {
active: Arc::new(RwLock::new(std::collections::HashMap::new())),
completed: Arc::new(RwLock::new(Vec::new())),
}
}
pub async fn start(&self, agent_id: String) -> String {
let trajectory = Trajectory::new(agent_id);
let id = trajectory.id.clone();
let mut active = self.active.write().await;
active.insert(id.clone(), trajectory);
id
}
pub async fn track(&self, trajectory: Trajectory) -> anyhow::Result<()> {
let mut active = self.active.write().await;
active.insert(trajectory.id.clone(), trajectory);
Ok(())
}
pub async fn get_active(&self, id: &str) -> Option<Trajectory> {
let active = self.active.read().await;
active.get(id).cloned()
}
pub async fn complete(&self, id: &str) -> anyhow::Result<()> {
let mut active = self.active.write().await;
if let Some(mut trajectory) = active.remove(id) {
trajectory.complete();
let mut completed = self.completed.write().await;
completed.push(trajectory);
}
Ok(())
}
pub async fn get_completed(&self, agent_id: Option<&str>) -> Vec<Trajectory> {
let completed = self.completed.read().await;
if let Some(agent_id) = agent_id {
completed
.iter()
.filter(|t| t.agent_id == agent_id)
.cloned()
.collect()
} else {
completed.clone()
}
}
pub fn count(&self) -> usize {
let active = self.active.blocking_read();
let completed = self.completed.blocking_read();
active.len() + completed.len()
}
pub async fn get_successful(&self) -> Vec<Trajectory> {
let completed = self.completed.read().await;
completed
.iter()
.filter(|t| t.score().map(|s| s > 0.5).unwrap_or(false))
.cloned()
.collect()
}
}
impl Default for TrajectoryTracker {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_trajectory_creation() {
let mut trajectory = Trajectory::new("agent_1".to_string());
trajectory.add_observation(serde_json::json!({"price": 100.0}), None);
trajectory.add_action(
"buy".to_string(),
serde_json::json!({"quantity": 10}),
Some(110.0),
);
trajectory.add_outcome(105.0);
assert_eq!(trajectory.observations.len(), 1);
assert_eq!(trajectory.actions.len(), 1);
assert_eq!(trajectory.outcomes.len(), 1);
}
#[tokio::test]
async fn test_trajectory_tracker() {
let tracker = TrajectoryTracker::new();
let id = tracker.start("agent_1".to_string()).await;
let trajectory = tracker.get_active(&id).await;
assert!(trajectory.is_some());
tracker.complete(&id).await.unwrap();
let completed = tracker.get_completed(Some("agent_1")).await;
assert_eq!(completed.len(), 1);
}
#[test]
fn test_trajectory_score() {
let mut trajectory = Trajectory::new("agent_1".to_string());
trajectory.add_outcome(0.8);
trajectory.add_outcome(0.6);
trajectory.add_outcome(0.7);
let score = trajectory.score().unwrap();
assert!((score - 0.7).abs() < 0.01);
}
}