converge_knowledge/agentic/
sessions.rs1use chrono::{DateTime, Utc};
10use serde::{Deserialize, Serialize};
11use uuid::Uuid;
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct LearningSession {
16 pub id: Uuid,
18
19 pub goal: String,
21
22 pub turns: Vec<SessionTurn>,
24
25 pub outcome: SessionOutcome,
27
28 pub started_at: DateTime<Utc>,
30
31 pub ended_at: Option<DateTime<Utc>>,
33
34 pub total_reward: f32,
36
37 pub metadata: std::collections::HashMap<String, String>,
39}
40
41impl LearningSession {
42 pub fn new(goal: impl Into<String>) -> Self {
44 Self {
45 id: Uuid::new_v4(),
46 goal: goal.into(),
47 turns: Vec::new(),
48 outcome: SessionOutcome::InProgress,
49 started_at: Utc::now(),
50 ended_at: None,
51 total_reward: 0.0,
52 metadata: std::collections::HashMap::new(),
53 }
54 }
55
56 pub fn add_turn(&mut self, turn: SessionTurn) {
58 self.total_reward += turn.reward.value();
59 self.turns.push(turn);
60 }
61
62 pub fn complete(&mut self, success: bool) {
64 self.ended_at = Some(Utc::now());
65 self.outcome = if success {
66 SessionOutcome::Success
67 } else {
68 SessionOutcome::Failure
69 };
70 }
71
72 pub fn abort(&mut self, reason: impl Into<String>) {
74 self.ended_at = Some(Utc::now());
75 self.outcome = SessionOutcome::Aborted(reason.into());
76 }
77
78 pub fn duration(&self) -> chrono::Duration {
80 let end = self.ended_at.unwrap_or_else(Utc::now);
81 end - self.started_at
82 }
83
84 pub fn to_trajectory(&self) -> Vec<(String, String, f32)> {
86 self.turns
87 .iter()
88 .map(|t| (t.action.clone(), t.observation.clone(), t.reward.value()))
89 .collect()
90 }
91
92 pub fn discounted_return(&self, gamma: f32) -> f32 {
94 let mut total = 0.0;
95
96 for turn in self.turns.iter().rev() {
97 total = turn.reward.value() + gamma * total;
98 }
99
100 total
101 }
102}
103
104#[derive(Debug, Clone, Serialize, Deserialize)]
106pub struct SessionTurn {
107 pub action: String,
109
110 pub observation: String,
112
113 pub reward: Reward,
115
116 pub timestamp: DateTime<Utc>,
118}
119
120impl SessionTurn {
121 pub fn new(action: impl Into<String>, observation: impl Into<String>, reward: Reward) -> Self {
123 Self {
124 action: action.into(),
125 observation: observation.into(),
126 reward,
127 timestamp: Utc::now(),
128 }
129 }
130}
131
132#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
134pub enum Reward {
135 Positive(f32),
137
138 Negative(f32),
140
141 Neutral,
143
144 Terminal(f32),
146}
147
148impl Reward {
149 pub fn value(&self) -> f32 {
151 match self {
152 Reward::Positive(v) => *v,
153 Reward::Negative(v) => -*v,
154 Reward::Neutral => 0.0,
155 Reward::Terminal(v) => *v,
156 }
157 }
158
159 pub fn is_positive(&self) -> bool {
161 self.value() > 0.0
162 }
163
164 pub fn is_negative(&self) -> bool {
166 self.value() < 0.0
167 }
168}
169
170#[derive(Debug, Clone, Serialize, Deserialize)]
172pub enum SessionOutcome {
173 InProgress,
175
176 Success,
178
179 Failure,
181
182 Aborted(String),
184}
185
186impl SessionOutcome {
187 pub fn is_completed(&self) -> bool {
189 matches!(self, SessionOutcome::Success | SessionOutcome::Failure)
190 }
191
192 pub fn is_success(&self) -> bool {
194 matches!(self, SessionOutcome::Success)
195 }
196}
197
198#[cfg(test)]
199mod tests {
200 use super::*;
201
202 #[test]
210 fn test_session_recording() {
211 let mut session = LearningSession::new("Find and fix the bug");
212
213 session.add_turn(SessionTurn::new(
215 "read_file main.rs",
216 "Found suspicious null check on line 42",
217 Reward::Positive(0.1), ));
219
220 session.add_turn(SessionTurn::new(
222 "edit main.rs: remove null check",
223 "Compilation error: cannot assign to immutable",
224 Reward::Negative(0.2), ));
226
227 session.add_turn(SessionTurn::new(
229 "edit main.rs: add mut keyword",
230 "File saved successfully",
231 Reward::Neutral,
232 ));
233
234 session.add_turn(SessionTurn::new(
236 "run_tests",
237 "All 15 tests passing",
238 Reward::Terminal(1.0), ));
240
241 session.complete(true);
242
243 assert_eq!(session.turns.len(), 4);
244 assert!(session.outcome.is_success());
245
246 assert!((session.total_reward - 0.9).abs() < 0.01);
248 }
249
250 #[test]
257 fn test_discounted_return() {
258 let mut session = LearningSession::new("Test gamma");
259
260 session.add_turn(SessionTurn::new("a1", "o1", Reward::Positive(1.0)));
261 session.add_turn(SessionTurn::new("a2", "o2", Reward::Positive(1.0)));
262 session.add_turn(SessionTurn::new("a3", "o3", Reward::Positive(1.0)));
263
264 let g = session.discounted_return(0.9);
268 assert!((g - 2.71).abs() < 0.01);
269
270 let g = session.discounted_return(0.5);
273 assert!((g - 1.75).abs() < 0.01);
274 }
275
276 #[test]
283 fn test_trajectory() {
284 let mut session = LearningSession::new("Demo");
285
286 session.add_turn(SessionTurn::new("step1", "result1", Reward::Positive(0.5)));
287 session.add_turn(SessionTurn::new("step2", "result2", Reward::Negative(0.1)));
288
289 let trajectory = session.to_trajectory();
290
291 assert_eq!(trajectory.len(), 2);
292 assert_eq!(trajectory[0].0, "step1");
293 assert!((trajectory[0].2 - 0.5).abs() < 0.01);
294 assert!((trajectory[1].2 - (-0.1)).abs() < 0.01);
295 }
296}