Skip to main content

verificar/ml/
rich_labeling.rs

1//! Rich Labeling - Beyond binary correctness
2//!
3//! Extracts maximum signal from each oracle invocation with rich multi-task labels.
4//!
5//! # Error Categories
6//!
7//! | Category | Description | Example |
8//! |----------|-------------|---------|
9//! | TypeMismatch | Type system incompatibility | `int` vs `i32` semantics |
10//! | OwnershipViolation | Rust borrow checker errors | Move after borrow |
11//! | LifetimeError | Lifetime annotation issues | Missing lifetime bounds |
12//! | PanicDivergence | Source continues, target panics | Divide by zero |
13//! | OutputMismatch | Different output values | Off-by-one errors |
14//!
15//! # Reference
16//! - VER-053: Rich Labeling - Beyond binary correctness
17
18use serde::{Deserialize, Serialize};
19
20/// Error category taxonomy for transpilation failures
21#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
22pub enum ErrorCategory {
23    /// Type system incompatibility
24    TypeMismatch,
25    /// Rust ownership/borrow checker errors
26    OwnershipViolation,
27    /// Lifetime annotation issues
28    LifetimeError,
29    /// Source continues, target panics
30    PanicDivergence,
31    /// Different output values
32    OutputMismatch,
33    /// Compilation error (syntax, missing imports)
34    CompilationError,
35    /// Runtime error (not panic)
36    RuntimeError,
37    /// Timeout or resource exhaustion
38    ResourceExhaustion,
39    /// Unknown or uncategorized error
40    Unknown,
41}
42
43impl Default for ErrorCategory {
44    fn default() -> Self {
45        Self::Unknown
46    }
47}
48
49impl ErrorCategory {
50    /// All error categories
51    #[must_use]
52    pub fn all() -> &'static [Self] {
53        &[
54            Self::TypeMismatch,
55            Self::OwnershipViolation,
56            Self::LifetimeError,
57            Self::PanicDivergence,
58            Self::OutputMismatch,
59            Self::CompilationError,
60            Self::RuntimeError,
61            Self::ResourceExhaustion,
62            Self::Unknown,
63        ]
64    }
65
66    /// Severity weight for prioritization (higher = more important to fix)
67    #[must_use]
68    pub fn severity(&self) -> f32 {
69        match self {
70            Self::PanicDivergence => 1.0,    // Critical: silent failures
71            Self::OwnershipViolation => 0.9, // Rust-specific complexity
72            Self::LifetimeError => 0.85,     // Rust-specific complexity
73            Self::TypeMismatch => 0.8,       // Common transpilation issue
74            Self::OutputMismatch => 0.7,     // Semantic error
75            Self::RuntimeError => 0.6,       // Detectable at runtime
76            Self::CompilationError => 0.5,   // Detectable at compile time
77            Self::ResourceExhaustion => 0.3, // Often environment-specific
78            Self::Unknown => 0.2,            // Needs investigation
79        }
80    }
81
82    /// Classify error from error message
83    #[must_use]
84    pub fn classify(error_msg: &str) -> Self {
85        let msg = error_msg.to_lowercase();
86
87        // Ownership/borrow errors
88        if msg.contains("borrow")
89            || msg.contains("move")
90            || msg.contains("cannot borrow")
91            || msg.contains("value borrowed")
92        {
93            return Self::OwnershipViolation;
94        }
95
96        // Lifetime errors
97        if msg.contains("lifetime")
98            || msg.contains("does not live long enough")
99            || msg.contains("'a")
100        {
101            return Self::LifetimeError;
102        }
103
104        // Type errors
105        if msg.contains("type mismatch")
106            || msg.contains("expected type")
107            || msg.contains("mismatched types")
108            || msg.contains("cannot convert")
109        {
110            return Self::TypeMismatch;
111        }
112
113        // Panic/divergence
114        if msg.contains("panic")
115            || msg.contains("unwrap")
116            || msg.contains("assertion failed")
117            || msg.contains("index out of bounds")
118        {
119            return Self::PanicDivergence;
120        }
121
122        // Output mismatch
123        if msg.contains("output")
124            || msg.contains("mismatch")
125            || msg.contains("expected")
126            || msg.contains("actual")
127        {
128            return Self::OutputMismatch;
129        }
130
131        // Compilation errors
132        if msg.contains("cannot find")
133            || msg.contains("unresolved")
134            || msg.contains("syntax error")
135            || msg.contains("parse error")
136        {
137            return Self::CompilationError;
138        }
139
140        // Runtime errors
141        if msg.contains("runtime") || msg.contains("overflow") || msg.contains("division by zero") {
142            return Self::RuntimeError;
143        }
144
145        // Resource exhaustion
146        if msg.contains("timeout")
147            || msg.contains("memory")
148            || msg.contains("stack overflow")
149            || msg.contains("resource")
150        {
151            return Self::ResourceExhaustion;
152        }
153
154        Self::Unknown
155    }
156
157    /// Convert to one-hot encoding (9 categories)
158    #[must_use]
159    pub fn to_one_hot(&self) -> [f32; 9] {
160        let mut one_hot = [0.0f32; 9];
161        one_hot[*self as usize] = 1.0;
162        one_hot
163    }
164
165    /// Create from one-hot encoding
166    #[must_use]
167    pub fn from_one_hot(one_hot: &[f32; 9]) -> Self {
168        one_hot
169            .iter()
170            .enumerate()
171            .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
172            .map_or(Self::Unknown, |(i, _)| Self::from_index(i))
173    }
174
175    fn from_index(idx: usize) -> Self {
176        match idx {
177            0 => Self::TypeMismatch,
178            1 => Self::OwnershipViolation,
179            2 => Self::LifetimeError,
180            3 => Self::PanicDivergence,
181            4 => Self::OutputMismatch,
182            5 => Self::CompilationError,
183            6 => Self::RuntimeError,
184            7 => Self::ResourceExhaustion,
185            _ => Self::Unknown,
186        }
187    }
188}
189
190/// Soft labels for gradual correctness
191#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
192pub struct SoftLabels {
193    /// Output similarity (0.0 = completely different, 1.0 = identical)
194    pub output_similarity: f32,
195    /// Runtime ratio (target_time / source_time, 1.0 = same speed)
196    pub runtime_ratio: f32,
197    /// Structural similarity of AST
198    pub structural_similarity: f32,
199    /// Semantic correctness confidence
200    pub semantic_confidence: f32,
201    /// Type safety score
202    pub type_safety: f32,
203}
204
205impl SoftLabels {
206    /// Create new soft labels
207    #[must_use]
208    pub fn new() -> Self {
209        Self::default()
210    }
211
212    /// All labels are valid (in [0, 1] range)
213    #[must_use]
214    pub fn is_valid(&self) -> bool {
215        self.output_similarity >= 0.0
216            && self.output_similarity <= 1.0
217            && self.runtime_ratio >= 0.0
218            && self.structural_similarity >= 0.0
219            && self.structural_similarity <= 1.0
220            && self.semantic_confidence >= 0.0
221            && self.semantic_confidence <= 1.0
222            && self.type_safety >= 0.0
223            && self.type_safety <= 1.0
224    }
225
226    /// Convert to array
227    #[must_use]
228    pub fn to_array(&self) -> [f32; 5] {
229        [
230            self.output_similarity,
231            self.runtime_ratio.min(10.0) / 10.0, // Normalize to [0, 1]
232            self.structural_similarity,
233            self.semantic_confidence,
234            self.type_safety,
235        ]
236    }
237
238    /// Create from array
239    #[must_use]
240    pub fn from_array(arr: [f32; 5]) -> Self {
241        Self {
242            output_similarity: arr[0],
243            runtime_ratio: arr[1] * 10.0, // Denormalize
244            structural_similarity: arr[2],
245            semantic_confidence: arr[3],
246            type_safety: arr[4],
247        }
248    }
249
250    /// Overall correctness score (weighted average)
251    #[must_use]
252    pub fn overall_score(&self) -> f32 {
253        let weights = [0.3, 0.1, 0.2, 0.25, 0.15];
254        let arr = self.to_array();
255
256        let weighted_sum: f32 = arr.iter().zip(&weights).map(|(v, w)| v * w).sum();
257        let total_weight: f32 = weights.iter().sum();
258
259        weighted_sum / total_weight
260    }
261}
262
263/// Builder for soft labels
264#[derive(Debug, Default)]
265pub struct SoftLabelsBuilder {
266    labels: SoftLabels,
267}
268
269impl SoftLabelsBuilder {
270    /// Create new builder
271    #[must_use]
272    pub fn new() -> Self {
273        Self::default()
274    }
275
276    /// Set output similarity
277    #[must_use]
278    pub fn output_similarity(mut self, value: f32) -> Self {
279        self.labels.output_similarity = value.clamp(0.0, 1.0);
280        self
281    }
282
283    /// Set runtime ratio
284    #[must_use]
285    pub fn runtime_ratio(mut self, value: f32) -> Self {
286        self.labels.runtime_ratio = value.max(0.0);
287        self
288    }
289
290    /// Set structural similarity
291    #[must_use]
292    pub fn structural_similarity(mut self, value: f32) -> Self {
293        self.labels.structural_similarity = value.clamp(0.0, 1.0);
294        self
295    }
296
297    /// Set semantic confidence
298    #[must_use]
299    pub fn semantic_confidence(mut self, value: f32) -> Self {
300        self.labels.semantic_confidence = value.clamp(0.0, 1.0);
301        self
302    }
303
304    /// Set type safety
305    #[must_use]
306    pub fn type_safety(mut self, value: f32) -> Self {
307        self.labels.type_safety = value.clamp(0.0, 1.0);
308        self
309    }
310
311    /// Build soft labels
312    #[must_use]
313    pub fn build(self) -> SoftLabels {
314        self.labels
315    }
316}
317
318/// Multi-task label schema
319#[derive(Debug, Clone, Default, Serialize, Deserialize)]
320pub struct RichLabel {
321    /// Binary correctness (ground truth)
322    pub is_correct: bool,
323    /// Error category (if not correct)
324    pub error_category: Option<ErrorCategory>,
325    /// Error message (if not correct)
326    pub error_message: Option<String>,
327    /// Soft labels for gradual correctness
328    pub soft_labels: SoftLabels,
329    /// AST diff summary
330    pub ast_diff: Option<AstDiff>,
331    /// Execution metrics
332    pub execution_metrics: ExecutionMetrics,
333}
334
335impl RichLabel {
336    /// Create for correct sample
337    #[must_use]
338    pub fn correct(soft_labels: SoftLabels) -> Self {
339        Self {
340            is_correct: true,
341            error_category: None,
342            error_message: None,
343            soft_labels,
344            ast_diff: None,
345            execution_metrics: ExecutionMetrics::default(),
346        }
347    }
348
349    /// Create for incorrect sample
350    #[must_use]
351    pub fn incorrect(category: ErrorCategory, message: String, soft_labels: SoftLabels) -> Self {
352        Self {
353            is_correct: false,
354            error_category: Some(category),
355            error_message: Some(message),
356            soft_labels,
357            ast_diff: None,
358            execution_metrics: ExecutionMetrics::default(),
359        }
360    }
361
362    /// Set AST diff
363    #[must_use]
364    pub fn with_ast_diff(mut self, diff: AstDiff) -> Self {
365        self.ast_diff = Some(diff);
366        self
367    }
368
369    /// Set execution metrics
370    #[must_use]
371    pub fn with_metrics(mut self, metrics: ExecutionMetrics) -> Self {
372        self.execution_metrics = metrics;
373        self
374    }
375
376    /// Convert to flat feature vector for ML
377    #[must_use]
378    pub fn to_feature_vector(&self) -> Vec<f32> {
379        let mut features = Vec::with_capacity(20);
380
381        // Binary label
382        features.push(if self.is_correct { 1.0 } else { 0.0 });
383
384        // Error category one-hot (9 values)
385        let one_hot = self
386            .error_category
387            .unwrap_or(ErrorCategory::Unknown)
388            .to_one_hot();
389        features.extend_from_slice(&one_hot);
390
391        // Soft labels (5 values)
392        features.extend_from_slice(&self.soft_labels.to_array());
393
394        // Execution metrics (4 values)
395        features.push(self.execution_metrics.source_time_ms as f32 / 1000.0);
396        features.push(self.execution_metrics.target_time_ms as f32 / 1000.0);
397        features.push(self.execution_metrics.memory_bytes as f32 / 1_000_000.0);
398        features.push(if self.execution_metrics.timeout {
399            1.0
400        } else {
401            0.0
402        });
403
404        features
405    }
406}
407
408/// AST diff summary
409#[derive(Debug, Clone, Default, Serialize, Deserialize)]
410pub struct AstDiff {
411    /// Number of nodes added
412    pub nodes_added: u32,
413    /// Number of nodes removed
414    pub nodes_removed: u32,
415    /// Number of nodes modified
416    pub nodes_modified: u32,
417    /// Structural edit distance
418    pub edit_distance: u32,
419    /// Most common diff type
420    pub primary_change: Option<String>,
421}
422
423impl AstDiff {
424    /// Total number of changes
425    #[must_use]
426    pub fn total_changes(&self) -> u32 {
427        self.nodes_added + self.nodes_removed + self.nodes_modified
428    }
429
430    /// Similarity score (1.0 = identical, 0.0 = completely different)
431    #[must_use]
432    pub fn similarity(&self, total_nodes: u32) -> f32 {
433        if total_nodes == 0 {
434            return 1.0;
435        }
436
437        let changes = self.total_changes();
438        1.0 - (changes as f32 / total_nodes as f32).min(1.0)
439    }
440}
441
442/// Execution metrics
443#[derive(Debug, Clone, Default, Serialize, Deserialize)]
444pub struct ExecutionMetrics {
445    /// Source execution time in milliseconds
446    pub source_time_ms: u64,
447    /// Target execution time in milliseconds
448    pub target_time_ms: u64,
449    /// Memory usage in bytes
450    pub memory_bytes: u64,
451    /// Whether execution timed out
452    pub timeout: bool,
453}
454
455impl ExecutionMetrics {
456    /// Runtime ratio (target / source)
457    #[must_use]
458    pub fn runtime_ratio(&self) -> f32 {
459        if self.source_time_ms == 0 {
460            return 1.0;
461        }
462        self.target_time_ms as f32 / self.source_time_ms as f32
463    }
464}
465
466/// Label extractor for oracle results
467#[derive(Debug, Default)]
468pub struct LabelExtractor;
469
470impl LabelExtractor {
471    /// Create new label extractor
472    #[must_use]
473    pub fn new() -> Self {
474        Self
475    }
476
477    /// Extract rich label from oracle result
478    pub fn extract(
479        &self,
480        is_correct: bool,
481        error_msg: Option<&str>,
482        source_output: &str,
483        target_output: &str,
484        source_time_ms: u64,
485        target_time_ms: u64,
486    ) -> RichLabel {
487        let output_similarity = self.compute_output_similarity(source_output, target_output);
488
489        let runtime_ratio = if source_time_ms == 0 {
490            1.0
491        } else {
492            target_time_ms as f32 / source_time_ms as f32
493        };
494
495        let soft_labels = SoftLabelsBuilder::new()
496            .output_similarity(output_similarity)
497            .runtime_ratio(runtime_ratio)
498            .semantic_confidence(if is_correct { 1.0 } else { 0.3 })
499            .type_safety(if is_correct { 1.0 } else { 0.5 })
500            .build();
501
502        let execution_metrics = ExecutionMetrics {
503            source_time_ms,
504            target_time_ms,
505            memory_bytes: 0,
506            timeout: false,
507        };
508
509        if is_correct {
510            RichLabel::correct(soft_labels).with_metrics(execution_metrics)
511        } else {
512            let category = error_msg.map_or(ErrorCategory::Unknown, ErrorCategory::classify);
513            let message = error_msg.unwrap_or("Unknown error").to_string();
514
515            RichLabel::incorrect(category, message, soft_labels).with_metrics(execution_metrics)
516        }
517    }
518
519    fn compute_output_similarity(&self, source: &str, target: &str) -> f32 {
520        if source == target {
521            return 1.0;
522        }
523
524        if source.is_empty() && target.is_empty() {
525            return 1.0;
526        }
527
528        if source.is_empty() || target.is_empty() {
529            return 0.0;
530        }
531
532        // Simple Jaccard similarity on lines
533        let source_lines: std::collections::HashSet<_> = source.lines().collect();
534        let target_lines: std::collections::HashSet<_> = target.lines().collect();
535
536        let intersection = source_lines.intersection(&target_lines).count();
537        let union = source_lines.union(&target_lines).count();
538
539        if union == 0 {
540            1.0
541        } else {
542            intersection as f32 / union as f32
543        }
544    }
545}
546
547#[cfg(test)]
548mod tests {
549    use super::*;
550
551    // ========== ErrorCategory Tests ==========
552
553    #[test]
554    fn test_error_category_all() {
555        let all = ErrorCategory::all();
556        assert_eq!(all.len(), 9);
557    }
558
559    #[test]
560    fn test_error_category_default() {
561        assert_eq!(ErrorCategory::default(), ErrorCategory::Unknown);
562    }
563
564    #[test]
565    fn test_error_category_severity() {
566        assert!(ErrorCategory::PanicDivergence.severity() > ErrorCategory::Unknown.severity());
567        assert!(
568            ErrorCategory::OwnershipViolation.severity()
569                > ErrorCategory::CompilationError.severity()
570        );
571    }
572
573    #[test]
574    fn test_error_category_classify_ownership() {
575        assert_eq!(
576            ErrorCategory::classify("cannot borrow x as mutable"),
577            ErrorCategory::OwnershipViolation
578        );
579        assert_eq!(
580            ErrorCategory::classify("value moved here"),
581            ErrorCategory::OwnershipViolation
582        );
583    }
584
585    #[test]
586    fn test_error_category_classify_lifetime() {
587        assert_eq!(
588            ErrorCategory::classify("lifetime 'a does not live long enough"),
589            ErrorCategory::LifetimeError
590        );
591    }
592
593    #[test]
594    fn test_error_category_classify_type() {
595        assert_eq!(
596            ErrorCategory::classify("type mismatch: expected i32"),
597            ErrorCategory::TypeMismatch
598        );
599    }
600
601    #[test]
602    fn test_error_category_classify_panic() {
603        assert_eq!(
604            ErrorCategory::classify("thread panicked at index out of bounds"),
605            ErrorCategory::PanicDivergence
606        );
607    }
608
609    #[test]
610    fn test_error_category_classify_output() {
611        assert_eq!(
612            ErrorCategory::classify("output mismatch: expected 5, actual 6"),
613            ErrorCategory::OutputMismatch
614        );
615    }
616
617    #[test]
618    fn test_error_category_classify_compilation() {
619        assert_eq!(
620            ErrorCategory::classify("cannot find value x in scope"),
621            ErrorCategory::CompilationError
622        );
623    }
624
625    #[test]
626    fn test_error_category_classify_runtime() {
627        assert_eq!(
628            ErrorCategory::classify("integer overflow detected"),
629            ErrorCategory::RuntimeError
630        );
631    }
632
633    #[test]
634    fn test_error_category_classify_resource() {
635        assert_eq!(
636            ErrorCategory::classify("execution timeout"),
637            ErrorCategory::ResourceExhaustion
638        );
639    }
640
641    #[test]
642    fn test_error_category_classify_unknown() {
643        assert_eq!(
644            ErrorCategory::classify("some random error"),
645            ErrorCategory::Unknown
646        );
647    }
648
649    #[test]
650    fn test_error_category_one_hot() {
651        let one_hot = ErrorCategory::TypeMismatch.to_one_hot();
652        assert_eq!(one_hot[0], 1.0);
653        assert_eq!(one_hot[1], 0.0);
654    }
655
656    #[test]
657    fn test_error_category_from_one_hot() {
658        let one_hot = [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
659        assert_eq!(
660            ErrorCategory::from_one_hot(&one_hot),
661            ErrorCategory::OwnershipViolation
662        );
663    }
664
665    // ========== SoftLabels Tests ==========
666
667    #[test]
668    fn test_soft_labels_default() {
669        let labels = SoftLabels::default();
670        assert_eq!(labels.output_similarity, 0.0);
671    }
672
673    #[test]
674    fn test_soft_labels_is_valid() {
675        let valid = SoftLabels {
676            output_similarity: 0.8,
677            runtime_ratio: 1.2,
678            structural_similarity: 0.9,
679            semantic_confidence: 0.95,
680            type_safety: 1.0,
681        };
682        assert!(valid.is_valid());
683
684        let invalid = SoftLabels {
685            output_similarity: -0.1,
686            ..Default::default()
687        };
688        assert!(!invalid.is_valid());
689    }
690
691    #[test]
692    fn test_soft_labels_to_array() {
693        let labels = SoftLabels {
694            output_similarity: 0.8,
695            runtime_ratio: 1.5,
696            structural_similarity: 0.9,
697            semantic_confidence: 0.7,
698            type_safety: 1.0,
699        };
700
701        let arr = labels.to_array();
702        assert_eq!(arr.len(), 5);
703        assert!((arr[0] - 0.8).abs() < 0.001);
704    }
705
706    #[test]
707    fn test_soft_labels_overall_score() {
708        let perfect = SoftLabels {
709            output_similarity: 1.0,
710            runtime_ratio: 1.0,
711            structural_similarity: 1.0,
712            semantic_confidence: 1.0,
713            type_safety: 1.0,
714        };
715
716        let score = perfect.overall_score();
717        assert!((score - 1.0).abs() < 0.1);
718    }
719
720    #[test]
721    fn test_soft_labels_builder() {
722        let labels = SoftLabelsBuilder::new()
723            .output_similarity(0.9)
724            .runtime_ratio(1.1)
725            .structural_similarity(0.95)
726            .semantic_confidence(0.85)
727            .type_safety(1.0)
728            .build();
729
730        assert!((labels.output_similarity - 0.9).abs() < 0.001);
731        assert!((labels.runtime_ratio - 1.1).abs() < 0.001);
732    }
733
734    #[test]
735    fn test_soft_labels_builder_clamps() {
736        let labels = SoftLabelsBuilder::new()
737            .output_similarity(1.5) // Should clamp to 1.0
738            .semantic_confidence(-0.5) // Should clamp to 0.0
739            .build();
740
741        assert!((labels.output_similarity - 1.0).abs() < 0.001);
742        assert!((labels.semantic_confidence - 0.0).abs() < 0.001);
743    }
744
745    // ========== RichLabel Tests ==========
746
747    #[test]
748    fn test_rich_label_correct() {
749        let label = RichLabel::correct(SoftLabels::default());
750        assert!(label.is_correct);
751        assert!(label.error_category.is_none());
752    }
753
754    #[test]
755    fn test_rich_label_incorrect() {
756        let label = RichLabel::incorrect(
757            ErrorCategory::TypeMismatch,
758            "Type error".to_string(),
759            SoftLabels::default(),
760        );
761        assert!(!label.is_correct);
762        assert_eq!(label.error_category, Some(ErrorCategory::TypeMismatch));
763    }
764
765    #[test]
766    fn test_rich_label_with_ast_diff() {
767        let diff = AstDiff {
768            nodes_added: 5,
769            nodes_removed: 2,
770            nodes_modified: 3,
771            edit_distance: 10,
772            primary_change: Some("FunctionDef".to_string()),
773        };
774
775        let label = RichLabel::correct(SoftLabels::default()).with_ast_diff(diff);
776        assert!(label.ast_diff.is_some());
777    }
778
779    #[test]
780    fn test_rich_label_feature_vector() {
781        let label = RichLabel::correct(SoftLabels {
782            output_similarity: 1.0,
783            runtime_ratio: 1.0,
784            structural_similarity: 1.0,
785            semantic_confidence: 1.0,
786            type_safety: 1.0,
787        });
788
789        let features = label.to_feature_vector();
790        assert_eq!(features.len(), 19); // 1 + 9 + 5 + 4
791        assert!((features[0] - 1.0).abs() < 0.001); // is_correct
792    }
793
794    // ========== AstDiff Tests ==========
795
796    #[test]
797    fn test_ast_diff_total_changes() {
798        let diff = AstDiff {
799            nodes_added: 5,
800            nodes_removed: 3,
801            nodes_modified: 2,
802            edit_distance: 0,
803            primary_change: None,
804        };
805
806        assert_eq!(diff.total_changes(), 10);
807    }
808
809    #[test]
810    fn test_ast_diff_similarity() {
811        let diff = AstDiff {
812            nodes_added: 2,
813            nodes_removed: 0,
814            nodes_modified: 0,
815            edit_distance: 2,
816            primary_change: None,
817        };
818
819        let sim = diff.similarity(10);
820        assert!((sim - 0.8).abs() < 0.001);
821    }
822
823    #[test]
824    fn test_ast_diff_similarity_empty() {
825        let diff = AstDiff::default();
826        assert!((diff.similarity(0) - 1.0).abs() < 0.001);
827    }
828
829    // ========== ExecutionMetrics Tests ==========
830
831    #[test]
832    fn test_execution_metrics_runtime_ratio() {
833        let metrics = ExecutionMetrics {
834            source_time_ms: 100,
835            target_time_ms: 150,
836            memory_bytes: 0,
837            timeout: false,
838        };
839
840        assert!((metrics.runtime_ratio() - 1.5).abs() < 0.001);
841    }
842
843    #[test]
844    fn test_execution_metrics_runtime_ratio_zero() {
845        let metrics = ExecutionMetrics {
846            source_time_ms: 0,
847            target_time_ms: 100,
848            memory_bytes: 0,
849            timeout: false,
850        };
851
852        assert!((metrics.runtime_ratio() - 1.0).abs() < 0.001);
853    }
854
855    // ========== LabelExtractor Tests ==========
856
857    #[test]
858    fn test_label_extractor_correct() {
859        let extractor = LabelExtractor::new();
860        let label = extractor.extract(true, None, "hello\nworld", "hello\nworld", 100, 100);
861
862        assert!(label.is_correct);
863        assert!((label.soft_labels.output_similarity - 1.0).abs() < 0.001);
864    }
865
866    #[test]
867    fn test_label_extractor_incorrect() {
868        let extractor = LabelExtractor::new();
869        let label = extractor.extract(false, Some("type mismatch error"), "5", "6", 100, 100);
870
871        assert!(!label.is_correct);
872        assert_eq!(label.error_category, Some(ErrorCategory::TypeMismatch));
873    }
874
875    #[test]
876    fn test_label_extractor_output_similarity() {
877        let extractor = LabelExtractor::new();
878
879        // Same output
880        let same = extractor.extract(true, None, "a\nb\nc", "a\nb\nc", 100, 100);
881        assert!((same.soft_labels.output_similarity - 1.0).abs() < 0.001);
882
883        // Partially different
884        let partial = extractor.extract(false, None, "a\nb\nc", "a\nb\nd", 100, 100);
885        assert!(partial.soft_labels.output_similarity > 0.0);
886        assert!(partial.soft_labels.output_similarity < 1.0);
887    }
888
889    // ========== Debug Tests ==========
890
891    #[test]
892    fn test_error_category_debug() {
893        let debug = format!("{:?}", ErrorCategory::TypeMismatch);
894        assert!(debug.contains("TypeMismatch"));
895    }
896
897    #[test]
898    fn test_soft_labels_debug() {
899        let labels = SoftLabels::default();
900        let debug = format!("{labels:?}");
901        assert!(debug.contains("SoftLabels"));
902    }
903
904    #[test]
905    fn test_rich_label_debug() {
906        let label = RichLabel::correct(SoftLabels::default());
907        let debug = format!("{label:?}");
908        assert!(debug.contains("RichLabel"));
909    }
910
911    #[test]
912    fn test_label_extractor_debug() {
913        let extractor = LabelExtractor::new();
914        let debug = format!("{extractor:?}");
915        assert!(debug.contains("LabelExtractor"));
916    }
917
918    // ========== Serialization Tests ==========
919
920    #[test]
921    fn test_error_category_serialize() {
922        let category = ErrorCategory::OwnershipViolation;
923        let json = serde_json::to_string(&category).unwrap();
924        let restored: ErrorCategory = serde_json::from_str(&json).unwrap();
925        assert_eq!(category, restored);
926    }
927
928    #[test]
929    fn test_soft_labels_serialize() {
930        let labels = SoftLabelsBuilder::new()
931            .output_similarity(0.8)
932            .runtime_ratio(1.2)
933            .build();
934
935        let json = serde_json::to_string(&labels).unwrap();
936        let restored: SoftLabels = serde_json::from_str(&json).unwrap();
937        assert!((labels.output_similarity - restored.output_similarity).abs() < 0.001);
938    }
939
940    #[test]
941    fn test_rich_label_serialize() {
942        let label = RichLabel::incorrect(
943            ErrorCategory::TypeMismatch,
944            "Error".to_string(),
945            SoftLabels::default(),
946        );
947
948        let json = serde_json::to_string(&label).unwrap();
949        let restored: RichLabel = serde_json::from_str(&json).unwrap();
950        assert_eq!(label.is_correct, restored.is_correct);
951        assert_eq!(label.error_category, restored.error_category);
952    }
953}
954
955/// Property-based tests
956#[cfg(test)]
957mod proptests {
958    use super::*;
959    use proptest::prelude::*;
960
961    proptest! {
962        /// Severity is bounded [0, 1]
963        #[test]
964        fn prop_severity_bounded(idx in 0usize..9) {
965            let category = ErrorCategory::from_index(idx);
966            let severity = category.severity();
967            prop_assert!(severity >= 0.0);
968            prop_assert!(severity <= 1.0);
969        }
970
971        /// One-hot roundtrip
972        #[test]
973        fn prop_one_hot_roundtrip(idx in 0usize..9) {
974            let original = ErrorCategory::from_index(idx);
975            let one_hot = original.to_one_hot();
976            let restored = ErrorCategory::from_one_hot(&one_hot);
977            prop_assert_eq!(original, restored);
978        }
979
980        /// Soft labels array roundtrip preserves structure
981        #[test]
982        fn prop_soft_labels_structure(
983            output_sim in 0.0f32..1.0,
984            structural_sim in 0.0f32..1.0,
985            semantic_conf in 0.0f32..1.0,
986            type_safety in 0.0f32..1.0,
987        ) {
988            let labels = SoftLabelsBuilder::new()
989                .output_similarity(output_sim)
990                .structural_similarity(structural_sim)
991                .semantic_confidence(semantic_conf)
992                .type_safety(type_safety)
993                .build();
994
995            prop_assert!(labels.is_valid());
996        }
997
998        /// Overall score is bounded [0, 1]
999        #[test]
1000        fn prop_overall_score_bounded(
1001            output_sim in 0.0f32..1.0,
1002            runtime_ratio in 0.0f32..10.0,
1003            structural_sim in 0.0f32..1.0,
1004            semantic_conf in 0.0f32..1.0,
1005            type_safety in 0.0f32..1.0,
1006        ) {
1007            let labels = SoftLabels {
1008                output_similarity: output_sim,
1009                runtime_ratio,
1010                structural_similarity: structural_sim,
1011                semantic_confidence: semantic_conf,
1012                type_safety,
1013            };
1014
1015            let score = labels.overall_score();
1016            prop_assert!(score >= 0.0);
1017            prop_assert!(score <= 1.0);
1018        }
1019
1020        /// Feature vector length is consistent
1021        #[test]
1022        fn prop_feature_vector_length(is_correct: bool) {
1023            let label = if is_correct {
1024                RichLabel::correct(SoftLabels::default())
1025            } else {
1026                RichLabel::incorrect(ErrorCategory::Unknown, "error".to_string(), SoftLabels::default())
1027            };
1028
1029            let features = label.to_feature_vector();
1030            prop_assert_eq!(features.len(), 19);
1031        }
1032    }
1033}