use rand::Rng;
use rand::SeedableRng;
use rand::rngs::StdRng;
use crate::synth_data::{QUALITY_INPUT_DIM, QUALITY_OUTPUT_DIM};
pub const CATEGORICAL_DIMS: usize = 74;
pub const NUMERICAL_DIMS: usize = 22;
pub const OUTCOME_DIMS: usize = 12;
pub const AUX_DIMS: usize = 8;
pub const N_EXPERTS: usize = 4;
const ROUTED_DIM: usize = 8;
#[derive(Debug, Clone)]
pub struct QualitySample {
pub input: Vec<f32>,
pub target: Vec<f32>,
pub expert_assignment: usize,
pub index: usize,
}
impl QualitySample {
pub fn shape_ok(&self) -> bool {
self.input.len() == QUALITY_INPUT_DIM && self.target.len() == QUALITY_OUTPUT_DIM
}
}
pub fn make_quality_decision_dataset(n_samples: usize, seed: u64) -> Vec<QualitySample> {
if n_samples == 0 {
return Vec::new();
}
let mut rng = StdRng::seed_from_u64(seed);
let p_routed = gauss_matrix(&mut rng, QUALITY_INPUT_DIM, ROUTED_DIM, 1.0 / 96f32.sqrt());
let router_w = gauss_matrix(&mut rng, ROUTED_DIM, N_EXPERTS, 1.0);
let mut outcome_w: Vec<Vec<f32>> = Vec::with_capacity(N_EXPERTS);
let mut outcome_b: Vec<f32> = Vec::with_capacity(N_EXPERTS * OUTCOME_DIMS);
for _ in 0..N_EXPERTS {
outcome_w.push(gauss_matrix(&mut rng, OUTCOME_DIMS, ROUTED_DIM, 1.0));
for _ in 0..OUTCOME_DIMS {
outcome_b.push(gauss(&mut rng, 0.0, 0.1));
}
}
let mut aux_w: Vec<Vec<f32>> = Vec::with_capacity(N_EXPERTS);
let mut aux_b: Vec<f32> = Vec::with_capacity(N_EXPERTS * AUX_DIMS);
for _ in 0..N_EXPERTS {
aux_w.push(gauss_matrix(&mut rng, AUX_DIMS, ROUTED_DIM, 1.0));
for _ in 0..AUX_DIMS {
aux_b.push(gauss(&mut rng, 0.0, 0.1));
}
}
let num_bias: Vec<f32> = (0..NUMERICAL_DIMS)
.map(|_| gauss(&mut rng, 0.0, 50.0))
.collect();
let mut out: Vec<QualitySample> = Vec::with_capacity(n_samples);
for idx in 0..n_samples {
let block_dims: [usize; 8] = [9, 11, 14, 7, 5, 16, 8, 4];
let mut cat: Vec<f32> = vec![0.0f32; CATEGORICAL_DIMS];
let mut off = 0usize;
for &d in &block_dims {
let primary = rng.gen_range(0..d);
cat[off + primary] = 1.0;
if d > 1 && rng.gen_range(0.0f32..1.0f32) < 0.15 {
let secondary = (primary + 1) % d;
cat[off + secondary] = 0.7;
}
off += d;
}
debug_assert_eq!(off, CATEGORICAL_DIMS);
let mut num: Vec<f32> = Vec::with_capacity(NUMERICAL_DIMS);
for j in 0..NUMERICAL_DIMS {
let base: f32 = rng.gen_range(0.0f32..1000.0f32);
let jitter: f32 = gauss(&mut rng, 0.0, 5.0);
let v = ((base + num_bias[j] + jitter) / 1000.0f32).clamp(0.0, 1.0);
num.push(v);
}
let mut input: Vec<f32> = Vec::with_capacity(QUALITY_INPUT_DIM);
input.extend_from_slice(&cat);
input.extend_from_slice(&num);
debug_assert_eq!(input.len(), QUALITY_INPUT_DIM);
let x_collapsed = mat_vec(&p_routed, &input, QUALITY_INPUT_DIM, ROUTED_DIM);
let r = mat_vec(&router_w, &x_collapsed, ROUTED_DIM, N_EXPERTS);
let expert = argmax(&r);
let w_e = &outcome_w[expert];
let b_e = &outcome_b[expert * OUTCOME_DIMS..(expert + 1) * OUTCOME_DIMS];
let mut logits = mat_vec(w_e, &x_collapsed, ROUTED_DIM, OUTCOME_DIMS);
for j in 0..OUTCOME_DIMS {
logits[j] += b_e[j] + gauss(&mut rng, 0.0, 0.1);
}
let outcome_soft = softmax(&logits);
let w_aux_e = &aux_w[expert];
let b_aux_e = &aux_b[expert * AUX_DIMS..(expert + 1) * AUX_DIMS];
let mut aux_logits = mat_vec(w_aux_e, &x_collapsed, ROUTED_DIM, AUX_DIMS);
for j in 0..AUX_DIMS {
aux_logits[j] += b_aux_e[j];
}
let aux_target: Vec<f32> = aux_logits
.iter()
.map(|v| {
let p = sigmoid(*v);
let p = if rng.gen_range(0.0f32..1.0f32) < 0.05 {
1.0 - p
} else {
p
};
p.clamp(0.0, 1.0)
})
.collect();
let mut target: Vec<f32> = Vec::with_capacity(QUALITY_OUTPUT_DIM);
target.extend_from_slice(&outcome_soft);
target.extend_from_slice(&aux_target);
debug_assert_eq!(target.len(), QUALITY_OUTPUT_DIM);
out.push(QualitySample {
input,
target,
expert_assignment: expert,
index: idx,
});
}
out
}
#[inline]
fn gauss<R: Rng>(rng: &mut R, mean: f32, std: f32) -> f32 {
let u1: f32 = rng.gen_range((1.0e-7f32)..1.0f32);
let u2: f32 = rng.gen_range(0.0f32..1.0f32);
let z0 = (-2.0f32 * u1.ln()).sqrt() * (2.0f32 * std::f32::consts::PI * u2).cos();
mean + std * z0
}
fn gauss_matrix<R: Rng>(rng: &mut R, rows: usize, cols: usize, std: f32) -> Vec<f32> {
let mut m = Vec::with_capacity(rows * cols);
for _ in 0..(rows * cols) {
m.push(gauss(rng, 0.0, std));
}
m
}
fn mat_vec(m: &[f32], x: &[f32], cols: usize, rows: usize) -> Vec<f32> {
let mut y = vec![0.0f32; rows];
for j in 0..rows {
let mut acc = 0.0f32;
for i in 0..cols {
acc += m[j * cols + i] * x[i];
}
y[j] = acc;
}
y
}
fn softmax(logits: &[f32]) -> Vec<f32> {
let max = logits.iter().fold(f32::NEG_INFINITY, |m, v| m.max(*v));
let mut out: Vec<f32> = logits.iter().map(|v| (v - max).exp()).collect();
let s: f32 = out.iter().sum();
if s > 0.0 {
for v in out.iter_mut() {
*v /= s;
}
} else {
let u = 1.0f32 / (out.len() as f32);
for v in out.iter_mut() {
*v = u;
}
}
out
}
fn sigmoid(x: f32) -> f32 {
if x > 40.0 {
1.0
} else if x < -40.0 {
0.0
} else {
1.0 / (1.0 + (-x).exp())
}
}
fn argmax(xs: &[f32]) -> usize {
let mut best = 0usize;
let mut best_v = f32::NEG_INFINITY;
for (i, v) in xs.iter().enumerate() {
if *v > best_v {
best_v = *v;
best = i;
}
}
best
}