use scirs2_core::ndarray::{Array1, Array2};
use scirs2_core::numeric::{Float, FromPrimitive};
use std::fmt::Debug;
use super::config::ActivationFunction;
use super::lstm::LSTMCell;
use crate::error::{Result, TimeSeriesError};
#[derive(Debug, Clone)]
pub enum NBeatsBlockType {
Generic,
Trend,
Seasonality,
}
#[derive(Debug)]
pub struct NBeatsBlock<F: Float + Debug> {
#[allow(dead_code)]
block_type: NBeatsBlockType,
#[allow(dead_code)]
input_size: usize,
#[allow(dead_code)]
output_size: usize,
#[allow(dead_code)]
num_layers: usize,
#[allow(dead_code)]
layer_widths: Vec<usize>,
#[allow(dead_code)]
weights: Vec<Array2<F>>,
#[allow(dead_code)]
biases: Vec<Array1<F>>,
#[allow(dead_code)]
theta_weights: Array2<F>,
#[allow(dead_code)]
theta_bias: Array1<F>,
}
impl<F: Float + Debug + Clone + FromPrimitive> NBeatsBlock<F> {
pub fn new(
block_type: NBeatsBlockType,
input_size: usize,
output_size: usize,
layer_widths: Vec<usize>,
) -> Self {
let num_layers = layer_widths.len();
let mut weights = Vec::new();
let mut biases = Vec::new();
let mut prev_width = input_size;
for &width in &layer_widths {
let scale = F::from(2.0).expect("Failed to convert constant to float")
/ F::from(prev_width).expect("Failed to convert to float");
let std_dev = scale.sqrt();
weights.push(LSTMCell::random_matrix(width, prev_width, std_dev));
biases.push(Array1::zeros(width));
prev_width = width;
}
let theta_size = match block_type {
NBeatsBlockType::Generic => output_size + input_size,
NBeatsBlockType::Trend => 3, NBeatsBlockType::Seasonality => output_size / 2, };
let theta_scale = F::from(2.0).expect("Failed to convert constant to float")
/ F::from(prev_width).expect("Failed to convert to float");
let theta_std = theta_scale.sqrt();
Self {
block_type,
input_size,
output_size,
num_layers,
layer_widths,
weights,
biases,
theta_weights: LSTMCell::random_matrix(theta_size, prev_width, theta_std),
theta_bias: Array1::zeros(theta_size),
}
}
pub fn forward(&self, input: &Array1<F>) -> Result<(Array1<F>, Array1<F>)> {
if input.len() != self.input_size {
return Err(TimeSeriesError::DimensionMismatch {
expected: self.input_size,
actual: input.len(),
});
}
let backcast = Array1::zeros(self.input_size);
let forecast = Array1::zeros(self.output_size);
Ok((backcast, forecast))
}
}
#[derive(Debug, Clone)]
pub enum NBeatsStackType {
Generic,
Trend,
Seasonality,
}
#[derive(Debug)]
pub struct NBeatsStack<F: Float + Debug> {
#[allow(dead_code)]
stack_type: NBeatsStackType,
#[allow(dead_code)]
blocks: Vec<NBeatsBlock<F>>,
}
impl<F: Float + Debug + Clone + FromPrimitive> NBeatsStack<F> {
pub fn new(
stack_type: NBeatsStackType,
input_size: usize,
output_size: usize,
num_blocks: usize,
layer_widths: Vec<usize>,
) -> Self {
let mut blocks = Vec::new();
let block_type = match stack_type {
NBeatsStackType::Generic => NBeatsBlockType::Generic,
NBeatsStackType::Trend => NBeatsBlockType::Trend,
NBeatsStackType::Seasonality => NBeatsBlockType::Seasonality,
};
for _ in 0..num_blocks {
blocks.push(NBeatsBlock::new(
block_type.clone(),
input_size,
output_size,
layer_widths.clone(),
));
}
Self { stack_type, blocks }
}
pub fn forward(&self, input: &Array1<F>) -> Result<(Array1<F>, Array1<F>)> {
let residual = input.clone();
let forecast = Array1::zeros(0); Ok((residual, forecast))
}
}
#[derive(Debug)]
pub struct NBeatsModel<F: Float + Debug> {
#[allow(dead_code)]
stacks: Vec<NBeatsStack<F>>,
#[allow(dead_code)]
input_size: usize,
#[allow(dead_code)]
output_size: usize,
}
impl<F: Float + Debug + Clone + FromPrimitive> NBeatsModel<F> {
pub fn new(
input_size: usize,
output_size: usize,
stack_configs: Vec<(NBeatsStackType, usize, Vec<usize>)>, ) -> Self {
let mut stacks = Vec::new();
for (stack_type, num_blocks, layer_widths) in stack_configs {
stacks.push(NBeatsStack::new(
stack_type,
input_size,
output_size,
num_blocks,
layer_widths,
));
}
Self {
stacks,
input_size,
output_size,
}
}
pub fn forward(&self, input: &Array1<F>) -> Result<Array1<F>> {
if input.len() != self.input_size {
return Err(TimeSeriesError::DimensionMismatch {
expected: self.input_size,
actual: input.len(),
});
}
Ok(Array1::zeros(self.output_size))
}
pub fn forecast(&self, input: &Array1<F>) -> Result<Array1<F>> {
self.forward(input)
}
}