use rand::Rng;
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use crate::optimizer::gepa::GEPACandidate;
#[derive(Debug, Clone)]
pub struct ParetoFrontier {
candidates: Vec<GEPACandidate>,
example_to_best: HashMap<usize, Vec<usize>>,
candidate_to_examples: HashMap<usize, HashSet<usize>>,
next_id: usize,
}
impl ParetoFrontier {
pub fn new() -> Self {
Self {
candidates: Vec::new(),
example_to_best: HashMap::new(),
candidate_to_examples: HashMap::new(),
next_id: 0,
}
}
pub fn len(&self) -> usize {
self.candidates.len()
}
pub fn is_empty(&self) -> bool {
self.candidates.is_empty()
}
pub fn candidates(&self) -> &[GEPACandidate] {
&self.candidates
}
pub fn add_candidate(&mut self, mut candidate: GEPACandidate, scores: &[f32]) -> bool {
candidate.id = self.next_id;
self.next_id += 1;
let mut wins_on = HashSet::new();
for (example_idx, &score) in scores.iter().enumerate() {
let current_best = self.example_to_best.get(&example_idx).and_then(|best_ids| {
best_ids
.iter()
.filter_map(|&id| self.candidates.iter().find(|c| c.id == id))
.filter_map(|c| c.example_scores.get(example_idx))
.max_by(|a, b| a.partial_cmp(b).unwrap())
.copied()
});
match current_best {
Some(best_score) if score > best_score => {
wins_on.insert(example_idx);
}
Some(best_score) if (score - best_score).abs() < 1e-6 => {
wins_on.insert(example_idx);
}
None => {
wins_on.insert(example_idx);
}
_ => {}
}
}
if wins_on.is_empty() {
return false;
}
candidate.example_scores = scores.to_vec();
for &example_idx in &wins_on {
let max_score = scores[example_idx];
if let Some(best_ids) = self.example_to_best.get_mut(&example_idx) {
best_ids.retain(|&id| {
if let Some(existing) = self.candidates.iter().find(|c| c.id == id) {
if let Some(&existing_score) = existing.example_scores.get(example_idx) {
(existing_score - max_score).abs() < 1e-6 || existing_score > max_score
} else {
false
}
} else {
false
}
});
if (max_score - scores[example_idx]).abs() < 1e-6 {
best_ids.push(candidate.id);
}
} else {
self.example_to_best.insert(example_idx, vec![candidate.id]);
}
}
self.candidate_to_examples.insert(candidate.id, wins_on);
self.prune_dominated();
self.candidates.push(candidate);
true
}
fn prune_dominated(&mut self) {
let mut still_winning: HashSet<usize> = HashSet::new();
for candidate_ids in self.example_to_best.values() {
still_winning.extend(candidate_ids.iter());
}
self.candidates.retain(|c| still_winning.contains(&c.id));
self.candidate_to_examples
.retain(|id, _| still_winning.contains(id));
}
pub fn sample_proportional_to_coverage(&self) -> Option<&GEPACandidate> {
if self.candidates.is_empty() {
return None;
}
let coverages: Vec<usize> = self
.candidates
.iter()
.map(|c| {
self.candidate_to_examples
.get(&c.id)
.map(|examples| examples.len())
.unwrap_or(0)
})
.collect();
let total_coverage: usize = coverages.iter().sum();
if total_coverage == 0 {
return self.candidates.first();
}
let mut rng = rand::thread_rng();
let mut target = rng.gen_range(0..total_coverage);
for (candidate, &coverage) in self.candidates.iter().zip(coverages.iter()) {
if target < coverage {
return Some(candidate);
}
target -= coverage;
}
self.candidates.last()
}
pub fn best_by_average(&self) -> Option<&GEPACandidate> {
self.candidates.iter().max_by(|a, b| {
let avg_a = a.average_score();
let avg_b = b.average_score();
avg_a.partial_cmp(&avg_b).unwrap()
})
}
pub fn statistics(&self) -> ParetoStatistics {
let num_candidates = self.candidates.len();
let num_examples_covered = self.example_to_best.len();
let coverage_per_candidate: Vec<usize> = self
.candidates
.iter()
.map(|c| {
self.candidate_to_examples
.get(&c.id)
.map(|examples| examples.len())
.unwrap_or(0)
})
.collect();
let avg_coverage = if !coverage_per_candidate.is_empty() {
coverage_per_candidate.iter().sum::<usize>() as f32
/ coverage_per_candidate.len() as f32
} else {
0.0
};
let max_coverage = coverage_per_candidate.iter().copied().max().unwrap_or(0);
let min_coverage = coverage_per_candidate.iter().copied().min().unwrap_or(0);
ParetoStatistics {
num_candidates,
num_examples_covered,
avg_coverage,
max_coverage,
min_coverage,
}
}
}
impl Default for ParetoFrontier {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ParetoStatistics {
pub num_candidates: usize,
pub num_examples_covered: usize,
pub avg_coverage: f32,
pub max_coverage: usize,
pub min_coverage: usize,
}