use crate::error::{DatasetsError, Result};
use scirs2_core::ndarray::{Array1, Array2};
use scirs2_core::random::prelude::*;
use scirs2_core::random::rand_distributions::Distribution;
use std::f64::consts::PI;
fn make_rng(seed: u64) -> StdRng {
StdRng::seed_from_u64(seed)
}
fn standard_normals(n: usize, rng: &mut StdRng) -> Result<Vec<f64>> {
let dist = scirs2_core::random::Normal::new(0.0_f64, 1.0_f64).map_err(|e| {
DatasetsError::ComputationError(format!("Normal distribution creation failed: {e}"))
})?;
Ok((0..n).map(|_| dist.sample(rng)).collect())
}
pub fn make_arma(n: usize, ar_params: &[f64], ma_params: &[f64], seed: u64) -> Result<Array1<f64>> {
if n == 0 {
return Err(DatasetsError::InvalidFormat(
"make_arma: n must be > 0".to_string(),
));
}
let p = ar_params.len();
let q = ma_params.len();
let burn_in = (p.max(q) + 50).max(200); let total = burn_in + n;
let mut rng = make_rng(seed);
let noise_raw = standard_normals(total, &mut rng)?;
let mut y = vec![0.0_f64; total];
let mut eps = vec![0.0_f64; total];
for t in 0..total {
eps[t] = noise_raw[t];
let mut val = eps[t];
for (i, &coef) in ar_params.iter().enumerate() {
if t > i {
val += coef * y[t - 1 - i];
}
}
for (i, &coef) in ma_params.iter().enumerate() {
if t > i {
val += coef * eps[t - 1 - i];
}
}
y[t] = val;
}
let result: Vec<f64> = y[burn_in..].to_vec();
Ok(Array1::from_vec(result))
}
pub fn make_lorenz(
n: usize,
dt: f64,
sigma: f64,
rho: f64,
beta: f64,
) -> Result<Array2<f64>> {
if n == 0 {
return Err(DatasetsError::InvalidFormat(
"make_lorenz: n must be > 0".to_string(),
));
}
if dt <= 0.0 {
return Err(DatasetsError::InvalidFormat(
"make_lorenz: dt must be > 0".to_string(),
));
}
let lorenz_deriv = |x: f64, y: f64, z: f64| -> (f64, f64, f64) {
(sigma * (y - x), x * (rho - z) - y, x * y - beta * z)
};
let mut x = 1.0_f64;
let mut y = 1.0_f64;
let mut z = 1.0_f64;
for _ in 0..500 {
let (k1x, k1y, k1z) = lorenz_deriv(x, y, z);
let (k2x, k2y, k2z) = lorenz_deriv(x + 0.5 * dt * k1x, y + 0.5 * dt * k1y, z + 0.5 * dt * k1z);
let (k3x, k3y, k3z) = lorenz_deriv(x + 0.5 * dt * k2x, y + 0.5 * dt * k2y, z + 0.5 * dt * k2z);
let (k4x, k4y, k4z) = lorenz_deriv(x + dt * k3x, y + dt * k3y, z + dt * k3z);
x += dt / 6.0 * (k1x + 2.0 * k2x + 2.0 * k3x + k4x);
y += dt / 6.0 * (k1y + 2.0 * k2y + 2.0 * k3y + k4y);
z += dt / 6.0 * (k1z + 2.0 * k2z + 2.0 * k3z + k4z);
}
let mut out = Array2::zeros((n, 3));
for i in 0..n {
out[[i, 0]] = x;
out[[i, 1]] = y;
out[[i, 2]] = z;
let (k1x, k1y, k1z) = lorenz_deriv(x, y, z);
let (k2x, k2y, k2z) = lorenz_deriv(x + 0.5 * dt * k1x, y + 0.5 * dt * k1y, z + 0.5 * dt * k1z);
let (k3x, k3y, k3z) = lorenz_deriv(x + 0.5 * dt * k2x, y + 0.5 * dt * k2y, z + 0.5 * dt * k2z);
let (k4x, k4y, k4z) = lorenz_deriv(x + dt * k3x, y + dt * k3y, z + dt * k3z);
x += dt / 6.0 * (k1x + 2.0 * k2x + 2.0 * k3x + k4x);
y += dt / 6.0 * (k1y + 2.0 * k2y + 2.0 * k3y + k4y);
z += dt / 6.0 * (k1z + 2.0 * k2z + 2.0 * k3z + k4z);
}
Ok(out)
}
pub fn make_van_der_pol(n: usize, mu: f64, dt: f64) -> Result<Array2<f64>> {
if n == 0 {
return Err(DatasetsError::InvalidFormat(
"make_van_der_pol: n must be > 0".to_string(),
));
}
if dt <= 0.0 {
return Err(DatasetsError::InvalidFormat(
"make_van_der_pol: dt must be > 0".to_string(),
));
}
if mu < 0.0 {
return Err(DatasetsError::InvalidFormat(
"make_van_der_pol: mu must be >= 0".to_string(),
));
}
let vdp_deriv = |x: f64, y: f64| -> (f64, f64) {
(y, mu * (1.0 - x * x) * y - x)
};
let mut x = 2.0_f64;
let mut y = 0.0_f64;
let mut out = Array2::zeros((n, 2));
for i in 0..n {
out[[i, 0]] = x;
out[[i, 1]] = y;
let (k1x, k1y) = vdp_deriv(x, y);
let (k2x, k2y) = vdp_deriv(x + 0.5 * dt * k1x, y + 0.5 * dt * k1y);
let (k3x, k3y) = vdp_deriv(x + 0.5 * dt * k2x, y + 0.5 * dt * k2y);
let (k4x, k4y) = vdp_deriv(x + dt * k3x, y + dt * k3y);
x += dt / 6.0 * (k1x + 2.0 * k2x + 2.0 * k3x + k4x);
y += dt / 6.0 * (k1y + 2.0 * k2y + 2.0 * k3y + k4y);
}
Ok(out)
}
pub fn make_seasonal_ts(
n: usize,
trend: f64,
seasonality: f64,
noise: f64,
period: usize,
seed: u64,
) -> Result<Array1<f64>> {
if n == 0 {
return Err(DatasetsError::InvalidFormat(
"make_seasonal_ts: n must be > 0".to_string(),
));
}
if period == 0 {
return Err(DatasetsError::InvalidFormat(
"make_seasonal_ts: period must be > 0".to_string(),
));
}
if noise < 0.0 {
return Err(DatasetsError::InvalidFormat(
"make_seasonal_ts: noise must be >= 0".to_string(),
));
}
let mut rng = make_rng(seed);
let noise_dist = scirs2_core::random::Normal::new(0.0_f64, noise.max(1e-12)).map_err(|e| {
DatasetsError::ComputationError(format!("Normal distribution creation failed: {e}"))
})?;
let mut out = Array1::zeros(n);
for t in 0..n {
let t_f = t as f64;
let seasonal = seasonality * (2.0 * PI * t_f / period as f64).sin();
let eps = if noise > 0.0 {
noise_dist.sample(&mut rng)
} else {
0.0
};
out[t] = trend * t_f + seasonal + eps;
}
Ok(out)
}
pub fn make_changepoint_series(
n: usize,
changepoints: &[usize],
means: &[f64],
noise_std: f64,
seed: u64,
) -> Result<Array1<f64>> {
if n == 0 {
return Err(DatasetsError::InvalidFormat(
"make_changepoint_series: n must be > 0".to_string(),
));
}
if means.len() != changepoints.len() + 1 {
return Err(DatasetsError::InvalidFormat(format!(
"make_changepoint_series: means.len() ({}) must be changepoints.len() + 1 ({})",
means.len(),
changepoints.len() + 1
)));
}
if noise_std < 0.0 {
return Err(DatasetsError::InvalidFormat(
"make_changepoint_series: noise_std must be >= 0".to_string(),
));
}
for (i, &cp) in changepoints.iter().enumerate() {
if cp == 0 || cp >= n {
return Err(DatasetsError::InvalidFormat(format!(
"make_changepoint_series: changepoint[{i}]={cp} is out of (0, n={n}) range"
)));
}
if i > 0 && cp <= changepoints[i - 1] {
return Err(DatasetsError::InvalidFormat(
"make_changepoint_series: changepoints must be strictly increasing".to_string(),
));
}
}
let mut rng = make_rng(seed);
let noise_dist =
scirs2_core::random::Normal::new(0.0_f64, noise_std.max(1e-12)).map_err(|e| {
DatasetsError::ComputationError(format!("Normal distribution creation failed: {e}"))
})?;
let mut boundaries: Vec<usize> = vec![0];
boundaries.extend_from_slice(changepoints);
boundaries.push(n);
let mut out = Array1::zeros(n);
for seg in 0..means.len() {
let start = boundaries[seg];
let end = boundaries[seg + 1];
let mean = means[seg];
for t in start..end {
let eps = if noise_std > 0.0 {
noise_dist.sample(&mut rng)
} else {
0.0
};
out[t] = mean + eps;
}
}
Ok(out)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_arma_length() {
let y = make_arma(200, &[0.5], &[], 1).expect("make_arma failed");
assert_eq!(y.len(), 200);
}
#[test]
fn test_arma_arma11() {
let y = make_arma(500, &[0.6], &[0.3], 7).expect("make_arma ARMA(1,1) failed");
assert_eq!(y.len(), 500);
let first = y[0];
assert!(y.iter().any(|&v| (v - first).abs() > 1e-10));
}
#[test]
fn test_arma_pure_ma() {
let y = make_arma(100, &[], &[0.5, 0.3], 99).expect("MA(2) failed");
assert_eq!(y.len(), 100);
}
#[test]
fn test_arma_determinism() {
let y1 = make_arma(100, &[0.4], &[0.2], 42).expect("seed 1");
let y2 = make_arma(100, &[0.4], &[0.2], 42).expect("seed 2");
for (a, b) in y1.iter().zip(y2.iter()) {
assert!((a - b).abs() < 1e-12);
}
}
#[test]
fn test_arma_error_n_zero() {
assert!(make_arma(0, &[0.5], &[], 1).is_err());
}
#[test]
fn test_lorenz_shape() {
let traj = make_lorenz(1000, 0.01, 10.0, 28.0, 8.0 / 3.0).expect("lorenz failed");
assert_eq!(traj.nrows(), 1000);
assert_eq!(traj.ncols(), 3);
}
#[test]
fn test_lorenz_not_constant() {
let traj = make_lorenz(200, 0.01, 10.0, 28.0, 8.0 / 3.0).expect("lorenz");
let x0 = traj[[0, 0]];
assert!(traj.column(0).iter().any(|&v| (v - x0).abs() > 0.1));
}
#[test]
fn test_lorenz_error_n_zero() {
assert!(make_lorenz(0, 0.01, 10.0, 28.0, 8.0 / 3.0).is_err());
}
#[test]
fn test_lorenz_error_dt_nonpositive() {
assert!(make_lorenz(100, 0.0, 10.0, 28.0, 8.0 / 3.0).is_err());
assert!(make_lorenz(100, -0.01, 10.0, 28.0, 8.0 / 3.0).is_err());
}
#[test]
fn test_vdp_shape() {
let traj = make_van_der_pol(400, 1.0, 0.01).expect("vdp failed");
assert_eq!(traj.nrows(), 400);
assert_eq!(traj.ncols(), 2);
}
#[test]
fn test_vdp_oscillatory() {
let traj = make_van_der_pol(1000, 0.5, 0.01).expect("vdp oscillatory");
let x_max = traj.column(0).fold(f64::NEG_INFINITY, |a, &b| a.max(b));
let x_min = traj.column(0).fold(f64::INFINITY, |a, &b| a.min(b));
assert!(x_max > 0.5);
assert!(x_min < -0.5);
}
#[test]
fn test_vdp_error_n_zero() {
assert!(make_van_der_pol(0, 1.0, 0.01).is_err());
}
#[test]
fn test_vdp_error_mu_negative() {
assert!(make_van_der_pol(100, -1.0, 0.01).is_err());
}
#[test]
fn test_seasonal_length() {
let y = make_seasonal_ts(120, 0.1, 2.0, 0.5, 12, 42).expect("seasonal ts");
assert_eq!(y.len(), 120);
}
#[test]
fn test_seasonal_no_noise_deterministic() {
let y1 = make_seasonal_ts(50, 0.0, 1.0, 0.0, 10, 0).expect("s1");
let y2 = make_seasonal_ts(50, 0.0, 1.0, 0.0, 10, 0).expect("s2");
for (a, b) in y1.iter().zip(y2.iter()) {
assert!((a - b).abs() < 1e-12);
}
}
#[test]
fn test_seasonal_trend_visible() {
let y = make_seasonal_ts(100, 1.0, 0.0, 0.0, 10, 1).expect("seasonal trend");
assert!(y[99] > y[0]);
}
#[test]
fn test_seasonal_error_period_zero() {
assert!(make_seasonal_ts(100, 0.0, 1.0, 0.0, 0, 1).is_err());
}
#[test]
fn test_changepoint_length() {
let y = make_changepoint_series(300, &[100, 200], &[0.0, 5.0, -3.0], 1.0, 42)
.expect("changepoint");
assert_eq!(y.len(), 300);
}
#[test]
fn test_changepoint_means_visible() {
let y = make_changepoint_series(300, &[100, 200], &[0.0, 10.0, -5.0], 0.0, 0)
.expect("changepoint no noise");
let seg0_mean: f64 = y.slice(scirs2_core::ndarray::s![0..100]).mean().unwrap_or(0.0);
assert!((seg0_mean - 0.0).abs() < 1e-9, "seg0 mean={seg0_mean}");
let seg1_mean: f64 = y.slice(scirs2_core::ndarray::s![100..200]).mean().unwrap_or(0.0);
assert!((seg1_mean - 10.0).abs() < 1e-9, "seg1 mean={seg1_mean}");
let seg2_mean: f64 = y.slice(scirs2_core::ndarray::s![200..300]).mean().unwrap_or(0.0);
assert!((seg2_mean - (-5.0)).abs() < 1e-9, "seg2 mean={seg2_mean}");
}
#[test]
fn test_changepoint_error_wrong_means_len() {
assert!(make_changepoint_series(200, &[50, 150], &[0.0, 1.0], 0.5, 1).is_err());
}
#[test]
fn test_changepoint_error_n_zero() {
assert!(make_changepoint_series(0, &[], &[1.0], 0.5, 1).is_err());
}
#[test]
fn test_changepoint_no_changepoints() {
let y = make_changepoint_series(100, &[], &[3.0], 0.0, 5).expect("single segment");
assert_eq!(y.len(), 100);
assert!((y[0] - 3.0).abs() < 1e-9);
assert!((y[99] - 3.0).abs() < 1e-9);
}
}