use scirs2_core::ndarray::{Array1, Array2};
use scirs2_core::numeric::{Float, FromPrimitive};
use std::fmt::Debug;
use super::nn_utils;
use crate::error::{Result, TimeSeriesError};
#[derive(Debug, Clone)]
pub struct NHiTSConfig {
pub seq_len: usize,
pub pred_len: usize,
pub n_channels: usize,
pub n_stacks: usize,
pub n_blocks: usize,
pub hidden_size: usize,
pub n_pool_kernel_size: Vec<usize>,
pub seed: u32,
}
impl Default for NHiTSConfig {
fn default() -> Self {
Self {
seq_len: 96,
pred_len: 24,
n_channels: 1,
n_stacks: 3,
n_blocks: 1,
hidden_size: 256,
n_pool_kernel_size: vec![16, 8, 1],
seed: 42,
}
}
}
fn max_pool_1d<F: Float>(signal: &Array1<F>, kernel_size: usize) -> Array1<F> {
if kernel_size <= 1 {
return signal.clone();
}
let n = signal.len();
let out_len = (n + kernel_size - 1) / kernel_size;
let mut out = Array1::zeros(out_len);
for i in 0..out_len {
let start = i * kernel_size;
let end = (start + kernel_size).min(n);
let mut max_val = F::neg_infinity();
for j in start..end {
if signal[j] > max_val {
max_val = signal[j];
}
}
out[i] = max_val;
}
out
}
fn interpolate_1d<F: Float + FromPrimitive>(signal: &Array1<F>, target_len: usize) -> Array1<F> {
let src_len = signal.len();
if src_len == 0 || target_len == 0 {
return Array1::zeros(target_len);
}
if src_len == target_len {
return signal.clone();
}
if src_len == 1 {
return Array1::from_elem(target_len, signal[0]);
}
let mut out = Array1::zeros(target_len);
let scale = F::from((src_len - 1) as f64 / (target_len - 1).max(1) as f64)
.unwrap_or_else(|| F::one());
for i in 0..target_len {
let pos = F::from(i as f64).unwrap_or_else(|| F::zero()) * scale;
let lo = pos.floor();
let hi = pos.ceil();
let frac = pos - lo;
let lo_idx = {
let v = lo.to_f64().unwrap_or(0.0) as usize;
v.min(src_len - 1)
};
let hi_idx = {
let v = hi.to_f64().unwrap_or(0.0) as usize;
v.min(src_len - 1)
};
out[i] = signal[lo_idx] * (F::one() - frac) + signal[hi_idx] * frac;
}
out
}
#[derive(Debug)]
pub struct NHiTSBlock<F: Float + Debug> {
pooled_size: usize,
seq_len: usize,
forecast_size: usize,
w1: Array2<F>,
b1: Array1<F>,
w2: Array2<F>,
b2: Array1<F>,
w_backcast: Array2<F>,
b_backcast: Array1<F>,
w_forecast: Array2<F>,
b_forecast: Array1<F>,
}
impl<F: Float + FromPrimitive + Debug> NHiTSBlock<F> {
pub fn new(
pooled_size: usize,
seq_len: usize,
forecast_size: usize,
hidden_size: usize,
seed: u32,
) -> Self {
Self {
pooled_size,
seq_len,
forecast_size,
w1: nn_utils::xavier_matrix(hidden_size, pooled_size, seed),
b1: nn_utils::zero_bias(hidden_size),
w2: nn_utils::xavier_matrix(hidden_size, hidden_size, seed.wrapping_add(100)),
b2: nn_utils::zero_bias(hidden_size),
w_backcast: nn_utils::xavier_matrix(seq_len, hidden_size, seed.wrapping_add(200)),
b_backcast: nn_utils::zero_bias(seq_len),
w_forecast: nn_utils::xavier_matrix(forecast_size, hidden_size, seed.wrapping_add(300)),
b_forecast: nn_utils::zero_bias(forecast_size),
}
}
pub fn forward(&self, pooled: &Array1<F>) -> (Array1<F>, Array1<F>) {
let h1 = nn_utils::dense_forward_vec(pooled, &self.w1, &self.b1);
let h1_act = nn_utils::relu_1d(&h1);
let h2 = nn_utils::dense_forward_vec(&h1_act, &self.w2, &self.b2);
let h2_act = nn_utils::relu_1d(&h2);
let backcast = nn_utils::dense_forward_vec(&h2_act, &self.w_backcast, &self.b_backcast);
let forecast = nn_utils::dense_forward_vec(&h2_act, &self.w_forecast, &self.b_forecast);
(backcast, forecast)
}
}
#[derive(Debug)]
pub struct NHiTSStack<F: Float + Debug> {
pool_kernel_size: usize,
seq_len: usize,
pred_len: usize,
forecast_size: usize,
blocks: Vec<NHiTSBlock<F>>,
}
impl<F: Float + FromPrimitive + Debug> NHiTSStack<F> {
pub fn new(
seq_len: usize,
pred_len: usize,
n_blocks: usize,
hidden_size: usize,
pool_kernel_size: usize,
seed: u32,
) -> Self {
let ks = pool_kernel_size.max(1);
let pooled_size = (seq_len + ks - 1) / ks;
let expr_rate = ks;
let forecast_size = (pred_len + expr_rate - 1) / expr_rate;
let forecast_size = forecast_size.max(1);
let mut blocks = Vec::with_capacity(n_blocks);
for i in 0..n_blocks {
blocks.push(NHiTSBlock::new(
pooled_size,
seq_len,
forecast_size,
hidden_size,
seed.wrapping_add(i as u32 * 1000),
));
}
Self {
pool_kernel_size: ks,
seq_len,
pred_len,
forecast_size,
blocks,
}
}
pub fn forward(&self, residual: &Array1<F>) -> (Array1<F>, Array1<F>) {
let pooled = max_pool_1d(residual, self.pool_kernel_size);
let mut current_residual = residual.clone();
let mut stack_forecast = Array1::zeros(self.pred_len);
for block in &self.blocks {
let pooled_input = if pooled.len() != block.pooled_size {
let mut adj = Array1::zeros(block.pooled_size);
let copy_len = pooled.len().min(block.pooled_size);
for k in 0..copy_len {
adj[k] = pooled[k];
}
adj
} else {
pooled.clone()
};
let (backcast, forecast_reduced) = block.forward(&pooled_input);
for t in 0..self.seq_len {
current_residual[t] = current_residual[t] - backcast[t];
}
let forecast_full = interpolate_1d(&forecast_reduced, self.pred_len);
for t in 0..self.pred_len {
stack_forecast[t] = stack_forecast[t] + forecast_full[t];
}
}
(current_residual, stack_forecast)
}
}
#[derive(Debug)]
pub struct NHiTSModel<F: Float + Debug> {
config: NHiTSConfig,
stacks: Vec<Vec<NHiTSStack<F>>>,
}
impl<F: Float + FromPrimitive + Debug> NHiTSModel<F> {
pub fn new(config: NHiTSConfig) -> Result<Self> {
if config.seq_len == 0 {
return Err(TimeSeriesError::InvalidInput(
"seq_len must be positive".to_string(),
));
}
if config.pred_len == 0 {
return Err(TimeSeriesError::InvalidInput(
"pred_len must be positive".to_string(),
));
}
if config.n_channels == 0 {
return Err(TimeSeriesError::InvalidInput(
"n_channels must be positive".to_string(),
));
}
if config.n_stacks == 0 {
return Err(TimeSeriesError::InvalidInput(
"n_stacks must be positive".to_string(),
));
}
if config.n_blocks == 0 {
return Err(TimeSeriesError::InvalidInput(
"n_blocks must be positive".to_string(),
));
}
if config.n_pool_kernel_size.len() != config.n_stacks {
return Err(TimeSeriesError::InvalidInput(format!(
"n_pool_kernel_size length ({}) must equal n_stacks ({})",
config.n_pool_kernel_size.len(),
config.n_stacks
)));
}
let mut stacks = Vec::with_capacity(config.n_channels);
for ch in 0..config.n_channels {
let mut ch_stacks = Vec::with_capacity(config.n_stacks);
for (s, &pool_ks) in config.n_pool_kernel_size.iter().enumerate() {
let stack_seed = config
.seed
.wrapping_add(ch as u32 * 10000)
.wrapping_add(s as u32 * 1000);
ch_stacks.push(NHiTSStack::new(
config.seq_len,
config.pred_len,
config.n_blocks,
config.hidden_size,
pool_ks,
stack_seed,
));
}
stacks.push(ch_stacks);
}
Ok(Self { config, stacks })
}
pub fn forecast(&self, input: &Array2<F>) -> Result<Array2<F>> {
let (seq_len, n_ch) = input.dim();
if seq_len != self.config.seq_len {
return Err(TimeSeriesError::DimensionMismatch {
expected: self.config.seq_len,
actual: seq_len,
});
}
if n_ch != self.config.n_channels {
return Err(TimeSeriesError::DimensionMismatch {
expected: self.config.n_channels,
actual: n_ch,
});
}
let mut output = Array2::zeros((self.config.pred_len, n_ch));
for ch in 0..n_ch {
let mut channel_series: Array1<F> = Array1::zeros(seq_len);
for t in 0..seq_len {
channel_series[t] = input[[t, ch]];
}
let mut residual = channel_series.clone();
let mut total_forecast: Array1<F> = Array1::zeros(self.config.pred_len);
for stack in &self.stacks[ch] {
let (new_residual, stack_forecast) = stack.forward(&residual);
residual = new_residual;
for t in 0..self.config.pred_len {
total_forecast[t] = total_forecast[t] + stack_forecast[t];
}
}
for t in 0..self.config.pred_len {
output[[t, ch]] = total_forecast[t];
}
}
Ok(output)
}
pub fn config(&self) -> &NHiTSConfig {
&self.config
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_input(seq_len: usize, n_channels: usize) -> Array2<f64> {
let mut arr = Array2::zeros((seq_len, n_channels));
for t in 0..seq_len {
for c in 0..n_channels {
arr[[t, c]] = (t as f64 * 0.1 + c as f64) as f64;
}
}
arr
}
#[test]
fn test_default_config_values() {
let cfg = NHiTSConfig::default();
assert_eq!(cfg.seq_len, 96);
assert_eq!(cfg.pred_len, 24);
assert_eq!(cfg.n_channels, 1);
assert_eq!(cfg.n_stacks, 3);
assert_eq!(cfg.n_blocks, 1);
assert_eq!(cfg.hidden_size, 256);
assert_eq!(cfg.n_pool_kernel_size, vec![16, 8, 1]);
assert_eq!(cfg.seed, 42);
}
#[test]
fn test_model_creation_default() {
let model = NHiTSModel::<f64>::new(NHiTSConfig::default());
assert!(model.is_ok());
}
#[test]
fn test_model_creation_invalid_seq_len() {
let cfg = NHiTSConfig {
seq_len: 0,
..NHiTSConfig::default()
};
assert!(NHiTSModel::<f64>::new(cfg).is_err());
}
#[test]
fn test_model_creation_invalid_kernel_size_len() {
let cfg = NHiTSConfig {
n_stacks: 3,
n_pool_kernel_size: vec![8, 4], ..NHiTSConfig::default()
};
assert!(NHiTSModel::<f64>::new(cfg).is_err());
}
#[test]
fn test_block_output_shapes() {
let pooled_size = 12;
let seq_len = 96;
let forecast_size = 6;
let hidden = 32;
let block = NHiTSBlock::<f64>::new(pooled_size, seq_len, forecast_size, hidden, 42);
let pooled = Array1::zeros(pooled_size);
let (backcast, forecast) = block.forward(&pooled);
assert_eq!(backcast.len(), seq_len, "backcast length mismatch");
assert_eq!(forecast.len(), forecast_size, "forecast length mismatch");
}
#[test]
fn test_stack_output_shapes() {
let seq_len = 48;
let pred_len = 12;
let stack = NHiTSStack::<f64>::new(seq_len, pred_len, 2, 64, 8, 42);
let residual = Array1::zeros(seq_len);
let (new_res, forecast) = stack.forward(&residual);
assert_eq!(new_res.len(), seq_len, "residual shape mismatch");
assert_eq!(forecast.len(), pred_len, "forecast shape mismatch");
}
#[test]
fn test_forecast_shape_single_channel() {
let cfg = NHiTSConfig {
seq_len: 48,
pred_len: 12,
n_channels: 1,
n_stacks: 2,
n_blocks: 1,
hidden_size: 32,
n_pool_kernel_size: vec![8, 1],
seed: 42,
};
let model = NHiTSModel::<f64>::new(cfg).expect("model creation failed");
let input = make_input(48, 1);
let output = model.forecast(&input).expect("forecast failed");
assert_eq!(output.dim(), (12, 1));
}
#[test]
fn test_forecast_shape_multichannel() {
let cfg = NHiTSConfig {
seq_len: 96,
pred_len: 24,
n_channels: 7,
n_stacks: 3,
n_blocks: 1,
hidden_size: 64,
n_pool_kernel_size: vec![16, 8, 1],
seed: 42,
};
let model = NHiTSModel::<f64>::new(cfg).expect("model creation failed");
let input = make_input(96, 7);
let output = model.forecast(&input).expect("forecast failed");
assert_eq!(output.dim(), (24, 7));
}
#[test]
fn test_forecast_output_is_finite() {
let cfg = NHiTSConfig {
seq_len: 32,
pred_len: 8,
n_channels: 3,
n_stacks: 2,
n_blocks: 2,
hidden_size: 32,
n_pool_kernel_size: vec![4, 1],
seed: 7,
};
let model = NHiTSModel::<f64>::new(cfg).expect("model creation failed");
let input = make_input(32, 3);
let output = model.forecast(&input).expect("forecast failed");
for pred_t in 0..8 {
for ch in 0..3 {
assert!(
output[[pred_t, ch]].is_finite(),
"Non-finite at [{pred_t},{ch}]"
);
}
}
}
#[test]
fn test_wrong_seq_len_returns_error() {
let cfg = NHiTSConfig {
seq_len: 48,
pred_len: 12,
n_channels: 1,
n_stacks: 1,
n_blocks: 1,
hidden_size: 16,
n_pool_kernel_size: vec![4],
seed: 1,
};
let model = NHiTSModel::<f64>::new(cfg).expect("model creation failed");
let bad_input = make_input(32, 1); assert!(model.forecast(&bad_input).is_err());
}
#[test]
fn test_wrong_n_channels_returns_error() {
let cfg = NHiTSConfig {
seq_len: 32,
pred_len: 8,
n_channels: 3,
n_stacks: 1,
n_blocks: 1,
hidden_size: 16,
n_pool_kernel_size: vec![4],
seed: 1,
};
let model = NHiTSModel::<f64>::new(cfg).expect("model creation failed");
let bad_input = make_input(32, 5); assert!(model.forecast(&bad_input).is_err());
}
#[test]
fn test_n_stacks_effect_on_residual() {
let base_cfg = NHiTSConfig {
seq_len: 32,
pred_len: 8,
n_channels: 1,
n_stacks: 1,
n_blocks: 1,
hidden_size: 16,
n_pool_kernel_size: vec![4],
seed: 42,
};
let three_stack_cfg = NHiTSConfig {
seq_len: 32,
pred_len: 8,
n_channels: 1,
n_stacks: 3,
n_blocks: 1,
hidden_size: 16,
n_pool_kernel_size: vec![8, 4, 1],
seed: 42,
};
let m1 = NHiTSModel::<f64>::new(base_cfg).expect("model1 creation failed");
let m3 = NHiTSModel::<f64>::new(three_stack_cfg).expect("model3 creation failed");
let input = make_input(32, 1);
let out1 = m1.forecast(&input).expect("forecast1 failed");
let out3 = m3.forecast(&input).expect("forecast3 failed");
assert_eq!(out1.dim(), (8, 1));
assert_eq!(out3.dim(), (8, 1));
}
#[test]
fn test_max_pool_1d_basic() {
let sig = Array1::from_vec(vec![1.0_f64, 3.0, 2.0, 4.0]);
let out = max_pool_1d(&sig, 2);
assert_eq!(out.len(), 2);
assert!((out[0] - 3.0).abs() < 1e-12); assert!((out[1] - 4.0).abs() < 1e-12); }
#[test]
fn test_interpolate_1d_basic() {
let sig = Array1::from_vec(vec![0.0_f64, 1.0]);
let out = interpolate_1d(&sig, 5);
assert_eq!(out.len(), 5);
assert!((out[0] - 0.0).abs() < 1e-10);
assert!((out[4] - 1.0).abs() < 1e-10);
}
#[test]
fn test_n_blocks_greater_than_1() {
let cfg = NHiTSConfig {
seq_len: 32,
pred_len: 8,
n_channels: 2,
n_stacks: 2,
n_blocks: 3,
hidden_size: 16,
n_pool_kernel_size: vec![4, 1],
seed: 99,
};
let model = NHiTSModel::<f64>::new(cfg).expect("model creation failed");
let input = make_input(32, 2);
let output = model.forecast(&input).expect("forecast failed");
assert_eq!(output.dim(), (8, 2));
}
#[test]
fn test_pool_kernel_size_1_acts_as_identity() {
let cfg = NHiTSConfig {
seq_len: 16,
pred_len: 4,
n_channels: 1,
n_stacks: 1,
n_blocks: 1,
hidden_size: 8,
n_pool_kernel_size: vec![1],
seed: 0,
};
let model = NHiTSModel::<f64>::new(cfg).expect("model creation failed");
let input = make_input(16, 1);
let output = model.forecast(&input).expect("forecast failed");
assert_eq!(output.dim(), (4, 1));
}
#[test]
fn test_forecast_shape_standard_config() {
let model = NHiTSModel::<f64>::new(NHiTSConfig::default()).expect("model creation failed");
let input = make_input(96, 1);
let output = model.forecast(&input).expect("forecast failed");
assert_eq!(output.dim(), (24, 1));
}
}