1use rand::Rng;
19use rand_distr::{Beta, Distribution};
20use serde::{Deserialize, Serialize};
21use std::collections::HashMap;
22
23#[derive(Debug, Clone, Default, Serialize, Deserialize)]
25pub struct CodeEmbedding {
26 pub features: Vec<f32>,
28 pub dim: usize,
30}
31
32impl CodeEmbedding {
33 #[must_use]
35 pub fn new(dim: usize) -> Self {
36 Self {
37 features: vec![0.0; dim],
38 dim,
39 }
40 }
41
42 #[must_use]
44 pub fn from_vec(features: Vec<f32>) -> Self {
45 let dim = features.len();
46 Self { features, dim }
47 }
48
49 #[must_use]
51 pub fn norm(&self) -> f32 {
52 self.features.iter().map(|x| x * x).sum::<f32>().sqrt()
53 }
54
55 pub fn normalize(&mut self) {
57 let norm = self.norm();
58 if norm > 0.0 {
59 for x in &mut self.features {
60 *x /= norm;
61 }
62 }
63 }
64
65 #[must_use]
67 pub fn cosine_similarity(&self, other: &Self) -> f32 {
68 if self.dim != other.dim {
69 return 0.0;
70 }
71
72 let dot: f32 = self
73 .features
74 .iter()
75 .zip(&other.features)
76 .map(|(a, b)| a * b)
77 .sum();
78
79 let norm_a = self.norm();
80 let norm_b = other.norm();
81
82 if norm_a > 0.0 && norm_b > 0.0 {
83 dot / (norm_a * norm_b)
84 } else {
85 0.0
86 }
87 }
88
89 #[must_use]
91 pub fn euclidean_distance(&self, other: &Self) -> f32 {
92 if self.dim != other.dim {
93 return f32::MAX;
94 }
95
96 self.features
97 .iter()
98 .zip(&other.features)
99 .map(|(a, b)| (a - b).powi(2))
100 .sum::<f32>()
101 .sqrt()
102 }
103}
104
105#[derive(Debug, Clone)]
107pub struct CodeEmbedder {
108 n: usize,
110 vocab_size: usize,
112}
113
114impl Default for CodeEmbedder {
115 fn default() -> Self {
116 Self::new(3, 128)
117 }
118}
119
120impl CodeEmbedder {
121 #[must_use]
123 pub fn new(n: usize, vocab_size: usize) -> Self {
124 Self { n, vocab_size }
125 }
126
127 #[must_use]
129 pub fn embed(&self, code: &str) -> CodeEmbedding {
130 let mut features = vec![0.0f32; self.vocab_size];
131
132 let chars: Vec<char> = code.chars().collect();
134 if chars.len() >= self.n {
135 for window in chars.windows(self.n) {
136 let hash = self.hash_ngram(window);
137 features[hash] += 1.0;
138 }
139 }
140
141 for word in code.split_whitespace() {
143 let hash = self.hash_word(word);
144 features[hash] += 1.0;
145 }
146
147 let mut embedding = CodeEmbedding::from_vec(features);
148 embedding.normalize();
149 embedding
150 }
151
152 fn hash_ngram(&self, chars: &[char]) -> usize {
153 let mut hash = 0usize;
154 for (i, &c) in chars.iter().enumerate() {
155 hash = hash.wrapping_add((c as usize).wrapping_mul(31_usize.wrapping_pow(i as u32)));
156 }
157 hash % self.vocab_size
158 }
159
160 fn hash_word(&self, word: &str) -> usize {
161 let mut hash = 0usize;
162 for (i, c) in word.chars().enumerate() {
163 hash = hash.wrapping_add((c as usize).wrapping_mul(37_usize.wrapping_pow(i as u32)));
164 }
165 hash % self.vocab_size
166 }
167}
168
169#[derive(Debug, Clone, Serialize, Deserialize)]
171pub struct Cluster {
172 pub id: usize,
174 pub centroid: CodeEmbedding,
176 pub size: usize,
178 pub intra_distance: f32,
180}
181
182impl Cluster {
183 #[must_use]
185 pub fn new(id: usize, centroid: CodeEmbedding) -> Self {
186 Self {
187 id,
188 centroid,
189 size: 0,
190 intra_distance: 0.0,
191 }
192 }
193
194 #[must_use]
196 pub fn avg_intra_distance(&self) -> f32 {
197 if self.size > 0 {
198 self.intra_distance / self.size as f32
199 } else {
200 0.0
201 }
202 }
203}
204
205#[derive(Debug, Clone)]
207pub struct ClusteringResult {
208 pub clusters: Vec<Cluster>,
210 pub assignments: Vec<usize>,
212 pub silhouette_score: f32,
214 pub iterations: usize,
216}
217
218#[derive(Debug, Clone)]
220pub struct KMeansClustering {
221 k: usize,
223 max_iter: usize,
225 seed: u64,
227}
228
229impl Default for KMeansClustering {
230 fn default() -> Self {
231 Self::new(5)
232 }
233}
234
235impl KMeansClustering {
236 #[must_use]
238 pub fn new(k: usize) -> Self {
239 Self {
240 k,
241 max_iter: 100,
242 seed: 42,
243 }
244 }
245
246 #[must_use]
248 pub fn with_max_iter(mut self, max_iter: usize) -> Self {
249 self.max_iter = max_iter;
250 self
251 }
252
253 #[must_use]
255 pub fn with_seed(mut self, seed: u64) -> Self {
256 self.seed = seed;
257 self
258 }
259
260 pub fn fit(&self, embeddings: &[CodeEmbedding]) -> ClusteringResult {
262 if embeddings.is_empty() {
263 return ClusteringResult {
264 clusters: vec![],
265 assignments: vec![],
266 silhouette_score: 0.0,
267 iterations: 0,
268 };
269 }
270
271 let dim = embeddings[0].dim;
272 let actual_k = self.k.min(embeddings.len());
273
274 let mut rng = rand::rng();
276 let mut centroids = self.init_centroids(embeddings, actual_k, &mut rng);
277
278 let mut assignments = vec![0usize; embeddings.len()];
279 let mut iterations = 0;
280
281 for iter in 0..self.max_iter {
282 iterations = iter + 1;
283
284 let mut changed = false;
286 for (i, emb) in embeddings.iter().enumerate() {
287 let nearest = self.find_nearest_centroid(emb, ¢roids);
288 if assignments[i] != nearest {
289 assignments[i] = nearest;
290 changed = true;
291 }
292 }
293
294 centroids = self.update_centroids(embeddings, &assignments, actual_k, dim);
296
297 if !changed {
298 break;
299 }
300 }
301
302 let mut clusters: Vec<Cluster> = centroids
304 .into_iter()
305 .enumerate()
306 .map(|(id, centroid)| Cluster::new(id, centroid))
307 .collect();
308
309 for (i, &cluster_id) in assignments.iter().enumerate() {
311 if cluster_id < clusters.len() {
312 clusters[cluster_id].size += 1;
313 clusters[cluster_id].intra_distance +=
314 embeddings[i].euclidean_distance(&clusters[cluster_id].centroid);
315 }
316 }
317
318 let silhouette_score = self.calculate_silhouette(embeddings, &assignments, &clusters);
320
321 ClusteringResult {
322 clusters,
323 assignments,
324 silhouette_score,
325 iterations,
326 }
327 }
328
329 fn init_centroids<R: Rng>(
330 &self,
331 embeddings: &[CodeEmbedding],
332 k: usize,
333 rng: &mut R,
334 ) -> Vec<CodeEmbedding> {
335 if embeddings.is_empty() || k == 0 {
336 return vec![];
337 }
338
339 let mut centroids = Vec::with_capacity(k);
340
341 let first_idx = rng.random_range(0..embeddings.len());
343 centroids.push(embeddings[first_idx].clone());
344
345 for _ in 1..k {
347 let distances: Vec<f32> = embeddings
348 .iter()
349 .map(|emb| {
350 centroids
351 .iter()
352 .map(|c| emb.euclidean_distance(c))
353 .fold(f32::MAX, f32::min)
354 .powi(2)
355 })
356 .collect();
357
358 let total: f32 = distances.iter().sum();
359 if total <= 0.0 {
360 break;
361 }
362
363 let threshold = rng.random::<f32>() * total;
364 let mut cumsum = 0.0;
365 for (i, &d) in distances.iter().enumerate() {
366 cumsum += d;
367 if cumsum >= threshold {
368 centroids.push(embeddings[i].clone());
369 break;
370 }
371 }
372 }
373
374 centroids
375 }
376
377 fn find_nearest_centroid(&self, emb: &CodeEmbedding, centroids: &[CodeEmbedding]) -> usize {
378 centroids
379 .iter()
380 .enumerate()
381 .map(|(i, c)| (i, emb.euclidean_distance(c)))
382 .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
383 .map_or(0, |(i, _)| i)
384 }
385
386 fn update_centroids(
387 &self,
388 embeddings: &[CodeEmbedding],
389 assignments: &[usize],
390 k: usize,
391 dim: usize,
392 ) -> Vec<CodeEmbedding> {
393 let mut sums: Vec<Vec<f32>> = vec![vec![0.0; dim]; k];
394 let mut counts = vec![0usize; k];
395
396 for (i, &cluster_id) in assignments.iter().enumerate() {
397 if cluster_id < k {
398 counts[cluster_id] += 1;
399 for (j, &val) in embeddings[i].features.iter().enumerate() {
400 if j < dim {
401 sums[cluster_id][j] += val;
402 }
403 }
404 }
405 }
406
407 sums.into_iter()
408 .zip(counts)
409 .map(|(sum, count)| {
410 if count > 0 {
411 let features: Vec<f32> = sum.into_iter().map(|s| s / count as f32).collect();
412 CodeEmbedding::from_vec(features)
413 } else {
414 CodeEmbedding::new(dim)
415 }
416 })
417 .collect()
418 }
419
420 fn calculate_silhouette(
421 &self,
422 embeddings: &[CodeEmbedding],
423 assignments: &[usize],
424 clusters: &[Cluster],
425 ) -> f32 {
426 if embeddings.len() <= 1 || clusters.len() <= 1 {
427 return 0.0;
428 }
429
430 let mut total_score = 0.0;
431 let mut count = 0;
432
433 for (i, emb) in embeddings.iter().enumerate() {
434 let cluster_id = assignments[i];
435 if cluster_id >= clusters.len() {
436 continue;
437 }
438
439 let a = clusters[cluster_id].avg_intra_distance();
441
442 let b = clusters
444 .iter()
445 .filter(|c| c.id != cluster_id)
446 .map(|c| emb.euclidean_distance(&c.centroid))
447 .fold(f32::MAX, f32::min);
448
449 if b < f32::MAX {
450 let max_ab = a.max(b);
451 if max_ab > 0.0 {
452 total_score += (b - a) / max_ab;
453 count += 1;
454 }
455 }
456 }
457
458 if count > 0 {
459 total_score / count as f32
460 } else {
461 0.0
462 }
463 }
464}
465
466#[derive(Debug)]
468pub struct ActiveLearner {
469 embedder: CodeEmbedder,
471 clustering: KMeansClustering,
473 cluster_result: Option<ClusteringResult>,
475 success_counts: HashMap<usize, f64>,
477 failure_counts: HashMap<usize, f64>,
479 total_samples: usize,
481 exploration_rate: f64,
483}
484
485impl Default for ActiveLearner {
486 fn default() -> Self {
487 Self::new(5)
488 }
489}
490
491impl ActiveLearner {
492 #[must_use]
494 pub fn new(k: usize) -> Self {
495 Self {
496 embedder: CodeEmbedder::default(),
497 clustering: KMeansClustering::new(k),
498 cluster_result: None,
499 success_counts: HashMap::new(),
500 failure_counts: HashMap::new(),
501 total_samples: 0,
502 exploration_rate: 0.1,
503 }
504 }
505
506 #[must_use]
508 pub fn with_embedder(mut self, embedder: CodeEmbedder) -> Self {
509 self.embedder = embedder;
510 self
511 }
512
513 #[must_use]
515 pub fn with_exploration_rate(mut self, rate: f64) -> Self {
516 self.exploration_rate = rate.clamp(0.0, 1.0);
517 self
518 }
519
520 pub fn fit(&mut self, codes: &[&str]) {
522 let embeddings: Vec<CodeEmbedding> = codes.iter().map(|c| self.embedder.embed(c)).collect();
523
524 self.cluster_result = Some(self.clustering.fit(&embeddings));
525 }
526
527 #[must_use]
529 pub fn get_cluster(&self, code: &str) -> Option<usize> {
530 let embedding = self.embedder.embed(code);
531 self.cluster_result.as_ref().map(|result| {
532 result
533 .clusters
534 .iter()
535 .enumerate()
536 .map(|(i, c)| (i, embedding.euclidean_distance(&c.centroid)))
537 .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
538 .map_or(0, |(i, _)| i)
539 })
540 }
541
542 pub fn sample_cluster(&self) -> Option<usize> {
551 let result = self.cluster_result.as_ref()?;
552 if result.clusters.is_empty() {
553 return None;
554 }
555
556 let mut rng = rand::rng();
557
558 let scores: Vec<(usize, f64)> = result
560 .clusters
561 .iter()
562 .map(|c| {
563 let alpha = self.failure_counts.get(&c.id).copied().unwrap_or(0.0) + 1.0;
565 let beta = self.success_counts.get(&c.id).copied().unwrap_or(0.0) + 1.0;
566
567 #[allow(clippy::unwrap_used)]
569 let beta_dist =
570 Beta::new(alpha, beta).unwrap_or_else(|_| Beta::new(1.0, 1.0).unwrap());
571 let score = beta_dist.sample(&mut rng);
572
573 (c.id, score)
574 })
575 .collect();
576
577 scores
579 .into_iter()
580 .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
581 .map(|(id, _)| id)
582 }
583
584 pub fn select_batch(&self, codes: &[&str], batch_size: usize) -> Vec<usize> {
588 if codes.is_empty() || batch_size == 0 {
589 return vec![];
590 }
591
592 let Some(_result) = &self.cluster_result else {
593 return (0..batch_size.min(codes.len())).collect();
594 };
595
596 let mut rng = rand::rng();
597 let mut selected = Vec::with_capacity(batch_size);
598 let mut remaining: Vec<usize> = (0..codes.len()).collect();
599
600 while selected.len() < batch_size && !remaining.is_empty() {
601 let target_cluster = self.sample_cluster().unwrap_or(0);
603
604 let in_cluster: Vec<usize> = remaining
606 .iter()
607 .filter(|&&i| {
608 self.get_cluster(codes[i])
609 .is_some_and(|c| c == target_cluster)
610 })
611 .copied()
612 .collect();
613
614 if in_cluster.is_empty() {
615 let idx = rng.random_range(0..remaining.len());
617 let sample_idx = remaining.remove(idx);
618 selected.push(sample_idx);
619 } else {
620 let idx = rng.random_range(0..in_cluster.len());
622 let sample_idx = in_cluster[idx];
623 remaining.retain(|&x| x != sample_idx);
624 selected.push(sample_idx);
625 }
626 }
627
628 selected
629 }
630
631 pub fn update_feedback(&mut self, code: &str, revealed_bug: bool) {
637 if let Some(cluster_id) = self.get_cluster(code) {
638 if revealed_bug {
639 *self.failure_counts.entry(cluster_id).or_insert(0.0) += 1.0;
640 } else {
641 *self.success_counts.entry(cluster_id).or_insert(0.0) += 1.0;
642 }
643 }
644 self.total_samples += 1;
645 }
646
647 #[must_use]
649 pub fn silhouette_score(&self) -> f32 {
650 self.cluster_result
651 .as_ref()
652 .map_or(0.0, |r| r.silhouette_score)
653 }
654
655 #[must_use]
657 pub fn cluster_stats(&self) -> Vec<ClusterStats> {
658 self.cluster_result
659 .as_ref()
660 .map(|r| {
661 r.clusters
662 .iter()
663 .map(|c| {
664 let successes = self.success_counts.get(&c.id).copied().unwrap_or(0.0);
665 let failures = self.failure_counts.get(&c.id).copied().unwrap_or(0.0);
666 let total = successes + failures;
667
668 ClusterStats {
669 cluster_id: c.id,
670 size: c.size,
671 bug_rate: if total > 0.0 { failures / total } else { 0.5 },
672 #[allow(clippy::cast_sign_loss)]
673 samples_tried: total.max(0.0) as usize,
674 }
675 })
676 .collect()
677 })
678 .unwrap_or_default()
679 }
680
681 #[must_use]
683 pub fn should_explore(&self) -> bool {
684 let mut rng = rand::rng();
685 rng.random::<f64>() < self.exploration_rate
686 }
687
688 #[must_use]
690 pub fn total_samples(&self) -> usize {
691 self.total_samples
692 }
693}
694
695#[derive(Debug, Clone, Serialize, Deserialize)]
697pub struct ClusterStats {
698 pub cluster_id: usize,
700 pub size: usize,
702 pub bug_rate: f64,
704 pub samples_tried: usize,
706}
707
708#[cfg(test)]
709mod tests {
710 use super::*;
711
712 fn sample_codes() -> Vec<&'static str> {
713 vec![
714 "def add(a, b):\n return a + b",
715 "def sub(a, b):\n return a - b",
716 "for i in range(10):\n print(i)",
717 "while True:\n break",
718 "if x > 0:\n return x\nelse:\n return -x",
719 "class Foo:\n def __init__(self):\n pass",
720 "x = [1, 2, 3]\ny = sum(x)",
721 "import os\npath = os.getcwd()",
722 ]
723 }
724
725 #[test]
728 fn test_code_embedding_new() {
729 let emb = CodeEmbedding::new(64);
730 assert_eq!(emb.dim, 64);
731 assert_eq!(emb.features.len(), 64);
732 }
733
734 #[test]
735 fn test_code_embedding_from_vec() {
736 let features = vec![1.0, 2.0, 3.0];
737 let emb = CodeEmbedding::from_vec(features.clone());
738 assert_eq!(emb.features, features);
739 }
740
741 #[test]
742 fn test_code_embedding_norm() {
743 let emb = CodeEmbedding::from_vec(vec![3.0, 4.0]);
744 assert!((emb.norm() - 5.0).abs() < 0.001);
745 }
746
747 #[test]
748 fn test_code_embedding_normalize() {
749 let mut emb = CodeEmbedding::from_vec(vec![3.0, 4.0]);
750 emb.normalize();
751 assert!((emb.norm() - 1.0).abs() < 0.001);
752 }
753
754 #[test]
755 fn test_code_embedding_cosine_similarity_same() {
756 let emb = CodeEmbedding::from_vec(vec![1.0, 2.0, 3.0]);
757 assert!((emb.cosine_similarity(&emb) - 1.0).abs() < 0.001);
758 }
759
760 #[test]
761 fn test_code_embedding_cosine_similarity_orthogonal() {
762 let emb1 = CodeEmbedding::from_vec(vec![1.0, 0.0]);
763 let emb2 = CodeEmbedding::from_vec(vec![0.0, 1.0]);
764 assert!(emb1.cosine_similarity(&emb2).abs() < 0.001);
765 }
766
767 #[test]
768 fn test_code_embedding_euclidean_distance() {
769 let emb1 = CodeEmbedding::from_vec(vec![0.0, 0.0]);
770 let emb2 = CodeEmbedding::from_vec(vec![3.0, 4.0]);
771 assert!((emb1.euclidean_distance(&emb2) - 5.0).abs() < 0.001);
772 }
773
774 #[test]
777 fn test_code_embedder_default() {
778 let embedder = CodeEmbedder::default();
779 assert_eq!(embedder.n, 3);
780 assert_eq!(embedder.vocab_size, 128);
781 }
782
783 #[test]
784 fn test_code_embedder_embed() {
785 let embedder = CodeEmbedder::default();
786 let emb = embedder.embed("def foo(): return 1");
787 assert_eq!(emb.dim, 128);
788 assert!(emb.norm() > 0.0);
789 }
790
791 #[test]
792 fn test_code_embedder_similar_code() {
793 let embedder = CodeEmbedder::default();
794 let emb1 = embedder.embed("def add(a, b): return a + b");
795 let emb2 = embedder.embed("def add(x, y): return x + y");
796 let emb3 = embedder.embed("class Foo: pass");
797
798 let sim_12 = emb1.cosine_similarity(&emb2);
800 let sim_13 = emb1.cosine_similarity(&emb3);
801 assert!(sim_12 > sim_13);
802 }
803
804 #[test]
805 fn test_code_embedder_empty() {
806 let embedder = CodeEmbedder::default();
807 let emb = embedder.embed("");
808 assert_eq!(emb.dim, 128);
809 }
810
811 #[test]
814 fn test_kmeans_default() {
815 let kmeans = KMeansClustering::default();
816 assert_eq!(kmeans.k, 5);
817 }
818
819 #[test]
820 fn test_kmeans_fit_empty() {
821 let kmeans = KMeansClustering::new(3);
822 let result = kmeans.fit(&[]);
823 assert!(result.clusters.is_empty());
824 assert!(result.assignments.is_empty());
825 }
826
827 #[test]
828 fn test_kmeans_fit() {
829 let embedder = CodeEmbedder::default();
830 let codes = sample_codes();
831 let embeddings: Vec<CodeEmbedding> = codes.iter().map(|c| embedder.embed(c)).collect();
832
833 let kmeans = KMeansClustering::new(3).with_seed(42);
834 let result = kmeans.fit(&embeddings);
835
836 assert_eq!(result.clusters.len(), 3);
837 assert_eq!(result.assignments.len(), codes.len());
838 }
839
840 #[test]
841 fn test_kmeans_silhouette_bounded() {
842 let embedder = CodeEmbedder::default();
843 let codes = sample_codes();
844 let embeddings: Vec<CodeEmbedding> = codes.iter().map(|c| embedder.embed(c)).collect();
845
846 let kmeans = KMeansClustering::new(3);
847 let result = kmeans.fit(&embeddings);
848
849 assert!(result.silhouette_score >= -1.0);
851 assert!(result.silhouette_score <= 1.0);
852 }
853
854 #[test]
857 fn test_active_learner_new() {
858 let learner = ActiveLearner::new(5);
859 assert_eq!(learner.total_samples(), 0);
860 }
861
862 #[test]
863 fn test_active_learner_fit() {
864 let mut learner = ActiveLearner::new(3);
865 let codes = sample_codes();
866
867 learner.fit(&codes);
868
869 assert!(learner.silhouette_score() >= -1.0);
870 }
871
872 #[test]
873 fn test_active_learner_get_cluster() {
874 let mut learner = ActiveLearner::new(3);
875 let codes = sample_codes();
876
877 learner.fit(&codes);
878
879 let cluster = learner.get_cluster(codes[0]);
880 assert!(cluster.is_some());
881 }
882
883 #[test]
884 fn test_active_learner_sample_cluster() {
885 let mut learner = ActiveLearner::new(3);
886 let codes = sample_codes();
887
888 learner.fit(&codes);
889
890 let cluster = learner.sample_cluster();
891 assert!(cluster.is_some());
892 }
893
894 #[test]
895 fn test_active_learner_select_batch() {
896 let mut learner = ActiveLearner::new(3);
897 let codes = sample_codes();
898
899 learner.fit(&codes);
900
901 let batch = learner.select_batch(&codes, 3);
902 assert_eq!(batch.len(), 3);
903 let mut sorted = batch.clone();
905 sorted.sort();
906 sorted.dedup();
907 assert_eq!(sorted.len(), batch.len());
908 }
909
910 #[test]
911 fn test_active_learner_update_feedback() {
912 let mut learner = ActiveLearner::new(3);
913 let codes = sample_codes();
914
915 learner.fit(&codes);
916
917 learner.update_feedback(codes[0], true);
918 learner.update_feedback(codes[1], false);
919
920 assert_eq!(learner.total_samples(), 2);
921 }
922
923 #[test]
924 fn test_active_learner_cluster_stats() {
925 let mut learner = ActiveLearner::new(3);
926 let codes = sample_codes();
927
928 learner.fit(&codes);
929
930 for (i, &code) in codes.iter().enumerate() {
932 learner.update_feedback(code, i % 2 == 0);
933 }
934
935 let stats = learner.cluster_stats();
936 assert!(!stats.is_empty());
937 }
938
939 #[test]
940 fn test_active_learner_exploration_rate() {
941 let learner = ActiveLearner::new(3).with_exploration_rate(1.0);
942
943 let mut explored = 0;
945 for _ in 0..100 {
946 if learner.should_explore() {
947 explored += 1;
948 }
949 }
950 assert_eq!(explored, 100);
951 }
952
953 #[test]
956 fn test_code_embedding_debug() {
957 let emb = CodeEmbedding::new(4);
958 let debug = format!("{emb:?}");
959 assert!(debug.contains("CodeEmbedding"));
960 }
961
962 #[test]
963 fn test_code_embedder_debug() {
964 let embedder = CodeEmbedder::default();
965 let debug = format!("{embedder:?}");
966 assert!(debug.contains("CodeEmbedder"));
967 }
968
969 #[test]
970 fn test_cluster_debug() {
971 let cluster = Cluster::new(0, CodeEmbedding::new(4));
972 let debug = format!("{cluster:?}");
973 assert!(debug.contains("Cluster"));
974 }
975
976 #[test]
977 fn test_active_learner_debug() {
978 let learner = ActiveLearner::new(3);
979 let debug = format!("{learner:?}");
980 assert!(debug.contains("ActiveLearner"));
981 }
982}
983
984#[cfg(test)]
986mod proptests {
987 use super::*;
988 use proptest::prelude::*;
989
990 proptest! {
991 #[test]
993 fn prop_embedding_norm_nonnegative(features in proptest::collection::vec(-100.0f32..100.0, 1..50)) {
994 let emb = CodeEmbedding::from_vec(features);
995 prop_assert!(emb.norm() >= 0.0);
996 }
997
998 #[test]
1000 fn prop_cosine_bounded(
1001 f1 in proptest::collection::vec(-10.0f32..10.0, 1..20),
1002 f2 in proptest::collection::vec(-10.0f32..10.0, 1..20),
1003 ) {
1004 let dim = f1.len().min(f2.len());
1005 let emb1 = CodeEmbedding::from_vec(f1[..dim].to_vec());
1006 let emb2 = CodeEmbedding::from_vec(f2[..dim].to_vec());
1007
1008 let sim = emb1.cosine_similarity(&emb2);
1009 prop_assert!(sim >= -1.0 - 0.001);
1010 prop_assert!(sim <= 1.0 + 0.001);
1011 }
1012
1013 #[test]
1015 fn prop_euclidean_nonnegative(
1016 f1 in proptest::collection::vec(-100.0f32..100.0, 1..20),
1017 f2 in proptest::collection::vec(-100.0f32..100.0, 1..20),
1018 ) {
1019 let dim = f1.len().min(f2.len());
1020 let emb1 = CodeEmbedding::from_vec(f1[..dim].to_vec());
1021 let emb2 = CodeEmbedding::from_vec(f2[..dim].to_vec());
1022
1023 prop_assert!(emb1.euclidean_distance(&emb2) >= 0.0);
1024 }
1025
1026 #[test]
1028 fn prop_normalized_unit_norm(features in proptest::collection::vec(0.1f32..10.0, 1..20)) {
1029 let mut emb = CodeEmbedding::from_vec(features);
1030 emb.normalize();
1031
1032 prop_assert!((emb.norm() - 1.0).abs() < 0.01);
1034 }
1035
1036 #[test]
1038 fn prop_batch_indices_valid(batch_size in 1usize..10) {
1039 let mut learner = ActiveLearner::new(3);
1040 let codes: Vec<&str> = vec![
1041 "x = 1",
1042 "y = 2",
1043 "z = 3",
1044 "def f(): pass",
1045 "class C: pass",
1046 ];
1047
1048 learner.fit(&codes);
1049
1050 let batch = learner.select_batch(&codes, batch_size);
1051
1052 for &idx in &batch {
1053 prop_assert!(idx < codes.len());
1054 }
1055 }
1056 }
1057}