stochastic-rs-ai 2.0.0-rc.1

AI / neural surrogates for volatility models.
Documentation
//! Neural-network surrogate for the 1-factor Bergomi model (flat forward variance).
//!
//! The architecture and scaling conventions follow:
//! - 3 hidden ELU layers (width 30), linear output (88 nodes).
//! - Parameter scaling to `[-1, 1]` with fixed lower/upper bounds.
//! - Output (IV surface) standardization with train-set mean/std.
//!
//! Source:
//! - https://github.com/amuguruza/NN-StochVol-Calibrations
//! - `1Factor/Flat Forward Variance/NN1Factor.ipynb`

use std::path::Path;

use anyhow::Result;
use candle_core::Device;
use ndarray::Array2;

use super::common::StochVolModelSpec;
use super::common::StochVolNn;
use super::common::TrainConfig;
use super::common::TrainReport;

pub const MODEL_ID: &str = "one_factor";
pub const INPUT_DIM: usize = 4;
pub const OUTPUT_DIM: usize = 88;
pub const DEFAULT_HIDDEN_DIM: usize = 30;
pub const PARAM_LB: [f32; INPUT_DIM] = [0.01, 0.5, 0.0, -1.0];
pub const PARAM_UB: [f32; INPUT_DIM] = [0.16, 4.0, 10.0, 0.0];

pub struct OneFactorNn {
  inner: StochVolNn,
}

impl OneFactorNn {
  pub fn new(device: &Device) -> Result<Self> {
    Self::with_hidden(device, DEFAULT_HIDDEN_DIM)
  }

  pub fn with_hidden(device: &Device, hidden_dim: usize) -> Result<Self> {
    let spec = StochVolModelSpec::new(
      MODEL_ID,
      INPUT_DIM,
      OUTPUT_DIM,
      hidden_dim,
      PARAM_LB.to_vec(),
      PARAM_UB.to_vec(),
    )?;
    Ok(Self {
      inner: StochVolNn::new(spec, device)?,
    })
  }

  pub fn train(
    &mut self,
    params: &Array2<f32>,
    surfaces: &Array2<f32>,
    config: &TrainConfig,
  ) -> Result<TrainReport> {
    self.inner.train(params, surfaces, config)
  }

  /// Predict an implied-volatility surface for the given parameter vector.
  ///
  /// Returns a flat `Vec<f32>` of length [`OUTPUT_DIM`] in row-major order
  /// (maturity-major). Pass the result to
  /// `stochastic_rs_quant::vol_surface::ImpliedVolSurface::from_flat_iv_grid`
  /// to consume the prediction with the rest of the pricing toolchain.
  pub fn predict_surface(&self, params: &[f32; INPUT_DIM]) -> Result<Vec<f32>> {
    self.inner.predict_surface(params)
  }

  pub fn predict_surfaces(&self, params: &Array2<f32>) -> Result<Array2<f32>> {
    self.inner.predict_surfaces(params)
  }

  /// Bridge to `stochastic-rs-quant`: predict and package as
  /// [`ImpliedVolSurface`](stochastic_rs_quant::vol_surface::ImpliedVolSurface).
  /// Available with the `quant` cargo feature.
  #[cfg(feature = "quant")]
  pub fn predict_implied_vol_surface(
    &self,
    params: &[f32; INPUT_DIM],
    strikes: Vec<f64>,
    maturities: Vec<f64>,
    forwards: Vec<f64>,
  ) -> Result<stochastic_rs_quant::vol_surface::ImpliedVolSurface> {
    self
      .inner
      .predict_implied_vol_surface(params, strikes, maturities, forwards)
  }

  pub fn save<P: AsRef<Path>>(&self, dir: P) -> Result<()> {
    self.inner.save(dir)
  }

  pub fn load<P: AsRef<Path>>(dir: P, device: &Device) -> Result<Self> {
    Ok(Self {
      inner: StochVolNn::load(MODEL_ID, dir, device)?,
    })
  }
}

