use crate::error::Result;
use crate::learn::{Alphabet, EquivalenceOracle};
use crate::sfa::{Sfa, StateId};
use std::collections::{HashMap, VecDeque};
#[derive(Debug, Clone, Copy, PartialEq, serde::Serialize, serde::Deserialize)]
pub struct PacBound {
pub epsilon: f64,
pub delta: f64,
pub samples: u64,
pub round: u64,
}
impl PacBound {
#[must_use]
pub fn compute(samples: u64, delta: f64, round: u64) -> Self {
let s = samples.max(1) as f64;
let eps = ((1.0 / delta).ln() + (round as f64 + 1.0) * std::f64::consts::LN_2) / s;
PacBound {
epsilon: eps.min(1.0),
delta,
samples,
round,
}
}
}
#[derive(Debug, Clone)]
struct SplitMix64(u64);
impl SplitMix64 {
fn next_u64(&mut self) -> u64 {
self.0 = self.0.wrapping_add(0x9E37_79B9_7F4A_7C15);
let mut z = self.0;
z = (z ^ (z >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9);
z = (z ^ (z >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB);
z ^ (z >> 31)
}
fn below(&mut self, n: usize) -> usize {
(self.next_u64() % n as u64) as usize
}
fn geo_len(&mut self, max: usize) -> usize {
let mut l = 0;
while l < max && self.next_u64() & 3 != 0 {
l += 1;
}
l
}
}
fn run_abstract(sfa: &Sfa, alpha: &Alphabet, word: &[usize]) -> StateId {
let mut s = sfa.start_state();
for &sym in word {
s = sfa.step_byte(s, alpha.byte_of(sym));
}
s
}
fn state_cover(sfa: &Sfa, alpha: &Alphabet) -> HashMap<StateId, Vec<usize>> {
let mut access: HashMap<StateId, Vec<usize>> = HashMap::new();
access.insert(sfa.start_state(), Vec::new());
let mut q = VecDeque::from([sfa.start_state()]);
while let Some(s) = q.pop_front() {
let base = access[&s].clone();
for sym in 0..alpha.len() {
let t = sfa.step_byte(s, alpha.byte_of(sym));
if let std::collections::hash_map::Entry::Vacant(e) = access.entry(t) {
let mut w = base.clone();
w.push(sym);
e.insert(w);
q.push_back(t);
}
}
}
access
}
fn characterizing_set(sfa: &Sfa, alpha: &Alphabet) -> Vec<Vec<usize>> {
let states: Vec<StateId> = state_cover(sfa, alpha).into_keys().collect();
let mut w: Vec<Vec<usize>> = vec![Vec::new()]; let mut seen: std::collections::HashSet<Vec<usize>> =
std::collections::HashSet::from([Vec::new()]);
for i in 0..states.len() {
for j in (i + 1)..states.len() {
let (p, qq) = (states[i], states[j]);
let mut bfs = VecDeque::from([(p, qq, Vec::<usize>::new())]);
let mut visited = std::collections::HashSet::from([(p, qq)]);
while let Some((a, b, suf)) = bfs.pop_front() {
if sfa.is_accepting(a) != sfa.is_accepting(b) {
if seen.insert(suf.clone()) {
w.push(suf);
}
break;
}
if suf.len() > states.len() + 1 {
break; }
for sym in 0..alpha.len() {
let na = sfa.step_byte(a, alpha.byte_of(sym));
let nb = sfa.step_byte(b, alpha.byte_of(sym));
if visited.insert((na, nb)) {
let mut ns = suf.clone();
ns.push(sym);
bfs.push_back((na, nb, ns));
}
}
}
}
}
w
}
#[derive(Debug, Clone, Copy)]
pub struct WMethodEq {
pub extra_states: usize,
}
impl EquivalenceOracle for WMethodEq {
fn find_counterexample(
&mut self,
hyp: &Sfa,
alpha: &Alphabet,
mq: &mut dyn FnMut(&[usize]) -> Result<bool>,
) -> Result<Option<Vec<usize>>> {
let cover = state_cover(hyp, alpha);
let wset = characterizing_set(hyp, alpha);
let k = alpha.len();
let mut middles: Vec<Vec<usize>> = vec![Vec::new()];
let mut frontier = vec![Vec::new()];
for _ in 0..=self.extra_states {
let mut next = Vec::new();
for m in &frontier {
for s in 0..k {
let mut e = m.clone();
e.push(s);
next.push(e.clone());
middles.push(e);
}
}
frontier = next;
}
for access in cover.values() {
for mid in &middles {
for suf in &wset {
let mut t = access.clone();
t.extend_from_slice(mid);
t.extend_from_slice(suf);
let truth = mq(&t)?;
if run_abstract_accepts(hyp, alpha, &t) != truth {
return Ok(Some(t));
}
}
}
}
Ok(None)
}
}
fn run_abstract_accepts(sfa: &Sfa, alpha: &Alphabet, word: &[usize]) -> bool {
sfa.is_accepting(run_abstract(sfa, alpha, word))
}
#[derive(Debug, Clone)]
pub struct UcbBanditEq {
pub budget: usize,
pub max_suffix: usize,
pub seed: u64,
counts: HashMap<(StateId, usize), u32>,
total: u32,
}
impl UcbBanditEq {
#[must_use]
pub fn new(budget: usize, max_suffix: usize, seed: u64) -> Self {
UcbBanditEq {
budget,
max_suffix,
seed,
counts: HashMap::new(),
total: 0,
}
}
#[must_use]
pub fn arms_explored(&self) -> usize {
self.counts.len()
}
}
impl EquivalenceOracle for UcbBanditEq {
fn find_counterexample(
&mut self,
hyp: &Sfa,
alpha: &Alphabet,
mq: &mut dyn FnMut(&[usize]) -> Result<bool>,
) -> Result<Option<Vec<usize>>> {
let cover = state_cover(hyp, alpha);
let mut rng = SplitMix64(self.seed ^ u64::from(self.total).wrapping_mul(0x100));
for _ in 0..self.budget {
self.total += 1;
let lnt = (f64::from(self.total) + 1.0).ln();
let mut best: Option<(StateId, usize)> = None;
let mut best_score = f64::NEG_INFINITY;
for &s in cover.keys() {
for sym in 0..alpha.len() {
let n = *self.counts.get(&(s, sym)).unwrap_or(&0);
let score = if n == 0 {
f64::INFINITY
} else {
(2.0 * lnt / f64::from(n)).sqrt()
};
if score > best_score {
best_score = score;
best = Some((s, sym));
}
}
}
let (s, sym) = best.ok_or(crate::error::WafModelError::EmptySearchSpace)?;
*self.counts.entry((s, sym)).or_insert(0) += 1;
let mut word = cover[&s].clone();
word.push(sym);
let suf_len = rng.geo_len(self.max_suffix);
for _ in 0..suf_len {
word.push(rng.below(alpha.len()));
}
let truth = mq(&word)?;
if run_abstract_accepts(hyp, alpha, &word) != truth {
return Ok(Some(word));
}
}
Ok(None)
}
}
#[derive(Debug, Clone)]
pub struct SampledEq {
pub samples: u64,
pub max_len: usize,
pub delta: f64,
pub seed: u64,
round: u64,
last: Option<PacBound>,
}
impl SampledEq {
#[must_use]
pub fn new(samples: u64, max_len: usize, delta: f64, seed: u64) -> Self {
SampledEq {
samples,
max_len,
delta,
seed,
round: 0,
last: None,
}
}
#[must_use]
pub fn last_bound(&self) -> Option<PacBound> {
self.last
}
}
impl EquivalenceOracle for SampledEq {
fn find_counterexample(
&mut self,
hyp: &Sfa,
alpha: &Alphabet,
mq: &mut dyn FnMut(&[usize]) -> Result<bool>,
) -> Result<Option<Vec<usize>>> {
let mut rng = SplitMix64(self.seed ^ self.round.wrapping_mul(0x9E37_79B9));
for _ in 0..self.samples {
let len = rng.geo_len(self.max_len);
let word: Vec<usize> = (0..len).map(|_| rng.below(alpha.len())).collect();
let truth = mq(&word)?;
if run_abstract_accepts(hyp, alpha, &word) != truth {
self.last = None;
return Ok(Some(word));
}
}
self.last = Some(PacBound::compute(self.samples, self.delta, self.round));
self.round += 1;
Ok(None)
}
}
pub struct ChainedEq {
oracles: Vec<Box<dyn EquivalenceOracle>>,
}
impl ChainedEq {
#[must_use]
pub fn new(oracles: Vec<Box<dyn EquivalenceOracle>>) -> Self {
ChainedEq { oracles }
}
}
impl EquivalenceOracle for ChainedEq {
fn find_counterexample(
&mut self,
hyp: &Sfa,
alpha: &Alphabet,
mq: &mut dyn FnMut(&[usize]) -> Result<bool>,
) -> Result<Option<Vec<usize>>> {
for o in &mut self.oracles {
if let Some(ce) = o.find_counterexample(hyp, alpha, mq)? {
return Ok(Some(ce));
}
}
Ok(None)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::error::WafModelError;
use crate::sfa::BytePred;
fn reject_all_sfa() -> Sfa {
Sfa::new(0, vec![false], vec![vec![(BytePred::any(), 0)]])
}
#[test]
fn ucb_bandit_zero_budget_returns_ok_none() {
let hyp = reject_all_sfa();
let alpha = Alphabet::new(vec![], b'\x00');
let mut oracle = UcbBanditEq {
budget: 0,
max_suffix: 4,
seed: 1,
counts: HashMap::new(),
total: 0,
};
let result = oracle.find_counterexample(&hyp, &alpha, &mut |_| Ok(false));
assert!(
result.is_ok(),
"zero-budget UCB bandit must return Ok, got {result:?}"
);
assert!(
result.unwrap().is_none(),
"zero-budget UCB bandit must return Ok(None)"
);
}
#[test]
fn ucb_bandit_single_state_no_counterexample_returns_ok_none() {
let hyp = reject_all_sfa();
let alpha = Alphabet::new(vec![b'a'], b'\x00');
let mut oracle = UcbBanditEq {
budget: 10,
max_suffix: 2,
seed: 42,
counts: HashMap::new(),
total: 0,
};
let result = oracle.find_counterexample(&hyp, &alpha, &mut |_| Ok(false));
assert!(
result.is_ok(),
"UCB bandit on consistent oracle must not return Err: {result:?}"
);
}
#[test]
fn empty_search_space_error_variant_is_correct() {
let none: Option<(StateId, usize)> = None;
let err = none.ok_or(WafModelError::EmptySearchSpace).unwrap_err();
assert!(
matches!(err, WafModelError::EmptySearchSpace),
"expected EmptySearchSpace variant, got {err:?}"
);
assert!(
!err.to_string().is_empty(),
"EmptySearchSpace Display must not be empty"
);
}
}