1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
//! Rule-Guided Sampling Decoder.
//!
//! This module implements a decoder that biases the beam-search algorithm
//! shipped in [`tensorlogic_infer::beam_search`] to prefer token sequences
//! consistent with a user-supplied [`tensorlogic_ir::TLExpr`] logical
//! constraint.
//!
//! Two enforcement strategies coexist:
//!
//! * **Hard masking** — forbidden tokens are hit with `f64::NEG_INFINITY`
//! logits and are consequently eliminated from the candidate pool.
//! * **Soft re-weighting** — tokens that merely *violate* the constraint
//! (without being outright forbidden) receive a log-probability penalty
//! of `-lambda * violation_score`. Forbidden tokens are still fully
//! banned under soft mode — the soft rule only applies to the SoftPenalty
//! verdict returned by the constraint.
//!
//! ## Public surface
//!
//! * [`RuleConstraint`] — wraps a `TLExpr` and compiles a vocabulary-level
//! allow-list via a caller-supplied token-to-symbol mapper.
//! * [`ConstraintVerdict`] — per-token classification result.
//! * [`LogitMasker`] — trait implemented by [`HardMask`] and
//! [`SoftPenaltyMask`].
//! * [`RuleGuidedBeamSearch`] — façade that plugs a constraint + masker into
//! [`tensorlogic_infer::beam_search::BeamSearchDecoder`].
//! * [`RuleGuidedError`] / [`RuleGuidedResult`] — error taxonomy.
//!
//! ## Example
//!
//! ```no_run
//! use std::sync::Arc;
//! use tensorlogic_infer::beam_search::BeamSearchConfig;
//! use tensorlogic_ir::{TLExpr, Term};
//! use tensorlogic_trustformers::rule_guided_decoder::{
//! HardMask, LogitMasker, RuleConstraint, RuleGuidedBeamSearch,
//! };
//!
//! let expr = TLExpr::Pred {
//! name: "entity".into(),
//! args: vec![Term::Const("Alice".into())],
//! };
//! let mapper = |tid: usize| match tid {
//! 0 => Some("entity".into()),
//! 1 => Some("Alice".into()),
//! _ => None,
//! };
//! let constraint = RuleConstraint::compile(expr, mapper).expect("compile");
//! let mask: Arc<dyn LogitMasker> = Arc::new(HardMask::new());
//! let cfg = BeamSearchConfig {
//! beam_width: 2,
//! max_length: 4,
//! vocab_size: 2,
//! ..BeamSearchConfig::default()
//! };
//! let decoder = RuleGuidedBeamSearch::new(cfg, constraint, mask);
//! // `decoder.decode(bos, score_fn)` now returns a BeamSearchResult whose
//! // hypotheses never include tokens that violate the constraint.
//! ```
pub use ;
pub use RuleGuidedBeamSearch;
pub use ;
pub use ;