1use chrono::{DateTime, Utc};
10use serde::{Deserialize, Serialize};
11use uuid::Uuid;
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct LearningSession {
16 pub id: Uuid,
18
19 pub goal: String,
21
22 pub turns: Vec<SessionTurn>,
24
25 pub outcome: SessionOutcome,
27
28 pub started_at: DateTime<Utc>,
30
31 pub ended_at: Option<DateTime<Utc>>,
33
34 pub total_reward: f32,
36
37 pub metadata: std::collections::HashMap<String, String>,
39}
40
41impl LearningSession {
42 pub fn new(goal: impl Into<String>) -> Self {
44 Self {
45 id: Uuid::new_v4(),
46 goal: goal.into(),
47 turns: Vec::new(),
48 outcome: SessionOutcome::InProgress,
49 started_at: Utc::now(),
50 ended_at: None,
51 total_reward: 0.0,
52 metadata: std::collections::HashMap::new(),
53 }
54 }
55
56 pub fn add_turn(&mut self, turn: SessionTurn) {
58 self.total_reward += turn.reward.value();
59 self.turns.push(turn);
60 }
61
62 pub fn complete(&mut self, success: bool) {
64 self.ended_at = Some(Utc::now());
65 self.outcome = if success {
66 SessionOutcome::Success
67 } else {
68 SessionOutcome::Failure
69 };
70 }
71
72 pub fn abort(&mut self, reason: impl Into<String>) {
74 self.ended_at = Some(Utc::now());
75 self.outcome = SessionOutcome::Aborted(reason.into());
76 }
77
78 pub fn duration(&self) -> chrono::Duration {
80 let end = self.ended_at.unwrap_or_else(Utc::now);
81 end - self.started_at
82 }
83
84 pub fn to_trajectory(&self) -> Vec<(String, String, f32)> {
86 self.turns
87 .iter()
88 .map(|t| (t.action.clone(), t.observation.clone(), t.reward.value()))
89 .collect()
90 }
91
92 pub fn discounted_return(&self, gamma: f32) -> f32 {
94 let mut total = 0.0;
95
96 for turn in self.turns.iter().rev() {
97 total = turn.reward.value() + gamma * total;
98 }
99
100 total
101 }
102}
103
104#[derive(Debug, Clone, Serialize, Deserialize)]
106pub struct SessionTurn {
107 pub action: String,
109
110 pub observation: String,
112
113 pub reward: Reward,
115
116 pub timestamp: DateTime<Utc>,
118}
119
120impl SessionTurn {
121 pub fn new(action: impl Into<String>, observation: impl Into<String>, reward: Reward) -> Self {
123 Self {
124 action: action.into(),
125 observation: observation.into(),
126 reward,
127 timestamp: Utc::now(),
128 }
129 }
130}
131
132#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
134pub enum Reward {
135 Positive(f32),
137
138 Negative(f32),
140
141 Neutral,
143
144 Terminal(f32),
146}
147
148impl Reward {
149 pub fn value(&self) -> f32 {
151 match self {
152 Reward::Positive(v) | Reward::Terminal(v) => *v,
153 Reward::Negative(v) => -*v,
154 Reward::Neutral => 0.0,
155 }
156 }
157
158 pub fn is_positive(&self) -> bool {
160 self.value() > 0.0
161 }
162
163 pub fn is_negative(&self) -> bool {
165 self.value() < 0.0
166 }
167}
168
169#[derive(Debug, Clone, Serialize, Deserialize)]
171pub enum SessionOutcome {
172 InProgress,
174
175 Success,
177
178 Failure,
180
181 Aborted(String),
183}
184
185impl SessionOutcome {
186 pub fn is_completed(&self) -> bool {
188 matches!(self, SessionOutcome::Success | SessionOutcome::Failure)
189 }
190
191 pub fn is_success(&self) -> bool {
193 matches!(self, SessionOutcome::Success)
194 }
195}
196
197#[cfg(test)]
198mod tests {
199 use super::*;
200
201 #[test]
209 fn test_session_recording() {
210 let mut session = LearningSession::new("Find and fix the bug");
211
212 session.add_turn(SessionTurn::new(
214 "read_file main.rs",
215 "Found suspicious null check on line 42",
216 Reward::Positive(0.1), ));
218
219 session.add_turn(SessionTurn::new(
221 "edit main.rs: remove null check",
222 "Compilation error: cannot assign to immutable",
223 Reward::Negative(0.2), ));
225
226 session.add_turn(SessionTurn::new(
228 "edit main.rs: add mut keyword",
229 "File saved successfully",
230 Reward::Neutral,
231 ));
232
233 session.add_turn(SessionTurn::new(
235 "run_tests",
236 "All 15 tests passing",
237 Reward::Terminal(1.0), ));
239
240 session.complete(true);
241
242 assert_eq!(session.turns.len(), 4);
243 assert!(session.outcome.is_success());
244
245 assert!((session.total_reward - 0.9).abs() < 0.01);
247 }
248
249 #[test]
256 fn test_discounted_return() {
257 let mut session = LearningSession::new("Test gamma");
258
259 session.add_turn(SessionTurn::new("a1", "o1", Reward::Positive(1.0)));
260 session.add_turn(SessionTurn::new("a2", "o2", Reward::Positive(1.0)));
261 session.add_turn(SessionTurn::new("a3", "o3", Reward::Positive(1.0)));
262
263 let g = session.discounted_return(0.9);
267 assert!((g - 2.71).abs() < 0.01);
268
269 let g = session.discounted_return(0.5);
272 assert!((g - 1.75).abs() < 0.01);
273 }
274
275 #[test]
282 fn test_trajectory() {
283 let mut session = LearningSession::new("Demo");
284
285 session.add_turn(SessionTurn::new("step1", "result1", Reward::Positive(0.5)));
286 session.add_turn(SessionTurn::new("step2", "result2", Reward::Negative(0.1)));
287
288 let trajectory = session.to_trajectory();
289
290 assert_eq!(trajectory.len(), 2);
291 assert_eq!(trajectory[0].0, "step1");
292 assert!((trajectory[0].2 - 0.5).abs() < 0.01);
293 assert!((trajectory[1].2 - (-0.1)).abs() < 0.01);
294 }
295}