converge_knowledge/agentic/
sessions.rs1use chrono::{DateTime, Utc};
10use serde::{Deserialize, Serialize};
11use uuid::Uuid;
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct LearningSession {
16 pub id: Uuid,
18
19 pub goal: String,
21
22 pub turns: Vec<SessionTurn>,
24
25 pub outcome: SessionOutcome,
27
28 pub started_at: DateTime<Utc>,
30
31 pub ended_at: Option<DateTime<Utc>>,
33
34 pub total_reward: f32,
36
37 pub metadata: std::collections::HashMap<String, String>,
39}
40
41impl LearningSession {
42 pub fn new(goal: impl Into<String>) -> Self {
44 Self {
45 id: Uuid::new_v4(),
46 goal: goal.into(),
47 turns: Vec::new(),
48 outcome: SessionOutcome::InProgress,
49 started_at: Utc::now(),
50 ended_at: None,
51 total_reward: 0.0,
52 metadata: std::collections::HashMap::new(),
53 }
54 }
55
56 pub fn add_turn(&mut self, turn: SessionTurn) {
58 self.total_reward += turn.reward.value();
59 self.turns.push(turn);
60 }
61
62 pub fn complete(&mut self, success: bool) {
64 self.ended_at = Some(Utc::now());
65 self.outcome = if success {
66 SessionOutcome::Success
67 } else {
68 SessionOutcome::Failure
69 };
70 }
71
72 pub fn abort(&mut self, reason: impl Into<String>) {
74 self.ended_at = Some(Utc::now());
75 self.outcome = SessionOutcome::Aborted(reason.into());
76 }
77
78 pub fn duration(&self) -> chrono::Duration {
80 let end = self.ended_at.unwrap_or_else(Utc::now);
81 end - self.started_at
82 }
83
84 pub fn to_trajectory(&self) -> Vec<(String, String, f32)> {
86 self.turns
87 .iter()
88 .map(|t| (t.action.clone(), t.observation.clone(), t.reward.value()))
89 .collect()
90 }
91
92 pub fn discounted_return(&self, gamma: f32) -> f32 {
94 let mut total = 0.0;
95 let mut discount = 1.0;
96
97 for turn in self.turns.iter().rev() {
98 total = turn.reward.value() + gamma * total;
99 discount *= gamma;
100 }
101
102 total
103 }
104}
105
106#[derive(Debug, Clone, Serialize, Deserialize)]
108pub struct SessionTurn {
109 pub action: String,
111
112 pub observation: String,
114
115 pub reward: Reward,
117
118 pub timestamp: DateTime<Utc>,
120}
121
122impl SessionTurn {
123 pub fn new(action: impl Into<String>, observation: impl Into<String>, reward: Reward) -> Self {
125 Self {
126 action: action.into(),
127 observation: observation.into(),
128 reward,
129 timestamp: Utc::now(),
130 }
131 }
132}
133
134#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
136pub enum Reward {
137 Positive(f32),
139
140 Negative(f32),
142
143 Neutral,
145
146 Terminal(f32),
148}
149
150impl Reward {
151 pub fn value(&self) -> f32 {
153 match self {
154 Reward::Positive(v) => *v,
155 Reward::Negative(v) => -*v,
156 Reward::Neutral => 0.0,
157 Reward::Terminal(v) => *v,
158 }
159 }
160
161 pub fn is_positive(&self) -> bool {
163 self.value() > 0.0
164 }
165
166 pub fn is_negative(&self) -> bool {
168 self.value() < 0.0
169 }
170}
171
172#[derive(Debug, Clone, Serialize, Deserialize)]
174pub enum SessionOutcome {
175 InProgress,
177
178 Success,
180
181 Failure,
183
184 Aborted(String),
186}
187
188impl SessionOutcome {
189 pub fn is_completed(&self) -> bool {
191 matches!(self, SessionOutcome::Success | SessionOutcome::Failure)
192 }
193
194 pub fn is_success(&self) -> bool {
196 matches!(self, SessionOutcome::Success)
197 }
198}
199
200#[cfg(test)]
201mod tests {
202 use super::*;
203
204 #[test]
212 fn test_session_recording() {
213 let mut session = LearningSession::new("Find and fix the bug");
214
215 session.add_turn(SessionTurn::new(
217 "read_file main.rs",
218 "Found suspicious null check on line 42",
219 Reward::Positive(0.1), ));
221
222 session.add_turn(SessionTurn::new(
224 "edit main.rs: remove null check",
225 "Compilation error: cannot assign to immutable",
226 Reward::Negative(0.2), ));
228
229 session.add_turn(SessionTurn::new(
231 "edit main.rs: add mut keyword",
232 "File saved successfully",
233 Reward::Neutral,
234 ));
235
236 session.add_turn(SessionTurn::new(
238 "run_tests",
239 "All 15 tests passing",
240 Reward::Terminal(1.0), ));
242
243 session.complete(true);
244
245 assert_eq!(session.turns.len(), 4);
246 assert!(session.outcome.is_success());
247
248 assert!((session.total_reward - 0.9).abs() < 0.01);
250 }
251
252 #[test]
259 fn test_discounted_return() {
260 let mut session = LearningSession::new("Test gamma");
261
262 session.add_turn(SessionTurn::new("a1", "o1", Reward::Positive(1.0)));
263 session.add_turn(SessionTurn::new("a2", "o2", Reward::Positive(1.0)));
264 session.add_turn(SessionTurn::new("a3", "o3", Reward::Positive(1.0)));
265
266 let g = session.discounted_return(0.9);
270 assert!((g - 2.71).abs() < 0.01);
271
272 let g = session.discounted_return(0.5);
275 assert!((g - 1.75).abs() < 0.01);
276 }
277
278 #[test]
285 fn test_trajectory() {
286 let mut session = LearningSession::new("Demo");
287
288 session.add_turn(SessionTurn::new("step1", "result1", Reward::Positive(0.5)));
289 session.add_turn(SessionTurn::new("step2", "result2", Reward::Negative(0.1)));
290
291 let trajectory = session.to_trajectory();
292
293 assert_eq!(trajectory.len(), 2);
294 assert_eq!(trajectory[0].0, "step1");
295 assert!((trajectory[0].2 - 0.5).abs() < 0.01);
296 assert!((trajectory[1].2 - (-0.1)).abs() < 0.01);
297 }
298}