use crate::error::{Result, WafModelError};
use crate::oracle::WafOracle;
use crate::outcome::Outcome;
use crate::sfa::{BytePred, Sfa, StateId};
use std::collections::HashMap;
use wafrift_types::Request;
#[derive(Debug, Clone)]
pub struct Alphabet {
symbols: Vec<u8>,
}
impl Alphabet {
#[must_use]
pub fn new(mut distinguished: Vec<u8>, catch_all: u8) -> Self {
distinguished.sort_unstable();
distinguished.dedup();
assert!(
!distinguished.contains(&catch_all),
"catch-all byte {catch_all} must not be a distinguished symbol"
);
distinguished.push(catch_all);
Alphabet {
symbols: distinguished,
}
}
#[must_use]
pub fn len(&self) -> usize {
self.symbols.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.symbols.is_empty()
}
#[must_use]
pub fn catch_all(&self) -> usize {
self.symbols.len() - 1
}
#[must_use]
pub fn concretize(&self, word: &[usize]) -> Vec<u8> {
word.iter().map(|&i| self.symbols[i]).collect()
}
#[must_use]
pub fn byte_of(&self, i: usize) -> u8 {
self.symbols[i]
}
#[must_use]
pub fn raw_symbols(&self) -> &[u8] {
&self.symbols
}
#[must_use]
pub fn from_raw_symbols(symbols: Vec<u8>) -> Self {
assert!(!symbols.is_empty(), "alphabet must have ≥1 class");
let mut seen = symbols.clone();
seen.sort_unstable();
let len_before = seen.len();
seen.dedup();
assert_eq!(len_before, seen.len(), "duplicate alphabet symbols");
Alphabet { symbols }
}
#[must_use]
pub fn guard(&self, i: usize) -> BytePred {
if i == self.catch_all() {
let mut explicit = BytePred::none();
for &b in &self.symbols[..self.catch_all()] {
explicit.insert(b);
}
!explicit
} else {
BytePred::byte(self.symbols[i])
}
}
}
struct Mq<'a, B> {
oracle: &'a mut dyn WafOracle,
build: &'a B,
cache: HashMap<Vec<usize>, bool>,
alpha: &'a Alphabet,
}
impl<'a, B> Mq<'a, B>
where
B: Fn(&[u8]) -> Request,
{
fn ask(&mut self, word: &[usize]) -> Result<bool> {
if let Some(&b) = self.cache.get(word) {
return Ok(b);
}
let req = (self.build)(&self.alpha.concretize(word));
let pass = matches!(self.oracle.classify(&req)?, Outcome::Pass);
self.cache.insert(word.to_vec(), pass);
Ok(pass)
}
}
pub trait EquivalenceOracle {
fn find_counterexample(
&mut self,
hyp: &Sfa,
alpha: &Alphabet,
mq: &mut dyn FnMut(&[usize]) -> Result<bool>,
) -> Result<Option<Vec<usize>>>;
}
#[derive(Debug, Clone, Copy)]
pub struct BoundedExhaustiveEq {
pub max_len: usize,
pub max_queries: Option<u64>,
}
impl EquivalenceOracle for BoundedExhaustiveEq {
fn find_counterexample(
&mut self,
hyp: &Sfa,
alpha: &Alphabet,
mq: &mut dyn FnMut(&[usize]) -> Result<bool>,
) -> Result<Option<Vec<usize>>> {
let k = alpha.len();
let query_cap = self.max_queries.unwrap_or(u64::MAX);
let mut queries_used: u64 = 0;
let mut frontier: Vec<Vec<usize>> = vec![Vec::new()];
for _len in 0..=self.max_len {
let mut next = Vec::new();
for w in &frontier {
if queries_used >= query_cap {
return Ok(None);
}
queries_used += 1;
let truth = mq(w)?;
if hyp.accepts(&alpha.concretize(w)) != truth {
return Ok(Some(w.clone()));
}
if next.len() >= Self::FRONTIER_CAP {
continue;
}
for sym in 0..k {
let mut e = w.clone();
e.push(sym);
next.push(e);
}
}
frontier = next;
}
Ok(None)
}
}
impl BoundedExhaustiveEq {
pub const FRONTIER_CAP: usize = 1_000_000;
}
pub const PASSIVE_LEARN_MAX_STATES: usize = 100_000;
#[derive(Debug)]
pub struct LearnReport {
pub sfa: Sfa,
pub membership_queries: u64,
pub equivalence_rounds: u64,
}
struct Table {
s: Vec<Vec<usize>>,
e: Vec<Vec<usize>>,
rows: HashMap<Vec<usize>, Vec<bool>>,
}
impl Table {
fn row<F: FnMut(&[usize]) -> Result<bool>>(
&mut self,
u: &[usize],
mq: &mut F,
) -> Result<Vec<bool>> {
if let Some(r) = self.rows.get(u) {
return Ok(r.clone());
}
let mut r = Vec::with_capacity(self.e.len());
for e in &self.e.clone() {
let mut w = u.to_vec();
w.extend_from_slice(e);
r.push(mq(&w)?);
}
self.rows.insert(u.to_vec(), r.clone());
Ok(r)
}
}
fn build_hypothesis<F: FnMut(&[usize]) -> Result<bool>>(
t: &mut Table,
alpha: &Alphabet,
mq: &mut F,
) -> Result<Sfa> {
let mut access: Vec<Vec<usize>> = Vec::new();
let mut row_of: HashMap<Vec<bool>, StateId> = HashMap::new();
for s in t.s.clone() {
let r = t.row(&s, mq)?;
row_of.entry(r).or_insert_with(|| {
access.push(s.clone());
access.len() - 1
});
}
let n = access.len();
let mut accept = vec![false; n];
let mut delta: Vec<Vec<(BytePred, StateId)>> = vec![Vec::new(); n];
let eps_idx =
t.e.iter()
.position(|e| e.is_empty())
.ok_or(WafModelError::TableNotClosed)?;
for (st, acc) in access.iter().zip(accept.iter_mut()) {
*acc = t.row(st, mq)?[eps_idx];
}
for st in 0..n {
for a in 0..alpha.len() {
let mut sa = access[st].clone();
sa.push(a);
let tgt_row = t.row(&sa, mq)?;
let tgt = *row_of.get(&tgt_row).ok_or(WafModelError::TableNotClosed)?;
delta[st].push((alpha.guard(a), tgt));
}
}
let start = *row_of
.get(&t.row(&[], mq)?)
.ok_or(WafModelError::TableNotClosed)?;
Ok(Sfa::new(start, accept, delta))
}
pub fn passive_learn<B>(
oracle: &mut dyn WafOracle,
build: &B,
alpha: &Alphabet,
depth: usize,
) -> Result<LearnReport>
where
B: Fn(&[u8]) -> Request,
{
let mut mqx = Mq {
oracle,
build,
cache: HashMap::new(),
alpha,
};
let mut suffixes: Vec<Vec<usize>> = vec![vec![]];
let mut frontier = vec![vec![]];
for _ in 0..depth {
let mut next = Vec::new();
for w in &frontier {
for s in 0..alpha.len() {
let mut e = w.clone();
e.push(s);
next.push(e.clone());
suffixes.push(e);
}
}
frontier = next;
}
let row = |mqx: &mut Mq<B>, p: &[usize]| -> Result<Vec<bool>> {
let mut r = Vec::with_capacity(suffixes.len());
for e in &suffixes {
let mut w = p.to_vec();
w.extend_from_slice(e);
r.push(mqx.ask(&w)?);
}
Ok(r)
};
use std::collections::HashMap as Map;
use std::collections::hash_map::Entry;
let mut id_of: Map<Vec<bool>, StateId> = Map::new();
let mut access: Vec<Vec<usize>> = Vec::new();
let r0 = row(&mut mqx, &[])?;
id_of.insert(r0.clone(), 0);
access.push(Vec::new());
let mut accept = vec![r0[0]];
let mut delta: Vec<Vec<(BytePred, StateId)>> = vec![Vec::new()];
let mut work = vec![0usize];
let mut wi = 0;
while wi < work.len() {
let s = work[wi];
wi += 1;
let p = access[s].clone();
for a in 0..alpha.len() {
let mut pa = p.clone();
pa.push(a);
let r = row(&mut mqx, &pa)?;
let tgt = if pa.len() <= depth {
match id_of.entry(r.clone()) {
Entry::Occupied(e) => *e.get(),
Entry::Vacant(e) => {
if access.len() >= PASSIVE_LEARN_MAX_STATES {
0
} else {
let id = access.len();
e.insert(id);
access.push(pa.clone());
accept.push(r[0]);
delta.push(Vec::new());
work.push(id);
id
}
}
}
} else {
id_of.get(&r).copied().unwrap_or(0)
};
delta[s].push((alpha.guard(a), tgt));
}
}
Ok(LearnReport {
sfa: Sfa::new(0, accept, delta),
membership_queries: mqx.cache.len() as u64,
equivalence_rounds: 0,
})
}
fn l_star_impl<B>(
oracle: &mut dyn WafOracle,
build: &B,
alpha: &Alphabet,
eq: &mut dyn EquivalenceOracle,
budget: u64,
) -> Result<LearnReport>
where
B: Fn(&[u8]) -> Request,
{
let mut mqx = Mq {
oracle,
build,
cache: HashMap::new(),
alpha,
};
let mut t = Table {
s: vec![vec![]],
e: vec![vec![]],
rows: HashMap::new(),
};
let mut rounds = 0u64;
loop {
loop {
let s_rows: std::collections::HashSet<Vec<bool>> = {
let mut set = std::collections::HashSet::new();
for s in t.s.clone() {
let r = {
let mut ask = |w: &[usize]| mqx.ask(w);
t.row(&s, &mut ask)?
};
set.insert(r);
}
set
};
let mut added = false;
'close: for s in t.s.clone() {
for a in 0..alpha.len() {
let mut sa = s.clone();
sa.push(a);
let r = {
let mut ask = |w: &[usize]| mqx.ask(w);
t.row(&sa, &mut ask)?
};
if !s_rows.contains(&r) {
t.s.push(sa);
added = true;
break 'close;
}
}
}
if added {
continue;
}
let mut fix: Option<Vec<usize>> = None;
'cons: for i in 0..t.s.len() {
for j in (i + 1)..t.s.len() {
let (si, sj) = (t.s[i].clone(), t.s[j].clone());
let (ri, rj) = {
let mut ask = |w: &[usize]| mqx.ask(w);
(t.row(&si, &mut ask)?, t.row(&sj, &mut ask)?)
};
if ri != rj {
continue;
}
for a in 0..alpha.len() {
for ei in 0..t.e.len() {
let e = t.e[ei].clone();
let mut wia = si.clone();
wia.push(a);
wia.extend_from_slice(&e);
let mut wja = sj.clone();
wja.push(a);
wja.extend_from_slice(&e);
let (a1, a2) = {
let mut ask = |w: &[usize]| mqx.ask(w);
(ask(&wia)?, ask(&wja)?)
};
if a1 != a2 {
let mut suffix = vec![a];
suffix.extend_from_slice(&e);
fix = Some(suffix);
break 'cons;
}
}
}
}
}
if let Some(suffix) = fix {
if !t.e.contains(&suffix) {
t.e.push(suffix);
t.rows.clear();
}
continue;
}
break;
}
if mqx.cache.len() as u64 >= budget {
return Err(crate::error::WafModelError::BudgetExhausted {
queries: mqx.cache.len() as u64,
});
}
let hyp = {
let mut ask = |w: &[usize]| mqx.ask(w);
build_hypothesis(&mut t, alpha, &mut ask)?
};
let ce = {
let mut ask = |w: &[usize]| mqx.ask(w);
eq.find_counterexample(&hyp, alpha, &mut ask)?
};
match ce {
None => {
return Ok(LearnReport {
sfa: hyp,
membership_queries: mqx.cache.len() as u64,
equivalence_rounds: rounds,
});
}
Some(c) => {
rounds += 1;
for i in 0..=c.len() {
let suf = c[i..].to_vec();
if !t.e.contains(&suf) {
t.e.push(suf);
}
}
t.rows.clear();
}
}
}
}
pub fn l_star<B>(
oracle: &mut dyn WafOracle,
build: &B,
alpha: &Alphabet,
eq: &mut dyn EquivalenceOracle,
) -> Result<LearnReport>
where
B: Fn(&[u8]) -> Request,
{
l_star_impl(oracle, build, alpha, eq, u64::MAX)
}
pub fn l_star_budgeted<B>(
oracle: &mut dyn WafOracle,
build: &B,
alpha: &Alphabet,
eq: &mut dyn EquivalenceOracle,
max_queries: u64,
) -> Result<LearnReport>
where
B: Fn(&[u8]) -> Request,
{
l_star_impl(oracle, build, alpha, eq, max_queries)
}
enum Node {
Leaf(usize),
Inner {
suffix: Vec<usize>,
accept_child: Box<Node>,
reject_child: Box<Node>,
},
}
struct Kv<'a, B> {
mqx: Mq<'a, B>,
access: Vec<Vec<usize>>,
tree: Node,
}
impl<'a, B> Kv<'a, B>
where
B: Fn(&[u8]) -> Request,
{
fn sift(&mut self, word: &[usize]) -> Result<usize> {
let mut node = &self.tree;
loop {
match node {
Node::Leaf(id) => return Ok(*id),
Node::Inner {
suffix,
accept_child,
reject_child,
} => {
let mut w = word.to_vec();
w.extend_from_slice(suffix);
node = if self.mqx.ask(&w)? {
accept_child
} else {
reject_child
};
}
}
}
}
fn hypothesis(&mut self, alpha: &Alphabet) -> Result<Sfa> {
let n = self.access.len();
let mut accept = vec![false; n];
for (i, slot) in accept.iter_mut().enumerate().take(n) {
*slot = self.mqx.ask(&self.access[i])?;
}
let words = self.access.clone();
let mut delta: Vec<Vec<(BytePred, StateId)>> = Vec::with_capacity(n);
for w in &words {
let mut row = Vec::with_capacity(alpha.len());
for sym in 0..alpha.len() {
let mut sa = w.clone();
sa.push(sym);
let tgt = self.sift(&sa)?;
row.push((alpha.guard(sym), tgt));
}
delta.push(row);
}
let start = self.sift(&[])?;
Ok(Sfa::new(start, accept, delta))
}
}
fn replace_leaf(node: &mut Node, target: usize, replacement: Node) {
match node {
Node::Leaf(id) if *id == target => *node = replacement,
Node::Leaf(_) => {}
Node::Inner {
accept_child,
reject_child,
..
} => {
replace_leaf(accept_child, target, replacement_clone(&replacement));
replace_leaf(reject_child, target, replacement);
}
}
}
fn replacement_clone(n: &Node) -> Node {
match n {
Node::Leaf(id) => Node::Leaf(*id),
Node::Inner {
suffix,
accept_child,
reject_child,
} => Node::Inner {
suffix: suffix.clone(),
accept_child: Box::new(replacement_clone(accept_child)),
reject_child: Box::new(replacement_clone(reject_child)),
},
}
}
pub fn kv_learn<B>(
oracle: &mut dyn WafOracle,
build: &B,
alpha: &Alphabet,
eq: &mut dyn EquivalenceOracle,
) -> Result<LearnReport>
where
B: Fn(&[u8]) -> Request,
{
let mut kv = Kv {
mqx: Mq {
oracle,
build,
cache: HashMap::new(),
alpha,
},
access: vec![vec![]],
tree: Node::Leaf(0),
};
let mut rounds = 0u64;
loop {
let hyp = kv.hypothesis(alpha)?;
let ce = {
let cache_ref = &mut kv.mqx;
let mut ask = |w: &[usize]| cache_ref.ask(w);
eq.find_counterexample(&hyp, alpha, &mut ask)?
};
let Some(c) = ce else {
return Ok(LearnReport {
sfa: hyp,
membership_queries: kv.mqx.cache.len() as u64,
equivalence_rounds: rounds,
});
};
rounds += 1;
let n = c.len();
if n == 0 {
return Err(crate::error::WafModelError::Oracle(
"equivalence oracle returned an empty counterexample — \
Rivest–Schapire decomposition is undefined for ε"
.into(),
));
}
let state_word = |k: usize, kv: &mut Kv<B>| -> Result<Vec<usize>> {
let id = {
let pref = c[..k].to_vec();
kv.sift(&pref)?
};
Ok(kv.access[id].clone())
};
let alpha_at = |k: usize, kv: &mut Kv<B>| -> Result<bool> {
let mut w = state_word(k, kv)?;
w.extend_from_slice(&c[k..]);
kv.mqx.ask(&w)
};
let g0 = alpha_at(0, &mut kv)?;
let gn = alpha_at(n, &mut kv)?;
debug_assert_ne!(g0, gn, "Rivest–Schapire precondition: γ0 ≠ γn");
let (mut lo, mut hi) = (0usize, n);
while hi - lo > 1 {
let mid = (lo + hi) / 2;
if alpha_at(mid, &mut kv)? == g0 {
lo = mid;
} else {
hi = mid;
}
}
let i = lo;
let new_suffix = c[i + 1..].to_vec();
let new_access = {
let mut w = state_word(i, &mut kv)?;
w.push(c[i]);
w
};
let split_leaf = {
let pref = c[..i + 1].to_vec();
kv.sift(&pref)?
};
let new_id = kv.access.len();
kv.access.push(new_access.clone());
let old_access = kv.access[split_leaf].clone();
let mut old_probe = old_access;
old_probe.extend_from_slice(&new_suffix);
let old_goes_accept = kv.mqx.ask(&old_probe)?;
let (accept_child, reject_child) = if old_goes_accept {
(
Box::new(Node::Leaf(split_leaf)),
Box::new(Node::Leaf(new_id)),
)
} else {
(
Box::new(Node::Leaf(new_id)),
Box::new(Node::Leaf(split_leaf)),
)
};
let replacement = Node::Inner {
suffix: new_suffix,
accept_child,
reject_child,
};
replace_leaf(&mut kv.tree, split_leaf, replacement);
}
}