use crate::error::{DatasetsError, Result};
use scirs2_core::ndarray::Array1;
use scirs2_core::random::prelude::*;
use scirs2_core::random::rand_distributions::Distribution;
use std::f64::consts::PI;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ChirpMethod {
Linear,
Logarithmic,
Hyperbolic,
}
fn make_rng(seed: u64) -> StdRng {
StdRng::seed_from_u64(seed)
}
fn normal_dist(std: f64) -> Result<scirs2_core::random::rand_distributions::Normal<f64>> {
scirs2_core::random::rand_distributions::Normal::new(0.0_f64, std).map_err(|e| {
DatasetsError::ComputationError(format!("Normal distribution creation failed: {e}"))
})
}
fn uniform_dist(
lo: f64,
hi: f64,
) -> Result<scirs2_core::random::rand_distributions::Uniform<f64>> {
scirs2_core::random::rand_distributions::Uniform::new(lo, hi).map_err(|e| {
DatasetsError::ComputationError(format!("Uniform distribution creation failed: {e}"))
})
}
fn time_axis(n: usize, fs: f64) -> Array1<f64> {
Array1::from_vec((0..n).map(|i| i as f64 / fs).collect())
}
fn check_n_fs(func: &str, n: usize, fs: f64) -> Result<()> {
if n == 0 {
return Err(DatasetsError::InvalidFormat(format!(
"{func}: n_samples must be > 0"
)));
}
if fs <= 0.0 {
return Err(DatasetsError::InvalidFormat(format!(
"{func}: fs must be > 0"
)));
}
Ok(())
}
pub fn ecg_signal(
n_samples: usize,
fs: f64,
heart_rate: f64,
noise_level: f64,
seed: u64,
) -> Result<(Array1<f64>, Array1<f64>)> {
check_n_fs("ecg_signal", n_samples, fs)?;
if heart_rate <= 0.0 {
return Err(DatasetsError::InvalidFormat(
"ecg_signal: heart_rate must be > 0".to_string(),
));
}
if noise_level < 0.0 {
return Err(DatasetsError::InvalidFormat(
"ecg_signal: noise_level must be >= 0".to_string(),
));
}
let mut rng = make_rng(seed);
let t = time_axis(n_samples, fs);
let rr_interval = 60.0 / heart_rate;
let waves: &[(f64, f64, f64)] = &[
(-0.20, 0.20, 0.020), (-0.05, -0.10, 0.010), (0.00, 1.00, 0.008), (0.05, -0.15, 0.010), (0.20, 0.35, 0.040), ];
let mut signal = vec![0.0_f64; n_samples];
let duration = n_samples as f64 / fs;
let mut beat_t = rr_interval / 2.0; while beat_t < duration {
for &(dt_rel, amp, sigma) in waves {
let center = beat_t + dt_rel;
for i in 0..n_samples {
let ti = t[i];
let arg = (ti - center) / sigma;
signal[i] += amp * (-0.5 * arg * arg).exp();
}
}
beat_t += rr_interval;
}
if noise_level > 0.0 {
let noise_dist = normal_dist(noise_level)?;
for s in signal.iter_mut() {
*s += noise_dist.sample(&mut rng);
}
}
Ok((t, Array1::from_vec(signal)))
}
pub fn seismic_trace(
n_samples: usize,
fs: f64,
n_events: usize,
seed: u64,
) -> Result<(Array1<f64>, Array1<f64>)> {
check_n_fs("seismic_trace", n_samples, fs)?;
let mut rng = make_rng(seed);
let t = time_axis(n_samples, fs);
let duration = n_samples as f64 / fs;
let f_dom = 20.0_f64;
let uniform_t = uniform_dist(0.0, duration)?;
let amp_dist = uniform_dist(0.5, 1.5)?;
let polarity_dist = uniform_dist(0.0, 1.0)?;
let noise_dist = normal_dist(0.05)?;
let mut signal = vec![0.0_f64; n_samples];
for _ in 0..n_events {
let t0 = uniform_t.sample(&mut rng);
let amp = amp_dist.sample(&mut rng);
let polarity = if polarity_dist.sample(&mut rng) < 0.5 { 1.0_f64 } else { -1.0_f64 };
for i in 0..n_samples {
let tau = t[i] - t0;
let pi_f_tau = PI * f_dom * tau;
signal[i] += polarity * amp
* (1.0 - 2.0 * pi_f_tau * pi_f_tau)
* (-pi_f_tau * pi_f_tau).exp();
}
}
let peak = signal
.iter()
.copied()
.fold(0.0_f64, |acc, v| acc.max(v.abs()));
let noise_scale = peak * 0.05;
if noise_scale > 0.0 {
let bg_noise = normal_dist(noise_scale)?;
for s in signal.iter_mut() {
*s += bg_noise.sample(&mut rng);
}
} else {
for s in signal.iter_mut() {
*s += noise_dist.sample(&mut rng);
}
}
Ok((t, Array1::from_vec(signal)))
}
pub fn chirp_signal(
n_samples: usize,
fs: f64,
f0: f64,
f1: f64,
method: ChirpMethod,
) -> Result<(Array1<f64>, Array1<f64>)> {
check_n_fs("chirp_signal", n_samples, fs)?;
if f0 <= 0.0 {
return Err(DatasetsError::InvalidFormat(
"chirp_signal: f0 must be > 0".to_string(),
));
}
let t = time_axis(n_samples, fs);
let t_end = (n_samples - 1) as f64 / fs;
let signal: Vec<f64> = (0..n_samples)
.map(|i| {
let ti = t[i];
let phase = match method {
ChirpMethod::Linear => {
let k = (f1 - f0) / t_end;
2.0 * PI * (f0 * ti + 0.5 * k * ti * ti)
}
ChirpMethod::Logarithmic => {
if (f1 - f0).abs() < 1e-12 {
2.0 * PI * f0 * ti
} else {
let ratio = f1 / f0;
let ln_ratio = ratio.ln();
2.0 * PI * f0 * t_end / ln_ratio * (ratio.powf(ti / t_end) - 1.0)
}
}
ChirpMethod::Hyperbolic => {
if (f1 - f0).abs() < 1e-12 {
2.0 * PI * f0 * ti
} else {
let coeff = f0 * f1 * t_end / (f1 - f0);
let arg = 1.0 - (f1 - f0) / (f1 * t_end) * ti;
let arg_clamped = arg.max(1e-15); -2.0 * PI * coeff * arg_clamped.ln()
}
}
};
phase.sin()
})
.collect();
Ok((t, Array1::from_vec(signal)))
}
pub fn am_signal(
carrier_freq: f64,
modulation_freq: f64,
fs: f64,
n_samples: usize,
) -> Result<Array1<f64>> {
check_n_fs("am_signal", n_samples, fs)?;
let sig: Vec<f64> = (0..n_samples)
.map(|i| {
let t = i as f64 / fs;
let envelope = 1.0 + (2.0 * PI * modulation_freq * t).cos();
envelope * (2.0 * PI * carrier_freq * t).cos()
})
.collect();
Ok(Array1::from_vec(sig))
}
pub fn fm_signal(
carrier_freq: f64,
modulation_freq: f64,
beta: f64,
fs: f64,
n_samples: usize,
) -> Result<Array1<f64>> {
check_n_fs("fm_signal", n_samples, fs)?;
let sig: Vec<f64> = (0..n_samples)
.map(|i| {
let t = i as f64 / fs;
let phase = 2.0 * PI * carrier_freq * t + beta * (2.0 * PI * modulation_freq * t).sin();
phase.cos()
})
.collect();
Ok(Array1::from_vec(sig))
}
pub fn sinusoidal_mixture(
frequencies: &[f64],
amplitudes: &[f64],
phases: &[f64],
noise_snr_db: f64,
fs: f64,
n_samples: usize,
seed: u64,
) -> Result<Array1<f64>> {
check_n_fs("sinusoidal_mixture", n_samples, fs)?;
if frequencies.len() != amplitudes.len() || amplitudes.len() != phases.len() {
return Err(DatasetsError::InvalidFormat(
"sinusoidal_mixture: frequencies, amplitudes, and phases must have the same length"
.to_string(),
));
}
if frequencies.is_empty() {
return Err(DatasetsError::InvalidFormat(
"sinusoidal_mixture: at least one frequency component is required".to_string(),
));
}
let mut signal = vec![0.0_f64; n_samples];
for i in 0..n_samples {
let t = i as f64 / fs;
for k in 0..frequencies.len() {
signal[i] += amplitudes[k] * (2.0 * PI * frequencies[k] * t + phases[k]).sin();
}
}
let rms = {
let sum_sq: f64 = signal.iter().map(|v| v * v).sum();
(sum_sq / n_samples as f64).sqrt()
};
let sigma = if rms > 1e-15 {
rms / 10.0_f64.powf(noise_snr_db / 20.0)
} else {
1e-6
};
let mut rng = make_rng(seed);
let noise_dist = normal_dist(sigma)?;
for s in signal.iter_mut() {
*s += noise_dist.sample(&mut rng);
}
Ok(Array1::from_vec(signal))
}
#[cfg(test)]
mod tests {
use super::*;
use std::f64::consts::PI;
#[test]
fn test_ecg_shape() {
let (t, sig) = ecg_signal(2000, 500.0, 70.0, 0.02, 42).expect("ecg failed");
assert_eq!(t.len(), 2000);
assert_eq!(sig.len(), 2000);
}
#[test]
fn test_ecg_time_axis() {
let fs = 250.0;
let n = 500;
let (t, _) = ecg_signal(n, fs, 60.0, 0.0, 0).expect("ecg failed");
assert!((t[0] - 0.0).abs() < 1e-12);
assert!((t[n - 1] - (n - 1) as f64 / fs).abs() < 1e-10);
}
#[test]
fn test_ecg_periodicity() {
let fs = 500.0;
let n = 3000;
let hr = 60.0;
let (_, sig) = ecg_signal(n, fs, hr, 0.0, 42).expect("ecg failed");
let mut peaks = vec![];
for i in 1..n - 1 {
if sig[i] > sig[i - 1] && sig[i] > sig[i + 1] && sig[i] > 0.5 {
peaks.push(i);
}
}
assert!(peaks.len() >= 2, "Expected at least 2 R-peaks, got {}", peaks.len());
let interval = (peaks[1] - peaks[0]) as f64;
let expected = fs * 60.0 / hr;
let rel = (interval - expected).abs() / expected;
assert!(rel < 0.05, "Peak interval {interval} vs expected {expected}");
}
#[test]
fn test_ecg_error_zero_samples() {
assert!(ecg_signal(0, 500.0, 70.0, 0.0, 0).is_err());
}
#[test]
fn test_ecg_error_negative_noise() {
assert!(ecg_signal(100, 500.0, 70.0, -0.1, 0).is_err());
}
#[test]
fn test_seismic_shape() {
let (t, sig) = seismic_trace(1024, 200.0, 4, 0).expect("seismic failed");
assert_eq!(t.len(), 1024);
assert_eq!(sig.len(), 1024);
}
#[test]
fn test_seismic_non_zero() {
let (_, sig) = seismic_trace(500, 100.0, 3, 7).expect("seismic failed");
let max_abs = sig.iter().copied().fold(0.0_f64, |a, v| a.max(v.abs()));
assert!(max_abs > 0.0, "seismic signal is all zeros");
}
#[test]
fn test_chirp_shape() {
let (t, sig) = chirp_signal(4096, 1000.0, 10.0, 400.0, ChirpMethod::Linear)
.expect("chirp failed");
assert_eq!(t.len(), 4096);
assert_eq!(sig.len(), 4096);
}
#[test]
fn test_chirp_unit_amplitude() {
for method in [ChirpMethod::Linear, ChirpMethod::Logarithmic, ChirpMethod::Hyperbolic] {
let (_, sig) =
chirp_signal(2048, 2000.0, 10.0, 800.0, method).expect("chirp failed");
let max_abs = sig.iter().copied().fold(0.0_f64, |a, v| a.max(v.abs()));
assert!(max_abs <= 1.0 + 1e-10, "method={method:?}, max={max_abs}");
}
}
#[test]
fn test_chirp_linear_frequency_sweep() {
let fs = 8000.0_f64;
let n = 8000_usize;
let f0 = 100.0_f64;
let f1 = 1000.0_f64;
let (_, sig) = chirp_signal(n, fs, f0, f1, ChirpMethod::Linear).expect("chirp failed");
let crossing_interval = |start: usize, count: usize| -> Option<f64> {
let mut crossings = vec![];
let mut i = start + 1;
while i < n && crossings.len() < count {
if sig[i - 1] < 0.0 && sig[i] >= 0.0 {
crossings.push(i);
}
i += 1;
}
if crossings.len() < 2 {
return None;
}
let intervals: Vec<f64> = crossings.windows(2).map(|w| (w[1] - w[0]) as f64).collect();
let mean = intervals.iter().sum::<f64>() / intervals.len() as f64;
Some(mean)
};
let period_start = crossing_interval(0, 10).expect("not enough crossings at start");
let period_end =
crossing_interval(n - n / 4, 5).expect("not enough crossings at end");
assert!(
period_end < period_start,
"Expected period_end ({period_end}) < period_start ({period_start})"
);
}
#[test]
fn test_chirp_error_zero_f0() {
assert!(chirp_signal(100, 1000.0, 0.0, 400.0, ChirpMethod::Linear).is_err());
}
#[test]
fn test_am_shape() {
let sig = am_signal(100.0, 5.0, 2000.0, 1000).expect("am failed");
assert_eq!(sig.len(), 1000);
}
#[test]
fn test_am_amplitude_bounds() {
let sig = am_signal(100.0, 5.0, 2000.0, 2000).expect("am failed");
let max_abs = sig.iter().copied().fold(0.0_f64, |a, v| a.max(v.abs()));
assert!(max_abs <= 2.0 + 1e-10, "AM amplitude > 2: {max_abs}");
}
#[test]
fn test_fm_shape() {
let sig = fm_signal(200.0, 10.0, 2.5, 4000.0, 2000).expect("fm failed");
assert_eq!(sig.len(), 2000);
}
#[test]
fn test_fm_unit_amplitude() {
let sig = fm_signal(200.0, 10.0, 5.0, 4000.0, 4000).expect("fm failed");
let max_abs = sig.iter().copied().fold(0.0_f64, |a, v| a.max(v.abs()));
assert!(max_abs <= 1.0 + 1e-10, "FM amplitude > 1: {max_abs}");
}
#[test]
fn test_mixture_shape() {
let sig = sinusoidal_mixture(
&[50.0, 120.0],
&[1.0, 0.5],
&[0.0, 0.0],
30.0,
2000.0,
4096,
99,
)
.expect("mixture failed");
assert_eq!(sig.len(), 4096);
}
#[test]
fn test_mixture_snr_improves_noise() {
let make = |snr: f64| {
sinusoidal_mixture(&[100.0], &[1.0], &[0.0], snr, 4000.0, 8000, 1)
.expect("mixture failed")
};
let high = make(40.0);
let low = make(5.0);
let rms = |s: &Array1<f64>| {
(s.iter().map(|v| v * v).sum::<f64>() / s.len() as f64).sqrt()
};
let rms_high = rms(&high);
let rms_low = rms(&low);
let pure_rms = (0.5_f64).sqrt();
let diff_high = (rms_high - pure_rms).abs();
let diff_low = (rms_low - pure_rms).abs();
assert!(
diff_high <= diff_low,
"High-SNR signal deviates more from ideal: high={diff_high}, low={diff_low}"
);
}
#[test]
fn test_mixture_error_mismatched_slices() {
assert!(sinusoidal_mixture(&[100.0, 200.0], &[1.0], &[0.0], 30.0, 2000.0, 1024, 0).is_err());
}
#[test]
fn test_mixture_error_empty_frequencies() {
assert!(sinusoidal_mixture(&[], &[], &[], 30.0, 2000.0, 1024, 0).is_err());
}
#[test]
fn test_reproducibility() {
let a = ecg_signal(500, 250.0, 60.0, 0.1, 77).expect("ecg failed");
let b = ecg_signal(500, 250.0, 60.0, 0.1, 77).expect("ecg failed");
assert_eq!(a.1, b.1, "ECG should be reproducible");
let c = sinusoidal_mixture(&[100.0], &[1.0], &[0.0], 20.0, 2000.0, 1000, 5)
.expect("mixture failed");
let d = sinusoidal_mixture(&[100.0], &[1.0], &[0.0], 20.0, 2000.0, 1000, 5)
.expect("mixture failed");
assert_eq!(c, d, "Mixture should be reproducible");
}
}