1use aho_corasick::{AhoCorasick, AhoCorasickBuilder, MatchKind};
10use std::collections::HashMap;
11
12#[derive(Debug, Clone, PartialEq)]
14pub enum MatchResult {
15 Complete {
17 prefix: String,
19 marker: String,
21 marker_start: usize,
23 suffix: String,
25 },
26 Partial {
28 prefix: String,
30 partial: String,
32 possible_patterns: Vec<String>,
34 },
35 None {
37 content: String,
39 },
40}
41
42pub struct MarkerMatcher {
44 patterns: Vec<String>,
46 complete_matcher: AhoCorasick,
48 prefix_trie: PrefixTrie,
50 max_pattern_len: usize,
52}
53
54impl MarkerMatcher {
55 pub fn new(patterns: Vec<String>) -> Result<Self, String> {
57 if patterns.is_empty() {
58 return Err("Cannot create MarkerMatcher with empty patterns".to_string());
59 }
60
61 let complete_matcher = AhoCorasickBuilder::new()
62 .match_kind(MatchKind::LeftmostFirst)
63 .build(&patterns)
64 .map_err(|e| format!("Failed to build Aho-Corasick matcher: {}", e))?;
65
66 let max_pattern_len = patterns.iter().map(|p| p.len()).max().unwrap_or(0);
67 let prefix_trie = PrefixTrie::new(&patterns);
68
69 Ok(Self {
70 patterns,
71 complete_matcher,
72 prefix_trie,
73 max_pattern_len,
74 })
75 }
76
77 pub fn max_pattern_len(&self) -> usize {
79 self.max_pattern_len
80 }
81
82 fn safe_slice(text: &str, start_byte: usize, end_byte: usize) -> String {
84 let start = text
86 .char_indices()
87 .find(|(i, _)| *i >= start_byte)
88 .map(|(i, _)| i)
89 .unwrap_or(text.len());
90
91 let end = text
92 .char_indices()
93 .find(|(i, _)| *i >= end_byte)
94 .map(|(i, _)| i)
95 .unwrap_or(text.len());
96
97 text[start..end].to_string()
98 }
99
100 pub fn process_chunk(&self, chunk: &str, partial_buffer: &str) -> MatchResult {
102 let combined = if partial_buffer.is_empty() {
104 chunk.to_string()
105 } else {
106 format!("{}{}", partial_buffer, chunk)
107 };
108
109 if let Some(mat) = self.complete_matcher.find(&combined) {
111 let marker = &self.patterns[mat.pattern().as_usize()];
112 return MatchResult::Complete {
113 prefix: Self::safe_slice(&combined, 0, mat.start()),
114 marker: marker.clone(),
115 marker_start: mat.start(),
116 suffix: Self::safe_slice(&combined, mat.end(), combined.len()),
117 };
118 }
119
120 if let Some((partial_start, partial, patterns)) = self.find_partial_suffix(&combined) {
123 return MatchResult::Partial {
124 prefix: Self::safe_slice(&combined, 0, partial_start),
125 partial: partial.to_string(),
126 possible_patterns: patterns,
127 };
128 }
129
130 MatchResult::None { content: combined }
132 }
133
134 fn find_partial_suffix<'a>(&self, text: &'a str) -> Option<(usize, &'a str, Vec<String>)> {
139 for (i, _) in text.char_indices() {
143 let suffix = &text[i..];
144 if let Some(patterns) = self.prefix_trie.find_prefix_match(suffix) {
145 return Some((i, suffix, patterns));
147 }
148 }
149 None
150 }
151}
152
153struct PrefixTrie {
155 root: TrieNode,
156}
157
158#[derive(Debug)]
159struct TrieNode {
160 children: HashMap<char, TrieNode>,
161 matching_patterns: Vec<String>,
163 is_complete: bool,
165}
166
167impl PrefixTrie {
168 fn new(patterns: &[String]) -> Self {
169 let mut root = TrieNode {
170 children: HashMap::new(),
171 matching_patterns: Vec::new(),
172 is_complete: false,
173 };
174
175 for pattern in patterns {
177 let mut current = &mut root;
178 let chars: Vec<char> = pattern.chars().collect();
179
180 for (i, &ch) in chars.iter().enumerate() {
181 current = current.children.entry(ch).or_insert(TrieNode {
182 children: HashMap::new(),
183 matching_patterns: Vec::new(),
184 is_complete: false,
185 });
186
187 if !current.matching_patterns.contains(pattern) {
189 current.matching_patterns.push(pattern.clone());
190 }
191
192 if i == chars.len() - 1 {
194 current.is_complete = true;
195 }
196 }
197 }
198
199 PrefixTrie { root }
200 }
201
202 fn find_prefix_match(&self, text: &str) -> Option<Vec<String>> {
204 let mut current = &self.root;
205
206 for ch in text.chars() {
207 if let Some(node) = current.children.get(&ch) {
208 current = node;
209 } else {
210 return None;
212 }
213 }
214
215 if !current.matching_patterns.is_empty() && !current.is_complete {
217 Some(current.matching_patterns.clone())
218 } else {
219 None
220 }
221 }
222}
223
224#[cfg(test)]
225mod tests {
226 use super::*;
227
228 #[test]
229 fn test_complete_match() {
230 let patterns = vec!["<TOOLCALL>".to_string(), "<tool_call>".to_string()];
231 let matcher = MarkerMatcher::new(patterns).unwrap();
232
233 let result = matcher.process_chunk("<TOOLCALL>data", "");
234
235 if let MatchResult::Complete {
236 prefix,
237 marker,
238 suffix,
239 ..
240 } = result
241 {
242 assert_eq!(prefix, "");
243 assert_eq!(marker, "<TOOLCALL>");
244 assert_eq!(suffix, "data");
245 } else {
246 panic!("Expected complete match");
247 }
248 }
249
250 #[test]
251 fn test_partial_match_suffix() {
252 let patterns = vec!["<TOOLCALL>".to_string()];
253 let matcher = MarkerMatcher::new(patterns).unwrap();
254
255 let result = matcher.process_chunk("n<T", "");
257
258 if let MatchResult::Partial {
259 prefix,
260 partial,
261 possible_patterns,
262 } = result
263 {
264 assert_eq!(prefix, "n");
265 assert_eq!(partial, "<T");
266 assert_eq!(possible_patterns, vec!["<TOOLCALL>"]);
267 } else {
268 panic!("Expected partial match, got: {:?}", result);
269 }
270 }
271
272 #[test]
273 fn test_no_false_positive() {
274 let patterns = vec!["<TOOLCALL>".to_string()];
275 let matcher = MarkerMatcher::new(patterns).unwrap();
276
277 let result = matcher.process_chunk("n < 5", "");
279
280 if let MatchResult::None { content } = result {
281 assert_eq!(content, "n < 5");
282 } else {
283 panic!("Expected no match, got: {:?}", result);
284 }
285 }
286
287 #[test]
288 fn test_partial_buffer_combination() {
289 let patterns = vec!["<TOOLCALL>".to_string()];
290 let matcher = MarkerMatcher::new(patterns).unwrap();
291
292 let result1 = matcher.process_chunk("<", "");
294 let partial = if let MatchResult::Partial { partial, .. } = result1 {
295 partial
296 } else {
297 panic!("Expected partial match");
298 };
299
300 let result2 = matcher.process_chunk("TOOLCALL>", &partial);
302
303 if let MatchResult::Complete { marker, .. } = result2 {
304 assert_eq!(marker, "<TOOLCALL>");
305 } else {
306 panic!("Expected complete match, got: {:?}", result2);
307 }
308 }
309
310 #[test]
311 fn test_prefix_with_content() {
312 let patterns = vec!["<TOOLCALL>".to_string()];
313 let matcher = MarkerMatcher::new(patterns).unwrap();
314
315 let result = matcher.process_chunk("text before <TOOLCALL> after", "");
316
317 if let MatchResult::Complete {
318 prefix,
319 marker,
320 suffix,
321 ..
322 } = result
323 {
324 assert_eq!(prefix, "text before ");
325 assert_eq!(marker, "<TOOLCALL>");
326 assert_eq!(suffix, " after");
327 } else {
328 panic!("Expected complete match");
329 }
330 }
331
332 #[test]
333 fn test_empty_patterns() {
334 let result = MarkerMatcher::new(vec![]);
335 assert!(result.is_err());
336 }
337
338 #[test]
339 fn test_multiple_patterns() {
340 let patterns = vec![
341 "<TOOLCALL>".to_string(),
342 "[TOOL_CALLS]".to_string(),
343 "<tool_call>".to_string(),
344 ];
345 let matcher = MarkerMatcher::new(patterns).unwrap();
346
347 let result1 = matcher.process_chunk("[TOOL_CALLS]", "");
349 if let MatchResult::Complete { marker, .. } = result1 {
350 assert_eq!(marker, "[TOOL_CALLS]");
351 } else {
352 panic!("Expected complete match for [TOOL_CALLS]");
353 }
354
355 let result2 = matcher.process_chunk("text<to", "");
357 if let MatchResult::Partial {
358 partial,
359 possible_patterns,
360 ..
361 } = result2
362 {
363 assert_eq!(partial, "<to");
364 assert!(possible_patterns.contains(&"<tool_call>".to_string()));
365 } else {
366 panic!("Expected partial match for <tool_call>");
367 }
368 }
369
370 #[test]
371 fn test_multiple_partial_matches_edge_case() {
372 let patterns = vec!["FooBar".to_string(), "<TOOLCALL>".to_string()];
379 let matcher = MarkerMatcher::new(patterns).unwrap();
380
381 let result = matcher.process_chunk("This is FooBaz which is a no, but <TOO", "");
382
383 if let MatchResult::Partial {
384 prefix,
385 partial,
386 possible_patterns,
387 } = result
388 {
389 assert_eq!(partial, "<TOO");
391 assert_eq!(prefix, "This is FooBaz which is a no, but ");
392 assert!(possible_patterns.contains(&"<TOOLCALL>".to_string()));
393 } else {
394 panic!("Expected partial match for '<TOO>', got: {:?}", result);
395 }
396 }
397
398 #[test]
399 fn test_earliest_valid_partial_match() {
400 let patterns = vec!["FooBar".to_string(), "<TOOLCALL>".to_string()];
407 let matcher = MarkerMatcher::new(patterns).unwrap();
408
409 let result = matcher.process_chunk("Some text FooBa and then <TO", "");
410
411 if let MatchResult::Partial {
412 prefix,
413 partial,
414 possible_patterns,
415 } = result
416 {
417 assert_eq!(partial, "<TO");
419 assert_eq!(prefix, "Some text FooBa and then ");
420 assert!(possible_patterns.contains(&"<TOOLCALL>".to_string()));
421 } else {
422 panic!("Expected partial match for '<TO>', got: {:?}", result);
423 }
424 }
425
426 #[test]
427 fn test_partial_at_exact_end() {
428 let patterns = vec!["FooBar".to_string(), "<TOOLCALL>".to_string()];
433 let matcher = MarkerMatcher::new(patterns).unwrap();
434
435 let result = matcher.process_chunk("Some text ending with FooBa", "");
436
437 if let MatchResult::Partial {
438 prefix,
439 partial,
440 possible_patterns,
441 } = result
442 {
443 assert_eq!(partial, "FooBa");
445 assert_eq!(prefix, "Some text ending with ");
446 assert!(possible_patterns.contains(&"FooBar".to_string()));
447 } else {
448 panic!("Expected partial match for 'FooBa', got: {:?}", result);
449 }
450 }
451
452 #[test]
453 fn test_unicode_complete_match() {
454 let patterns = vec!["<TOOLCALL>".to_string()];
457 let matcher = MarkerMatcher::new(patterns).unwrap();
458
459 let result = matcher.process_chunk("Hello 👋 world <TOOLCALL>data 🚀", "");
461
462 if let MatchResult::Complete {
463 prefix,
464 marker,
465 suffix,
466 ..
467 } = result
468 {
469 assert_eq!(prefix, "Hello 👋 world ");
470 assert_eq!(marker, "<TOOLCALL>");
471 assert_eq!(suffix, "data 🚀");
472 } else {
473 panic!("Expected complete match, got: {:?}", result);
474 }
475 }
476
477 #[test]
478 fn test_unicode_partial_match() {
479 let patterns = vec!["<TOOLCALL>".to_string()];
481 let matcher = MarkerMatcher::new(patterns).unwrap();
482
483 let result = matcher.process_chunk("Text with 中文字符 and <TO", "");
485
486 if let MatchResult::Partial {
487 prefix,
488 partial,
489 possible_patterns,
490 } = result
491 {
492 assert_eq!(prefix, "Text with 中文字符 and ");
493 assert_eq!(partial, "<TO");
494 assert!(possible_patterns.contains(&"<TOOLCALL>".to_string()));
495 } else {
496 panic!("Expected partial match, got: {:?}", result);
497 }
498 }
499
500 #[test]
501 fn test_unicode_no_false_positive() {
502 let patterns = vec!["<TOOLCALL>".to_string()];
504 let matcher = MarkerMatcher::new(patterns).unwrap();
505
506 let result = matcher.process_chunk("Unicode test <TOOLCALL> full-width", "");
508
509 if let MatchResult::None { content } = result {
510 assert_eq!(content, "Unicode test <TOOLCALL> full-width");
511 } else {
512 panic!(
513 "Expected no match for full-width characters, got: {:?}",
514 result
515 );
516 }
517 }
518
519 #[test]
520 fn test_unicode_pattern_itself() {
521 let patterns = vec!["🔧工具".to_string(), "📞call".to_string()];
523 let matcher = MarkerMatcher::new(patterns).unwrap();
524
525 let result1 = matcher.process_chunk("Start 🔧工具 end", "");
527 if let MatchResult::Complete {
528 prefix,
529 marker,
530 suffix,
531 ..
532 } = result1
533 {
534 assert_eq!(prefix, "Start ");
535 assert_eq!(marker, "🔧工具");
536 assert_eq!(suffix, " end");
537 } else {
538 panic!(
539 "Expected complete match for unicode pattern, got: {:?}",
540 result1
541 );
542 }
543
544 let result2 = matcher.process_chunk("Text 🔧工", "");
546 if let MatchResult::Partial {
547 prefix,
548 partial,
549 possible_patterns,
550 } = result2
551 {
552 assert_eq!(prefix, "Text ");
553 assert_eq!(partial, "🔧工");
554 assert!(possible_patterns.contains(&"🔧工具".to_string()));
555 } else {
556 panic!(
557 "Expected partial match for unicode pattern, got: {:?}",
558 result2
559 );
560 }
561 }
562}