entrenar_shell/
commands.rs

1//! Command parsing and execution for the REPL.
2
3use crate::state::{HistoryEntry, LoadedModel, ModelRole, SessionState};
4use entrenar_common::{EntrenarError, Result};
5
6/// A parsed command.
7#[derive(Debug, Clone, PartialEq)]
8pub enum Command {
9    /// Fetch a model from HuggingFace
10    Fetch { model_id: String, role: ModelRole },
11    /// Inspect a loaded model
12    Inspect { target: InspectTarget },
13    /// Estimate memory requirements
14    Memory {
15        batch_size: Option<u32>,
16        seq_len: Option<usize>,
17    },
18    /// Set configuration values
19    Set { key: String, value: String },
20    /// Run distillation
21    Distill { dry_run: bool },
22    /// Export model to file
23    Export { format: String, path: String },
24    /// Show command history
25    History,
26    /// Show help
27    Help { topic: Option<String> },
28    /// Clear screen
29    Clear,
30    /// Quit the shell
31    Quit,
32    /// Unknown command
33    Unknown { input: String },
34}
35
36/// Target for inspect command.
37#[derive(Debug, Clone, PartialEq)]
38pub enum InspectTarget {
39    /// Inspect layer structure
40    Layers,
41    /// Inspect memory usage
42    Memory,
43    /// Inspect all info
44    All,
45    /// Inspect specific model by name
46    Model(String),
47}
48
49/// Parse a command string into a Command.
50pub fn parse(input: &str) -> Result<Command> {
51    let input = input.trim();
52    if input.is_empty() {
53        return Ok(Command::Unknown {
54            input: String::new(),
55        });
56    }
57
58    let parts: Vec<&str> = input.split_whitespace().collect();
59    let cmd = parts[0].to_lowercase();
60    let args = &parts[1..];
61
62    match cmd.as_str() {
63        "fetch" | "download" => parse_fetch(args),
64        "inspect" | "show" => parse_inspect(args),
65        "memory" | "mem" => parse_memory(args),
66        "set" => parse_set(args),
67        "distill" | "train" => parse_distill(args),
68        "export" | "save" => parse_export(args),
69        "history" | "hist" => Ok(Command::History),
70        "help" | "?" => parse_help(args),
71        "clear" | "cls" => Ok(Command::Clear),
72        "quit" | "exit" | "q" => Ok(Command::Quit),
73        _ => Ok(Command::Unknown {
74            input: input.to_string(),
75        }),
76    }
77}
78
79fn parse_fetch(args: &[&str]) -> Result<Command> {
80    if args.is_empty() {
81        return Err(EntrenarError::ConfigValue {
82            field: "model_id".into(),
83            message: "No model ID provided".into(),
84            suggestion: "Usage: fetch <model_id> [--teacher|--student]".into(),
85        });
86    }
87
88    let model_id = args[0].to_string();
89    let role = if args.contains(&"--teacher") {
90        ModelRole::Teacher
91    } else if args.contains(&"--student") {
92        ModelRole::Student
93    } else {
94        ModelRole::None
95    };
96
97    Ok(Command::Fetch { model_id, role })
98}
99
100fn parse_inspect(args: &[&str]) -> Result<Command> {
101    let target = if args.is_empty() {
102        InspectTarget::All
103    } else {
104        match args[0].to_lowercase().as_str() {
105            "layers" | "layer" => InspectTarget::Layers,
106            "memory" | "mem" => InspectTarget::Memory,
107            "all" => InspectTarget::All,
108            name => InspectTarget::Model(name.to_string()),
109        }
110    };
111
112    Ok(Command::Inspect { target })
113}
114
115fn parse_memory(args: &[&str]) -> Result<Command> {
116    let mut batch_size = None;
117    let mut seq_len = None;
118
119    let mut i = 0;
120    while i < args.len() {
121        match args[i] {
122            "--batch" | "-b" if i + 1 < args.len() => {
123                batch_size = args[i + 1].parse().ok();
124                i += 2;
125            }
126            "--seq" | "-s" if i + 1 < args.len() => {
127                seq_len = args[i + 1].parse().ok();
128                i += 2;
129            }
130            _ => i += 1,
131        }
132    }
133
134    Ok(Command::Memory {
135        batch_size,
136        seq_len,
137    })
138}
139
140fn parse_set(args: &[&str]) -> Result<Command> {
141    if args.len() < 2 {
142        return Err(EntrenarError::ConfigValue {
143            field: "set".into(),
144            message: "Not enough arguments".into(),
145            suggestion: "Usage: set <key> <value>".into(),
146        });
147    }
148
149    Ok(Command::Set {
150        key: args[0].to_string(),
151        value: args[1..].join(" "),
152    })
153}
154
155fn parse_distill(args: &[&str]) -> Result<Command> {
156    let dry_run = args.contains(&"--dry-run") || args.contains(&"-n");
157    Ok(Command::Distill { dry_run })
158}
159
160fn parse_export(args: &[&str]) -> Result<Command> {
161    if args.len() < 2 {
162        return Err(EntrenarError::ConfigValue {
163            field: "export".into(),
164            message: "Not enough arguments".into(),
165            suggestion: "Usage: export <format> <path>".into(),
166        });
167    }
168
169    Ok(Command::Export {
170        format: args[0].to_string(),
171        path: args[1].to_string(),
172    })
173}
174
175fn parse_help(args: &[&str]) -> Result<Command> {
176    let topic = args.first().map(ToString::to_string);
177    Ok(Command::Help { topic })
178}
179
180/// Execute a command and update state.
181pub fn execute(cmd: &Command, state: &mut SessionState) -> Result<String> {
182    let start = std::time::Instant::now();
183
184    let result = match cmd {
185        Command::Fetch { model_id, role } => execute_fetch(model_id, *role, state),
186        Command::Inspect { target } => execute_inspect(target, state),
187        Command::Memory {
188            batch_size,
189            seq_len,
190        } => execute_memory(*batch_size, *seq_len, state),
191        Command::Set { key, value } => execute_set(key, value, state),
192        Command::Distill { dry_run } => execute_distill(*dry_run, state),
193        Command::Export { format, path } => execute_export(format, path, state),
194        Command::History => execute_history(state),
195        Command::Help { topic } => execute_help(topic.as_deref()),
196        Command::Clear => Ok(String::new()),
197        Command::Quit => Ok("Goodbye!".to_string()),
198        Command::Unknown { input } => {
199            if input.is_empty() {
200                Ok(String::new())
201            } else {
202                Err(EntrenarError::ConfigValue {
203                    field: "command".into(),
204                    message: format!("Unknown command: {input}"),
205                    suggestion: "Type 'help' for available commands".into(),
206                })
207            }
208        }
209    };
210
211    let duration_ms = start.elapsed().as_millis() as u64;
212    let success = result.is_ok();
213
214    // Record in history (except for help/history/clear/quit)
215    if !matches!(
216        cmd,
217        Command::Help { .. }
218            | Command::History
219            | Command::Clear
220            | Command::Quit
221            | Command::Unknown { .. }
222    ) {
223        let cmd_str = format!("{cmd:?}");
224        state.add_to_history(HistoryEntry::new(cmd_str, duration_ms, success));
225        state.record_command(duration_ms, success);
226    }
227
228    result
229}
230
231fn execute_fetch(model_id: &str, role: ModelRole, state: &mut SessionState) -> Result<String> {
232    // Simulate model fetching
233    let model = LoadedModel {
234        id: model_id.to_string(),
235        path: std::path::PathBuf::from(format!("/tmp/models/{}", model_id.replace('/', "_"))),
236        architecture: detect_architecture(model_id),
237        parameters: estimate_params(model_id),
238        layers: estimate_layers(model_id),
239        hidden_dim: 4096,
240        role,
241    };
242
243    let name = if role == ModelRole::Teacher {
244        "teacher"
245    } else if role == ModelRole::Student {
246        "student"
247    } else {
248        model_id.split('/').next_back().unwrap_or(model_id)
249    };
250
251    state.add_model(name.to_string(), model.clone());
252
253    Ok(format!(
254        "✓ Fetched {}\n  Architecture: {}\n  Parameters: {:.1}B\n  Layers: {}",
255        model_id,
256        model.architecture,
257        model.parameters as f64 / 1e9,
258        model.layers
259    ))
260}
261
262fn execute_inspect(target: &InspectTarget, state: &SessionState) -> Result<String> {
263    match target {
264        InspectTarget::All => {
265            if state.loaded_models().is_empty() {
266                return Ok("No models loaded. Use 'fetch <model_id>' to load a model.".to_string());
267            }
268
269            let mut output = String::from("Loaded Models:\n");
270            for (name, model) in state.loaded_models() {
271                output.push_str(&format!(
272                    "  {} ({}): {:.1}B params, {} layers\n",
273                    name,
274                    model.id,
275                    model.parameters as f64 / 1e9,
276                    model.layers
277                ));
278            }
279            Ok(output)
280        }
281        InspectTarget::Layers => {
282            let mut output = String::from("Layer Analysis:\n");
283            for (name, model) in state.loaded_models() {
284                output.push_str(&format!(
285                    "  {}: {} layers, hidden_dim={}\n",
286                    name, model.layers, model.hidden_dim
287                ));
288            }
289            Ok(output)
290        }
291        InspectTarget::Memory => execute_memory(None, None, state),
292        InspectTarget::Model(name) => {
293            if let Some(model) = state.get_model(name) {
294                Ok(format!(
295                    "Model: {}\n  ID: {}\n  Path: {}\n  Architecture: {}\n  Parameters: {:.1}B\n  Layers: {}\n  Hidden Dim: {}",
296                    name, model.id, model.path.display(), model.architecture,
297                    model.parameters as f64 / 1e9, model.layers, model.hidden_dim
298                ))
299            } else {
300                Err(EntrenarError::ModelNotFound { path: name.into() })
301            }
302        }
303    }
304}
305
306fn execute_memory(
307    batch_size: Option<u32>,
308    seq_len: Option<usize>,
309    state: &SessionState,
310) -> Result<String> {
311    let batch = batch_size.unwrap_or(state.preferences().default_batch_size);
312    let seq = seq_len.unwrap_or(state.preferences().default_seq_len);
313
314    let total_params: u64 = state.loaded_models().values().map(|m| m.parameters).sum();
315    let model_mem = total_params * 2; // FP16
316    let activation_mem = u64::from(batch) * (seq as u64) * 4096 * 32 * 2;
317    let total = model_mem + activation_mem;
318
319    Ok(format!(
320        "Memory Estimate (batch={}, seq={}):\n  Model: {:.1} GB\n  Activations: {:.1} GB\n  Total: {:.1} GB",
321        batch, seq,
322        model_mem as f64 / 1e9,
323        activation_mem as f64 / 1e9,
324        total as f64 / 1e9
325    ))
326}
327
328fn execute_set(key: &str, value: &str, state: &mut SessionState) -> Result<String> {
329    match key {
330        "batch_size" | "batch" => {
331            let v: u32 = value.parse().map_err(|_| EntrenarError::ConfigValue {
332                field: "batch_size".into(),
333                message: "Invalid number".into(),
334                suggestion: "Use a positive integer".into(),
335            })?;
336            state.preferences_mut().default_batch_size = v;
337            Ok(format!("Set batch_size = {v}"))
338        }
339        "seq_len" | "seq" => {
340            let v: usize = value.parse().map_err(|_| EntrenarError::ConfigValue {
341                field: "seq_len".into(),
342                message: "Invalid number".into(),
343                suggestion: "Use a positive integer".into(),
344            })?;
345            state.preferences_mut().default_seq_len = v;
346            Ok(format!("Set seq_len = {v}"))
347        }
348        _ => Err(EntrenarError::ConfigValue {
349            field: key.into(),
350            message: "Unknown setting".into(),
351            suggestion: "Available settings: batch_size, seq_len".into(),
352        }),
353    }
354}
355
356fn execute_distill(dry_run: bool, state: &SessionState) -> Result<String> {
357    let teacher = state
358        .loaded_models()
359        .values()
360        .find(|m| m.role == ModelRole::Teacher);
361    let student = state
362        .loaded_models()
363        .values()
364        .find(|m| m.role == ModelRole::Student);
365
366    if teacher.is_none() {
367        return Err(EntrenarError::ConfigValue {
368            field: "teacher".into(),
369            message: "No teacher model loaded".into(),
370            suggestion: "Use 'fetch <model_id> --teacher' to load a teacher model".into(),
371        });
372    }
373
374    if student.is_none() {
375        return Err(EntrenarError::ConfigValue {
376            field: "student".into(),
377            message: "No student model loaded".into(),
378            suggestion: "Use 'fetch <model_id> --student' to load a student model".into(),
379        });
380    }
381
382    if dry_run {
383        Ok(format!(
384            "Dry run configuration:\n  Teacher: {} ({:.1}B)\n  Student: {} ({:.1}B)\n  Ready to train",
385            teacher.unwrap().id, teacher.unwrap().parameters as f64 / 1e9,
386            student.unwrap().id, student.unwrap().parameters as f64 / 1e9
387        ))
388    } else {
389        Ok("Training started... (simulated)".to_string())
390    }
391}
392
393fn execute_export(format: &str, path: &str, _state: &SessionState) -> Result<String> {
394    Ok(format!("Exported to {path} in {format} format"))
395}
396
397fn execute_history(state: &SessionState) -> Result<String> {
398    if state.history().is_empty() {
399        return Ok("No command history.".to_string());
400    }
401
402    let mut output = String::from("Command History:\n");
403    for (i, entry) in state.history().iter().enumerate() {
404        let status = if entry.success { "✓" } else { "✗" };
405        output.push_str(&format!(
406            "  {}. {} {} ({}ms)\n",
407            i + 1,
408            status,
409            entry.command,
410            entry.duration_ms
411        ));
412    }
413    Ok(output)
414}
415
416fn execute_help(topic: Option<&str>) -> Result<String> {
417    match topic {
418        Some("fetch") => Ok(
419            "fetch <model_id> [--teacher|--student]\n  Download a model from HuggingFace"
420                .to_string(),
421        ),
422        Some("inspect") => {
423            Ok("inspect [layers|memory|all|<model>]\n  Inspect loaded models".to_string())
424        }
425        Some("memory") => {
426            Ok("memory [--batch <n>] [--seq <n>]\n  Estimate memory requirements".to_string())
427        }
428        Some("distill") => Ok("distill [--dry-run]\n  Run distillation training".to_string()),
429        _ => Ok("Available commands:
430  fetch <model>      Download model from HuggingFace
431  inspect [target]   Inspect loaded models
432  memory             Estimate memory requirements
433  set <key> <value>  Configure settings
434  distill            Run distillation
435  export <fmt> <path> Export model
436  history            Show command history
437  help [topic]       Show help
438  quit               Exit shell"
439            .to_string()),
440    }
441}
442
443fn detect_architecture(model_id: &str) -> String {
444    let lower = model_id.to_lowercase();
445    if lower.contains("llama") {
446        "llama".to_string()
447    } else if lower.contains("bert") {
448        "bert".to_string()
449    } else if lower.contains("gpt") {
450        "gpt".to_string()
451    } else if lower.contains("mistral") {
452        "mistral".to_string()
453    } else {
454        "unknown".to_string()
455    }
456}
457
458fn estimate_params(model_id: &str) -> u64 {
459    let lower = model_id.to_lowercase();
460    if lower.contains("70b") {
461        70_000_000_000
462    } else if lower.contains("13b") {
463        13_000_000_000
464    } else if lower.contains("7b") {
465        7_000_000_000
466    } else if lower.contains("1.1b") || lower.contains("1b") {
467        1_100_000_000
468    } else if lower.contains("base") {
469        350_000_000
470    } else {
471        1_000_000_000
472    }
473}
474
475fn estimate_layers(model_id: &str) -> u32 {
476    let lower = model_id.to_lowercase();
477    if lower.contains("70b") {
478        80
479    } else if lower.contains("13b") {
480        40
481    } else if lower.contains("7b") {
482        32
483    } else if lower.contains("base") {
484        12
485    } else {
486        24
487    }
488}
489
490#[cfg(test)]
491mod tests {
492    use super::*;
493
494    #[test]
495    fn test_parse_fetch() {
496        let cmd = parse("fetch meta-llama/Llama-2-7b --teacher").unwrap();
497        assert!(matches!(
498            cmd,
499            Command::Fetch {
500                role: ModelRole::Teacher,
501                ..
502            }
503        ));
504
505        let cmd = parse("fetch TinyLlama/TinyLlama-1.1B --student").unwrap();
506        assert!(matches!(
507            cmd,
508            Command::Fetch {
509                role: ModelRole::Student,
510                ..
511            }
512        ));
513    }
514
515    #[test]
516    fn test_parse_inspect() {
517        assert!(matches!(
518            parse("inspect").unwrap(),
519            Command::Inspect {
520                target: InspectTarget::All
521            }
522        ));
523        assert!(matches!(
524            parse("inspect layers").unwrap(),
525            Command::Inspect {
526                target: InspectTarget::Layers
527            }
528        ));
529    }
530
531    #[test]
532    fn test_parse_memory() {
533        let cmd = parse("memory --batch 64 --seq 1024").unwrap();
534        if let Command::Memory {
535            batch_size,
536            seq_len,
537        } = cmd
538        {
539            assert_eq!(batch_size, Some(64));
540            assert_eq!(seq_len, Some(1024));
541        } else {
542            panic!("Expected Memory command");
543        }
544    }
545
546    #[test]
547    fn test_parse_quit_variants() {
548        assert!(matches!(parse("quit").unwrap(), Command::Quit));
549        assert!(matches!(parse("exit").unwrap(), Command::Quit));
550        assert!(matches!(parse("q").unwrap(), Command::Quit));
551    }
552
553    #[test]
554    fn test_execute_fetch() {
555        let mut state = SessionState::new();
556        let result = execute_fetch("meta-llama/Llama-2-7b", ModelRole::Teacher, &mut state);
557
558        assert!(result.is_ok());
559        assert!(state.get_model("teacher").is_some());
560    }
561
562    #[test]
563    fn test_execute_set() {
564        let mut state = SessionState::new();
565
566        execute_set("batch_size", "64", &mut state).unwrap();
567        assert_eq!(state.preferences().default_batch_size, 64);
568
569        execute_set("seq_len", "1024", &mut state).unwrap();
570        assert_eq!(state.preferences().default_seq_len, 1024);
571    }
572
573    #[test]
574    fn test_unknown_command() {
575        let cmd = parse("foobar").unwrap();
576        assert!(matches!(cmd, Command::Unknown { .. }));
577    }
578
579    #[test]
580    fn test_parse_fetch_missing_model() {
581        let result = parse("fetch");
582        assert!(result.is_err());
583    }
584
585    #[test]
586    fn test_parse_set_not_enough_args() {
587        let result = parse("set batch_size");
588        assert!(result.is_err());
589    }
590
591    #[test]
592    fn test_parse_export_not_enough_args() {
593        let result = parse("export safetensors");
594        assert!(result.is_err());
595    }
596
597    #[test]
598    fn test_parse_export_valid() {
599        let cmd = parse("export safetensors /tmp/model.st").unwrap();
600        if let Command::Export { format, path } = cmd {
601            assert_eq!(format, "safetensors");
602            assert_eq!(path, "/tmp/model.st");
603        } else {
604            panic!("Expected Export command");
605        }
606    }
607
608    #[test]
609    fn test_parse_help_with_topic() {
610        let cmd = parse("help fetch").unwrap();
611        if let Command::Help { topic } = cmd {
612            assert_eq!(topic, Some("fetch".to_string()));
613        } else {
614            panic!("Expected Help command");
615        }
616    }
617
618    #[test]
619    fn test_parse_distill_dry_run() {
620        let cmd = parse("distill --dry-run").unwrap();
621        if let Command::Distill { dry_run } = cmd {
622            assert!(dry_run);
623        } else {
624            panic!("Expected Distill command");
625        }
626    }
627
628    #[test]
629    fn test_parse_distill_short_flag() {
630        let cmd = parse("distill -n").unwrap();
631        if let Command::Distill { dry_run } = cmd {
632            assert!(dry_run);
633        } else {
634            panic!("Expected Distill command");
635        }
636    }
637
638    #[test]
639    fn test_parse_inspect_model() {
640        let cmd = parse("inspect teacher").unwrap();
641        if let Command::Inspect { target } = cmd {
642            assert_eq!(target, InspectTarget::Model("teacher".to_string()));
643        } else {
644            panic!("Expected Inspect command");
645        }
646    }
647
648    #[test]
649    fn test_parse_inspect_memory() {
650        let cmd = parse("inspect memory").unwrap();
651        assert!(matches!(
652            cmd,
653            Command::Inspect {
654                target: InspectTarget::Memory
655            }
656        ));
657    }
658
659    #[test]
660    fn test_parse_command_aliases() {
661        // download = fetch
662        assert!(matches!(
663            parse("download model").unwrap(),
664            Command::Fetch { .. }
665        ));
666        // show = inspect
667        assert!(matches!(
668            parse("show layers").unwrap(),
669            Command::Inspect { .. }
670        ));
671        // mem = memory
672        assert!(matches!(parse("mem").unwrap(), Command::Memory { .. }));
673        // train = distill
674        assert!(matches!(parse("train").unwrap(), Command::Distill { .. }));
675        // save = export (needs args)
676        assert!(matches!(
677            parse("save gguf /tmp/out").unwrap(),
678            Command::Export { .. }
679        ));
680        // cls = clear
681        assert!(matches!(parse("cls").unwrap(), Command::Clear));
682        // ? = help
683        assert!(matches!(parse("?").unwrap(), Command::Help { .. }));
684        // hist = history
685        assert!(matches!(parse("hist").unwrap(), Command::History));
686    }
687
688    #[test]
689    fn test_execute_inspect_no_models() {
690        let state = SessionState::new();
691        let result = execute_inspect(&InspectTarget::All, &state);
692        assert!(result.unwrap().contains("No models loaded"));
693    }
694
695    #[test]
696    fn test_execute_inspect_layers() {
697        let mut state = SessionState::new();
698        let model = LoadedModel {
699            id: "test".to_string(),
700            path: std::path::PathBuf::from("/tmp"),
701            architecture: "llama".to_string(),
702            parameters: 7_000_000_000,
703            layers: 32,
704            hidden_dim: 4096,
705            role: ModelRole::None,
706        };
707        state.add_model("test".to_string(), model);
708
709        let result = execute_inspect(&InspectTarget::Layers, &state).unwrap();
710        assert!(result.contains("32 layers"));
711    }
712
713    #[test]
714    fn test_execute_inspect_model_not_found() {
715        let state = SessionState::new();
716        let result = execute_inspect(&InspectTarget::Model("unknown".to_string()), &state);
717        assert!(result.is_err());
718    }
719
720    #[test]
721    fn test_execute_history_empty() {
722        let state = SessionState::new();
723        let result = execute_history(&state).unwrap();
724        assert!(result.contains("No command history"));
725    }
726
727    #[test]
728    fn test_execute_help_topics() {
729        let fetch_help = execute_help(Some("fetch")).unwrap();
730        assert!(fetch_help.contains("Download"));
731
732        let inspect_help = execute_help(Some("inspect")).unwrap();
733        assert!(inspect_help.contains("Inspect"));
734
735        let memory_help = execute_help(Some("memory")).unwrap();
736        assert!(memory_help.contains("memory"));
737
738        let distill_help = execute_help(Some("distill")).unwrap();
739        assert!(distill_help.contains("distill"));
740
741        let general_help = execute_help(None).unwrap();
742        assert!(general_help.contains("Available commands"));
743    }
744
745    #[test]
746    fn test_detect_architecture_variants() {
747        assert_eq!(detect_architecture("meta-llama/Llama-2-7b"), "llama");
748        assert_eq!(detect_architecture("bert-base-uncased"), "bert");
749        assert_eq!(detect_architecture("openai-gpt"), "gpt");
750        assert_eq!(detect_architecture("mistralai/Mistral-7B"), "mistral");
751        assert_eq!(detect_architecture("custom-model"), "unknown");
752    }
753
754    #[test]
755    fn test_estimate_params_variants() {
756        assert_eq!(estimate_params("model-70b"), 70_000_000_000);
757        assert_eq!(estimate_params("model-13b"), 13_000_000_000);
758        assert_eq!(estimate_params("model-7b"), 7_000_000_000);
759        assert_eq!(estimate_params("model-1.1b"), 1_100_000_000);
760        assert_eq!(estimate_params("bert-base"), 350_000_000);
761    }
762
763    #[test]
764    fn test_estimate_layers_variants() {
765        assert_eq!(estimate_layers("model-70b"), 80);
766        assert_eq!(estimate_layers("model-13b"), 40);
767        assert_eq!(estimate_layers("model-7b"), 32);
768        assert_eq!(estimate_layers("bert-base"), 12);
769    }
770
771    #[test]
772    fn test_execute_set_invalid_number() {
773        let mut state = SessionState::new();
774        let result = execute_set("batch_size", "not_a_number", &mut state);
775        assert!(result.is_err());
776    }
777
778    #[test]
779    fn test_execute_set_unknown_key() {
780        let mut state = SessionState::new();
781        let result = execute_set("unknown_setting", "value", &mut state);
782        assert!(result.is_err());
783    }
784
785    #[test]
786    fn test_execute_distill_no_teacher() {
787        let state = SessionState::new();
788        let result = execute_distill(true, &state);
789        assert!(result.is_err());
790    }
791
792    #[test]
793    fn test_execute_distill_no_student() {
794        let mut state = SessionState::new();
795        let model = LoadedModel {
796            id: "teacher".to_string(),
797            path: std::path::PathBuf::from("/tmp"),
798            architecture: "llama".to_string(),
799            parameters: 7_000_000_000,
800            layers: 32,
801            hidden_dim: 4096,
802            role: ModelRole::Teacher,
803        };
804        state.add_model("teacher".to_string(), model);
805
806        let result = execute_distill(true, &state);
807        assert!(result.is_err());
808    }
809
810    #[test]
811    fn test_execute_distill_success() {
812        let mut state = SessionState::new();
813
814        let teacher = LoadedModel {
815            id: "teacher/model".to_string(),
816            path: std::path::PathBuf::from("/tmp/t"),
817            architecture: "llama".to_string(),
818            parameters: 7_000_000_000,
819            layers: 32,
820            hidden_dim: 4096,
821            role: ModelRole::Teacher,
822        };
823        state.add_model("teacher".to_string(), teacher);
824
825        let student = LoadedModel {
826            id: "student/model".to_string(),
827            path: std::path::PathBuf::from("/tmp/s"),
828            architecture: "llama".to_string(),
829            parameters: 1_000_000_000,
830            layers: 12,
831            hidden_dim: 2048,
832            role: ModelRole::Student,
833        };
834        state.add_model("student".to_string(), student);
835
836        let result = execute_distill(true, &state).unwrap();
837        assert!(result.contains("Dry run"));
838    }
839
840    #[test]
841    fn test_execute_export() {
842        let state = SessionState::new();
843        let result = execute_export("safetensors", "/tmp/model.st", &state).unwrap();
844        assert!(result.contains("Exported"));
845    }
846
847    #[test]
848    fn test_parse_empty_input() {
849        let cmd = parse("").unwrap();
850        assert!(matches!(cmd, Command::Unknown { .. }));
851    }
852
853    #[test]
854    fn test_execute_memory_with_args() {
855        let state = SessionState::new();
856        let result = execute_memory(Some(64), Some(1024), &state).unwrap();
857        assert!(result.contains("batch=64"));
858        assert!(result.contains("seq=1024"));
859    }
860
861    #[test]
862    fn test_command_enum_equality() {
863        assert_eq!(Command::Quit, Command::Quit);
864        assert_eq!(Command::Clear, Command::Clear);
865        assert_eq!(Command::History, Command::History);
866    }
867}