use crate::rule_guided_decoder::constraint::{ConstraintVerdict, RuleConstraint, TokenId};
use crate::rule_guided_decoder::error::{RuleGuidedError, RuleGuidedResult};
pub trait LogitMasker: Send + Sync {
fn apply(
&self,
constraint: &RuleConstraint,
prefix: &[TokenId],
logits: &mut [f64],
) -> RuleGuidedResult<()>;
fn name(&self) -> &'static str;
}
#[derive(Debug, Default, Clone, Copy)]
pub struct HardMask;
impl HardMask {
pub const fn new() -> Self {
Self
}
}
impl LogitMasker for HardMask {
fn apply(
&self,
constraint: &RuleConstraint,
prefix: &[TokenId],
logits: &mut [f64],
) -> RuleGuidedResult<()> {
for (token_id, logit) in logits.iter_mut().enumerate() {
match constraint.evaluate(prefix, token_id) {
ConstraintVerdict::Allowed => {}
ConstraintVerdict::Forbidden => {
*logit = f64::NEG_INFINITY;
}
ConstraintVerdict::SoftPenalty(_) => {
}
}
}
Ok(())
}
fn name(&self) -> &'static str {
"HardMask"
}
}
#[derive(Debug, Clone, Copy)]
pub struct SoftPenaltyMask {
pub lambda: f64,
}
impl SoftPenaltyMask {
pub fn new(lambda: f64) -> RuleGuidedResult<Self> {
if !lambda.is_finite() || lambda < 0.0 {
return Err(RuleGuidedError::InvalidConfig(format!(
"lambda must be a non-negative finite number, got {lambda}"
)));
}
Ok(Self { lambda })
}
}
impl LogitMasker for SoftPenaltyMask {
fn apply(
&self,
constraint: &RuleConstraint,
prefix: &[TokenId],
logits: &mut [f64],
) -> RuleGuidedResult<()> {
for (token_id, logit) in logits.iter_mut().enumerate() {
match constraint.evaluate(prefix, token_id) {
ConstraintVerdict::Allowed => {}
ConstraintVerdict::Forbidden => {
*logit = f64::NEG_INFINITY;
}
ConstraintVerdict::SoftPenalty(score) => {
let clamped = score.max(0.0);
if self.lambda > 0.0 && clamped > 0.0 {
*logit -= self.lambda * clamped;
}
}
}
}
Ok(())
}
fn name(&self) -> &'static str {
"SoftPenaltyMask"
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::rule_guided_decoder::constraint::RuleConstraint;
use tensorlogic_ir::{TLExpr, Term};
fn mapper() -> impl Fn(TokenId) -> Option<String> + Send + Sync + 'static {
|tid: TokenId| match tid {
0 => Some("entity".into()),
1 => Some("Alice".into()),
2 => Some("Bob".into()),
_ => None,
}
}
fn alice_only() -> RuleConstraint {
let expr = TLExpr::Pred {
name: "entity".into(),
args: vec![Term::Const("Alice".into())],
};
RuleConstraint::compile(expr, mapper()).expect("compile")
}
#[test]
fn hard_mask_sets_forbidden_to_neg_infinity() {
let rc = alice_only();
let mut logits = vec![0.0_f64, 0.0, 0.0, 0.0];
HardMask::new().apply(&rc, &[], &mut logits).expect("apply");
assert_eq!(logits[0], 0.0);
assert_eq!(logits[1], 0.0);
assert_eq!(logits[2], f64::NEG_INFINITY);
assert_eq!(logits[3], 0.0);
}
#[test]
fn soft_mask_applies_log_penalty() {
let rc = alice_only();
let mut logits = vec![0.0_f64, 0.0, 0.0, 0.0];
SoftPenaltyMask::new(2.5)
.expect("ctor")
.apply(&rc, &[], &mut logits)
.expect("apply");
assert_eq!(logits[0], 0.0);
assert_eq!(logits[1], 0.0);
assert_eq!(logits[2], f64::NEG_INFINITY);
assert!((logits[3] - (-2.5)).abs() < 1e-12);
}
#[test]
fn soft_mask_rejects_negative_lambda() {
let err = SoftPenaltyMask::new(-0.1).expect_err("should reject");
assert!(err.to_string().contains("non-negative"));
}
#[test]
fn soft_mask_zero_lambda_is_noop() {
let rc = alice_only();
let mut logits = vec![0.0_f64, 0.0, 0.0, 0.0];
SoftPenaltyMask::new(0.0)
.expect("ctor")
.apply(&rc, &[], &mut logits)
.expect("apply");
assert_eq!(logits[0], 0.0);
assert_eq!(logits[1], 0.0);
assert_eq!(logits[2], f64::NEG_INFINITY);
assert_eq!(logits[3], 0.0);
}
#[test]
fn masker_name_reports_strategy() {
let hard: Box<dyn LogitMasker> = Box::new(HardMask::new());
let soft: Box<dyn LogitMasker> = Box::new(SoftPenaltyMask::new(1.0).expect("ctor"));
assert_eq!(hard.name(), "HardMask");
assert_eq!(soft.name(), "SoftPenaltyMask");
}
}