Skip to main content

verificar/ml/
entrenar.rs

1//! Entrenar LLM fine-tuning integration
2//!
3//! Exports verified transpilation tuples for LoRA fine-tuning with entrenar.
4//! See VERIFICAR-090.
5//!
6//! # Knowledge Distillation
7//!
8//! From spec Section 5.4: Multi-teacher distillation via temperature-scaled
9//! KL divergence (Hinton et al. 2015).
10//!
11//! ```text
12//! L_distill = α * KL(softmax(z_s/T) || softmax(z_t/T)) + (1-α) * L_CE
13//! ```
14
15use crate::data::VerifiedTuple;
16use serde::{Deserialize, Serialize};
17use std::path::Path;
18
19/// Training example for code-to-code translation
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct CodeTranslationExample {
22    /// Unique identifier
23    pub id: String,
24    /// Source language
25    pub source_language: String,
26    /// Target language
27    pub target_language: String,
28    /// Source code
29    pub source_code: String,
30    /// Target code (correct translation)
31    pub target_code: String,
32    /// Prompt for LLM (formatted input)
33    pub prompt: String,
34    /// Completion for LLM (expected output)
35    pub completion: String,
36}
37
38/// Prompt template for code translation
39#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct PromptTemplate {
41    /// Template name
42    pub name: String,
43    /// System prompt (optional)
44    pub system: Option<String>,
45    /// User prompt template with placeholders: {source_lang}, {target_lang}, {source_code}
46    pub user_template: String,
47    /// Whether to include language tags
48    pub include_lang_tags: bool,
49}
50
51impl Default for PromptTemplate {
52    fn default() -> Self {
53        Self::instruction_following()
54    }
55}
56
57impl PromptTemplate {
58    /// Instruction-following style prompt (Alpaca/Vicuna format)
59    #[must_use]
60    pub fn instruction_following() -> Self {
61        Self {
62            name: "instruction".to_string(),
63            system: Some("You are an expert code translator.".to_string()),
64            user_template: "Translate the following {source_lang} code to {target_lang}:\n\n```{source_lang}\n{source_code}\n```".to_string(),
65            include_lang_tags: true,
66        }
67    }
68
69    /// Chat-style prompt (ChatML format)
70    #[must_use]
71    pub fn chat_style() -> Self {
72        Self {
73            name: "chat".to_string(),
74            system: Some("You are a helpful assistant that translates code between programming languages.".to_string()),
75            user_template: "Please convert this {source_lang} code to idiomatic {target_lang}:\n\n{source_code}".to_string(),
76            include_lang_tags: false,
77        }
78    }
79
80    /// Completion-style prompt (minimal, for base models)
81    #[must_use]
82    pub fn completion_style() -> Self {
83        Self {
84            name: "completion".to_string(),
85            system: None,
86            user_template: "# {source_lang}\n{source_code}\n\n# {target_lang}\n".to_string(),
87            include_lang_tags: false,
88        }
89    }
90
91    /// Format a prompt using this template
92    #[must_use]
93    pub fn format(&self, source_lang: &str, target_lang: &str, source_code: &str) -> String {
94        self.user_template
95            .replace("{source_lang}", source_lang)
96            .replace("{target_lang}", target_lang)
97            .replace("{source_code}", source_code)
98    }
99
100    /// Format completion (target code with optional language tag)
101    #[must_use]
102    pub fn format_completion(&self, target_lang: &str, target_code: &str) -> String {
103        if self.include_lang_tags {
104            format!("```{target_lang}\n{target_code}\n```")
105        } else {
106            target_code.to_string()
107        }
108    }
109}
110
111/// Export configuration for entrenar
112#[derive(Debug, Clone, Serialize, Deserialize)]
113pub struct ExportConfig {
114    /// Output format: "json", "jsonl", "parquet"
115    pub format: ExportFormat,
116    /// Prompt template to use
117    pub template: PromptTemplate,
118    /// Train/val split ratio (0.0 to 1.0)
119    pub train_ratio: f64,
120    /// Random seed for splitting
121    pub seed: u64,
122    /// Maximum examples to export (None for all)
123    pub max_examples: Option<usize>,
124}
125
126impl Default for ExportConfig {
127    fn default() -> Self {
128        Self {
129            format: ExportFormat::Jsonl,
130            template: PromptTemplate::default(),
131            train_ratio: 0.9,
132            seed: 42,
133            max_examples: None,
134        }
135    }
136}
137
138/// Export format
139#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
140#[serde(rename_all = "lowercase")]
141pub enum ExportFormat {
142    /// Single JSON array
143    Json,
144    /// JSON Lines (one object per line)
145    Jsonl,
146    /// Apache Parquet
147    Parquet,
148}
149
150/// Export statistics
151#[derive(Debug, Clone, Default)]
152pub struct ExportStats {
153    /// Total examples exported
154    pub total: usize,
155    /// Training examples
156    pub train_count: usize,
157    /// Validation examples
158    pub val_count: usize,
159    /// Average source code length
160    pub avg_source_len: f64,
161    /// Average target code length
162    pub avg_target_len: f64,
163}
164
165// ============================================================================
166// Knowledge Distillation Configuration (Spec Section 5.4)
167// ============================================================================
168
169/// Configuration for knowledge distillation training
170///
171/// Implements multi-teacher distillation via temperature-scaled KL divergence
172/// (Hinton et al. 2015). The loss function is:
173///
174/// ```text
175/// L_distill = α * KL(softmax(z_s/T) || softmax(z_t/T)) + (1-α) * L_CE
176/// ```
177///
178/// Where:
179/// - `T` is the temperature (higher = softer distributions)
180/// - `α` is the balance between distillation and cross-entropy loss
181/// - `z_s` and `z_t` are student and teacher logits
182#[derive(Debug, Clone, Serialize, Deserialize)]
183pub struct DistillationConfig {
184    /// Temperature for softmax (higher = softer probabilities)
185    /// Typical values: 1.0-10.0. Default: 3.0
186    pub temperature: f32,
187
188    /// Balance between distillation loss and CE loss
189    /// α=1.0 means pure distillation, α=0.0 means pure CE
190    /// Typical values: 0.5-0.9. Default: 0.7
191    pub alpha: f32,
192
193    /// Number of teacher models for ensemble distillation
194    pub num_teachers: usize,
195
196    /// Student model configuration
197    pub student: StudentConfig,
198
199    /// Training hyperparameters
200    pub training: DistillTrainingConfig,
201
202    /// Output directory for distilled model
203    pub output_dir: std::path::PathBuf,
204}
205
206impl Default for DistillationConfig {
207    fn default() -> Self {
208        Self {
209            temperature: 3.0,
210            alpha: 0.7,
211            num_teachers: 1,
212            student: StudentConfig::default(),
213            training: DistillTrainingConfig::default(),
214            output_dir: std::path::PathBuf::from("distilled_model"),
215        }
216    }
217}
218
219/// Student model architecture configuration
220#[derive(Debug, Clone, Serialize, Deserialize)]
221pub struct StudentConfig {
222    /// Model type identifier
223    pub model_type: String,
224    /// Hidden dimension size
225    pub hidden_size: usize,
226    /// Number of transformer layers
227    pub num_layers: usize,
228    /// Number of output classes (for classification)
229    pub num_classes: usize,
230}
231
232impl Default for StudentConfig {
233    fn default() -> Self {
234        Self {
235            model_type: "distilled_student".to_string(),
236            hidden_size: 256,
237            num_layers: 4,
238            num_classes: 18, // 18 defect categories from org-intel
239        }
240    }
241}
242
243/// Training configuration for distillation
244#[derive(Debug, Clone, Serialize, Deserialize)]
245pub struct DistillTrainingConfig {
246    /// Number of training epochs
247    pub epochs: usize,
248    /// Batch size
249    pub batch_size: usize,
250    /// Learning rate
251    pub learning_rate: f64,
252    /// Gradient clipping norm
253    pub grad_clip: f32,
254    /// Whether to use mixed precision training
255    pub fp16: bool,
256}
257
258impl Default for DistillTrainingConfig {
259    fn default() -> Self {
260        Self {
261            epochs: 3,
262            batch_size: 32,
263            learning_rate: 1e-4,
264            grad_clip: 1.0,
265            fp16: false,
266        }
267    }
268}
269
270/// Result from distillation training
271#[derive(Debug, Clone, Serialize, Deserialize)]
272pub struct DistillationResult {
273    /// Final distillation loss
274    pub final_loss: f32,
275    /// Loss history per epoch
276    pub loss_history: Vec<f32>,
277    /// Number of teachers used
278    pub teacher_count: usize,
279    /// Student model configuration
280    pub student_config: StudentConfig,
281    /// Temperature used
282    pub temperature: f32,
283    /// Alpha used
284    pub alpha: f32,
285    /// Training status
286    pub status: String,
287    /// Additional notes
288    pub note: String,
289}
290
291impl DistillationConfig {
292    /// Create a new distillation config
293    #[must_use]
294    pub fn new() -> Self {
295        Self::default()
296    }
297
298    /// Builder: set temperature
299    #[must_use]
300    pub fn with_temperature(mut self, temperature: f32) -> Self {
301        self.temperature = temperature;
302        self
303    }
304
305    /// Builder: set alpha (distillation weight)
306    #[must_use]
307    pub fn with_alpha(mut self, alpha: f32) -> Self {
308        self.alpha = alpha;
309        self
310    }
311
312    /// Builder: set number of teachers
313    #[must_use]
314    pub fn with_teachers(mut self, num_teachers: usize) -> Self {
315        self.num_teachers = num_teachers;
316        self
317    }
318
319    /// Builder: set student config
320    #[must_use]
321    pub fn with_student(mut self, student: StudentConfig) -> Self {
322        self.student = student;
323        self
324    }
325
326    /// Builder: set training config
327    #[must_use]
328    pub fn with_training(mut self, training: DistillTrainingConfig) -> Self {
329        self.training = training;
330        self
331    }
332
333    /// Builder: set output directory
334    #[must_use]
335    pub fn with_output_dir(mut self, output_dir: impl Into<std::path::PathBuf>) -> Self {
336        self.output_dir = output_dir.into();
337        self
338    }
339
340    /// Generate YAML configuration file for entrenar distillation
341    #[must_use]
342    pub fn to_yaml(&self) -> String {
343        format!(
344            "# Entrenar Distillation Config\n\
345             # Generated by verificar distill\n\
346             \n\
347             model:\n\
348             \x20 type: student\n\
349             \x20 hidden_size: {}\n\
350             \x20 num_layers: {}\n\
351             \n\
352             distillation:\n\
353             \x20 temperature: {}\n\
354             \x20 alpha: {}\n\
355             \x20 num_teachers: {}\n\
356             \n\
357             training:\n\
358             \x20 epochs: {}\n\
359             \x20 batch_size: {}\n\
360             \x20 learning_rate: {:e}\n\
361             \n\
362             data:\n\
363             \x20 teacher_logits: \"/tmp/teacher_logits\"\n\
364             \x20 output_dir: {:?}\n",
365            self.student.hidden_size,
366            self.student.num_layers,
367            self.temperature,
368            self.alpha,
369            self.num_teachers,
370            self.training.epochs,
371            self.training.batch_size,
372            self.training.learning_rate,
373            self.output_dir.display()
374        )
375    }
376
377    /// Run placeholder distillation (simulates training)
378    ///
379    /// Full distillation requires entrenar LLM feature and teacher model weights.
380    /// This returns a placeholder result for testing the pipeline.
381    #[must_use]
382    pub fn run_placeholder(&self) -> DistillationResult {
383        // Simulate decreasing loss over epochs
384        let mut loss_history = Vec::with_capacity(self.training.epochs);
385        let mut loss = 2.6_f32;
386
387        for _ in 0..self.training.epochs {
388            loss *= 0.75; // Simulate 25% improvement per epoch
389            loss_history.push(loss);
390        }
391
392        DistillationResult {
393            final_loss: loss,
394            loss_history,
395            teacher_count: self.num_teachers,
396            student_config: self.student.clone(),
397            temperature: self.temperature,
398            alpha: self.alpha,
399            status: "placeholder".to_string(),
400            note: "Full distillation requires entrenar llm feature and teacher model weights"
401                .to_string(),
402        }
403    }
404
405    /// Validate configuration parameters
406    ///
407    /// # Errors
408    ///
409    /// Returns error if parameters are invalid
410    pub fn validate(&self) -> Result<(), String> {
411        if self.temperature <= 0.0 {
412            return Err("temperature must be positive".to_string());
413        }
414        if !(0.0..=1.0).contains(&self.alpha) {
415            return Err("alpha must be in [0.0, 1.0]".to_string());
416        }
417        if self.num_teachers == 0 {
418            return Err("num_teachers must be at least 1".to_string());
419        }
420        if self.student.hidden_size == 0 {
421            return Err("hidden_size must be positive".to_string());
422        }
423        if self.student.num_layers == 0 {
424            return Err("num_layers must be at least 1".to_string());
425        }
426        if self.training.epochs == 0 {
427            return Err("epochs must be at least 1".to_string());
428        }
429        if self.training.learning_rate <= 0.0 {
430            return Err("learning_rate must be positive".to_string());
431        }
432        Ok(())
433    }
434}
435
436/// Exporter for entrenar training data
437#[derive(Debug)]
438pub struct EntrenarExporter {
439    config: ExportConfig,
440}
441
442impl EntrenarExporter {
443    /// Create a new exporter with configuration
444    #[must_use]
445    pub fn new(config: ExportConfig) -> Self {
446        Self { config }
447    }
448
449    /// Convert verified tuple to training example
450    #[must_use]
451    pub fn to_example(&self, tuple: &VerifiedTuple, id: &str) -> CodeTranslationExample {
452        let source_lang = tuple.source_language.to_string();
453        let target_lang = tuple.target_language.to_string();
454
455        let prompt = self
456            .config
457            .template
458            .format(&source_lang, &target_lang, &tuple.source_code);
459        let completion = self
460            .config
461            .template
462            .format_completion(&target_lang, &tuple.target_code);
463
464        CodeTranslationExample {
465            id: id.to_string(),
466            source_language: source_lang,
467            target_language: target_lang,
468            source_code: tuple.source_code.clone(),
469            target_code: tuple.target_code.clone(),
470            prompt,
471            completion,
472        }
473    }
474
475    /// Export verified tuples to training data
476    ///
477    /// # Errors
478    ///
479    /// Returns error if export fails
480    pub fn export(
481        &self,
482        tuples: &[VerifiedTuple],
483        output_dir: &Path,
484    ) -> std::io::Result<ExportStats> {
485        let examples: Vec<_> = tuples
486            .iter()
487            .take(self.config.max_examples.unwrap_or(usize::MAX))
488            .enumerate()
489            .map(|(i, t)| self.to_example(t, &format!("ex_{i:06}")))
490            .collect();
491
492        let (train, val) = self.split_train_val(&examples);
493
494        let stats = ExportStats {
495            total: examples.len(),
496            train_count: train.len(),
497            val_count: val.len(),
498            avg_source_len: examples.iter().map(|e| e.source_code.len()).sum::<usize>() as f64
499                / examples.len().max(1) as f64,
500            avg_target_len: examples.iter().map(|e| e.target_code.len()).sum::<usize>() as f64
501                / examples.len().max(1) as f64,
502        };
503
504        std::fs::create_dir_all(output_dir)?;
505
506        match self.config.format {
507            ExportFormat::Json => {
508                self.write_json(&train, &output_dir.join("train.json"))?;
509                self.write_json(&val, &output_dir.join("val.json"))?;
510            }
511            ExportFormat::Jsonl => {
512                self.write_jsonl(&train, &output_dir.join("train.jsonl"))?;
513                self.write_jsonl(&val, &output_dir.join("val.jsonl"))?;
514            }
515            ExportFormat::Parquet => {
516                // Parquet export requires additional dependencies
517                // For now, fall back to JSONL
518                self.write_jsonl(&train, &output_dir.join("train.jsonl"))?;
519                self.write_jsonl(&val, &output_dir.join("val.jsonl"))?;
520            }
521        }
522
523        Ok(stats)
524    }
525
526    /// Split examples into train/val sets
527    fn split_train_val(
528        &self,
529        examples: &[CodeTranslationExample],
530    ) -> (Vec<CodeTranslationExample>, Vec<CodeTranslationExample>) {
531        use std::collections::hash_map::DefaultHasher;
532        use std::hash::{Hash, Hasher};
533
534        let mut train = Vec::new();
535        let mut val = Vec::new();
536
537        for (i, example) in examples.iter().enumerate() {
538            let mut hasher = DefaultHasher::new();
539            (self.config.seed, i).hash(&mut hasher);
540            let hash = hasher.finish();
541
542            #[allow(clippy::cast_sign_loss)]
543            let threshold = (self.config.train_ratio * u64::MAX as f64) as u64;
544
545            if hash < threshold {
546                train.push(example.clone());
547            } else {
548                val.push(example.clone());
549            }
550        }
551
552        (train, val)
553    }
554
555    fn write_json(&self, examples: &[CodeTranslationExample], path: &Path) -> std::io::Result<()> {
556        let json = serde_json::to_string_pretty(examples)
557            .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
558        std::fs::write(path, json)
559    }
560
561    fn write_jsonl(&self, examples: &[CodeTranslationExample], path: &Path) -> std::io::Result<()> {
562        use std::io::Write;
563        let mut file = std::fs::File::create(path)?;
564        for example in examples {
565            let line = serde_json::to_string(example)
566                .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
567            writeln!(file, "{line}")?;
568        }
569        Ok(())
570    }
571}
572
573/// Generate entrenar YAML config for the exported data
574#[must_use]
575pub fn generate_entrenar_config(data_dir: &Path, output_dir: &Path, lora_rank: usize) -> String {
576    format!(
577        r"# Entrenar configuration for verificar training data
578# Generated by verificar v{}
579
580model:
581  path: codellama-7b.gguf  # Replace with your base model
582  layers: [q_proj, k_proj, v_proj, o_proj]
583
584data:
585  train: {}
586  val: {}
587  batch_size: 4
588  seq_len: 2048
589
590optimizer:
591  name: adamw
592  lr: 0.0001
593  weight_decay: 0.01
594
595lora:
596  rank: {}
597  alpha: {}
598  target_modules: [q_proj, v_proj]
599  dropout: 0.05
600
601training:
602  epochs: 3
603  grad_clip: 1.0
604  lr_scheduler: cosine
605  warmup_steps: 100
606  save_interval: 1
607  output_dir: {}
608",
609        env!("CARGO_PKG_VERSION"),
610        data_dir.join("train.jsonl").display(),
611        data_dir.join("val.jsonl").display(),
612        lora_rank,
613        lora_rank * 2, // alpha = 2 * rank is common
614        output_dir.display()
615    )
616}
617
618#[cfg(test)]
619mod tests {
620    use super::*;
621    use crate::Language;
622
623    fn sample_tuple() -> VerifiedTuple {
624        VerifiedTuple {
625            source_language: Language::Python,
626            target_language: Language::Rust,
627            source_code: "def add(a: int, b: int) -> int:\n    return a + b".to_string(),
628            target_code: "fn add(a: i32, b: i32) -> i32 {\n    a + b\n}".to_string(),
629            is_correct: true,
630            execution_time_ms: 10,
631        }
632    }
633
634    #[test]
635    fn test_prompt_template_instruction() {
636        let template = PromptTemplate::instruction_following();
637        let prompt = template.format("Python", "Rust", "x = 1");
638
639        assert!(prompt.contains("Python"));
640        assert!(prompt.contains("Rust"));
641        assert!(prompt.contains("x = 1"));
642        assert!(prompt.contains("```Python"));
643    }
644
645    #[test]
646    fn test_prompt_template_chat() {
647        let template = PromptTemplate::chat_style();
648        let prompt = template.format("Python", "Rust", "x = 1");
649
650        assert!(prompt.contains("Python"));
651        assert!(prompt.contains("idiomatic Rust"));
652        assert!(!prompt.contains("```")); // No code blocks in chat style
653    }
654
655    #[test]
656    fn test_prompt_template_completion() {
657        let template = PromptTemplate::completion_style();
658        let prompt = template.format("Python", "Rust", "x = 1");
659
660        assert!(prompt.contains("# Python"));
661        assert!(prompt.contains("# Rust"));
662    }
663
664    #[test]
665    fn test_format_completion_with_tags() {
666        let template = PromptTemplate::instruction_following();
667        let completion = template.format_completion("Rust", "fn main() {}");
668
669        assert!(completion.contains("```Rust"));
670        assert!(completion.contains("fn main() {}"));
671    }
672
673    #[test]
674    fn test_format_completion_without_tags() {
675        let template = PromptTemplate::completion_style();
676        let completion = template.format_completion("Rust", "fn main() {}");
677
678        assert_eq!(completion, "fn main() {}");
679        assert!(!completion.contains("```"));
680    }
681
682    #[test]
683    fn test_to_example() {
684        let config = ExportConfig::default();
685        let exporter = EntrenarExporter::new(config);
686        let tuple = sample_tuple();
687
688        let example = exporter.to_example(&tuple, "test_001");
689
690        assert_eq!(example.id, "test_001");
691        assert_eq!(example.source_language, "python");
692        assert_eq!(example.target_language, "rust");
693        assert!(example.prompt.contains("def add"));
694        assert!(example.completion.contains("fn add"));
695    }
696
697    #[test]
698    fn test_export_config_default() {
699        let config = ExportConfig::default();
700
701        assert_eq!(config.format, ExportFormat::Jsonl);
702        assert!((config.train_ratio - 0.9).abs() < f64::EPSILON);
703        assert_eq!(config.seed, 42);
704        assert!(config.max_examples.is_none());
705    }
706
707    #[test]
708    fn test_split_train_val_ratio() {
709        let config = ExportConfig {
710            train_ratio: 0.8,
711            ..Default::default()
712        };
713        let exporter = EntrenarExporter::new(config);
714
715        let examples: Vec<_> = (0..1000)
716            .map(|i| CodeTranslationExample {
717                id: format!("ex_{i}"),
718                source_language: "Python".to_string(),
719                target_language: "Rust".to_string(),
720                source_code: format!("x = {i}"),
721                target_code: format!("let x = {i};"),
722                prompt: String::new(),
723                completion: String::new(),
724            })
725            .collect();
726
727        let (train, val) = exporter.split_train_val(&examples);
728
729        // Should be approximately 80/20 split
730        let train_ratio = train.len() as f64 / examples.len() as f64;
731        assert!(train_ratio > 0.7 && train_ratio < 0.9);
732        assert_eq!(train.len() + val.len(), examples.len());
733    }
734
735    #[test]
736    fn test_split_deterministic() {
737        let config = ExportConfig::default();
738        let exporter = EntrenarExporter::new(config);
739
740        let examples: Vec<_> = (0..100)
741            .map(|i| CodeTranslationExample {
742                id: format!("ex_{i}"),
743                source_language: "Python".to_string(),
744                target_language: "Rust".to_string(),
745                source_code: format!("x = {i}"),
746                target_code: format!("let x = {i};"),
747                prompt: String::new(),
748                completion: String::new(),
749            })
750            .collect();
751
752        let (train1, _) = exporter.split_train_val(&examples);
753        let (train2, _) = exporter.split_train_val(&examples);
754
755        assert_eq!(train1.len(), train2.len());
756    }
757
758    #[test]
759    fn test_generate_entrenar_config() {
760        let config =
761            generate_entrenar_config(Path::new("data/train"), Path::new("outputs/model"), 16);
762
763        assert!(config.contains("lora:"));
764        assert!(config.contains("rank: 16"));
765        assert!(config.contains("alpha: 32"));
766        assert!(config.contains("train.jsonl"));
767        assert!(config.contains("val.jsonl"));
768    }
769
770    #[test]
771    fn test_export_format_serde() {
772        let json = serde_json::to_string(&ExportFormat::Jsonl).unwrap();
773        assert_eq!(json, "\"jsonl\"");
774
775        let parsed: ExportFormat = serde_json::from_str("\"parquet\"").unwrap();
776        assert_eq!(parsed, ExportFormat::Parquet);
777    }
778
779    // RED PHASE: Tests that require full entrenar integration
780
781    #[test]
782    #[ignore = "requires filesystem setup"]
783    fn test_export_to_jsonl() {
784        let config = ExportConfig::default();
785        let exporter = EntrenarExporter::new(config);
786        let tuples = vec![sample_tuple()];
787
788        let dir = tempfile::tempdir().unwrap();
789        let stats = exporter.export(&tuples, dir.path()).unwrap();
790
791        assert_eq!(stats.total, 1);
792        assert!(dir.path().join("train.jsonl").exists() || dir.path().join("val.jsonl").exists());
793    }
794
795    #[test]
796    #[ignore = "requires entrenar integration"]
797    fn test_export_to_parquet() {
798        // TODO: Implement Parquet export
799        // let config = ExportConfig { format: ExportFormat::Parquet, ..Default::default() };
800        // let exporter = EntrenarExporter::new(config);
801        // let stats = exporter.export(&tuples, dir.path()).unwrap();
802        // assert!(dir.path().join("train.parquet").exists());
803        unimplemented!("Parquet export not yet implemented")
804    }
805
806    #[test]
807    #[ignore = "requires entrenar integration"]
808    fn test_load_in_entrenar() {
809        // TODO: Verify exported data loads correctly in entrenar
810        // let config = entrenar::config::load_config("train_config.yaml").unwrap();
811        // assert!(config.data.train.exists());
812        unimplemented!("Entrenar integration test not yet implemented")
813    }
814
815    #[test]
816    #[ignore = "requires LLM evaluation"]
817    fn test_prompt_quality() {
818        // TODO: Evaluate prompt quality with actual LLM
819        // - Measure translation accuracy
820        // - Compare different prompt templates
821        // - Validate on held-out test set
822        unimplemented!("LLM evaluation not yet implemented")
823    }
824
825    // ========== DISTILLATION CONFIG TESTS ==========
826
827    #[test]
828    fn test_distillation_config_default() {
829        let config = DistillationConfig::default();
830
831        assert!((config.temperature - 3.0).abs() < f32::EPSILON);
832        assert!((config.alpha - 0.7).abs() < f32::EPSILON);
833        assert_eq!(config.num_teachers, 1);
834        assert_eq!(config.student.hidden_size, 256);
835        assert_eq!(config.student.num_layers, 4);
836        assert_eq!(config.student.num_classes, 18);
837        assert_eq!(config.training.epochs, 3);
838    }
839
840    #[test]
841    fn test_distillation_config_builder() {
842        let config = DistillationConfig::new()
843            .with_temperature(5.0)
844            .with_alpha(0.9)
845            .with_teachers(3)
846            .with_output_dir("/tmp/model");
847
848        assert!((config.temperature - 5.0).abs() < f32::EPSILON);
849        assert!((config.alpha - 0.9).abs() < f32::EPSILON);
850        assert_eq!(config.num_teachers, 3);
851        assert_eq!(config.output_dir.to_str().unwrap(), "/tmp/model");
852    }
853
854    #[test]
855    fn test_distillation_config_with_student() {
856        let student = StudentConfig {
857            model_type: "custom".to_string(),
858            hidden_size: 512,
859            num_layers: 8,
860            num_classes: 10,
861        };
862
863        let config = DistillationConfig::new().with_student(student);
864
865        assert_eq!(config.student.model_type, "custom");
866        assert_eq!(config.student.hidden_size, 512);
867        assert_eq!(config.student.num_layers, 8);
868        assert_eq!(config.student.num_classes, 10);
869    }
870
871    #[test]
872    fn test_distillation_config_with_training() {
873        let training = DistillTrainingConfig {
874            epochs: 10,
875            batch_size: 64,
876            learning_rate: 5e-5,
877            grad_clip: 0.5,
878            fp16: true,
879        };
880
881        let config = DistillationConfig::new().with_training(training);
882
883        assert_eq!(config.training.epochs, 10);
884        assert_eq!(config.training.batch_size, 64);
885        assert!((config.training.learning_rate - 5e-5).abs() < f64::EPSILON);
886        assert!(config.training.fp16);
887    }
888
889    #[test]
890    fn test_distillation_config_to_yaml() {
891        let config = DistillationConfig::default();
892        let yaml = config.to_yaml();
893
894        assert!(yaml.contains("temperature: 3"));
895        assert!(yaml.contains("alpha: 0.7"));
896        assert!(yaml.contains("hidden_size: 256"));
897        assert!(yaml.contains("num_layers: 4"));
898        assert!(yaml.contains("epochs: 3"));
899    }
900
901    #[test]
902    fn test_distillation_config_validate_valid() {
903        let config = DistillationConfig::default();
904        assert!(config.validate().is_ok());
905    }
906
907    #[test]
908    fn test_distillation_config_validate_invalid_temperature() {
909        let config = DistillationConfig::default().with_temperature(0.0);
910        let result = config.validate();
911        assert!(result.is_err());
912        assert!(result.unwrap_err().contains("temperature"));
913    }
914
915    #[test]
916    fn test_distillation_config_validate_invalid_alpha() {
917        let config = DistillationConfig::default().with_alpha(1.5);
918        let result = config.validate();
919        assert!(result.is_err());
920        assert!(result.unwrap_err().contains("alpha"));
921    }
922
923    #[test]
924    fn test_distillation_config_validate_invalid_teachers() {
925        let config = DistillationConfig::default().with_teachers(0);
926        let result = config.validate();
927        assert!(result.is_err());
928        assert!(result.unwrap_err().contains("teachers"));
929    }
930
931    #[test]
932    fn test_distillation_config_validate_invalid_hidden_size() {
933        let mut config = DistillationConfig::default();
934        config.student.hidden_size = 0;
935        let result = config.validate();
936        assert!(result.is_err());
937        assert!(result.unwrap_err().contains("hidden_size"));
938    }
939
940    #[test]
941    fn test_distillation_config_validate_invalid_layers() {
942        let mut config = DistillationConfig::default();
943        config.student.num_layers = 0;
944        let result = config.validate();
945        assert!(result.is_err());
946        assert!(result.unwrap_err().contains("num_layers"));
947    }
948
949    #[test]
950    fn test_distillation_config_validate_invalid_epochs() {
951        let mut config = DistillationConfig::default();
952        config.training.epochs = 0;
953        let result = config.validate();
954        assert!(result.is_err());
955        assert!(result.unwrap_err().contains("epochs"));
956    }
957
958    #[test]
959    fn test_distillation_config_validate_invalid_lr() {
960        let mut config = DistillationConfig::default();
961        config.training.learning_rate = 0.0;
962        let result = config.validate();
963        assert!(result.is_err());
964        assert!(result.unwrap_err().contains("learning_rate"));
965    }
966
967    #[test]
968    fn test_run_placeholder() {
969        let config = DistillationConfig::default();
970        let result = config.run_placeholder();
971
972        assert_eq!(result.teacher_count, 1);
973        assert!((result.temperature - 3.0).abs() < f32::EPSILON);
974        assert!((result.alpha - 0.7).abs() < f32::EPSILON);
975        assert_eq!(result.loss_history.len(), 3); // 3 epochs
976        assert!(result.final_loss < 2.6); // Should decrease
977        assert_eq!(result.status, "placeholder");
978    }
979
980    #[test]
981    fn test_distillation_result_serde() {
982        let result = DistillationResult {
983            final_loss: 0.5,
984            loss_history: vec![1.0, 0.75, 0.5],
985            teacher_count: 2,
986            student_config: StudentConfig::default(),
987            temperature: 3.0,
988            alpha: 0.7,
989            status: "complete".to_string(),
990            note: "test".to_string(),
991        };
992
993        let json = serde_json::to_string(&result).unwrap();
994        let parsed: DistillationResult = serde_json::from_str(&json).unwrap();
995
996        assert!((parsed.final_loss - 0.5).abs() < f32::EPSILON);
997        assert_eq!(parsed.teacher_count, 2);
998        assert_eq!(parsed.loss_history.len(), 3);
999    }
1000
1001    #[test]
1002    fn test_student_config_default() {
1003        let config = StudentConfig::default();
1004
1005        assert_eq!(config.model_type, "distilled_student");
1006        assert_eq!(config.hidden_size, 256);
1007        assert_eq!(config.num_layers, 4);
1008        assert_eq!(config.num_classes, 18);
1009    }
1010
1011    #[test]
1012    fn test_distill_training_config_default() {
1013        let config = DistillTrainingConfig::default();
1014
1015        assert_eq!(config.epochs, 3);
1016        assert_eq!(config.batch_size, 32);
1017        assert!((config.learning_rate - 1e-4).abs() < f64::EPSILON);
1018        assert!((config.grad_clip - 1.0).abs() < f32::EPSILON);
1019        assert!(!config.fp16);
1020    }
1021
1022    #[test]
1023    fn test_distillation_config_debug() {
1024        let config = DistillationConfig::default();
1025        let debug = format!("{config:?}");
1026        assert!(debug.contains("DistillationConfig"));
1027        assert!(debug.contains("temperature"));
1028    }
1029
1030    #[test]
1031    fn test_distillation_config_clone() {
1032        let config = DistillationConfig::default();
1033        let cloned = config.clone();
1034        assert!((cloned.temperature - config.temperature).abs() < f32::EPSILON);
1035        assert_eq!(cloned.num_teachers, config.num_teachers);
1036    }
1037
1038    #[test]
1039    fn test_loss_history_decreasing() {
1040        let config = DistillationConfig::new().with_training(DistillTrainingConfig {
1041            epochs: 5,
1042            ..Default::default()
1043        });
1044        let result = config.run_placeholder();
1045
1046        // Verify loss decreases over epochs
1047        for i in 1..result.loss_history.len() {
1048            assert!(result.loss_history[i] < result.loss_history[i - 1]);
1049        }
1050    }
1051}