use std::collections::HashSet;
use tensorlogic_ir::TLExpr;
#[cfg(test)]
use crate::rule_guided_decoder::error::RuleGuidedError;
use crate::rule_guided_decoder::error::RuleGuidedResult;
pub type TokenId = usize;
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum ConstraintVerdict {
Allowed,
Forbidden,
SoftPenalty(f64),
}
pub type TokenSymbolMapper = dyn Fn(TokenId) -> Option<String> + Send + Sync;
pub struct RuleConstraint {
source: TLExpr,
allow_set: Option<HashSet<String>>,
complement: bool,
mapper: Box<TokenSymbolMapper>,
supported: bool,
}
impl std::fmt::Debug for RuleConstraint {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RuleConstraint")
.field("source", &self.source)
.field("allow_set", &self.allow_set)
.field("complement", &self.complement)
.field("supported", &self.supported)
.finish_non_exhaustive()
}
}
impl RuleConstraint {
pub fn compile<M>(expr: TLExpr, mapper: M) -> RuleGuidedResult<Self>
where
M: Fn(TokenId) -> Option<String> + Send + Sync + 'static,
{
let mut builder = AllowSetBuilder::default();
let supported = builder.visit(&expr)?;
let (allow_set, complement) = if supported {
let (set, comp) = builder.finalize();
(Some(set), comp)
} else {
(None, false)
};
Ok(Self {
source: expr,
allow_set,
complement,
mapper: Box::new(mapper),
supported,
})
}
pub fn evaluate(&self, prefix: &[TokenId], candidate: TokenId) -> ConstraintVerdict {
let _ = prefix; if !self.supported {
return ConstraintVerdict::SoftPenalty(0.0);
}
let allow_set = match &self.allow_set {
Some(set) => set,
None => return ConstraintVerdict::SoftPenalty(0.0),
};
let symbol = (self.mapper)(candidate);
match symbol {
Some(name) => {
let in_set = allow_set.contains(&name);
if in_set ^ self.complement {
ConstraintVerdict::Allowed
} else {
ConstraintVerdict::Forbidden
}
}
None => {
ConstraintVerdict::SoftPenalty(1.0)
}
}
}
pub fn allow_set(&self) -> Option<&HashSet<String>> {
self.allow_set.as_ref()
}
pub fn is_complement(&self) -> bool {
self.complement
}
pub fn is_supported(&self) -> bool {
self.supported
}
pub fn source(&self) -> &TLExpr {
&self.source
}
}
type ClassifyResult = Option<(HashSet<String>, bool)>;
#[derive(Default)]
struct AllowSetBuilder {
current: Option<(HashSet<String>, bool)>,
}
impl AllowSetBuilder {
fn visit(&mut self, expr: &TLExpr) -> RuleGuidedResult<bool> {
let pair = match self.classify(expr)? {
Some(p) => p,
None => return Ok(false),
};
self.current = Some(pair);
Ok(true)
}
fn finalize(self) -> (HashSet<String>, bool) {
self.current.unwrap_or_default()
}
fn classify(&self, expr: &TLExpr) -> RuleGuidedResult<ClassifyResult> {
match expr {
TLExpr::Pred { name, args } => {
let mut set = HashSet::with_capacity(1 + args.len());
set.insert(name.clone());
for arg in args {
match arg {
tensorlogic_ir::Term::Const(s) => {
set.insert(s.clone());
}
tensorlogic_ir::Term::Var(_) => {
}
tensorlogic_ir::Term::Typed { value, .. } => {
if let tensorlogic_ir::Term::Const(s) = value.as_ref() {
set.insert(s.clone());
}
}
}
}
Ok(Some((set, false)))
}
TLExpr::And(lhs, rhs) => {
let (l, lc) = match self.classify(lhs)? {
Some(p) => p,
None => return Ok(None),
};
let (r, rc) = match self.classify(rhs)? {
Some(p) => p,
None => return Ok(None),
};
if lc == rc {
let combined: HashSet<String> = if lc {
l.union(&r).cloned().collect()
} else {
l.intersection(&r).cloned().collect()
};
Ok(Some((combined, lc)))
} else {
Ok(None)
}
}
TLExpr::Or(lhs, rhs) => {
let (l, lc) = match self.classify(lhs)? {
Some(p) => p,
None => return Ok(None),
};
let (r, rc) = match self.classify(rhs)? {
Some(p) => p,
None => return Ok(None),
};
if lc == rc {
let combined: HashSet<String> = if lc {
l.intersection(&r).cloned().collect()
} else {
l.union(&r).cloned().collect()
};
Ok(Some((combined, lc)))
} else {
Ok(None)
}
}
TLExpr::Not(inner) => {
match self.classify(inner)? {
Some((set, comp)) => Ok(Some((set, !comp))),
None => Ok(None),
}
}
TLExpr::Imply(premise, conclusion) => {
let not_p = TLExpr::Not(premise.clone());
let rewritten = TLExpr::Or(Box::new(not_p), conclusion.clone());
self.classify(&rewritten)
}
_ => Ok(None),
}
}
}
pub const fn extend_tlexpr_support() {}
#[cfg(test)]
mod tests {
use super::*;
use tensorlogic_ir::Term;
fn mk_pred(name: &str, consts: &[&str]) -> TLExpr {
TLExpr::Pred {
name: name.into(),
args: consts.iter().map(|c| Term::Const((*c).into())).collect(),
}
}
fn demo_mapper() -> impl Fn(TokenId) -> Option<String> + Send + Sync + 'static {
|tid: TokenId| match tid {
1 => Some("Alice".into()),
2 => Some("Bob".into()),
3 => Some("entity".into()),
_ => None,
}
}
#[test]
fn predicate_allow_list_accepts_named_consts() {
let expr = mk_pred("entity", &["Alice"]);
let rc = RuleConstraint::compile(expr, demo_mapper()).expect("compile");
assert!(rc.is_supported());
assert_eq!(rc.evaluate(&[], 1), ConstraintVerdict::Allowed);
assert_eq!(rc.evaluate(&[], 2), ConstraintVerdict::Forbidden);
assert_eq!(rc.evaluate(&[], 3), ConstraintVerdict::Allowed);
}
#[test]
fn conjunction_intersects_allow_sets() {
let a = mk_pred("entity", &["Alice"]);
let b = mk_pred("entity", &["Bob"]);
let expr = TLExpr::And(Box::new(a), Box::new(b));
let rc = RuleConstraint::compile(expr, demo_mapper()).expect("compile");
assert!(rc.is_supported());
assert_eq!(rc.evaluate(&[], 1), ConstraintVerdict::Forbidden);
assert_eq!(rc.evaluate(&[], 2), ConstraintVerdict::Forbidden);
assert_eq!(rc.evaluate(&[], 3), ConstraintVerdict::Allowed);
}
#[test]
fn disjunction_unions_allow_sets() {
let a = mk_pred("entity", &["Alice"]);
let b = mk_pred("entity", &["Bob"]);
let expr = TLExpr::Or(Box::new(a), Box::new(b));
let rc = RuleConstraint::compile(expr, demo_mapper()).expect("compile");
assert!(rc.is_supported());
assert_eq!(rc.evaluate(&[], 1), ConstraintVerdict::Allowed);
assert_eq!(rc.evaluate(&[], 2), ConstraintVerdict::Allowed);
assert_eq!(rc.evaluate(&[], 3), ConstraintVerdict::Allowed);
}
#[test]
fn unsupported_variant_returns_soft_noop() {
let body = mk_pred("entity", &["Alice"]);
let expr = TLExpr::Exists {
var: "x".to_string(),
domain: "Person".to_string(),
body: Box::new(body),
};
let rc = RuleConstraint::compile(expr, demo_mapper()).expect("compile");
assert!(!rc.is_supported());
assert_eq!(rc.evaluate(&[], 1), ConstraintVerdict::SoftPenalty(0.0));
}
#[test]
fn unknown_token_yields_soft_penalty() {
let expr = mk_pred("entity", &["Alice"]);
let rc = RuleConstraint::compile(expr, demo_mapper()).expect("compile");
assert_eq!(rc.evaluate(&[], 99), ConstraintVerdict::SoftPenalty(1.0));
}
#[test]
fn empty_intersection_forbids_all_known_tokens() {
let a = mk_pred("entity", &["Alice"]);
let b = mk_pred("user", &["Charlie"]);
let expr = TLExpr::And(Box::new(a), Box::new(b));
let rc = RuleConstraint::compile(expr, demo_mapper()).expect("compile");
assert!(rc.is_supported());
assert_eq!(rc.evaluate(&[], 1), ConstraintVerdict::Forbidden);
assert_eq!(rc.evaluate(&[], 3), ConstraintVerdict::Forbidden);
}
#[test]
fn error_type_has_useful_display() {
let err: RuleGuidedError =
RuleGuidedError::CompilationError("synthetic failure".to_string());
assert!(err.to_string().contains("synthetic"));
}
#[test]
fn not_pred_forbids_inner_allows_rest() {
let inner = mk_pred("entity", &["Alice"]);
let expr = TLExpr::Not(Box::new(inner));
let rc = RuleConstraint::compile(expr, demo_mapper()).expect("compile");
assert!(rc.is_supported());
assert!(rc.is_complement());
assert_eq!(rc.evaluate(&[], 1), ConstraintVerdict::Forbidden);
assert_eq!(rc.evaluate(&[], 3), ConstraintVerdict::Forbidden);
assert_eq!(rc.evaluate(&[], 2), ConstraintVerdict::Allowed);
assert_eq!(rc.evaluate(&[], 99), ConstraintVerdict::SoftPenalty(1.0));
}
#[test]
fn double_negation_is_identity() {
let inner = mk_pred("entity", &["Alice"]);
let single = TLExpr::Not(Box::new(inner.clone()));
let double = TLExpr::Not(Box::new(single));
let rc = RuleConstraint::compile(double, demo_mapper()).expect("compile");
assert!(rc.is_supported());
assert!(!rc.is_complement());
assert_eq!(rc.evaluate(&[], 1), ConstraintVerdict::Allowed); assert_eq!(rc.evaluate(&[], 2), ConstraintVerdict::Forbidden); assert_eq!(rc.evaluate(&[], 3), ConstraintVerdict::Allowed); }
#[test]
fn imply_p_q_is_not_p_or_q() {
let p = mk_pred("entity", &["Alice"]);
let q = mk_pred("entity", &["Bob"]);
let expr = TLExpr::Imply(Box::new(p), Box::new(q));
let rc = RuleConstraint::compile(expr, demo_mapper()).expect("compile");
assert!(!rc.is_supported());
assert_eq!(rc.evaluate(&[], 1), ConstraintVerdict::SoftPenalty(0.0));
}
#[test]
fn imply_p_q_same_complement_succeeds() {
let a = mk_pred("entity", &["Alice"]);
let b = mk_pred("entity", &["Bob"]);
let not_a = TLExpr::Not(Box::new(a));
let expr = TLExpr::Imply(Box::new(not_a), Box::new(b));
let rc = RuleConstraint::compile(expr, demo_mapper()).expect("compile");
assert!(rc.is_supported());
assert!(!rc.is_complement());
assert_eq!(rc.evaluate(&[], 1), ConstraintVerdict::Allowed); assert_eq!(rc.evaluate(&[], 2), ConstraintVerdict::Allowed); assert_eq!(rc.evaluate(&[], 3), ConstraintVerdict::Allowed); }
}