#![cfg(test)]
use std::sync::Arc;
use tensorlogic_infer::beam_search::BeamSearchConfig;
use tensorlogic_ir::{TLExpr, Term};
use super::constraint::{ConstraintVerdict, RuleConstraint, TokenId};
use super::engine::RuleGuidedBeamSearch;
use super::mask::{HardMask, LogitMasker, SoftPenaltyMask};
fn mapper() -> impl Fn(TokenId) -> Option<String> + Send + Sync + 'static {
|tid: TokenId| match tid {
0 => Some("entity".into()),
1 => Some("Alice".into()),
2 => Some("Bob".into()),
3 => Some("Eve".into()),
_ => None,
}
}
#[test]
fn conjunction_constraint_hard_masks_all_non_shared_symbols() {
let a = TLExpr::Pred {
name: "entity".into(),
args: vec![Term::Const("Alice".into())],
};
let b = TLExpr::Pred {
name: "entity".into(),
args: vec![Term::Const("Bob".into())],
};
let expr = TLExpr::And(Box::new(a), Box::new(b));
let rc = RuleConstraint::compile(expr, mapper()).expect("compile");
let mut logits = vec![0.0_f64; 4];
HardMask::new().apply(&rc, &[], &mut logits).expect("apply");
assert_eq!(logits[0], 0.0);
assert_eq!(logits[1], f64::NEG_INFINITY);
assert_eq!(logits[2], f64::NEG_INFINITY);
assert_eq!(logits[3], f64::NEG_INFINITY);
}
#[test]
fn soft_mask_math_matches_lambda_times_violation() {
let expr = TLExpr::Pred {
name: "entity".into(),
args: vec![Term::Const("Alice".into())],
};
let rc = RuleConstraint::compile(expr, mapper()).expect("compile");
let mut logits = vec![3.0_f64; 5];
SoftPenaltyMask::new(0.75)
.expect("ctor")
.apply(&rc, &[], &mut logits)
.expect("apply");
assert!((logits[0] - 3.0).abs() < 1e-12);
assert!((logits[1] - 3.0).abs() < 1e-12);
assert_eq!(logits[2], f64::NEG_INFINITY);
assert_eq!(logits[3], f64::NEG_INFINITY);
assert!((logits[4] - 2.25).abs() < 1e-12);
}
#[test]
fn predicate_allow_list_evaluates_prefix_independent() {
let expr = TLExpr::Pred {
name: "entity".into(),
args: vec![Term::Const("Alice".into())],
};
let rc = RuleConstraint::compile(expr, mapper()).expect("compile");
assert_eq!(rc.evaluate(&[], 1), ConstraintVerdict::Allowed);
assert_eq!(rc.evaluate(&[7, 42, 99], 1), ConstraintVerdict::Allowed);
}
#[test]
fn integration_smoke_with_stubbed_decoder() {
let a = TLExpr::Pred {
name: "entity".into(),
args: vec![Term::Const("Alice".into())],
};
let b = TLExpr::Pred {
name: "entity".into(),
args: vec![Term::Const("Bob".into())],
};
let expr = TLExpr::Or(Box::new(a), Box::new(b));
let rc = RuleConstraint::compile(expr, mapper()).expect("compile");
let config = BeamSearchConfig {
beam_width: 2,
max_length: 3,
eos_token_id: None,
length_penalty: 0.0,
min_length: 1,
vocab_size: 4,
temperature: 1.0,
top_k_filter: None,
};
let decoder = RuleGuidedBeamSearch::new(config, rc, Arc::new(HardMask::new()));
let result = decoder
.decode(0, |beams: &[&[usize]]| {
Ok(beams.iter().map(|_| vec![1.0_f64, 0.5, 0.5, 0.5]).collect())
})
.expect("decode");
assert!(!result.hypotheses.is_empty());
for hyp in &result.hypotheses {
for tok in &hyp.tokens {
assert!(
(0..=2).contains(tok),
"hard-masked decoder emitted banned token {tok}: {:?}",
hyp.tokens
);
}
}
}
#[test]
fn negative_case_when_every_symbol_is_banned_decoder_still_terminates() {
let a = TLExpr::Pred {
name: "entity".into(),
args: vec![Term::Const("Alice".into())],
};
let c = TLExpr::Pred {
name: "user".into(),
args: vec![Term::Const("Charlie".into())],
};
let expr = TLExpr::And(Box::new(a), Box::new(c));
let rc = RuleConstraint::compile(expr, mapper()).expect("compile");
let masker: Arc<dyn LogitMasker> = Arc::new(SoftPenaltyMask::new(0.1).expect("ctor"));
let decoder = RuleGuidedBeamSearch::new(
BeamSearchConfig {
beam_width: 1,
max_length: 2,
eos_token_id: None,
length_penalty: 0.0,
min_length: 1,
vocab_size: 4,
temperature: 1.0,
top_k_filter: None,
},
rc,
masker,
);
let result = decoder
.decode(0, |beams: &[&[usize]]| {
Ok(beams.iter().map(|_| vec![0.0_f64; 4]).collect())
})
.expect("decode should not crash");
assert!(!result.hypotheses.is_empty());
}
#[test]
fn constraint_reports_source_and_support_flag() {
let expr = TLExpr::Pred {
name: "entity".into(),
args: vec![Term::Const("Alice".into())],
};
let rc = RuleConstraint::compile(expr.clone(), mapper()).expect("compile");
assert!(rc.is_supported());
assert_eq!(rc.source(), &expr);
let unsupported = TLExpr::Not(Box::new(expr));
let rc2 = RuleConstraint::compile(unsupported, mapper()).expect("compile");
assert!(!rc2.is_supported());
assert_eq!(rc2.evaluate(&[], 1), ConstraintVerdict::SoftPenalty(0.0));
}