use crate::calibration::CalibrationResult;
use crate::errors::CalibrationError;
use crate::math::levenberg_marquardt;
use crate::raw::RawSvi;
use crate::types::Quote;
const MIN_QUOTES: usize = 5;
const LM_TOL: f64 = 1e-14;
const LM_MAX_ITER: usize = 500;
#[allow(clippy::similar_names, clippy::many_single_char_names)]
pub fn refine(quotes: &[Quote], seed: &RawSvi) -> Result<CalibrationResult, CalibrationError> {
validate_quotes(quotes)?;
let start = [
seed.a,
seed.b.max(1e-12).ln(),
atanh(seed.rho.clamp(-0.999_999, 0.999_999)),
seed.m,
seed.sigma.max(1e-12).ln(),
];
let residual = |p: &[f64]| -> Vec<(f64, f64, Vec<f64>)> {
let a = p[0];
let b_hat = p[1];
let rho_hat = p[2];
let m = p[3];
let sigma_hat = p[4];
let b = b_hat.exp();
let rho = rho_hat.tanh();
let sigma = sigma_hat.exp();
let db_dbhat = b; let drho_drhohat = 1.0 - rho * rho; let dsigma_dshat = sigma;
quotes
.iter()
.filter(|q| q.weight > 0.0)
.map(|q| {
let u = q.k - m;
let r = (u * u + sigma * sigma).sqrt();
let model = a + b * (rho * u + r);
let resid = model - q.w;
let dw_da = 1.0;
let dw_db = rho * u + r;
let dw_drho = b * u;
let dw_dm = b * (-rho - u / r);
let dw_dsigma = b * sigma / r;
let jac = vec![
dw_da,
dw_db * db_dbhat,
dw_drho * drho_drhohat,
dw_dm,
dw_dsigma * dsigma_dshat,
];
(resid, q.weight, jac)
})
.collect()
};
let res = levenberg_marquardt(residual, &start, LM_TOL, LM_MAX_ITER);
let a = res.params[0];
let b = res.params[1].exp();
let rho = res.params[2].tanh();
let m = res.params[3];
let sigma = res.params[4].exp();
let total_weight: f64 = quotes.iter().map(|q| q.weight).sum();
let rmse = if total_weight > 0.0 {
(res.cost / total_weight).sqrt()
} else {
0.0
};
let slice = RawSvi::new(a, b, rho, m, sigma).map_err(CalibrationError::Param)?;
Ok(CalibrationResult::new(slice, rmse))
}
pub fn calibrate(quotes: &[Quote]) -> Result<CalibrationResult, CalibrationError> {
validate_quotes(quotes)?;
let w_atm = quotes
.iter()
.min_by(|x, y| {
x.k.abs()
.partial_cmp(&y.k.abs())
.unwrap_or(core::cmp::Ordering::Equal)
})
.map_or(0.04, |q| q.w);
let k_span = {
let lo = quotes.iter().map(|q| q.k).fold(f64::INFINITY, f64::min);
let hi = quotes.iter().map(|q| q.k).fold(f64::NEG_INFINITY, f64::max);
(hi - lo).max(1e-3)
};
let seed = RawSvi::new(0.5 * w_atm, 0.1, -0.1, 0.0, 0.5 * k_span)
.unwrap_or(RawSvi::new_unchecked(0.5 * w_atm, 0.1, -0.1, 0.0, 0.1));
refine(quotes, &seed)
}
fn validate_quotes(quotes: &[Quote]) -> Result<(), CalibrationError> {
if quotes.is_empty() {
return Err(CalibrationError::EmptyQuotes);
}
if quotes.len() < MIN_QUOTES {
return Err(CalibrationError::TooFewQuotes {
got: quotes.len(),
need: MIN_QUOTES,
});
}
if quotes.iter().all(|q| q.weight <= 0.0) {
return Err(CalibrationError::AllWeightsZero);
}
Ok(())
}
#[inline]
fn atanh(x: f64) -> f64 {
0.5 * ((1.0 + x) / (1.0 - x)).ln()
}
#[cfg(test)]
mod tests {
use super::*;
fn synthetic(svi: &RawSvi, ks: &[f64]) -> Vec<Quote> {
ks.iter()
.map(|&k| Quote::new(k, svi.total_variance(k), 1.0).unwrap())
.collect()
}
#[test]
fn atanh_inverts_tanh() {
for &x in &[-0.9, -0.3, 0.0, 0.5, 0.95] {
assert!((atanh(x).tanh() - x).abs() < 1e-12);
}
}
#[test]
fn rejects_empty() {
assert!(matches!(calibrate(&[]), Err(CalibrationError::EmptyQuotes)));
}
#[test]
fn rejects_too_few() {
let q = Quote::new(0.0, 0.04, 1.0).unwrap();
assert!(matches!(
calibrate(&[q, q]),
Err(CalibrationError::TooFewQuotes { .. })
));
}
#[test]
fn refine_recovers_from_perturbed_seed() {
let truth = RawSvi::new(0.04, 0.4, -0.3, 0.05, 0.15).unwrap();
let ks = [-0.4, -0.25, -0.1, 0.0, 0.1, 0.25, 0.4];
let quotes = synthetic(&truth, &ks);
let seed = RawSvi::new(0.05, 0.3, -0.2, 0.0, 0.18).unwrap();
let fit = refine("es, &seed).unwrap();
assert!(fit.rmse < 1e-6, "rmse = {}", fit.rmse);
for &k in &[-0.5, 0.0, 0.5] {
let err = (fit.slice.total_variance(k) - truth.total_variance(k)).abs();
assert!(err < 1e-4, "k = {k}, err = {err}");
}
}
#[test]
fn calibrate_standalone_fits() {
let truth = RawSvi::new(0.04, 0.35, -0.25, 0.02, 0.16).unwrap();
let ks = [-0.4, -0.2, -0.05, 0.05, 0.2, 0.4];
let quotes = synthetic(&truth, &ks);
let fit = calibrate("es).unwrap();
for &k in &[-0.3, 0.0, 0.3] {
let err = (fit.slice.total_variance(k) - truth.total_variance(k)).abs();
assert!(err < 1e-2, "k = {k}, err = {err}");
}
}
#[test]
fn refine_preserves_domain() {
let truth = RawSvi::new(0.03, 0.5, 0.6, -0.05, 0.12).unwrap();
let ks = [-0.3, -0.15, 0.0, 0.15, 0.3, 0.45];
let quotes = synthetic(&truth, &ks);
let seed = RawSvi::new(0.04, 0.3, 0.3, 0.0, 0.2).unwrap();
let fit = refine("es, &seed).unwrap();
assert!(fit.slice.validate().is_ok());
assert!(fit.slice.b >= 0.0);
assert!(fit.slice.rho.abs() < 1.0);
assert!(fit.slice.sigma > 0.0);
}
}