entrenar_shell/
state.rs

1//! Session state management for the interactive shell.
2
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::path::PathBuf;
6
7/// Session state that persists across commands.
8#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
9pub struct SessionState {
10    /// Currently loaded models
11    models: HashMap<String, LoadedModel>,
12    /// Command history
13    history: Vec<HistoryEntry>,
14    /// User preferences
15    preferences: Preferences,
16    /// Session metrics
17    metrics: SessionMetrics,
18}
19
20impl SessionState {
21    /// Create a new empty session state.
22    pub fn new() -> Self {
23        Self::default()
24    }
25
26    /// Get loaded models.
27    pub fn loaded_models(&self) -> &HashMap<String, LoadedModel> {
28        &self.models
29    }
30
31    /// Get command history.
32    pub fn history(&self) -> &[HistoryEntry] {
33        &self.history
34    }
35
36    /// Add a model to the session.
37    pub fn add_model(&mut self, name: String, model: LoadedModel) {
38        self.models.insert(name, model);
39    }
40
41    /// Remove a model from the session.
42    pub fn remove_model(&mut self, name: &str) -> Option<LoadedModel> {
43        self.models.remove(name)
44    }
45
46    /// Get a model by name.
47    pub fn get_model(&self, name: &str) -> Option<&LoadedModel> {
48        self.models.get(name)
49    }
50
51    /// Add a command to history.
52    pub fn add_to_history(&mut self, entry: HistoryEntry) {
53        self.history.push(entry);
54    }
55
56    /// Get mutable preferences.
57    pub fn preferences_mut(&mut self) -> &mut Preferences {
58        &mut self.preferences
59    }
60
61    /// Get preferences.
62    pub fn preferences(&self) -> &Preferences {
63        &self.preferences
64    }
65
66    /// Get session metrics.
67    pub fn metrics(&self) -> &SessionMetrics {
68        &self.metrics
69    }
70
71    /// Update metrics after a command.
72    pub fn record_command(&mut self, duration_ms: u64, success: bool) {
73        self.metrics.total_commands += 1;
74        if success {
75            self.metrics.successful_commands += 1;
76        }
77        self.metrics.total_duration_ms += duration_ms;
78    }
79
80    /// Save state to a file.
81    pub fn save(&self, path: &PathBuf) -> std::io::Result<()> {
82        let json = serde_json::to_string_pretty(self)?;
83        std::fs::write(path, json)
84    }
85
86    /// Load state from a file.
87    pub fn load(path: &PathBuf) -> std::io::Result<Self> {
88        let json = std::fs::read_to_string(path)?;
89        serde_json::from_str(&json)
90            .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))
91    }
92}
93
94/// A loaded model in the session.
95#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
96pub struct LoadedModel {
97    /// Model identifier (HuggingFace ID or path)
98    pub id: String,
99    /// Local path to cached model
100    pub path: PathBuf,
101    /// Model architecture
102    pub architecture: String,
103    /// Number of parameters
104    pub parameters: u64,
105    /// Number of layers
106    pub layers: u32,
107    /// Hidden dimension
108    pub hidden_dim: u32,
109    /// Role in session (teacher/student)
110    pub role: ModelRole,
111}
112
113/// Role of a model in the distillation session.
114#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
115#[derive(Default)]
116pub enum ModelRole {
117    /// Teacher model (knowledge source)
118    Teacher,
119    /// Student model (learning target)
120    Student,
121    /// No specific role assigned
122    #[default]
123    None,
124}
125
126
127/// A history entry for a command.
128#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
129pub struct HistoryEntry {
130    /// The command string
131    pub command: String,
132    /// Execution timestamp (Unix seconds)
133    pub timestamp: u64,
134    /// Duration in milliseconds
135    pub duration_ms: u64,
136    /// Whether the command succeeded
137    pub success: bool,
138}
139
140impl HistoryEntry {
141    /// Create a new history entry.
142    pub fn new(command: impl Into<String>, duration_ms: u64, success: bool) -> Self {
143        Self {
144            command: command.into(),
145            timestamp: std::time::SystemTime::now()
146                .duration_since(std::time::UNIX_EPOCH)
147                .map(|d| d.as_secs())
148                .unwrap_or(0),
149            duration_ms,
150            success,
151        }
152    }
153}
154
155/// User preferences.
156#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
157pub struct Preferences {
158    /// Default output format
159    pub output_format: String,
160    /// Whether to show progress bars
161    pub show_progress: bool,
162    /// Whether to save history automatically
163    pub auto_save_history: bool,
164    /// Default batch size for operations
165    pub default_batch_size: u32,
166    /// Default sequence length
167    pub default_seq_len: usize,
168}
169
170impl Default for Preferences {
171    fn default() -> Self {
172        Self {
173            output_format: "table".to_string(),
174            show_progress: true,
175            auto_save_history: true,
176            default_batch_size: 32,
177            default_seq_len: 512,
178        }
179    }
180}
181
182/// Session-level metrics.
183#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
184pub struct SessionMetrics {
185    /// Total commands executed
186    pub total_commands: u64,
187    /// Successful commands
188    pub successful_commands: u64,
189    /// Total duration in milliseconds
190    pub total_duration_ms: u64,
191}
192
193impl SessionMetrics {
194    /// Get success rate as a percentage.
195    pub fn success_rate(&self) -> f64 {
196        if self.total_commands == 0 {
197            100.0
198        } else {
199            (self.successful_commands as f64 / self.total_commands as f64) * 100.0
200        }
201    }
202
203    /// Get average command duration in milliseconds.
204    pub fn avg_duration_ms(&self) -> f64 {
205        if self.total_commands == 0 {
206            0.0
207        } else {
208            self.total_duration_ms as f64 / self.total_commands as f64
209        }
210    }
211}
212
213#[cfg(test)]
214mod tests {
215    use super::*;
216
217    #[test]
218    fn test_session_state_model_management() {
219        let mut state = SessionState::new();
220
221        let model = LoadedModel {
222            id: "test/model".to_string(),
223            path: PathBuf::from("/tmp/model"),
224            architecture: "llama".to_string(),
225            parameters: 7_000_000_000,
226            layers: 32,
227            hidden_dim: 4096,
228            role: ModelRole::Teacher,
229        };
230
231        state.add_model("teacher".to_string(), model.clone());
232        assert_eq!(state.loaded_models().len(), 1);
233        assert!(state.get_model("teacher").is_some());
234
235        state.remove_model("teacher");
236        assert!(state.get_model("teacher").is_none());
237    }
238
239    #[test]
240    fn test_session_state_history() {
241        let mut state = SessionState::new();
242
243        state.add_to_history(HistoryEntry::new("fetch model", 100, true));
244        state.add_to_history(HistoryEntry::new("inspect layers", 50, true));
245
246        assert_eq!(state.history().len(), 2);
247        assert_eq!(state.history()[0].command, "fetch model");
248    }
249
250    #[test]
251    fn test_session_metrics() {
252        let mut state = SessionState::new();
253
254        state.record_command(100, true);
255        state.record_command(200, true);
256        state.record_command(150, false);
257
258        assert_eq!(state.metrics().total_commands, 3);
259        assert_eq!(state.metrics().successful_commands, 2);
260        assert!((state.metrics().success_rate() - 66.67).abs() < 1.0);
261    }
262
263    #[test]
264    fn test_session_state_serialization_roundtrip() {
265        let mut state = SessionState::new();
266        state.add_to_history(HistoryEntry::new("test", 100, true));
267        state.preferences_mut().default_batch_size = 64;
268
269        let json = serde_json::to_string(&state).unwrap();
270        let restored: SessionState = serde_json::from_str(&json).unwrap();
271
272        assert_eq!(state, restored);
273    }
274
275    #[test]
276    fn test_model_role_default() {
277        assert_eq!(ModelRole::default(), ModelRole::None);
278    }
279
280    #[test]
281    fn test_preferences_default_values() {
282        let prefs = Preferences::default();
283        assert_eq!(prefs.output_format, "table");
284        assert!(prefs.show_progress);
285        assert_eq!(prefs.default_batch_size, 32);
286    }
287
288    #[test]
289    fn test_session_metrics_success_rate_zero() {
290        let metrics = SessionMetrics::default();
291        assert_eq!(metrics.success_rate(), 100.0);
292    }
293
294    #[test]
295    fn test_session_metrics_avg_duration_zero() {
296        let metrics = SessionMetrics::default();
297        assert_eq!(metrics.avg_duration_ms(), 0.0);
298    }
299
300    #[test]
301    fn test_session_metrics_avg_duration() {
302        let mut state = SessionState::new();
303        state.record_command(100, true);
304        state.record_command(200, true);
305        assert_eq!(state.metrics().avg_duration_ms(), 150.0);
306    }
307
308    #[test]
309    fn test_history_entry_new() {
310        let entry = HistoryEntry::new("test command", 50, true);
311        assert_eq!(entry.command, "test command");
312        assert_eq!(entry.duration_ms, 50);
313        assert!(entry.success);
314        assert!(entry.timestamp > 0);
315    }
316
317    #[test]
318    fn test_loaded_model_equality() {
319        let model1 = LoadedModel {
320            id: "test".to_string(),
321            path: PathBuf::from("/tmp"),
322            architecture: "llama".to_string(),
323            parameters: 7_000_000_000,
324            layers: 32,
325            hidden_dim: 4096,
326            role: ModelRole::None,
327        };
328        let model2 = model1.clone();
329        assert_eq!(model1, model2);
330    }
331
332    #[test]
333    fn test_model_role_equality() {
334        assert_eq!(ModelRole::Teacher, ModelRole::Teacher);
335        assert_ne!(ModelRole::Teacher, ModelRole::Student);
336        assert_ne!(ModelRole::Student, ModelRole::None);
337    }
338
339    #[test]
340    fn test_session_state_save_load() {
341        use tempfile::TempDir;
342
343        let temp_dir = TempDir::new().unwrap();
344        let state_path = temp_dir.path().join("state.json");
345
346        let mut state = SessionState::new();
347        state.add_to_history(HistoryEntry::new("test", 100, true));
348        state.preferences_mut().default_batch_size = 128;
349
350        state.save(&state_path).unwrap();
351        let loaded = SessionState::load(&state_path).unwrap();
352
353        assert_eq!(state, loaded);
354    }
355
356    #[test]
357    fn test_session_state_load_invalid_json() {
358        use tempfile::NamedTempFile;
359        use std::io::Write;
360
361        let mut file = NamedTempFile::new().unwrap();
362        file.write_all(b"not valid json").unwrap();
363
364        let result = SessionState::load(&file.path().to_path_buf());
365        assert!(result.is_err());
366    }
367
368    #[test]
369    fn test_preferences_all_fields() {
370        let prefs = Preferences::default();
371        assert_eq!(prefs.output_format, "table");
372        assert!(prefs.show_progress);
373        assert!(prefs.auto_save_history);
374        assert_eq!(prefs.default_batch_size, 32);
375        assert_eq!(prefs.default_seq_len, 512);
376    }
377
378    #[test]
379    fn test_session_state_remove_nonexistent() {
380        let mut state = SessionState::new();
381        let result = state.remove_model("nonexistent");
382        assert!(result.is_none());
383    }
384
385    #[test]
386    fn test_session_state_get_nonexistent() {
387        let state = SessionState::new();
388        assert!(state.get_model("nonexistent").is_none());
389    }
390
391    #[test]
392    fn test_session_metrics_fields() {
393        let metrics = SessionMetrics {
394            total_commands: 10,
395            successful_commands: 8,
396            total_duration_ms: 1000,
397        };
398        assert_eq!(metrics.success_rate(), 80.0);
399        assert_eq!(metrics.avg_duration_ms(), 100.0);
400    }
401}