use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
use rand_distr::Normal;
#[inline]
fn cbf_shape(class_idx: usize, t: f64, amplitude: f64) -> f64 {
let sigmoid = (6.0 * t - 3.0).exp() / (1.0 + (6.0 * t - 3.0).exp());
match class_idx {
0 => amplitude, 1 => amplitude * sigmoid, _ => amplitude * (1.0 - sigmoid), }
}
pub fn make_cylinder_bell_funnel(
n_samples_per_class: usize,
n_timestamps: usize,
random_seed: Option<u64>,
) -> (Vec<Vec<f64>>, Vec<String>) {
let mut rng = match random_seed {
Some(seed) => StdRng::seed_from_u64(seed),
None => StdRng::from_entropy(),
};
let noise_dist = Normal::new(0.0, 0.1).unwrap();
let total = n_samples_per_class * 3;
let mut x = Vec::with_capacity(total);
let mut y = Vec::with_capacity(total);
for class_idx in 0..3 {
let label = match class_idx {
0 => "cylinder",
1 => "bell",
_ => "funnel",
};
for _ in 0..n_samples_per_class {
let mut ts = vec![0.0; n_timestamps];
let amplitude: f64 = rng.gen_range(4.0..8.0);
let onset = rng.gen_range(n_timestamps / 8..n_timestamps / 4);
let duration = rng.gen_range(n_timestamps / 4..n_timestamps / 2);
let end = (onset + duration).min(n_timestamps);
for i in onset..end {
let t = (i - onset) as f64 / duration as f64;
ts[i] = cbf_shape(class_idx, t, amplitude);
}
for v in &mut ts {
*v += rng.sample(noise_dist);
}
x.push(ts);
y.push(label.to_string());
}
}
(x, y)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cbf_shape() {
let (x, y) = make_cylinder_bell_funnel(10, 128, Some(42));
assert_eq!(x.len(), 30); assert_eq!(y.len(), 30);
for ts in &x {
assert_eq!(ts.len(), 128);
}
}
#[test]
fn test_cbf_labels() {
let (_, y) = make_cylinder_bell_funnel(5, 64, Some(42));
let cylinders = y.iter().filter(|l| *l == "cylinder").count();
let bells = y.iter().filter(|l| *l == "bell").count();
let funnels = y.iter().filter(|l| *l == "funnel").count();
assert_eq!(cylinders, 5);
assert_eq!(bells, 5);
assert_eq!(funnels, 5);
}
#[test]
fn test_cbf_deterministic() {
let (x1, _) = make_cylinder_bell_funnel(3, 32, Some(123));
let (x2, _) = make_cylinder_bell_funnel(3, 32, Some(123));
assert_eq!(x1, x2);
}
}