use serde::{Deserialize, Serialize};
use std::collections::BTreeMap;
use std::fmt;
use thiserror::Error;
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)]
#[serde(rename_all = "snake_case")]
pub enum Section {
Keywords,
Semantics,
Llm,
}
impl Section {
pub(crate) fn from_str(s: &str) -> Option<Self> {
match s {
"keywords" => Some(Self::Keywords),
"semantics" => Some(Self::Semantics),
"llm" => Some(Self::Llm),
_ => None,
}
}
}
impl fmt::Display for Section {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let s = match self {
Self::Keywords => "keywords",
Self::Semantics => "semantics",
Self::Llm => "llm",
};
f.write_str(s)
}
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
pub enum Quantifier {
Any,
All,
AtLeast(u32),
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum ConditionExpr {
Literal(bool),
Reference {
section: Section,
var: String,
},
Wildcard {
section: Section,
},
PrefixWildcard {
section: Section,
prefix: String,
},
Quantified {
quantifier: Quantifier,
target: Box<QuantifierTarget>,
},
Not(Box<ConditionExpr>),
And(Vec<ConditionExpr>),
Or(Vec<ConditionExpr>),
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum QuantifierTarget {
SectionWildcard(Section),
Inner(Box<ConditionExpr>),
}
#[derive(Debug, Error, PartialEq)]
pub enum EvalError {
#[error("condition references unknown variable `{section}.${var}`")]
UnknownReference { section: Section, var: String },
}
#[derive(Debug, Default, Clone)]
pub struct EvalContext {
pub keywords: BTreeMap<String, bool>,
pub semantics: BTreeMap<String, bool>,
pub llm: BTreeMap<String, bool>,
}
impl EvalContext {
fn section(&self, section: Section) -> &BTreeMap<String, bool> {
match section {
Section::Keywords => &self.keywords,
Section::Semantics => &self.semantics,
Section::Llm => &self.llm,
}
}
}
impl ConditionExpr {
pub fn eval(&self, ctx: &EvalContext) -> Result<bool, EvalError> {
match self {
Self::Literal(b) => Ok(*b),
Self::Reference { section, var } => match ctx.section(*section).get(var) {
Some(b) => Ok(*b),
None => Err(EvalError::UnknownReference {
section: *section,
var: var.clone(),
}),
},
Self::Wildcard { section } => {
let map = ctx.section(*section);
Ok(map.values().any(|b| *b))
}
Self::PrefixWildcard { section, prefix } => {
let map = ctx.section(*section);
Ok(map
.iter()
.any(|(name, hit)| *hit && name.starts_with(prefix)))
}
Self::Not(inner) => Ok(!inner.eval(ctx)?),
Self::And(items) => {
for item in items {
if !item.eval(ctx)? {
return Ok(false);
}
}
Ok(true)
}
Self::Or(items) => {
for item in items {
if item.eval(ctx)? {
return Ok(true);
}
}
Ok(false)
}
Self::Quantified { quantifier, target } => {
let hits = collect_target_hits(target, ctx)?;
let count = hits.iter().filter(|b| **b).count() as u32;
Ok(match quantifier {
Quantifier::Any => count >= 1,
Quantifier::All => !hits.is_empty() && count == hits.len() as u32,
Quantifier::AtLeast(n) => count >= *n,
})
}
}
}
}
fn collect_target_hits(
target: &QuantifierTarget,
ctx: &EvalContext,
) -> Result<Vec<bool>, EvalError> {
match target {
QuantifierTarget::SectionWildcard(section) => {
Ok(ctx.section(*section).values().copied().collect())
}
QuantifierTarget::Inner(expr) => match expr.as_ref() {
ConditionExpr::Or(items) => items.iter().map(|i| i.eval(ctx)).collect(),
other => Ok(vec![other.eval(ctx)?]),
},
}
}
#[cfg(test)]
mod tests {
use super::*;
fn ctx(
keywords: &[(&str, bool)],
semantics: &[(&str, bool)],
llm: &[(&str, bool)],
) -> EvalContext {
let to_map = |entries: &[(&str, bool)]| {
entries
.iter()
.map(|(k, v)| ((*k).to_string(), *v))
.collect::<BTreeMap<_, _>>()
};
EvalContext {
keywords: to_map(keywords),
semantics: to_map(semantics),
llm: to_map(llm),
}
}
#[test]
fn wildcard_keywords_any_match_is_true() {
let c = ConditionExpr::Wildcard {
section: Section::Keywords,
};
assert!(c
.eval(&ctx(&[("a", false), ("b", true)], &[], &[]))
.unwrap());
assert!(!c
.eval(&ctx(&[("a", false), ("b", false)], &[], &[]))
.unwrap());
assert!(!c.eval(&ctx(&[], &[], &[])).unwrap());
}
#[test]
fn any_of_wildcard_equivalent_to_bare_wildcard() {
let bare = ConditionExpr::Wildcard {
section: Section::Keywords,
};
let any_of = ConditionExpr::Quantified {
quantifier: Quantifier::Any,
target: Box::new(QuantifierTarget::SectionWildcard(Section::Keywords)),
};
for fixture in [
ctx(&[("a", true)], &[], &[]),
ctx(&[("a", false)], &[], &[]),
ctx(&[("a", false), ("b", true)], &[], &[]),
] {
assert_eq!(bare.eval(&fixture), any_of.eval(&fixture));
}
}
#[test]
fn all_of_section_requires_every_hit_and_nonempty() {
let all = ConditionExpr::Quantified {
quantifier: Quantifier::All,
target: Box::new(QuantifierTarget::SectionWildcard(Section::Semantics)),
};
assert!(all
.eval(&ctx(&[], &[("a", true), ("b", true)], &[]))
.unwrap());
assert!(!all
.eval(&ctx(&[], &[("a", true), ("b", false)], &[]))
.unwrap());
assert!(!all.eval(&ctx(&[], &[], &[])).unwrap());
}
#[test]
fn at_least_n_quantifier_counts_matches() {
let two = ConditionExpr::Quantified {
quantifier: Quantifier::AtLeast(2),
target: Box::new(QuantifierTarget::SectionWildcard(Section::Llm)),
};
assert!(two
.eval(&ctx(&[], &[], &[("a", true), ("b", true), ("c", false)]))
.unwrap());
assert!(!two
.eval(&ctx(&[], &[], &[("a", true), ("b", false)]))
.unwrap());
}
#[test]
fn unknown_reference_returns_named_error() {
let r = ConditionExpr::Reference {
section: Section::Keywords,
var: "missing_var".into(),
};
let err = r
.eval(&ctx(&[("present", true)], &[], &[]))
.expect_err("typo must fail loudly");
assert_eq!(
err,
EvalError::UnknownReference {
section: Section::Keywords,
var: "missing_var".into()
}
);
}
#[test]
fn boolean_combinators_short_circuit_correctly() {
let kw_a = ConditionExpr::Reference {
section: Section::Keywords,
var: "a".into(),
};
let kw_b = ConditionExpr::Reference {
section: Section::Keywords,
var: "b".into(),
};
let expr = ConditionExpr::And(vec![
ConditionExpr::Not(Box::new(kw_a.clone())),
kw_b.clone(),
]);
assert!(expr
.eval(&ctx(&[("a", false), ("b", true)], &[], &[]))
.unwrap());
assert!(!expr
.eval(&ctx(&[("a", true), ("b", true)], &[], &[]))
.unwrap());
assert!(!expr
.eval(&ctx(&[("a", false), ("b", false)], &[], &[]))
.unwrap());
let or_expr = ConditionExpr::Or(vec![kw_a, kw_b]);
assert!(or_expr
.eval(&ctx(&[("a", false), ("b", true)], &[], &[]))
.unwrap());
assert!(!or_expr
.eval(&ctx(&[("a", false), ("b", false)], &[], &[]))
.unwrap());
}
}