dynamo_llm/utils/
prefix_matcher.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! Efficient multi-pattern marker detection with partial suffix matching
5//!
6//! This module provides utilities for detecting complete and partial marker patterns
7//! in streaming text, with support for detecting markers split across chunk boundaries.
8
9use aho_corasick::{AhoCorasick, AhoCorasickBuilder, MatchKind};
10use std::collections::HashMap;
11
12/// Result of processing a chunk with potential marker detection
13#[derive(Debug, Clone, PartialEq)]
14pub enum MatchResult {
15    /// Complete marker found
16    Complete {
17        /// Content before the marker (safe to emit)
18        prefix: String,
19        /// The complete marker matched
20        marker: String,
21        /// Start position of the marker in the input
22        marker_start: usize,
23        /// Remaining content after the marker
24        suffix: String,
25    },
26    /// Partial marker at end of chunk
27    Partial {
28        /// Content before the partial (safe to emit)
29        prefix: String,
30        /// The partial match to hold
31        partial: String,
32        /// Which patterns this could match
33        possible_patterns: Vec<String>,
34    },
35    /// No markers detected
36    None {
37        /// All content is safe to emit
38        content: String,
39    },
40}
41
42/// Efficient multi-pattern matcher with partial suffix detection
43pub struct MarkerMatcher {
44    /// All patterns we're looking for
45    patterns: Vec<String>,
46    /// Aho-Corasick matcher for complete patterns
47    complete_matcher: AhoCorasick,
48    /// Trie for partial matching
49    prefix_trie: PrefixTrie,
50    /// Maximum pattern length (for buffer limits)
51    max_pattern_len: usize,
52}
53
54impl MarkerMatcher {
55    /// Create a new matcher with the given patterns
56    pub fn new(patterns: Vec<String>) -> Result<Self, String> {
57        if patterns.is_empty() {
58            return Err("Cannot create MarkerMatcher with empty patterns".to_string());
59        }
60
61        let complete_matcher = AhoCorasickBuilder::new()
62            .match_kind(MatchKind::LeftmostFirst)
63            .build(&patterns)
64            .map_err(|e| format!("Failed to build Aho-Corasick matcher: {}", e))?;
65
66        let max_pattern_len = patterns.iter().map(|p| p.len()).max().unwrap_or(0);
67        let prefix_trie = PrefixTrie::new(&patterns);
68
69        Ok(Self {
70            patterns,
71            complete_matcher,
72            prefix_trie,
73            max_pattern_len,
74        })
75    }
76
77    /// Get the maximum pattern length
78    pub fn max_pattern_len(&self) -> usize {
79        self.max_pattern_len
80    }
81
82    /// Safe UTF-8 slicing that ensures we only slice at character boundaries
83    fn safe_slice(text: &str, start_byte: usize, end_byte: usize) -> String {
84        // Clamp indices to valid boundaries
85        let start = text
86            .char_indices()
87            .find(|(i, _)| *i >= start_byte)
88            .map(|(i, _)| i)
89            .unwrap_or(text.len());
90
91        let end = text
92            .char_indices()
93            .find(|(i, _)| *i >= end_byte)
94            .map(|(i, _)| i)
95            .unwrap_or(text.len());
96
97        text[start..end].to_string()
98    }
99
100    /// Process a chunk with an optional partial buffer from previous chunk
101    pub fn process_chunk(&self, chunk: &str, partial_buffer: &str) -> MatchResult {
102        // Combine buffer with new chunk
103        let combined = if partial_buffer.is_empty() {
104            chunk.to_string()
105        } else {
106            format!("{}{}", partial_buffer, chunk)
107        };
108
109        // First check for complete markers
110        if let Some(mat) = self.complete_matcher.find(&combined) {
111            let marker = &self.patterns[mat.pattern().as_usize()];
112            return MatchResult::Complete {
113                prefix: Self::safe_slice(&combined, 0, mat.start()),
114                marker: marker.clone(),
115                marker_start: mat.start(),
116                suffix: Self::safe_slice(&combined, mat.end(), combined.len()),
117            };
118        }
119
120        // No complete match - check for partial at ANY suffix position
121        // This is the key: check "n<T" → finds "<T" as partial
122        if let Some((partial_start, partial, patterns)) = self.find_partial_suffix(&combined) {
123            return MatchResult::Partial {
124                prefix: Self::safe_slice(&combined, 0, partial_start),
125                partial: partial.to_string(),
126                possible_patterns: patterns,
127            };
128        }
129
130        // No matches at all
131        MatchResult::None { content: combined }
132    }
133
134    /// Find the longest partial match in any suffix of the input
135    ///
136    /// This scans from left to right to find the EARLIEST partial match,
137    /// ensuring we emit as much content as possible while holding only the minimal partial.
138    fn find_partial_suffix<'a>(&self, text: &'a str) -> Option<(usize, &'a str, Vec<String>)> {
139        // Start from the beginning to find the EARLIEST partial match
140        // This ensures we emit as much as possible
141        // Use char_indices to get valid UTF-8 boundaries
142        for (i, _) in text.char_indices() {
143            let suffix = &text[i..];
144            if let Some(patterns) = self.prefix_trie.find_prefix_match(suffix) {
145                // This suffix is a prefix of one or more patterns
146                return Some((i, suffix, patterns));
147            }
148        }
149        None
150    }
151}
152
153/// Trie structure for efficient prefix matching
154struct PrefixTrie {
155    root: TrieNode,
156}
157
158#[derive(Debug)]
159struct TrieNode {
160    children: HashMap<char, TrieNode>,
161    /// Patterns that have this exact prefix
162    matching_patterns: Vec<String>,
163    /// Is this node a complete pattern?
164    is_complete: bool,
165}
166
167impl PrefixTrie {
168    fn new(patterns: &[String]) -> Self {
169        let mut root = TrieNode {
170            children: HashMap::new(),
171            matching_patterns: Vec::new(),
172            is_complete: false,
173        };
174
175        // Build trie
176        for pattern in patterns {
177            let mut current = &mut root;
178            let chars: Vec<char> = pattern.chars().collect();
179
180            for (i, &ch) in chars.iter().enumerate() {
181                current = current.children.entry(ch).or_insert(TrieNode {
182                    children: HashMap::new(),
183                    matching_patterns: Vec::new(),
184                    is_complete: false,
185                });
186
187                // Add this pattern to all prefix nodes
188                if !current.matching_patterns.contains(pattern) {
189                    current.matching_patterns.push(pattern.clone());
190                }
191
192                // Mark complete if we're at the end
193                if i == chars.len() - 1 {
194                    current.is_complete = true;
195                }
196            }
197        }
198
199        PrefixTrie { root }
200    }
201
202    /// Check if text is a prefix of any pattern (but not a complete pattern)
203    fn find_prefix_match(&self, text: &str) -> Option<Vec<String>> {
204        let mut current = &self.root;
205
206        for ch in text.chars() {
207            if let Some(node) = current.children.get(&ch) {
208                current = node;
209            } else {
210                // Not a prefix of any pattern
211                return None;
212            }
213        }
214
215        // If we matched the entire text and it's a prefix of something (but not complete)
216        if !current.matching_patterns.is_empty() && !current.is_complete {
217            Some(current.matching_patterns.clone())
218        } else {
219            None
220        }
221    }
222}
223
224#[cfg(test)]
225mod tests {
226    use super::*;
227
228    #[test]
229    fn test_complete_match() {
230        let patterns = vec!["<TOOLCALL>".to_string(), "<tool_call>".to_string()];
231        let matcher = MarkerMatcher::new(patterns).unwrap();
232
233        let result = matcher.process_chunk("<TOOLCALL>data", "");
234
235        if let MatchResult::Complete {
236            prefix,
237            marker,
238            suffix,
239            ..
240        } = result
241        {
242            assert_eq!(prefix, "");
243            assert_eq!(marker, "<TOOLCALL>");
244            assert_eq!(suffix, "data");
245        } else {
246            panic!("Expected complete match");
247        }
248    }
249
250    #[test]
251    fn test_partial_match_suffix() {
252        let patterns = vec!["<TOOLCALL>".to_string()];
253        let matcher = MarkerMatcher::new(patterns).unwrap();
254
255        // Test the key case: "n<T" should detect "<T" as partial
256        let result = matcher.process_chunk("n<T", "");
257
258        if let MatchResult::Partial {
259            prefix,
260            partial,
261            possible_patterns,
262        } = result
263        {
264            assert_eq!(prefix, "n");
265            assert_eq!(partial, "<T");
266            assert_eq!(possible_patterns, vec!["<TOOLCALL>"]);
267        } else {
268            panic!("Expected partial match, got: {:?}", result);
269        }
270    }
271
272    #[test]
273    fn test_no_false_positive() {
274        let patterns = vec!["<TOOLCALL>".to_string()];
275        let matcher = MarkerMatcher::new(patterns).unwrap();
276
277        // Test case: "n < 5" should not trigger partial match
278        let result = matcher.process_chunk("n < 5", "");
279
280        if let MatchResult::None { content } = result {
281            assert_eq!(content, "n < 5");
282        } else {
283            panic!("Expected no match, got: {:?}", result);
284        }
285    }
286
287    #[test]
288    fn test_partial_buffer_combination() {
289        let patterns = vec!["<TOOLCALL>".to_string()];
290        let matcher = MarkerMatcher::new(patterns).unwrap();
291
292        // First chunk: partial "<"
293        let result1 = matcher.process_chunk("<", "");
294        let partial = if let MatchResult::Partial { partial, .. } = result1 {
295            partial
296        } else {
297            panic!("Expected partial match");
298        };
299
300        // Second chunk: "TOOLCALL>" completes the pattern
301        let result2 = matcher.process_chunk("TOOLCALL>", &partial);
302
303        if let MatchResult::Complete { marker, .. } = result2 {
304            assert_eq!(marker, "<TOOLCALL>");
305        } else {
306            panic!("Expected complete match, got: {:?}", result2);
307        }
308    }
309
310    #[test]
311    fn test_prefix_with_content() {
312        let patterns = vec!["<TOOLCALL>".to_string()];
313        let matcher = MarkerMatcher::new(patterns).unwrap();
314
315        let result = matcher.process_chunk("text before <TOOLCALL> after", "");
316
317        if let MatchResult::Complete {
318            prefix,
319            marker,
320            suffix,
321            ..
322        } = result
323        {
324            assert_eq!(prefix, "text before ");
325            assert_eq!(marker, "<TOOLCALL>");
326            assert_eq!(suffix, " after");
327        } else {
328            panic!("Expected complete match");
329        }
330    }
331
332    #[test]
333    fn test_empty_patterns() {
334        let result = MarkerMatcher::new(vec![]);
335        assert!(result.is_err());
336    }
337
338    #[test]
339    fn test_multiple_patterns() {
340        let patterns = vec![
341            "<TOOLCALL>".to_string(),
342            "[TOOL_CALLS]".to_string(),
343            "<tool_call>".to_string(),
344        ];
345        let matcher = MarkerMatcher::new(patterns).unwrap();
346
347        // Test different patterns
348        let result1 = matcher.process_chunk("[TOOL_CALLS]", "");
349        if let MatchResult::Complete { marker, .. } = result1 {
350            assert_eq!(marker, "[TOOL_CALLS]");
351        } else {
352            panic!("Expected complete match for [TOOL_CALLS]");
353        }
354
355        // Test partial for different pattern
356        let result2 = matcher.process_chunk("text<to", "");
357        if let MatchResult::Partial {
358            partial,
359            possible_patterns,
360            ..
361        } = result2
362        {
363            assert_eq!(partial, "<to");
364            assert!(possible_patterns.contains(&"<tool_call>".to_string()));
365        } else {
366            panic!("Expected partial match for <tool_call>");
367        }
368    }
369
370    #[test]
371    fn test_multiple_partial_matches_edge_case() {
372        // Test scenario: Multiple patterns where one looks like a prefix but isn't valid
373        // Patterns: ["FooBar", "<TOOLCALL>"]
374        // Input: "This is FooBaz which is a no, but <TOO"
375        // Key insight: "FooBa" from "FooBaz" is NOT a valid partial because the 'z'
376        // doesn't match the expected 'r' in "FooBar"
377        // Expected: Hold "<TOO" as partial, emit "This is FooBaz which is a no, but "
378        let patterns = vec!["FooBar".to_string(), "<TOOLCALL>".to_string()];
379        let matcher = MarkerMatcher::new(patterns).unwrap();
380
381        let result = matcher.process_chunk("This is FooBaz which is a no, but <TOO", "");
382
383        if let MatchResult::Partial {
384            prefix,
385            partial,
386            possible_patterns,
387        } = result
388        {
389            // The algorithm correctly skips "FooBaz" (not a valid prefix) and finds "<TOO"
390            assert_eq!(partial, "<TOO");
391            assert_eq!(prefix, "This is FooBaz which is a no, but ");
392            assert!(possible_patterns.contains(&"<TOOLCALL>".to_string()));
393        } else {
394            panic!("Expected partial match for '<TOO>', got: {:?}", result);
395        }
396    }
397
398    #[test]
399    fn test_earliest_valid_partial_match() {
400        // Test that the algorithm finds the earliest VALID partial match
401        // Patterns: ["FooBar", "<TOOLCALL>"]
402        // Input: "Some text FooBa and then <TO"
403        // Analysis: "FooBa and then <TO" is not a valid prefix of "FooBar" because
404        // after "FooBa" we have " " (space) but "FooBar" expects "r"
405        // Expected: Skip invalid "FooBa..." and find valid "<TO" partial
406        let patterns = vec!["FooBar".to_string(), "<TOOLCALL>".to_string()];
407        let matcher = MarkerMatcher::new(patterns).unwrap();
408
409        let result = matcher.process_chunk("Some text FooBa and then <TO", "");
410
411        if let MatchResult::Partial {
412            prefix,
413            partial,
414            possible_patterns,
415        } = result
416        {
417            // Should find "<TO" as the valid partial match
418            assert_eq!(partial, "<TO");
419            assert_eq!(prefix, "Some text FooBa and then ");
420            assert!(possible_patterns.contains(&"<TOOLCALL>".to_string()));
421        } else {
422            panic!("Expected partial match for '<TO>', got: {:?}", result);
423        }
424    }
425
426    #[test]
427    fn test_partial_at_exact_end() {
428        // Test case where a valid partial is exactly at the end
429        // Patterns: ["FooBar", "<TOOLCALL>"]
430        // Input: "Some text ending with FooBa"
431        // Expected: Hold "FooBa" as partial (valid prefix of "FooBar")
432        let patterns = vec!["FooBar".to_string(), "<TOOLCALL>".to_string()];
433        let matcher = MarkerMatcher::new(patterns).unwrap();
434
435        let result = matcher.process_chunk("Some text ending with FooBa", "");
436
437        if let MatchResult::Partial {
438            prefix,
439            partial,
440            possible_patterns,
441        } = result
442        {
443            // Should find "FooBa" as a valid partial match at the end
444            assert_eq!(partial, "FooBa");
445            assert_eq!(prefix, "Some text ending with ");
446            assert!(possible_patterns.contains(&"FooBar".to_string()));
447        } else {
448            panic!("Expected partial match for 'FooBa', got: {:?}", result);
449        }
450    }
451
452    #[test]
453    fn test_unicode_complete_match() {
454        // Test complete pattern matching with unicode content
455        // Use patterns with ASCII markers but unicode content
456        let patterns = vec!["<TOOLCALL>".to_string()];
457        let matcher = MarkerMatcher::new(patterns).unwrap();
458
459        // Test with emoji and multi-byte characters
460        let result = matcher.process_chunk("Hello 👋 world <TOOLCALL>data 🚀", "");
461
462        if let MatchResult::Complete {
463            prefix,
464            marker,
465            suffix,
466            ..
467        } = result
468        {
469            assert_eq!(prefix, "Hello 👋 world ");
470            assert_eq!(marker, "<TOOLCALL>");
471            assert_eq!(suffix, "data 🚀");
472        } else {
473            panic!("Expected complete match, got: {:?}", result);
474        }
475    }
476
477    #[test]
478    fn test_unicode_partial_match() {
479        // Test partial matching where the partial might occur after unicode content
480        let patterns = vec!["<TOOLCALL>".to_string()];
481        let matcher = MarkerMatcher::new(patterns).unwrap();
482
483        // Test partial after multi-byte characters
484        let result = matcher.process_chunk("Text with 中文字符 and <TO", "");
485
486        if let MatchResult::Partial {
487            prefix,
488            partial,
489            possible_patterns,
490        } = result
491        {
492            assert_eq!(prefix, "Text with 中文字符 and ");
493            assert_eq!(partial, "<TO");
494            assert!(possible_patterns.contains(&"<TOOLCALL>".to_string()));
495        } else {
496            panic!("Expected partial match, got: {:?}", result);
497        }
498    }
499
500    #[test]
501    fn test_unicode_no_false_positive() {
502        // Test that unicode content doesn't create false positives
503        let patterns = vec!["<TOOLCALL>".to_string()];
504        let matcher = MarkerMatcher::new(patterns).unwrap();
505
506        // Test with unicode that might look similar to ASCII patterns
507        let result = matcher.process_chunk("Unicode test <TOOLCALL> full-width", "");
508
509        if let MatchResult::None { content } = result {
510            assert_eq!(content, "Unicode test <TOOLCALL> full-width");
511        } else {
512            panic!(
513                "Expected no match for full-width characters, got: {:?}",
514                result
515            );
516        }
517    }
518
519    #[test]
520    fn test_unicode_pattern_itself() {
521        // Test patterns that contain unicode characters
522        let patterns = vec!["🔧工具".to_string(), "📞call".to_string()];
523        let matcher = MarkerMatcher::new(patterns).unwrap();
524
525        // Test complete match with unicode pattern
526        let result1 = matcher.process_chunk("Start 🔧工具 end", "");
527        if let MatchResult::Complete {
528            prefix,
529            marker,
530            suffix,
531            ..
532        } = result1
533        {
534            assert_eq!(prefix, "Start ");
535            assert_eq!(marker, "🔧工具");
536            assert_eq!(suffix, " end");
537        } else {
538            panic!(
539                "Expected complete match for unicode pattern, got: {:?}",
540                result1
541            );
542        }
543
544        // Test partial match with unicode pattern
545        let result2 = matcher.process_chunk("Text 🔧工", "");
546        if let MatchResult::Partial {
547            prefix,
548            partial,
549            possible_patterns,
550        } = result2
551        {
552            assert_eq!(prefix, "Text ");
553            assert_eq!(partial, "🔧工");
554            assert!(possible_patterns.contains(&"🔧工具".to_string()));
555        } else {
556            panic!(
557                "Expected partial match for unicode pattern, got: {:?}",
558                result2
559            );
560        }
561    }
562}