use crate::ast_mcts::{AstMctsOracle, MctsResult, RuleId, mcts_search};
use crate::evolution::{Chromosome, GenePool, population::random_chromosome};
use crate::lineage::Lineage;
use crate::search::{EvalCandidate, SearchAlgorithm, fitness_cmp};
use crate::types::{Budget, EvolutionError, OracleVerdict, SearchStats};
use rand::RngCore;
use rand::rngs::StdRng;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
pub const DEFAULT_MCTS_BUDGET: u64 = 64;
pub const DEFAULT_UCB1_C: f64 = std::f64::consts::SQRT_2;
struct InlineOracle<'a> {
candidates: &'a mut Vec<String>,
prior_bypass: bool,
jitter: u64,
}
impl<'a> AstMctsOracle for InlineOracle<'a> {
fn eval(&mut self, candidate: &str) -> bool {
self.candidates.push(candidate.to_string());
self.jitter = self
.jitter
.wrapping_mul(6_364_136_223_846_793_005)
.wrapping_add(1_442_695_040_888_963_407);
if self.prior_bypass {
true
} else {
true
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AstMctsAlgorithm {
best: Chromosome,
gene_pool: GenePool,
generation: u32,
eval_counter: u64,
best_payload: String,
bypass_found: bool,
#[serde(default)]
rule_stats: HashMap<u8, (u64, f64)>,
#[serde(skip)]
in_flight: HashMap<u64, Chromosome>,
mcts_budget: u64,
ucb1_c: f64,
#[serde(skip)]
pending: Vec<(u64, Chromosome)>,
}
impl AstMctsAlgorithm {
#[must_use]
pub fn new() -> Self {
Self::with_config(DEFAULT_MCTS_BUDGET, DEFAULT_UCB1_C)
}
#[must_use]
pub fn with_config(mcts_budget: u64, ucb1_c: f64) -> Self {
Self {
best: Chromosome::new(vec![("ast_mcts_payload".into(), String::new())]),
gene_pool: GenePool::default_wafrift(),
generation: 0,
eval_counter: 0,
best_payload: String::new(),
bypass_found: false,
rule_stats: HashMap::new(),
in_flight: HashMap::new(),
mcts_budget,
ucb1_c,
pending: Vec::new(),
}
}
fn payload_from_chromosome(c: &Chromosome) -> &str {
c.gene("ast_mcts_payload")
.or_else(|| c.gene("payload"))
.unwrap_or("")
}
fn replenish(&mut self, n: usize, rng: &mut StdRng) {
if self.best_payload.is_empty() {
for _ in 0..n {
self.eval_counter = self.eval_counter.saturating_add(1);
let mut c = random_chromosome(&self.gene_pool, rng);
c.genes.push(("ast_mcts_payload".into(), String::new()));
c.lineage = Lineage::genesis(self.generation);
self.pending.push((self.eval_counter, c));
}
return;
}
let jitter: u64 = rng.next_u64();
let mut generated: Vec<String> = Vec::new();
let mut inline = InlineOracle {
candidates: &mut generated,
prior_bypass: self.bypass_found,
jitter,
};
let result: Option<MctsResult> = mcts_search(
&self.best_payload,
self.mcts_budget,
self.ucb1_c,
&mut inline,
);
if let Some(ref r) = result {
for &(action, visits, mean_reward) in &r.arm_stats {
let entry = self.rule_stats.entry(action.rule.0).or_insert((0, 0.0));
entry.0 = entry.0.saturating_add(visits);
let addend = if mean_reward.is_finite() {
mean_reward * (visits as f64)
} else {
0.0
};
entry.1 = if entry.1.is_finite() {
entry.1 + addend
} else {
addend
};
}
}
let mut seen: std::collections::HashSet<String> = std::collections::HashSet::new();
if let Some(ref r) = result
&& !r.best_payload.is_empty()
&& seen.insert(r.best_payload.clone())
{
self.eval_counter = self.eval_counter.saturating_add(1);
let mut c = self.best.clone();
let payload = r.best_payload.clone();
set_gene(&mut c, "ast_mcts_payload", &payload);
c.lineage = Lineage::mutation(
&self.best,
vec![crate::lineage::MutationOp {
gene_name: "ast_mcts_payload".into(),
from: self.best_payload.clone(),
to: payload.clone(),
operator: "ast_mcts:best_payload".into(),
}],
self.generation,
);
self.pending.push((self.eval_counter, c));
}
for payload in generated {
if self.pending.len() >= n {
break;
}
if payload.is_empty() || !seen.insert(payload.clone()) {
continue;
}
self.eval_counter = self.eval_counter.saturating_add(1);
let mut c = self.best.clone();
set_gene(&mut c, "ast_mcts_payload", &payload);
c.lineage = Lineage::mutation(
&self.best,
vec![crate::lineage::MutationOp {
gene_name: "ast_mcts_payload".into(),
from: self.best_payload.clone(),
to: payload.clone(),
operator: "ast_mcts:inline_candidate".into(),
}],
self.generation,
);
self.pending.push((self.eval_counter, c));
}
if self.pending.is_empty() {
self.eval_counter = self.eval_counter.saturating_add(1);
let mut c = self.best.clone();
set_gene(&mut c, "ast_mcts_payload", &self.best_payload);
c.lineage = Lineage::genesis(self.generation);
self.pending.push((self.eval_counter, c));
}
}
}
impl Default for AstMctsAlgorithm {
fn default() -> Self {
Self::new()
}
}
fn set_gene(c: &mut Chromosome, name: &str, value: &str) {
if let Some(entry) = c.genes.iter_mut().find(|(k, _)| k == name) {
entry.1 = value.to_string();
} else {
c.genes.push((name.to_string(), value.to_string()));
}
}
impl SearchAlgorithm for AstMctsAlgorithm {
fn name(&self) -> &'static str {
"ast_mcts"
}
fn initialize(&mut self, population: Vec<Chromosome>, gene_pool: &GenePool, _rng: &mut StdRng) {
self.gene_pool = gene_pool.clone();
self.generation = 0;
self.eval_counter = 0;
self.bypass_found = false;
self.pending.clear();
self.in_flight.clear();
if let Some(seed) = population
.into_iter()
.max_by(|a, b| fitness_cmp(a.fitness, b.fitness))
{
let payload = Self::payload_from_chromosome(&seed).to_string();
self.best_payload = payload;
self.best = seed;
}
set_gene(&mut self.best, "ast_mcts_payload", &self.best_payload);
}
fn request_evaluations(&mut self, n: usize, rng: &mut StdRng) -> Vec<EvalCandidate> {
if n == 0 {
return Vec::new();
}
if self.pending.is_empty() {
self.replenish(n, rng);
}
let drain_count = n.min(self.pending.len());
let batch: Vec<(u64, Chromosome)> = self.pending.drain(..drain_count).collect();
let mut out = Vec::with_capacity(batch.len());
for (id, chromosome) in batch {
self.in_flight.insert(id, chromosome.clone());
out.push(EvalCandidate { id, chromosome });
}
out
}
fn submit_evaluations(&mut self, results: Vec<(u64, OracleVerdict)>) {
for (id, verdict) in results {
let Some(mut chromosome) = self.in_flight.remove(&id) else {
continue;
};
chromosome.record_verdict(&verdict);
if verdict.passed || chromosome.fitness > self.best.fitness {
if verdict.passed && !self.bypass_found {
self.bypass_found = true;
}
let new_payload = chromosome
.gene("ast_mcts_payload")
.unwrap_or("")
.to_string();
if !new_payload.is_empty() {
self.best_payload = new_payload;
}
self.best = chromosome;
}
}
self.generation = self.generation.saturating_add(1);
}
fn should_terminate(&self, stats: &SearchStats, budget: &Budget) -> bool {
self.bypass_found
|| stats.evaluations >= budget.max_requests
|| stats.generation >= budget.max_generations
|| stats.stagnation_counter >= budget.stagnation_limit
}
fn best(&self) -> Option<&Chromosome> {
Some(&self.best)
}
fn checkpoint(&self) -> Result<Vec<u8>, EvolutionError> {
serde_json::to_vec(self).map_err(EvolutionError::SerializationFailed)
}
fn restore(&mut self, bytes: &[u8]) -> Result<(), EvolutionError> {
if bytes.len() > crate::types::MAX_CHECKPOINT_BYTES {
return Err(EvolutionError::OversizedData {
context: "ast_mcts checkpoint restore".into(),
size: bytes.len(),
max: crate::types::MAX_CHECKPOINT_BYTES,
});
}
*self = serde_json::from_slice(bytes).map_err(EvolutionError::DeserializationFailed)?;
Ok(())
}
fn clone_box(&self) -> Box<dyn SearchAlgorithm> {
Box::new(self.clone())
}
fn population_snapshot(&self) -> Vec<Chromosome> {
vec![self.best.clone()]
}
}
#[must_use]
pub fn all_rule_names() -> Vec<&'static str> {
RuleId::ALL.iter().map(|r| r.name()).collect()
}
#[cfg(test)]
mod tests {
use super::*;
use rand::SeedableRng;
fn make_rng() -> StdRng {
StdRng::seed_from_u64(0x00C0_FFEE_BABE)
}
#[test]
fn name_is_ast_mcts() {
assert_eq!(AstMctsAlgorithm::new().name(), "ast_mcts");
}
#[test]
fn initialize_with_empty_population_sets_empty_best_payload() {
let mut alg = AstMctsAlgorithm::new();
let pool = GenePool::default_wafrift();
let mut rng = make_rng();
alg.initialize(vec![], &pool, &mut rng);
assert!(alg.best_payload.is_empty());
}
#[test]
fn initialize_with_sql_payload_captures_it() {
let mut alg = AstMctsAlgorithm::new();
let pool = GenePool::default_wafrift();
let mut rng = make_rng();
let seed = Chromosome::new(vec![("ast_mcts_payload".into(), "'a'='a'".into())]);
alg.initialize(vec![seed], &pool, &mut rng);
assert_eq!(alg.best_payload, "'a'='a'");
}
#[test]
fn request_evaluations_returns_n_candidates() {
let mut alg = AstMctsAlgorithm::new();
let pool = GenePool::default_wafrift();
let mut rng = make_rng();
let seed = Chromosome::new(vec![("ast_mcts_payload".into(), "1=1".into())]);
alg.initialize(vec![seed], &pool, &mut rng);
let candidates = alg.request_evaluations(4, &mut rng);
assert!(!candidates.is_empty(), "must return at least one candidate");
assert!(candidates.len() <= 4);
}
#[test]
fn request_evaluations_n_zero_returns_empty() {
let mut alg = AstMctsAlgorithm::new();
let pool = GenePool::default_wafrift();
let mut rng = make_rng();
alg.initialize(vec![], &pool, &mut rng);
let out = alg.request_evaluations(0, &mut rng);
assert!(out.is_empty());
}
#[test]
fn submit_evaluations_updates_best_on_pass() {
let mut alg = AstMctsAlgorithm::new();
let pool = GenePool::default_wafrift();
let mut rng = make_rng();
let seed = Chromosome::new(vec![("ast_mcts_payload".into(), "1=1".into())]);
alg.initialize(vec![seed], &pool, &mut rng);
let candidates = alg.request_evaluations(3, &mut rng);
let first = candidates.into_iter().next().unwrap();
let first_payload = first
.chromosome
.gene("ast_mcts_payload")
.unwrap_or("")
.to_string();
let verdict = OracleVerdict::from_bool(true);
alg.submit_evaluations(vec![(first.id, verdict)]);
assert!(alg.bypass_found, "bypass_found must be set after a pass");
assert_eq!(alg.best_payload, first_payload);
}
#[test]
fn should_terminate_on_bypass() {
let mut alg = AstMctsAlgorithm::new();
let pool = GenePool::default_wafrift();
let mut rng = make_rng();
alg.initialize(vec![], &pool, &mut rng);
alg.bypass_found = true;
let stats = SearchStats::new();
let budget = Budget::default();
assert!(alg.should_terminate(&stats, &budget));
}
#[test]
fn checkpoint_roundtrip_preserves_state() {
let mut alg = AstMctsAlgorithm::new();
let pool = GenePool::default_wafrift();
let mut rng = make_rng();
let seed = Chromosome::new(vec![("ast_mcts_payload".into(), "'x'='x'".into())]);
alg.initialize(vec![seed], &pool, &mut rng);
alg.bypass_found = true;
let bytes = alg.checkpoint().unwrap();
let mut restored = AstMctsAlgorithm::new();
restored.restore(&bytes).unwrap();
assert_eq!(restored.best_payload, alg.best_payload);
assert_eq!(restored.bypass_found, alg.bypass_found);
}
#[test]
fn clone_box_produces_independent_instance() {
let mut alg = AstMctsAlgorithm::new();
let pool = GenePool::default_wafrift();
let mut rng = make_rng();
let seed = Chromosome::new(vec![("ast_mcts_payload".into(), "1=1".into())]);
alg.initialize(vec![seed], &pool, &mut rng);
let cloned = alg.clone_box();
alg.bypass_found = true;
assert!(!cloned.best().unwrap().has_gene("non_existent"));
let _ = cloned.best();
}
#[test]
fn all_rule_names_covers_all_16_rules() {
let names = all_rule_names();
assert_eq!(names.len(), 16, "all 16 RuleId variants must be named");
}
#[test]
fn population_snapshot_returns_best() {
let alg = AstMctsAlgorithm::new();
let snap = alg.population_snapshot();
assert_eq!(snap.len(), 1);
}
#[test]
fn eval_counter_saturates_at_u64_max() {
let mut alg = AstMctsAlgorithm::new();
let pool = GenePool::default_wafrift();
let mut rng = make_rng();
alg.initialize(
vec![Chromosome::new(vec![(
"ast_mcts_payload".into(),
"1=1".into(),
)])],
&pool,
&mut rng,
);
alg.eval_counter = u64::MAX;
let _ = alg.request_evaluations(1, &mut rng);
assert_eq!(
alg.eval_counter,
u64::MAX,
"eval_counter must saturate at u64::MAX, not wrap to 0"
);
}
#[test]
fn generation_saturates_at_u32_max() {
let mut alg = AstMctsAlgorithm::new();
let pool = GenePool::default_wafrift();
let mut rng = make_rng();
alg.initialize(vec![], &pool, &mut rng);
alg.generation = u32::MAX;
alg.submit_evaluations(vec![(0, OracleVerdict::from_bool(false))]);
assert_eq!(
alg.generation,
u32::MAX,
"generation must saturate at u32::MAX, not wrap to 0"
);
}
#[test]
fn rule_stats_nan_reward_does_not_poison_ucb1() {
let mut alg = AstMctsAlgorithm::new();
let pool = GenePool::default_wafrift();
let mut rng = make_rng();
alg.initialize(
vec![Chromosome::new(vec![(
"ast_mcts_payload".into(),
"1=1".into(),
)])],
&pool,
&mut rng,
);
alg.rule_stats.insert(0, (10, f64::NAN));
let candidates = alg.request_evaluations(2, &mut rng);
if let Some(c) = candidates.into_iter().next() {
alg.submit_evaluations(vec![(c.id, OracleVerdict::from_bool(true))]);
}
for (visits, total) in alg.rule_stats.values() {
assert!(
total.is_finite() || *visits == 0,
"rule_stats total must be finite after NaN reset, got {total}"
);
}
}
#[test]
fn rule_stats_inf_reward_does_not_poison_ucb1() {
let mut alg = AstMctsAlgorithm::new();
let pool = GenePool::default_wafrift();
let mut rng = make_rng();
alg.initialize(
vec![Chromosome::new(vec![(
"ast_mcts_payload".into(),
"1=1".into(),
)])],
&pool,
&mut rng,
);
alg.rule_stats.insert(1, (5, f64::INFINITY));
let candidates = alg.request_evaluations(2, &mut rng);
if let Some(c) = candidates.into_iter().next() {
alg.submit_evaluations(vec![(c.id, OracleVerdict::from_bool(false))]);
}
for (visits, total) in alg.rule_stats.values() {
assert!(
total.is_finite() || *visits == 0,
"rule_stats total must be finite after Inf reset, got {total}"
);
}
}
#[test]
fn rule_stats_nan_does_not_cross_contaminate_other_rules() {
let mut alg = AstMctsAlgorithm::new();
let pool = GenePool::default_wafrift();
let mut rng = make_rng();
alg.initialize(
vec![Chromosome::new(vec![(
"ast_mcts_payload".into(),
"1=1".into(),
)])],
&pool,
&mut rng,
);
alg.rule_stats.insert(0, (3, 2.5));
alg.rule_stats.insert(1, (7, f64::NAN));
let candidates = alg.request_evaluations(1, &mut rng);
if let Some(c) = candidates.into_iter().next() {
alg.submit_evaluations(vec![(c.id, OracleVerdict::from_bool(true))]);
}
if let Some((_, total)) = alg.rule_stats.get(&0) {
assert!(
total.is_finite(),
"healthy rule_stats entry must remain finite, got {total}"
);
}
}
}