use crate::equiv_query::PacBound;
use crate::error::{Result, WafModelError};
use crate::learn::Alphabet;
use crate::sfa::{BytePred, Sfa};
pub const SCHEMA_VERSION: u32 = 1;
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
struct Edge {
to: usize,
pred: String,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
struct StateRow {
accept: bool,
edge: Vec<Edge>,
}
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
pub struct Provenance {
pub oracle_id: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub ruleset_fingerprint: Option<String>,
pub membership_queries: u64,
pub equivalence_rounds: u64,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub pac: Option<PacBound>,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct LearnedModel {
schema_version: u32,
start: usize,
alphabet: Vec<u8>,
pub provenance: Provenance,
state: Vec<StateRow>,
}
impl LearnedModel {
#[must_use]
pub fn capture(alpha: &Alphabet, sfa: &Sfa, provenance: Provenance) -> Self {
let (start, accept, delta) = sfa.export();
let state = accept
.iter()
.zip(delta.iter())
.map(|(&acc, edges)| StateRow {
accept: acc,
edge: edges
.iter()
.map(|(p, t)| Edge {
to: *t,
pred: p.to_hex(),
})
.collect(),
})
.collect();
LearnedModel {
schema_version: SCHEMA_VERSION,
provenance,
alphabet: alpha.raw_symbols().to_vec(),
start,
state,
}
}
pub fn to_toml(&self) -> Result<String> {
toml::to_string_pretty(self).map_err(|e| WafModelError::Artifact(e.to_string()))
}
pub fn from_toml(src: &str) -> Result<Self> {
let m: LearnedModel =
toml::from_str(src).map_err(|e| WafModelError::Artifact(e.to_string()))?;
if m.schema_version != SCHEMA_VERSION {
return Err(WafModelError::Artifact(format!(
"unsupported schema version {} (this build understands {})",
m.schema_version, SCHEMA_VERSION
)));
}
Ok(m)
}
pub fn alphabet(&self) -> Result<Alphabet> {
if self.alphabet.is_empty() {
return Err(WafModelError::Artifact("empty alphabet".into()));
}
let mut d = self.alphabet.clone();
d.sort_unstable();
let before = d.len();
d.dedup();
if d.len() != before {
return Err(WafModelError::Artifact("duplicate alphabet symbols".into()));
}
Ok(Alphabet::from_raw_symbols(self.alphabet.clone()))
}
pub fn sfa(&self) -> Result<Sfa> {
let n = self.state.len();
if self.start >= n {
return Err(WafModelError::Artifact(format!(
"start state {} out of range (n={n})",
self.start
)));
}
let mut accept = Vec::with_capacity(n);
let mut delta: Vec<Vec<(BytePred, usize)>> = Vec::with_capacity(n);
for (si, row) in self.state.iter().enumerate() {
accept.push(row.accept);
let mut trans = Vec::with_capacity(row.edge.len());
let mut cover = BytePred::none();
for e in &row.edge {
if e.to >= n {
return Err(WafModelError::Artifact(format!(
"state {si}: edge target {} out of range",
e.to
)));
}
let p = BytePred::from_hex(&e.pred).ok_or_else(|| {
WafModelError::Artifact(format!("state {si}: malformed predicate hex"))
})?;
if !cover.and(p).is_empty() {
return Err(WafModelError::Artifact(format!(
"state {si}: overlapping guards (non-deterministic artifact)"
)));
}
cover = cover.or(p);
trans.push((p, e.to));
}
if cover != BytePred::any() {
return Err(WafModelError::Artifact(format!(
"state {si}: guards are not total (incomplete artifact)"
)));
}
delta.push(trans);
}
Ok(Sfa::import(self.start, accept, delta))
}
}