Skip to main content

llm_prefix_match/
lib.rs

1/*!
2llm-prefix-match: gate LLM outputs based on expected prefix patterns.
3
4When you need a model to start its response a certain way (e.g. JSON,
5a keyword, or a specific token), this crate lets you check, assert, or
6strip that prefix before passing the response downstream.
7
8```rust
9use llm_prefix_match::{PrefixMatcher, MatchResult};
10
11let m = PrefixMatcher::new().require_any(&["YES", "NO"]);
12assert_eq!(m.check("YES, I agree"), MatchResult::Matched("YES".into()));
13assert_eq!(m.check("Maybe"), MatchResult::NoMatch);
14```
15*/
16
17/// Result of a prefix match check.
18#[derive(Debug, Clone, PartialEq)]
19pub enum MatchResult {
20    /// The text started with this prefix.
21    Matched(String),
22    /// No configured prefix matched.
23    NoMatch,
24}
25
26impl MatchResult {
27    pub fn is_match(&self) -> bool { matches!(self, MatchResult::Matched(_)) }
28    pub fn matched_prefix(&self) -> Option<&str> {
29        if let MatchResult::Matched(s) = self { Some(s) } else { None }
30    }
31}
32
33/// Match mode.
34#[derive(Debug, Clone, PartialEq)]
35pub enum MatchMode {
36    CaseSensitive,
37    CaseInsensitive,
38}
39
40/// Checks whether LLM outputs start with any of a set of expected prefixes.
41#[derive(Debug, Clone, Default)]
42pub struct PrefixMatcher {
43    prefixes: Vec<String>,
44    mode: Option<MatchMode>,
45    trim_before_check: bool,
46}
47
48impl PrefixMatcher {
49    pub fn new() -> Self { Self::default() }
50
51    pub fn require(mut self, prefix: impl Into<String>) -> Self {
52        self.prefixes.push(prefix.into()); self
53    }
54
55    pub fn require_any(mut self, prefixes: &[&str]) -> Self {
56        self.prefixes.extend(prefixes.iter().map(|s| s.to_string())); self
57    }
58
59    pub fn case_insensitive(mut self) -> Self {
60        self.mode = Some(MatchMode::CaseInsensitive); self
61    }
62
63    pub fn trim(mut self) -> Self { self.trim_before_check = true; self }
64
65    /// Check if `text` starts with any configured prefix. Returns the first match.
66    pub fn check(&self, text: &str) -> MatchResult {
67        let candidate = if self.trim_before_check { text.trim_start() } else { text };
68        let is_ci = self.mode == Some(MatchMode::CaseInsensitive);
69        for prefix in &self.prefixes {
70            let matches = if is_ci {
71                candidate.to_lowercase().starts_with(&prefix.to_lowercase())
72            } else {
73                candidate.starts_with(prefix.as_str())
74            };
75            if matches {
76                return MatchResult::Matched(prefix.clone());
77            }
78        }
79        MatchResult::NoMatch
80    }
81
82    /// Strip the matched prefix from `text` and return the remainder, or None.
83    pub fn strip(&self, text: &str) -> Option<String> {
84        let candidate = if self.trim_before_check { text.trim_start() } else { text };
85        let is_ci = self.mode == Some(MatchMode::CaseInsensitive);
86        for prefix in &self.prefixes {
87            let matches = if is_ci {
88                candidate.to_lowercase().starts_with(&prefix.to_lowercase())
89            } else {
90                candidate.starts_with(prefix.as_str())
91            };
92            if matches {
93                return Some(candidate[prefix.len()..].trim_start().to_string());
94            }
95        }
96        None
97    }
98
99    /// True if text matches any prefix.
100    pub fn is_valid(&self, text: &str) -> bool {
101        self.check(text).is_match()
102    }
103
104    pub fn prefix_count(&self) -> usize { self.prefixes.len() }
105}
106
107#[cfg(test)]
108mod tests {
109    use super::*;
110
111    #[test]
112    fn single_prefix_match() {
113        let m = PrefixMatcher::new().require("YES");
114        assert_eq!(m.check("YES I agree"), MatchResult::Matched("YES".into()));
115    }
116
117    #[test]
118    fn single_prefix_no_match() {
119        let m = PrefixMatcher::new().require("YES");
120        assert_eq!(m.check("NO"), MatchResult::NoMatch);
121    }
122
123    #[test]
124    fn multiple_prefixes_first_match_wins() {
125        let m = PrefixMatcher::new().require("YES").require("NO");
126        assert_eq!(m.check("YES: sure"), MatchResult::Matched("YES".into()));
127    }
128
129    #[test]
130    fn require_any() {
131        let m = PrefixMatcher::new().require_any(&["YES", "NO", "MAYBE"]);
132        assert!(m.check("MAYBE later").is_match());
133    }
134
135    #[test]
136    fn case_insensitive_match() {
137        let m = PrefixMatcher::new().require("yes").case_insensitive();
138        assert!(m.check("YES I agree").is_match());
139    }
140
141    #[test]
142    fn case_sensitive_no_match() {
143        let m = PrefixMatcher::new().require("yes");
144        assert!(!m.check("YES").is_match());
145    }
146
147    #[test]
148    fn trim_leading_whitespace() {
149        let m = PrefixMatcher::new().require("OK").trim();
150        assert!(m.check("   OK great").is_match());
151    }
152
153    #[test]
154    fn strip_prefix_returns_remainder() {
155        let m = PrefixMatcher::new().require("YES:");
156        let rest = m.strip("YES: I agree");
157        assert_eq!(rest.as_deref(), Some("I agree"));
158    }
159
160    #[test]
161    fn strip_no_match_returns_none() {
162        let m = PrefixMatcher::new().require("YES");
163        assert!(m.strip("NO").is_none());
164    }
165
166    #[test]
167    fn is_valid() {
168        let m = PrefixMatcher::new().require("OK");
169        assert!(m.is_valid("OK then"));
170        assert!(!m.is_valid("bad"));
171    }
172
173    #[test]
174    fn prefix_count() {
175        let m = PrefixMatcher::new().require("A").require("B").require("C");
176        assert_eq!(m.prefix_count(), 3);
177    }
178
179    #[test]
180    fn empty_prefix_always_matches() {
181        let m = PrefixMatcher::new().require("");
182        assert!(m.check("anything").is_match());
183    }
184
185    #[test]
186    fn matched_prefix_accessor() {
187        let m = PrefixMatcher::new().require("YES");
188        let res = m.check("YES!");
189        assert_eq!(res.matched_prefix(), Some("YES"));
190    }
191}