use crate::{lag_matrix, lag_matrix_2d, LagError, LagMatrix, MatrixLayout};
use ndarray::prelude::*;
use ndarray::{Array1, OwnedRepr};
pub trait LagMatrixFromArray<A>
where
A: Copy,
{
fn lag_matrix<R: IntoIterator<Item = usize>>(
&self,
lags: R,
fill: A,
stride: usize,
) -> Result<Array2<A>, LagError>;
}
impl<A> LagMatrixFromArray<A> for Array1<A>
where
A: Copy,
{
fn lag_matrix<R: IntoIterator<Item = usize>>(
&self,
lags: R,
fill: A,
stride: usize,
) -> Result<Array2<A>, LagError> {
if let Some(slice) = self.as_slice() {
let lagged = lag_matrix(slice, lags, fill, stride)?;
Ok(make_array(lagged))
} else {
Err(LagError::InvalidMemoryLayout)
}
}
}
impl<A> LagMatrixFromArray<A> for Array2<A>
where
A: Copy,
{
fn lag_matrix<R: IntoIterator<Item = usize>>(
&self,
lags: R,
fill: A,
stride: usize,
) -> Result<Array2<A>, LagError> {
if let Some(slice) = self.as_slice_memory_order() {
if self.is_standard_layout() {
let series_len = self.ncols();
let lagged = lag_matrix_2d(
slice,
MatrixLayout::RowMajor(series_len),
lags,
fill,
stride,
)?;
Ok(make_array_2d_row_major(lagged))
} else {
let series_len = self.nrows();
let lagged = lag_matrix_2d(
slice,
MatrixLayout::ColumnMajor(series_len),
lags,
fill,
stride,
)?;
Ok(make_array_2d_column_major(lagged))
}
} else {
Err(LagError::InvalidMemoryLayout)
}
}
}
fn make_array<A>(matrix: LagMatrix<A>) -> ArrayBase<OwnedRepr<A>, Ix2> {
let array = if matrix.row_stride == matrix.series_length {
Array2::<A>::from_shape_vec((matrix.series_length, matrix.num_lags), matrix.data)
.expect("the shape is valid")
} else {
Array2::<A>::from_shape_vec(
(matrix.series_length, matrix.num_lags).strides((matrix.row_stride, 1)),
matrix.data,
)
.expect("the shape is valid")
};
array
}
fn make_array_2d_row_major<A>(matrix: LagMatrix<A>) -> ArrayBase<OwnedRepr<A>, Ix2> {
let array = if matrix.row_stride == matrix.series_length {
Array2::<A>::from_shape_vec(
(matrix.series_length, matrix.series_count * matrix.num_lags),
matrix.into_vec(),
)
.expect("the shape is valid")
} else {
Array2::<A>::from_shape_vec(
(matrix.series_count * matrix.num_lags, matrix.series_length)
.strides((matrix.row_stride, 1)),
matrix.into_vec(),
)
.expect("the shape is valid")
};
array
}
fn make_array_2d_column_major<A>(matrix: LagMatrix<A>) -> ArrayBase<OwnedRepr<A>, Ix2> {
let array = if matrix.row_stride == matrix.series_length {
Array2::<A>::from_shape_vec(
(matrix.series_count * matrix.num_lags, matrix.series_length),
matrix.into_vec(),
)
.expect("the shape is valid")
.reversed_axes()
} else {
Array2::<A>::from_shape_vec(
(matrix.series_count * matrix.num_lags, matrix.series_length)
.strides((1, matrix.row_stride)),
matrix.into_vec(),
)
.expect("the shape is valid")
.reversed_axes()
};
array
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
#[rustfmt::skip]
fn test_lag() {
let data = Array1::from_iter([42.0, 40.0, 38.0, 36.0]);
let lag = f64::INFINITY;
let array = data.lag_matrix(0..=3, lag, 0).unwrap();
assert_eq!(array.ncols(), 4);
assert_eq!(array.nrows(), 4);
assert_eq!(
array.as_slice().unwrap(),
&[
42.0, 40.0, 38.0, 36.0, lag, 42.0, 40.0, 38.0, lag, lag, 42.0, 40.0, lag, lag, lag, 42.0 ]
);
}
#[test]
#[rustfmt::skip]
fn test_strided_lag_2() {
let data = Array1::from_iter([42.0, 40.0, 38.0, 36.0]);
let lag = f64::INFINITY;
let array = data.lag_matrix(0..=3, lag, 8).unwrap();
assert_eq!(array.ncols(), 4);
assert_eq!(array.nrows(), 4);
assert_eq!(
array.as_standard_layout().as_slice().unwrap(),
&[
42.0, 40.0, 38.0, 36.0, lag, 42.0, 40.0, 38.0, lag, lag, 42.0, 40.0, lag, lag, lag, 42.0 ]
);
}
#[test]
#[rustfmt::skip]
fn test_lag_2d_rowwise() {
let data = [
1.0, 2.0, 3.0, 4.0,
-1.0, -2.0, -3.0, -4.0
];
let data = Array2::from_shape_vec((2, 4), data.into_iter().collect()).unwrap();
let lag = f64::INFINITY;
let array = data.lag_matrix(0..=3, lag, 5).unwrap();
assert_eq!(array.ncols(), 4);
assert_eq!(array.nrows(), 8);
assert_eq!(
array.as_standard_layout().as_slice().unwrap(),
&[
1.0, 2.0, 3.0, 4.0, -1.0, -2.0, -3.0, -4.0,
lag, 1.0, 2.0, 3.0, lag, -1.0, -2.0, -3.0,
lag, lag, 1.0, 2.0, lag, lag, -1.0, -2.0,
lag, lag, lag, 1.0, lag, lag, lag, -1.0,
]
);
}
#[test]
#[rustfmt::skip]
fn test_lag_2d_columnwise() {
let data = [
1.0, -1.0,
2.0, -2.0,
3.0, -3.0,
4.0, -4.0
];
let data = Array2::from_shape_vec((2, 4), data.into_iter().collect())
.unwrap()
.reversed_axes();
let lag = f64::INFINITY;
let array = data.lag_matrix(0..=3, lag, 9).unwrap();
assert_eq!(array.ncols(), 8);
assert_eq!(array.nrows(), 4);
assert_eq!(
array.as_standard_layout().as_slice().unwrap(),
&[
1.0, -1.0, lag, lag, lag, lag, lag, lag,
2.0, -2.0, 1.0, -1.0, lag, lag, lag, lag,
3.0, -3.0, 2.0, -2.0, 1.0, -1.0, lag, lag,
4.0, -4.0, 3.0, -3.0, 2.0, -2.0, 1.0, -1.0
]
);
}
}