use super::{Rng, StreamGenerator, TaskType};
#[derive(Debug, Clone)]
pub struct MqarStream {
keys: Vec<f64>,
values: Vec<f64>,
rng: Rng,
n_pairs: usize,
d_key: usize,
d_value: usize,
pair_idx: usize,
in_recall: bool,
drift_flag: bool,
}
impl MqarStream {
pub const DEFAULT_N_PAIRS: usize = 128;
pub const DEFAULT_D_KEY: usize = 8;
pub const DEFAULT_D_VALUE: usize = 4;
pub fn new(seed: u64, d_key: usize, n_pairs: usize) -> Self {
let d_value = (d_key / 2).max(1);
Self::with_config(seed, n_pairs, d_key, d_value)
}
pub fn with_config(seed: u64, n_pairs: usize, d_key: usize, d_value: usize) -> Self {
assert!(n_pairs > 0, "n_pairs must be > 0");
assert!(d_key > 0, "d_key must be > 0");
assert!(d_value > 0, "d_value must be > 0");
let mut rng = Rng::new(seed);
let (keys, values) = Self::generate_epoch(&mut rng, n_pairs, d_key, d_value);
Self {
keys,
values,
rng,
n_pairs,
d_key,
d_value,
pair_idx: 0,
in_recall: false,
drift_flag: false,
}
}
fn generate_epoch(
rng: &mut Rng,
n_pairs: usize,
d_key: usize,
d_value: usize,
) -> (Vec<f64>, Vec<f64>) {
let mut keys = Vec::with_capacity(n_pairs * d_key);
let mut values = Vec::with_capacity(n_pairs * d_value);
for _ in 0..n_pairs {
let mut norm_sq = 0.0;
let mut raw: Vec<f64> = (0..d_key)
.map(|_| {
let v = rng.normal(0.0, 1.0);
norm_sq += v * v;
v
})
.collect();
let norm = norm_sq.sqrt().max(1e-12);
for v in raw.iter_mut() {
*v /= norm;
}
keys.extend_from_slice(&raw);
for _ in 0..d_value {
values.push(rng.uniform_range(-1.0, 1.0));
}
}
(keys, values)
}
fn current_key(&self) -> Vec<f64> {
let start = self.pair_idx * self.d_key;
self.keys[start..start + self.d_key].to_vec()
}
fn current_target(&self) -> f64 {
self.values[self.pair_idx * self.d_value]
}
pub fn n_pairs(&self) -> usize {
self.n_pairs
}
pub fn d_key(&self) -> usize {
self.d_key
}
pub fn d_value(&self) -> usize {
self.d_value
}
pub fn in_recall_phase(&self) -> bool {
self.in_recall
}
}
impl StreamGenerator for MqarStream {
fn next_sample(&mut self) -> (Vec<f64>, f64) {
self.drift_flag = false;
let features = self.current_key();
let target = self.current_target();
self.pair_idx += 1;
if self.pair_idx >= self.n_pairs {
self.pair_idx = 0;
if self.in_recall {
let (new_keys, new_values) =
Self::generate_epoch(&mut self.rng, self.n_pairs, self.d_key, self.d_value);
self.keys = new_keys;
self.values = new_values;
self.in_recall = false;
} else {
self.in_recall = true;
}
self.drift_flag = true;
}
(features, target)
}
fn n_features(&self) -> usize {
self.d_key
}
fn task_type(&self) -> TaskType {
TaskType::Regression
}
fn drift_occurred(&self) -> bool {
self.drift_flag
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn mqar_produces_correct_n_features() {
let d_key = MqarStream::DEFAULT_D_KEY;
let mut gen = MqarStream::new(42, d_key, MqarStream::DEFAULT_N_PAIRS);
let (features, _) = gen.next_sample();
assert_eq!(
features.len(),
d_key,
"features should have d_key={} dims, got {}",
d_key,
features.len()
);
}
#[test]
fn mqar_task_type_is_regression() {
let gen = MqarStream::new(42, MqarStream::DEFAULT_D_KEY, MqarStream::DEFAULT_N_PAIRS);
assert_eq!(gen.task_type(), TaskType::Regression);
}
#[test]
fn mqar_produces_finite_values() {
let mut gen = MqarStream::new(7, 8, 32);
for i in 0..512 {
let (features, target) = gen.next_sample();
for (j, f) in features.iter().enumerate() {
assert!(f.is_finite(), "feature {} at sample {} is not finite", j, i);
}
assert!(target.is_finite(), "target at sample {} is not finite", i);
}
}
#[test]
fn mqar_deterministic_with_same_seed() {
let mut gen1 = MqarStream::new(42, 8, 32);
let mut gen2 = MqarStream::new(42, 8, 32);
for _ in 0..512 {
let (f1, t1) = gen1.next_sample();
let (f2, t2) = gen2.next_sample();
assert_eq!(f1, f2, "same seed should produce identical features");
assert_eq!(t1, t2, "same seed should produce identical targets");
}
}
#[test]
fn mqar_recall_keys_match_bind_keys() {
let n = 16;
let mut gen = MqarStream::with_config(1234, n, 4, 2);
let mut bind_pairs: Vec<(Vec<f64>, f64)> = Vec::new();
for _ in 0..n {
bind_pairs.push(gen.next_sample());
}
let mut recall_pairs: Vec<(Vec<f64>, f64)> = Vec::new();
for _ in 0..n {
recall_pairs.push(gen.next_sample());
}
assert_eq!(bind_pairs.len(), recall_pairs.len());
for (i, ((bf, bt), (rf, rt))) in bind_pairs.iter().zip(recall_pairs.iter()).enumerate() {
assert_eq!(bf, rf, "bind and recall keys must match at pair {}", i);
assert!(
(bt - rt).abs() < 1e-12,
"bind and recall targets must match at pair {}: bind={}, recall={}",
i,
bt,
rt
);
}
}
#[test]
fn mqar_phase_boundary_drift_flag() {
let n = 8;
let mut gen = MqarStream::with_config(99, n, 4, 2);
for i in 0..n - 1 {
gen.next_sample();
assert!(
!gen.drift_occurred(),
"no drift expected at bind sample {}",
i
);
}
gen.next_sample();
assert!(
gen.drift_occurred(),
"drift expected at bind→recall boundary"
);
for i in 0..n - 1 {
gen.next_sample();
assert!(
!gen.drift_occurred(),
"no drift expected at recall sample {}",
i
);
}
gen.next_sample();
assert!(
gen.drift_occurred(),
"drift expected at recall→bind boundary"
);
}
#[test]
fn mqar_custom_config_dimensions() {
let mut gen = MqarStream::with_config(1, 32, 6, 3);
let (features, _) = gen.next_sample();
assert_eq!(features.len(), 6);
assert_eq!(gen.n_pairs(), 32);
assert_eq!(gen.d_key(), 6);
assert_eq!(gen.d_value(), 3);
assert_eq!(gen.n_features(), 6);
let gen2 = MqarStream::new(1, 8, 16);
assert_eq!(gen2.n_features(), 8);
assert_eq!(gen2.n_pairs(), 16);
assert_eq!(gen2.d_value(), 4); }
#[test]
fn mqar_keys_are_unit_norm() {
let n = 16;
let mut gen = MqarStream::with_config(55, n, 8, 4);
for i in 0..n {
let (features, _) = gen.next_sample();
let norm_sq: f64 = features.iter().map(|v| v * v).sum();
assert!(
(norm_sq.sqrt() - 1.0).abs() < 1e-9,
"bind key {} should be unit-norm, got norm_sq={}",
i,
norm_sq
);
}
}
}