use std::collections::VecDeque;
use super::WelfordRace;
#[derive(Clone, Debug)]
pub struct DistillationConfig {
pub trigger_after_dominated_samples: u64,
pub replay_buffer_size: usize,
pub distill_weight: f64,
}
impl Default for DistillationConfig {
fn default() -> Self {
Self {
trigger_after_dominated_samples: 500,
replay_buffer_size: 1000,
distill_weight: 0.3,
}
}
}
#[derive(Debug, Clone)]
pub struct DistillationStats {
pub disabled: bool,
pub n_distillations_triggered: u64,
pub last_distillation_at_samples: Option<u64>,
pub candidates_currently_distilling: Vec<usize>,
}
pub(super) struct CandidateDistillState {
pub replay: VecDeque<(Vec<f64>, f64)>,
pub samples_dominated: u64,
pub is_distilling: bool,
}
impl CandidateDistillState {
pub fn new() -> Self {
Self {
replay: VecDeque::new(),
samples_dominated: 0,
is_distilling: false,
}
}
pub fn push(&mut self, features: &[f64], target: f64, capacity: usize) {
if self.replay.len() >= capacity {
self.replay.pop_front();
}
self.replay.push_back((features.to_vec(), target));
}
}
impl WelfordRace {
pub fn with_distillation(mut self, cfg: DistillationConfig) -> Self {
let weight = cfg.distill_weight.clamp(f64::MIN_POSITIVE, 1.0);
self.distill_cfg = Some(DistillationConfig {
distill_weight: weight,
..cfg
});
self.distill_state = self
.candidates
.iter()
.map(|_| CandidateDistillState::new())
.collect();
self.distill_stats = DistillationStats {
disabled: false,
n_distillations_triggered: 0,
last_distillation_at_samples: None,
candidates_currently_distilling: Vec::new(),
};
self
}
pub fn distillation_stats(&self) -> DistillationStats {
self.distill_stats.clone()
}
pub(super) fn run_distillation_pass(
&mut self,
features: &[f64],
target: f64,
pareto_front: &[usize],
winner_idx: usize,
) {
let cfg = match &self.distill_cfg {
Some(c) => c.clone(),
None => return,
};
let n = self.candidates.len();
while self.distill_state.len() < n {
self.distill_state.push(CandidateDistillState::new());
}
for i in 0..n {
if pareto_front.contains(&i) {
self.distill_state[i].samples_dominated = 0;
self.distill_state[i].is_distilling = false;
} else {
self.distill_state[i].samples_dominated += 1;
if self.distill_state[i].samples_dominated >= cfg.trigger_after_dominated_samples {
self.distill_state[i].is_distilling = true;
}
}
}
for i in 0..n {
if i == winner_idx || !self.distill_state[i].is_distilling {
continue;
}
let replay_pairs: Vec<(Vec<f64>, f64)> =
self.distill_state[i].replay.iter().cloned().collect();
if replay_pairs.is_empty() {
continue;
}
for (feat, _orig_target) in &replay_pairs {
let pseudo_target = self.candidates[i].model.predict(feat);
self.candidates[winner_idx].model.train_one_weighted(
feat,
pseudo_target,
cfg.distill_weight,
);
}
let n_samples_now = self.candidates.first().map(|c| c.stats.n).unwrap_or(0);
self.distill_stats.n_distillations_triggered += 1;
self.distill_stats.last_distillation_at_samples = Some(n_samples_now);
}
for i in 0..n {
self.distill_state[i].push(features, target, cfg.replay_buffer_size);
}
self.distill_stats.candidates_currently_distilling = (0..n)
.filter(|&i| self.distill_state[i].is_distilling)
.collect();
}
}