use super::{Rng, StreamGenerator, TaskType};
fn h1(i: usize) -> f64 {
let i = i as f64;
if i <= 7.0 { i } else { 14.0 - i }.max(0.0)
}
fn h2(i: usize) -> f64 {
let i = i as f64;
if i <= 11.0 {
(i - 4.0).max(0.0)
} else {
(18.0 - i).max(0.0)
}
}
fn h3(i: usize) -> f64 {
let i = i as f64;
if i <= 15.0 {
(i - 8.0).max(0.0)
} else {
(22.0 - i).max(0.0)
}
}
#[derive(Debug, Clone)]
pub struct Waveform {
rng: Rng,
has_noise: bool,
}
impl Waveform {
pub fn new(seed: u64) -> Self {
Self::with_noise(seed, false)
}
pub fn with_noise(seed: u64, has_noise_features: bool) -> Self {
Self {
rng: Rng::new(seed),
has_noise: has_noise_features,
}
}
}
impl StreamGenerator for Waveform {
fn next_sample(&mut self) -> (Vec<f64>, f64) {
let class = self.rng.uniform_int(3);
let u1 = self.rng.uniform();
let u2 = 1.0 - u1;
let n_base = 21;
let n_total = if self.has_noise { 40 } else { 21 };
let mut features = Vec::with_capacity(n_total);
for i in 1..=n_base {
let base = match class {
0 => u1 * h1(i) + u2 * h2(i),
1 => u1 * h1(i) + u2 * h3(i),
2 => u1 * h2(i) + u2 * h3(i),
_ => unreachable!(),
};
features.push(base + self.rng.normal(0.0, 1.0));
}
if self.has_noise {
for _ in 0..19 {
features.push(self.rng.normal(0.0, 1.0));
}
}
(features, class as f64)
}
fn n_features(&self) -> usize {
if self.has_noise {
40
} else {
21
}
}
fn task_type(&self) -> TaskType {
TaskType::MulticlassClassification { n_classes: 3 }
}
fn drift_occurred(&self) -> bool {
false }
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn waveform_produces_correct_n_features_default() {
let mut gen = Waveform::new(42);
let (features, _) = gen.next_sample();
assert_eq!(
features.len(),
21,
"Waveform (no noise) should produce 21 features, got {}",
features.len()
);
}
#[test]
fn waveform_produces_correct_n_features_with_noise() {
let mut gen = Waveform::with_noise(42, true);
let (features, _) = gen.next_sample();
assert_eq!(
features.len(),
40,
"Waveform (with noise) should produce 40 features, got {}",
features.len()
);
}
#[test]
fn waveform_task_type_is_multiclass_3() {
let gen = Waveform::new(42);
assert_eq!(
gen.task_type(),
TaskType::MulticlassClassification { n_classes: 3 },
"Waveform task type should be 3-class multiclass"
);
}
#[test]
fn waveform_labels_in_valid_range() {
let mut gen = Waveform::new(42);
for _ in 0..1000 {
let (_, target) = gen.next_sample();
let c = target as usize;
assert!(c < 3, "Waveform class should be 0, 1, or 2, got {}", target);
assert_eq!(target, c as f64, "Waveform class should be an integer");
}
}
#[test]
fn waveform_produces_finite_values() {
let mut gen = Waveform::with_noise(123, true);
for i in 0..1000 {
let (features, target) = gen.next_sample();
for (j, f) in features.iter().enumerate() {
assert!(f.is_finite(), "feature {} at sample {} is not finite", j, i);
}
assert!(target.is_finite(), "target at sample {} is not finite", i);
}
}
#[test]
fn waveform_no_drift() {
let mut gen = Waveform::new(42);
for _ in 0..500 {
gen.next_sample();
assert!(!gen.drift_occurred(), "Waveform should never signal drift");
}
}
#[test]
fn waveform_all_classes_appear() {
let mut gen = Waveform::new(42);
let mut seen = [false; 3];
for _ in 0..300 {
let (_, target) = gen.next_sample();
seen[target as usize] = true;
}
for (c, &s) in seen.iter().enumerate() {
assert!(s, "class {} was never generated in 300 samples", c);
}
}
#[test]
fn waveform_deterministic_with_same_seed() {
let mut gen1 = Waveform::new(42);
let mut gen2 = Waveform::new(42);
for _ in 0..200 {
let (f1, t1) = gen1.next_sample();
let (f2, t2) = gen2.next_sample();
assert_eq!(f1, f2, "same seed should produce identical features");
assert_eq!(t1, t2, "same seed should produce identical targets");
}
}
#[test]
fn waveform_templates_have_correct_peaks() {
assert_eq!(h1(7), 7.0, "h1 should peak at position 7");
assert_eq!(h2(11), 7.0, "h2 should peak at position 11");
assert_eq!(h3(15), 7.0, "h3 should peak at position 15");
assert_eq!(h1(0), 0.0, "h1(0) should be 0");
assert_eq!(h1(14), 0.0, "h1(14) should be 0");
}
}