use std::collections::HashMap;
use candle_core::{DType, Device, Tensor};
use crate::error::{OptimError, Result};
#[derive(Debug, Clone)]
pub struct DeterministicPredictionConfig {
pub warmup_steps: usize,
pub history_window: usize,
pub prediction_horizon: usize,
pub history_decay: f32,
pub residual_threshold: f32,
}
impl Default for DeterministicPredictionConfig {
fn default() -> Self {
Self {
warmup_steps: 10,
history_window: 8,
prediction_horizon: 4,
history_decay: 0.95,
residual_threshold: 0.5,
}
}
}
impl DeterministicPredictionConfig {
#[must_use]
pub const fn with_warmup_steps(mut self, steps: usize) -> Self {
self.warmup_steps = steps;
self
}
#[must_use]
pub const fn with_history_window(mut self, window: usize) -> Self {
self.history_window = window;
self
}
#[must_use]
pub const fn with_prediction_horizon(mut self, horizon: usize) -> Self {
self.prediction_horizon = horizon;
self
}
#[must_use]
pub const fn with_history_decay(mut self, decay: f32) -> Self {
self.history_decay = decay;
self
}
}
#[derive(Clone)]
struct GradientSnapshot {
step: usize,
gradient: Tensor,
}
#[derive(Clone)]
struct LinearGradientModel {
baseline: Tensor,
velocity: Tensor,
fit_step: usize,
}
pub struct DeterministicPredictor {
config: DeterministicPredictionConfig,
device: Device,
shapes: HashMap<String, Vec<usize>>,
history: HashMap<String, Vec<GradientSnapshot>>,
models: HashMap<String, LinearGradientModel>,
residuals: HashMap<String, Tensor>,
global_step: usize,
steps_since_fit: usize,
warmup_complete: bool,
stats: PredictorStatistics,
}
#[derive(Debug, Clone, Default)]
pub struct PredictorStatistics {
pub total_steps: usize,
pub full_steps: usize,
pub predicted_steps: usize,
pub mean_abs_error: f32,
pub max_residual: f32,
pub early_corrections: usize,
}
impl DeterministicPredictor {
pub fn new(
param_shapes: &[(String, Vec<usize>)],
config: DeterministicPredictionConfig,
device: &Device,
) -> Result<Self> {
let mut shapes = HashMap::new();
let mut history = HashMap::new();
let mut residuals = HashMap::new();
for (name, shape) in param_shapes {
shapes.insert(name.clone(), shape.clone());
history.insert(name.clone(), Vec::with_capacity(config.history_window + 4));
residuals.insert(
name.clone(),
Tensor::zeros(shape.as_slice(), DType::F32, device)?,
);
}
Ok(Self {
config,
device: device.clone(),
shapes,
history,
models: HashMap::new(),
residuals,
global_step: 0,
steps_since_fit: 0,
warmup_complete: false,
stats: PredictorStatistics::default(),
})
}
#[must_use]
pub fn in_warmup(&self) -> bool {
!self.warmup_complete
}
#[must_use]
pub fn needs_correction(&self) -> bool {
if self.steps_since_fit >= self.config.prediction_horizon {
return true;
}
for residual in self.residuals.values() {
if let Ok(max_abs) = residual.abs().and_then(|t| t.max(0)).and_then(|t| t.to_scalar::<f32>()) {
if max_abs > self.config.residual_threshold {
return true;
}
}
}
false
}
pub fn record_gradient(
&mut self,
gradients: &HashMap<String, Tensor>,
is_correction: bool,
) -> Result<()> {
for (name, grad) in gradients {
if let Some(hist) = self.history.get_mut(name) {
hist.push(GradientSnapshot {
step: self.global_step,
gradient: grad.clone(),
});
let window = self.config.history_window;
if hist.len() > window + 2 {
hist.drain(0..hist.len() - window - 2);
}
}
}
self.stats.total_steps += 1;
self.stats.full_steps += 1;
if !self.warmup_complete {
let min_history = self.history.values().map(|h| h.len()).min().unwrap_or(0);
if min_history >= self.config.warmup_steps {
self.warmup_complete = true;
self.fit_models()?;
}
} else if is_correction {
self.update_residuals(gradients)?;
self.fit_models()?;
} else {
self.fit_models()?;
}
self.global_step += 1;
self.steps_since_fit = 0;
Ok(())
}
pub fn predict_gradient(&mut self) -> Result<HashMap<String, Tensor>> {
if !self.warmup_complete {
return Err(OptimError::Prediction(
"Cannot predict during warmup phase".to_string(),
));
}
let mut predicted = HashMap::new();
for (name, model) in &self.models {
let dt = (self.global_step - model.fit_step) as f64;
let velocity_term = (&model.velocity * dt)?;
let mut prediction = model.baseline.add(&velocity_term)?;
if let Some(residual) = self.residuals.get(name) {
let residual_weight = self.config.history_decay.powi(self.steps_since_fit as i32);
let scaled_residual = (residual * residual_weight as f64)?;
prediction = prediction.add(&scaled_residual)?;
}
predicted.insert(name.clone(), prediction);
}
self.stats.total_steps += 1;
self.stats.predicted_steps += 1;
self.global_step += 1;
self.steps_since_fit += 1;
Ok(predicted)
}
fn update_residuals(&mut self, actual: &HashMap<String, Tensor>) -> Result<()> {
for (name, actual_grad) in actual {
if let Some(model) = self.models.get(name) {
let dt = (self.global_step - model.fit_step) as f64;
let velocity_term = (&model.velocity * dt)?;
let predicted = model.baseline.add(&velocity_term)?;
let error = actual_grad.sub(&predicted)?;
if let Some(existing) = self.residuals.get(name) {
let decay = self.config.history_decay as f64;
let decayed_existing = (existing * decay)?;
let new_contribution = (&error * (1.0 - decay))?;
self.residuals
.insert(name.clone(), decayed_existing.add(&new_contribution)?);
} else {
self.residuals.insert(name.clone(), error);
}
if let Ok(mean_err) = actual_grad
.sub(&predicted)
.and_then(|t| t.abs())
.and_then(|t| t.mean_all())
.and_then(|t| t.to_scalar::<f32>())
{
self.stats.mean_abs_error =
0.9 * self.stats.mean_abs_error + 0.1 * mean_err;
}
}
}
Ok(())
}
fn fit_models(&mut self) -> Result<()> {
for (name, hist) in &self.history {
if hist.len() < 2 {
continue;
}
let shape = self.shapes.get(name).ok_or_else(|| {
OptimError::Prediction(format!("Unknown parameter: {name}"))
})?;
let n = hist.len();
let mut sum_w = 0.0f64;
let mut sum_wt = 0.0f64;
let mut sum_wt2 = 0.0f64;
let mut sum_wg: Option<Tensor> = None;
let mut sum_wtg: Option<Tensor> = None;
let t_ref = hist.last().map(|s| s.step).unwrap_or(0);
for (i, snapshot) in hist.iter().enumerate() {
let age = (n - 1 - i) as i32;
let w = self.config.history_decay.powi(age) as f64;
let t = (snapshot.step as i64 - t_ref as i64) as f64;
sum_w += w;
sum_wt += w * t;
sum_wt2 += w * t * t;
let wg = (&snapshot.gradient * w)?;
let wtg = (&snapshot.gradient * (w * t))?;
sum_wg = Some(match sum_wg {
Some(acc) => acc.add(&wg)?,
None => wg,
});
sum_wtg = Some(match sum_wtg {
Some(acc) => acc.add(&wtg)?,
None => wtg,
});
}
let det = sum_w * sum_wt2 - sum_wt * sum_wt;
if det.abs() < 1e-10 {
let baseline = hist.last().unwrap().gradient.clone();
let velocity = Tensor::zeros(shape.as_slice(), DType::F32, &self.device)?;
self.models.insert(
name.clone(),
LinearGradientModel {
baseline,
velocity,
fit_step: self.global_step,
},
);
continue;
}
let sum_wg = sum_wg.ok_or_else(|| {
OptimError::Prediction("Empty gradient history".to_string())
})?;
let sum_wtg = sum_wtg.ok_or_else(|| {
OptimError::Prediction("Empty gradient history".to_string())
})?;
let baseline = {
let term1 = (&sum_wg * sum_wt2)?;
let term2 = (&sum_wtg * sum_wt)?;
let numer = term1.sub(&term2)?;
(&numer * (1.0 / det))?
};
let velocity = {
let term1 = (&sum_wtg * sum_w)?;
let term2 = (&sum_wg * sum_wt)?;
let numer = term1.sub(&term2)?;
(&numer * (1.0 / det))?
};
self.models.insert(
name.clone(),
LinearGradientModel {
baseline,
velocity,
fit_step: self.global_step,
},
);
}
Ok(())
}
#[must_use]
pub fn get_stats(&self) -> &PredictorStatistics {
&self.stats
}
pub fn reset(&mut self) -> Result<()> {
for hist in self.history.values_mut() {
hist.clear();
}
self.models.clear();
for (name, shape) in &self.shapes {
self.residuals.insert(
name.clone(),
Tensor::zeros(shape.as_slice(), DType::F32, &self.device)?,
);
}
self.global_step = 0;
self.steps_since_fit = 0;
self.warmup_complete = false;
self.stats = PredictorStatistics::default();
Ok(())
}
#[must_use]
pub const fn global_step(&self) -> usize {
self.global_step
}
#[must_use]
pub const fn is_ready(&self) -> bool {
self.warmup_complete
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_shapes() -> Vec<(String, Vec<usize>)> {
vec![
("layer.weight".to_string(), vec![16, 32]),
("layer.bias".to_string(), vec![16]),
]
}
#[test]
fn test_warmup_phase() {
let config = DeterministicPredictionConfig::default().with_warmup_steps(5);
let mut predictor =
DeterministicPredictor::new(&create_shapes(), config, &Device::Cpu).unwrap();
assert!(predictor.in_warmup());
assert!(!predictor.is_ready());
for i in 0..5 {
let mut grads = HashMap::new();
grads.insert(
"layer.weight".to_string(),
Tensor::ones((16, 32), DType::F32, &Device::Cpu)
.unwrap()
.affine(i as f64, 0.0)
.unwrap(),
);
grads.insert(
"layer.bias".to_string(),
Tensor::ones(16, DType::F32, &Device::Cpu)
.unwrap()
.affine(i as f64, 0.0)
.unwrap(),
);
predictor.record_gradient(&grads, false).unwrap();
}
assert!(!predictor.in_warmup());
assert!(predictor.is_ready());
}
#[test]
fn test_deterministic_prediction() {
let config = DeterministicPredictionConfig::default()
.with_warmup_steps(3)
.with_prediction_horizon(2);
let device = Device::Cpu;
let shapes = create_shapes();
let mut pred1 = DeterministicPredictor::new(&shapes, config.clone(), &device).unwrap();
let mut pred2 = DeterministicPredictor::new(&shapes, config, &device).unwrap();
for i in 0..5 {
let mut grads = HashMap::new();
grads.insert(
"layer.weight".to_string(),
Tensor::ones((16, 32), DType::F32, &device)
.unwrap()
.affine(1.0 + i as f64 * 0.1, 0.0)
.unwrap(),
);
grads.insert(
"layer.bias".to_string(),
Tensor::ones(16, DType::F32, &device)
.unwrap()
.affine(1.0 + i as f64 * 0.1, 0.0)
.unwrap(),
);
pred1.record_gradient(&grads, false).unwrap();
pred2.record_gradient(&grads, false).unwrap();
}
let p1 = pred1.predict_gradient().unwrap();
let p2 = pred2.predict_gradient().unwrap();
for (name, t1) in &p1 {
let t2 = p2.get(name).unwrap();
let diff: f32 = t1
.sub(t2)
.unwrap()
.abs()
.unwrap()
.flatten_all()
.unwrap()
.max(0)
.unwrap()
.to_scalar()
.unwrap();
assert!(
diff < 1e-6,
"Predictions should be deterministic, got diff={diff}"
);
}
}
#[test]
fn test_linear_fit_quality() {
let config = DeterministicPredictionConfig::default()
.with_warmup_steps(5)
.with_prediction_horizon(3);
let device = Device::Cpu;
let shapes = vec![("param".to_string(), vec![8])];
let mut predictor = DeterministicPredictor::new(&shapes, config, &device).unwrap();
for t in 0..5 {
let mut grads = HashMap::new();
grads.insert(
"param".to_string(),
Tensor::ones(8, DType::F32, &device)
.unwrap()
.affine(1.0 + 0.1 * t as f64, 0.0)
.unwrap(),
);
predictor.record_gradient(&grads, false).unwrap();
}
let predicted = predictor.predict_gradient().unwrap();
let pred_vals: Vec<f32> = predicted
.get("param")
.unwrap()
.to_vec1()
.unwrap();
for v in &pred_vals {
assert!(
(*v - 1.5).abs() < 0.1,
"Linear prediction should be accurate, got {v}"
);
}
}
}