use crate::error::{OptimizeError, OptimizeResult};
use scirs2_core::random::{rngs::StdRng, RngExt, SeedableRng};
#[derive(Debug, Clone, Copy, PartialEq)]
#[non_exhaustive]
pub enum SDCALoss {
Hinge,
SquaredHinge,
Logistic,
SquaredLoss,
}
#[derive(Debug, Clone)]
pub struct SDCAConfig {
pub lambda: f64,
pub max_epochs: usize,
pub tolerance: f64,
pub loss: SDCALoss,
pub seed: Option<u64>,
}
impl Default for SDCAConfig {
fn default() -> Self {
Self {
lambda: 1e-4,
max_epochs: 100,
tolerance: 1e-6,
loss: SDCALoss::SquaredLoss,
seed: Some(42),
}
}
}
#[derive(Debug, Clone)]
pub struct SDCAResult {
pub weights: Vec<f64>,
pub dual_variables: Vec<f64>,
pub primal_objective: f64,
pub dual_objective: f64,
pub duality_gap: f64,
pub epochs: usize,
pub converged: bool,
}
fn dot_product(features: &[f64], weights: &[f64]) -> f64 {
features
.iter()
.zip(weights.iter())
.map(|(f, w)| f * w)
.sum()
}
fn squared_norm(v: &[f64]) -> f64 {
v.iter().map(|x| x * x).sum()
}
fn primal_loss_single(prediction: f64, label: f64, loss: SDCALoss) -> f64 {
match loss {
SDCALoss::Hinge => {
let margin = label * prediction;
(1.0 - margin).max(0.0)
}
SDCALoss::SquaredHinge => {
let margin = label * prediction;
let hinge = (1.0 - margin).max(0.0);
hinge * hinge
}
SDCALoss::Logistic => {
let margin = label * prediction;
if margin > 15.0 {
(-margin).exp() } else if margin < -15.0 {
-margin
} else {
(1.0 + (-margin).exp()).ln()
}
}
SDCALoss::SquaredLoss => {
let diff = label - prediction;
0.5 * diff * diff
}
}
}
fn conjugate_loss(alpha: f64, label: f64, loss: SDCALoss) -> f64 {
match loss {
SDCALoss::Hinge => {
let ay = alpha * label;
if ay >= -1.0 && ay <= 0.0 {
ay
} else {
f64::INFINITY
}
}
SDCALoss::SquaredHinge => {
let ay = alpha * label;
if ay <= 0.0 {
ay + 0.25 * alpha * alpha
} else {
f64::INFINITY
}
}
SDCALoss::Logistic => {
let p = -alpha * label;
if p <= 0.0 || p >= 1.0 {
if (p - 0.0).abs() < 1e-15 || (p - 1.0).abs() < 1e-15 {
0.0
} else {
f64::INFINITY
}
} else {
p * p.ln() + (1.0 - p) * (1.0 - p).ln()
}
}
SDCALoss::SquaredLoss => {
0.5 * alpha * alpha + alpha * label
}
}
}
fn compute_dual_update(
alpha_i: f64,
xi_dot_w: f64,
label: f64,
xi_norm_sq: f64,
n: f64,
lambda: f64,
loss: SDCALoss,
) -> f64 {
let q = xi_norm_sq / (lambda * n);
match loss {
SDCALoss::Hinge => {
let s = alpha_i * label;
let margin = label * xi_dot_w;
let delta_s = (1.0 - margin - s) / (q + 1.0);
let new_s = (s + delta_s).max(0.0).min(1.0);
new_s * label
}
SDCALoss::SquaredHinge => {
let s = alpha_i * label;
let margin = label * xi_dot_w;
let delta_s = (1.0 - margin - s) / (q + 1.0 + 0.5);
let new_s = (s + delta_s).max(0.0);
new_s * label
}
SDCALoss::Logistic => {
let s = (alpha_i * label).max(0.0).min(1.0);
let margin = label * xi_dot_w;
let sigmoid = 1.0 / (1.0 + (-margin).exp());
let target = 1.0 - sigmoid;
let delta = (target - s) / (q + 1.0);
let new_s = (s + delta).max(1e-10).min(1.0 - 1e-10);
new_s * label
}
SDCALoss::SquaredLoss => {
let delta = (label - xi_dot_w - alpha_i) / (q + 1.0);
alpha_i + delta
}
}
}
pub fn sdca(
features: &[Vec<f64>],
labels: &[f64],
config: &SDCAConfig,
) -> OptimizeResult<SDCAResult> {
let n_samples = features.len();
if n_samples == 0 {
return Err(OptimizeError::InvalidInput(
"Must provide at least one training sample".to_string(),
));
}
if labels.len() != n_samples {
return Err(OptimizeError::InvalidInput(format!(
"Number of labels ({}) must match number of feature vectors ({})",
labels.len(),
n_samples
)));
}
let n_features = features[0].len();
if n_features == 0 {
return Err(OptimizeError::InvalidInput(
"Feature vectors must have at least one dimension".to_string(),
));
}
for (i, feat) in features.iter().enumerate() {
if feat.len() != n_features {
return Err(OptimizeError::InvalidInput(format!(
"Feature vector {} has length {} but expected {}",
i,
feat.len(),
n_features
)));
}
}
if config.lambda <= 0.0 {
return Err(OptimizeError::InvalidInput(format!(
"Regularization parameter lambda must be positive, got {}",
config.lambda
)));
}
let n = n_samples as f64;
let lambda = config.lambda;
let mut rng = match config.seed {
Some(seed) => StdRng::seed_from_u64(seed),
None => StdRng::seed_from_u64(0),
};
let xi_norms_sq: Vec<f64> = features.iter().map(|x| squared_norm(x)).collect();
let mut alpha = vec![0.0; n_samples];
let mut weights = vec![0.0; n_features];
let mut converged = false;
let mut epochs = 0;
for epoch in 0..config.max_epochs {
epochs = epoch + 1;
let mut indices: Vec<usize> = (0..n_samples).collect();
for i in (1..n_samples).rev() {
let j = rng.random_range(0..=(i as u64)) as usize;
indices.swap(i, j);
}
for &i in &indices {
let xi = &features[i];
let yi = labels[i];
let xi_norm_sq = xi_norms_sq[i];
if xi_norm_sq < 1e-30 {
continue; }
let xi_dot_w = dot_product(xi, &weights);
let old_alpha = alpha[i];
let new_alpha =
compute_dual_update(old_alpha, xi_dot_w, yi, xi_norm_sq, n, lambda, config.loss);
let delta_alpha = new_alpha - old_alpha;
alpha[i] = new_alpha;
let scale = delta_alpha / (lambda * n);
for j in 0..n_features {
weights[j] += scale * xi[j];
}
}
let primal_obj = compute_primal_objective(features, labels, &weights, lambda, config.loss);
let dual_obj = compute_dual_objective(features, labels, &alpha, &weights, lambda, config.loss);
let gap = primal_obj - dual_obj;
if gap.abs() < config.tolerance {
converged = true;
break;
}
}
let primal_objective =
compute_primal_objective(features, labels, &weights, lambda, config.loss);
let dual_objective =
compute_dual_objective(features, labels, &alpha, &weights, lambda, config.loss);
let duality_gap = primal_objective - dual_objective;
Ok(SDCAResult {
weights,
dual_variables: alpha,
primal_objective,
dual_objective,
duality_gap,
epochs,
converged,
})
}
fn compute_primal_objective(
features: &[Vec<f64>],
labels: &[f64],
weights: &[f64],
lambda: f64,
loss: SDCALoss,
) -> f64 {
let n = features.len() as f64;
let empirical_risk: f64 = features
.iter()
.zip(labels.iter())
.map(|(xi, &yi)| {
let prediction = dot_product(xi, weights);
primal_loss_single(prediction, yi, loss)
})
.sum::<f64>()
/ n;
let regularization = 0.5 * lambda * squared_norm(weights);
empirical_risk + regularization
}
fn compute_dual_objective(
features: &[Vec<f64>],
labels: &[f64],
alpha: &[f64],
_weights: &[f64],
lambda: f64,
loss: SDCALoss,
) -> f64 {
let n = features.len() as f64;
let conj_sum: f64 = alpha
.iter()
.zip(labels.iter())
.map(|(&ai, &yi)| conjugate_loss(-ai, yi, loss))
.sum::<f64>()
/ n;
let n_features = features[0].len();
let mut sum_alpha_x = vec![0.0; n_features];
for (i, xi) in features.iter().enumerate() {
for j in 0..n_features {
sum_alpha_x[j] += alpha[i] * xi[j];
}
}
let norm_sq = squared_norm(&sum_alpha_x);
-conj_sum - norm_sq / (2.0 * lambda * n * n)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sdca_hinge_separable() {
let features = vec![
vec![1.0, 0.0],
vec![0.0, 1.0],
vec![2.0, 1.0],
vec![-1.0, 0.0],
vec![0.0, -1.0],
vec![-2.0, -1.0],
];
let labels = vec![1.0, 1.0, 1.0, -1.0, -1.0, -1.0];
let config = SDCAConfig {
lambda: 0.01,
max_epochs: 200,
tolerance: 1e-4,
loss: SDCALoss::Hinge,
seed: Some(42),
};
let result = sdca(&features, &labels, &config);
assert!(result.is_ok());
let result = result.expect("SDCA should succeed");
for (xi, &yi) in features.iter().zip(labels.iter()) {
let pred = dot_product(xi, &result.weights);
assert!(
pred * yi > -0.5,
"Misclassification: pred={}, label={}",
pred,
yi
);
}
}
#[test]
fn test_sdca_ridge_regression() {
let features = vec![
vec![1.0, 0.0],
vec![0.0, 1.0],
vec![1.0, 1.0],
vec![2.0, 0.0],
vec![0.0, 2.0],
vec![2.0, 1.0],
vec![1.0, 2.0],
vec![3.0, 1.0],
];
let labels: Vec<f64> = features
.iter()
.map(|x| 2.0 * x[0] + 3.0 * x[1])
.collect();
let config = SDCAConfig {
lambda: 0.001,
max_epochs: 500,
tolerance: 1e-6,
loss: SDCALoss::SquaredLoss,
seed: Some(42),
};
let result = sdca(&features, &labels, &config);
assert!(result.is_ok());
let result = result.expect("SDCA should succeed");
assert!(
(result.weights[0] - 2.0).abs() < 0.5,
"w[0]={}, expected ~2.0",
result.weights[0]
);
assert!(
(result.weights[1] - 3.0).abs() < 0.5,
"w[1]={}, expected ~3.0",
result.weights[1]
);
}
#[test]
fn test_sdca_logistic() {
let features = vec![
vec![3.0, 3.0],
vec![4.0, 3.0],
vec![3.0, 4.0],
vec![-3.0, -3.0],
vec![-4.0, -3.0],
vec![-3.0, -4.0],
];
let labels = vec![1.0, 1.0, 1.0, -1.0, -1.0, -1.0];
let config = SDCAConfig {
lambda: 0.01,
max_epochs: 500,
tolerance: 1e-6,
loss: SDCALoss::Logistic,
seed: Some(42),
};
let result = sdca(&features, &labels, &config);
assert!(result.is_ok());
let result = result.expect("SDCA should succeed");
let mut correct = 0;
for (xi, &yi) in features.iter().zip(labels.iter()) {
let pred = dot_product(xi, &result.weights);
if pred * yi > 0.0 {
correct += 1;
}
}
assert!(
correct >= 4,
"Only {}/6 correct classifications",
correct
);
}
#[test]
fn test_sdca_squared_hinge() {
let features = vec![
vec![1.0, 1.0],
vec![2.0, 2.0],
vec![-1.0, -1.0],
vec![-2.0, -2.0],
];
let labels = vec![1.0, 1.0, -1.0, -1.0];
let config = SDCAConfig {
lambda: 0.1,
max_epochs: 200,
tolerance: 1e-4,
loss: SDCALoss::SquaredHinge,
seed: Some(42),
};
let result = sdca(&features, &labels, &config);
assert!(result.is_ok());
let result = result.expect("SDCA should succeed");
assert!(
result.primal_objective.is_finite(),
"Primal objective is not finite: {}",
result.primal_objective
);
}
#[test]
fn test_empty_features() {
let features: Vec<Vec<f64>> = vec![];
let labels: Vec<f64> = vec![];
let config = SDCAConfig::default();
let result = sdca(&features, &labels, &config);
assert!(result.is_err());
}
#[test]
fn test_mismatched_dimensions() {
let features = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
let labels = vec![1.0]; let config = SDCAConfig::default();
let result = sdca(&features, &labels, &config);
assert!(result.is_err());
}
#[test]
fn test_invalid_lambda() {
let features = vec![vec![1.0]];
let labels = vec![1.0];
let config = SDCAConfig {
lambda: -1.0,
..SDCAConfig::default()
};
let result = sdca(&features, &labels, &config);
assert!(result.is_err());
}
#[test]
fn test_primal_decreases_squared_loss() {
let features = vec![
vec![1.0, 0.0],
vec![0.0, 1.0],
vec![1.0, 1.0],
vec![2.0, 1.0],
];
let labels = vec![1.0, 2.0, 3.0, 4.0];
let config_few = SDCAConfig {
lambda: 0.01,
max_epochs: 5,
tolerance: 0.0,
loss: SDCALoss::SquaredLoss,
seed: Some(42),
};
let config_many = SDCAConfig {
lambda: 0.01,
max_epochs: 200,
tolerance: 0.0,
loss: SDCALoss::SquaredLoss,
seed: Some(42),
};
let result_few = sdca(&features, &labels, &config_few);
let result_many = sdca(&features, &labels, &config_many);
assert!(result_few.is_ok());
assert!(result_many.is_ok());
let r_few = result_few.expect("should succeed");
let r_many = result_many.expect("should succeed");
assert!(
r_many.primal_objective <= r_few.primal_objective + 1e-8,
"More epochs didn't help: {} vs {}",
r_many.primal_objective,
r_few.primal_objective
);
}
#[test]
fn test_dual_variables_populated() {
let features = vec![vec![1.0], vec![-1.0]];
let labels = vec![1.0, -1.0];
let config = SDCAConfig {
lambda: 0.1,
max_epochs: 50,
tolerance: 1e-8,
loss: SDCALoss::SquaredLoss,
seed: Some(42),
};
let result = sdca(&features, &labels, &config);
assert!(result.is_ok());
let result = result.expect("should succeed");
assert_eq!(result.dual_variables.len(), 2);
let any_nonzero = result.dual_variables.iter().any(|&a| a.abs() > 1e-15);
assert!(any_nonzero, "All dual variables are zero");
}
#[test]
fn test_inconsistent_feature_dims() {
let features = vec![vec![1.0, 2.0], vec![3.0]]; let labels = vec![1.0, -1.0];
let config = SDCAConfig::default();
let result = sdca(&features, &labels, &config);
assert!(result.is_err());
}
}