use crate::trigram::{extract_trigrams, Trigram};
pub fn extract_trigrams_from_regex(pattern: &str) -> Vec<Trigram> {
let literals = extract_literal_sequences(pattern);
if literals.is_empty() {
log::debug!("No literals found in regex pattern '{}', will fall back to full scan", pattern);
return vec![];
}
log::debug!("Extracted {} literal sequences from regex: {:?}", literals.len(), literals);
let mut all_trigrams = Vec::new();
for literal in literals {
let trigrams = extract_trigrams(&literal);
all_trigrams.extend(trigrams);
}
all_trigrams.sort_unstable();
all_trigrams.dedup();
log::debug!("Extracted {} unique trigrams from regex pattern", all_trigrams.len());
all_trigrams
}
pub fn extract_literal_sequences(pattern: &str) -> Vec<String> {
let mut sequences = Vec::new();
let mut current = String::new();
let mut chars = pattern.chars().peekable();
let mut has_case_insensitive_flag = false;
while let Some(ch) = chars.next() {
match ch {
'.' | '*' | '+' | '?' | '|' | '[' | ']' | '^' | '$' => {
if current.len() >= 3 {
sequences.push(current.clone());
}
current.clear();
}
'(' => {
if current.len() >= 3 {
sequences.push(current.clone());
}
current.clear();
if chars.peek() == Some(&'?') {
chars.next();
if let Some(&flag_ch) = chars.peek() {
match flag_ch {
':' => {
chars.next(); }
'i' | 'm' | 's' | 'x' | '-' => {
if flag_ch == 'i' {
has_case_insensitive_flag = true;
}
while let Some(&next_ch) = chars.peek() {
if next_ch == 'i' {
has_case_insensitive_flag = true;
}
chars.next();
if next_ch == ')' {
break;
}
}
}
_ => {
while let Some(&next_ch) = chars.peek() {
chars.next();
if next_ch == ')' {
break;
}
}
}
}
}
}
}
')' => {
if current.len() >= 3 {
sequences.push(current.clone());
}
current.clear();
}
'{' => {
if current.len() >= 3 {
sequences.push(current.clone());
}
current.clear();
while let Some(&next_ch) = chars.peek() {
chars.next();
if next_ch == '}' {
break;
}
}
}
'}' => {
if current.len() >= 3 {
sequences.push(current.clone());
}
current.clear();
}
'\\' => {
if let Some(&next_ch) = chars.peek() {
match next_ch {
's' | 'd' | 'w' | 'S' | 'D' | 'W' | 'n' | 't' | 'r' | 'b' | 'B' => {
chars.next(); if current.len() >= 3 {
sequences.push(current.clone());
}
current.clear();
}
_ => {
chars.next(); current.push(next_ch);
}
}
} else {
if current.len() >= 3 {
sequences.push(current.clone());
}
current.clear();
}
}
_ => {
current.push(ch);
}
}
}
if current.len() >= 3 {
sequences.push(current);
}
if has_case_insensitive_flag {
log::debug!("Case-insensitive flag detected in pattern, cannot use trigram optimization");
return vec![];
}
sequences
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_extract_literal_sequences_simple() {
let sequences = extract_literal_sequences("hello");
assert_eq!(sequences, vec!["hello"]);
}
#[test]
fn test_extract_literal_sequences_with_wildcard() {
let sequences = extract_literal_sequences("fn.*test");
assert_eq!(sequences, vec!["test"]);
}
#[test]
fn test_extract_literal_sequences_multiple() {
let sequences = extract_literal_sequences("class.*Controller");
assert_eq!(sequences, vec!["class", "Controller"]);
}
#[test]
fn test_extract_literal_sequences_no_literals() {
let sequences = extract_literal_sequences(".*");
assert!(sequences.is_empty());
}
#[test]
fn test_extract_literal_sequences_short_literals() {
let sequences = extract_literal_sequences("fn.*test");
assert_eq!(sequences, vec!["test"]);
}
#[test]
fn test_extract_literal_sequences_escaped() {
let sequences = extract_literal_sequences("test\\.txt");
assert_eq!(sequences, vec!["test.txt"]);
}
#[test]
fn test_extract_literal_sequences_whitespace_escape() {
let sequences = extract_literal_sequences("fn\\s+extract");
assert_eq!(sequences, vec!["extract"]);
}
#[test]
fn test_extract_literal_sequences_word_boundary() {
let sequences = extract_literal_sequences("\\bListUsersController\\b");
assert_eq!(sequences, vec!["ListUsersController"]);
}
#[test]
fn test_extract_trigrams_simple_literal() {
let trigrams = extract_trigrams_from_regex("extract");
assert_eq!(trigrams.len(), 5);
}
#[test]
fn test_extract_trigrams_with_wildcard() {
let trigrams = extract_trigrams_from_regex("fn.*test");
assert_eq!(trigrams.len(), 2);
}
#[test]
fn test_extract_trigrams_multiple_literals() {
let trigrams = extract_trigrams_from_regex("class.*Controller");
assert!(trigrams.len() >= 10); }
#[test]
fn test_extract_trigrams_no_literals() {
let trigrams = extract_trigrams_from_regex(".*");
assert!(trigrams.is_empty());
}
#[test]
fn test_extract_trigrams_complex_pattern() {
let trigrams = extract_trigrams_from_regex("(function|const)");
assert!(trigrams.len() >= 6);
}
#[test]
fn test_extract_literal_sequences_alternation() {
let sequences = extract_literal_sequences("(SymbolWriter|ContentWriter)");
assert_eq!(sequences, vec!["SymbolWriter", "ContentWriter"]);
}
#[test]
fn test_extract_literal_sequences_three_way_alternation() {
let sequences = extract_literal_sequences("(Indexer|QueryEngine|CacheManager)");
assert_eq!(sequences, vec!["Indexer", "QueryEngine", "CacheManager"]);
}
#[test]
fn test_extract_literal_sequences_case_insensitive_flag() {
let sequences = extract_literal_sequences("(?i)queryengine");
assert_eq!(sequences, Vec::<String>::new());
}
#[test]
fn test_extract_literal_sequences_multiline_flag() {
let sequences = extract_literal_sequences("(?m)^test");
assert_eq!(sequences, vec!["test"]);
}
#[test]
fn test_extract_literal_sequences_non_capturing_group() {
let sequences = extract_literal_sequences("(?:test|func)");
assert_eq!(sequences, vec!["test", "func"]);
}
#[test]
fn test_extract_literal_sequences_quantifier_no_false_literal() {
let sequences = extract_literal_sequences("a{2,3}test");
assert_eq!(sequences, vec!["test"]);
assert!(!sequences.contains(&"2,3".to_string()));
}
#[test]
fn test_extract_literal_sequences_quantifier_range() {
let sequences = extract_literal_sequences("test{1,5}word");
assert_eq!(sequences, vec!["test", "word"]);
assert!(!sequences.contains(&"1,5".to_string()));
}
#[test]
fn test_extract_literal_sequences_quantifier_exact() {
let sequences = extract_literal_sequences("test{3}word");
assert_eq!(sequences, vec!["test", "word"]);
assert!(!sequences.contains(&"3".to_string()));
}
#[test]
fn test_extract_literal_sequences_combined_flags() {
let sequences = extract_literal_sequences("(?im)test");
assert_eq!(sequences, Vec::<String>::new());
}
#[test]
fn test_extract_literal_sequences_flag_before_literal() {
let sequences = extract_literal_sequences("(?i)test.*function");
assert_eq!(sequences, Vec::<String>::new());
}
}