Skip to main content

mnemos/agentic/
sessions.rs

1//! Learning Sessions - RL Training Data
2//!
3//! Implements learning session tracking for reinforcement learning:
4//! 1. Track agent actions and observations
5//! 2. Record rewards (positive/negative/neutral)
6//! 3. Build trajectories for training
7//! 4. Support offline RL and online learning
8
9use chrono::{DateTime, Utc};
10use serde::{Deserialize, Serialize};
11use uuid::Uuid;
12
13/// A learning session capturing agent interactions.
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct LearningSession {
16    /// Unique identifier.
17    pub id: Uuid,
18
19    /// Session goal or task description.
20    pub goal: String,
21
22    /// Session turns (action-observation-reward tuples).
23    pub turns: Vec<SessionTurn>,
24
25    /// Session outcome.
26    pub outcome: SessionOutcome,
27
28    /// Session start time.
29    pub started_at: DateTime<Utc>,
30
31    /// Session end time.
32    pub ended_at: Option<DateTime<Utc>>,
33
34    /// Total reward accumulated.
35    pub total_reward: f32,
36
37    /// Session metadata.
38    pub metadata: std::collections::HashMap<String, String>,
39}
40
41impl LearningSession {
42    /// Create a new learning session.
43    pub fn new(goal: impl Into<String>) -> Self {
44        Self {
45            id: Uuid::new_v4(),
46            goal: goal.into(),
47            turns: Vec::new(),
48            outcome: SessionOutcome::InProgress,
49            started_at: Utc::now(),
50            ended_at: None,
51            total_reward: 0.0,
52            metadata: std::collections::HashMap::new(),
53        }
54    }
55
56    /// Add a turn to the session.
57    pub fn add_turn(&mut self, turn: SessionTurn) {
58        self.total_reward += turn.reward.value();
59        self.turns.push(turn);
60    }
61
62    /// Complete the session.
63    pub fn complete(&mut self, success: bool) {
64        self.ended_at = Some(Utc::now());
65        self.outcome = if success {
66            SessionOutcome::Success
67        } else {
68            SessionOutcome::Failure
69        };
70    }
71
72    /// Abort the session.
73    pub fn abort(&mut self, reason: impl Into<String>) {
74        self.ended_at = Some(Utc::now());
75        self.outcome = SessionOutcome::Aborted(reason.into());
76    }
77
78    /// Get session duration.
79    pub fn duration(&self) -> chrono::Duration {
80        let end = self.ended_at.unwrap_or_else(Utc::now);
81        end - self.started_at
82    }
83
84    /// Get trajectory for RL training.
85    pub fn to_trajectory(&self) -> Vec<(String, String, f32)> {
86        self.turns
87            .iter()
88            .map(|t| (t.action.clone(), t.observation.clone(), t.reward.value()))
89            .collect()
90    }
91
92    /// Calculate discounted return (for RL).
93    pub fn discounted_return(&self, gamma: f32) -> f32 {
94        let mut total = 0.0;
95
96        for turn in self.turns.iter().rev() {
97            total = turn.reward.value() + gamma * total;
98        }
99
100        total
101    }
102}
103
104/// A single turn in a learning session.
105#[derive(Debug, Clone, Serialize, Deserialize)]
106pub struct SessionTurn {
107    /// Action taken by the agent.
108    pub action: String,
109
110    /// Observation/result of the action.
111    pub observation: String,
112
113    /// Reward received.
114    pub reward: Reward,
115
116    /// Turn timestamp.
117    pub timestamp: DateTime<Utc>,
118}
119
120impl SessionTurn {
121    /// Create a new turn.
122    pub fn new(action: impl Into<String>, observation: impl Into<String>, reward: Reward) -> Self {
123        Self {
124            action: action.into(),
125            observation: observation.into(),
126            reward,
127            timestamp: Utc::now(),
128        }
129    }
130}
131
132/// Reward signal.
133#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
134pub enum Reward {
135    /// Positive reward.
136    Positive(f32),
137
138    /// Negative reward (penalty).
139    Negative(f32),
140
141    /// No reward.
142    Neutral,
143
144    /// Sparse reward at end of episode.
145    Terminal(f32),
146}
147
148impl Reward {
149    /// Get numeric value.
150    pub fn value(&self) -> f32 {
151        match self {
152            Reward::Positive(v) | Reward::Terminal(v) => *v,
153            Reward::Negative(v) => -*v,
154            Reward::Neutral => 0.0,
155        }
156    }
157
158    /// Check if positive.
159    pub fn is_positive(&self) -> bool {
160        self.value() > 0.0
161    }
162
163    /// Check if negative.
164    pub fn is_negative(&self) -> bool {
165        self.value() < 0.0
166    }
167}
168
169/// Session outcome.
170#[derive(Debug, Clone, Serialize, Deserialize)]
171pub enum SessionOutcome {
172    /// Session still in progress.
173    InProgress,
174
175    /// Session completed successfully.
176    Success,
177
178    /// Session failed.
179    Failure,
180
181    /// Session aborted with reason.
182    Aborted(String),
183}
184
185impl SessionOutcome {
186    /// Check if completed (success or failure).
187    pub fn is_completed(&self) -> bool {
188        matches!(self, SessionOutcome::Success | SessionOutcome::Failure)
189    }
190
191    /// Check if successful.
192    pub fn is_success(&self) -> bool {
193        matches!(self, SessionOutcome::Success)
194    }
195}
196
197#[cfg(test)]
198mod tests {
199    use super::*;
200
201    /// Test: Recording a learning session.
202    ///
203    /// What happens:
204    /// 1. Start a session with a goal
205    /// 2. Record turns (action → observation → reward)
206    /// 3. Track cumulative reward
207    /// 4. Complete session with outcome
208    #[test]
209    fn test_session_recording() {
210        let mut session = LearningSession::new("Find and fix the bug");
211
212        // Turn 1: Read code
213        session.add_turn(SessionTurn::new(
214            "read_file main.rs",
215            "Found suspicious null check on line 42",
216            Reward::Positive(0.1), // Small reward for progress
217        ));
218
219        // Turn 2: Make a wrong change
220        session.add_turn(SessionTurn::new(
221            "edit main.rs: remove null check",
222            "Compilation error: cannot assign to immutable",
223            Reward::Negative(0.2), // Penalty for error
224        ));
225
226        // Turn 3: Fix the fix
227        session.add_turn(SessionTurn::new(
228            "edit main.rs: add mut keyword",
229            "File saved successfully",
230            Reward::Neutral,
231        ));
232
233        // Turn 4: Test
234        session.add_turn(SessionTurn::new(
235            "run_tests",
236            "All 15 tests passing",
237            Reward::Terminal(1.0), // Big reward for success!
238        ));
239
240        session.complete(true);
241
242        assert_eq!(session.turns.len(), 4);
243        assert!(session.outcome.is_success());
244
245        // Check total reward: 0.1 - 0.2 + 0 + 1.0 = 0.9
246        assert!((session.total_reward - 0.9).abs() < 0.01);
247    }
248
249    /// Test: Discounted return calculation.
250    ///
251    /// What happens:
252    /// 1. Record rewards over time
253    /// 2. Calculate discounted return with gamma
254    /// 3. Later rewards contribute less (temporal discounting)
255    #[test]
256    fn test_discounted_return() {
257        let mut session = LearningSession::new("Test gamma");
258
259        session.add_turn(SessionTurn::new("a1", "o1", Reward::Positive(1.0)));
260        session.add_turn(SessionTurn::new("a2", "o2", Reward::Positive(1.0)));
261        session.add_turn(SessionTurn::new("a3", "o3", Reward::Positive(1.0)));
262
263        // With gamma=0.9:
264        // G = r1 + 0.9*r2 + 0.81*r3
265        // G = 1.0 + 0.9 + 0.81 = 2.71
266        let g = session.discounted_return(0.9);
267        assert!((g - 2.71).abs() < 0.01);
268
269        // With gamma=0.5 (more discounting):
270        // G = 1.0 + 0.5 + 0.25 = 1.75
271        let g = session.discounted_return(0.5);
272        assert!((g - 1.75).abs() < 0.01);
273    }
274
275    /// Test: Trajectory extraction for RL.
276    ///
277    /// What happens:
278    /// 1. Session is converted to trajectory
279    /// 2. Trajectory is list of (action, observation, reward) tuples
280    /// 3. Can be used for offline RL training
281    #[test]
282    fn test_trajectory() {
283        let mut session = LearningSession::new("Demo");
284
285        session.add_turn(SessionTurn::new("step1", "result1", Reward::Positive(0.5)));
286        session.add_turn(SessionTurn::new("step2", "result2", Reward::Negative(0.1)));
287
288        let trajectory = session.to_trajectory();
289
290        assert_eq!(trajectory.len(), 2);
291        assert_eq!(trajectory[0].0, "step1");
292        assert!((trajectory[0].2 - 0.5).abs() < 0.01);
293        assert!((trajectory[1].2 - (-0.1)).abs() < 0.01);
294    }
295}