use somatize_core::event::MetricRecord;
pub trait Pruner: Send + Sync {
fn should_prune(
&self,
metric_name: &str,
current_value: f64,
step: usize,
history: &[TrialMetricHistory],
) -> Option<String>;
}
pub struct TrialMetricHistory {
pub trial_id: String,
pub metrics: Vec<MetricRecord>,
}
pub struct MedianPruner {
pub n_warmup_steps: usize,
pub min_trials: usize,
}
impl MedianPruner {
pub fn new(n_warmup_steps: usize) -> Self {
Self {
n_warmup_steps,
min_trials: 1,
}
}
pub fn with_min_trials(mut self, min_trials: usize) -> Self {
self.min_trials = min_trials;
self
}
}
impl Pruner for MedianPruner {
fn should_prune(
&self,
metric_name: &str,
current_value: f64,
step: usize,
history: &[TrialMetricHistory],
) -> Option<String> {
if step < self.n_warmup_steps {
return None;
}
let mut values_at_step: Vec<f64> = history
.iter()
.filter_map(|h| {
h.metrics
.iter()
.filter(|m| m.name == metric_name && m.step == step)
.map(|m| m.value)
.next_back()
})
.collect();
if values_at_step.len() < self.min_trials {
return None;
}
values_at_step.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let median = if values_at_step.len().is_multiple_of(2) {
let mid = values_at_step.len() / 2;
(values_at_step[mid - 1] + values_at_step[mid]) / 2.0
} else {
values_at_step[values_at_step.len() / 2]
};
if current_value < median {
Some(format!(
"value {current_value:.4} below median {median:.4} at step {step}"
))
} else {
None
}
}
}
pub struct PercentilePruner {
pub percentile: f64,
pub n_warmup_steps: usize,
pub min_trials: usize,
}
impl PercentilePruner {
pub fn new(percentile: f64, n_warmup_steps: usize) -> Self {
Self {
percentile,
n_warmup_steps,
min_trials: 1,
}
}
}
impl Pruner for PercentilePruner {
fn should_prune(
&self,
metric_name: &str,
current_value: f64,
step: usize,
history: &[TrialMetricHistory],
) -> Option<String> {
if step < self.n_warmup_steps {
return None;
}
let mut values_at_step: Vec<f64> = history
.iter()
.filter_map(|h| {
h.metrics
.iter()
.filter(|m| m.name == metric_name && m.step == step)
.map(|m| m.value)
.next_back()
})
.collect();
if values_at_step.len() < self.min_trials {
return None;
}
values_at_step.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let idx = ((self.percentile / 100.0) * values_at_step.len() as f64).floor() as usize;
let idx = idx.min(values_at_step.len() - 1);
let threshold = values_at_step[idx];
if current_value < threshold {
Some(format!(
"value {current_value:.4} below p{:.0} threshold {threshold:.4} at step {step}",
self.percentile
))
} else {
None
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use chrono::Utc;
fn make_history(values_per_step: &[Vec<f64>]) -> Vec<TrialMetricHistory> {
values_per_step[0]
.iter()
.enumerate()
.map(|(trial_idx, _)| {
let metrics: Vec<MetricRecord> = values_per_step
.iter()
.enumerate()
.filter_map(|(step, vals)| {
vals.get(trial_idx).map(|&v| MetricRecord {
name: "f1".into(),
value: v,
step,
timestamp: Utc::now(),
})
})
.collect();
TrialMetricHistory {
trial_id: format!("t{trial_idx}"),
metrics,
}
})
.collect()
}
#[test]
fn median_no_prune_during_warmup() {
let pruner = MedianPruner::new(5);
let history = make_history(&[vec![0.9, 0.8, 0.7]]);
assert!(pruner.should_prune("f1", 0.1, 3, &history).is_none());
}
#[test]
fn median_prunes_below_median() {
let pruner = MedianPruner::new(0);
let history = make_history(&[vec![0.7, 0.8, 0.9]]);
assert!(pruner.should_prune("f1", 0.5, 0, &history).is_some());
}
#[test]
fn median_keeps_above_median() {
let pruner = MedianPruner::new(0);
let history = make_history(&[vec![0.7, 0.8, 0.9]]);
assert!(pruner.should_prune("f1", 0.85, 0, &history).is_none());
}
#[test]
fn median_no_prune_insufficient_history() {
let pruner = MedianPruner::new(0).with_min_trials(5);
let history = make_history(&[vec![0.7, 0.8]]);
assert!(pruner.should_prune("f1", 0.1, 0, &history).is_none());
}
#[test]
fn median_empty_history() {
let pruner = MedianPruner::new(0);
assert!(pruner.should_prune("f1", 0.5, 0, &[]).is_none());
}
#[test]
fn percentile_prunes_below_threshold() {
let pruner = PercentilePruner::new(25.0, 0);
let history = make_history(&[vec![0.5, 0.9, 0.3, 0.7]]);
assert!(pruner.should_prune("f1", 0.2, 0, &history).is_some());
}
#[test]
fn percentile_keeps_above_threshold() {
let pruner = PercentilePruner::new(25.0, 0);
let history = make_history(&[vec![0.5, 0.9, 0.3, 0.7]]);
assert!(pruner.should_prune("f1", 0.6, 0, &history).is_none());
}
#[test]
fn percentile_warmup_respected() {
let pruner = PercentilePruner::new(50.0, 10);
let history = make_history(&[vec![0.9]]);
assert!(pruner.should_prune("f1", 0.1, 5, &history).is_none());
}
}