use crate::error::{SeqError, SeqResult};
#[derive(Debug, Clone, Copy)]
pub struct FullCuttingPlaneConfig {
pub c_reg: f32,
pub epsilon: f32,
pub max_iter: usize,
}
impl Default for FullCuttingPlaneConfig {
fn default() -> Self {
Self {
c_reg: 1.0,
epsilon: 1e-3,
max_iter: 100,
}
}
}
#[derive(Debug, Clone)]
pub struct FullCuttingPlaneResult {
pub w: Vec<f32>,
pub xi: Vec<f32>,
pub iterations: usize,
pub n_constraints: usize,
pub converged: bool,
}
pub struct FullCuttingPlaneSvm;
impl FullCuttingPlaneSvm {
pub fn qp_master(
constraints: &[Vec<(Vec<f32>, f32)>],
n_features: usize,
c_reg: f32,
) -> SeqResult<(Vec<f32>, Vec<f32>)> {
if c_reg <= 0.0 || c_reg.is_nan() {
return Err(SeqError::InvalidParameter {
name: "c_reg".to_string(),
value: c_reg as f64,
});
}
let n_examples = constraints.len();
for (i, ws) in constraints.iter().enumerate() {
for (j, (dpsi, _)) in ws.iter().enumerate() {
if dpsi.len() != n_features {
return Err(SeqError::ShapeMismatch {
expected: n_features,
got: dpsi.len(),
});
}
let _ = (i, j);
}
}
if n_examples == 0 {
return Ok((vec![0.0; n_features], Vec::new()));
}
let total_c: usize = constraints.iter().map(|ws| ws.len()).sum();
if total_c == 0 {
return Ok((vec![0.0; n_features], vec![0.0; n_examples]));
}
let mut alpha: Vec<f32> = vec![0.0; total_c];
let mut owner: Vec<usize> = Vec::with_capacity(total_c);
let mut loss: Vec<f32> = Vec::with_capacity(total_c);
let mut dpsi: Vec<&Vec<f32>> = Vec::with_capacity(total_c);
for (i, ws) in constraints.iter().enumerate() {
for (psi, ell) in ws.iter() {
owner.push(i);
loss.push(*ell);
dpsi.push(psi);
}
}
let mut max_norm_sq: f32 = 0.0;
for d in &dpsi {
let mut n2 = 0.0_f32;
for v in d.iter() {
n2 += *v * *v;
}
if n2 > max_norm_sq {
max_norm_sq = n2;
}
}
let step: f32 = if max_norm_sq > 0.0 {
1.0 / max_norm_sq
} else {
1.0
};
let inner_iters = 500;
let inner_tol: f32 = 1e-6;
let mut prev_obj = f32::NEG_INFINITY;
for _it in 0..inner_iters {
let mut w = vec![0.0_f32; n_features];
for k in 0..total_c {
let a = alpha[k];
if a == 0.0 {
continue;
}
let pk = dpsi[k];
for f in 0..n_features {
w[f] += a * pk[f];
}
}
let mut grad = vec![0.0_f32; total_c];
for k in 0..total_c {
let mut dot = 0.0_f32;
let pk = dpsi[k];
for f in 0..n_features {
dot += w[f] * pk[f];
}
grad[k] = loss[k] - dot;
}
for k in 0..total_c {
alpha[k] += step * grad[k];
if alpha[k] < 0.0 {
alpha[k] = 0.0;
}
}
let mut sum_i = vec![0.0_f32; n_examples];
for k in 0..total_c {
sum_i[owner[k]] += alpha[k];
}
for k in 0..total_c {
let i = owner[k];
let s = sum_i[i];
if s > c_reg && s > 0.0 {
let scale = c_reg / s;
alpha[k] *= scale;
}
}
let mut obj = 0.0_f32;
for k in 0..total_c {
obj += alpha[k] * loss[k];
}
let mut w = vec![0.0_f32; n_features];
for k in 0..total_c {
let a = alpha[k];
if a == 0.0 {
continue;
}
let pk = dpsi[k];
for f in 0..n_features {
w[f] += a * pk[f];
}
}
let mut ww = 0.0_f32;
for f in 0..n_features {
ww += w[f] * w[f];
}
obj -= 0.5 * ww;
if (obj - prev_obj).abs() < inner_tol {
break;
}
prev_obj = obj;
}
let mut w = vec![0.0_f32; n_features];
for k in 0..total_c {
let a = alpha[k];
if a == 0.0 {
continue;
}
let pk = dpsi[k];
for f in 0..n_features {
w[f] += a * pk[f];
}
}
let mut xi = vec![0.0_f32; n_examples];
for (i, ws) in constraints.iter().enumerate() {
let mut best = 0.0_f32;
for (psi, ell) in ws.iter() {
let mut dot = 0.0_f32;
for f in 0..n_features {
dot += w[f] * psi[f];
}
let v = *ell - dot;
if v > best {
best = v;
}
}
xi[i] = best;
}
Ok((w, xi))
}
pub fn train<O>(
n_examples: usize,
n_features: usize,
separation_oracle: O,
cfg: &FullCuttingPlaneConfig,
) -> SeqResult<FullCuttingPlaneResult>
where
O: Fn(&[f32], usize) -> SeqResult<(Vec<f32>, f32)>,
{
if n_examples == 0 {
return Err(SeqError::InvalidParameter {
name: "n_examples".to_string(),
value: 0.0,
});
}
if n_features == 0 {
return Err(SeqError::InvalidParameter {
name: "n_features".to_string(),
value: 0.0,
});
}
if cfg.c_reg <= 0.0 || cfg.c_reg.is_nan() {
return Err(SeqError::InvalidParameter {
name: "c_reg".to_string(),
value: cfg.c_reg as f64,
});
}
if cfg.epsilon <= 0.0 || cfg.epsilon.is_nan() {
return Err(SeqError::InvalidParameter {
name: "epsilon".to_string(),
value: cfg.epsilon as f64,
});
}
if cfg.max_iter == 0 {
return Err(SeqError::InvalidParameter {
name: "max_iter".to_string(),
value: 0.0,
});
}
let mut working_set: Vec<Vec<(Vec<f32>, f32)>> = vec![Vec::new(); n_examples];
let mut w: Vec<f32> = vec![0.0; n_features];
let mut xi: Vec<f32> = vec![0.0; n_examples];
let mut iterations = 0_usize;
let mut converged = false;
for it in 0..cfg.max_iter {
iterations = it + 1;
let mut added = false;
for i in 0..n_examples {
let (dpsi, ell) = separation_oracle(&w, i)?;
if dpsi.len() != n_features {
return Err(SeqError::ShapeMismatch {
expected: n_features,
got: dpsi.len(),
});
}
let mut dot = 0.0_f32;
for f in 0..n_features {
dot += w[f] * dpsi[f];
}
let viol = ell - dot;
if viol > xi[i] + cfg.epsilon {
working_set[i].push((dpsi, ell));
added = true;
}
}
if !added {
converged = true;
break;
}
let (w_new, xi_new) = Self::qp_master(&working_set, n_features, cfg.c_reg)?;
w = w_new;
xi = xi_new;
}
let n_constraints: usize = working_set.iter().map(|ws| ws.len()).sum();
Ok(FullCuttingPlaneResult {
w,
xi,
iterations,
n_constraints,
converged,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::cell::Cell;
fn default_cfg() -> FullCuttingPlaneConfig {
FullCuttingPlaneConfig {
c_reg: 1.0,
epsilon: 1e-3,
max_iter: 50,
}
}
#[test]
fn train_zero_oracle_returns_zero_weights() {
let oracle =
|_w: &[f32], _i: usize| -> SeqResult<(Vec<f32>, f32)> { Ok((vec![0.0; 3], 0.0)) };
let res = FullCuttingPlaneSvm::train(2, 3, oracle, &default_cfg()).expect("ok");
for v in &res.w {
assert!(v.abs() < 1e-7);
}
for v in &res.xi {
assert!(v.abs() < 1e-7);
}
assert!(res.converged);
assert_eq!(res.n_constraints, 0);
assert_eq!(res.iterations, 1);
}
#[test]
fn qp_master_single_constraint_closed_form() {
let constraints: Vec<Vec<(Vec<f32>, f32)>> = vec![vec![(vec![1.0, 0.0], 1.0)]];
let (w, xi) = FullCuttingPlaneSvm::qp_master(&constraints, 2, 0.5).expect("ok");
assert!((w[0] - 0.5).abs() < 1e-3);
assert!(w[1].abs() < 1e-3);
assert!((xi[0] - 0.5).abs() < 1e-3);
}
#[test]
fn qp_master_single_constraint_within_cap() {
let constraints: Vec<Vec<(Vec<f32>, f32)>> = vec![vec![(vec![1.0, 0.0], 1.0)]];
let (w, xi) = FullCuttingPlaneSvm::qp_master(&constraints, 2, 10.0).expect("ok");
assert!((w[0] - 1.0).abs() < 1e-3);
assert!(w[1].abs() < 1e-3);
assert!(xi[0].abs() < 1e-3);
}
#[test]
fn train_synthetic_separable_problem_converges() {
let calls: Cell<usize> = Cell::new(0);
let oracle = |w: &[f32], i: usize| -> SeqResult<(Vec<f32>, f32)> {
calls.set(calls.get() + 1);
let dpsi = if i == 0 {
vec![1.0_f32, 0.0]
} else {
vec![0.0_f32, 1.0]
};
let ell = 1.0_f32;
let mut dot = 0.0_f32;
for f in 0..2 {
dot += w[f] * dpsi[f];
}
let _ = ell - dot;
Ok((dpsi, ell))
};
let cfg = FullCuttingPlaneConfig {
c_reg: 10.0,
epsilon: 1e-3,
max_iter: 10,
};
let res = FullCuttingPlaneSvm::train(2, 2, oracle, &cfg).expect("ok");
assert!(res.iterations >= 1);
assert!(res.converged);
assert!((res.w[0] - 1.0).abs() < 0.1);
assert!((res.w[1] - 1.0).abs() < 0.1);
assert!(calls.get() >= 2);
}
#[test]
fn train_constraints_grow_with_iterations() {
let counter: Cell<usize> = Cell::new(0);
let oracle = |_w: &[f32], _i: usize| -> SeqResult<(Vec<f32>, f32)> {
let c = counter.get();
counter.set(c + 1);
Ok((vec![1.0_f32, 0.0], 10.0 + c as f32))
};
let cfg = FullCuttingPlaneConfig {
c_reg: 1.0,
epsilon: 1e-3,
max_iter: 3,
};
let res = FullCuttingPlaneSvm::train(2, 2, oracle, &cfg).expect("ok");
assert!(res.n_constraints >= 2);
}
#[test]
fn train_converged_flag_true_when_no_new_violation() {
let calls: Cell<usize> = Cell::new(0);
let oracle = |_w: &[f32], _i: usize| -> SeqResult<(Vec<f32>, f32)> {
let c = calls.get();
calls.set(c + 1);
Ok((vec![1.0_f32, 0.0], 1.0))
};
let cfg = FullCuttingPlaneConfig {
c_reg: 10.0,
epsilon: 1e-3,
max_iter: 10,
};
let res = FullCuttingPlaneSvm::train(1, 2, oracle, &cfg).expect("ok");
assert!(res.converged);
}
#[test]
fn train_is_deterministic() {
let oracle =
|_w: &[f32], _i: usize| -> SeqResult<(Vec<f32>, f32)> { Ok((vec![0.5_f32, 0.5], 0.5)) };
let cfg = default_cfg();
let r1 = FullCuttingPlaneSvm::train(2, 2, oracle, &cfg).expect("ok");
let r2 = FullCuttingPlaneSvm::train(2, 2, oracle, &cfg).expect("ok");
assert_eq!(r1.iterations, r2.iterations);
assert_eq!(r1.n_constraints, r2.n_constraints);
for i in 0..2 {
assert!((r1.w[i] - r2.w[i]).abs() < 1e-7);
}
}
#[test]
fn err_n_examples_zero() {
let oracle = |_w: &[f32], _i: usize| -> SeqResult<(Vec<f32>, f32)> { Ok((vec![0.0], 0.0)) };
let cfg = default_cfg();
let r = FullCuttingPlaneSvm::train(0, 1, oracle, &cfg);
assert!(matches!(r, Err(SeqError::InvalidParameter { .. })));
}
#[test]
fn err_n_features_zero() {
let oracle = |_w: &[f32], _i: usize| -> SeqResult<(Vec<f32>, f32)> { Ok((vec![], 0.0)) };
let cfg = default_cfg();
let r = FullCuttingPlaneSvm::train(1, 0, oracle, &cfg);
assert!(matches!(r, Err(SeqError::InvalidParameter { .. })));
}
#[test]
fn err_c_reg_non_positive() {
let oracle = |_w: &[f32], _i: usize| -> SeqResult<(Vec<f32>, f32)> { Ok((vec![0.0], 0.0)) };
let mut cfg = default_cfg();
cfg.c_reg = 0.0;
let r = FullCuttingPlaneSvm::train(1, 1, oracle, &cfg);
assert!(matches!(r, Err(SeqError::InvalidParameter { .. })));
}
#[test]
fn err_epsilon_non_positive() {
let oracle = |_w: &[f32], _i: usize| -> SeqResult<(Vec<f32>, f32)> { Ok((vec![0.0], 0.0)) };
let mut cfg = default_cfg();
cfg.epsilon = 0.0;
let r = FullCuttingPlaneSvm::train(1, 1, oracle, &cfg);
assert!(matches!(r, Err(SeqError::InvalidParameter { .. })));
}
#[test]
fn err_max_iter_zero() {
let oracle = |_w: &[f32], _i: usize| -> SeqResult<(Vec<f32>, f32)> { Ok((vec![0.0], 0.0)) };
let mut cfg = default_cfg();
cfg.max_iter = 0;
let r = FullCuttingPlaneSvm::train(1, 1, oracle, &cfg);
assert!(matches!(r, Err(SeqError::InvalidParameter { .. })));
}
#[test]
fn single_example_single_constraint() {
let oracle =
|_w: &[f32], _i: usize| -> SeqResult<(Vec<f32>, f32)> { Ok((vec![1.0_f32], 1.0)) };
let cfg = FullCuttingPlaneConfig {
c_reg: 10.0,
epsilon: 1e-3,
max_iter: 10,
};
let res = FullCuttingPlaneSvm::train(1, 1, oracle, &cfg).expect("ok");
assert!((res.w[0] - 1.0).abs() < 0.05);
}
#[test]
fn multiple_examples_accumulate_per_example_constraints() {
let oracle = |_w: &[f32], i: usize| -> SeqResult<(Vec<f32>, f32)> {
let mut dpsi = vec![0.0_f32; 3];
if i < 3 {
dpsi[i] = 1.0;
}
Ok((dpsi, 1.0))
};
let cfg = FullCuttingPlaneConfig {
c_reg: 10.0,
epsilon: 1e-3,
max_iter: 5,
};
let res = FullCuttingPlaneSvm::train(3, 3, oracle, &cfg).expect("ok");
assert!(res.n_constraints >= 3);
}
#[test]
fn xi_is_non_negative() {
let oracle =
|_w: &[f32], _i: usize| -> SeqResult<(Vec<f32>, f32)> { Ok((vec![1.0_f32, 0.0], 0.5)) };
let cfg = default_cfg();
let res = FullCuttingPlaneSvm::train(2, 2, oracle, &cfg).expect("ok");
for v in &res.xi {
assert!(*v >= 0.0);
}
}
#[test]
fn xi_length_equals_n_examples() {
let oracle =
|_w: &[f32], _i: usize| -> SeqResult<(Vec<f32>, f32)> { Ok((vec![0.0_f32, 0.0], 0.0)) };
let cfg = default_cfg();
let res = FullCuttingPlaneSvm::train(5, 2, oracle, &cfg).expect("ok");
assert_eq!(res.xi.len(), 5);
assert_eq!(res.w.len(), 2);
}
#[test]
fn convergence_within_max_iter() {
let oracle =
|_w: &[f32], _i: usize| -> SeqResult<(Vec<f32>, f32)> { Ok((vec![1.0_f32, 0.0], 1.0)) };
let cfg = FullCuttingPlaneConfig {
c_reg: 10.0,
epsilon: 1e-3,
max_iter: 100,
};
let res = FullCuttingPlaneSvm::train(2, 2, oracle, &cfg).expect("ok");
assert!(res.iterations <= cfg.max_iter);
}
#[test]
fn qp_master_empty_constraints() {
let constraints: Vec<Vec<(Vec<f32>, f32)>> = vec![Vec::new(); 3];
let (w, xi) = FullCuttingPlaneSvm::qp_master(&constraints, 4, 1.0).expect("ok");
assert_eq!(w.len(), 4);
assert_eq!(xi.len(), 3);
for v in &w {
assert!(v.abs() < 1e-9);
}
for v in &xi {
assert!(v.abs() < 1e-9);
}
}
#[test]
fn qp_master_shape_mismatch_errors() {
let constraints: Vec<Vec<(Vec<f32>, f32)>> = vec![vec![(vec![1.0, 0.0], 1.0)]];
let r = FullCuttingPlaneSvm::qp_master(&constraints, 3, 1.0);
assert!(matches!(r, Err(SeqError::ShapeMismatch { .. })));
}
#[test]
fn qp_master_c_reg_non_positive_errors() {
let constraints: Vec<Vec<(Vec<f32>, f32)>> = vec![Vec::new()];
let r = FullCuttingPlaneSvm::qp_master(&constraints, 2, -1.0);
assert!(matches!(r, Err(SeqError::InvalidParameter { .. })));
}
#[test]
fn oracle_shape_mismatch_propagates() {
let oracle = |_w: &[f32], _i: usize| -> SeqResult<(Vec<f32>, f32)> {
Ok((vec![1.0_f32], 1.0)) };
let cfg = default_cfg();
let r = FullCuttingPlaneSvm::train(1, 2, oracle, &cfg);
assert!(matches!(r, Err(SeqError::ShapeMismatch { .. })));
}
#[test]
fn oracle_error_propagates() {
let oracle =
|_w: &[f32], _i: usize| -> SeqResult<(Vec<f32>, f32)> { Err(SeqError::EmptyInput) };
let cfg = default_cfg();
let r = FullCuttingPlaneSvm::train(1, 2, oracle, &cfg);
assert!(matches!(r, Err(SeqError::EmptyInput)));
}
}