use crate::features::CommitFeatures;
use anyhow::Result;
pub struct Smote {
k_neighbors: usize,
}
impl Smote {
pub fn new() -> Self {
Self { k_neighbors: 5 }
}
pub fn with_k(k_neighbors: usize) -> Self {
Self { k_neighbors }
}
pub fn oversample(
&self,
features: &[CommitFeatures],
minority_category: u8,
target_ratio: f32,
) -> Result<Vec<CommitFeatures>> {
let minority: Vec<&CommitFeatures> = features
.iter()
.filter(|f| f.defect_category == minority_category)
.collect();
let majority_count = features.len() - minority.len();
if minority.is_empty() {
anyhow::bail!("No samples found for category {}", minority_category);
}
let target_minority = (majority_count as f32 * target_ratio) as usize;
let samples_needed = target_minority.saturating_sub(minority.len());
if samples_needed == 0 {
return Ok(features.to_vec());
}
let minority_vecs: Vec<Vec<f32>> = minority.iter().map(|f| f.to_vector()).collect();
let mut synthetic = Vec::with_capacity(samples_needed);
let mut sample_idx = 0;
while synthetic.len() < samples_needed {
let base_idx = sample_idx % minority.len();
let base = &minority_vecs[base_idx];
let neighbors = self.find_k_nearest(base, &minority_vecs, base_idx);
let neighbor_idx = neighbors[sample_idx % neighbors.len()];
let neighbor = &minority_vecs[neighbor_idx];
let synthetic_vec = self.interpolate(base, neighbor, sample_idx);
let synthetic_feature = self.vector_to_features(&synthetic_vec, minority_category);
synthetic.push(synthetic_feature);
sample_idx += 1;
}
let mut result = features.to_vec();
result.extend(synthetic);
Ok(result)
}
fn find_k_nearest(&self, target: &[f32], all: &[Vec<f32>], exclude_idx: usize) -> Vec<usize> {
let mut distances: Vec<(usize, f32)> = all
.iter()
.enumerate()
.filter(|(i, _)| *i != exclude_idx)
.map(|(i, v)| (i, self.euclidean_distance(target, v)))
.collect();
distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
distances
.iter()
.take(self.k_neighbors.min(distances.len()))
.map(|(i, _)| *i)
.collect()
}
fn euclidean_distance(&self, a: &[f32], b: &[f32]) -> f32 {
a.iter()
.zip(b.iter())
.map(|(x, y)| (x - y).powi(2))
.sum::<f32>()
.sqrt()
}
fn interpolate(&self, base: &[f32], neighbor: &[f32], seed: usize) -> Vec<f32> {
let gap = ((seed * 17 + 42) % 100) as f32 / 100.0;
base.iter()
.zip(neighbor.iter())
.map(|(b, n)| b + gap * (n - b))
.collect()
}
fn vector_to_features(&self, vec: &[f32], category: u8) -> CommitFeatures {
CommitFeatures {
defect_category: category,
files_changed: vec[1].max(0.0),
lines_added: vec[2].max(0.0),
lines_deleted: vec[3].max(0.0),
complexity_delta: vec[4],
timestamp: vec[5] as f64,
hour_of_day: (vec[6] as u8).min(23),
day_of_week: (vec[7] as u8).min(6),
error_code_class: if vec.len() > 8 { vec[8] as u8 } else { 4 },
has_suggestion: if vec.len() > 9 { vec[9] as u8 } else { 0 },
suggestion_applicability: if vec.len() > 10 { vec[10] as u8 } else { 0 },
clippy_lint_count: if vec.len() > 11 { vec[11] as u8 } else { 0 },
span_line_delta: if vec.len() > 12 { vec[12] } else { 0.0 },
diagnostic_confidence: if vec.len() > 13 { vec[13] } else { 0.0 },
}
}
}
impl Default for Smote {
fn default() -> Self {
Self::new()
}
}
pub struct FocalLoss {
gamma: f32, alpha: Vec<f32>, }
impl FocalLoss {
pub fn new() -> Self {
Self {
gamma: 2.0,
alpha: vec![1.0; 10], }
}
pub fn with_params(gamma: f32, alpha: Vec<f32>) -> Self {
Self { gamma, alpha }
}
pub fn compute_weights(features: &[CommitFeatures]) -> Vec<f32> {
let mut counts = [0usize; 10];
for f in features {
let idx = (f.defect_category as usize).min(9);
counts[idx] += 1;
}
let total = features.len() as f32;
let k = counts.iter().filter(|&&c| c > 0).count() as f32;
counts
.iter()
.map(|&c| if c > 0 { total / (k * c as f32) } else { 0.0 })
.collect()
}
pub fn loss(&self, prob: f32, class: u8) -> f32 {
let p_t = prob.clamp(1e-7, 1.0 - 1e-7);
let alpha_t = self.alpha.get(class as usize).copied().unwrap_or(1.0);
-alpha_t * (1.0 - p_t).powf(self.gamma) * p_t.ln()
}
pub fn batch_loss(&self, probs: &[f32], classes: &[u8]) -> f32 {
probs
.iter()
.zip(classes.iter())
.map(|(&p, &c)| self.loss(p, c))
.sum::<f32>()
/ probs.len() as f32
}
}
impl Default for FocalLoss {
fn default() -> Self {
Self::new()
}
}
pub struct AuprcMetric;
impl AuprcMetric {
pub fn compute(predictions: &[f32], labels: &[u8]) -> Result<f32> {
if predictions.len() != labels.len() {
anyhow::bail!("Predictions and labels must have same length");
}
if predictions.is_empty() {
anyhow::bail!("Empty predictions");
}
let mut pairs: Vec<(f32, u8)> = predictions
.iter()
.copied()
.zip(labels.iter().copied())
.collect();
pairs.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap());
let total_positives = labels.iter().filter(|&&l| l == 1).count() as f32;
if total_positives == 0.0 {
anyhow::bail!("No positive samples in labels");
}
let mut true_positives = 0.0;
let mut false_positives = 0.0;
let mut auprc = 0.0;
let mut prev_recall = 0.0;
for (_, label) in &pairs {
if *label == 1 {
true_positives += 1.0;
} else {
false_positives += 1.0;
}
let precision = true_positives / (true_positives + false_positives);
let recall = true_positives / total_positives;
auprc += precision * (recall - prev_recall);
prev_recall = recall;
}
Ok(auprc)
}
pub fn precision_at_recall(
predictions: &[f32],
labels: &[u8],
target_recall: f32,
) -> Result<f32> {
if predictions.len() != labels.len() {
anyhow::bail!("Predictions and labels must have same length");
}
let mut pairs: Vec<(f32, u8)> = predictions
.iter()
.copied()
.zip(labels.iter().copied())
.collect();
pairs.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap());
let total_positives = labels.iter().filter(|&&l| l == 1).count() as f32;
if total_positives == 0.0 {
anyhow::bail!("No positive samples");
}
let mut true_positives = 0.0;
let mut false_positives = 0.0;
for (_, label) in &pairs {
if *label == 1 {
true_positives += 1.0;
} else {
false_positives += 1.0;
}
let recall = true_positives / total_positives;
if recall >= target_recall {
let precision = true_positives / (true_positives + false_positives);
return Ok(precision);
}
}
Ok(0.0)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_feature(category: u8, files: u32) -> CommitFeatures {
CommitFeatures {
defect_category: category,
files_changed: files as f32,
lines_added: 100.0,
lines_deleted: 50.0,
complexity_delta: 0.0,
timestamp: 1700000000.0,
hour_of_day: 10,
day_of_week: 1,
..Default::default()
}
}
#[test]
fn test_smote_creation() {
let smote = Smote::new();
assert_eq!(smote.k_neighbors, 5);
let smote_k3 = Smote::with_k(3);
assert_eq!(smote_k3.k_neighbors, 3);
}
#[test]
fn test_smote_oversample() {
let mut features = Vec::new();
for i in 0..90 {
features.push(make_feature(0, i));
}
for i in 0..10 {
features.push(make_feature(1, i + 100));
}
let smote = Smote::new();
let balanced = smote.oversample(&features, 1, 0.5).unwrap();
assert!(balanced.len() > features.len());
let minority_count = balanced.iter().filter(|f| f.defect_category == 1).count();
assert!(minority_count > 10);
}
#[test]
fn test_smote_no_minority() {
let features = vec![make_feature(0, 1), make_feature(0, 2)];
let smote = Smote::new();
let result = smote.oversample(&features, 1, 0.5);
assert!(result.is_err());
}
#[test]
fn test_focal_loss_creation() {
let fl = FocalLoss::new();
assert_eq!(fl.gamma, 2.0);
assert_eq!(fl.alpha.len(), 10);
}
#[test]
fn test_focal_loss_compute_weights() {
let features = vec![
make_feature(0, 1),
make_feature(0, 2),
make_feature(0, 3),
make_feature(1, 4),
];
let weights = FocalLoss::compute_weights(&features);
assert!(weights[1] > weights[0]);
}
#[test]
fn test_focal_loss_value() {
let fl = FocalLoss::new();
let loss_high = fl.loss(0.9, 0);
let loss_low = fl.loss(0.3, 0);
assert!(loss_high < loss_low);
}
#[test]
fn test_auprc_perfect() {
let predictions = vec![0.9, 0.8, 0.2, 0.1];
let labels = vec![1, 1, 0, 0];
let auprc = AuprcMetric::compute(&predictions, &labels).unwrap();
assert!((auprc - 1.0).abs() < 0.01);
}
#[test]
fn test_auprc_range() {
let predictions = vec![0.9, 0.7, 0.5, 0.3, 0.2, 0.15, 0.1, 0.05, 0.02, 0.01];
let labels = vec![1, 0, 0, 0, 0, 0, 0, 0, 0, 0];
let auprc = AuprcMetric::compute(&predictions, &labels).unwrap();
assert!((auprc - 1.0).abs() < 0.01);
}
#[test]
fn test_precision_at_recall() {
let predictions = vec![0.9, 0.7, 0.5, 0.3];
let labels = vec![1, 1, 0, 0];
let p_at_50 = AuprcMetric::precision_at_recall(&predictions, &labels, 0.5).unwrap();
assert!((p_at_50 - 1.0).abs() < 0.01);
let p_at_100 = AuprcMetric::precision_at_recall(&predictions, &labels, 1.0).unwrap();
assert!((p_at_100 - 1.0).abs() < 0.01); }
#[test]
fn test_smote_default() {
let smote = Smote::default();
assert_eq!(smote.k_neighbors, 5);
}
#[test]
fn test_smote_no_samples_needed() {
let features = vec![
make_feature(0, 1),
make_feature(0, 2),
make_feature(1, 10),
make_feature(1, 11),
];
let smote = Smote::new();
let result = smote.oversample(&features, 1, 0.5).unwrap();
assert_eq!(result.len(), features.len());
}
#[test]
fn test_smote_vector_to_features_clamping() {
let smote = Smote::new();
let vec = vec![0.0, -5.0, -10.0, -1.0, 0.5, 1700000000.0, 25.0, 8.0];
let feature = smote.vector_to_features(&vec, 3);
assert_eq!(feature.files_changed, 0.0); assert_eq!(feature.lines_added, 0.0); assert_eq!(feature.lines_deleted, 0.0); assert_eq!(feature.hour_of_day, 23); assert_eq!(feature.day_of_week, 6); }
#[test]
fn test_focal_loss_with_params() {
let weights = vec![2.0, 1.5, 1.0, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5];
let fl = FocalLoss::with_params(3.0, weights.clone());
assert_eq!(fl.gamma, 3.0);
assert_eq!(fl.alpha, weights);
}
#[test]
fn test_focal_loss_default() {
let fl = FocalLoss::default();
assert_eq!(fl.gamma, 2.0);
assert_eq!(fl.alpha.len(), 10);
}
#[test]
fn test_focal_loss_batch() {
let fl = FocalLoss::new();
let probs = vec![0.9, 0.8, 0.7, 0.6];
let classes = vec![0, 0, 1, 1];
let batch_loss = fl.batch_loss(&probs, &classes);
assert!(batch_loss > 0.0);
let expected =
(fl.loss(0.9, 0) + fl.loss(0.8, 0) + fl.loss(0.7, 1) + fl.loss(0.6, 1)) / 4.0;
assert!((batch_loss - expected).abs() < 0.001);
}
#[test]
fn test_focal_loss_compute_weights_edge_cases() {
let features = vec![make_feature(0, 1), make_feature(0, 2), make_feature(0, 3)];
let weights = FocalLoss::compute_weights(&features);
assert!(weights[0] > 0.0);
assert_eq!(weights[1], 0.0); }
#[test]
fn test_auprc_length_mismatch() {
let predictions = vec![0.9, 0.8, 0.7];
let labels = vec![1, 0];
let result = AuprcMetric::compute(&predictions, &labels);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("same length"));
}
#[test]
fn test_auprc_empty() {
let predictions: Vec<f32> = vec![];
let labels: Vec<u8> = vec![];
let result = AuprcMetric::compute(&predictions, &labels);
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("Empty predictions"));
}
#[test]
fn test_auprc_no_positives() {
let predictions = vec![0.9, 0.8, 0.7];
let labels = vec![0, 0, 0];
let result = AuprcMetric::compute(&predictions, &labels);
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("No positive samples"));
}
#[test]
fn test_precision_at_recall_length_mismatch() {
let predictions = vec![0.9, 0.8];
let labels = vec![1, 0, 0];
let result = AuprcMetric::precision_at_recall(&predictions, &labels, 0.5);
assert!(result.is_err());
}
#[test]
fn test_precision_at_recall_no_positives() {
let predictions = vec![0.9, 0.8, 0.7];
let labels = vec![0, 0, 0];
let result = AuprcMetric::precision_at_recall(&predictions, &labels, 0.5);
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("No positive samples"));
}
#[test]
fn test_precision_at_recall_target_not_reached() {
let predictions = vec![0.9, 0.7, 0.3];
let labels = vec![0, 0, 1];
let p = AuprcMetric::precision_at_recall(&predictions, &labels, 1.0).unwrap();
assert!((p - 1.0 / 3.0).abs() < 0.01); }
#[test]
fn test_smote_interpolate_deterministic() {
let smote = Smote::new();
let base = vec![1.0, 2.0, 3.0];
let neighbor = vec![2.0, 4.0, 6.0];
let synthetic1 = smote.interpolate(&base, &neighbor, 0);
let synthetic2 = smote.interpolate(&base, &neighbor, 0);
assert_eq!(synthetic1, synthetic2);
let synthetic3 = smote.interpolate(&base, &neighbor, 1);
assert_ne!(synthetic1, synthetic3);
}
}