1use crate::data::VerifiedTuple;
16use serde::{Deserialize, Serialize};
17use std::path::Path;
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct CodeTranslationExample {
22 pub id: String,
24 pub source_language: String,
26 pub target_language: String,
28 pub source_code: String,
30 pub target_code: String,
32 pub prompt: String,
34 pub completion: String,
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct PromptTemplate {
41 pub name: String,
43 pub system: Option<String>,
45 pub user_template: String,
47 pub include_lang_tags: bool,
49}
50
51impl Default for PromptTemplate {
52 fn default() -> Self {
53 Self::instruction_following()
54 }
55}
56
57impl PromptTemplate {
58 #[must_use]
60 pub fn instruction_following() -> Self {
61 Self {
62 name: "instruction".to_string(),
63 system: Some("You are an expert code translator.".to_string()),
64 user_template: "Translate the following {source_lang} code to {target_lang}:\n\n```{source_lang}\n{source_code}\n```".to_string(),
65 include_lang_tags: true,
66 }
67 }
68
69 #[must_use]
71 pub fn chat_style() -> Self {
72 Self {
73 name: "chat".to_string(),
74 system: Some("You are a helpful assistant that translates code between programming languages.".to_string()),
75 user_template: "Please convert this {source_lang} code to idiomatic {target_lang}:\n\n{source_code}".to_string(),
76 include_lang_tags: false,
77 }
78 }
79
80 #[must_use]
82 pub fn completion_style() -> Self {
83 Self {
84 name: "completion".to_string(),
85 system: None,
86 user_template: "# {source_lang}\n{source_code}\n\n# {target_lang}\n".to_string(),
87 include_lang_tags: false,
88 }
89 }
90
91 #[must_use]
93 pub fn format(&self, source_lang: &str, target_lang: &str, source_code: &str) -> String {
94 self.user_template
95 .replace("{source_lang}", source_lang)
96 .replace("{target_lang}", target_lang)
97 .replace("{source_code}", source_code)
98 }
99
100 #[must_use]
102 pub fn format_completion(&self, target_lang: &str, target_code: &str) -> String {
103 if self.include_lang_tags {
104 format!("```{target_lang}\n{target_code}\n```")
105 } else {
106 target_code.to_string()
107 }
108 }
109}
110
111#[derive(Debug, Clone, Serialize, Deserialize)]
113pub struct ExportConfig {
114 pub format: ExportFormat,
116 pub template: PromptTemplate,
118 pub train_ratio: f64,
120 pub seed: u64,
122 pub max_examples: Option<usize>,
124}
125
126impl Default for ExportConfig {
127 fn default() -> Self {
128 Self {
129 format: ExportFormat::Jsonl,
130 template: PromptTemplate::default(),
131 train_ratio: 0.9,
132 seed: 42,
133 max_examples: None,
134 }
135 }
136}
137
138#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
140#[serde(rename_all = "lowercase")]
141pub enum ExportFormat {
142 Json,
144 Jsonl,
146 Parquet,
148}
149
150#[derive(Debug, Clone, Default)]
152pub struct ExportStats {
153 pub total: usize,
155 pub train_count: usize,
157 pub val_count: usize,
159 pub avg_source_len: f64,
161 pub avg_target_len: f64,
163}
164
165#[derive(Debug, Clone, Serialize, Deserialize)]
183pub struct DistillationConfig {
184 pub temperature: f32,
187
188 pub alpha: f32,
192
193 pub num_teachers: usize,
195
196 pub student: StudentConfig,
198
199 pub training: DistillTrainingConfig,
201
202 pub output_dir: std::path::PathBuf,
204}
205
206impl Default for DistillationConfig {
207 fn default() -> Self {
208 Self {
209 temperature: 3.0,
210 alpha: 0.7,
211 num_teachers: 1,
212 student: StudentConfig::default(),
213 training: DistillTrainingConfig::default(),
214 output_dir: std::path::PathBuf::from("distilled_model"),
215 }
216 }
217}
218
219#[derive(Debug, Clone, Serialize, Deserialize)]
221pub struct StudentConfig {
222 pub model_type: String,
224 pub hidden_size: usize,
226 pub num_layers: usize,
228 pub num_classes: usize,
230}
231
232impl Default for StudentConfig {
233 fn default() -> Self {
234 Self {
235 model_type: "distilled_student".to_string(),
236 hidden_size: 256,
237 num_layers: 4,
238 num_classes: 18, }
240 }
241}
242
243#[derive(Debug, Clone, Serialize, Deserialize)]
245pub struct DistillTrainingConfig {
246 pub epochs: usize,
248 pub batch_size: usize,
250 pub learning_rate: f64,
252 pub grad_clip: f32,
254 pub fp16: bool,
256}
257
258impl Default for DistillTrainingConfig {
259 fn default() -> Self {
260 Self {
261 epochs: 3,
262 batch_size: 32,
263 learning_rate: 1e-4,
264 grad_clip: 1.0,
265 fp16: false,
266 }
267 }
268}
269
270#[derive(Debug, Clone, Serialize, Deserialize)]
272pub struct DistillationResult {
273 pub final_loss: f32,
275 pub loss_history: Vec<f32>,
277 pub teacher_count: usize,
279 pub student_config: StudentConfig,
281 pub temperature: f32,
283 pub alpha: f32,
285 pub status: String,
287 pub note: String,
289}
290
291impl DistillationConfig {
292 #[must_use]
294 pub fn new() -> Self {
295 Self::default()
296 }
297
298 #[must_use]
300 pub fn with_temperature(mut self, temperature: f32) -> Self {
301 self.temperature = temperature;
302 self
303 }
304
305 #[must_use]
307 pub fn with_alpha(mut self, alpha: f32) -> Self {
308 self.alpha = alpha;
309 self
310 }
311
312 #[must_use]
314 pub fn with_teachers(mut self, num_teachers: usize) -> Self {
315 self.num_teachers = num_teachers;
316 self
317 }
318
319 #[must_use]
321 pub fn with_student(mut self, student: StudentConfig) -> Self {
322 self.student = student;
323 self
324 }
325
326 #[must_use]
328 pub fn with_training(mut self, training: DistillTrainingConfig) -> Self {
329 self.training = training;
330 self
331 }
332
333 #[must_use]
335 pub fn with_output_dir(mut self, output_dir: impl Into<std::path::PathBuf>) -> Self {
336 self.output_dir = output_dir.into();
337 self
338 }
339
340 #[must_use]
342 pub fn to_yaml(&self) -> String {
343 format!(
344 "# Entrenar Distillation Config\n\
345 # Generated by verificar distill\n\
346 \n\
347 model:\n\
348 \x20 type: student\n\
349 \x20 hidden_size: {}\n\
350 \x20 num_layers: {}\n\
351 \n\
352 distillation:\n\
353 \x20 temperature: {}\n\
354 \x20 alpha: {}\n\
355 \x20 num_teachers: {}\n\
356 \n\
357 training:\n\
358 \x20 epochs: {}\n\
359 \x20 batch_size: {}\n\
360 \x20 learning_rate: {:e}\n\
361 \n\
362 data:\n\
363 \x20 teacher_logits: \"/tmp/teacher_logits\"\n\
364 \x20 output_dir: {:?}\n",
365 self.student.hidden_size,
366 self.student.num_layers,
367 self.temperature,
368 self.alpha,
369 self.num_teachers,
370 self.training.epochs,
371 self.training.batch_size,
372 self.training.learning_rate,
373 self.output_dir.display()
374 )
375 }
376
377 #[must_use]
382 pub fn run_placeholder(&self) -> DistillationResult {
383 let mut loss_history = Vec::with_capacity(self.training.epochs);
385 let mut loss = 2.6_f32;
386
387 for _ in 0..self.training.epochs {
388 loss *= 0.75; loss_history.push(loss);
390 }
391
392 DistillationResult {
393 final_loss: loss,
394 loss_history,
395 teacher_count: self.num_teachers,
396 student_config: self.student.clone(),
397 temperature: self.temperature,
398 alpha: self.alpha,
399 status: "placeholder".to_string(),
400 note: "Full distillation requires entrenar llm feature and teacher model weights"
401 .to_string(),
402 }
403 }
404
405 pub fn validate(&self) -> Result<(), String> {
411 if self.temperature <= 0.0 {
412 return Err("temperature must be positive".to_string());
413 }
414 if !(0.0..=1.0).contains(&self.alpha) {
415 return Err("alpha must be in [0.0, 1.0]".to_string());
416 }
417 if self.num_teachers == 0 {
418 return Err("num_teachers must be at least 1".to_string());
419 }
420 if self.student.hidden_size == 0 {
421 return Err("hidden_size must be positive".to_string());
422 }
423 if self.student.num_layers == 0 {
424 return Err("num_layers must be at least 1".to_string());
425 }
426 if self.training.epochs == 0 {
427 return Err("epochs must be at least 1".to_string());
428 }
429 if self.training.learning_rate <= 0.0 {
430 return Err("learning_rate must be positive".to_string());
431 }
432 Ok(())
433 }
434}
435
436#[derive(Debug)]
438pub struct EntrenarExporter {
439 config: ExportConfig,
440}
441
442impl EntrenarExporter {
443 #[must_use]
445 pub fn new(config: ExportConfig) -> Self {
446 Self { config }
447 }
448
449 #[must_use]
451 pub fn to_example(&self, tuple: &VerifiedTuple, id: &str) -> CodeTranslationExample {
452 let source_lang = tuple.source_language.to_string();
453 let target_lang = tuple.target_language.to_string();
454
455 let prompt = self
456 .config
457 .template
458 .format(&source_lang, &target_lang, &tuple.source_code);
459 let completion = self
460 .config
461 .template
462 .format_completion(&target_lang, &tuple.target_code);
463
464 CodeTranslationExample {
465 id: id.to_string(),
466 source_language: source_lang,
467 target_language: target_lang,
468 source_code: tuple.source_code.clone(),
469 target_code: tuple.target_code.clone(),
470 prompt,
471 completion,
472 }
473 }
474
475 pub fn export(
481 &self,
482 tuples: &[VerifiedTuple],
483 output_dir: &Path,
484 ) -> std::io::Result<ExportStats> {
485 let examples: Vec<_> = tuples
486 .iter()
487 .take(self.config.max_examples.unwrap_or(usize::MAX))
488 .enumerate()
489 .map(|(i, t)| self.to_example(t, &format!("ex_{i:06}")))
490 .collect();
491
492 let (train, val) = self.split_train_val(&examples);
493
494 let stats = ExportStats {
495 total: examples.len(),
496 train_count: train.len(),
497 val_count: val.len(),
498 avg_source_len: examples.iter().map(|e| e.source_code.len()).sum::<usize>() as f64
499 / examples.len().max(1) as f64,
500 avg_target_len: examples.iter().map(|e| e.target_code.len()).sum::<usize>() as f64
501 / examples.len().max(1) as f64,
502 };
503
504 std::fs::create_dir_all(output_dir)?;
505
506 match self.config.format {
507 ExportFormat::Json => {
508 self.write_json(&train, &output_dir.join("train.json"))?;
509 self.write_json(&val, &output_dir.join("val.json"))?;
510 }
511 ExportFormat::Jsonl => {
512 self.write_jsonl(&train, &output_dir.join("train.jsonl"))?;
513 self.write_jsonl(&val, &output_dir.join("val.jsonl"))?;
514 }
515 ExportFormat::Parquet => {
516 self.write_jsonl(&train, &output_dir.join("train.jsonl"))?;
519 self.write_jsonl(&val, &output_dir.join("val.jsonl"))?;
520 }
521 }
522
523 Ok(stats)
524 }
525
526 fn split_train_val(
528 &self,
529 examples: &[CodeTranslationExample],
530 ) -> (Vec<CodeTranslationExample>, Vec<CodeTranslationExample>) {
531 use std::collections::hash_map::DefaultHasher;
532 use std::hash::{Hash, Hasher};
533
534 let mut train = Vec::new();
535 let mut val = Vec::new();
536
537 for (i, example) in examples.iter().enumerate() {
538 let mut hasher = DefaultHasher::new();
539 (self.config.seed, i).hash(&mut hasher);
540 let hash = hasher.finish();
541
542 #[allow(clippy::cast_sign_loss)]
543 let threshold = (self.config.train_ratio * u64::MAX as f64) as u64;
544
545 if hash < threshold {
546 train.push(example.clone());
547 } else {
548 val.push(example.clone());
549 }
550 }
551
552 (train, val)
553 }
554
555 fn write_json(&self, examples: &[CodeTranslationExample], path: &Path) -> std::io::Result<()> {
556 let json = serde_json::to_string_pretty(examples)
557 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
558 std::fs::write(path, json)
559 }
560
561 fn write_jsonl(&self, examples: &[CodeTranslationExample], path: &Path) -> std::io::Result<()> {
562 use std::io::Write;
563 let mut file = std::fs::File::create(path)?;
564 for example in examples {
565 let line = serde_json::to_string(example)
566 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
567 writeln!(file, "{line}")?;
568 }
569 Ok(())
570 }
571}
572
573#[must_use]
575pub fn generate_entrenar_config(data_dir: &Path, output_dir: &Path, lora_rank: usize) -> String {
576 format!(
577 r"# Entrenar configuration for verificar training data
578# Generated by verificar v{}
579
580model:
581 path: codellama-7b.gguf # Replace with your base model
582 layers: [q_proj, k_proj, v_proj, o_proj]
583
584data:
585 train: {}
586 val: {}
587 batch_size: 4
588 seq_len: 2048
589
590optimizer:
591 name: adamw
592 lr: 0.0001
593 weight_decay: 0.01
594
595lora:
596 rank: {}
597 alpha: {}
598 target_modules: [q_proj, v_proj]
599 dropout: 0.05
600
601training:
602 epochs: 3
603 grad_clip: 1.0
604 lr_scheduler: cosine
605 warmup_steps: 100
606 save_interval: 1
607 output_dir: {}
608",
609 env!("CARGO_PKG_VERSION"),
610 data_dir.join("train.jsonl").display(),
611 data_dir.join("val.jsonl").display(),
612 lora_rank,
613 lora_rank * 2, output_dir.display()
615 )
616}
617
618#[cfg(test)]
619mod tests {
620 use super::*;
621 use crate::Language;
622
623 fn sample_tuple() -> VerifiedTuple {
624 VerifiedTuple {
625 source_language: Language::Python,
626 target_language: Language::Rust,
627 source_code: "def add(a: int, b: int) -> int:\n return a + b".to_string(),
628 target_code: "fn add(a: i32, b: i32) -> i32 {\n a + b\n}".to_string(),
629 is_correct: true,
630 execution_time_ms: 10,
631 }
632 }
633
634 #[test]
635 fn test_prompt_template_instruction() {
636 let template = PromptTemplate::instruction_following();
637 let prompt = template.format("Python", "Rust", "x = 1");
638
639 assert!(prompt.contains("Python"));
640 assert!(prompt.contains("Rust"));
641 assert!(prompt.contains("x = 1"));
642 assert!(prompt.contains("```Python"));
643 }
644
645 #[test]
646 fn test_prompt_template_chat() {
647 let template = PromptTemplate::chat_style();
648 let prompt = template.format("Python", "Rust", "x = 1");
649
650 assert!(prompt.contains("Python"));
651 assert!(prompt.contains("idiomatic Rust"));
652 assert!(!prompt.contains("```")); }
654
655 #[test]
656 fn test_prompt_template_completion() {
657 let template = PromptTemplate::completion_style();
658 let prompt = template.format("Python", "Rust", "x = 1");
659
660 assert!(prompt.contains("# Python"));
661 assert!(prompt.contains("# Rust"));
662 }
663
664 #[test]
665 fn test_format_completion_with_tags() {
666 let template = PromptTemplate::instruction_following();
667 let completion = template.format_completion("Rust", "fn main() {}");
668
669 assert!(completion.contains("```Rust"));
670 assert!(completion.contains("fn main() {}"));
671 }
672
673 #[test]
674 fn test_format_completion_without_tags() {
675 let template = PromptTemplate::completion_style();
676 let completion = template.format_completion("Rust", "fn main() {}");
677
678 assert_eq!(completion, "fn main() {}");
679 assert!(!completion.contains("```"));
680 }
681
682 #[test]
683 fn test_to_example() {
684 let config = ExportConfig::default();
685 let exporter = EntrenarExporter::new(config);
686 let tuple = sample_tuple();
687
688 let example = exporter.to_example(&tuple, "test_001");
689
690 assert_eq!(example.id, "test_001");
691 assert_eq!(example.source_language, "python");
692 assert_eq!(example.target_language, "rust");
693 assert!(example.prompt.contains("def add"));
694 assert!(example.completion.contains("fn add"));
695 }
696
697 #[test]
698 fn test_export_config_default() {
699 let config = ExportConfig::default();
700
701 assert_eq!(config.format, ExportFormat::Jsonl);
702 assert!((config.train_ratio - 0.9).abs() < f64::EPSILON);
703 assert_eq!(config.seed, 42);
704 assert!(config.max_examples.is_none());
705 }
706
707 #[test]
708 fn test_split_train_val_ratio() {
709 let config = ExportConfig {
710 train_ratio: 0.8,
711 ..Default::default()
712 };
713 let exporter = EntrenarExporter::new(config);
714
715 let examples: Vec<_> = (0..1000)
716 .map(|i| CodeTranslationExample {
717 id: format!("ex_{i}"),
718 source_language: "Python".to_string(),
719 target_language: "Rust".to_string(),
720 source_code: format!("x = {i}"),
721 target_code: format!("let x = {i};"),
722 prompt: String::new(),
723 completion: String::new(),
724 })
725 .collect();
726
727 let (train, val) = exporter.split_train_val(&examples);
728
729 let train_ratio = train.len() as f64 / examples.len() as f64;
731 assert!(train_ratio > 0.7 && train_ratio < 0.9);
732 assert_eq!(train.len() + val.len(), examples.len());
733 }
734
735 #[test]
736 fn test_split_deterministic() {
737 let config = ExportConfig::default();
738 let exporter = EntrenarExporter::new(config);
739
740 let examples: Vec<_> = (0..100)
741 .map(|i| CodeTranslationExample {
742 id: format!("ex_{i}"),
743 source_language: "Python".to_string(),
744 target_language: "Rust".to_string(),
745 source_code: format!("x = {i}"),
746 target_code: format!("let x = {i};"),
747 prompt: String::new(),
748 completion: String::new(),
749 })
750 .collect();
751
752 let (train1, _) = exporter.split_train_val(&examples);
753 let (train2, _) = exporter.split_train_val(&examples);
754
755 assert_eq!(train1.len(), train2.len());
756 }
757
758 #[test]
759 fn test_generate_entrenar_config() {
760 let config =
761 generate_entrenar_config(Path::new("data/train"), Path::new("outputs/model"), 16);
762
763 assert!(config.contains("lora:"));
764 assert!(config.contains("rank: 16"));
765 assert!(config.contains("alpha: 32"));
766 assert!(config.contains("train.jsonl"));
767 assert!(config.contains("val.jsonl"));
768 }
769
770 #[test]
771 fn test_export_format_serde() {
772 let json = serde_json::to_string(&ExportFormat::Jsonl).unwrap();
773 assert_eq!(json, "\"jsonl\"");
774
775 let parsed: ExportFormat = serde_json::from_str("\"parquet\"").unwrap();
776 assert_eq!(parsed, ExportFormat::Parquet);
777 }
778
779 #[test]
782 #[ignore = "requires filesystem setup"]
783 fn test_export_to_jsonl() {
784 let config = ExportConfig::default();
785 let exporter = EntrenarExporter::new(config);
786 let tuples = vec![sample_tuple()];
787
788 let dir = tempfile::tempdir().unwrap();
789 let stats = exporter.export(&tuples, dir.path()).unwrap();
790
791 assert_eq!(stats.total, 1);
792 assert!(dir.path().join("train.jsonl").exists() || dir.path().join("val.jsonl").exists());
793 }
794
795 #[test]
796 #[ignore = "requires entrenar integration"]
797 fn test_export_to_parquet() {
798 unimplemented!("Parquet export not yet implemented")
804 }
805
806 #[test]
807 #[ignore = "requires entrenar integration"]
808 fn test_load_in_entrenar() {
809 unimplemented!("Entrenar integration test not yet implemented")
813 }
814
815 #[test]
816 #[ignore = "requires LLM evaluation"]
817 fn test_prompt_quality() {
818 unimplemented!("LLM evaluation not yet implemented")
823 }
824
825 #[test]
828 fn test_distillation_config_default() {
829 let config = DistillationConfig::default();
830
831 assert!((config.temperature - 3.0).abs() < f32::EPSILON);
832 assert!((config.alpha - 0.7).abs() < f32::EPSILON);
833 assert_eq!(config.num_teachers, 1);
834 assert_eq!(config.student.hidden_size, 256);
835 assert_eq!(config.student.num_layers, 4);
836 assert_eq!(config.student.num_classes, 18);
837 assert_eq!(config.training.epochs, 3);
838 }
839
840 #[test]
841 fn test_distillation_config_builder() {
842 let config = DistillationConfig::new()
843 .with_temperature(5.0)
844 .with_alpha(0.9)
845 .with_teachers(3)
846 .with_output_dir("/tmp/model");
847
848 assert!((config.temperature - 5.0).abs() < f32::EPSILON);
849 assert!((config.alpha - 0.9).abs() < f32::EPSILON);
850 assert_eq!(config.num_teachers, 3);
851 assert_eq!(config.output_dir.to_str().unwrap(), "/tmp/model");
852 }
853
854 #[test]
855 fn test_distillation_config_with_student() {
856 let student = StudentConfig {
857 model_type: "custom".to_string(),
858 hidden_size: 512,
859 num_layers: 8,
860 num_classes: 10,
861 };
862
863 let config = DistillationConfig::new().with_student(student);
864
865 assert_eq!(config.student.model_type, "custom");
866 assert_eq!(config.student.hidden_size, 512);
867 assert_eq!(config.student.num_layers, 8);
868 assert_eq!(config.student.num_classes, 10);
869 }
870
871 #[test]
872 fn test_distillation_config_with_training() {
873 let training = DistillTrainingConfig {
874 epochs: 10,
875 batch_size: 64,
876 learning_rate: 5e-5,
877 grad_clip: 0.5,
878 fp16: true,
879 };
880
881 let config = DistillationConfig::new().with_training(training);
882
883 assert_eq!(config.training.epochs, 10);
884 assert_eq!(config.training.batch_size, 64);
885 assert!((config.training.learning_rate - 5e-5).abs() < f64::EPSILON);
886 assert!(config.training.fp16);
887 }
888
889 #[test]
890 fn test_distillation_config_to_yaml() {
891 let config = DistillationConfig::default();
892 let yaml = config.to_yaml();
893
894 assert!(yaml.contains("temperature: 3"));
895 assert!(yaml.contains("alpha: 0.7"));
896 assert!(yaml.contains("hidden_size: 256"));
897 assert!(yaml.contains("num_layers: 4"));
898 assert!(yaml.contains("epochs: 3"));
899 }
900
901 #[test]
902 fn test_distillation_config_validate_valid() {
903 let config = DistillationConfig::default();
904 assert!(config.validate().is_ok());
905 }
906
907 #[test]
908 fn test_distillation_config_validate_invalid_temperature() {
909 let config = DistillationConfig::default().with_temperature(0.0);
910 let result = config.validate();
911 assert!(result.is_err());
912 assert!(result.unwrap_err().contains("temperature"));
913 }
914
915 #[test]
916 fn test_distillation_config_validate_invalid_alpha() {
917 let config = DistillationConfig::default().with_alpha(1.5);
918 let result = config.validate();
919 assert!(result.is_err());
920 assert!(result.unwrap_err().contains("alpha"));
921 }
922
923 #[test]
924 fn test_distillation_config_validate_invalid_teachers() {
925 let config = DistillationConfig::default().with_teachers(0);
926 let result = config.validate();
927 assert!(result.is_err());
928 assert!(result.unwrap_err().contains("teachers"));
929 }
930
931 #[test]
932 fn test_distillation_config_validate_invalid_hidden_size() {
933 let mut config = DistillationConfig::default();
934 config.student.hidden_size = 0;
935 let result = config.validate();
936 assert!(result.is_err());
937 assert!(result.unwrap_err().contains("hidden_size"));
938 }
939
940 #[test]
941 fn test_distillation_config_validate_invalid_layers() {
942 let mut config = DistillationConfig::default();
943 config.student.num_layers = 0;
944 let result = config.validate();
945 assert!(result.is_err());
946 assert!(result.unwrap_err().contains("num_layers"));
947 }
948
949 #[test]
950 fn test_distillation_config_validate_invalid_epochs() {
951 let mut config = DistillationConfig::default();
952 config.training.epochs = 0;
953 let result = config.validate();
954 assert!(result.is_err());
955 assert!(result.unwrap_err().contains("epochs"));
956 }
957
958 #[test]
959 fn test_distillation_config_validate_invalid_lr() {
960 let mut config = DistillationConfig::default();
961 config.training.learning_rate = 0.0;
962 let result = config.validate();
963 assert!(result.is_err());
964 assert!(result.unwrap_err().contains("learning_rate"));
965 }
966
967 #[test]
968 fn test_run_placeholder() {
969 let config = DistillationConfig::default();
970 let result = config.run_placeholder();
971
972 assert_eq!(result.teacher_count, 1);
973 assert!((result.temperature - 3.0).abs() < f32::EPSILON);
974 assert!((result.alpha - 0.7).abs() < f32::EPSILON);
975 assert_eq!(result.loss_history.len(), 3); assert!(result.final_loss < 2.6); assert_eq!(result.status, "placeholder");
978 }
979
980 #[test]
981 fn test_distillation_result_serde() {
982 let result = DistillationResult {
983 final_loss: 0.5,
984 loss_history: vec![1.0, 0.75, 0.5],
985 teacher_count: 2,
986 student_config: StudentConfig::default(),
987 temperature: 3.0,
988 alpha: 0.7,
989 status: "complete".to_string(),
990 note: "test".to_string(),
991 };
992
993 let json = serde_json::to_string(&result).unwrap();
994 let parsed: DistillationResult = serde_json::from_str(&json).unwrap();
995
996 assert!((parsed.final_loss - 0.5).abs() < f32::EPSILON);
997 assert_eq!(parsed.teacher_count, 2);
998 assert_eq!(parsed.loss_history.len(), 3);
999 }
1000
1001 #[test]
1002 fn test_student_config_default() {
1003 let config = StudentConfig::default();
1004
1005 assert_eq!(config.model_type, "distilled_student");
1006 assert_eq!(config.hidden_size, 256);
1007 assert_eq!(config.num_layers, 4);
1008 assert_eq!(config.num_classes, 18);
1009 }
1010
1011 #[test]
1012 fn test_distill_training_config_default() {
1013 let config = DistillTrainingConfig::default();
1014
1015 assert_eq!(config.epochs, 3);
1016 assert_eq!(config.batch_size, 32);
1017 assert!((config.learning_rate - 1e-4).abs() < f64::EPSILON);
1018 assert!((config.grad_clip - 1.0).abs() < f32::EPSILON);
1019 assert!(!config.fp16);
1020 }
1021
1022 #[test]
1023 fn test_distillation_config_debug() {
1024 let config = DistillationConfig::default();
1025 let debug = format!("{config:?}");
1026 assert!(debug.contains("DistillationConfig"));
1027 assert!(debug.contains("temperature"));
1028 }
1029
1030 #[test]
1031 fn test_distillation_config_clone() {
1032 let config = DistillationConfig::default();
1033 let cloned = config.clone();
1034 assert!((cloned.temperature - config.temperature).abs() < f32::EPSILON);
1035 assert_eq!(cloned.num_teachers, config.num_teachers);
1036 }
1037
1038 #[test]
1039 fn test_loss_history_decreasing() {
1040 let config = DistillationConfig::new().with_training(DistillTrainingConfig {
1041 epochs: 5,
1042 ..Default::default()
1043 });
1044 let result = config.run_placeholder();
1045
1046 for i in 1..result.loss_history.len() {
1048 assert!(result.loss_history[i] < result.loss_history[i - 1]);
1049 }
1050 }
1051}