use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LearningSession {
pub id: Uuid,
pub goal: String,
pub turns: Vec<SessionTurn>,
pub outcome: SessionOutcome,
pub started_at: DateTime<Utc>,
pub ended_at: Option<DateTime<Utc>>,
pub total_reward: f32,
pub metadata: std::collections::HashMap<String, String>,
}
impl LearningSession {
pub fn new(goal: impl Into<String>) -> Self {
Self {
id: Uuid::new_v4(),
goal: goal.into(),
turns: Vec::new(),
outcome: SessionOutcome::InProgress,
started_at: Utc::now(),
ended_at: None,
total_reward: 0.0,
metadata: std::collections::HashMap::new(),
}
}
pub fn add_turn(&mut self, turn: SessionTurn) {
self.total_reward += turn.reward.value();
self.turns.push(turn);
}
pub fn complete(&mut self, success: bool) {
self.ended_at = Some(Utc::now());
self.outcome = if success {
SessionOutcome::Success
} else {
SessionOutcome::Failure
};
}
pub fn abort(&mut self, reason: impl Into<String>) {
self.ended_at = Some(Utc::now());
self.outcome = SessionOutcome::Aborted(reason.into());
}
pub fn duration(&self) -> chrono::Duration {
let end = self.ended_at.unwrap_or_else(Utc::now);
end - self.started_at
}
pub fn to_trajectory(&self) -> Vec<(String, String, f32)> {
self.turns
.iter()
.map(|t| (t.action.clone(), t.observation.clone(), t.reward.value()))
.collect()
}
pub fn discounted_return(&self, gamma: f32) -> f32 {
let mut total = 0.0;
for turn in self.turns.iter().rev() {
total = turn.reward.value() + gamma * total;
}
total
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SessionTurn {
pub action: String,
pub observation: String,
pub reward: Reward,
pub timestamp: DateTime<Utc>,
}
impl SessionTurn {
pub fn new(action: impl Into<String>, observation: impl Into<String>, reward: Reward) -> Self {
Self {
action: action.into(),
observation: observation.into(),
reward,
timestamp: Utc::now(),
}
}
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub enum Reward {
Positive(f32),
Negative(f32),
Neutral,
Terminal(f32),
}
impl Reward {
pub fn value(&self) -> f32 {
match self {
Reward::Positive(v) | Reward::Terminal(v) => *v,
Reward::Negative(v) => -*v,
Reward::Neutral => 0.0,
}
}
pub fn is_positive(&self) -> bool {
self.value() > 0.0
}
pub fn is_negative(&self) -> bool {
self.value() < 0.0
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum SessionOutcome {
InProgress,
Success,
Failure,
Aborted(String),
}
impl SessionOutcome {
pub fn is_completed(&self) -> bool {
matches!(self, SessionOutcome::Success | SessionOutcome::Failure)
}
pub fn is_success(&self) -> bool {
matches!(self, SessionOutcome::Success)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_session_recording() {
let mut session = LearningSession::new("Find and fix the bug");
session.add_turn(SessionTurn::new(
"read_file main.rs",
"Found suspicious null check on line 42",
Reward::Positive(0.1), ));
session.add_turn(SessionTurn::new(
"edit main.rs: remove null check",
"Compilation error: cannot assign to immutable",
Reward::Negative(0.2), ));
session.add_turn(SessionTurn::new(
"edit main.rs: add mut keyword",
"File saved successfully",
Reward::Neutral,
));
session.add_turn(SessionTurn::new(
"run_tests",
"All 15 tests passing",
Reward::Terminal(1.0), ));
session.complete(true);
assert_eq!(session.turns.len(), 4);
assert!(session.outcome.is_success());
assert!((session.total_reward - 0.9).abs() < 0.01);
}
#[test]
fn test_discounted_return() {
let mut session = LearningSession::new("Test gamma");
session.add_turn(SessionTurn::new("a1", "o1", Reward::Positive(1.0)));
session.add_turn(SessionTurn::new("a2", "o2", Reward::Positive(1.0)));
session.add_turn(SessionTurn::new("a3", "o3", Reward::Positive(1.0)));
let g = session.discounted_return(0.9);
assert!((g - 2.71).abs() < 0.01);
let g = session.discounted_return(0.5);
assert!((g - 1.75).abs() < 0.01);
}
#[test]
fn test_trajectory() {
let mut session = LearningSession::new("Demo");
session.add_turn(SessionTurn::new("step1", "result1", Reward::Positive(0.5)));
session.add_turn(SessionTurn::new("step2", "result2", Reward::Negative(0.1)));
let trajectory = session.to_trajectory();
assert_eq!(trajectory.len(), 2);
assert_eq!(trajectory[0].0, "step1");
assert!((trajectory[0].2 - 0.5).abs() < 0.01);
assert!((trajectory[1].2 - (-0.1)).abs() < 0.01);
}
}