use crate::error::{Result, TimeSeriesError};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DemandType {
Smooth,
Erratic,
Intermittent,
Lumpy,
}
impl std::fmt::Display for DemandType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
DemandType::Smooth => write!(f, "Smooth"),
DemandType::Erratic => write!(f, "Erratic"),
DemandType::Intermittent => write!(f, "Intermittent"),
DemandType::Lumpy => write!(f, "Lumpy"),
}
}
}
pub const ADI_THRESHOLD: f64 = 1.32;
pub const CV2_THRESHOLD: f64 = 0.49;
pub fn compute_adi(series: &[f64]) -> f64 {
let n = series.len();
if n == 0 {
return f64::INFINITY;
}
let non_zero = series.iter().filter(|&&v| v > 0.0).count();
if non_zero == 0 {
return f64::INFINITY;
}
n as f64 / non_zero as f64
}
pub fn compute_cv2(series: &[f64]) -> f64 {
let non_zero: Vec<f64> = series.iter().copied().filter(|&v| v > 0.0).collect();
let k = non_zero.len();
if k < 2 {
return 0.0;
}
let mean = non_zero.iter().copied().sum::<f64>() / k as f64;
if mean.abs() < f64::EPSILON {
return 0.0;
}
let var = non_zero.iter().map(|&d| (d - mean).powi(2)).sum::<f64>() / (k - 1) as f64;
var / (mean * mean)
}
pub fn classify_demand(series: &[f64]) -> DemandType {
let adi = compute_adi(series);
let cv2 = compute_cv2(series);
match (adi >= ADI_THRESHOLD, cv2 >= CV2_THRESHOLD) {
(false, false) => DemandType::Smooth,
(false, true) => DemandType::Erratic,
(true, false) => DemandType::Intermittent,
(true, true) => DemandType::Lumpy,
}
}
#[derive(Debug, Clone)]
pub struct CrostonModel {
pub z: f64,
pub p: f64,
pub alpha: f64,
pub q: f64,
correction: f64,
}
impl CrostonModel {
pub fn forecast(&self, h: usize) -> Vec<f64> {
if h == 0 {
return Vec::new();
}
let f = self.correction * self.z / self.p;
vec![f; h]
}
pub fn point_forecast(&self) -> f64 {
self.correction * self.z / self.p
}
}
pub struct Croston;
impl Croston {
pub fn fit(demand_series: &[f64], alpha: Option<f64>) -> Result<CrostonModel> {
validate_demand_series(demand_series)?;
let alpha = resolve_alpha(alpha)?;
let (z, p, q) = croston_recursion(demand_series, alpha, 1.0)?;
Ok(CrostonModel {
z,
p,
alpha,
q,
correction: 1.0,
})
}
}
#[derive(Debug, Clone)]
pub struct SbaModel {
pub inner: CrostonModel,
}
impl SbaModel {
pub fn forecast(&self, h: usize) -> Vec<f64> {
self.inner.forecast(h)
}
pub fn point_forecast(&self) -> f64 {
self.inner.point_forecast()
}
}
pub struct Sba;
impl Sba {
pub fn fit(demand_series: &[f64], alpha: Option<f64>) -> Result<SbaModel> {
validate_demand_series(demand_series)?;
let alpha = resolve_alpha(alpha)?;
let correction = 1.0 - alpha / 2.0;
let (z, p, q) = croston_recursion(demand_series, alpha, correction)?;
Ok(SbaModel {
inner: CrostonModel {
z,
p,
alpha,
q,
correction,
},
})
}
}
#[derive(Debug, Clone)]
pub struct TsbModel {
pub demand_prob: f64,
pub demand_size: f64,
pub alpha: f64,
pub beta: f64,
}
impl TsbModel {
pub fn forecast(&self, h: usize) -> Vec<f64> {
if h == 0 {
return Vec::new();
}
let f = self.demand_prob * self.demand_size;
vec![f; h]
}
pub fn point_forecast(&self) -> f64 {
self.demand_prob * self.demand_size
}
}
pub struct Tsb;
impl Tsb {
pub fn fit(demand_series: &[f64], alpha: Option<f64>, beta: Option<f64>) -> Result<TsbModel> {
validate_demand_series(demand_series)?;
let alpha = resolve_alpha(alpha)?;
let beta = match beta {
Some(b) => {
if !(0.0 < b && b <= 1.0) {
return Err(TimeSeriesError::InvalidParameter {
name: "beta".to_string(),
message: "beta must be in (0, 1]".to_string(),
});
}
b
}
None => alpha,
};
let first_third = (demand_series.len() / 3).max(1);
let first_nonzero = demand_series[..first_third]
.iter()
.find(|&&v| v > 0.0)
.copied()
.unwrap_or_else(|| demand_series.iter().copied().filter(|&v| v > 0.0).next().unwrap_or(1.0));
let init_prob = demand_series[..first_third]
.iter()
.filter(|&&v| v > 0.0)
.count() as f64
/ first_third as f64;
let init_prob = init_prob.max(0.01);
let mut p = init_prob;
let mut z = first_nonzero;
for &y in demand_series {
if y > 0.0 {
p = (1.0 - alpha) * p + alpha;
z = (1.0 - beta) * z + beta * y;
} else {
p = (1.0 - alpha) * p;
}
}
Ok(TsbModel {
demand_prob: p,
demand_size: z,
alpha,
beta,
})
}
}
fn validate_demand_series(series: &[f64]) -> Result<()> {
if series.is_empty() {
return Err(TimeSeriesError::InsufficientData {
message: "Demand series must not be empty".to_string(),
required: 1,
actual: 0,
});
}
for &v in series {
if v < 0.0 {
return Err(TimeSeriesError::InvalidInput(
"Demand series must be non-negative".to_string(),
));
}
}
let has_nonzero = series.iter().any(|&v| v > 0.0);
if !has_nonzero {
return Err(TimeSeriesError::InvalidInput(
"Demand series must contain at least one non-zero observation".to_string(),
));
}
Ok(())
}
fn resolve_alpha(alpha: Option<f64>) -> Result<f64> {
match alpha {
Some(a) => {
if !(0.0 < a && a <= 1.0) {
Err(TimeSeriesError::InvalidParameter {
name: "alpha".to_string(),
message: "alpha must be in (0, 1]".to_string(),
})
} else {
Ok(a)
}
}
None => Ok(0.1),
}
}
fn croston_recursion(series: &[f64], alpha: f64, _correction: f64) -> Result<(f64, f64, f64)> {
let (first_idx, first_demand) = series
.iter()
.copied()
.enumerate()
.find(|&(_, v)| v > 0.0)
.ok_or_else(|| {
TimeSeriesError::InvalidInput(
"No non-zero demand found in series; initialisation impossible".to_string(),
)
})?;
let mut z = first_demand;
let mut p = (first_idx + 1) as f64;
let mut q = 1_f64;
for &y in &series[(first_idx + 1)..] {
q += 1.0;
if y > 0.0 {
z = (1.0 - alpha) * z + alpha * y;
p = (1.0 - alpha) * p + alpha * q;
q = 0.0;
}
}
Ok((z, p, q))
}
#[derive(Debug, Clone)]
pub struct DemandSummary {
pub n_periods: usize,
pub n_demand_periods: usize,
pub adi: f64,
pub cv2: f64,
pub demand_type: DemandType,
pub mean_demand: f64,
pub total_demand: f64,
}
impl DemandSummary {
pub fn compute(series: &[f64]) -> Result<Self> {
if series.is_empty() {
return Err(TimeSeriesError::InsufficientData {
message: "Cannot compute summary of empty series".to_string(),
required: 1,
actual: 0,
});
}
let non_zero: Vec<f64> = series.iter().copied().filter(|&v| v > 0.0).collect();
let n_demand_periods = non_zero.len();
let total_demand: f64 = non_zero.iter().copied().sum();
let mean_demand = if n_demand_periods > 0 {
total_demand / n_demand_periods as f64
} else {
0.0
};
Ok(Self {
n_periods: series.len(),
n_demand_periods,
adi: compute_adi(series),
cv2: compute_cv2(series),
demand_type: classify_demand(series),
mean_demand,
total_demand,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_classify_smooth() {
let series = vec![3.0, 4.0, 3.5, 4.0, 3.8, 4.2, 3.9, 4.1];
assert_eq!(classify_demand(&series), DemandType::Smooth);
}
#[test]
fn test_classify_intermittent() {
let series = vec![5.0, 0.0, 0.0, 5.0, 0.0, 0.0, 5.0, 0.0, 0.0, 5.0];
let dt = classify_demand(&series);
assert!(
dt == DemandType::Intermittent || dt == DemandType::Smooth,
"got {:?}",
dt
);
}
#[test]
fn test_classify_lumpy() {
let mut series = vec![0.0_f64; 10];
series[0] = 1.0;
series[4] = 100.0;
series[8] = 2.0;
let dt = classify_demand(&series);
assert_eq!(dt, DemandType::Lumpy);
}
#[test]
fn test_compute_adi_all_nonzero() {
let series = vec![1.0, 2.0, 3.0, 4.0];
assert!((compute_adi(&series) - 1.0).abs() < 1e-9);
}
#[test]
fn test_compute_adi_half_nonzero() {
let series = vec![1.0, 0.0, 1.0, 0.0];
assert!((compute_adi(&series) - 2.0).abs() < 1e-9);
}
#[test]
fn test_compute_cv2_constant() {
let series = vec![5.0, 0.0, 5.0, 0.0, 5.0];
assert!(compute_cv2(&series) < 1e-9);
}
#[test]
fn test_croston_fit_basic() {
let demand = vec![0.0, 5.0, 0.0, 0.0, 3.0, 0.0, 7.0, 0.0, 4.0];
let model = Croston::fit(&demand, Some(0.1)).expect("failed to create model");
assert!(model.point_forecast() > 0.0);
}
#[test]
fn test_croston_forecast_constant() {
let demand = vec![0.0, 4.0, 0.0, 4.0, 0.0, 4.0];
let model = Croston::fit(&demand, Some(0.2)).expect("failed to create model");
let fc = model.forecast(5);
assert_eq!(fc.len(), 5);
for &f in &fc {
assert!((f - fc[0]).abs() < 1e-9);
}
}
#[test]
fn test_croston_invalid_alpha() {
let demand = vec![0.0, 1.0, 0.0];
assert!(Croston::fit(&demand, Some(1.5)).is_err());
assert!(Croston::fit(&demand, Some(0.0)).is_err());
}
#[test]
fn test_croston_all_zero_error() {
let demand = vec![0.0, 0.0, 0.0];
assert!(Croston::fit(&demand, None).is_err());
}
#[test]
fn test_croston_negative_demand_error() {
let demand = vec![1.0, -1.0, 0.0];
assert!(Croston::fit(&demand, None).is_err());
}
#[test]
fn test_sba_correction() {
let demand = vec![0.0, 5.0, 0.0, 5.0, 0.0, 5.0, 0.0, 5.0];
let alpha = 0.2;
let croston = Croston::fit(&demand, Some(alpha)).expect("failed to create croston");
let sba = Sba::fit(&demand, Some(alpha)).expect("failed to create sba");
let expected = (1.0 - alpha / 2.0) * croston.point_forecast();
assert!(
(sba.point_forecast() - expected).abs() < 1e-9,
"SBA = {}, expected = {}",
sba.point_forecast(),
expected
);
}
#[test]
fn test_sba_lower_than_croston() {
let demand = vec![0.0, 3.0, 0.0, 0.0, 6.0, 0.0, 4.0];
let alpha = 0.15;
let c = Croston::fit(&demand, Some(alpha)).expect("failed to create c").point_forecast();
let s = Sba::fit(&demand, Some(alpha)).expect("failed to create s").point_forecast();
assert!(
s <= c + 1e-9,
"SBA ({s}) should be ≤ Croston ({c})"
);
}
#[test]
fn test_tsb_fit_basic() {
let demand = vec![0.0, 5.0, 0.0, 0.0, 3.0, 0.0, 7.0, 0.0, 4.0];
let model = Tsb::fit(&demand, Some(0.1), None).expect("failed to create model");
assert!(model.demand_prob > 0.0);
assert!(model.demand_prob <= 1.0);
assert!(model.demand_size > 0.0);
assert!(model.point_forecast() > 0.0);
}
#[test]
fn test_tsb_probability_decays_on_zeros() {
let mut demand = vec![5.0];
demand.extend(vec![0.0_f64; 50]);
let model = Tsb::fit(&demand, Some(0.3), Some(0.2)).expect("failed to create model");
assert!(
model.demand_prob < 0.1,
"Probability should decay after long zero run, got {}",
model.demand_prob
);
}
#[test]
fn test_tsb_forecast_length() {
let demand = vec![0.0, 2.0, 0.0, 3.0, 0.0];
let model = Tsb::fit(&demand, None, None).expect("failed to create model");
let fc = model.forecast(10);
assert_eq!(fc.len(), 10);
}
#[test]
fn test_demand_summary() {
let demand = vec![0.0, 5.0, 0.0, 0.0, 10.0, 0.0, 5.0];
let summary = DemandSummary::compute(&demand).expect("failed to create summary");
assert_eq!(summary.n_periods, 7);
assert_eq!(summary.n_demand_periods, 3);
assert!((summary.total_demand - 20.0).abs() < 1e-9);
assert!((summary.mean_demand - 20.0 / 3.0).abs() < 1e-9);
}
#[test]
fn test_demand_summary_empty() {
assert!(DemandSummary::compute(&[]).is_err());
}
}