use super::{Rng, StreamGenerator, TaskType};
const LED_SEGMENTS: [[u8; 7]; 10] = [
[1, 1, 1, 1, 1, 1, 0], [0, 1, 1, 0, 0, 0, 0], [1, 1, 0, 1, 1, 0, 1], [1, 1, 1, 1, 0, 0, 1], [0, 1, 1, 0, 0, 1, 1], [1, 0, 1, 1, 0, 1, 1], [1, 0, 1, 1, 1, 1, 1], [1, 1, 1, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 0, 1, 1], ];
#[derive(Debug, Clone)]
pub struct LED {
rng: Rng,
noise_rate: f64,
permutation: [usize; 7],
change_points: Vec<usize>,
sample_idx: usize,
next_change: usize,
drift_flag: bool,
}
impl LED {
pub fn new(seed: u64) -> Self {
Self::with_config(seed, 0.1, Vec::new())
}
pub fn with_config(seed: u64, noise_rate: f64, change_points: Vec<usize>) -> Self {
Self {
rng: Rng::new(seed),
noise_rate,
permutation: [0, 1, 2, 3, 4, 5, 6], change_points,
sample_idx: 0,
next_change: 0,
drift_flag: false,
}
}
fn shuffle_permutation(&mut self) {
for i in (1..7).rev() {
let j = self.rng.uniform_int(i + 1);
self.permutation.swap(i, j);
}
}
}
impl StreamGenerator for LED {
fn next_sample(&mut self) -> (Vec<f64>, f64) {
self.drift_flag = false;
if self.next_change < self.change_points.len()
&& self.sample_idx == self.change_points[self.next_change]
{
self.shuffle_permutation();
self.next_change += 1;
self.drift_flag = true;
}
let digit = self.rng.uniform_int(10);
let segments = &LED_SEGMENTS[digit];
let mut features = vec![0.0; 24];
for (seg_idx, &seg_val) in segments.iter().enumerate() {
let mut val = seg_val as f64;
if self.rng.bernoulli(self.noise_rate) {
val = 1.0 - val;
}
features[self.permutation[seg_idx]] = val;
}
for feat in features.iter_mut().skip(7) {
*feat = if self.rng.bernoulli(0.5) { 1.0 } else { 0.0 };
}
self.sample_idx += 1;
(features, digit as f64)
}
fn n_features(&self) -> usize {
24
}
fn task_type(&self) -> TaskType {
TaskType::MulticlassClassification { n_classes: 10 }
}
fn drift_occurred(&self) -> bool {
self.drift_flag
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn led_produces_correct_n_features() {
let mut gen = LED::new(42);
let (features, _) = gen.next_sample();
assert_eq!(
features.len(),
24,
"LED should produce 24 features, got {}",
features.len()
);
}
#[test]
fn led_task_type_is_multiclass_10() {
let gen = LED::new(42);
assert_eq!(
gen.task_type(),
TaskType::MulticlassClassification { n_classes: 10 },
"LED task type should be 10-class multiclass"
);
}
#[test]
fn led_labels_in_valid_range() {
let mut gen = LED::new(42);
for _ in 0..1000 {
let (_, target) = gen.next_sample();
let digit = target as usize;
assert!(digit < 10, "LED digit should be in 0..9, got {}", target);
assert_eq!(
target, digit as f64,
"LED digit should be an integer, got {}",
target
);
}
}
#[test]
fn led_features_are_binary() {
let mut gen = LED::new(42);
for _ in 0..500 {
let (features, _) = gen.next_sample();
for (j, &f) in features.iter().enumerate() {
assert!(
f == 0.0 || f == 1.0,
"LED feature {} should be 0 or 1, got {}",
j,
f
);
}
}
}
#[test]
fn led_produces_finite_values() {
let mut gen = LED::new(123);
for i in 0..500 {
let (features, target) = gen.next_sample();
for (j, f) in features.iter().enumerate() {
assert!(f.is_finite(), "feature {} at sample {} is not finite", j, i);
}
assert!(target.is_finite(), "target at sample {} is not finite", i);
}
}
#[test]
fn led_drift_at_change_points() {
let mut gen = LED::with_config(42, 0.1, vec![50, 100]);
for _ in 0..50 {
gen.next_sample();
}
gen.next_sample();
assert!(
gen.drift_occurred(),
"drift should occur at change point 50"
);
for _ in 52..100 {
gen.next_sample();
assert!(
!gen.drift_occurred(),
"no drift expected between change points"
);
}
}
#[test]
fn led_deterministic_with_same_seed() {
let mut gen1 = LED::new(42);
let mut gen2 = LED::new(42);
for _ in 0..200 {
let (f1, t1) = gen1.next_sample();
let (f2, t2) = gen2.next_sample();
assert_eq!(f1, f2, "same seed should produce identical features");
assert_eq!(t1, t2, "same seed should produce identical targets");
}
}
#[test]
fn led_all_digits_appear() {
let mut gen = LED::new(42);
let mut seen = [false; 10];
for _ in 0..1000 {
let (_, target) = gen.next_sample();
seen[target as usize] = true;
}
for (d, &s) in seen.iter().enumerate() {
assert!(s, "digit {} was never generated in 1000 samples", d);
}
}
}