#[cfg(test)]
mod tests {
  use std::fs;
  use std::path::Path;

  use super::*;
  use crate::volatility::common::load_trainset_gzip_npy;
  use crate::volatility::common::rmse_1d;
  use crate::volatility::common::synthetic_surface_dataset;
  use crate::volatility::common::write_surface_fit_plot_html;

  #[test]
  fn train_save_load_roundtrip() -> Result<()> {
    let device = Device::Cpu;
    let (params, surfaces) = synthetic_surface_dataset(&PARAM_LB, &PARAM_UB, 192, OUTPUT_DIM, 11);
    let mut model = OneFactorNn::new(&device)?;
    let cfg = TrainConfig {
      test_ratio: 0.2,
      batch_size: 32,
      epochs: 20,
      learning_rate: 1e-3,
      random_seed: 1234,
      shuffle: true,
    };
    let report = model.train(&params, &surfaces, &cfg)?;
    assert_eq!(report.epochs.len(), cfg.epochs);
    assert!(report.epochs.last().unwrap().train_rmse.is_finite());

    let save_dir = std::env::temp_dir().join(format!(
      "stochastic_rs_one_factor_nn_{}_{}",
      std::process::id(),
      2002_u64
    ));
    if save_dir.exists() {
      let _ = fs::remove_dir_all(&save_dir);
    }
    model.save(&save_dir)?;
    let loaded = OneFactorNn::load(&save_dir, &device)?;

    let sample = [
      params[[0, 0]],
      params[[0, 1]],
      params[[0, 2]],
      params[[0, 3]],
    ];
    let p1 = model.predict_surface(&sample)?;
    let p2 = loaded.predict_surface(&sample)?;
    let max_diff = p1
      .iter()
      .zip(p2.iter())
      .map(|(a, b)| (a - b).abs())
      .fold(0.0_f32, f32::max);
    assert!(max_diff < 1e-4);

    let _ = fs::remove_dir_all(&save_dir);
    Ok(())
  }

  #[test]
  fn real_trainset_fit_plot() -> Result<()> {
    let trainset_path = Path::new("src/ai/volatility/Bergomi1FactorTrainSet.txt.gz");
    if !trainset_path.exists() {
      return Ok(());
    }

    let device = Device::Cpu;
    let (params, surfaces) =
      load_trainset_gzip_npy(trainset_path, INPUT_DIM, OUTPUT_DIM, Some(12_000))?;

    let mut model = OneFactorNn::new(&device)?;
    let cfg = TrainConfig {
      test_ratio: 0.15,
      batch_size: 64,
      epochs: 30,
      learning_rate: 1e-3,
      random_seed: 42,
      shuffle: true,
    };
    let report = model.train(&params, &surfaces, &cfg)?;
    let sample_idx = surfaces.nrows() / 3;
    let sample = [
      params[[sample_idx, 0]],
      params[[sample_idx, 1]],
      params[[sample_idx, 2]],
      params[[sample_idx, 3]],
    ];
    let pred = model.predict_surface(&sample)?;
    let actual = surfaces.row(sample_idx).to_vec();
    let fit_rmse = rmse_1d(&actual, &pred)?;

    let out = Path::new("target/nn_fit_plots/one_factor_fit.html");
    let strikes = [0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4, 1.5];
    let maturities = [0.1, 0.3, 0.6, 0.9, 1.2, 1.5, 1.8, 2.0];
    write_surface_fit_plot_html(
      out,
      &format!(
        "1Factor Bergomi NN Fit - sample {} - RMSE {:.5}",
        sample_idx, fit_rmse
      ),
      &strikes,
      &maturities,
      &actual,
      &pred,
    )?;
    println!(
      "1Factor fit plot written to {} (sample_rmse={:.6}, final_val_rmse={:.6})",
      out.display(),
      fit_rmse,
      report.epochs.last().map(|e| e.val_rmse).unwrap_or(f32::NAN)
    );
    assert!(out.exists());
    Ok(())
  }
}