1use crate::state::{HistoryEntry, LoadedModel, ModelRole, SessionState};
4use entrenar_common::{EntrenarError, Result};
5
6#[derive(Debug, Clone, PartialEq)]
8pub enum Command {
9 Fetch { model_id: String, role: ModelRole },
11 Inspect { target: InspectTarget },
13 Memory {
15 batch_size: Option<u32>,
16 seq_len: Option<usize>,
17 },
18 Set { key: String, value: String },
20 Distill { dry_run: bool },
22 Export { format: String, path: String },
24 History,
26 Help { topic: Option<String> },
28 Clear,
30 Quit,
32 Unknown { input: String },
34}
35
36#[derive(Debug, Clone, PartialEq)]
38pub enum InspectTarget {
39 Layers,
41 Memory,
43 All,
45 Model(String),
47}
48
49pub fn parse(input: &str) -> Result<Command> {
51 let input = input.trim();
52 if input.is_empty() {
53 return Ok(Command::Unknown {
54 input: String::new(),
55 });
56 }
57
58 let parts: Vec<&str> = input.split_whitespace().collect();
59 let cmd = parts[0].to_lowercase();
60 let args = &parts[1..];
61
62 match cmd.as_str() {
63 "fetch" | "download" => parse_fetch(args),
64 "inspect" | "show" => parse_inspect(args),
65 "memory" | "mem" => parse_memory(args),
66 "set" => parse_set(args),
67 "distill" | "train" => parse_distill(args),
68 "export" | "save" => parse_export(args),
69 "history" | "hist" => Ok(Command::History),
70 "help" | "?" => parse_help(args),
71 "clear" | "cls" => Ok(Command::Clear),
72 "quit" | "exit" | "q" => Ok(Command::Quit),
73 _ => Ok(Command::Unknown {
74 input: input.to_string(),
75 }),
76 }
77}
78
79fn parse_fetch(args: &[&str]) -> Result<Command> {
80 if args.is_empty() {
81 return Err(EntrenarError::ConfigValue {
82 field: "model_id".into(),
83 message: "No model ID provided".into(),
84 suggestion: "Usage: fetch <model_id> [--teacher|--student]".into(),
85 });
86 }
87
88 let model_id = args[0].to_string();
89 let role = if args.contains(&"--teacher") {
90 ModelRole::Teacher
91 } else if args.contains(&"--student") {
92 ModelRole::Student
93 } else {
94 ModelRole::None
95 };
96
97 Ok(Command::Fetch { model_id, role })
98}
99
100fn parse_inspect(args: &[&str]) -> Result<Command> {
101 let target = if args.is_empty() {
102 InspectTarget::All
103 } else {
104 match args[0].to_lowercase().as_str() {
105 "layers" | "layer" => InspectTarget::Layers,
106 "memory" | "mem" => InspectTarget::Memory,
107 "all" => InspectTarget::All,
108 name => InspectTarget::Model(name.to_string()),
109 }
110 };
111
112 Ok(Command::Inspect { target })
113}
114
115fn parse_memory(args: &[&str]) -> Result<Command> {
116 let mut batch_size = None;
117 let mut seq_len = None;
118
119 let mut i = 0;
120 while i < args.len() {
121 match args[i] {
122 "--batch" | "-b" if i + 1 < args.len() => {
123 batch_size = args[i + 1].parse().ok();
124 i += 2;
125 }
126 "--seq" | "-s" if i + 1 < args.len() => {
127 seq_len = args[i + 1].parse().ok();
128 i += 2;
129 }
130 _ => i += 1,
131 }
132 }
133
134 Ok(Command::Memory {
135 batch_size,
136 seq_len,
137 })
138}
139
140fn parse_set(args: &[&str]) -> Result<Command> {
141 if args.len() < 2 {
142 return Err(EntrenarError::ConfigValue {
143 field: "set".into(),
144 message: "Not enough arguments".into(),
145 suggestion: "Usage: set <key> <value>".into(),
146 });
147 }
148
149 Ok(Command::Set {
150 key: args[0].to_string(),
151 value: args[1..].join(" "),
152 })
153}
154
155fn parse_distill(args: &[&str]) -> Result<Command> {
156 let dry_run = args.contains(&"--dry-run") || args.contains(&"-n");
157 Ok(Command::Distill { dry_run })
158}
159
160fn parse_export(args: &[&str]) -> Result<Command> {
161 if args.len() < 2 {
162 return Err(EntrenarError::ConfigValue {
163 field: "export".into(),
164 message: "Not enough arguments".into(),
165 suggestion: "Usage: export <format> <path>".into(),
166 });
167 }
168
169 Ok(Command::Export {
170 format: args[0].to_string(),
171 path: args[1].to_string(),
172 })
173}
174
175fn parse_help(args: &[&str]) -> Result<Command> {
176 let topic = args.first().map(ToString::to_string);
177 Ok(Command::Help { topic })
178}
179
180pub fn execute(cmd: &Command, state: &mut SessionState) -> Result<String> {
182 let start = std::time::Instant::now();
183
184 let result = match cmd {
185 Command::Fetch { model_id, role } => execute_fetch(model_id, *role, state),
186 Command::Inspect { target } => execute_inspect(target, state),
187 Command::Memory {
188 batch_size,
189 seq_len,
190 } => execute_memory(*batch_size, *seq_len, state),
191 Command::Set { key, value } => execute_set(key, value, state),
192 Command::Distill { dry_run } => execute_distill(*dry_run, state),
193 Command::Export { format, path } => execute_export(format, path, state),
194 Command::History => execute_history(state),
195 Command::Help { topic } => execute_help(topic.as_deref()),
196 Command::Clear => Ok(String::new()),
197 Command::Quit => Ok("Goodbye!".to_string()),
198 Command::Unknown { input } => {
199 if input.is_empty() {
200 Ok(String::new())
201 } else {
202 Err(EntrenarError::ConfigValue {
203 field: "command".into(),
204 message: format!("Unknown command: {input}"),
205 suggestion: "Type 'help' for available commands".into(),
206 })
207 }
208 }
209 };
210
211 let duration_ms = start.elapsed().as_millis() as u64;
212 let success = result.is_ok();
213
214 if !matches!(
216 cmd,
217 Command::Help { .. }
218 | Command::History
219 | Command::Clear
220 | Command::Quit
221 | Command::Unknown { .. }
222 ) {
223 let cmd_str = format!("{cmd:?}");
224 state.add_to_history(HistoryEntry::new(cmd_str, duration_ms, success));
225 state.record_command(duration_ms, success);
226 }
227
228 result
229}
230
231fn execute_fetch(model_id: &str, role: ModelRole, state: &mut SessionState) -> Result<String> {
232 let model = LoadedModel {
234 id: model_id.to_string(),
235 path: std::path::PathBuf::from(format!("/tmp/models/{}", model_id.replace('/', "_"))),
236 architecture: detect_architecture(model_id),
237 parameters: estimate_params(model_id),
238 layers: estimate_layers(model_id),
239 hidden_dim: 4096,
240 role,
241 };
242
243 let name = if role == ModelRole::Teacher {
244 "teacher"
245 } else if role == ModelRole::Student {
246 "student"
247 } else {
248 model_id.split('/').next_back().unwrap_or(model_id)
249 };
250
251 state.add_model(name.to_string(), model.clone());
252
253 Ok(format!(
254 "✓ Fetched {}\n Architecture: {}\n Parameters: {:.1}B\n Layers: {}",
255 model_id,
256 model.architecture,
257 model.parameters as f64 / 1e9,
258 model.layers
259 ))
260}
261
262fn execute_inspect(target: &InspectTarget, state: &SessionState) -> Result<String> {
263 match target {
264 InspectTarget::All => {
265 if state.loaded_models().is_empty() {
266 return Ok("No models loaded. Use 'fetch <model_id>' to load a model.".to_string());
267 }
268
269 let mut output = String::from("Loaded Models:\n");
270 for (name, model) in state.loaded_models() {
271 output.push_str(&format!(
272 " {} ({}): {:.1}B params, {} layers\n",
273 name,
274 model.id,
275 model.parameters as f64 / 1e9,
276 model.layers
277 ));
278 }
279 Ok(output)
280 }
281 InspectTarget::Layers => {
282 let mut output = String::from("Layer Analysis:\n");
283 for (name, model) in state.loaded_models() {
284 output.push_str(&format!(
285 " {}: {} layers, hidden_dim={}\n",
286 name, model.layers, model.hidden_dim
287 ));
288 }
289 Ok(output)
290 }
291 InspectTarget::Memory => execute_memory(None, None, state),
292 InspectTarget::Model(name) => {
293 if let Some(model) = state.get_model(name) {
294 Ok(format!(
295 "Model: {}\n ID: {}\n Path: {}\n Architecture: {}\n Parameters: {:.1}B\n Layers: {}\n Hidden Dim: {}",
296 name, model.id, model.path.display(), model.architecture,
297 model.parameters as f64 / 1e9, model.layers, model.hidden_dim
298 ))
299 } else {
300 Err(EntrenarError::ModelNotFound { path: name.into() })
301 }
302 }
303 }
304}
305
306fn execute_memory(
307 batch_size: Option<u32>,
308 seq_len: Option<usize>,
309 state: &SessionState,
310) -> Result<String> {
311 let batch = batch_size.unwrap_or(state.preferences().default_batch_size);
312 let seq = seq_len.unwrap_or(state.preferences().default_seq_len);
313
314 let total_params: u64 = state.loaded_models().values().map(|m| m.parameters).sum();
315 let model_mem = total_params * 2; let activation_mem = u64::from(batch) * (seq as u64) * 4096 * 32 * 2;
317 let total = model_mem + activation_mem;
318
319 Ok(format!(
320 "Memory Estimate (batch={}, seq={}):\n Model: {:.1} GB\n Activations: {:.1} GB\n Total: {:.1} GB",
321 batch, seq,
322 model_mem as f64 / 1e9,
323 activation_mem as f64 / 1e9,
324 total as f64 / 1e9
325 ))
326}
327
328fn execute_set(key: &str, value: &str, state: &mut SessionState) -> Result<String> {
329 match key {
330 "batch_size" | "batch" => {
331 let v: u32 = value.parse().map_err(|_| EntrenarError::ConfigValue {
332 field: "batch_size".into(),
333 message: "Invalid number".into(),
334 suggestion: "Use a positive integer".into(),
335 })?;
336 state.preferences_mut().default_batch_size = v;
337 Ok(format!("Set batch_size = {v}"))
338 }
339 "seq_len" | "seq" => {
340 let v: usize = value.parse().map_err(|_| EntrenarError::ConfigValue {
341 field: "seq_len".into(),
342 message: "Invalid number".into(),
343 suggestion: "Use a positive integer".into(),
344 })?;
345 state.preferences_mut().default_seq_len = v;
346 Ok(format!("Set seq_len = {v}"))
347 }
348 _ => Err(EntrenarError::ConfigValue {
349 field: key.into(),
350 message: "Unknown setting".into(),
351 suggestion: "Available settings: batch_size, seq_len".into(),
352 }),
353 }
354}
355
356fn execute_distill(dry_run: bool, state: &SessionState) -> Result<String> {
357 let teacher = state
358 .loaded_models()
359 .values()
360 .find(|m| m.role == ModelRole::Teacher);
361 let student = state
362 .loaded_models()
363 .values()
364 .find(|m| m.role == ModelRole::Student);
365
366 if teacher.is_none() {
367 return Err(EntrenarError::ConfigValue {
368 field: "teacher".into(),
369 message: "No teacher model loaded".into(),
370 suggestion: "Use 'fetch <model_id> --teacher' to load a teacher model".into(),
371 });
372 }
373
374 if student.is_none() {
375 return Err(EntrenarError::ConfigValue {
376 field: "student".into(),
377 message: "No student model loaded".into(),
378 suggestion: "Use 'fetch <model_id> --student' to load a student model".into(),
379 });
380 }
381
382 if dry_run {
383 Ok(format!(
384 "Dry run configuration:\n Teacher: {} ({:.1}B)\n Student: {} ({:.1}B)\n Ready to train",
385 teacher.unwrap().id, teacher.unwrap().parameters as f64 / 1e9,
386 student.unwrap().id, student.unwrap().parameters as f64 / 1e9
387 ))
388 } else {
389 Ok("Training started... (simulated)".to_string())
390 }
391}
392
393fn execute_export(format: &str, path: &str, _state: &SessionState) -> Result<String> {
394 Ok(format!("Exported to {path} in {format} format"))
395}
396
397fn execute_history(state: &SessionState) -> Result<String> {
398 if state.history().is_empty() {
399 return Ok("No command history.".to_string());
400 }
401
402 let mut output = String::from("Command History:\n");
403 for (i, entry) in state.history().iter().enumerate() {
404 let status = if entry.success { "✓" } else { "✗" };
405 output.push_str(&format!(
406 " {}. {} {} ({}ms)\n",
407 i + 1,
408 status,
409 entry.command,
410 entry.duration_ms
411 ));
412 }
413 Ok(output)
414}
415
416fn execute_help(topic: Option<&str>) -> Result<String> {
417 match topic {
418 Some("fetch") => Ok(
419 "fetch <model_id> [--teacher|--student]\n Download a model from HuggingFace"
420 .to_string(),
421 ),
422 Some("inspect") => {
423 Ok("inspect [layers|memory|all|<model>]\n Inspect loaded models".to_string())
424 }
425 Some("memory") => {
426 Ok("memory [--batch <n>] [--seq <n>]\n Estimate memory requirements".to_string())
427 }
428 Some("distill") => Ok("distill [--dry-run]\n Run distillation training".to_string()),
429 _ => Ok("Available commands:
430 fetch <model> Download model from HuggingFace
431 inspect [target] Inspect loaded models
432 memory Estimate memory requirements
433 set <key> <value> Configure settings
434 distill Run distillation
435 export <fmt> <path> Export model
436 history Show command history
437 help [topic] Show help
438 quit Exit shell"
439 .to_string()),
440 }
441}
442
443fn detect_architecture(model_id: &str) -> String {
444 let lower = model_id.to_lowercase();
445 if lower.contains("llama") {
446 "llama".to_string()
447 } else if lower.contains("bert") {
448 "bert".to_string()
449 } else if lower.contains("gpt") {
450 "gpt".to_string()
451 } else if lower.contains("mistral") {
452 "mistral".to_string()
453 } else {
454 "unknown".to_string()
455 }
456}
457
458fn estimate_params(model_id: &str) -> u64 {
459 let lower = model_id.to_lowercase();
460 if lower.contains("70b") {
461 70_000_000_000
462 } else if lower.contains("13b") {
463 13_000_000_000
464 } else if lower.contains("7b") {
465 7_000_000_000
466 } else if lower.contains("1.1b") || lower.contains("1b") {
467 1_100_000_000
468 } else if lower.contains("base") {
469 350_000_000
470 } else {
471 1_000_000_000
472 }
473}
474
475fn estimate_layers(model_id: &str) -> u32 {
476 let lower = model_id.to_lowercase();
477 if lower.contains("70b") {
478 80
479 } else if lower.contains("13b") {
480 40
481 } else if lower.contains("7b") {
482 32
483 } else if lower.contains("base") {
484 12
485 } else {
486 24
487 }
488}
489
490#[cfg(test)]
491mod tests {
492 use super::*;
493
494 #[test]
495 fn test_parse_fetch() {
496 let cmd = parse("fetch meta-llama/Llama-2-7b --teacher").unwrap();
497 assert!(matches!(
498 cmd,
499 Command::Fetch {
500 role: ModelRole::Teacher,
501 ..
502 }
503 ));
504
505 let cmd = parse("fetch TinyLlama/TinyLlama-1.1B --student").unwrap();
506 assert!(matches!(
507 cmd,
508 Command::Fetch {
509 role: ModelRole::Student,
510 ..
511 }
512 ));
513 }
514
515 #[test]
516 fn test_parse_inspect() {
517 assert!(matches!(
518 parse("inspect").unwrap(),
519 Command::Inspect {
520 target: InspectTarget::All
521 }
522 ));
523 assert!(matches!(
524 parse("inspect layers").unwrap(),
525 Command::Inspect {
526 target: InspectTarget::Layers
527 }
528 ));
529 }
530
531 #[test]
532 fn test_parse_memory() {
533 let cmd = parse("memory --batch 64 --seq 1024").unwrap();
534 if let Command::Memory {
535 batch_size,
536 seq_len,
537 } = cmd
538 {
539 assert_eq!(batch_size, Some(64));
540 assert_eq!(seq_len, Some(1024));
541 } else {
542 panic!("Expected Memory command");
543 }
544 }
545
546 #[test]
547 fn test_parse_quit_variants() {
548 assert!(matches!(parse("quit").unwrap(), Command::Quit));
549 assert!(matches!(parse("exit").unwrap(), Command::Quit));
550 assert!(matches!(parse("q").unwrap(), Command::Quit));
551 }
552
553 #[test]
554 fn test_execute_fetch() {
555 let mut state = SessionState::new();
556 let result = execute_fetch("meta-llama/Llama-2-7b", ModelRole::Teacher, &mut state);
557
558 assert!(result.is_ok());
559 assert!(state.get_model("teacher").is_some());
560 }
561
562 #[test]
563 fn test_execute_set() {
564 let mut state = SessionState::new();
565
566 execute_set("batch_size", "64", &mut state).unwrap();
567 assert_eq!(state.preferences().default_batch_size, 64);
568
569 execute_set("seq_len", "1024", &mut state).unwrap();
570 assert_eq!(state.preferences().default_seq_len, 1024);
571 }
572
573 #[test]
574 fn test_unknown_command() {
575 let cmd = parse("foobar").unwrap();
576 assert!(matches!(cmd, Command::Unknown { .. }));
577 }
578
579 #[test]
580 fn test_parse_fetch_missing_model() {
581 let result = parse("fetch");
582 assert!(result.is_err());
583 }
584
585 #[test]
586 fn test_parse_set_not_enough_args() {
587 let result = parse("set batch_size");
588 assert!(result.is_err());
589 }
590
591 #[test]
592 fn test_parse_export_not_enough_args() {
593 let result = parse("export safetensors");
594 assert!(result.is_err());
595 }
596
597 #[test]
598 fn test_parse_export_valid() {
599 let cmd = parse("export safetensors /tmp/model.st").unwrap();
600 if let Command::Export { format, path } = cmd {
601 assert_eq!(format, "safetensors");
602 assert_eq!(path, "/tmp/model.st");
603 } else {
604 panic!("Expected Export command");
605 }
606 }
607
608 #[test]
609 fn test_parse_help_with_topic() {
610 let cmd = parse("help fetch").unwrap();
611 if let Command::Help { topic } = cmd {
612 assert_eq!(topic, Some("fetch".to_string()));
613 } else {
614 panic!("Expected Help command");
615 }
616 }
617
618 #[test]
619 fn test_parse_distill_dry_run() {
620 let cmd = parse("distill --dry-run").unwrap();
621 if let Command::Distill { dry_run } = cmd {
622 assert!(dry_run);
623 } else {
624 panic!("Expected Distill command");
625 }
626 }
627
628 #[test]
629 fn test_parse_distill_short_flag() {
630 let cmd = parse("distill -n").unwrap();
631 if let Command::Distill { dry_run } = cmd {
632 assert!(dry_run);
633 } else {
634 panic!("Expected Distill command");
635 }
636 }
637
638 #[test]
639 fn test_parse_inspect_model() {
640 let cmd = parse("inspect teacher").unwrap();
641 if let Command::Inspect { target } = cmd {
642 assert_eq!(target, InspectTarget::Model("teacher".to_string()));
643 } else {
644 panic!("Expected Inspect command");
645 }
646 }
647
648 #[test]
649 fn test_parse_inspect_memory() {
650 let cmd = parse("inspect memory").unwrap();
651 assert!(matches!(
652 cmd,
653 Command::Inspect {
654 target: InspectTarget::Memory
655 }
656 ));
657 }
658
659 #[test]
660 fn test_parse_command_aliases() {
661 assert!(matches!(
663 parse("download model").unwrap(),
664 Command::Fetch { .. }
665 ));
666 assert!(matches!(
668 parse("show layers").unwrap(),
669 Command::Inspect { .. }
670 ));
671 assert!(matches!(parse("mem").unwrap(), Command::Memory { .. }));
673 assert!(matches!(parse("train").unwrap(), Command::Distill { .. }));
675 assert!(matches!(
677 parse("save gguf /tmp/out").unwrap(),
678 Command::Export { .. }
679 ));
680 assert!(matches!(parse("cls").unwrap(), Command::Clear));
682 assert!(matches!(parse("?").unwrap(), Command::Help { .. }));
684 assert!(matches!(parse("hist").unwrap(), Command::History));
686 }
687
688 #[test]
689 fn test_execute_inspect_no_models() {
690 let state = SessionState::new();
691 let result = execute_inspect(&InspectTarget::All, &state);
692 assert!(result.unwrap().contains("No models loaded"));
693 }
694
695 #[test]
696 fn test_execute_inspect_layers() {
697 let mut state = SessionState::new();
698 let model = LoadedModel {
699 id: "test".to_string(),
700 path: std::path::PathBuf::from("/tmp"),
701 architecture: "llama".to_string(),
702 parameters: 7_000_000_000,
703 layers: 32,
704 hidden_dim: 4096,
705 role: ModelRole::None,
706 };
707 state.add_model("test".to_string(), model);
708
709 let result = execute_inspect(&InspectTarget::Layers, &state).unwrap();
710 assert!(result.contains("32 layers"));
711 }
712
713 #[test]
714 fn test_execute_inspect_model_not_found() {
715 let state = SessionState::new();
716 let result = execute_inspect(&InspectTarget::Model("unknown".to_string()), &state);
717 assert!(result.is_err());
718 }
719
720 #[test]
721 fn test_execute_history_empty() {
722 let state = SessionState::new();
723 let result = execute_history(&state).unwrap();
724 assert!(result.contains("No command history"));
725 }
726
727 #[test]
728 fn test_execute_help_topics() {
729 let fetch_help = execute_help(Some("fetch")).unwrap();
730 assert!(fetch_help.contains("Download"));
731
732 let inspect_help = execute_help(Some("inspect")).unwrap();
733 assert!(inspect_help.contains("Inspect"));
734
735 let memory_help = execute_help(Some("memory")).unwrap();
736 assert!(memory_help.contains("memory"));
737
738 let distill_help = execute_help(Some("distill")).unwrap();
739 assert!(distill_help.contains("distill"));
740
741 let general_help = execute_help(None).unwrap();
742 assert!(general_help.contains("Available commands"));
743 }
744
745 #[test]
746 fn test_detect_architecture_variants() {
747 assert_eq!(detect_architecture("meta-llama/Llama-2-7b"), "llama");
748 assert_eq!(detect_architecture("bert-base-uncased"), "bert");
749 assert_eq!(detect_architecture("openai-gpt"), "gpt");
750 assert_eq!(detect_architecture("mistralai/Mistral-7B"), "mistral");
751 assert_eq!(detect_architecture("custom-model"), "unknown");
752 }
753
754 #[test]
755 fn test_estimate_params_variants() {
756 assert_eq!(estimate_params("model-70b"), 70_000_000_000);
757 assert_eq!(estimate_params("model-13b"), 13_000_000_000);
758 assert_eq!(estimate_params("model-7b"), 7_000_000_000);
759 assert_eq!(estimate_params("model-1.1b"), 1_100_000_000);
760 assert_eq!(estimate_params("bert-base"), 350_000_000);
761 }
762
763 #[test]
764 fn test_estimate_layers_variants() {
765 assert_eq!(estimate_layers("model-70b"), 80);
766 assert_eq!(estimate_layers("model-13b"), 40);
767 assert_eq!(estimate_layers("model-7b"), 32);
768 assert_eq!(estimate_layers("bert-base"), 12);
769 }
770
771 #[test]
772 fn test_execute_set_invalid_number() {
773 let mut state = SessionState::new();
774 let result = execute_set("batch_size", "not_a_number", &mut state);
775 assert!(result.is_err());
776 }
777
778 #[test]
779 fn test_execute_set_unknown_key() {
780 let mut state = SessionState::new();
781 let result = execute_set("unknown_setting", "value", &mut state);
782 assert!(result.is_err());
783 }
784
785 #[test]
786 fn test_execute_distill_no_teacher() {
787 let state = SessionState::new();
788 let result = execute_distill(true, &state);
789 assert!(result.is_err());
790 }
791
792 #[test]
793 fn test_execute_distill_no_student() {
794 let mut state = SessionState::new();
795 let model = LoadedModel {
796 id: "teacher".to_string(),
797 path: std::path::PathBuf::from("/tmp"),
798 architecture: "llama".to_string(),
799 parameters: 7_000_000_000,
800 layers: 32,
801 hidden_dim: 4096,
802 role: ModelRole::Teacher,
803 };
804 state.add_model("teacher".to_string(), model);
805
806 let result = execute_distill(true, &state);
807 assert!(result.is_err());
808 }
809
810 #[test]
811 fn test_execute_distill_success() {
812 let mut state = SessionState::new();
813
814 let teacher = LoadedModel {
815 id: "teacher/model".to_string(),
816 path: std::path::PathBuf::from("/tmp/t"),
817 architecture: "llama".to_string(),
818 parameters: 7_000_000_000,
819 layers: 32,
820 hidden_dim: 4096,
821 role: ModelRole::Teacher,
822 };
823 state.add_model("teacher".to_string(), teacher);
824
825 let student = LoadedModel {
826 id: "student/model".to_string(),
827 path: std::path::PathBuf::from("/tmp/s"),
828 architecture: "llama".to_string(),
829 parameters: 1_000_000_000,
830 layers: 12,
831 hidden_dim: 2048,
832 role: ModelRole::Student,
833 };
834 state.add_model("student".to_string(), student);
835
836 let result = execute_distill(true, &state).unwrap();
837 assert!(result.contains("Dry run"));
838 }
839
840 #[test]
841 fn test_execute_export() {
842 let state = SessionState::new();
843 let result = execute_export("safetensors", "/tmp/model.st", &state).unwrap();
844 assert!(result.contains("Exported"));
845 }
846
847 #[test]
848 fn test_parse_empty_input() {
849 let cmd = parse("").unwrap();
850 assert!(matches!(cmd, Command::Unknown { .. }));
851 }
852
853 #[test]
854 fn test_execute_memory_with_args() {
855 let state = SessionState::new();
856 let result = execute_memory(Some(64), Some(1024), &state).unwrap();
857 assert!(result.contains("batch=64"));
858 assert!(result.contains("seq=1024"));
859 }
860
861 #[test]
862 fn test_command_enum_equality() {
863 assert_eq!(Command::Quit, Command::Quit);
864 assert_eq!(Command::Clear, Command::Clear);
865 assert_eq!(Command::History, Command::History);
866 }
867}