Skip to main content

sqry_nl/
translator.rs

1//! Main Translator API for natural language to sqry command translation.
2//!
3//! This module ties together all the components of the translation pipeline:
4//! preprocess → extract → classify → assemble → validate → cache
5
6use crate::assembler;
7use crate::cache::{CacheConfig, CacheKey, CachedResult, TranslationCache};
8use crate::error::{AssemblerError, NlResult};
9use crate::extractor;
10use crate::preprocess;
11use crate::types::{
12    DisambiguationOption, ExtractedEntities, Intent, TranslationResponse, ValidationStatus,
13};
14use crate::validator;
15use std::sync::atomic::{AtomicU64, Ordering};
16use std::time::Instant;
17
18/// Confidence thresholds for response tiers.
19const EXECUTE_THRESHOLD: f32 = 0.85;
20const CONFIRM_THRESHOLD: f32 = 0.65;
21
22/// Default cache capacity (number of entries).
23const DEFAULT_CACHE_CAPACITY: usize = 128;
24
25/// Default result limit for cache key generation.
26const DEFAULT_RESULT_LIMIT: u32 = 100;
27
28/// Configuration for the Translator.
29#[derive(Debug, Clone)]
30pub struct TranslatorConfig {
31    /// Path to model directory (for classifier feature).
32    /// Directory should contain: `intent_classifier.onnx`, `tokenizer.json`, and optionally
33    /// `calibration.json` or `temperature.json` for confidence calibration
34    pub model_dir: Option<String>,
35    /// Context: current working directory for relative paths.
36    pub working_directory: Option<String>,
37    /// Custom confidence thresholds.
38    pub execute_threshold: f32,
39    pub confirm_threshold: f32,
40    /// Cache configuration. Set to None to disable caching.
41    pub cache_config: Option<CacheConfig>,
42    /// Default result limit (affects cache key generation).
43    pub default_limit: u32,
44    /// Languages to restrict searches to (affects cache key generation).
45    pub languages: Vec<String>,
46}
47
48impl Default for TranslatorConfig {
49    fn default() -> Self {
50        Self {
51            model_dir: None,
52            working_directory: None,
53            execute_threshold: EXECUTE_THRESHOLD,
54            confirm_threshold: CONFIRM_THRESHOLD,
55            cache_config: Some(CacheConfig {
56                capacity: DEFAULT_CACHE_CAPACITY,
57                ..Default::default()
58            }),
59            default_limit: DEFAULT_RESULT_LIMIT,
60            languages: Vec::new(),
61        }
62    }
63}
64
65/// The main Translator struct that provides the `translate()` API.
66pub struct Translator {
67    config: TranslatorConfig,
68    /// Translation counter for stats
69    translations: AtomicU64,
70    /// Translation cache for repeated queries (Step 7)
71    cache: Option<TranslationCache>,
72    #[cfg(feature = "classifier")]
73    classifier: Option<crate::classifier::IntentClassifier>,
74}
75
76impl Translator {
77    /// Create a new Translator with the given configuration.
78    ///
79    /// # Errors
80    ///
81    /// Returns an error if the classifier fails to load (when classifier feature is enabled).
82    pub fn new(config: TranslatorConfig) -> NlResult<Self> {
83        #[cfg(feature = "classifier")]
84        let classifier = if let Some(model_dir) = &config.model_dir {
85            use std::path::Path;
86            Some(crate::classifier::IntentClassifier::load(Path::new(
87                model_dir,
88            ))?)
89        } else {
90            None
91        };
92
93        // Initialize cache if configured
94        let cache = config
95            .cache_config
96            .as_ref()
97            .map(|cfg| TranslationCache::with_config(cfg.clone()));
98
99        Ok(Self {
100            config,
101            translations: AtomicU64::new(0),
102            cache,
103            #[cfg(feature = "classifier")]
104            classifier,
105        })
106    }
107
108    /// Create a Translator with default configuration.
109    ///
110    /// # Errors
111    ///
112    /// Returns an error if initialization fails.
113    pub fn load_default() -> NlResult<Self> {
114        Self::new(TranslatorConfig::default())
115    }
116
117    /// Translate a natural language query to a sqry command.
118    ///
119    /// # Arguments
120    ///
121    /// * `input` - The natural language query to translate
122    ///
123    /// # Returns
124    ///
125    /// A `TranslationResponse` indicating:
126    /// - `Execute`: High confidence, safe to run automatically
127    /// - `Confirm`: Medium confidence, ask user to confirm
128    /// - `Disambiguate`: Low confidence, need user clarification
129    /// - `Reject`: Cannot translate safely
130    ///
131    /// # Note
132    ///
133    /// This method requires `&mut self` when the classifier feature is enabled.
134    pub fn translate(&mut self, input: &str) -> TranslationResponse {
135        self.translations.fetch_add(1, Ordering::Relaxed);
136        self.translate_impl(input)
137    }
138
139    /// Internal translation implementation.
140    fn translate_impl(&mut self, input: &str) -> TranslationResponse {
141        let start_time = Instant::now();
142
143        // Create cache key from input and context
144        let cache_key = CacheKey::new(
145            input,
146            &self.config.languages,
147            self.config.working_directory.clone(),
148            self.config.default_limit,
149        );
150
151        // Check cache first
152        if let Some(cached_response) = self.cached_response(&cache_key, start_time) {
153            return cached_response;
154        }
155
156        // Step 1: Preprocess
157        let preprocessed = match preprocess::preprocess_input(input) {
158            Ok(p) => p,
159            Err(e) => {
160                return TranslationResponse::Reject {
161                    reason: format!("Preprocessing failed: {e}"),
162                    suggestions: vec!["Try simplifying your query".to_string()],
163                };
164            }
165        };
166
167        // Step 2: Extract entities
168        let entities = extractor::extract_entities(&preprocessed.text);
169
170        // Step 3: Classify intent
171        let (intent, confidence) = self.classify_intent(&preprocessed.text, &entities);
172
173        // Step 4: Assemble command
174        let command = match assembler::assemble_command(&intent, &entities) {
175            Ok(cmd) => cmd,
176            Err(e) => return Self::handle_assembly_error(e, &entities),
177        };
178
179        // Step 5: Validate command
180        self.handle_validation_result(
181            command, confidence, intent, &entities, cache_key, start_time,
182        )
183    }
184
185    fn cached_response(
186        &self,
187        cache_key: &CacheKey,
188        start_time: Instant,
189    ) -> Option<TranslationResponse> {
190        let cache = self.cache.as_ref()?;
191        let cached = cache.get(cache_key)?;
192        Some(TranslationResponse::Execute {
193            command: cached.command,
194            confidence: cached.confidence,
195            intent: cached.intent,
196            cached: true,
197            latency_ms: Self::elapsed_ms(start_time),
198        })
199    }
200
201    fn handle_validation_result(
202        &self,
203        command: String,
204        confidence: f32,
205        intent: Intent,
206        entities: &ExtractedEntities,
207        cache_key: CacheKey,
208        start_time: Instant,
209    ) -> TranslationResponse {
210        match validator::validate_command(&command) {
211            ValidationStatus::Valid => {
212                let latency_ms = Self::elapsed_ms(start_time);
213
214                if confidence >= self.config.execute_threshold
215                    && let Some(ref cache) = self.cache
216                {
217                    cache.put(
218                        cache_key,
219                        CachedResult {
220                            command: command.clone(),
221                            intent,
222                            confidence,
223                            created_at: Instant::now(),
224                        },
225                    );
226                }
227
228                self.create_response_with_latency(command, confidence, intent, entities, latency_ms)
229            }
230            ValidationStatus::RejectedMetachar => TranslationResponse::Reject {
231                reason: "Command contains disallowed shell characters".to_string(),
232                suggestions: vec![
233                    "Avoid special characters like ;, |, &, $".to_string(),
234                    "Use quoted strings for literal values".to_string(),
235                ],
236            },
237            ValidationStatus::RejectedEnvVar => TranslationResponse::Reject {
238                reason: "Command contains environment variable references".to_string(),
239                suggestions: vec![
240                    "Use literal paths instead of $HOME, ${VAR}".to_string(),
241                    "Specify the full path explicitly".to_string(),
242                ],
243            },
244            ValidationStatus::RejectedPathTraversal => TranslationResponse::Reject {
245                reason: "Command contains path traversal patterns".to_string(),
246                suggestions: vec![
247                    "Use relative paths within the project".to_string(),
248                    "Avoid .. in paths".to_string(),
249                ],
250            },
251            ValidationStatus::RejectedTooLong => TranslationResponse::Reject {
252                reason: "Generated command exceeds maximum length".to_string(),
253                suggestions: vec![
254                    "Try a simpler query".to_string(),
255                    "Reduce the number of filters".to_string(),
256                ],
257            },
258            ValidationStatus::RejectedWriteMode => TranslationResponse::Reject {
259                reason: "Command attempts write operation".to_string(),
260                suggestions: vec![
261                    "NL translation only supports read operations".to_string(),
262                    "Use CLI directly for write operations".to_string(),
263                ],
264            },
265            ValidationStatus::RejectedUnknown => {
266                let template_names = assembler::templates::TEMPLATES
267                    .iter()
268                    .map(|(name, _)| *name)
269                    .collect::<Vec<_>>()
270                    .join(", ");
271                let template_examples = ["query", "search", "trace-path"]
272                    .into_iter()
273                    .filter_map(assembler::templates::get_template)
274                    .map(str::to_string)
275                    .collect::<Vec<_>>()
276                    .join(" | ");
277
278                TranslationResponse::Reject {
279                    reason: "Command does not match any allowed template".to_string(),
280                    suggestions: vec![
281                        format!("Use supported command templates: {template_names}"),
282                        format!("Examples: {template_examples}"),
283                        "Try rephrasing your query".to_string(),
284                    ],
285                }
286            }
287        }
288    }
289
290    fn elapsed_ms(start_time: Instant) -> u64 {
291        u64::try_from(start_time.elapsed().as_millis()).unwrap_or(u64::MAX)
292    }
293
294    /// Classify the intent of the query.
295    #[allow(clippy::unused_self)] // Uses self when classifier feature is enabled.
296    fn classify_intent(&mut self, text: &str, entities: &ExtractedEntities) -> (Intent, f32) {
297        #[cfg(feature = "classifier")]
298        if let Some(ref mut classifier) = self.classifier {
299            match classifier.classify(text) {
300                Ok(result) => return (result.intent, result.confidence),
301                Err(e) => {
302                    // Log and fall back to rules
303                    eprintln!("Classifier failed, using fallback: {e}");
304                }
305            }
306        }
307
308        // Fallback: rule-based classification
309        Self::classify_intent_rules(text, entities)
310    }
311
312    /// Rule-based intent classification fallback.
313    fn classify_intent_rules(text: &str, entities: &ExtractedEntities) -> (Intent, f32) {
314        let text_lower = text.to_lowercase();
315
316        if let Some(intent) = Self::classify_graph_intent(&text_lower) {
317            return intent;
318        }
319
320        if let Some(intent) = Self::classify_index_intent(&text_lower) {
321            return intent;
322        }
323
324        if let Some(intent) = Self::classify_text_search_intent(&text_lower, text) {
325            return intent;
326        }
327
328        if let Some(intent) = Self::classify_symbol_query_intent(&text_lower, entities) {
329            return intent;
330        }
331
332        if Self::is_ambiguous(&text_lower) {
333            return (Intent::Ambiguous, 0.3);
334        }
335
336        (Intent::SymbolQuery, 0.5)
337    }
338
339    fn classify_graph_intent(text_lower: &str) -> Option<(Intent, f32)> {
340        if Self::matches_callers(text_lower) {
341            return Some((Intent::FindCallers, 0.85));
342        }
343
344        if Self::matches_callees(text_lower) {
345            return Some((Intent::FindCallees, 0.85));
346        }
347
348        if Self::matches_trace_path(text_lower) {
349            return Some((Intent::TracePath, 0.8));
350        }
351
352        if Self::matches_visualize(text_lower) {
353            return Some((Intent::Visualize, 0.8));
354        }
355
356        None
357    }
358
359    fn matches_callers(text_lower: &str) -> bool {
360        text_lower.contains("callers")
361            || text_lower.contains("who calls")
362            || text_lower.contains("what calls")
363            || text_lower.contains("who uses")
364            || text_lower.contains("who depends")
365            || text_lower.contains("find usages")
366            || text_lower.contains("find all references")
367            || text_lower.contains("where is") && text_lower.contains("used")
368    }
369
370    fn matches_callees(text_lower: &str) -> bool {
371        text_lower.contains("callees")
372            || text_lower.contains("what does") && text_lower.contains("call")
373            || text_lower.contains("functions called by")
374            || text_lower.contains("methods called by")
375            || text_lower.contains("dependencies of")
376            || text_lower.contains("outgoing calls")
377            || text_lower.contains("what functions does")
378            || text_lower.contains("what methods does")
379    }
380
381    fn matches_trace_path(text_lower: &str) -> bool {
382        text_lower.contains("trace")
383            || text_lower.contains("path from")
384            || text_lower.contains("path to")
385            || text_lower.contains("path between")
386            || text_lower.contains("call chain")
387            || text_lower.contains("call sequence")
388            || (text_lower.contains("how does") && text_lower.contains("reach"))
389            || (text_lower.contains("how does") && text_lower.contains("flow"))
390    }
391
392    fn matches_visualize(text_lower: &str) -> bool {
393        text_lower.contains("visualize")
394            || text_lower.contains("diagram")
395            || text_lower.contains("draw")
396            || text_lower.contains("mermaid")
397            || text_lower.contains("dot graph")
398            || (text_lower.contains("generate") && text_lower.contains("graph"))
399            || (text_lower.contains("show") && text_lower.contains("visual"))
400    }
401
402    fn classify_index_intent(text_lower: &str) -> Option<(Intent, f32)> {
403        if (text_lower.contains("index") && text_lower.contains("status"))
404            || text_lower.starts_with("index status")
405            || text_lower.contains("is index")
406            || text_lower.contains("check index")
407            || text_lower.contains("index info")
408            || text_lower.contains("index stat")
409            || text_lower.contains("indexed")
410            || text_lower.contains("what files are indexed")
411            || text_lower.contains("how many symbols")
412            || text_lower.contains("when was index")
413        {
414            return Some((Intent::IndexStatus, 0.85));
415        }
416
417        None
418    }
419
420    fn classify_text_search_intent(text_lower: &str, text: &str) -> Option<(Intent, f32)> {
421        let is_predicate_query = Self::is_predicate_query(text_lower);
422
423        if text_lower.starts_with("grep")
424            || text_lower.starts_with("search for")
425            || text_lower.contains("grep for")
426            || text_lower.contains("grep ")
427            || text_lower.contains("look for")
428            || (text_lower.contains("search") && !text_lower.contains("code search"))
429            || text_lower.contains("todo")
430            || text_lower.contains("fixme")
431            || text_lower.contains("deprecated")
432            || text_lower.contains("copyright")
433            || text_lower.contains("hardcoded")
434            || text.contains('!')
435            || (!is_predicate_query && text_lower.contains("unsafe"))
436            || text_lower.contains(" pub ")
437            || text_lower.contains(" mut ")
438            || (!is_predicate_query && text_lower.contains("async"))
439            || text_lower.contains("unsafe blocks")
440            || text_lower.contains("impl blocks")
441            || text_lower.contains("import")
442            || text_lower.contains("use statement")
443            || text_lower.contains("require")
444        {
445            return Some((Intent::TextSearch, 0.8));
446        }
447
448        None
449    }
450
451    fn classify_symbol_query_intent(
452        text_lower: &str,
453        entities: &ExtractedEntities,
454    ) -> Option<(Intent, f32)> {
455        if text_lower.starts_with("find")
456            || text_lower.starts_with("show")
457            || text_lower.starts_with("list")
458            || text_lower.starts_with("where is")
459            || text_lower.starts_with("where are")
460            || text_lower.contains("function")
461            || text_lower.contains("method")
462            || text_lower.contains("class")
463            || text_lower.contains("struct")
464            || text_lower.contains("enum")
465            || text_lower.contains("trait")
466            || text_lower.contains("interface")
467            || text_lower.contains("module")
468            || text_lower.contains("constant")
469            || text_lower.contains("variable")
470            || text_lower.contains("public")
471            || text_lower.contains("private")
472            || text_lower.contains("defined")
473        {
474            return Some((Intent::SymbolQuery, 0.8));
475        }
476
477        if entities.kind.is_some() {
478            return Some((Intent::SymbolQuery, 0.85));
479        }
480
481        if !entities.symbols.is_empty() {
482            return Some((Intent::SymbolQuery, 0.7));
483        }
484
485        if !entities.languages.is_empty() {
486            return Some((Intent::SymbolQuery, 0.65));
487        }
488
489        None
490    }
491
492    fn is_predicate_query(text_lower: &str) -> bool {
493        text_lower.contains("functions")
494            || text_lower.contains("methods")
495            || text_lower.contains("function")
496            || text_lower.contains("method")
497    }
498
499    fn is_ambiguous(text_lower: &str) -> bool {
500        text_lower.split_whitespace().count() <= 2
501    }
502
503    /// Create the appropriate response based on confidence level (with latency tracking).
504    fn create_response_with_latency(
505        &self,
506        command: String,
507        confidence: f32,
508        intent: Intent,
509        entities: &ExtractedEntities,
510        latency_ms: u64,
511    ) -> TranslationResponse {
512        if confidence >= self.config.execute_threshold {
513            TranslationResponse::Execute {
514                command,
515                confidence,
516                intent,
517                cached: false,
518                latency_ms,
519            }
520        } else if confidence >= self.config.confirm_threshold {
521            let prompt = format!(
522                "I'll run: {}\nConfidence: {:.0}%. Proceed? [y/N]",
523                command,
524                confidence * 100.0
525            );
526            TranslationResponse::Confirm {
527                command,
528                confidence,
529                prompt,
530            }
531        } else {
532            // Disambiguate - present options to user
533            let options = Self::generate_disambiguation_options(entities);
534            TranslationResponse::Disambiguate {
535                options,
536                prompt: "I'm not sure what you mean. Did you want to:".to_string(),
537            }
538        }
539    }
540
541    /// Create the appropriate response based on confidence level.
542    #[allow(dead_code)]
543    fn create_response(
544        &self,
545        command: String,
546        confidence: f32,
547        intent: Intent,
548        entities: &ExtractedEntities,
549    ) -> TranslationResponse {
550        self.create_response_with_latency(command, confidence, intent, entities, 0)
551    }
552
553    /// Generate disambiguation options when confidence is low.
554    fn generate_disambiguation_options(entities: &ExtractedEntities) -> Vec<DisambiguationOption> {
555        let mut options = Vec::new();
556
557        if let Some(symbol) = entities.primary_symbol() {
558            options.push(DisambiguationOption {
559                command: format!("sqry query \"{symbol}\""),
560                intent: Intent::SymbolQuery,
561                description: format!("Search for symbol \"{symbol}\""),
562                confidence: 0.5,
563            });
564            options.push(DisambiguationOption {
565                command: format!("sqry graph direct-callers \"{symbol}\""),
566                intent: Intent::FindCallers,
567                description: format!("Find callers of \"{symbol}\""),
568                confidence: 0.4,
569            });
570        } else {
571            options.push(DisambiguationOption {
572                command: "sqry query \"<symbol>\"".to_string(),
573                intent: Intent::SymbolQuery,
574                description: "Search for a specific symbol".to_string(),
575                confidence: 0.3,
576            });
577        }
578
579        options
580    }
581
582    /// Handle assembly errors with helpful suggestions.
583    fn handle_assembly_error(
584        error: AssemblerError,
585        entities: &ExtractedEntities,
586    ) -> TranslationResponse {
587        match error {
588            AssemblerError::MissingSymbol => {
589                let suggestions = if entities.languages.is_empty() {
590                    vec![
591                        "Specify what symbol or pattern you're looking for".to_string(),
592                        "Example: find \"authenticate\" in rust".to_string(),
593                    ]
594                } else {
595                    vec![
596                        format!(
597                            "Try: find <symbol name> in {}",
598                            entities.languages.join(", ")
599                        ),
600                        "Specify what you're looking for in quotes".to_string(),
601                    ]
602                };
603                TranslationResponse::Reject {
604                    reason: "Could not determine what to search for".to_string(),
605                    suggestions,
606                }
607            }
608            AssemblerError::AmbiguousIntent => TranslationResponse::Disambiguate {
609                options: vec![
610                    DisambiguationOption {
611                        command: "sqry query \"<symbol>\"".to_string(),
612                        intent: Intent::SymbolQuery,
613                        description: "Search for symbols matching a pattern".to_string(),
614                        confidence: 0.3,
615                    },
616                    DisambiguationOption {
617                        command: "sqry graph direct-callers \"<symbol>\"".to_string(),
618                        intent: Intent::FindCallers,
619                        description: "Find callers of a function".to_string(),
620                        confidence: 0.3,
621                    },
622                ],
623                prompt: "Please clarify what you'd like to do:".to_string(),
624            },
625            AssemblerError::MissingTracePath => TranslationResponse::Reject {
626                reason: "Trace path requires both source and target symbols".to_string(),
627                suggestions: vec![
628                    "Specify two symbols: trace path from X to Y".to_string(),
629                    "Example: trace path from login to database".to_string(),
630                ],
631            },
632            AssemblerError::CommandTooLong { .. } => TranslationResponse::Reject {
633                reason: "Generated command is too long".to_string(),
634                suggestions: vec![
635                    "Try a simpler query".to_string(),
636                    "Reduce the number of filters".to_string(),
637                ],
638            },
639            AssemblerError::NoTemplate(intent_name) => TranslationResponse::Reject {
640                reason: format!("No template available for intent: {intent_name}"),
641                suggestions: vec![
642                    "Try a different query type".to_string(),
643                    "Supported queries: symbol search, callers, callees, trace path".to_string(),
644                ],
645            },
646        }
647    }
648
649    /// Get translation count.
650    #[must_use]
651    pub fn translation_count(&self) -> u64 {
652        self.translations.load(Ordering::Relaxed)
653    }
654
655    /// Get cache statistics.
656    ///
657    /// Returns `None` if caching is disabled.
658    #[must_use]
659    pub fn cache_stats(&self) -> Option<crate::cache::CacheStats> {
660        self.cache
661            .as_ref()
662            .map(super::cache::TranslationCache::stats)
663    }
664
665    /// Get cache hit rate (0.0-1.0).
666    ///
667    /// Returns `None` if caching is disabled.
668    #[must_use]
669    pub fn cache_hit_rate(&self) -> Option<f64> {
670        self.cache
671            .as_ref()
672            .map(super::cache::TranslationCache::hit_rate)
673    }
674
675    /// Clear the translation cache.
676    ///
677    /// Does nothing if caching is disabled.
678    pub fn clear_cache(&self) {
679        if let Some(ref cache) = self.cache {
680            cache.clear();
681        }
682    }
683}
684
685#[cfg(test)]
686mod tests {
687    use super::*;
688
689    #[test]
690    fn test_translator_creation() {
691        let translator = Translator::load_default().unwrap();
692        assert_eq!(translator.translation_count(), 0);
693    }
694
695    #[test]
696    fn test_translate_simple_query() {
697        let mut translator = Translator::load_default().unwrap();
698        let response = translator.translate("find authentication functions");
699
700        // Should not be a Reject response for missing symbol
701        if let TranslationResponse::Reject { reason, .. } = &response {
702            // May reject due to validation, but not missing symbol
703            assert!(!reason.contains("Could not determine"));
704        }
705        assert_eq!(translator.translation_count(), 1);
706    }
707
708    #[test]
709    fn test_translate_with_language() {
710        let mut translator = Translator::load_default().unwrap();
711        let response = translator.translate("find authentication in rust");
712
713        match response {
714            TranslationResponse::Execute { command, .. }
715            | TranslationResponse::Confirm { command, .. } => {
716                assert!(command.contains("--language rust"));
717            }
718            _ => {} // Disambiguate or Reject is ok for rule-based fallback
719        }
720    }
721
722    #[test]
723    fn test_translate_callers() {
724        let mut translator = Translator::load_default().unwrap();
725        let response = translator.translate("who calls authenticate");
726
727        match response {
728            TranslationResponse::Execute { intent, .. } => {
729                assert_eq!(intent, Intent::FindCallers);
730            }
731            TranslationResponse::Confirm { command, .. } => {
732                // Confirm doesn't carry intent, but should have graph direct-callers command
733                assert!(
734                    command.contains("graph direct-callers") || command.contains("authenticate")
735                );
736            }
737            _ => {}
738        }
739    }
740
741    #[test]
742    fn test_custom_thresholds() {
743        let config = TranslatorConfig {
744            execute_threshold: 0.99,
745            confirm_threshold: 0.90,
746            ..Default::default()
747        };
748        let mut translator = Translator::new(config).unwrap();
749
750        // With high thresholds, most queries should need confirmation or disambiguation
751        let response = translator.translate("find foo");
752        assert!(!matches!(response, TranslationResponse::Execute { .. }));
753    }
754
755    #[test]
756    fn test_kind_only_query() {
757        let mut translator = Translator::load_default().unwrap();
758
759        // Kind-only queries should work (e.g., "list all traits")
760        let response = translator.translate("list all traits");
761        match response {
762            TranslationResponse::Execute { command, .. }
763            | TranslationResponse::Confirm { command, .. } => {
764                // kind is now part of the query expression as a predicate
765                assert!(command.contains("kind:trait"));
766            }
767            _ => panic!("Expected Execute or Confirm response"),
768        }
769    }
770
771    #[test]
772    fn test_snake_case_symbol() {
773        let mut translator = Translator::load_default().unwrap();
774
775        // Snake_case symbols should be extracted correctly
776        let response = translator.translate("find user_id variable");
777        match response {
778            TranslationResponse::Execute { command, .. }
779            | TranslationResponse::Confirm { command, .. } => {
780                assert!(command.contains("user_id"));
781            }
782            _ => panic!("Expected Execute or Confirm response"),
783        }
784    }
785}
786
787// Predicate translation regression tests
788#[cfg(test)]
789mod predicate_translation_tests {
790    use super::*;
791
792    #[test]
793    fn test_async_functions_translation() {
794        let config = TranslatorConfig::default();
795        let mut translator = Translator::new(config).expect("Translator init failed");
796
797        let response = translator.translate("find async functions");
798        match response {
799            TranslationResponse::Execute { command, .. }
800            | TranslationResponse::Confirm { command, .. } => {
801                assert!(command.contains("async:true"));
802            }
803            _ => panic!("should execute or confirm"),
804        }
805    }
806
807    #[test]
808    fn test_unsafe_functions_translation() {
809        let config = TranslatorConfig::default();
810        let mut translator = Translator::new(config).expect("Translator init failed");
811
812        let response = translator.translate("find unsafe functions");
813        match response {
814            TranslationResponse::Execute { command, .. }
815            | TranslationResponse::Confirm { command, .. } => {
816                assert!(command.contains("unsafe:true"));
817            }
818            _ => panic!("should execute or confirm"),
819        }
820    }
821
822    #[test]
823    fn test_public_async_functions_translation() {
824        let config = TranslatorConfig::default();
825        let mut translator = Translator::new(config).expect("Translator init failed");
826
827        let response = translator.translate("find public async functions");
828        match response {
829            TranslationResponse::Execute { command, .. }
830            | TranslationResponse::Confirm { command, .. } => {
831                assert!(command.contains("visibility:public"));
832                assert!(command.contains("async:true"));
833            }
834            _ => panic!("should execute or confirm"),
835        }
836    }
837}