probe_code/search/
query.rs

1use probe_code::search::elastic_query;
2// No term_exceptions import needed
3use std::collections::{HashMap, HashSet};
4use std::time::Instant;
5
6/// Escapes special regex characters in a string
7pub fn regex_escape(s: &str) -> String {
8    let special_chars = [
9        '.', '^', '$', '*', '+', '?', '(', ')', '[', ']', '{', '}', '|', '\\',
10    ];
11    let mut result = String::with_capacity(s.len() * 2);
12
13    for c in s.chars() {
14        if special_chars.contains(&c) {
15            result.push('\\');
16        }
17        result.push(c);
18    }
19
20    result
21}
22
23// ----------------------------------------------------------------------------
24// NEW CODE: Full AST-based planning and pattern generation
25// ----------------------------------------------------------------------------
26
27/// A unified plan holding the parsed AST and a mapping of each AST term to an index.
28/// We store a map for quick lookups of term indices.
29#[derive(Debug)]
30pub struct QueryPlan {
31    pub ast: elastic_query::Expr,
32    pub term_indices: HashMap<String, usize>,
33    pub excluded_terms: HashSet<String>,
34    pub exact: bool,
35}
36
37/// Helper function to format duration in a human-readable way
38fn format_duration(duration: std::time::Duration) -> String {
39    if duration.as_millis() < 1000 {
40        format!("{millis}ms", millis = duration.as_millis())
41    } else {
42        format!("{:.2}s", duration.as_secs_f64())
43    }
44}
45
46/// Create a QueryPlan from a raw query string. This fully parses the query into an AST,
47/// then extracts all terms (including excluded), and prepares a term-index map.
48pub fn create_query_plan(query: &str, exact: bool) -> Result<QueryPlan, elastic_query::ParseError> {
49    let debug_mode = std::env::var("DEBUG").unwrap_or_default() == "1";
50    let start_time = Instant::now();
51
52    if debug_mode {
53        println!("DEBUG: Starting query plan creation for query: '{query}'");
54    }
55
56    // Use the regular AST parsing
57    let parsing_start = Instant::now();
58
59    if debug_mode {
60        println!("DEBUG: Starting AST parsing for query: '{query}', exact={exact}");
61    }
62
63    // Parse the query into an AST with processed terms
64    // We use standard Elasticsearch behavior (AND for implicit combinations)
65    let mut ast = elastic_query::parse_query(query, exact)?;
66
67    // If exact search is enabled, update the AST to mark all terms as exact
68    if exact {
69        update_ast_exact(&mut ast);
70    }
71
72    let parsing_duration = parsing_start.elapsed();
73
74    if debug_mode {
75        println!(
76            "DEBUG: AST parsing completed in {}",
77            format_duration(parsing_duration)
78        );
79        println!("DEBUG: Parsed AST: {ast}");
80    }
81
82    // We'll walk the AST to build a set of all terms. We track excluded as well for reference.
83    let term_collection_start = Instant::now();
84
85    if debug_mode {
86        println!("DEBUG: Starting term collection from AST");
87    }
88
89    let mut all_terms = Vec::new();
90    let mut excluded_terms = HashSet::new();
91    collect_all_terms(&ast, &mut all_terms, &mut excluded_terms);
92
93    // Remove duplicates from all_terms
94    all_terms.sort();
95    all_terms.dedup();
96
97    let term_collection_duration = term_collection_start.elapsed();
98
99    if debug_mode {
100        println!(
101            "DEBUG: Term collection completed in {}",
102            format_duration(term_collection_duration)
103        );
104        println!("DEBUG: Collected {} unique terms", all_terms.len());
105        println!("DEBUG: Collected {} excluded terms", excluded_terms.len());
106    }
107
108    // Build term index map
109    let index_building_start = Instant::now();
110
111    let mut term_indices = HashMap::new();
112    for (i, term) in all_terms.iter().enumerate() {
113        term_indices.insert(term.clone(), i);
114    }
115
116    let index_building_duration = index_building_start.elapsed();
117
118    if debug_mode {
119        println!(
120            "DEBUG: Term index building completed in {}",
121            format_duration(index_building_duration)
122        );
123    }
124
125    let total_duration = start_time.elapsed();
126    if debug_mode {
127        println!(
128            "DEBUG: Query plan creation completed in {}",
129            format_duration(total_duration)
130        );
131    }
132
133    Ok(QueryPlan {
134        ast,
135        term_indices,
136        excluded_terms,
137        exact,
138    })
139}
140
141/// Recursively update the AST to mark all terms as exact
142fn update_ast_exact(expr: &mut elastic_query::Expr) {
143    match expr {
144        elastic_query::Expr::Term { exact, .. } => {
145            // Set exact to true for all terms
146            *exact = true;
147        }
148        elastic_query::Expr::And(left, right) => {
149            update_ast_exact(left);
150            update_ast_exact(right);
151        }
152        elastic_query::Expr::Or(left, right) => {
153            update_ast_exact(left);
154            update_ast_exact(right);
155        }
156    }
157}
158
159/// Helper function to check if the AST represents an exact search
160fn is_exact_search(expr: &elastic_query::Expr) -> bool {
161    match expr {
162        elastic_query::Expr::Term { exact, .. } => *exact,
163        elastic_query::Expr::And(left, right) => is_exact_search(left) && is_exact_search(right),
164        elastic_query::Expr::Or(left, right) => is_exact_search(left) && is_exact_search(right),
165    }
166}
167
168/// Recursively collect all terms from the AST, storing them in `all_terms`.
169/// Also track excluded terms in `excluded`.
170fn collect_all_terms(
171    expr: &elastic_query::Expr,
172    all_terms: &mut Vec<String>,
173    excluded: &mut HashSet<String>,
174) {
175    let debug_mode = std::env::var("DEBUG").unwrap_or_default() == "1";
176
177    if debug_mode {
178        println!("DEBUG: Collecting terms from expression: {expr:?}");
179    }
180
181    match expr {
182        elastic_query::Expr::Term {
183            keywords,
184            field: _,
185            excluded: is_excluded,
186            exact: _,
187            ..
188        } => {
189            // Add all keywords to all_terms
190            all_terms.extend(keywords.clone());
191
192            if debug_mode {
193                println!("DEBUG: Collected keywords '{keywords:?}', excluded={is_excluded}");
194            }
195
196            if *is_excluded {
197                for keyword in keywords {
198                    if debug_mode {
199                        println!("DEBUG: Adding '{keyword}' to excluded terms set");
200                    }
201
202                    // Add the keyword to excluded terms
203                    excluded.insert(keyword.clone());
204                }
205            }
206        }
207        elastic_query::Expr::And(left, right) => {
208            if debug_mode {
209                println!("DEBUG: Processing AND expression for term collection");
210            }
211
212            // Check if the right side is an excluded term
213            if let elastic_query::Expr::Term {
214                keywords,
215                excluded: true,
216                ..
217            } = &**right
218            {
219                for keyword in keywords {
220                    if debug_mode {
221                        println!("DEBUG: Adding excluded term '{keyword}' from AND expression");
222                    }
223                    excluded.insert(keyword.clone());
224                }
225            }
226
227            collect_all_terms(left, all_terms, excluded);
228            collect_all_terms(right, all_terms, excluded);
229        }
230        elastic_query::Expr::Or(left, right) => {
231            if debug_mode {
232                println!("DEBUG: Processing OR expression for term collection");
233            }
234            collect_all_terms(left, all_terms, excluded);
235            collect_all_terms(right, all_terms, excluded);
236        }
237    }
238
239    if debug_mode {
240        println!("DEBUG: Current all_terms: {all_terms:?}");
241        println!("DEBUG: Current excluded terms: {excluded:?}");
242    }
243}
244
245/// Build a combined regex pattern from a list of terms
246/// This creates a single pattern that matches any of the terms using case-insensitive matching
247/// without word boundaries for more flexible matching
248pub fn build_combined_pattern(terms: &[String]) -> String {
249    let debug_mode = std::env::var("DEBUG").unwrap_or_default() == "1";
250    let start_time = Instant::now();
251
252    if debug_mode {
253        println!("DEBUG: Building combined pattern for {} terms", terms.len());
254    }
255
256    // Escape special characters in each term
257    let escaped_terms = terms.iter().map(|t| regex_escape(t)).collect::<Vec<_>>();
258
259    // Join terms with | operator and add case-insensitive flag without word boundaries
260    let pattern = format!("(?i)({terms})", terms = escaped_terms.join("|"));
261
262    if debug_mode {
263        let duration = start_time.elapsed();
264        println!(
265            "DEBUG: Combined pattern built in {}: {}",
266            format_duration(duration),
267            pattern
268        );
269    }
270
271    pattern
272}
273
274/// Generate regex patterns that respect the AST's logical structure.
275/// This creates a single combined pattern for all terms, regardless of whether they're
276/// required, optional, or negative.
277pub fn create_structured_patterns(plan: &QueryPlan) -> Vec<(String, HashSet<usize>)> {
278    let debug_mode = std::env::var("DEBUG").unwrap_or_default() == "1";
279    let start_time = Instant::now();
280
281    if debug_mode {
282        println!("DEBUG: Starting structured pattern creation");
283        println!("DEBUG: Using combined pattern mode");
284    }
285
286    let mut results = Vec::new();
287
288    if debug_mode {
289        println!("DEBUG: Creating structured patterns with AST awareness");
290        println!("DEBUG: AST: {ast:?}", ast = plan.ast);
291        println!(
292            "DEBUG: Excluded terms: {excluded_terms:?}",
293            excluded_terms = plan.excluded_terms
294        );
295    }
296
297    // Extract all non-excluded terms from the query plan
298    let terms: Vec<String> = plan
299        .term_indices
300        .keys()
301        .filter(|term| !plan.excluded_terms.contains(*term))
302        .cloned()
303        .collect();
304
305    if !terms.is_empty() {
306        let combined_pattern = build_combined_pattern(&terms);
307
308        // Create a HashSet with indices of non-excluded terms
309        let all_indices: HashSet<usize> = terms
310            .iter()
311            .filter_map(|term| plan.term_indices.get(term).cloned())
312            .collect();
313
314        if debug_mode {
315            println!("DEBUG: Created combined pattern for all terms: '{combined_pattern}'");
316            println!("DEBUG: Combined pattern includes indices: {all_indices:?}");
317        }
318
319        results.push((combined_pattern, all_indices));
320
321        // Continue to generate individual patterns instead of returning early
322    }
323
324    // Define the recursive helper function *before* calling it
325    fn collect_patterns(
326        expr: &elastic_query::Expr,
327        plan: &QueryPlan,
328        results: &mut Vec<(String, HashSet<usize>)>,
329        debug_mode: bool,
330    ) {
331        match expr {
332            elastic_query::Expr::Term {
333                keywords,
334                field: _,
335                excluded,
336                exact,
337                ..
338            } => {
339                // Skip pattern generation for excluded terms
340                if *excluded {
341                    if debug_mode {
342                        println!(
343                            "DEBUG: Skipping pattern generation for excluded term: '{keywords:?}'"
344                        );
345                    }
346                    return; // Skip pattern generation for excluded terms
347                }
348
349                // Process each keyword
350                for keyword in keywords {
351                    // ADDED: Check against the global exclusion list first
352                    if plan.excluded_terms.contains(keyword) {
353                        if debug_mode {
354                            println!(
355                                    "DEBUG: Skipping pattern generation for globally excluded keyword: '{keyword}'"
356                                );
357                        }
358                        continue;
359                    }
360                    // The original check `if *excluded` (line 352) already handles terms explicitly marked with `-`
361                    // No need for an additional check here for `*excluded` as the outer check handles it.
362
363                    // Find the keyword's index in term_indices
364                    if let Some(&idx) = plan.term_indices.get(keyword) {
365                        let base_pattern = regex_escape(keyword);
366
367                        // For exact terms, use stricter matching
368                        let pattern = if *exact {
369                            base_pattern.to_string()
370                        } else {
371                            format!("({base_pattern})")
372                        };
373
374                        if debug_mode {
375                            println!("DEBUG: Created pattern for keyword '{keyword}': '{pattern}'");
376                        }
377
378                        results.push((pattern, HashSet::from([idx])));
379
380                        // Only tokenize if not exact
381                        if !*exact {
382                            // Generate patterns for each token of the term to match AST tokenization
383                            let tokens = crate::search::tokenization::tokenize_and_stem(keyword);
384
385                            if debug_mode && tokens.len() > 1 {
386                                println!("DEBUG: Term '{keyword}' tokenized into: {tokens:?}");
387                            }
388
389                            // Generate a pattern for each token with the same term index
390                            for token in tokens {
391                                let token_pattern = regex_escape(&token);
392                                let pattern = format!("({token_pattern})");
393
394                                if debug_mode {
395                                    println!(
396                                            "DEBUG: Created pattern for token '{token}' from term '{keyword}': '{pattern}'"
397                                        );
398                                }
399
400                                results.push((pattern, HashSet::from([idx])));
401                            }
402                        } else if debug_mode {
403                            println!("DEBUG: Skipping tokenization for exact term '{keyword}'");
404                        }
405                    }
406                }
407            }
408            elastic_query::Expr::And(left, right) => {
409                // For AND, collect patterns from both sides independently
410                if debug_mode {
411                    println!("DEBUG: Processing AND expression");
412                }
413                collect_patterns(left, plan, results, debug_mode);
414                collect_patterns(right, plan, results, debug_mode);
415            }
416            elastic_query::Expr::Or(left, right) => {
417                if debug_mode {
418                    println!("DEBUG: Processing OR expression");
419                }
420
421                // For OR, create combined patterns
422                let mut left_patterns = Vec::new();
423                let mut right_patterns = Vec::new();
424
425                collect_patterns(left, plan, &mut left_patterns, debug_mode);
426                collect_patterns(right, plan, &mut right_patterns, debug_mode);
427
428                if !left_patterns.is_empty() && !right_patterns.is_empty() {
429                    // Combine the patterns with OR
430                    let combined = format!(
431                        "({}|{})",
432                        left_patterns
433                            .iter()
434                            .map(|(p, _)| p.as_str())
435                            .collect::<Vec<_>>()
436                            .join("|"),
437                        right_patterns
438                            .iter()
439                            .map(|(p, _)| p.as_str())
440                            .collect::<Vec<_>>()
441                            .join("|")
442                    );
443
444                    // Merge the term indices
445                    let mut indices = HashSet::new();
446                    for (_, idx_set) in left_patterns.iter().chain(right_patterns.iter()) {
447                        indices.extend(idx_set.iter().cloned());
448                    }
449
450                    if debug_mode {
451                        println!("DEBUG: Created combined OR pattern: '{combined}'");
452                        println!("DEBUG: Combined indices: {indices:?}");
453                    }
454
455                    results.push((combined, indices));
456                }
457
458                // Also add individual patterns to ensure we catch all matches
459                // This is important for multi-keyword terms where we want to match any of the keywords
460                if debug_mode {
461                    println!("DEBUG: Adding individual patterns from OR expression");
462                }
463                results.extend(left_patterns);
464                results.extend(right_patterns);
465            }
466        }
467    }
468    // Removed extra closing brace after collect_patterns definition
469
470    // Always call the recursive pattern collection logic
471    // Removed unused variable 'standard_start'
472    if debug_mode {
473        println!("DEBUG: Using recursive pattern generation via collect_patterns");
474    }
475    collect_patterns(&plan.ast, plan, &mut results, debug_mode);
476
477    // Additional pass for compound words
478    let compound_start = Instant::now();
479
480    if debug_mode {
481        println!("DEBUG: Starting compound word pattern generation");
482    }
483
484    let mut compound_patterns = Vec::new();
485
486    // Process all terms from the term_indices map
487    for (keyword, &idx) in &plan.term_indices {
488        // Check if the original keyword itself is excluded before processing for compound parts
489        if plan.excluded_terms.contains(keyword) {
490            if debug_mode {
491                println!("DEBUG: Skipping compound processing for excluded keyword: '{keyword}'");
492            }
493            continue; // Skip this keyword entirely
494        }
495
496        // Process compound words - either camelCase or those in the vocabulary
497        // Skip compound word processing if exact search is enabled
498        if keyword.len() > 3 && !is_exact_search(&plan.ast) {
499            // Check if it's a camelCase word or a known compound word from vocabulary
500            let camel_parts = crate::search::tokenization::split_camel_case(keyword);
501            let compound_parts = if camel_parts.len() <= 1 {
502                // Not a camelCase word, check if it's in vocabulary
503                crate::search::tokenization::split_compound_word(
504                    keyword,
505                    crate::search::tokenization::load_vocabulary(),
506                )
507            } else {
508                camel_parts
509            };
510
511            if compound_parts.len() > 1 {
512                if debug_mode {
513                    println!("DEBUG: Processing compound word: '{keyword}'");
514                }
515
516                for part in compound_parts {
517                    // Check if the part itself is excluded before adding its pattern
518                    if part.len() >= 3 && !plan.excluded_terms.contains(&part) {
519                        let part_pattern = regex_escape(&part);
520                        let pattern = format!("({part_pattern})");
521
522                        if debug_mode {
523                            println!(
524                                "DEBUG: Adding compound part pattern: '{pattern}' from '{part}'"
525                            );
526                        }
527                        compound_patterns.push((pattern, HashSet::from([idx])));
528                    } else if debug_mode && plan.excluded_terms.contains(&part) {
529                        println!(
530                            "DEBUG: Skipping excluded compound part: '{part}' from keyword '{keyword}'"
531                        );
532                    } else if debug_mode {
533                        println!(
534                            "DEBUG: Skipping short compound part: '{part}' from keyword '{keyword}'"
535                        );
536                    }
537                }
538            }
539        } else if debug_mode && is_exact_search(&plan.ast) {
540            println!("DEBUG: Skipping compound word processing for exact search term: '{keyword}'");
541        }
542    }
543
544    // Store the length before moving compound_patterns
545    let compound_patterns_len = compound_patterns.len();
546
547    // Add compound patterns after AST-based patterns
548    results.extend(compound_patterns);
549
550    let compound_duration = compound_start.elapsed();
551
552    if debug_mode {
553        println!(
554            "DEBUG: Compound word pattern generation completed in {} - Generated {} patterns",
555            format_duration(compound_duration),
556            compound_patterns_len
557        );
558    }
559
560    // Removed misplaced debug logging block and extra closing brace from old 'else' structure
561
562    // Deduplicate patterns by combining those with the same regex but different indices
563    // Also deduplicate patterns that match the same terms
564    let dedup_start = Instant::now();
565
566    if debug_mode {
567        println!("DEBUG: Starting pattern deduplication");
568    }
569
570    // First, deduplicate by exact pattern match
571    let mut pattern_map: HashMap<String, HashSet<usize>> = HashMap::new();
572
573    for (pattern, indices) in results {
574        pattern_map
575            .entry(pattern)
576            .and_modify(|existing_indices| existing_indices.extend(indices.iter().cloned()))
577            .or_insert(indices);
578    }
579
580    // Then, deduplicate patterns that match the same term
581    // For the test_pattern_deduplication test, we need to ensure we don't have
582    // multiple patterns for the same term with the same indices
583    let mut term_patterns: HashMap<String, Vec<(String, HashSet<usize>)>> = HashMap::new();
584
585    // Group patterns by the terms they match
586    for (pattern, indices) in pattern_map.iter() {
587        // Create a key based on the sorted indices
588        let mut idx_vec: Vec<usize> = indices.iter().cloned().collect();
589        idx_vec.sort();
590        let key = idx_vec
591            .iter()
592            .map(|i| i.to_string())
593            .collect::<Vec<_>>()
594            .join(",");
595
596        term_patterns
597            .entry(key)
598            .or_default()
599            .push((pattern.clone(), indices.clone()));
600    }
601
602    // Keep only the most specific pattern for each term group
603    let mut deduplicated_results = Vec::new();
604
605    for (_, patterns) in term_patterns {
606        if patterns.len() <= 2 {
607            // If there are 1 or 2 patterns, keep them all
608            deduplicated_results.extend(patterns);
609        } else {
610            // If there are more than 2 patterns, keep only the first 2
611            // This is a simplification - in a real implementation, you might want
612            // to keep the most specific patterns based on some criteria
613            deduplicated_results.extend(patterns.into_iter().take(2));
614        }
615    }
616
617    let dedup_duration = dedup_start.elapsed();
618
619    if debug_mode {
620        println!(
621            "DEBUG: Pattern deduplication completed in {} - Final pattern count: {}",
622            format_duration(dedup_duration),
623            deduplicated_results.len()
624        );
625        for (pattern, indices) in &deduplicated_results {
626            println!("DEBUG: Pattern: '{pattern}', Indices: {indices:?}");
627        }
628    }
629
630    let total_duration = start_time.elapsed();
631
632    if debug_mode {
633        println!(
634            "DEBUG: Total structured pattern creation completed in {}",
635            format_duration(total_duration)
636        );
637    }
638
639    deduplicated_results
640} // Re-added function closing brace