use std::collections::{HashMap, HashSet};
use std::hash::Hash;
use rand::Rng;
use rand::rngs::StdRng;
use crate::core::state::ProgramIdx;
use crate::error::{GEPAError, Result};
pub fn is_dominated<Key>(
y: ProgramIdx,
programs: &HashSet<ProgramIdx>,
pareto_front_mapping: &HashMap<Key, HashSet<ProgramIdx>>,
) -> bool
where
Key: Eq + Hash,
{
let y_fronts: Vec<&HashSet<ProgramIdx>> = pareto_front_mapping
.values()
.filter(|front| front.contains(&y))
.collect();
if y_fronts.is_empty() {
return false;
}
for front in y_fronts {
let dominator_present = front.iter().any(|&other| programs.contains(&other));
if !dominator_present {
return false;
}
}
true
}
pub fn remove_dominated_programs<Key>(
pareto_front_mapping: &HashMap<Key, HashSet<ProgramIdx>>,
scores: Option<&[f64]>,
) -> HashMap<Key, HashSet<ProgramIdx>>
where
Key: Eq + Hash + Clone,
{
let mut freq: HashMap<ProgramIdx, usize> = HashMap::new();
for front in pareto_front_mapping.values() {
for &p in front {
*freq.entry(p).or_insert(0) += 1;
}
}
let mut all_programs: Vec<ProgramIdx> = freq.keys().copied().collect();
all_programs.sort_by(|&a, &b| {
let sa = scores.and_then(|s| s.get(a)).copied().unwrap_or(1.0);
let sb = scores.and_then(|s| s.get(b)).copied().unwrap_or(1.0);
sa.partial_cmp(&sb).unwrap_or(std::cmp::Ordering::Equal)
});
let mut dominated: HashSet<ProgramIdx> = HashSet::new();
let mut found_to_remove = true;
while found_to_remove {
found_to_remove = false;
for &y in &all_programs {
if dominated.contains(&y) {
continue;
}
let pool: HashSet<ProgramIdx> = all_programs
.iter()
.copied()
.filter(|&p| p != y && !dominated.contains(&p))
.collect();
if is_dominated(y, &pool, pareto_front_mapping) {
dominated.insert(y);
found_to_remove = true;
break; }
}
}
let dominators: HashSet<ProgramIdx> = all_programs
.iter()
.copied()
.filter(|p| !dominated.contains(p))
.collect();
let new_mapping: HashMap<Key, HashSet<ProgramIdx>> = pareto_front_mapping
.iter()
.map(|(key, front)| {
let filtered: HashSet<ProgramIdx> = front
.iter()
.copied()
.filter(|p| dominators.contains(p))
.collect();
(key.clone(), filtered)
})
.collect();
for (key, front) in pareto_front_mapping {
if !front.is_empty() {
debug_assert!(
new_mapping.get(key).is_some_and(|f| !f.is_empty()),
"Invariant violated: a non-empty frontier key lost all its programs \
after domination removal."
);
}
}
new_mapping
}
pub fn find_dominator_programs<Key>(
pareto_front_programs: &HashMap<Key, HashSet<ProgramIdx>>,
scores: &[f64],
) -> Vec<ProgramIdx>
where
Key: Eq + Hash + Clone,
{
let new_mapping = remove_dominated_programs(pareto_front_programs, Some(scores));
let mut uniq: HashSet<ProgramIdx> = HashSet::new();
for front in new_mapping.values() {
uniq.extend(front.iter().copied());
}
let mut result: Vec<ProgramIdx> = uniq.into_iter().collect();
result.sort_unstable();
result
}
pub fn select_program_candidate_from_pareto_front<Key>(
pareto_front_programs: &HashMap<Key, HashSet<ProgramIdx>>,
scores: &[f64],
rng: &mut StdRng,
) -> Result<ProgramIdx>
where
Key: Eq + Hash + Clone,
{
if pareto_front_programs.is_empty() {
return Err(GEPAError::EmptyFrontier);
}
let filtered_mapping = remove_dominated_programs(pareto_front_programs, Some(scores));
let mut program_frequency: HashMap<ProgramIdx, usize> = HashMap::new();
for front in filtered_mapping.values() {
for &prog_idx in front {
*program_frequency.entry(prog_idx).or_insert(0) += 1;
}
}
let sampling_list: Vec<ProgramIdx> = program_frequency
.iter()
.flat_map(|(&prog_idx, &freq)| std::iter::repeat_n(prog_idx, freq))
.collect();
if sampling_list.is_empty() {
return Err(GEPAError::EmptyFrontier);
}
let chosen_idx = rng.gen_range(0..sampling_list.len());
Ok(sampling_list[chosen_idx])
}
pub fn idxmax(lst: &[f64]) -> Result<usize> {
if lst.is_empty() {
return Err(GEPAError::NoCandidates);
}
let (idx, _) = lst
.iter()
.enumerate()
.reduce(|(mi, mv), (i, v)| if v > mv { (i, v) } else { (mi, mv) })
.unwrap();
Ok(idx)
}
#[cfg(test)]
mod tests {
use super::*;
use rand::SeedableRng;
fn make_frontier(entries: &[(u32, &[ProgramIdx])]) -> HashMap<u32, HashSet<ProgramIdx>> {
entries
.iter()
.map(|(k, idxs)| (*k, idxs.iter().copied().collect()))
.collect()
}
#[test]
fn is_dominated_when_fully_covered() {
let frontier = make_frontier(&[(0, &[0, 1]), (1, &[0, 1])]);
let pool: HashSet<ProgramIdx> = [1].into();
assert!(is_dominated(0, &pool, &frontier));
}
#[test]
fn not_dominated_when_unique_on_one_key() {
let frontier = make_frontier(&[(0, &[0, 1]), (1, &[0])]);
let pool: HashSet<ProgramIdx> = [1].into();
assert!(!is_dominated(0, &pool, &frontier));
}
#[test]
fn not_dominated_when_not_in_any_front() {
let frontier = make_frontier(&[(0, &[0, 1])]);
let pool: HashSet<ProgramIdx> = [0, 1].into();
assert!(!is_dominated(99, &pool, &frontier));
}
#[test]
fn remove_dominated_programs_eliminates_weak_fully_covered_program() {
let frontier = make_frontier(&[(0, &[0, 1]), (1, &[2]), (2, &[1])]);
let scores = vec![0.2, 0.7, 0.9];
let result = remove_dominated_programs(&frontier, Some(&scores));
let surviving: HashSet<ProgramIdx> = result.values().flatten().copied().collect();
assert!(
!surviving.contains(&0),
"program 0 should be dominated and removed"
);
assert!(
surviving.contains(&1),
"program 1 should survive (unique on key 2)"
);
assert!(
surviving.contains(&2),
"program 2 should survive (unique on key 1)"
);
}
#[test]
fn remove_dominated_preserves_all_when_no_one_is_dominated() {
let frontier = make_frontier(&[(0, &[0]), (1, &[1]), (2, &[2])]);
let scores = vec![0.5, 0.6, 0.7];
let result = remove_dominated_programs(&frontier, Some(&scores));
let surviving: HashSet<ProgramIdx> = result.values().flatten().copied().collect();
assert_eq!(surviving, HashSet::from([0, 1, 2]));
}
#[test]
fn remove_dominated_programs_single_program_survives() {
let frontier = make_frontier(&[(0, &[0])]);
let scores = vec![0.5];
let result = remove_dominated_programs(&frontier, Some(&scores));
let surviving: HashSet<ProgramIdx> = result.values().flatten().copied().collect();
assert_eq!(surviving, HashSet::from([0]));
}
#[test]
fn find_dominator_programs_returns_non_dominated_set() {
let frontier = make_frontier(&[(0, &[0, 1]), (1, &[1, 2])]);
let scores = vec![0.2, 0.9, 0.5];
let dominators = find_dominator_programs(&frontier, &scores);
assert_eq!(dominators, vec![1], "only program 1 should survive");
}
#[test]
fn selection_returns_valid_program_index() {
let frontier = make_frontier(&[(0, &[0, 1]), (1, &[1])]);
let scores = vec![0.5, 0.9];
let mut rng = StdRng::seed_from_u64(42);
let selected = select_program_candidate_from_pareto_front(&frontier, &scores, &mut rng)
.expect("selection should succeed");
assert!(
selected == 0 || selected == 1,
"selected program must be in frontier"
);
}
#[test]
fn selection_is_deterministic_with_same_seed() {
let frontier = make_frontier(&[(0, &[0, 1]), (1, &[1])]);
let scores = vec![0.5, 0.9];
let mut rng1 = StdRng::seed_from_u64(1234);
let mut rng2 = StdRng::seed_from_u64(1234);
let sel1 =
select_program_candidate_from_pareto_front(&frontier, &scores, &mut rng1).unwrap();
let sel2 =
select_program_candidate_from_pareto_front(&frontier, &scores, &mut rng2).unwrap();
assert_eq!(sel1, sel2, "same seed should produce same selection");
}
#[test]
fn selection_on_empty_frontier_returns_error() {
let frontier: HashMap<u32, HashSet<ProgramIdx>> = HashMap::new();
let scores = vec![];
let mut rng = StdRng::seed_from_u64(0);
let result = select_program_candidate_from_pareto_front(&frontier, &scores, &mut rng);
assert!(
matches!(result, Err(GEPAError::EmptyFrontier)),
"expected EmptyFrontier error"
);
}
#[test]
fn frequency_weighting_favours_high_coverage_program() {
let frontier_entries: Vec<(u32, &[ProgramIdx])> = vec![(0, &[0, 1])];
let extra_fronts: Vec<(u32, Vec<ProgramIdx>)> =
(1..5u32).map(|k| (k, vec![1usize])).collect();
let frontier: HashMap<u32, HashSet<ProgramIdx>> = frontier_entries
.iter()
.map(|(k, idxs)| (*k, idxs.iter().copied().collect::<HashSet<_>>()))
.chain(
extra_fronts
.iter()
.map(|(k, idxs)| (*k, idxs.iter().copied().collect::<HashSet<_>>())),
)
.collect();
let scores = vec![0.5, 0.9]; let mut rng = StdRng::seed_from_u64(999);
let mut count_1 = 0usize;
let trials = 500;
for _ in 0..trials {
if select_program_candidate_from_pareto_front(&frontier, &scores, &mut rng).unwrap()
== 1
{
count_1 += 1;
}
}
let fraction_1 = count_1 as f64 / f64::from(trials);
assert!(
fraction_1 > 0.60,
"expected program 1 to dominate selection (got {fraction_1:.2})"
);
}
#[test]
fn idxmax_returns_first_occurrence_of_max() {
assert_eq!(idxmax(&[0.1, 0.9, 0.9]).unwrap(), 1);
assert_eq!(idxmax(&[3.0, 1.0, 2.0]).unwrap(), 0);
}
#[test]
fn idxmax_on_empty_returns_error() {
assert!(matches!(idxmax(&[]), Err(GEPAError::NoCandidates)));
}
#[test]
fn test_cascading_domination() {
let frontier = make_frontier(&[(0, &[0, 1]), (1, &[0, 1, 2])]);
let scores = vec![0.9_f64, 0.5, 0.3];
let result = remove_dominated_programs(&frontier, Some(&scores));
let surviving: HashSet<ProgramIdx> = result.values().flatten().copied().collect();
assert!(
surviving.contains(&0),
"A should survive (highest scorer, dominates all)"
);
assert!(!surviving.contains(&1), "B should be dominated by A");
assert!(
!surviving.contains(&2),
"C should be dominated by A after B is removed"
);
}
#[test]
fn test_remove_dominated_with_none_scores() {
let frontier = make_frontier(&[(0, &[0, 1]), (1, &[0])]);
let result = remove_dominated_programs(&frontier, None);
let surviving: HashSet<ProgramIdx> = result.values().flatten().copied().collect();
assert!(
surviving.contains(&0),
"program 0 must survive (unique on key 1)"
);
}
#[test]
fn test_selection_with_all_zero_scores() {
let frontier = make_frontier(&[(0, &[0, 1, 2])]);
let scores = vec![0.0_f64, 0.0, 0.0];
let mut rng = StdRng::seed_from_u64(7);
let selected = select_program_candidate_from_pareto_front(&frontier, &scores, &mut rng)
.expect("should select even with all-zero scores");
assert!(
selected <= 2,
"selected program {selected} must be a valid index (0..=2)"
);
}
}