Skip to main content

converge_knowledge/agentic/
sessions.rs

1//! Learning Sessions - RL Training Data
2//!
3//! Implements learning session tracking for reinforcement learning:
4//! 1. Track agent actions and observations
5//! 2. Record rewards (positive/negative/neutral)
6//! 3. Build trajectories for training
7//! 4. Support offline RL and online learning
8
9use chrono::{DateTime, Utc};
10use serde::{Deserialize, Serialize};
11use uuid::Uuid;
12
13/// A learning session capturing agent interactions.
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct LearningSession {
16    /// Unique identifier.
17    pub id: Uuid,
18
19    /// Session goal or task description.
20    pub goal: String,
21
22    /// Session turns (action-observation-reward tuples).
23    pub turns: Vec<SessionTurn>,
24
25    /// Session outcome.
26    pub outcome: SessionOutcome,
27
28    /// Session start time.
29    pub started_at: DateTime<Utc>,
30
31    /// Session end time.
32    pub ended_at: Option<DateTime<Utc>>,
33
34    /// Total reward accumulated.
35    pub total_reward: f32,
36
37    /// Session metadata.
38    pub metadata: std::collections::HashMap<String, String>,
39}
40
41impl LearningSession {
42    /// Create a new learning session.
43    pub fn new(goal: impl Into<String>) -> Self {
44        Self {
45            id: Uuid::new_v4(),
46            goal: goal.into(),
47            turns: Vec::new(),
48            outcome: SessionOutcome::InProgress,
49            started_at: Utc::now(),
50            ended_at: None,
51            total_reward: 0.0,
52            metadata: std::collections::HashMap::new(),
53        }
54    }
55
56    /// Add a turn to the session.
57    pub fn add_turn(&mut self, turn: SessionTurn) {
58        self.total_reward += turn.reward.value();
59        self.turns.push(turn);
60    }
61
62    /// Complete the session.
63    pub fn complete(&mut self, success: bool) {
64        self.ended_at = Some(Utc::now());
65        self.outcome = if success {
66            SessionOutcome::Success
67        } else {
68            SessionOutcome::Failure
69        };
70    }
71
72    /// Abort the session.
73    pub fn abort(&mut self, reason: impl Into<String>) {
74        self.ended_at = Some(Utc::now());
75        self.outcome = SessionOutcome::Aborted(reason.into());
76    }
77
78    /// Get session duration.
79    pub fn duration(&self) -> chrono::Duration {
80        let end = self.ended_at.unwrap_or_else(Utc::now);
81        end - self.started_at
82    }
83
84    /// Get trajectory for RL training.
85    pub fn to_trajectory(&self) -> Vec<(String, String, f32)> {
86        self.turns
87            .iter()
88            .map(|t| (t.action.clone(), t.observation.clone(), t.reward.value()))
89            .collect()
90    }
91
92    /// Calculate discounted return (for RL).
93    pub fn discounted_return(&self, gamma: f32) -> f32 {
94        let mut total = 0.0;
95
96        for turn in self.turns.iter().rev() {
97            total = turn.reward.value() + gamma * total;
98        }
99
100        total
101    }
102}
103
104/// A single turn in a learning session.
105#[derive(Debug, Clone, Serialize, Deserialize)]
106pub struct SessionTurn {
107    /// Action taken by the agent.
108    pub action: String,
109
110    /// Observation/result of the action.
111    pub observation: String,
112
113    /// Reward received.
114    pub reward: Reward,
115
116    /// Turn timestamp.
117    pub timestamp: DateTime<Utc>,
118}
119
120impl SessionTurn {
121    /// Create a new turn.
122    pub fn new(action: impl Into<String>, observation: impl Into<String>, reward: Reward) -> Self {
123        Self {
124            action: action.into(),
125            observation: observation.into(),
126            reward,
127            timestamp: Utc::now(),
128        }
129    }
130}
131
132/// Reward signal.
133#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
134pub enum Reward {
135    /// Positive reward.
136    Positive(f32),
137
138    /// Negative reward (penalty).
139    Negative(f32),
140
141    /// No reward.
142    Neutral,
143
144    /// Sparse reward at end of episode.
145    Terminal(f32),
146}
147
148impl Reward {
149    /// Get numeric value.
150    pub fn value(&self) -> f32 {
151        match self {
152            Reward::Positive(v) => *v,
153            Reward::Negative(v) => -*v,
154            Reward::Neutral => 0.0,
155            Reward::Terminal(v) => *v,
156        }
157    }
158
159    /// Check if positive.
160    pub fn is_positive(&self) -> bool {
161        self.value() > 0.0
162    }
163
164    /// Check if negative.
165    pub fn is_negative(&self) -> bool {
166        self.value() < 0.0
167    }
168}
169
170/// Session outcome.
171#[derive(Debug, Clone, Serialize, Deserialize)]
172pub enum SessionOutcome {
173    /// Session still in progress.
174    InProgress,
175
176    /// Session completed successfully.
177    Success,
178
179    /// Session failed.
180    Failure,
181
182    /// Session aborted with reason.
183    Aborted(String),
184}
185
186impl SessionOutcome {
187    /// Check if completed (success or failure).
188    pub fn is_completed(&self) -> bool {
189        matches!(self, SessionOutcome::Success | SessionOutcome::Failure)
190    }
191
192    /// Check if successful.
193    pub fn is_success(&self) -> bool {
194        matches!(self, SessionOutcome::Success)
195    }
196}
197
198#[cfg(test)]
199mod tests {
200    use super::*;
201
202    /// Test: Recording a learning session.
203    ///
204    /// What happens:
205    /// 1. Start a session with a goal
206    /// 2. Record turns (action → observation → reward)
207    /// 3. Track cumulative reward
208    /// 4. Complete session with outcome
209    #[test]
210    fn test_session_recording() {
211        let mut session = LearningSession::new("Find and fix the bug");
212
213        // Turn 1: Read code
214        session.add_turn(SessionTurn::new(
215            "read_file main.rs",
216            "Found suspicious null check on line 42",
217            Reward::Positive(0.1), // Small reward for progress
218        ));
219
220        // Turn 2: Make a wrong change
221        session.add_turn(SessionTurn::new(
222            "edit main.rs: remove null check",
223            "Compilation error: cannot assign to immutable",
224            Reward::Negative(0.2), // Penalty for error
225        ));
226
227        // Turn 3: Fix the fix
228        session.add_turn(SessionTurn::new(
229            "edit main.rs: add mut keyword",
230            "File saved successfully",
231            Reward::Neutral,
232        ));
233
234        // Turn 4: Test
235        session.add_turn(SessionTurn::new(
236            "run_tests",
237            "All 15 tests passing",
238            Reward::Terminal(1.0), // Big reward for success!
239        ));
240
241        session.complete(true);
242
243        assert_eq!(session.turns.len(), 4);
244        assert!(session.outcome.is_success());
245
246        // Check total reward: 0.1 - 0.2 + 0 + 1.0 = 0.9
247        assert!((session.total_reward - 0.9).abs() < 0.01);
248    }
249
250    /// Test: Discounted return calculation.
251    ///
252    /// What happens:
253    /// 1. Record rewards over time
254    /// 2. Calculate discounted return with gamma
255    /// 3. Later rewards contribute less (temporal discounting)
256    #[test]
257    fn test_discounted_return() {
258        let mut session = LearningSession::new("Test gamma");
259
260        session.add_turn(SessionTurn::new("a1", "o1", Reward::Positive(1.0)));
261        session.add_turn(SessionTurn::new("a2", "o2", Reward::Positive(1.0)));
262        session.add_turn(SessionTurn::new("a3", "o3", Reward::Positive(1.0)));
263
264        // With gamma=0.9:
265        // G = r1 + 0.9*r2 + 0.81*r3
266        // G = 1.0 + 0.9 + 0.81 = 2.71
267        let g = session.discounted_return(0.9);
268        assert!((g - 2.71).abs() < 0.01);
269
270        // With gamma=0.5 (more discounting):
271        // G = 1.0 + 0.5 + 0.25 = 1.75
272        let g = session.discounted_return(0.5);
273        assert!((g - 1.75).abs() < 0.01);
274    }
275
276    /// Test: Trajectory extraction for RL.
277    ///
278    /// What happens:
279    /// 1. Session is converted to trajectory
280    /// 2. Trajectory is list of (action, observation, reward) tuples
281    /// 3. Can be used for offline RL training
282    #[test]
283    fn test_trajectory() {
284        let mut session = LearningSession::new("Demo");
285
286        session.add_turn(SessionTurn::new("step1", "result1", Reward::Positive(0.5)));
287        session.add_turn(SessionTurn::new("step2", "result2", Reward::Negative(0.1)));
288
289        let trajectory = session.to_trajectory();
290
291        assert_eq!(trajectory.len(), 2);
292        assert_eq!(trajectory[0].0, "step1");
293        assert!((trajectory[0].2 - 0.5).abs() < 0.01);
294        assert!((trajectory[1].2 - (-0.1)).abs() < 0.01);
295    }
296}