Skip to main content

converge_knowledge/agentic/
sessions.rs

1//! Learning Sessions - RL Training Data
2//!
3//! Implements learning session tracking for reinforcement learning:
4//! 1. Track agent actions and observations
5//! 2. Record rewards (positive/negative/neutral)
6//! 3. Build trajectories for training
7//! 4. Support offline RL and online learning
8
9use chrono::{DateTime, Utc};
10use serde::{Deserialize, Serialize};
11use uuid::Uuid;
12
13/// A learning session capturing agent interactions.
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct LearningSession {
16    /// Unique identifier.
17    pub id: Uuid,
18
19    /// Session goal or task description.
20    pub goal: String,
21
22    /// Session turns (action-observation-reward tuples).
23    pub turns: Vec<SessionTurn>,
24
25    /// Session outcome.
26    pub outcome: SessionOutcome,
27
28    /// Session start time.
29    pub started_at: DateTime<Utc>,
30
31    /// Session end time.
32    pub ended_at: Option<DateTime<Utc>>,
33
34    /// Total reward accumulated.
35    pub total_reward: f32,
36
37    /// Session metadata.
38    pub metadata: std::collections::HashMap<String, String>,
39}
40
41impl LearningSession {
42    /// Create a new learning session.
43    pub fn new(goal: impl Into<String>) -> Self {
44        Self {
45            id: Uuid::new_v4(),
46            goal: goal.into(),
47            turns: Vec::new(),
48            outcome: SessionOutcome::InProgress,
49            started_at: Utc::now(),
50            ended_at: None,
51            total_reward: 0.0,
52            metadata: std::collections::HashMap::new(),
53        }
54    }
55
56    /// Add a turn to the session.
57    pub fn add_turn(&mut self, turn: SessionTurn) {
58        self.total_reward += turn.reward.value();
59        self.turns.push(turn);
60    }
61
62    /// Complete the session.
63    pub fn complete(&mut self, success: bool) {
64        self.ended_at = Some(Utc::now());
65        self.outcome = if success {
66            SessionOutcome::Success
67        } else {
68            SessionOutcome::Failure
69        };
70    }
71
72    /// Abort the session.
73    pub fn abort(&mut self, reason: impl Into<String>) {
74        self.ended_at = Some(Utc::now());
75        self.outcome = SessionOutcome::Aborted(reason.into());
76    }
77
78    /// Get session duration.
79    pub fn duration(&self) -> chrono::Duration {
80        let end = self.ended_at.unwrap_or_else(Utc::now);
81        end - self.started_at
82    }
83
84    /// Get trajectory for RL training.
85    pub fn to_trajectory(&self) -> Vec<(String, String, f32)> {
86        self.turns
87            .iter()
88            .map(|t| (t.action.clone(), t.observation.clone(), t.reward.value()))
89            .collect()
90    }
91
92    /// Calculate discounted return (for RL).
93    pub fn discounted_return(&self, gamma: f32) -> f32 {
94        let mut total = 0.0;
95        let mut discount = 1.0;
96
97        for turn in self.turns.iter().rev() {
98            total = turn.reward.value() + gamma * total;
99            discount *= gamma;
100        }
101
102        total
103    }
104}
105
106/// A single turn in a learning session.
107#[derive(Debug, Clone, Serialize, Deserialize)]
108pub struct SessionTurn {
109    /// Action taken by the agent.
110    pub action: String,
111
112    /// Observation/result of the action.
113    pub observation: String,
114
115    /// Reward received.
116    pub reward: Reward,
117
118    /// Turn timestamp.
119    pub timestamp: DateTime<Utc>,
120}
121
122impl SessionTurn {
123    /// Create a new turn.
124    pub fn new(action: impl Into<String>, observation: impl Into<String>, reward: Reward) -> Self {
125        Self {
126            action: action.into(),
127            observation: observation.into(),
128            reward,
129            timestamp: Utc::now(),
130        }
131    }
132}
133
134/// Reward signal.
135#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
136pub enum Reward {
137    /// Positive reward.
138    Positive(f32),
139
140    /// Negative reward (penalty).
141    Negative(f32),
142
143    /// No reward.
144    Neutral,
145
146    /// Sparse reward at end of episode.
147    Terminal(f32),
148}
149
150impl Reward {
151    /// Get numeric value.
152    pub fn value(&self) -> f32 {
153        match self {
154            Reward::Positive(v) => *v,
155            Reward::Negative(v) => -*v,
156            Reward::Neutral => 0.0,
157            Reward::Terminal(v) => *v,
158        }
159    }
160
161    /// Check if positive.
162    pub fn is_positive(&self) -> bool {
163        self.value() > 0.0
164    }
165
166    /// Check if negative.
167    pub fn is_negative(&self) -> bool {
168        self.value() < 0.0
169    }
170}
171
172/// Session outcome.
173#[derive(Debug, Clone, Serialize, Deserialize)]
174pub enum SessionOutcome {
175    /// Session still in progress.
176    InProgress,
177
178    /// Session completed successfully.
179    Success,
180
181    /// Session failed.
182    Failure,
183
184    /// Session aborted with reason.
185    Aborted(String),
186}
187
188impl SessionOutcome {
189    /// Check if completed (success or failure).
190    pub fn is_completed(&self) -> bool {
191        matches!(self, SessionOutcome::Success | SessionOutcome::Failure)
192    }
193
194    /// Check if successful.
195    pub fn is_success(&self) -> bool {
196        matches!(self, SessionOutcome::Success)
197    }
198}
199
200#[cfg(test)]
201mod tests {
202    use super::*;
203
204    /// Test: Recording a learning session.
205    ///
206    /// What happens:
207    /// 1. Start a session with a goal
208    /// 2. Record turns (action → observation → reward)
209    /// 3. Track cumulative reward
210    /// 4. Complete session with outcome
211    #[test]
212    fn test_session_recording() {
213        let mut session = LearningSession::new("Find and fix the bug");
214
215        // Turn 1: Read code
216        session.add_turn(SessionTurn::new(
217            "read_file main.rs",
218            "Found suspicious null check on line 42",
219            Reward::Positive(0.1), // Small reward for progress
220        ));
221
222        // Turn 2: Make a wrong change
223        session.add_turn(SessionTurn::new(
224            "edit main.rs: remove null check",
225            "Compilation error: cannot assign to immutable",
226            Reward::Negative(0.2), // Penalty for error
227        ));
228
229        // Turn 3: Fix the fix
230        session.add_turn(SessionTurn::new(
231            "edit main.rs: add mut keyword",
232            "File saved successfully",
233            Reward::Neutral,
234        ));
235
236        // Turn 4: Test
237        session.add_turn(SessionTurn::new(
238            "run_tests",
239            "All 15 tests passing",
240            Reward::Terminal(1.0), // Big reward for success!
241        ));
242
243        session.complete(true);
244
245        assert_eq!(session.turns.len(), 4);
246        assert!(session.outcome.is_success());
247
248        // Check total reward: 0.1 - 0.2 + 0 + 1.0 = 0.9
249        assert!((session.total_reward - 0.9).abs() < 0.01);
250    }
251
252    /// Test: Discounted return calculation.
253    ///
254    /// What happens:
255    /// 1. Record rewards over time
256    /// 2. Calculate discounted return with gamma
257    /// 3. Later rewards contribute less (temporal discounting)
258    #[test]
259    fn test_discounted_return() {
260        let mut session = LearningSession::new("Test gamma");
261
262        session.add_turn(SessionTurn::new("a1", "o1", Reward::Positive(1.0)));
263        session.add_turn(SessionTurn::new("a2", "o2", Reward::Positive(1.0)));
264        session.add_turn(SessionTurn::new("a3", "o3", Reward::Positive(1.0)));
265
266        // With gamma=0.9:
267        // G = r1 + 0.9*r2 + 0.81*r3
268        // G = 1.0 + 0.9 + 0.81 = 2.71
269        let g = session.discounted_return(0.9);
270        assert!((g - 2.71).abs() < 0.01);
271
272        // With gamma=0.5 (more discounting):
273        // G = 1.0 + 0.5 + 0.25 = 1.75
274        let g = session.discounted_return(0.5);
275        assert!((g - 1.75).abs() < 0.01);
276    }
277
278    /// Test: Trajectory extraction for RL.
279    ///
280    /// What happens:
281    /// 1. Session is converted to trajectory
282    /// 2. Trajectory is list of (action, observation, reward) tuples
283    /// 3. Can be used for offline RL training
284    #[test]
285    fn test_trajectory() {
286        let mut session = LearningSession::new("Demo");
287
288        session.add_turn(SessionTurn::new("step1", "result1", Reward::Positive(0.5)));
289        session.add_turn(SessionTurn::new("step2", "result2", Reward::Negative(0.1)));
290
291        let trajectory = session.to_trajectory();
292
293        assert_eq!(trajectory.len(), 2);
294        assert_eq!(trajectory[0].0, "step1");
295        assert!((trajectory[0].2 - 0.5).abs() < 0.01);
296        assert!((trajectory[1].2 - (-0.1)).abs() < 0.01);
297    }
298}