Skip to main content

verificar/ml/
active_learner.rs

1//! Active Learning with Thompson Sampling
2//!
3//! Implements active learning for dynamic sampling strategy adjustment
4//! based on oracle feedback using Thompson Sampling on code clusters.
5//!
6//! # Architecture
7//!
8//! ```text
9//! Code → Embedding → Clustering → Thompson Sampling → Sample Selection
10//!                                        ↑
11//!                                  Oracle Feedback
12//! ```
13//!
14//! # Reference
15//! - VER-052: Active Learning - Thompson Sampling exploration
16//! - Spieker et al. (2017): "Reinforcement Learning for Automatic Test Case Prioritization"
17
18use rand::Rng;
19use rand_distr::{Beta, Distribution};
20use serde::{Deserialize, Serialize};
21use std::collections::HashMap;
22
23/// Code embedding for clustering
24#[derive(Debug, Clone, Default, Serialize, Deserialize)]
25pub struct CodeEmbedding {
26    /// Feature vector (n-gram counts, normalized)
27    pub features: Vec<f32>,
28    /// Dimensionality
29    pub dim: usize,
30}
31
32impl CodeEmbedding {
33    /// Create empty embedding
34    #[must_use]
35    pub fn new(dim: usize) -> Self {
36        Self {
37            features: vec![0.0; dim],
38            dim,
39        }
40    }
41
42    /// Create from feature vector
43    #[must_use]
44    pub fn from_vec(features: Vec<f32>) -> Self {
45        let dim = features.len();
46        Self { features, dim }
47    }
48
49    /// L2 norm of embedding
50    #[must_use]
51    pub fn norm(&self) -> f32 {
52        self.features.iter().map(|x| x * x).sum::<f32>().sqrt()
53    }
54
55    /// Normalize to unit vector
56    pub fn normalize(&mut self) {
57        let norm = self.norm();
58        if norm > 0.0 {
59            for x in &mut self.features {
60                *x /= norm;
61            }
62        }
63    }
64
65    /// Cosine similarity with another embedding
66    #[must_use]
67    pub fn cosine_similarity(&self, other: &Self) -> f32 {
68        if self.dim != other.dim {
69            return 0.0;
70        }
71
72        let dot: f32 = self
73            .features
74            .iter()
75            .zip(&other.features)
76            .map(|(a, b)| a * b)
77            .sum();
78
79        let norm_a = self.norm();
80        let norm_b = other.norm();
81
82        if norm_a > 0.0 && norm_b > 0.0 {
83            dot / (norm_a * norm_b)
84        } else {
85            0.0
86        }
87    }
88
89    /// Euclidean distance to another embedding
90    #[must_use]
91    pub fn euclidean_distance(&self, other: &Self) -> f32 {
92        if self.dim != other.dim {
93            return f32::MAX;
94        }
95
96        self.features
97            .iter()
98            .zip(&other.features)
99            .map(|(a, b)| (a - b).powi(2))
100            .sum::<f32>()
101            .sqrt()
102    }
103}
104
105/// Simple code embedder using n-gram features
106#[derive(Debug, Clone)]
107pub struct CodeEmbedder {
108    /// N-gram size
109    n: usize,
110    /// Vocabulary size (hash buckets)
111    vocab_size: usize,
112}
113
114impl Default for CodeEmbedder {
115    fn default() -> Self {
116        Self::new(3, 128)
117    }
118}
119
120impl CodeEmbedder {
121    /// Create embedder with n-gram size and vocabulary size
122    #[must_use]
123    pub fn new(n: usize, vocab_size: usize) -> Self {
124        Self { n, vocab_size }
125    }
126
127    /// Embed code string to vector
128    #[must_use]
129    pub fn embed(&self, code: &str) -> CodeEmbedding {
130        let mut features = vec![0.0f32; self.vocab_size];
131
132        // Extract character n-grams
133        let chars: Vec<char> = code.chars().collect();
134        if chars.len() >= self.n {
135            for window in chars.windows(self.n) {
136                let hash = self.hash_ngram(window);
137                features[hash] += 1.0;
138            }
139        }
140
141        // Also add word unigrams
142        for word in code.split_whitespace() {
143            let hash = self.hash_word(word);
144            features[hash] += 1.0;
145        }
146
147        let mut embedding = CodeEmbedding::from_vec(features);
148        embedding.normalize();
149        embedding
150    }
151
152    fn hash_ngram(&self, chars: &[char]) -> usize {
153        let mut hash = 0usize;
154        for (i, &c) in chars.iter().enumerate() {
155            hash = hash.wrapping_add((c as usize).wrapping_mul(31_usize.wrapping_pow(i as u32)));
156        }
157        hash % self.vocab_size
158    }
159
160    fn hash_word(&self, word: &str) -> usize {
161        let mut hash = 0usize;
162        for (i, c) in word.chars().enumerate() {
163            hash = hash.wrapping_add((c as usize).wrapping_mul(37_usize.wrapping_pow(i as u32)));
164        }
165        hash % self.vocab_size
166    }
167}
168
169/// K-means cluster
170#[derive(Debug, Clone, Serialize, Deserialize)]
171pub struct Cluster {
172    /// Cluster ID
173    pub id: usize,
174    /// Centroid
175    pub centroid: CodeEmbedding,
176    /// Number of samples in cluster
177    pub size: usize,
178    /// Sum of distances to centroid (for silhouette calculation)
179    pub intra_distance: f32,
180}
181
182impl Cluster {
183    /// Create new cluster
184    #[must_use]
185    pub fn new(id: usize, centroid: CodeEmbedding) -> Self {
186        Self {
187            id,
188            centroid,
189            size: 0,
190            intra_distance: 0.0,
191        }
192    }
193
194    /// Average intra-cluster distance
195    #[must_use]
196    pub fn avg_intra_distance(&self) -> f32 {
197        if self.size > 0 {
198            self.intra_distance / self.size as f32
199        } else {
200            0.0
201        }
202    }
203}
204
205/// K-means clustering result
206#[derive(Debug, Clone)]
207pub struct ClusteringResult {
208    /// Clusters
209    pub clusters: Vec<Cluster>,
210    /// Assignment of each sample to cluster
211    pub assignments: Vec<usize>,
212    /// Silhouette score (-1 to 1, higher = better)
213    pub silhouette_score: f32,
214    /// Number of iterations
215    pub iterations: usize,
216}
217
218/// Simple K-means clustering
219#[derive(Debug, Clone)]
220pub struct KMeansClustering {
221    /// Number of clusters
222    k: usize,
223    /// Max iterations
224    max_iter: usize,
225    /// Random seed
226    seed: u64,
227}
228
229impl Default for KMeansClustering {
230    fn default() -> Self {
231        Self::new(5)
232    }
233}
234
235impl KMeansClustering {
236    /// Create with k clusters
237    #[must_use]
238    pub fn new(k: usize) -> Self {
239        Self {
240            k,
241            max_iter: 100,
242            seed: 42,
243        }
244    }
245
246    /// Set max iterations
247    #[must_use]
248    pub fn with_max_iter(mut self, max_iter: usize) -> Self {
249        self.max_iter = max_iter;
250        self
251    }
252
253    /// Set random seed
254    #[must_use]
255    pub fn with_seed(mut self, seed: u64) -> Self {
256        self.seed = seed;
257        self
258    }
259
260    /// Fit clusters to embeddings
261    pub fn fit(&self, embeddings: &[CodeEmbedding]) -> ClusteringResult {
262        if embeddings.is_empty() {
263            return ClusteringResult {
264                clusters: vec![],
265                assignments: vec![],
266                silhouette_score: 0.0,
267                iterations: 0,
268            };
269        }
270
271        let dim = embeddings[0].dim;
272        let actual_k = self.k.min(embeddings.len());
273
274        // Initialize centroids (k-means++ style)
275        let mut rng = rand::rng();
276        let mut centroids = self.init_centroids(embeddings, actual_k, &mut rng);
277
278        let mut assignments = vec![0usize; embeddings.len()];
279        let mut iterations = 0;
280
281        for iter in 0..self.max_iter {
282            iterations = iter + 1;
283
284            // Assign samples to nearest centroid
285            let mut changed = false;
286            for (i, emb) in embeddings.iter().enumerate() {
287                let nearest = self.find_nearest_centroid(emb, &centroids);
288                if assignments[i] != nearest {
289                    assignments[i] = nearest;
290                    changed = true;
291                }
292            }
293
294            // Update centroids
295            centroids = self.update_centroids(embeddings, &assignments, actual_k, dim);
296
297            if !changed {
298                break;
299            }
300        }
301
302        // Build cluster objects
303        let mut clusters: Vec<Cluster> = centroids
304            .into_iter()
305            .enumerate()
306            .map(|(id, centroid)| Cluster::new(id, centroid))
307            .collect();
308
309        // Calculate cluster sizes and intra-distances
310        for (i, &cluster_id) in assignments.iter().enumerate() {
311            if cluster_id < clusters.len() {
312                clusters[cluster_id].size += 1;
313                clusters[cluster_id].intra_distance +=
314                    embeddings[i].euclidean_distance(&clusters[cluster_id].centroid);
315            }
316        }
317
318        // Calculate silhouette score
319        let silhouette_score = self.calculate_silhouette(embeddings, &assignments, &clusters);
320
321        ClusteringResult {
322            clusters,
323            assignments,
324            silhouette_score,
325            iterations,
326        }
327    }
328
329    fn init_centroids<R: Rng>(
330        &self,
331        embeddings: &[CodeEmbedding],
332        k: usize,
333        rng: &mut R,
334    ) -> Vec<CodeEmbedding> {
335        if embeddings.is_empty() || k == 0 {
336            return vec![];
337        }
338
339        let mut centroids = Vec::with_capacity(k);
340
341        // First centroid: random
342        let first_idx = rng.random_range(0..embeddings.len());
343        centroids.push(embeddings[first_idx].clone());
344
345        // K-means++: choose remaining centroids proportional to squared distance
346        for _ in 1..k {
347            let distances: Vec<f32> = embeddings
348                .iter()
349                .map(|emb| {
350                    centroids
351                        .iter()
352                        .map(|c| emb.euclidean_distance(c))
353                        .fold(f32::MAX, f32::min)
354                        .powi(2)
355                })
356                .collect();
357
358            let total: f32 = distances.iter().sum();
359            if total <= 0.0 {
360                break;
361            }
362
363            let threshold = rng.random::<f32>() * total;
364            let mut cumsum = 0.0;
365            for (i, &d) in distances.iter().enumerate() {
366                cumsum += d;
367                if cumsum >= threshold {
368                    centroids.push(embeddings[i].clone());
369                    break;
370                }
371            }
372        }
373
374        centroids
375    }
376
377    fn find_nearest_centroid(&self, emb: &CodeEmbedding, centroids: &[CodeEmbedding]) -> usize {
378        centroids
379            .iter()
380            .enumerate()
381            .map(|(i, c)| (i, emb.euclidean_distance(c)))
382            .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
383            .map_or(0, |(i, _)| i)
384    }
385
386    fn update_centroids(
387        &self,
388        embeddings: &[CodeEmbedding],
389        assignments: &[usize],
390        k: usize,
391        dim: usize,
392    ) -> Vec<CodeEmbedding> {
393        let mut sums: Vec<Vec<f32>> = vec![vec![0.0; dim]; k];
394        let mut counts = vec![0usize; k];
395
396        for (i, &cluster_id) in assignments.iter().enumerate() {
397            if cluster_id < k {
398                counts[cluster_id] += 1;
399                for (j, &val) in embeddings[i].features.iter().enumerate() {
400                    if j < dim {
401                        sums[cluster_id][j] += val;
402                    }
403                }
404            }
405        }
406
407        sums.into_iter()
408            .zip(counts)
409            .map(|(sum, count)| {
410                if count > 0 {
411                    let features: Vec<f32> = sum.into_iter().map(|s| s / count as f32).collect();
412                    CodeEmbedding::from_vec(features)
413                } else {
414                    CodeEmbedding::new(dim)
415                }
416            })
417            .collect()
418    }
419
420    fn calculate_silhouette(
421        &self,
422        embeddings: &[CodeEmbedding],
423        assignments: &[usize],
424        clusters: &[Cluster],
425    ) -> f32 {
426        if embeddings.len() <= 1 || clusters.len() <= 1 {
427            return 0.0;
428        }
429
430        let mut total_score = 0.0;
431        let mut count = 0;
432
433        for (i, emb) in embeddings.iter().enumerate() {
434            let cluster_id = assignments[i];
435            if cluster_id >= clusters.len() {
436                continue;
437            }
438
439            // a(i): average distance to points in same cluster
440            let a = clusters[cluster_id].avg_intra_distance();
441
442            // b(i): minimum average distance to points in other clusters
443            let b = clusters
444                .iter()
445                .filter(|c| c.id != cluster_id)
446                .map(|c| emb.euclidean_distance(&c.centroid))
447                .fold(f32::MAX, f32::min);
448
449            if b < f32::MAX {
450                let max_ab = a.max(b);
451                if max_ab > 0.0 {
452                    total_score += (b - a) / max_ab;
453                    count += 1;
454                }
455            }
456        }
457
458        if count > 0 {
459            total_score / count as f32
460        } else {
461            0.0
462        }
463    }
464}
465
466/// Active learner using Thompson Sampling on clusters
467#[derive(Debug)]
468pub struct ActiveLearner {
469    /// Code embedder
470    embedder: CodeEmbedder,
471    /// Clustering algorithm
472    clustering: KMeansClustering,
473    /// Current clustering result
474    cluster_result: Option<ClusteringResult>,
475    /// Success counts per cluster (alpha for Beta dist)
476    success_counts: HashMap<usize, f64>,
477    /// Failure counts per cluster (beta for Beta dist)
478    failure_counts: HashMap<usize, f64>,
479    /// Total samples
480    total_samples: usize,
481    /// Exploration rate
482    exploration_rate: f64,
483}
484
485impl Default for ActiveLearner {
486    fn default() -> Self {
487        Self::new(5)
488    }
489}
490
491impl ActiveLearner {
492    /// Create active learner with k clusters
493    #[must_use]
494    pub fn new(k: usize) -> Self {
495        Self {
496            embedder: CodeEmbedder::default(),
497            clustering: KMeansClustering::new(k),
498            cluster_result: None,
499            success_counts: HashMap::new(),
500            failure_counts: HashMap::new(),
501            total_samples: 0,
502            exploration_rate: 0.1,
503        }
504    }
505
506    /// Create with custom embedder
507    #[must_use]
508    pub fn with_embedder(mut self, embedder: CodeEmbedder) -> Self {
509        self.embedder = embedder;
510        self
511    }
512
513    /// Set exploration rate
514    #[must_use]
515    pub fn with_exploration_rate(mut self, rate: f64) -> Self {
516        self.exploration_rate = rate.clamp(0.0, 1.0);
517        self
518    }
519
520    /// Fit clusters on code samples
521    pub fn fit(&mut self, codes: &[&str]) {
522        let embeddings: Vec<CodeEmbedding> = codes.iter().map(|c| self.embedder.embed(c)).collect();
523
524        self.cluster_result = Some(self.clustering.fit(&embeddings));
525    }
526
527    /// Get cluster for a code sample
528    #[must_use]
529    pub fn get_cluster(&self, code: &str) -> Option<usize> {
530        let embedding = self.embedder.embed(code);
531        self.cluster_result.as_ref().map(|result| {
532            result
533                .clusters
534                .iter()
535                .enumerate()
536                .map(|(i, c)| (i, embedding.euclidean_distance(&c.centroid)))
537                .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
538                .map_or(0, |(i, _)| i)
539        })
540    }
541
542    /// Sample cluster using Thompson Sampling
543    ///
544    /// Returns cluster ID with high expected value (exploration vs exploitation)
545    ///
546    /// # Panics
547    ///
548    /// Panics if Beta distribution creation fails with default parameters (1.0, 1.0),
549    /// which should never happen for valid alpha/beta values.
550    pub fn sample_cluster(&self) -> Option<usize> {
551        let result = self.cluster_result.as_ref()?;
552        if result.clusters.is_empty() {
553            return None;
554        }
555
556        let mut rng = rand::rng();
557
558        // Sample from Beta distribution for each cluster
559        let scores: Vec<(usize, f64)> = result
560            .clusters
561            .iter()
562            .map(|c| {
563                // Get counts with prior (Beta(1,1) = uniform)
564                let alpha = self.failure_counts.get(&c.id).copied().unwrap_or(0.0) + 1.0;
565                let beta = self.success_counts.get(&c.id).copied().unwrap_or(0.0) + 1.0;
566
567                // Sample from Beta distribution
568                #[allow(clippy::unwrap_used)]
569                let beta_dist =
570                    Beta::new(alpha, beta).unwrap_or_else(|_| Beta::new(1.0, 1.0).unwrap());
571                let score = beta_dist.sample(&mut rng);
572
573                (c.id, score)
574            })
575            .collect();
576
577        // Return cluster with highest sampled score
578        scores
579            .into_iter()
580            .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
581            .map(|(id, _)| id)
582    }
583
584    /// Select samples for next batch using Thompson Sampling
585    ///
586    /// Returns indices of codes to sample next
587    pub fn select_batch(&self, codes: &[&str], batch_size: usize) -> Vec<usize> {
588        if codes.is_empty() || batch_size == 0 {
589            return vec![];
590        }
591
592        let Some(_result) = &self.cluster_result else {
593            return (0..batch_size.min(codes.len())).collect();
594        };
595
596        let mut rng = rand::rng();
597        let mut selected = Vec::with_capacity(batch_size);
598        let mut remaining: Vec<usize> = (0..codes.len()).collect();
599
600        while selected.len() < batch_size && !remaining.is_empty() {
601            // Sample cluster using Thompson Sampling
602            let target_cluster = self.sample_cluster().unwrap_or(0);
603
604            // Find samples in target cluster
605            let in_cluster: Vec<usize> = remaining
606                .iter()
607                .filter(|&&i| {
608                    self.get_cluster(codes[i])
609                        .is_some_and(|c| c == target_cluster)
610                })
611                .copied()
612                .collect();
613
614            if in_cluster.is_empty() {
615                // Fallback: random selection
616                let idx = rng.random_range(0..remaining.len());
617                let sample_idx = remaining.remove(idx);
618                selected.push(sample_idx);
619            } else {
620                // Select from target cluster
621                let idx = rng.random_range(0..in_cluster.len());
622                let sample_idx = in_cluster[idx];
623                remaining.retain(|&x| x != sample_idx);
624                selected.push(sample_idx);
625            }
626        }
627
628        selected
629    }
630
631    /// Update with oracle feedback
632    ///
633    /// # Arguments
634    /// * `code` - The code that was verified
635    /// * `revealed_bug` - True if verification revealed a bug
636    pub fn update_feedback(&mut self, code: &str, revealed_bug: bool) {
637        if let Some(cluster_id) = self.get_cluster(code) {
638            if revealed_bug {
639                *self.failure_counts.entry(cluster_id).or_insert(0.0) += 1.0;
640            } else {
641                *self.success_counts.entry(cluster_id).or_insert(0.0) += 1.0;
642            }
643        }
644        self.total_samples += 1;
645    }
646
647    /// Get current silhouette score
648    #[must_use]
649    pub fn silhouette_score(&self) -> f32 {
650        self.cluster_result
651            .as_ref()
652            .map_or(0.0, |r| r.silhouette_score)
653    }
654
655    /// Get cluster statistics
656    #[must_use]
657    pub fn cluster_stats(&self) -> Vec<ClusterStats> {
658        self.cluster_result
659            .as_ref()
660            .map(|r| {
661                r.clusters
662                    .iter()
663                    .map(|c| {
664                        let successes = self.success_counts.get(&c.id).copied().unwrap_or(0.0);
665                        let failures = self.failure_counts.get(&c.id).copied().unwrap_or(0.0);
666                        let total = successes + failures;
667
668                        ClusterStats {
669                            cluster_id: c.id,
670                            size: c.size,
671                            bug_rate: if total > 0.0 { failures / total } else { 0.5 },
672                            #[allow(clippy::cast_sign_loss)]
673                            samples_tried: total.max(0.0) as usize,
674                        }
675                    })
676                    .collect()
677            })
678            .unwrap_or_default()
679    }
680
681    /// Check if exploration should be prioritized
682    #[must_use]
683    pub fn should_explore(&self) -> bool {
684        let mut rng = rand::rng();
685        rng.random::<f64>() < self.exploration_rate
686    }
687
688    /// Get total samples processed
689    #[must_use]
690    pub fn total_samples(&self) -> usize {
691        self.total_samples
692    }
693}
694
695/// Statistics for a cluster
696#[derive(Debug, Clone, Serialize, Deserialize)]
697pub struct ClusterStats {
698    /// Cluster ID
699    pub cluster_id: usize,
700    /// Number of samples in cluster
701    pub size: usize,
702    /// Bug revelation rate (0-1)
703    pub bug_rate: f64,
704    /// Number of samples tried from this cluster
705    pub samples_tried: usize,
706}
707
708#[cfg(test)]
709mod tests {
710    use super::*;
711
712    fn sample_codes() -> Vec<&'static str> {
713        vec![
714            "def add(a, b):\n    return a + b",
715            "def sub(a, b):\n    return a - b",
716            "for i in range(10):\n    print(i)",
717            "while True:\n    break",
718            "if x > 0:\n    return x\nelse:\n    return -x",
719            "class Foo:\n    def __init__(self):\n        pass",
720            "x = [1, 2, 3]\ny = sum(x)",
721            "import os\npath = os.getcwd()",
722        ]
723    }
724
725    // ========== CodeEmbedding Tests ==========
726
727    #[test]
728    fn test_code_embedding_new() {
729        let emb = CodeEmbedding::new(64);
730        assert_eq!(emb.dim, 64);
731        assert_eq!(emb.features.len(), 64);
732    }
733
734    #[test]
735    fn test_code_embedding_from_vec() {
736        let features = vec![1.0, 2.0, 3.0];
737        let emb = CodeEmbedding::from_vec(features.clone());
738        assert_eq!(emb.features, features);
739    }
740
741    #[test]
742    fn test_code_embedding_norm() {
743        let emb = CodeEmbedding::from_vec(vec![3.0, 4.0]);
744        assert!((emb.norm() - 5.0).abs() < 0.001);
745    }
746
747    #[test]
748    fn test_code_embedding_normalize() {
749        let mut emb = CodeEmbedding::from_vec(vec![3.0, 4.0]);
750        emb.normalize();
751        assert!((emb.norm() - 1.0).abs() < 0.001);
752    }
753
754    #[test]
755    fn test_code_embedding_cosine_similarity_same() {
756        let emb = CodeEmbedding::from_vec(vec![1.0, 2.0, 3.0]);
757        assert!((emb.cosine_similarity(&emb) - 1.0).abs() < 0.001);
758    }
759
760    #[test]
761    fn test_code_embedding_cosine_similarity_orthogonal() {
762        let emb1 = CodeEmbedding::from_vec(vec![1.0, 0.0]);
763        let emb2 = CodeEmbedding::from_vec(vec![0.0, 1.0]);
764        assert!(emb1.cosine_similarity(&emb2).abs() < 0.001);
765    }
766
767    #[test]
768    fn test_code_embedding_euclidean_distance() {
769        let emb1 = CodeEmbedding::from_vec(vec![0.0, 0.0]);
770        let emb2 = CodeEmbedding::from_vec(vec![3.0, 4.0]);
771        assert!((emb1.euclidean_distance(&emb2) - 5.0).abs() < 0.001);
772    }
773
774    // ========== CodeEmbedder Tests ==========
775
776    #[test]
777    fn test_code_embedder_default() {
778        let embedder = CodeEmbedder::default();
779        assert_eq!(embedder.n, 3);
780        assert_eq!(embedder.vocab_size, 128);
781    }
782
783    #[test]
784    fn test_code_embedder_embed() {
785        let embedder = CodeEmbedder::default();
786        let emb = embedder.embed("def foo(): return 1");
787        assert_eq!(emb.dim, 128);
788        assert!(emb.norm() > 0.0);
789    }
790
791    #[test]
792    fn test_code_embedder_similar_code() {
793        let embedder = CodeEmbedder::default();
794        let emb1 = embedder.embed("def add(a, b): return a + b");
795        let emb2 = embedder.embed("def add(x, y): return x + y");
796        let emb3 = embedder.embed("class Foo: pass");
797
798        // Similar functions should be more similar than different constructs
799        let sim_12 = emb1.cosine_similarity(&emb2);
800        let sim_13 = emb1.cosine_similarity(&emb3);
801        assert!(sim_12 > sim_13);
802    }
803
804    #[test]
805    fn test_code_embedder_empty() {
806        let embedder = CodeEmbedder::default();
807        let emb = embedder.embed("");
808        assert_eq!(emb.dim, 128);
809    }
810
811    // ========== KMeansClustering Tests ==========
812
813    #[test]
814    fn test_kmeans_default() {
815        let kmeans = KMeansClustering::default();
816        assert_eq!(kmeans.k, 5);
817    }
818
819    #[test]
820    fn test_kmeans_fit_empty() {
821        let kmeans = KMeansClustering::new(3);
822        let result = kmeans.fit(&[]);
823        assert!(result.clusters.is_empty());
824        assert!(result.assignments.is_empty());
825    }
826
827    #[test]
828    fn test_kmeans_fit() {
829        let embedder = CodeEmbedder::default();
830        let codes = sample_codes();
831        let embeddings: Vec<CodeEmbedding> = codes.iter().map(|c| embedder.embed(c)).collect();
832
833        let kmeans = KMeansClustering::new(3).with_seed(42);
834        let result = kmeans.fit(&embeddings);
835
836        assert_eq!(result.clusters.len(), 3);
837        assert_eq!(result.assignments.len(), codes.len());
838    }
839
840    #[test]
841    fn test_kmeans_silhouette_bounded() {
842        let embedder = CodeEmbedder::default();
843        let codes = sample_codes();
844        let embeddings: Vec<CodeEmbedding> = codes.iter().map(|c| embedder.embed(c)).collect();
845
846        let kmeans = KMeansClustering::new(3);
847        let result = kmeans.fit(&embeddings);
848
849        // Silhouette should be in [-1, 1]
850        assert!(result.silhouette_score >= -1.0);
851        assert!(result.silhouette_score <= 1.0);
852    }
853
854    // ========== ActiveLearner Tests ==========
855
856    #[test]
857    fn test_active_learner_new() {
858        let learner = ActiveLearner::new(5);
859        assert_eq!(learner.total_samples(), 0);
860    }
861
862    #[test]
863    fn test_active_learner_fit() {
864        let mut learner = ActiveLearner::new(3);
865        let codes = sample_codes();
866
867        learner.fit(&codes);
868
869        assert!(learner.silhouette_score() >= -1.0);
870    }
871
872    #[test]
873    fn test_active_learner_get_cluster() {
874        let mut learner = ActiveLearner::new(3);
875        let codes = sample_codes();
876
877        learner.fit(&codes);
878
879        let cluster = learner.get_cluster(codes[0]);
880        assert!(cluster.is_some());
881    }
882
883    #[test]
884    fn test_active_learner_sample_cluster() {
885        let mut learner = ActiveLearner::new(3);
886        let codes = sample_codes();
887
888        learner.fit(&codes);
889
890        let cluster = learner.sample_cluster();
891        assert!(cluster.is_some());
892    }
893
894    #[test]
895    fn test_active_learner_select_batch() {
896        let mut learner = ActiveLearner::new(3);
897        let codes = sample_codes();
898
899        learner.fit(&codes);
900
901        let batch = learner.select_batch(&codes, 3);
902        assert_eq!(batch.len(), 3);
903        // All indices should be unique
904        let mut sorted = batch.clone();
905        sorted.sort();
906        sorted.dedup();
907        assert_eq!(sorted.len(), batch.len());
908    }
909
910    #[test]
911    fn test_active_learner_update_feedback() {
912        let mut learner = ActiveLearner::new(3);
913        let codes = sample_codes();
914
915        learner.fit(&codes);
916
917        learner.update_feedback(codes[0], true);
918        learner.update_feedback(codes[1], false);
919
920        assert_eq!(learner.total_samples(), 2);
921    }
922
923    #[test]
924    fn test_active_learner_cluster_stats() {
925        let mut learner = ActiveLearner::new(3);
926        let codes = sample_codes();
927
928        learner.fit(&codes);
929
930        // Add some feedback
931        for (i, &code) in codes.iter().enumerate() {
932            learner.update_feedback(code, i % 2 == 0);
933        }
934
935        let stats = learner.cluster_stats();
936        assert!(!stats.is_empty());
937    }
938
939    #[test]
940    fn test_active_learner_exploration_rate() {
941        let learner = ActiveLearner::new(3).with_exploration_rate(1.0);
942
943        // With rate=1.0, should always explore
944        let mut explored = 0;
945        for _ in 0..100 {
946            if learner.should_explore() {
947                explored += 1;
948            }
949        }
950        assert_eq!(explored, 100);
951    }
952
953    // ========== Debug Tests ==========
954
955    #[test]
956    fn test_code_embedding_debug() {
957        let emb = CodeEmbedding::new(4);
958        let debug = format!("{emb:?}");
959        assert!(debug.contains("CodeEmbedding"));
960    }
961
962    #[test]
963    fn test_code_embedder_debug() {
964        let embedder = CodeEmbedder::default();
965        let debug = format!("{embedder:?}");
966        assert!(debug.contains("CodeEmbedder"));
967    }
968
969    #[test]
970    fn test_cluster_debug() {
971        let cluster = Cluster::new(0, CodeEmbedding::new(4));
972        let debug = format!("{cluster:?}");
973        assert!(debug.contains("Cluster"));
974    }
975
976    #[test]
977    fn test_active_learner_debug() {
978        let learner = ActiveLearner::new(3);
979        let debug = format!("{learner:?}");
980        assert!(debug.contains("ActiveLearner"));
981    }
982}
983
984/// Property-based tests
985#[cfg(test)]
986mod proptests {
987    use super::*;
988    use proptest::prelude::*;
989
990    proptest! {
991        /// Embedding norm is non-negative
992        #[test]
993        fn prop_embedding_norm_nonnegative(features in proptest::collection::vec(-100.0f32..100.0, 1..50)) {
994            let emb = CodeEmbedding::from_vec(features);
995            prop_assert!(emb.norm() >= 0.0);
996        }
997
998        /// Cosine similarity is bounded [-1, 1]
999        #[test]
1000        fn prop_cosine_bounded(
1001            f1 in proptest::collection::vec(-10.0f32..10.0, 1..20),
1002            f2 in proptest::collection::vec(-10.0f32..10.0, 1..20),
1003        ) {
1004            let dim = f1.len().min(f2.len());
1005            let emb1 = CodeEmbedding::from_vec(f1[..dim].to_vec());
1006            let emb2 = CodeEmbedding::from_vec(f2[..dim].to_vec());
1007
1008            let sim = emb1.cosine_similarity(&emb2);
1009            prop_assert!(sim >= -1.0 - 0.001);
1010            prop_assert!(sim <= 1.0 + 0.001);
1011        }
1012
1013        /// Euclidean distance is non-negative
1014        #[test]
1015        fn prop_euclidean_nonnegative(
1016            f1 in proptest::collection::vec(-100.0f32..100.0, 1..20),
1017            f2 in proptest::collection::vec(-100.0f32..100.0, 1..20),
1018        ) {
1019            let dim = f1.len().min(f2.len());
1020            let emb1 = CodeEmbedding::from_vec(f1[..dim].to_vec());
1021            let emb2 = CodeEmbedding::from_vec(f2[..dim].to_vec());
1022
1023            prop_assert!(emb1.euclidean_distance(&emb2) >= 0.0);
1024        }
1025
1026        /// Normalized vectors have unit norm
1027        #[test]
1028        fn prop_normalized_unit_norm(features in proptest::collection::vec(0.1f32..10.0, 1..20)) {
1029            let mut emb = CodeEmbedding::from_vec(features);
1030            emb.normalize();
1031
1032            // Should be close to 1.0 (allow small floating point error)
1033            prop_assert!((emb.norm() - 1.0).abs() < 0.01);
1034        }
1035
1036        /// Batch selection returns valid indices
1037        #[test]
1038        fn prop_batch_indices_valid(batch_size in 1usize..10) {
1039            let mut learner = ActiveLearner::new(3);
1040            let codes: Vec<&str> = vec![
1041                "x = 1",
1042                "y = 2",
1043                "z = 3",
1044                "def f(): pass",
1045                "class C: pass",
1046            ];
1047
1048            learner.fit(&codes);
1049
1050            let batch = learner.select_batch(&codes, batch_size);
1051
1052            for &idx in &batch {
1053                prop_assert!(idx < codes.len());
1054            }
1055        }
1056    }
1057}