use rand::SeedableRng;
use rand_chacha::ChaCha8Rng;
use rand_distr::{Distribution, Uniform};
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum BlendStrategy {
Interpolate,
Select,
Ensemble,
}
#[derive(Debug, Clone)]
pub struct HybridGenerator {
weight: f64,
}
impl HybridGenerator {
pub fn new(weight: f64) -> Self {
Self {
weight: weight.clamp(0.0, 1.0),
}
}
pub fn weight(&self) -> f64 {
self.weight
}
pub fn blend(
&self,
rule_based: &[Vec<f64>],
diffusion: &[Vec<f64>],
strategy: BlendStrategy,
seed: u64,
) -> Vec<Vec<f64>> {
let n_rows = rule_based.len().min(diffusion.len());
if n_rows == 0 {
return vec![];
}
match strategy {
BlendStrategy::Interpolate => self.blend_interpolate(rule_based, diffusion, n_rows),
BlendStrategy::Select => self.blend_select(rule_based, diffusion, n_rows, seed),
BlendStrategy::Ensemble => {
self.blend_interpolate(rule_based, diffusion, n_rows)
}
}
}
pub fn blend_ensemble(
&self,
rule_based: &[Vec<f64>],
diffusion: &[Vec<f64>],
diffusion_columns: &[usize],
) -> Vec<Vec<f64>> {
let n_rows = rule_based.len().min(diffusion.len());
if n_rows == 0 {
return vec![];
}
(0..n_rows)
.map(|i| {
let rule_row = &rule_based[i];
let diff_row = &diffusion[i];
let n_cols = rule_row.len().min(diff_row.len());
(0..n_cols)
.map(|j| {
if diffusion_columns.contains(&j) {
diff_row[j]
} else {
rule_row[j]
}
})
.collect()
})
.collect()
}
fn blend_interpolate(
&self,
rule_based: &[Vec<f64>],
diffusion: &[Vec<f64>],
n_rows: usize,
) -> Vec<Vec<f64>> {
let w = self.weight;
(0..n_rows)
.map(|i| {
let rule_row = &rule_based[i];
let diff_row = &diffusion[i];
let n_cols = rule_row.len().min(diff_row.len());
(0..n_cols)
.map(|j| (1.0 - w) * rule_row[j] + w * diff_row[j])
.collect()
})
.collect()
}
fn blend_select(
&self,
rule_based: &[Vec<f64>],
diffusion: &[Vec<f64>],
n_rows: usize,
seed: u64,
) -> Vec<Vec<f64>> {
let mut rng = ChaCha8Rng::seed_from_u64(seed);
let uniform = Uniform::new(0.0_f64, 1.0).expect("valid uniform params");
(0..n_rows)
.map(|i| {
let roll: f64 = uniform.sample(&mut rng);
if roll < self.weight {
diffusion[i].clone()
} else {
rule_based[i].clone()
}
})
.collect()
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
#[test]
fn test_interpolation_produces_blended_output() {
let gen = HybridGenerator::new(0.5);
let rules = vec![vec![10.0, 20.0], vec![30.0, 40.0]];
let diffusion = vec![vec![20.0, 40.0], vec![50.0, 60.0]];
let blended = gen.blend(&rules, &diffusion, BlendStrategy::Interpolate, 0);
assert_eq!(blended.len(), 2);
assert!((blended[0][0] - 15.0).abs() < 1e-10);
assert!((blended[0][1] - 30.0).abs() < 1e-10);
assert!((blended[1][0] - 40.0).abs() < 1e-10);
assert!((blended[1][1] - 50.0).abs() < 1e-10);
}
#[test]
fn test_select_picks_from_both_sources() {
let gen = HybridGenerator::new(0.5);
let rules = vec![vec![0.0]; 1000];
let diffusion = vec![vec![1.0]; 1000];
let blended = gen.blend(&rules, &diffusion, BlendStrategy::Select, 42);
assert_eq!(blended.len(), 1000);
let count_diffusion = blended.iter().filter(|r| r[0] > 0.5).count();
let count_rule = blended.iter().filter(|r| r[0] < 0.5).count();
assert!(
count_diffusion > 100,
"Expected diffusion picks, got {}",
count_diffusion
);
assert!(
count_rule > 100,
"Expected rule-based picks, got {}",
count_rule
);
}
#[test]
fn test_ensemble_uses_correct_columns() {
let gen = HybridGenerator::new(0.5);
let rules = vec![vec![1.0, 2.0, 3.0]];
let diffusion = vec![vec![10.0, 20.0, 30.0]];
let diffusion_cols = vec![1];
let blended = gen.blend_ensemble(&rules, &diffusion, &diffusion_cols);
assert_eq!(blended.len(), 1);
assert!(
(blended[0][0] - 1.0).abs() < 1e-10,
"Column 0 should be rule-based"
);
assert!(
(blended[0][1] - 20.0).abs() < 1e-10,
"Column 1 should be diffusion"
);
assert!(
(blended[0][2] - 3.0).abs() < 1e-10,
"Column 2 should be rule-based"
);
}
#[test]
fn test_weight_zero_returns_rule_based() {
let gen = HybridGenerator::new(0.0);
let rules = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
let diffusion = vec![vec![10.0, 20.0], vec![30.0, 40.0]];
let blended = gen.blend(&rules, &diffusion, BlendStrategy::Interpolate, 0);
for (rule_row, blend_row) in rules.iter().zip(blended.iter()) {
for (&r, &b) in rule_row.iter().zip(blend_row.iter()) {
assert!(
(r - b).abs() < 1e-10,
"weight=0 should return rule-based: {} vs {}",
r,
b
);
}
}
}
#[test]
fn test_weight_one_returns_diffusion() {
let gen = HybridGenerator::new(1.0);
let rules = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
let diffusion = vec![vec![10.0, 20.0], vec![30.0, 40.0]];
let blended = gen.blend(&rules, &diffusion, BlendStrategy::Interpolate, 0);
for (diff_row, blend_row) in diffusion.iter().zip(blended.iter()) {
for (&d, &b) in diff_row.iter().zip(blend_row.iter()) {
assert!(
(d - b).abs() < 1e-10,
"weight=1 should return diffusion: {} vs {}",
d,
b
);
}
}
}
#[test]
fn test_empty_inputs() {
let gen = HybridGenerator::new(0.5);
let empty: Vec<Vec<f64>> = vec![];
let result = gen.blend(&empty, &empty, BlendStrategy::Interpolate, 0);
assert!(result.is_empty());
let result = gen.blend_ensemble(&empty, &empty, &[0]);
assert!(result.is_empty());
}
#[test]
fn test_weight_clamping() {
let gen_low = HybridGenerator::new(-0.5);
assert!((gen_low.weight() - 0.0).abs() < 1e-10);
let gen_high = HybridGenerator::new(1.5);
assert!((gen_high.weight() - 1.0).abs() < 1e-10);
}
}