1use serde::{Deserialize, Serialize};
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
22pub enum ErrorCategory {
23 TypeMismatch,
25 OwnershipViolation,
27 LifetimeError,
29 PanicDivergence,
31 OutputMismatch,
33 CompilationError,
35 RuntimeError,
37 ResourceExhaustion,
39 Unknown,
41}
42
43impl Default for ErrorCategory {
44 fn default() -> Self {
45 Self::Unknown
46 }
47}
48
49impl ErrorCategory {
50 #[must_use]
52 pub fn all() -> &'static [Self] {
53 &[
54 Self::TypeMismatch,
55 Self::OwnershipViolation,
56 Self::LifetimeError,
57 Self::PanicDivergence,
58 Self::OutputMismatch,
59 Self::CompilationError,
60 Self::RuntimeError,
61 Self::ResourceExhaustion,
62 Self::Unknown,
63 ]
64 }
65
66 #[must_use]
68 pub fn severity(&self) -> f32 {
69 match self {
70 Self::PanicDivergence => 1.0, Self::OwnershipViolation => 0.9, Self::LifetimeError => 0.85, Self::TypeMismatch => 0.8, Self::OutputMismatch => 0.7, Self::RuntimeError => 0.6, Self::CompilationError => 0.5, Self::ResourceExhaustion => 0.3, Self::Unknown => 0.2, }
80 }
81
82 #[must_use]
84 pub fn classify(error_msg: &str) -> Self {
85 let msg = error_msg.to_lowercase();
86
87 if msg.contains("borrow")
89 || msg.contains("move")
90 || msg.contains("cannot borrow")
91 || msg.contains("value borrowed")
92 {
93 return Self::OwnershipViolation;
94 }
95
96 if msg.contains("lifetime")
98 || msg.contains("does not live long enough")
99 || msg.contains("'a")
100 {
101 return Self::LifetimeError;
102 }
103
104 if msg.contains("type mismatch")
106 || msg.contains("expected type")
107 || msg.contains("mismatched types")
108 || msg.contains("cannot convert")
109 {
110 return Self::TypeMismatch;
111 }
112
113 if msg.contains("panic")
115 || msg.contains("unwrap")
116 || msg.contains("assertion failed")
117 || msg.contains("index out of bounds")
118 {
119 return Self::PanicDivergence;
120 }
121
122 if msg.contains("output")
124 || msg.contains("mismatch")
125 || msg.contains("expected")
126 || msg.contains("actual")
127 {
128 return Self::OutputMismatch;
129 }
130
131 if msg.contains("cannot find")
133 || msg.contains("unresolved")
134 || msg.contains("syntax error")
135 || msg.contains("parse error")
136 {
137 return Self::CompilationError;
138 }
139
140 if msg.contains("runtime") || msg.contains("overflow") || msg.contains("division by zero") {
142 return Self::RuntimeError;
143 }
144
145 if msg.contains("timeout")
147 || msg.contains("memory")
148 || msg.contains("stack overflow")
149 || msg.contains("resource")
150 {
151 return Self::ResourceExhaustion;
152 }
153
154 Self::Unknown
155 }
156
157 #[must_use]
159 pub fn to_one_hot(&self) -> [f32; 9] {
160 let mut one_hot = [0.0f32; 9];
161 one_hot[*self as usize] = 1.0;
162 one_hot
163 }
164
165 #[must_use]
167 pub fn from_one_hot(one_hot: &[f32; 9]) -> Self {
168 one_hot
169 .iter()
170 .enumerate()
171 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
172 .map_or(Self::Unknown, |(i, _)| Self::from_index(i))
173 }
174
175 fn from_index(idx: usize) -> Self {
176 match idx {
177 0 => Self::TypeMismatch,
178 1 => Self::OwnershipViolation,
179 2 => Self::LifetimeError,
180 3 => Self::PanicDivergence,
181 4 => Self::OutputMismatch,
182 5 => Self::CompilationError,
183 6 => Self::RuntimeError,
184 7 => Self::ResourceExhaustion,
185 _ => Self::Unknown,
186 }
187 }
188}
189
190#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
192pub struct SoftLabels {
193 pub output_similarity: f32,
195 pub runtime_ratio: f32,
197 pub structural_similarity: f32,
199 pub semantic_confidence: f32,
201 pub type_safety: f32,
203}
204
205impl SoftLabels {
206 #[must_use]
208 pub fn new() -> Self {
209 Self::default()
210 }
211
212 #[must_use]
214 pub fn is_valid(&self) -> bool {
215 self.output_similarity >= 0.0
216 && self.output_similarity <= 1.0
217 && self.runtime_ratio >= 0.0
218 && self.structural_similarity >= 0.0
219 && self.structural_similarity <= 1.0
220 && self.semantic_confidence >= 0.0
221 && self.semantic_confidence <= 1.0
222 && self.type_safety >= 0.0
223 && self.type_safety <= 1.0
224 }
225
226 #[must_use]
228 pub fn to_array(&self) -> [f32; 5] {
229 [
230 self.output_similarity,
231 self.runtime_ratio.min(10.0) / 10.0, self.structural_similarity,
233 self.semantic_confidence,
234 self.type_safety,
235 ]
236 }
237
238 #[must_use]
240 pub fn from_array(arr: [f32; 5]) -> Self {
241 Self {
242 output_similarity: arr[0],
243 runtime_ratio: arr[1] * 10.0, structural_similarity: arr[2],
245 semantic_confidence: arr[3],
246 type_safety: arr[4],
247 }
248 }
249
250 #[must_use]
252 pub fn overall_score(&self) -> f32 {
253 let weights = [0.3, 0.1, 0.2, 0.25, 0.15];
254 let arr = self.to_array();
255
256 let weighted_sum: f32 = arr.iter().zip(&weights).map(|(v, w)| v * w).sum();
257 let total_weight: f32 = weights.iter().sum();
258
259 weighted_sum / total_weight
260 }
261}
262
263#[derive(Debug, Default)]
265pub struct SoftLabelsBuilder {
266 labels: SoftLabels,
267}
268
269impl SoftLabelsBuilder {
270 #[must_use]
272 pub fn new() -> Self {
273 Self::default()
274 }
275
276 #[must_use]
278 pub fn output_similarity(mut self, value: f32) -> Self {
279 self.labels.output_similarity = value.clamp(0.0, 1.0);
280 self
281 }
282
283 #[must_use]
285 pub fn runtime_ratio(mut self, value: f32) -> Self {
286 self.labels.runtime_ratio = value.max(0.0);
287 self
288 }
289
290 #[must_use]
292 pub fn structural_similarity(mut self, value: f32) -> Self {
293 self.labels.structural_similarity = value.clamp(0.0, 1.0);
294 self
295 }
296
297 #[must_use]
299 pub fn semantic_confidence(mut self, value: f32) -> Self {
300 self.labels.semantic_confidence = value.clamp(0.0, 1.0);
301 self
302 }
303
304 #[must_use]
306 pub fn type_safety(mut self, value: f32) -> Self {
307 self.labels.type_safety = value.clamp(0.0, 1.0);
308 self
309 }
310
311 #[must_use]
313 pub fn build(self) -> SoftLabels {
314 self.labels
315 }
316}
317
318#[derive(Debug, Clone, Default, Serialize, Deserialize)]
320pub struct RichLabel {
321 pub is_correct: bool,
323 pub error_category: Option<ErrorCategory>,
325 pub error_message: Option<String>,
327 pub soft_labels: SoftLabels,
329 pub ast_diff: Option<AstDiff>,
331 pub execution_metrics: ExecutionMetrics,
333}
334
335impl RichLabel {
336 #[must_use]
338 pub fn correct(soft_labels: SoftLabels) -> Self {
339 Self {
340 is_correct: true,
341 error_category: None,
342 error_message: None,
343 soft_labels,
344 ast_diff: None,
345 execution_metrics: ExecutionMetrics::default(),
346 }
347 }
348
349 #[must_use]
351 pub fn incorrect(category: ErrorCategory, message: String, soft_labels: SoftLabels) -> Self {
352 Self {
353 is_correct: false,
354 error_category: Some(category),
355 error_message: Some(message),
356 soft_labels,
357 ast_diff: None,
358 execution_metrics: ExecutionMetrics::default(),
359 }
360 }
361
362 #[must_use]
364 pub fn with_ast_diff(mut self, diff: AstDiff) -> Self {
365 self.ast_diff = Some(diff);
366 self
367 }
368
369 #[must_use]
371 pub fn with_metrics(mut self, metrics: ExecutionMetrics) -> Self {
372 self.execution_metrics = metrics;
373 self
374 }
375
376 #[must_use]
378 pub fn to_feature_vector(&self) -> Vec<f32> {
379 let mut features = Vec::with_capacity(20);
380
381 features.push(if self.is_correct { 1.0 } else { 0.0 });
383
384 let one_hot = self
386 .error_category
387 .unwrap_or(ErrorCategory::Unknown)
388 .to_one_hot();
389 features.extend_from_slice(&one_hot);
390
391 features.extend_from_slice(&self.soft_labels.to_array());
393
394 features.push(self.execution_metrics.source_time_ms as f32 / 1000.0);
396 features.push(self.execution_metrics.target_time_ms as f32 / 1000.0);
397 features.push(self.execution_metrics.memory_bytes as f32 / 1_000_000.0);
398 features.push(if self.execution_metrics.timeout {
399 1.0
400 } else {
401 0.0
402 });
403
404 features
405 }
406}
407
408#[derive(Debug, Clone, Default, Serialize, Deserialize)]
410pub struct AstDiff {
411 pub nodes_added: u32,
413 pub nodes_removed: u32,
415 pub nodes_modified: u32,
417 pub edit_distance: u32,
419 pub primary_change: Option<String>,
421}
422
423impl AstDiff {
424 #[must_use]
426 pub fn total_changes(&self) -> u32 {
427 self.nodes_added + self.nodes_removed + self.nodes_modified
428 }
429
430 #[must_use]
432 pub fn similarity(&self, total_nodes: u32) -> f32 {
433 if total_nodes == 0 {
434 return 1.0;
435 }
436
437 let changes = self.total_changes();
438 1.0 - (changes as f32 / total_nodes as f32).min(1.0)
439 }
440}
441
442#[derive(Debug, Clone, Default, Serialize, Deserialize)]
444pub struct ExecutionMetrics {
445 pub source_time_ms: u64,
447 pub target_time_ms: u64,
449 pub memory_bytes: u64,
451 pub timeout: bool,
453}
454
455impl ExecutionMetrics {
456 #[must_use]
458 pub fn runtime_ratio(&self) -> f32 {
459 if self.source_time_ms == 0 {
460 return 1.0;
461 }
462 self.target_time_ms as f32 / self.source_time_ms as f32
463 }
464}
465
466#[derive(Debug, Default)]
468pub struct LabelExtractor;
469
470impl LabelExtractor {
471 #[must_use]
473 pub fn new() -> Self {
474 Self
475 }
476
477 pub fn extract(
479 &self,
480 is_correct: bool,
481 error_msg: Option<&str>,
482 source_output: &str,
483 target_output: &str,
484 source_time_ms: u64,
485 target_time_ms: u64,
486 ) -> RichLabel {
487 let output_similarity = self.compute_output_similarity(source_output, target_output);
488
489 let runtime_ratio = if source_time_ms == 0 {
490 1.0
491 } else {
492 target_time_ms as f32 / source_time_ms as f32
493 };
494
495 let soft_labels = SoftLabelsBuilder::new()
496 .output_similarity(output_similarity)
497 .runtime_ratio(runtime_ratio)
498 .semantic_confidence(if is_correct { 1.0 } else { 0.3 })
499 .type_safety(if is_correct { 1.0 } else { 0.5 })
500 .build();
501
502 let execution_metrics = ExecutionMetrics {
503 source_time_ms,
504 target_time_ms,
505 memory_bytes: 0,
506 timeout: false,
507 };
508
509 if is_correct {
510 RichLabel::correct(soft_labels).with_metrics(execution_metrics)
511 } else {
512 let category = error_msg.map_or(ErrorCategory::Unknown, ErrorCategory::classify);
513 let message = error_msg.unwrap_or("Unknown error").to_string();
514
515 RichLabel::incorrect(category, message, soft_labels).with_metrics(execution_metrics)
516 }
517 }
518
519 fn compute_output_similarity(&self, source: &str, target: &str) -> f32 {
520 if source == target {
521 return 1.0;
522 }
523
524 if source.is_empty() && target.is_empty() {
525 return 1.0;
526 }
527
528 if source.is_empty() || target.is_empty() {
529 return 0.0;
530 }
531
532 let source_lines: std::collections::HashSet<_> = source.lines().collect();
534 let target_lines: std::collections::HashSet<_> = target.lines().collect();
535
536 let intersection = source_lines.intersection(&target_lines).count();
537 let union = source_lines.union(&target_lines).count();
538
539 if union == 0 {
540 1.0
541 } else {
542 intersection as f32 / union as f32
543 }
544 }
545}
546
547#[cfg(test)]
548mod tests {
549 use super::*;
550
551 #[test]
554 fn test_error_category_all() {
555 let all = ErrorCategory::all();
556 assert_eq!(all.len(), 9);
557 }
558
559 #[test]
560 fn test_error_category_default() {
561 assert_eq!(ErrorCategory::default(), ErrorCategory::Unknown);
562 }
563
564 #[test]
565 fn test_error_category_severity() {
566 assert!(ErrorCategory::PanicDivergence.severity() > ErrorCategory::Unknown.severity());
567 assert!(
568 ErrorCategory::OwnershipViolation.severity()
569 > ErrorCategory::CompilationError.severity()
570 );
571 }
572
573 #[test]
574 fn test_error_category_classify_ownership() {
575 assert_eq!(
576 ErrorCategory::classify("cannot borrow x as mutable"),
577 ErrorCategory::OwnershipViolation
578 );
579 assert_eq!(
580 ErrorCategory::classify("value moved here"),
581 ErrorCategory::OwnershipViolation
582 );
583 }
584
585 #[test]
586 fn test_error_category_classify_lifetime() {
587 assert_eq!(
588 ErrorCategory::classify("lifetime 'a does not live long enough"),
589 ErrorCategory::LifetimeError
590 );
591 }
592
593 #[test]
594 fn test_error_category_classify_type() {
595 assert_eq!(
596 ErrorCategory::classify("type mismatch: expected i32"),
597 ErrorCategory::TypeMismatch
598 );
599 }
600
601 #[test]
602 fn test_error_category_classify_panic() {
603 assert_eq!(
604 ErrorCategory::classify("thread panicked at index out of bounds"),
605 ErrorCategory::PanicDivergence
606 );
607 }
608
609 #[test]
610 fn test_error_category_classify_output() {
611 assert_eq!(
612 ErrorCategory::classify("output mismatch: expected 5, actual 6"),
613 ErrorCategory::OutputMismatch
614 );
615 }
616
617 #[test]
618 fn test_error_category_classify_compilation() {
619 assert_eq!(
620 ErrorCategory::classify("cannot find value x in scope"),
621 ErrorCategory::CompilationError
622 );
623 }
624
625 #[test]
626 fn test_error_category_classify_runtime() {
627 assert_eq!(
628 ErrorCategory::classify("integer overflow detected"),
629 ErrorCategory::RuntimeError
630 );
631 }
632
633 #[test]
634 fn test_error_category_classify_resource() {
635 assert_eq!(
636 ErrorCategory::classify("execution timeout"),
637 ErrorCategory::ResourceExhaustion
638 );
639 }
640
641 #[test]
642 fn test_error_category_classify_unknown() {
643 assert_eq!(
644 ErrorCategory::classify("some random error"),
645 ErrorCategory::Unknown
646 );
647 }
648
649 #[test]
650 fn test_error_category_one_hot() {
651 let one_hot = ErrorCategory::TypeMismatch.to_one_hot();
652 assert_eq!(one_hot[0], 1.0);
653 assert_eq!(one_hot[1], 0.0);
654 }
655
656 #[test]
657 fn test_error_category_from_one_hot() {
658 let one_hot = [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
659 assert_eq!(
660 ErrorCategory::from_one_hot(&one_hot),
661 ErrorCategory::OwnershipViolation
662 );
663 }
664
665 #[test]
668 fn test_soft_labels_default() {
669 let labels = SoftLabels::default();
670 assert_eq!(labels.output_similarity, 0.0);
671 }
672
673 #[test]
674 fn test_soft_labels_is_valid() {
675 let valid = SoftLabels {
676 output_similarity: 0.8,
677 runtime_ratio: 1.2,
678 structural_similarity: 0.9,
679 semantic_confidence: 0.95,
680 type_safety: 1.0,
681 };
682 assert!(valid.is_valid());
683
684 let invalid = SoftLabels {
685 output_similarity: -0.1,
686 ..Default::default()
687 };
688 assert!(!invalid.is_valid());
689 }
690
691 #[test]
692 fn test_soft_labels_to_array() {
693 let labels = SoftLabels {
694 output_similarity: 0.8,
695 runtime_ratio: 1.5,
696 structural_similarity: 0.9,
697 semantic_confidence: 0.7,
698 type_safety: 1.0,
699 };
700
701 let arr = labels.to_array();
702 assert_eq!(arr.len(), 5);
703 assert!((arr[0] - 0.8).abs() < 0.001);
704 }
705
706 #[test]
707 fn test_soft_labels_overall_score() {
708 let perfect = SoftLabels {
709 output_similarity: 1.0,
710 runtime_ratio: 1.0,
711 structural_similarity: 1.0,
712 semantic_confidence: 1.0,
713 type_safety: 1.0,
714 };
715
716 let score = perfect.overall_score();
717 assert!((score - 1.0).abs() < 0.1);
718 }
719
720 #[test]
721 fn test_soft_labels_builder() {
722 let labels = SoftLabelsBuilder::new()
723 .output_similarity(0.9)
724 .runtime_ratio(1.1)
725 .structural_similarity(0.95)
726 .semantic_confidence(0.85)
727 .type_safety(1.0)
728 .build();
729
730 assert!((labels.output_similarity - 0.9).abs() < 0.001);
731 assert!((labels.runtime_ratio - 1.1).abs() < 0.001);
732 }
733
734 #[test]
735 fn test_soft_labels_builder_clamps() {
736 let labels = SoftLabelsBuilder::new()
737 .output_similarity(1.5) .semantic_confidence(-0.5) .build();
740
741 assert!((labels.output_similarity - 1.0).abs() < 0.001);
742 assert!((labels.semantic_confidence - 0.0).abs() < 0.001);
743 }
744
745 #[test]
748 fn test_rich_label_correct() {
749 let label = RichLabel::correct(SoftLabels::default());
750 assert!(label.is_correct);
751 assert!(label.error_category.is_none());
752 }
753
754 #[test]
755 fn test_rich_label_incorrect() {
756 let label = RichLabel::incorrect(
757 ErrorCategory::TypeMismatch,
758 "Type error".to_string(),
759 SoftLabels::default(),
760 );
761 assert!(!label.is_correct);
762 assert_eq!(label.error_category, Some(ErrorCategory::TypeMismatch));
763 }
764
765 #[test]
766 fn test_rich_label_with_ast_diff() {
767 let diff = AstDiff {
768 nodes_added: 5,
769 nodes_removed: 2,
770 nodes_modified: 3,
771 edit_distance: 10,
772 primary_change: Some("FunctionDef".to_string()),
773 };
774
775 let label = RichLabel::correct(SoftLabels::default()).with_ast_diff(diff);
776 assert!(label.ast_diff.is_some());
777 }
778
779 #[test]
780 fn test_rich_label_feature_vector() {
781 let label = RichLabel::correct(SoftLabels {
782 output_similarity: 1.0,
783 runtime_ratio: 1.0,
784 structural_similarity: 1.0,
785 semantic_confidence: 1.0,
786 type_safety: 1.0,
787 });
788
789 let features = label.to_feature_vector();
790 assert_eq!(features.len(), 19); assert!((features[0] - 1.0).abs() < 0.001); }
793
794 #[test]
797 fn test_ast_diff_total_changes() {
798 let diff = AstDiff {
799 nodes_added: 5,
800 nodes_removed: 3,
801 nodes_modified: 2,
802 edit_distance: 0,
803 primary_change: None,
804 };
805
806 assert_eq!(diff.total_changes(), 10);
807 }
808
809 #[test]
810 fn test_ast_diff_similarity() {
811 let diff = AstDiff {
812 nodes_added: 2,
813 nodes_removed: 0,
814 nodes_modified: 0,
815 edit_distance: 2,
816 primary_change: None,
817 };
818
819 let sim = diff.similarity(10);
820 assert!((sim - 0.8).abs() < 0.001);
821 }
822
823 #[test]
824 fn test_ast_diff_similarity_empty() {
825 let diff = AstDiff::default();
826 assert!((diff.similarity(0) - 1.0).abs() < 0.001);
827 }
828
829 #[test]
832 fn test_execution_metrics_runtime_ratio() {
833 let metrics = ExecutionMetrics {
834 source_time_ms: 100,
835 target_time_ms: 150,
836 memory_bytes: 0,
837 timeout: false,
838 };
839
840 assert!((metrics.runtime_ratio() - 1.5).abs() < 0.001);
841 }
842
843 #[test]
844 fn test_execution_metrics_runtime_ratio_zero() {
845 let metrics = ExecutionMetrics {
846 source_time_ms: 0,
847 target_time_ms: 100,
848 memory_bytes: 0,
849 timeout: false,
850 };
851
852 assert!((metrics.runtime_ratio() - 1.0).abs() < 0.001);
853 }
854
855 #[test]
858 fn test_label_extractor_correct() {
859 let extractor = LabelExtractor::new();
860 let label = extractor.extract(true, None, "hello\nworld", "hello\nworld", 100, 100);
861
862 assert!(label.is_correct);
863 assert!((label.soft_labels.output_similarity - 1.0).abs() < 0.001);
864 }
865
866 #[test]
867 fn test_label_extractor_incorrect() {
868 let extractor = LabelExtractor::new();
869 let label = extractor.extract(false, Some("type mismatch error"), "5", "6", 100, 100);
870
871 assert!(!label.is_correct);
872 assert_eq!(label.error_category, Some(ErrorCategory::TypeMismatch));
873 }
874
875 #[test]
876 fn test_label_extractor_output_similarity() {
877 let extractor = LabelExtractor::new();
878
879 let same = extractor.extract(true, None, "a\nb\nc", "a\nb\nc", 100, 100);
881 assert!((same.soft_labels.output_similarity - 1.0).abs() < 0.001);
882
883 let partial = extractor.extract(false, None, "a\nb\nc", "a\nb\nd", 100, 100);
885 assert!(partial.soft_labels.output_similarity > 0.0);
886 assert!(partial.soft_labels.output_similarity < 1.0);
887 }
888
889 #[test]
892 fn test_error_category_debug() {
893 let debug = format!("{:?}", ErrorCategory::TypeMismatch);
894 assert!(debug.contains("TypeMismatch"));
895 }
896
897 #[test]
898 fn test_soft_labels_debug() {
899 let labels = SoftLabels::default();
900 let debug = format!("{labels:?}");
901 assert!(debug.contains("SoftLabels"));
902 }
903
904 #[test]
905 fn test_rich_label_debug() {
906 let label = RichLabel::correct(SoftLabels::default());
907 let debug = format!("{label:?}");
908 assert!(debug.contains("RichLabel"));
909 }
910
911 #[test]
912 fn test_label_extractor_debug() {
913 let extractor = LabelExtractor::new();
914 let debug = format!("{extractor:?}");
915 assert!(debug.contains("LabelExtractor"));
916 }
917
918 #[test]
921 fn test_error_category_serialize() {
922 let category = ErrorCategory::OwnershipViolation;
923 let json = serde_json::to_string(&category).unwrap();
924 let restored: ErrorCategory = serde_json::from_str(&json).unwrap();
925 assert_eq!(category, restored);
926 }
927
928 #[test]
929 fn test_soft_labels_serialize() {
930 let labels = SoftLabelsBuilder::new()
931 .output_similarity(0.8)
932 .runtime_ratio(1.2)
933 .build();
934
935 let json = serde_json::to_string(&labels).unwrap();
936 let restored: SoftLabels = serde_json::from_str(&json).unwrap();
937 assert!((labels.output_similarity - restored.output_similarity).abs() < 0.001);
938 }
939
940 #[test]
941 fn test_rich_label_serialize() {
942 let label = RichLabel::incorrect(
943 ErrorCategory::TypeMismatch,
944 "Error".to_string(),
945 SoftLabels::default(),
946 );
947
948 let json = serde_json::to_string(&label).unwrap();
949 let restored: RichLabel = serde_json::from_str(&json).unwrap();
950 assert_eq!(label.is_correct, restored.is_correct);
951 assert_eq!(label.error_category, restored.error_category);
952 }
953}
954
955#[cfg(test)]
957mod proptests {
958 use super::*;
959 use proptest::prelude::*;
960
961 proptest! {
962 #[test]
964 fn prop_severity_bounded(idx in 0usize..9) {
965 let category = ErrorCategory::from_index(idx);
966 let severity = category.severity();
967 prop_assert!(severity >= 0.0);
968 prop_assert!(severity <= 1.0);
969 }
970
971 #[test]
973 fn prop_one_hot_roundtrip(idx in 0usize..9) {
974 let original = ErrorCategory::from_index(idx);
975 let one_hot = original.to_one_hot();
976 let restored = ErrorCategory::from_one_hot(&one_hot);
977 prop_assert_eq!(original, restored);
978 }
979
980 #[test]
982 fn prop_soft_labels_structure(
983 output_sim in 0.0f32..1.0,
984 structural_sim in 0.0f32..1.0,
985 semantic_conf in 0.0f32..1.0,
986 type_safety in 0.0f32..1.0,
987 ) {
988 let labels = SoftLabelsBuilder::new()
989 .output_similarity(output_sim)
990 .structural_similarity(structural_sim)
991 .semantic_confidence(semantic_conf)
992 .type_safety(type_safety)
993 .build();
994
995 prop_assert!(labels.is_valid());
996 }
997
998 #[test]
1000 fn prop_overall_score_bounded(
1001 output_sim in 0.0f32..1.0,
1002 runtime_ratio in 0.0f32..10.0,
1003 structural_sim in 0.0f32..1.0,
1004 semantic_conf in 0.0f32..1.0,
1005 type_safety in 0.0f32..1.0,
1006 ) {
1007 let labels = SoftLabels {
1008 output_similarity: output_sim,
1009 runtime_ratio,
1010 structural_similarity: structural_sim,
1011 semantic_confidence: semantic_conf,
1012 type_safety,
1013 };
1014
1015 let score = labels.overall_score();
1016 prop_assert!(score >= 0.0);
1017 prop_assert!(score <= 1.0);
1018 }
1019
1020 #[test]
1022 fn prop_feature_vector_length(is_correct: bool) {
1023 let label = if is_correct {
1024 RichLabel::correct(SoftLabels::default())
1025 } else {
1026 RichLabel::incorrect(ErrorCategory::Unknown, "error".to_string(), SoftLabels::default())
1027 };
1028
1029 let features = label.to_feature_vector();
1030 prop_assert_eq!(features.len(), 19);
1031 }
1032 }
1033}