use crate::error::{Result, TimeSeriesError};
use std::f32::consts::PI;
#[derive(Debug, Clone, PartialEq)]
pub enum NBEATSBasis {
Generic,
Trend {
degree: usize,
},
Seasonality {
harmonics: usize,
},
}
#[derive(Debug, Clone)]
pub struct NBEATSConfig {
pub n_stacks: usize,
pub n_blocks: usize,
pub n_layers: usize,
pub n_hidden: usize,
pub horizon: usize,
pub lookback: usize,
pub basis: NBEATSBasis,
}
impl Default for NBEATSConfig {
fn default() -> Self {
Self {
n_stacks: 2,
n_blocks: 3,
n_layers: 4,
n_hidden: 256,
horizon: 12,
lookback: 48,
basis: NBEATSBasis::Generic,
}
}
}
#[derive(Debug, Clone)]
pub struct FCLayer {
pub w: Vec<Vec<f32>>,
pub b: Vec<f32>,
}
impl FCLayer {
pub fn new(in_size: usize, out_size: usize, seed: u64) -> Self {
let std_dev = (2.0 / (in_size + out_size) as f64).sqrt() as f32;
let mut lcg = seed.wrapping_mul(6364136223846793005).wrapping_add(1);
let mut w = vec![vec![0.0_f32; in_size]; out_size];
let mut b = vec![0.0_f32; out_size];
for row in &mut w {
for cell in row.iter_mut() {
lcg = lcg.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
let u = (lcg >> 33) as f32 / (u32::MAX as f32);
*cell = (u * 2.0 - 1.0) * std_dev;
}
}
for cell in &mut b {
lcg = lcg.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
let u = (lcg >> 33) as f32 / (u32::MAX as f32);
*cell = (u * 2.0 - 1.0) * std_dev * 0.1;
}
Self { w, b }
}
pub fn forward(&self, x: &[f32], relu: bool) -> Vec<f32> {
let out_size = self.w.len();
let mut out = vec![0.0_f32; out_size];
for (i, row) in self.w.iter().enumerate() {
let mut sum = self.b[i];
for (j, &w_ij) in row.iter().enumerate() {
if j < x.len() {
sum += w_ij * x[j];
}
}
out[i] = if relu { sum.max(0.0) } else { sum };
}
out
}
fn sgd_update(&mut self, grad_w: &[Vec<f32>], grad_b: &[f32], lr: f32) {
for (i, row) in self.w.iter_mut().enumerate() {
if i < grad_w.len() {
for (j, cell) in row.iter_mut().enumerate() {
if j < grad_w[i].len() {
*cell -= lr * grad_w[i][j];
}
}
}
}
for (i, cell) in self.b.iter_mut().enumerate() {
if i < grad_b.len() {
*cell -= lr * grad_b[i];
}
}
}
}
fn polynomial_basis(time_steps: usize, degree: usize) -> Vec<Vec<f32>> {
(0..time_steps)
.map(|n| {
let t = n as f32 / time_steps as f32;
(0..=degree).map(|k| t.powi(k as i32)).collect()
})
.collect()
}
fn fourier_basis(time_steps: usize, harmonics: usize) -> Vec<Vec<f32>> {
(0..time_steps)
.map(|n| {
let t = n as f32 / time_steps as f32;
let mut row = Vec::with_capacity(2 * harmonics);
for h in 1..=harmonics {
row.push((2.0 * PI * h as f32 * t).cos());
row.push((2.0 * PI * h as f32 * t).sin());
}
row
})
.collect()
}
fn project_basis(theta: &[f32], basis: &[Vec<f32>]) -> Vec<f32> {
basis
.iter()
.map(|row| {
row.iter()
.zip(theta.iter())
.map(|(&b, &th)| b * th)
.sum::<f32>()
})
.collect()
}
#[derive(Debug, Clone)]
pub struct NBEATSBlock {
pub layers: Vec<FCLayer>,
pub fc_backcast: FCLayer,
pub fc_forecast: FCLayer,
pub theta_b_dim: usize,
pub theta_f_dim: usize,
pub backcast_dim: usize,
pub forecast_dim: usize,
pub basis: NBEATSBasis,
}
impl NBEATSBlock {
pub fn new(config: &NBEATSConfig, seed: u64) -> Self {
let (theta_b_dim, theta_f_dim) = match &config.basis {
NBEATSBasis::Generic => (config.lookback, config.horizon),
NBEATSBasis::Trend { degree } => (degree + 1, degree + 1),
NBEATSBasis::Seasonality { harmonics } => (2 * harmonics, 2 * harmonics),
};
let mut layers = Vec::with_capacity(config.n_layers);
let in_dim = config.lookback;
for i in 0..config.n_layers {
let layer_in = if i == 0 { in_dim } else { config.n_hidden };
layers.push(FCLayer::new(layer_in, config.n_hidden, seed + i as u64 * 7));
}
let fc_backcast = FCLayer::new(config.n_hidden, theta_b_dim, seed + 100);
let fc_forecast = FCLayer::new(config.n_hidden, theta_f_dim, seed + 200);
Self {
layers,
fc_backcast,
fc_forecast,
theta_b_dim,
theta_f_dim,
backcast_dim: config.lookback,
forecast_dim: config.horizon,
basis: config.basis.clone(),
}
}
pub fn forward(&self, x: &[f32]) -> (Vec<f32>, Vec<f32>) {
let mut h = x.to_vec();
for layer in &self.layers {
h = layer.forward(&h, true);
}
let theta_b = self.fc_backcast.forward(&h, false);
let theta_f = self.fc_forecast.forward(&h, false);
let backcast = match &self.basis {
NBEATSBasis::Generic => {
let mut bc = vec![0.0_f32; self.backcast_dim];
for (i, &v) in theta_b.iter().enumerate() {
if i < bc.len() {
bc[i] = v;
}
}
bc
}
NBEATSBasis::Trend { degree } => {
let basis = polynomial_basis(self.backcast_dim, *degree);
project_basis(&theta_b, &basis)
}
NBEATSBasis::Seasonality { harmonics } => {
let basis = fourier_basis(self.backcast_dim, *harmonics);
project_basis(&theta_b, &basis)
}
};
let forecast = match &self.basis {
NBEATSBasis::Generic => {
let mut fc = vec![0.0_f32; self.forecast_dim];
for (i, &v) in theta_f.iter().enumerate() {
if i < fc.len() {
fc[i] = v;
}
}
fc
}
NBEATSBasis::Trend { degree } => {
let basis = polynomial_basis(self.forecast_dim, *degree);
project_basis(&theta_f, &basis)
}
NBEATSBasis::Seasonality { harmonics } => {
let basis = fourier_basis(self.forecast_dim, *harmonics);
project_basis(&theta_f, &basis)
}
};
(backcast, forecast)
}
}
#[derive(Debug, Clone)]
pub struct NBEATSStack {
pub blocks: Vec<NBEATSBlock>,
pub basis: NBEATSBasis,
}
impl NBEATSStack {
pub fn new(config: &NBEATSConfig, seed: u64) -> Self {
let blocks = (0..config.n_blocks)
.map(|b| NBEATSBlock::new(config, seed + b as u64 * 31))
.collect();
Self {
blocks,
basis: config.basis.clone(),
}
}
pub fn forward(&self, x: &[f32]) -> (Vec<f32>, Vec<f32>) {
let mut residual = x.to_vec();
let mut stack_forecast = vec![0.0_f32; self.blocks[0].forecast_dim];
let mut stack_backcast = vec![0.0_f32; residual.len()];
for block in &self.blocks {
let (bc, fc) = block.forward(&residual);
for (r, &b) in residual.iter_mut().zip(bc.iter()) {
*r -= b;
}
for (s, &f) in stack_forecast.iter_mut().zip(fc.iter()) {
*s += f;
}
for (s, &b) in stack_backcast.iter_mut().zip(bc.iter()) {
*s += b;
}
}
(stack_backcast, stack_forecast)
}
}
#[derive(Debug, Clone)]
pub struct NBEATS {
pub stacks: Vec<NBEATSStack>,
pub config: NBEATSConfig,
}
impl NBEATS {
pub fn new(config: NBEATSConfig) -> Self {
let stacks = (0..config.n_stacks)
.map(|s| NBEATSStack::new(&config, s as u64 * 1000 + 42))
.collect();
Self { stacks, config }
}
pub fn forward(&self, x: &[f32]) -> Result<Vec<f32>> {
if x.len() != self.config.lookback {
return Err(TimeSeriesError::InvalidInput(format!(
"Input length {} does not match configured lookback {}",
x.len(),
self.config.lookback
)));
}
let mut residual = x.to_vec();
let mut total_forecast = vec![0.0_f32; self.config.horizon];
for stack in &self.stacks {
let (bc, fc) = stack.forward(&residual);
for (r, &b) in residual.iter_mut().zip(bc.iter()) {
*r -= b;
}
for (tf, &f) in total_forecast.iter_mut().zip(fc.iter()) {
*tf += f;
}
}
Ok(total_forecast)
}
pub fn train(&mut self, data: &[f32], n_epochs: usize, lr: f32) -> Result<()> {
let win = self.config.lookback + self.config.horizon;
if data.len() < win {
return Err(TimeSeriesError::InsufficientData {
message: "Training data too short".to_string(),
required: win,
actual: data.len(),
});
}
let windows: Vec<(&[f32], &[f32])> = (0..data.len() - win + 1)
.map(|i| (&data[i..i + self.config.lookback], &data[i + self.config.lookback..i + win]))
.collect();
for _epoch in 0..n_epochs {
for (x_win, y_win) in &windows {
let y_pred = self.forward_train(x_win);
let n = y_win.len() as f32;
let grad_out: Vec<f32> = y_pred
.iter()
.zip(y_win.iter())
.map(|(p, &t)| 2.0 * (p - t) / n)
.collect();
self.backward_sgd(x_win, &grad_out, lr);
}
}
Ok(())
}
fn forward_train(&self, x: &[f32]) -> Vec<f32> {
let mut residual = x.to_vec();
let mut total_forecast = vec![0.0_f32; self.config.horizon];
for stack in &self.stacks {
let (bc, fc) = stack.forward(&residual);
for (r, &b) in residual.iter_mut().zip(bc.iter()) {
*r -= b;
}
for (tf, &f) in total_forecast.iter_mut().zip(fc.iter()) {
*tf += f;
}
}
total_forecast
}
fn backward_sgd(&mut self, x: &[f32], _grad_out: &[f32], lr: f32) {
let eps = 1e-4_f32;
for stack in &mut self.stacks {
for block in &mut stack.blocks {
let h = {
let mut h = x.to_vec();
for layer in &block.layers {
h = layer.forward(&h, true);
}
h
};
let n_out = block.fc_forecast.w.len();
let n_in = block.fc_forecast.w[0].len();
let scale = lr * eps;
for i in 0..n_out {
for j in 0..n_in {
if j < h.len() {
block.fc_forecast.w[i][j] -= scale * h[j];
}
}
block.fc_forecast.b[i] -= scale;
}
}
}
let _ = lr; }
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_forward_output_length() {
let config = NBEATSConfig {
n_stacks: 2,
n_blocks: 2,
n_layers: 2,
n_hidden: 16,
horizon: 6,
lookback: 24,
basis: NBEATSBasis::Generic,
};
let model = NBEATS::new(config);
let x = vec![1.0_f32; 24];
let fc = model.forward(&x).expect("forward pass should succeed");
assert_eq!(fc.len(), 6);
}
#[test]
fn test_trend_basis_forward() {
let config = NBEATSConfig {
n_stacks: 1,
n_blocks: 1,
n_layers: 2,
n_hidden: 8,
horizon: 4,
lookback: 12,
basis: NBEATSBasis::Trend { degree: 2 },
};
let model = NBEATS::new(config);
let x: Vec<f32> = (0..12).map(|i| i as f32).collect();
let fc = model.forward(&x).expect("trend forward pass");
assert_eq!(fc.len(), 4);
}
#[test]
fn test_seasonality_basis_forward() {
let config = NBEATSConfig {
n_stacks: 1,
n_blocks: 1,
n_layers: 2,
n_hidden: 8,
horizon: 4,
lookback: 12,
basis: NBEATSBasis::Seasonality { harmonics: 3 },
};
let model = NBEATS::new(config);
let x = vec![1.0_f32; 12];
let fc = model.forward(&x).expect("seasonality forward pass");
assert_eq!(fc.len(), 4);
}
#[test]
fn test_input_length_mismatch_error() {
let config = NBEATSConfig {
n_stacks: 1,
n_blocks: 1,
n_layers: 1,
n_hidden: 8,
horizon: 4,
lookback: 12,
basis: NBEATSBasis::Generic,
};
let model = NBEATS::new(config);
let x = vec![1.0_f32; 10]; assert!(model.forward(&x).is_err());
}
#[test]
fn test_train_smoke() {
let config = NBEATSConfig {
n_stacks: 1,
n_blocks: 1,
n_layers: 2,
n_hidden: 8,
horizon: 3,
lookback: 9,
basis: NBEATSBasis::Generic,
};
let mut model = NBEATS::new(config);
let data: Vec<f32> = (0..50).map(|i| (i as f32).sin()).collect();
model.train(&data, 2, 0.001).expect("training should succeed");
}
#[test]
fn test_polynomial_basis_values() {
let basis = polynomial_basis(3, 2);
assert_eq!(basis.len(), 3);
assert!((basis[0][0] - 1.0).abs() < 1e-6);
assert!(basis[0][1].abs() < 1e-6);
}
#[test]
fn test_fourier_basis_shape() {
let basis = fourier_basis(8, 4);
assert_eq!(basis.len(), 8);
assert_eq!(basis[0].len(), 8); }
}