use serde_json::Value;
use crate::core::adapter::Candidate;
use crate::core::data_loader::DataId;
use crate::core::state::GEPAState;
use crate::error::Result;
pub trait ComponentSelector<Id: DataId>: Send + Sync {
fn select_components(
&self,
state: &mut GEPAState<Id>,
trajectories: &[Value],
subsample_scores: &[f64],
candidate_idx: usize,
candidate: &Candidate,
) -> Result<Vec<String>>;
}
pub struct RoundRobinSelector;
impl<Id: DataId> ComponentSelector<Id> for RoundRobinSelector {
fn select_components(
&self,
state: &mut GEPAState<Id>,
_trajectories: &[Value],
_subsample_scores: &[f64],
candidate_idx: usize,
_candidate: &Candidate,
) -> Result<Vec<String>> {
let n = state.list_of_named_predictors.len();
if n == 0 {
return Ok(Vec::new());
}
let pid = state.named_predictor_id_to_update_next[candidate_idx] % n;
state.named_predictor_id_to_update_next[candidate_idx] = (pid + 1) % n;
Ok(vec![state.list_of_named_predictors[pid].clone()])
}
}
pub struct AllComponentSelector;
impl<Id: DataId> ComponentSelector<Id> for AllComponentSelector {
fn select_components(
&self,
_state: &mut GEPAState<Id>,
_trajectories: &[Value],
_subsample_scores: &[f64],
_candidate_idx: usize,
candidate: &Candidate,
) -> Result<Vec<String>> {
Ok(candidate.keys().cloned().collect())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::state::{FrontierType, ValsetEvaluation};
fn make_multi_component_state() -> GEPAState<usize> {
let mut seed = Candidate::new();
seed.insert("alpha".into(), "first".into());
seed.insert("beta".into(), "second".into());
seed.insert("gamma".into(), "third".into());
let eval = ValsetEvaluation::from_vecs(
vec![0usize],
vec![serde_json::json!("out")],
vec![0.5],
None,
);
GEPAState::new(seed, eval, FrontierType::Instance, None)
.expect("construction should succeed")
}
#[test]
fn round_robin_cycles_through_all_components() {
let mut state = make_multi_component_state();
let candidate = state.program_candidates[0].clone();
let selector = RoundRobinSelector;
let mut seen: std::collections::HashSet<String> = std::collections::HashSet::new();
for _ in 0..3 {
let names = selector
.select_components(&mut state, &[], &[], 0, &candidate)
.expect("should select");
assert_eq!(names.len(), 1);
seen.insert(names[0].clone());
}
assert_eq!(seen.len(), 3);
}
#[test]
fn round_robin_wraps_around() {
let mut state = make_multi_component_state();
let candidate = state.program_candidates[0].clone();
let selector = RoundRobinSelector;
let mut calls: Vec<String> = Vec::new();
for _ in 0..6 {
let names = selector
.select_components(&mut state, &[], &[], 0, &candidate)
.expect("should select");
calls.push(names[0].clone());
}
assert_eq!(calls[0], calls[3]);
assert_eq!(calls[1], calls[4]);
assert_eq!(calls[2], calls[5]);
}
#[test]
fn all_component_selector_returns_all_components() {
let mut state = make_multi_component_state();
let candidate = state.program_candidates[0].clone();
let selector = AllComponentSelector;
let names = selector
.select_components(&mut state, &[], &[], 0, &candidate)
.expect("should select");
assert_eq!(names.len(), 3);
assert!(names.contains(&"alpha".to_string()));
assert!(names.contains(&"beta".to_string()));
assert!(names.contains(&"gamma".to_string()));
}
#[test]
fn test_round_robin_single_component() {
let mut seed = Candidate::new();
seed.insert("only".into(), "value".into());
let eval = ValsetEvaluation::from_vecs(
vec![0usize],
vec![serde_json::json!("out")],
vec![0.5],
None,
);
let mut state = GEPAState::new(seed, eval, FrontierType::Instance, None)
.expect("construction should succeed");
let candidate = state.program_candidates[0].clone();
let selector = RoundRobinSelector;
let first = selector
.select_components(&mut state, &[], &[], 0, &candidate)
.expect("should select");
let second = selector
.select_components(&mut state, &[], &[], 0, &candidate)
.expect("should select again");
let third = selector
.select_components(&mut state, &[], &[], 0, &candidate)
.expect("should select third time");
assert_eq!(first, vec!["only".to_string()]);
assert_eq!(second, vec!["only".to_string()]);
assert_eq!(third, vec!["only".to_string()]);
}
#[test]
fn round_robin_per_candidate_independence() {
let mut state = make_multi_component_state();
let mut cand2 = Candidate::new();
cand2.insert("alpha".into(), "a2".into());
cand2.insert("beta".into(), "b2".into());
cand2.insert("gamma".into(), "g2".into());
let eval2 = ValsetEvaluation::from_vecs(
vec![0usize],
vec![serde_json::json!("out2")],
vec![0.8],
None,
);
state
.update_state_with_new_program(vec![0], cand2, eval2, 1)
.expect("update should succeed");
let candidate0 = state.program_candidates[0].clone();
let candidate1 = state.program_candidates[1].clone();
let selector = RoundRobinSelector;
selector
.select_components(&mut state, &[], &[], 0, &candidate0)
.unwrap();
selector
.select_components(&mut state, &[], &[], 0, &candidate0)
.unwrap();
let next_for_1 = selector
.select_components(&mut state, &[], &[], 1, &candidate1)
.expect("should select for candidate 1");
assert_eq!(next_for_1.len(), 1);
}
}