use std::collections::HashSet;
use tensorlogic_ir::TLExpr;
#[cfg(test)]
use crate::rule_guided_decoder::error::RuleGuidedError;
use crate::rule_guided_decoder::error::RuleGuidedResult;
pub type TokenId = usize;
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum ConstraintVerdict {
Allowed,
Forbidden,
SoftPenalty(f64),
}
pub type TokenSymbolMapper = dyn Fn(TokenId) -> Option<String> + Send + Sync;
pub struct RuleConstraint {
source: TLExpr,
allow_set: Option<HashSet<String>>,
mapper: Box<TokenSymbolMapper>,
supported: bool,
}
impl std::fmt::Debug for RuleConstraint {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RuleConstraint")
.field("source", &self.source)
.field("allow_set", &self.allow_set)
.field("supported", &self.supported)
.finish_non_exhaustive()
}
}
impl RuleConstraint {
pub fn compile<M>(expr: TLExpr, mapper: M) -> RuleGuidedResult<Self>
where
M: Fn(TokenId) -> Option<String> + Send + Sync + 'static,
{
let mut builder = AllowSetBuilder::default();
let supported = builder.visit(&expr)?;
let allow_set = if supported {
Some(builder.finalize())
} else {
None
};
Ok(Self {
source: expr,
allow_set,
mapper: Box::new(mapper),
supported,
})
}
pub fn evaluate(&self, prefix: &[TokenId], candidate: TokenId) -> ConstraintVerdict {
let _ = prefix; if !self.supported {
return ConstraintVerdict::SoftPenalty(0.0);
}
let allow_set = match &self.allow_set {
Some(set) => set,
None => return ConstraintVerdict::SoftPenalty(0.0),
};
let symbol = (self.mapper)(candidate);
match symbol {
Some(name) if allow_set.contains(&name) => ConstraintVerdict::Allowed,
Some(_) => ConstraintVerdict::Forbidden,
None => {
ConstraintVerdict::SoftPenalty(1.0)
}
}
}
pub fn allow_set(&self) -> Option<&HashSet<String>> {
self.allow_set.as_ref()
}
pub fn is_supported(&self) -> bool {
self.supported
}
pub fn source(&self) -> &TLExpr {
&self.source
}
}
#[derive(Default)]
struct AllowSetBuilder {
current: Option<HashSet<String>>,
}
impl AllowSetBuilder {
fn visit(&mut self, expr: &TLExpr) -> RuleGuidedResult<bool> {
let set = match self.classify(expr)? {
Some(s) => s,
None => return Ok(false),
};
self.current = Some(set);
Ok(true)
}
fn finalize(self) -> HashSet<String> {
self.current.unwrap_or_default()
}
fn classify(&self, expr: &TLExpr) -> RuleGuidedResult<Option<HashSet<String>>> {
match expr {
TLExpr::Pred { name, args } => {
let mut set = HashSet::with_capacity(1 + args.len());
set.insert(name.clone());
for arg in args {
match arg {
tensorlogic_ir::Term::Const(s) => {
set.insert(s.clone());
}
tensorlogic_ir::Term::Var(_) => {
}
tensorlogic_ir::Term::Typed { value, .. } => {
if let tensorlogic_ir::Term::Const(s) = value.as_ref() {
set.insert(s.clone());
}
}
}
}
Ok(Some(set))
}
TLExpr::And(lhs, rhs) => {
let l = match self.classify(lhs)? {
Some(s) => s,
None => return Ok(None),
};
let r = match self.classify(rhs)? {
Some(s) => s,
None => return Ok(None),
};
Ok(Some(l.intersection(&r).cloned().collect()))
}
TLExpr::Or(lhs, rhs) => {
let l = match self.classify(lhs)? {
Some(s) => s,
None => return Ok(None),
};
let r = match self.classify(rhs)? {
Some(s) => s,
None => return Ok(None),
};
Ok(Some(l.union(&r).cloned().collect()))
}
TLExpr::Not(inner) => {
let _ = inner;
Ok(None)
}
_ => Ok(None),
}
}
}
pub const fn extend_tlexpr_support() {}
#[cfg(test)]
mod tests {
use super::*;
use tensorlogic_ir::Term;
fn mk_pred(name: &str, consts: &[&str]) -> TLExpr {
TLExpr::Pred {
name: name.into(),
args: consts.iter().map(|c| Term::Const((*c).into())).collect(),
}
}
fn demo_mapper() -> impl Fn(TokenId) -> Option<String> + Send + Sync + 'static {
|tid: TokenId| match tid {
1 => Some("Alice".into()),
2 => Some("Bob".into()),
3 => Some("entity".into()),
_ => None,
}
}
#[test]
fn predicate_allow_list_accepts_named_consts() {
let expr = mk_pred("entity", &["Alice"]);
let rc = RuleConstraint::compile(expr, demo_mapper()).expect("compile");
assert!(rc.is_supported());
assert_eq!(rc.evaluate(&[], 1), ConstraintVerdict::Allowed);
assert_eq!(rc.evaluate(&[], 2), ConstraintVerdict::Forbidden);
assert_eq!(rc.evaluate(&[], 3), ConstraintVerdict::Allowed);
}
#[test]
fn conjunction_intersects_allow_sets() {
let a = mk_pred("entity", &["Alice"]);
let b = mk_pred("entity", &["Bob"]);
let expr = TLExpr::And(Box::new(a), Box::new(b));
let rc = RuleConstraint::compile(expr, demo_mapper()).expect("compile");
assert!(rc.is_supported());
assert_eq!(rc.evaluate(&[], 1), ConstraintVerdict::Forbidden);
assert_eq!(rc.evaluate(&[], 2), ConstraintVerdict::Forbidden);
assert_eq!(rc.evaluate(&[], 3), ConstraintVerdict::Allowed);
}
#[test]
fn disjunction_unions_allow_sets() {
let a = mk_pred("entity", &["Alice"]);
let b = mk_pred("entity", &["Bob"]);
let expr = TLExpr::Or(Box::new(a), Box::new(b));
let rc = RuleConstraint::compile(expr, demo_mapper()).expect("compile");
assert!(rc.is_supported());
assert_eq!(rc.evaluate(&[], 1), ConstraintVerdict::Allowed);
assert_eq!(rc.evaluate(&[], 2), ConstraintVerdict::Allowed);
assert_eq!(rc.evaluate(&[], 3), ConstraintVerdict::Allowed);
}
#[test]
fn unsupported_variant_returns_soft_noop() {
let inner = mk_pred("entity", &["Alice"]);
let expr = TLExpr::Not(Box::new(inner));
let rc = RuleConstraint::compile(expr, demo_mapper()).expect("compile");
assert!(!rc.is_supported());
assert_eq!(rc.evaluate(&[], 1), ConstraintVerdict::SoftPenalty(0.0));
}
#[test]
fn unknown_token_yields_soft_penalty() {
let expr = mk_pred("entity", &["Alice"]);
let rc = RuleConstraint::compile(expr, demo_mapper()).expect("compile");
assert_eq!(rc.evaluate(&[], 99), ConstraintVerdict::SoftPenalty(1.0));
}
#[test]
fn empty_intersection_forbids_all_known_tokens() {
let a = mk_pred("entity", &["Alice"]);
let b = mk_pred("user", &["Charlie"]);
let expr = TLExpr::And(Box::new(a), Box::new(b));
let rc = RuleConstraint::compile(expr, demo_mapper()).expect("compile");
assert!(rc.is_supported());
assert_eq!(rc.evaluate(&[], 1), ConstraintVerdict::Forbidden);
assert_eq!(rc.evaluate(&[], 3), ConstraintVerdict::Forbidden);
}
#[test]
fn error_type_has_useful_display() {
let err: RuleGuidedError =
RuleGuidedError::CompilationError("synthetic failure".to_string());
assert!(err.to_string().contains("synthetic"));
}
}