1use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::path::PathBuf;
6
7#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
9pub struct SessionState {
10 models: HashMap<String, LoadedModel>,
12 history: Vec<HistoryEntry>,
14 preferences: Preferences,
16 metrics: SessionMetrics,
18}
19
20impl SessionState {
21 pub fn new() -> Self {
23 Self::default()
24 }
25
26 pub fn loaded_models(&self) -> &HashMap<String, LoadedModel> {
28 &self.models
29 }
30
31 pub fn history(&self) -> &[HistoryEntry] {
33 &self.history
34 }
35
36 pub fn add_model(&mut self, name: String, model: LoadedModel) {
38 self.models.insert(name, model);
39 }
40
41 pub fn remove_model(&mut self, name: &str) -> Option<LoadedModel> {
43 self.models.remove(name)
44 }
45
46 pub fn get_model(&self, name: &str) -> Option<&LoadedModel> {
48 self.models.get(name)
49 }
50
51 pub fn add_to_history(&mut self, entry: HistoryEntry) {
53 self.history.push(entry);
54 }
55
56 pub fn preferences_mut(&mut self) -> &mut Preferences {
58 &mut self.preferences
59 }
60
61 pub fn preferences(&self) -> &Preferences {
63 &self.preferences
64 }
65
66 pub fn metrics(&self) -> &SessionMetrics {
68 &self.metrics
69 }
70
71 pub fn record_command(&mut self, duration_ms: u64, success: bool) {
73 self.metrics.total_commands += 1;
74 if success {
75 self.metrics.successful_commands += 1;
76 }
77 self.metrics.total_duration_ms += duration_ms;
78 }
79
80 pub fn save(&self, path: &PathBuf) -> std::io::Result<()> {
82 let json = serde_json::to_string_pretty(self)?;
83 std::fs::write(path, json)
84 }
85
86 pub fn load(path: &PathBuf) -> std::io::Result<Self> {
88 let json = std::fs::read_to_string(path)?;
89 serde_json::from_str(&json)
90 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))
91 }
92}
93
94#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
96pub struct LoadedModel {
97 pub id: String,
99 pub path: PathBuf,
101 pub architecture: String,
103 pub parameters: u64,
105 pub layers: u32,
107 pub hidden_dim: u32,
109 pub role: ModelRole,
111}
112
113#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
115#[derive(Default)]
116pub enum ModelRole {
117 Teacher,
119 Student,
121 #[default]
123 None,
124}
125
126
127#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
129pub struct HistoryEntry {
130 pub command: String,
132 pub timestamp: u64,
134 pub duration_ms: u64,
136 pub success: bool,
138}
139
140impl HistoryEntry {
141 pub fn new(command: impl Into<String>, duration_ms: u64, success: bool) -> Self {
143 Self {
144 command: command.into(),
145 timestamp: std::time::SystemTime::now()
146 .duration_since(std::time::UNIX_EPOCH)
147 .map(|d| d.as_secs())
148 .unwrap_or(0),
149 duration_ms,
150 success,
151 }
152 }
153}
154
155#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
157pub struct Preferences {
158 pub output_format: String,
160 pub show_progress: bool,
162 pub auto_save_history: bool,
164 pub default_batch_size: u32,
166 pub default_seq_len: usize,
168}
169
170impl Default for Preferences {
171 fn default() -> Self {
172 Self {
173 output_format: "table".to_string(),
174 show_progress: true,
175 auto_save_history: true,
176 default_batch_size: 32,
177 default_seq_len: 512,
178 }
179 }
180}
181
182#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
184pub struct SessionMetrics {
185 pub total_commands: u64,
187 pub successful_commands: u64,
189 pub total_duration_ms: u64,
191}
192
193impl SessionMetrics {
194 pub fn success_rate(&self) -> f64 {
196 if self.total_commands == 0 {
197 100.0
198 } else {
199 (self.successful_commands as f64 / self.total_commands as f64) * 100.0
200 }
201 }
202
203 pub fn avg_duration_ms(&self) -> f64 {
205 if self.total_commands == 0 {
206 0.0
207 } else {
208 self.total_duration_ms as f64 / self.total_commands as f64
209 }
210 }
211}
212
213#[cfg(test)]
214mod tests {
215 use super::*;
216
217 #[test]
218 fn test_session_state_model_management() {
219 let mut state = SessionState::new();
220
221 let model = LoadedModel {
222 id: "test/model".to_string(),
223 path: PathBuf::from("/tmp/model"),
224 architecture: "llama".to_string(),
225 parameters: 7_000_000_000,
226 layers: 32,
227 hidden_dim: 4096,
228 role: ModelRole::Teacher,
229 };
230
231 state.add_model("teacher".to_string(), model.clone());
232 assert_eq!(state.loaded_models().len(), 1);
233 assert!(state.get_model("teacher").is_some());
234
235 state.remove_model("teacher");
236 assert!(state.get_model("teacher").is_none());
237 }
238
239 #[test]
240 fn test_session_state_history() {
241 let mut state = SessionState::new();
242
243 state.add_to_history(HistoryEntry::new("fetch model", 100, true));
244 state.add_to_history(HistoryEntry::new("inspect layers", 50, true));
245
246 assert_eq!(state.history().len(), 2);
247 assert_eq!(state.history()[0].command, "fetch model");
248 }
249
250 #[test]
251 fn test_session_metrics() {
252 let mut state = SessionState::new();
253
254 state.record_command(100, true);
255 state.record_command(200, true);
256 state.record_command(150, false);
257
258 assert_eq!(state.metrics().total_commands, 3);
259 assert_eq!(state.metrics().successful_commands, 2);
260 assert!((state.metrics().success_rate() - 66.67).abs() < 1.0);
261 }
262
263 #[test]
264 fn test_session_state_serialization_roundtrip() {
265 let mut state = SessionState::new();
266 state.add_to_history(HistoryEntry::new("test", 100, true));
267 state.preferences_mut().default_batch_size = 64;
268
269 let json = serde_json::to_string(&state).unwrap();
270 let restored: SessionState = serde_json::from_str(&json).unwrap();
271
272 assert_eq!(state, restored);
273 }
274
275 #[test]
276 fn test_model_role_default() {
277 assert_eq!(ModelRole::default(), ModelRole::None);
278 }
279
280 #[test]
281 fn test_preferences_default_values() {
282 let prefs = Preferences::default();
283 assert_eq!(prefs.output_format, "table");
284 assert!(prefs.show_progress);
285 assert_eq!(prefs.default_batch_size, 32);
286 }
287
288 #[test]
289 fn test_session_metrics_success_rate_zero() {
290 let metrics = SessionMetrics::default();
291 assert_eq!(metrics.success_rate(), 100.0);
292 }
293
294 #[test]
295 fn test_session_metrics_avg_duration_zero() {
296 let metrics = SessionMetrics::default();
297 assert_eq!(metrics.avg_duration_ms(), 0.0);
298 }
299
300 #[test]
301 fn test_session_metrics_avg_duration() {
302 let mut state = SessionState::new();
303 state.record_command(100, true);
304 state.record_command(200, true);
305 assert_eq!(state.metrics().avg_duration_ms(), 150.0);
306 }
307
308 #[test]
309 fn test_history_entry_new() {
310 let entry = HistoryEntry::new("test command", 50, true);
311 assert_eq!(entry.command, "test command");
312 assert_eq!(entry.duration_ms, 50);
313 assert!(entry.success);
314 assert!(entry.timestamp > 0);
315 }
316
317 #[test]
318 fn test_loaded_model_equality() {
319 let model1 = LoadedModel {
320 id: "test".to_string(),
321 path: PathBuf::from("/tmp"),
322 architecture: "llama".to_string(),
323 parameters: 7_000_000_000,
324 layers: 32,
325 hidden_dim: 4096,
326 role: ModelRole::None,
327 };
328 let model2 = model1.clone();
329 assert_eq!(model1, model2);
330 }
331
332 #[test]
333 fn test_model_role_equality() {
334 assert_eq!(ModelRole::Teacher, ModelRole::Teacher);
335 assert_ne!(ModelRole::Teacher, ModelRole::Student);
336 assert_ne!(ModelRole::Student, ModelRole::None);
337 }
338
339 #[test]
340 fn test_session_state_save_load() {
341 use tempfile::TempDir;
342
343 let temp_dir = TempDir::new().unwrap();
344 let state_path = temp_dir.path().join("state.json");
345
346 let mut state = SessionState::new();
347 state.add_to_history(HistoryEntry::new("test", 100, true));
348 state.preferences_mut().default_batch_size = 128;
349
350 state.save(&state_path).unwrap();
351 let loaded = SessionState::load(&state_path).unwrap();
352
353 assert_eq!(state, loaded);
354 }
355
356 #[test]
357 fn test_session_state_load_invalid_json() {
358 use tempfile::NamedTempFile;
359 use std::io::Write;
360
361 let mut file = NamedTempFile::new().unwrap();
362 file.write_all(b"not valid json").unwrap();
363
364 let result = SessionState::load(&file.path().to_path_buf());
365 assert!(result.is_err());
366 }
367
368 #[test]
369 fn test_preferences_all_fields() {
370 let prefs = Preferences::default();
371 assert_eq!(prefs.output_format, "table");
372 assert!(prefs.show_progress);
373 assert!(prefs.auto_save_history);
374 assert_eq!(prefs.default_batch_size, 32);
375 assert_eq!(prefs.default_seq_len, 512);
376 }
377
378 #[test]
379 fn test_session_state_remove_nonexistent() {
380 let mut state = SessionState::new();
381 let result = state.remove_model("nonexistent");
382 assert!(result.is_none());
383 }
384
385 #[test]
386 fn test_session_state_get_nonexistent() {
387 let state = SessionState::new();
388 assert!(state.get_model("nonexistent").is_none());
389 }
390
391 #[test]
392 fn test_session_metrics_fields() {
393 let metrics = SessionMetrics {
394 total_commands: 10,
395 successful_commands: 8,
396 total_duration_ms: 1000,
397 };
398 assert_eq!(metrics.success_rate(), 80.0);
399 assert_eq!(metrics.avg_duration_ms(), 100.0);
400 }
401}