1use crate::assembler;
7use crate::cache::{CacheConfig, CacheKey, CachedResult, TranslationCache};
8use crate::error::{AssemblerError, NlResult};
9use crate::extractor;
10use crate::preprocess;
11use crate::types::{
12 DisambiguationOption, ExtractedEntities, Intent, TranslationResponse, ValidationStatus,
13};
14use crate::validator;
15use std::path::PathBuf;
16use std::sync::atomic::{AtomicU64, Ordering};
17use std::time::Instant;
18
19const EXECUTE_THRESHOLD: f32 = 0.85;
21const CONFIRM_THRESHOLD: f32 = 0.65;
22
23const DEFAULT_CACHE_CAPACITY: usize = 128;
25
26const DEFAULT_RESULT_LIMIT: u32 = 100;
28
29#[derive(Debug, Clone)]
31pub struct TranslatorConfig {
32 pub model_dir: Option<String>,
36 pub working_directory: Option<String>,
38 pub execute_threshold: f32,
40 pub confirm_threshold: f32,
41 pub cache_config: Option<CacheConfig>,
43 pub default_limit: u32,
45 pub languages: Vec<String>,
47 pub model_dir_override: Option<PathBuf>,
53 pub allow_unverified_model: bool,
58 pub allow_model_download: bool,
62 pub model_cache_dir: Option<PathBuf>,
66 pub classifier_pool_size: Option<usize>,
71}
72
73impl Default for TranslatorConfig {
74 fn default() -> Self {
75 Self {
76 model_dir: None,
77 working_directory: None,
78 execute_threshold: EXECUTE_THRESHOLD,
79 confirm_threshold: CONFIRM_THRESHOLD,
80 cache_config: Some(CacheConfig {
81 capacity: DEFAULT_CACHE_CAPACITY,
82 ..Default::default()
83 }),
84 default_limit: DEFAULT_RESULT_LIMIT,
85 languages: Vec::new(),
86 model_dir_override: None,
87 allow_unverified_model: false,
88 allow_model_download: false,
89 model_cache_dir: None,
90 classifier_pool_size: None,
91 }
92 }
93}
94
95pub struct Translator {
102 config: TranslatorConfig,
103 translations: AtomicU64,
105 cache: Option<TranslationCache>,
107 #[cfg(feature = "classifier")]
111 classifier_pool: Option<crate::classifier::ClassifierPool>,
112}
113
114impl Translator {
115 pub fn new(config: TranslatorConfig) -> NlResult<Self> {
121 #[cfg(feature = "classifier")]
122 let classifier_pool = {
123 use crate::classifier::{
124 BAKED_MANIFEST, ClassifierPool, RealDirs, ResolverLevel, TrustMode,
125 ensure_model_in_cache, resolve_model_dir, resolve_pool_size,
126 };
127 use std::ffi::OsString;
128 use std::path::{Path, PathBuf};
129
130 let cli_override: Option<&Path> = config.model_dir_override.as_deref();
137 let legacy_path: Option<&Path> = config.model_dir.as_deref().map(Path::new);
138 let env_value: Option<OsString> = std::env::var_os("SQRY_NL_MODEL_DIR");
139 let env_ref = env_value.as_deref();
140 let exe = std::env::current_exe().ok();
141 let exe_ref = exe.as_deref();
142
143 let resolved =
144 resolve_model_dir(cli_override, legacy_path, env_ref, &RealDirs, exe_ref);
145
146 let resolved: Option<(PathBuf, ResolverLevel)> = match resolved {
154 Some(hit) => Some(hit),
155 None if config.allow_model_download => {
156 let cache_dir: PathBuf = config
157 .model_cache_dir
158 .clone()
159 .or_else(|| dirs::cache_dir().map(|p| p.join("sqry/models")))
160 .ok_or_else(|| {
161 crate::error::NlError::Config(
162 "no platform cache_dir available for model download".to_string(),
163 )
164 })?;
165 let dir = ensure_model_in_cache(&cache_dir, &BAKED_MANIFEST, true)?;
166 Some((dir, ResolverLevel::XdgCache))
167 }
168 None => None,
169 };
170
171 match resolved {
172 Some((model_dir, level)) => {
173 let trust_mode = TrustMode::from(level);
174 if matches!(trust_mode, TrustMode::Custom) {
179 tracing::warn!(
180 target: "sqry_nl::classifier",
181 model_dir = %model_dir.display(),
182 resolver_level = ?level,
183 "Loading NL classifier under custom trust mode — \
184 integrity rooted in user-supplied manifest.json. \
185 For trusted defaults use the XDG cache or the \
186 binary-adjacent install location."
187 );
188 }
189
190 let pool_size = resolve_pool_size(config.classifier_pool_size);
197 tracing::info!(
198 target: "sqry_nl::classifier",
199 model_dir = %model_dir.display(),
200 pool_size,
201 "Initialising NL classifier pool"
202 );
203 let model_dir_for_loader = model_dir.clone();
204 let pool = ClassifierPool::new(pool_size, move || {
205 crate::classifier::IntentClassifier::load(
206 &model_dir_for_loader,
207 config.allow_unverified_model,
208 trust_mode,
209 )
210 .map_err(crate::error::NlError::from)
211 })?;
212 Some(pool)
213 }
214 None => None,
215 }
216 };
217
218 let cache = config
220 .cache_config
221 .as_ref()
222 .map(|cfg| TranslationCache::with_config(cfg.clone()));
223
224 Ok(Self {
225 config,
226 translations: AtomicU64::new(0),
227 cache,
228 #[cfg(feature = "classifier")]
229 classifier_pool,
230 })
231 }
232
233 pub fn load_default() -> NlResult<Self> {
239 Self::new(TranslatorConfig::default())
240 }
241
242 pub fn translate(&mut self, input: &str) -> TranslationResponse {
266 self.translate_shared(input)
267 }
268
269 pub fn translate_shared(&self, input: &str) -> TranslationResponse {
276 self.translations.fetch_add(1, Ordering::Relaxed);
277 self.translate_impl(input)
278 }
279
280 fn translate_impl(&self, input: &str) -> TranslationResponse {
282 let start_time = Instant::now();
283
284 let cache_key = CacheKey::new(
286 input,
287 &self.config.languages,
288 self.config.working_directory.clone(),
289 self.config.default_limit,
290 );
291
292 if let Some(cached_response) = self.cached_response(&cache_key, start_time) {
294 return cached_response;
295 }
296
297 let preprocessed = match preprocess::preprocess_input(input) {
299 Ok(p) => p,
300 Err(e) => {
301 return TranslationResponse::Reject {
302 reason: format!("Preprocessing failed: {e}"),
303 suggestions: vec!["Try simplifying your query".to_string()],
304 };
305 }
306 };
307
308 let entities = extractor::extract_entities(&preprocessed.text);
310
311 let (intent, confidence) = self.classify_intent(&preprocessed.text, &entities);
313
314 let command = match assembler::assemble_command(&intent, &entities) {
316 Ok(cmd) => cmd,
317 Err(e) => return Self::handle_assembly_error(e, &entities),
318 };
319
320 self.handle_validation_result(
322 command, confidence, intent, &entities, cache_key, start_time,
323 )
324 }
325
326 fn cached_response(
327 &self,
328 cache_key: &CacheKey,
329 start_time: Instant,
330 ) -> Option<TranslationResponse> {
331 let cache = self.cache.as_ref()?;
332 let cached = cache.get(cache_key)?;
333 Some(TranslationResponse::Execute {
334 command: cached.command,
335 confidence: cached.confidence,
336 intent: cached.intent,
337 cached: true,
338 latency_ms: Self::elapsed_ms(start_time),
339 })
340 }
341
342 fn handle_validation_result(
343 &self,
344 command: String,
345 confidence: f32,
346 intent: Intent,
347 entities: &ExtractedEntities,
348 cache_key: CacheKey,
349 start_time: Instant,
350 ) -> TranslationResponse {
351 match validator::validate_command(&command) {
352 ValidationStatus::Valid => {
353 let latency_ms = Self::elapsed_ms(start_time);
354
355 if confidence >= self.config.execute_threshold
356 && let Some(ref cache) = self.cache
357 {
358 cache.put(
359 cache_key,
360 CachedResult {
361 command: command.clone(),
362 intent,
363 confidence,
364 created_at: Instant::now(),
365 },
366 );
367 }
368
369 self.create_response_with_latency(command, confidence, intent, entities, latency_ms)
370 }
371 ValidationStatus::RejectedMetachar => TranslationResponse::Reject {
372 reason: "Command contains disallowed shell characters".to_string(),
373 suggestions: vec![
374 "Avoid special characters like ;, |, &, $".to_string(),
375 "Use quoted strings for literal values".to_string(),
376 ],
377 },
378 ValidationStatus::RejectedEnvVar => TranslationResponse::Reject {
379 reason: "Command contains environment variable references".to_string(),
380 suggestions: vec![
381 "Use literal paths instead of $HOME, ${VAR}".to_string(),
382 "Specify the full path explicitly".to_string(),
383 ],
384 },
385 ValidationStatus::RejectedPathTraversal => TranslationResponse::Reject {
386 reason: "Command contains path traversal patterns".to_string(),
387 suggestions: vec![
388 "Use relative paths within the project".to_string(),
389 "Avoid .. in paths".to_string(),
390 ],
391 },
392 ValidationStatus::RejectedTooLong => TranslationResponse::Reject {
393 reason: "Generated command exceeds maximum length".to_string(),
394 suggestions: vec![
395 "Try a simpler query".to_string(),
396 "Reduce the number of filters".to_string(),
397 ],
398 },
399 ValidationStatus::RejectedWriteMode => TranslationResponse::Reject {
400 reason: "Command attempts write operation".to_string(),
401 suggestions: vec![
402 "NL translation only supports read operations".to_string(),
403 "Use CLI directly for write operations".to_string(),
404 ],
405 },
406 ValidationStatus::RejectedUnknown => {
407 let template_names = assembler::templates::TEMPLATES
408 .iter()
409 .map(|(name, _)| *name)
410 .collect::<Vec<_>>()
411 .join(", ");
412 let template_examples = ["query", "search", "trace-path"]
413 .into_iter()
414 .filter_map(assembler::templates::get_template)
415 .map(str::to_string)
416 .collect::<Vec<_>>()
417 .join(" | ");
418
419 TranslationResponse::Reject {
420 reason: "Command does not match any allowed template".to_string(),
421 suggestions: vec![
422 format!("Use supported command templates: {template_names}"),
423 format!("Examples: {template_examples}"),
424 "Try rephrasing your query".to_string(),
425 ],
426 }
427 }
428 }
429 }
430
431 fn elapsed_ms(start_time: Instant) -> u64 {
432 u64::try_from(start_time.elapsed().as_millis()).unwrap_or(u64::MAX)
433 }
434
435 #[allow(clippy::unused_self)] fn classify_intent(&self, text: &str, entities: &ExtractedEntities) -> (Intent, f32) {
443 #[cfg(feature = "classifier")]
444 if let Some(ref pool) = self.classifier_pool {
445 let guard = pool.acquire();
449 let mut classifier = guard.classifier().lock();
450 match classifier.classify(text) {
451 Ok(result) => return (result.intent, result.confidence),
452 Err(e) => {
453 tracing::warn!(
457 target: "sqry_nl::classifier",
458 error = %e,
459 "Classifier failed, using rule-based fallback"
460 );
461 }
462 }
463 drop(classifier);
466 drop(guard);
467 }
468
469 Self::classify_intent_rules(text, entities)
471 }
472
473 fn classify_intent_rules(text: &str, entities: &ExtractedEntities) -> (Intent, f32) {
475 let text_lower = text.to_lowercase();
476
477 if let Some(intent) = Self::classify_graph_intent(&text_lower) {
478 return intent;
479 }
480
481 if let Some(intent) = Self::classify_index_intent(&text_lower) {
482 return intent;
483 }
484
485 if let Some(intent) = Self::classify_text_search_intent(&text_lower, text) {
486 return intent;
487 }
488
489 if let Some(intent) = Self::classify_symbol_query_intent(&text_lower, entities) {
490 return intent;
491 }
492
493 if Self::is_ambiguous(&text_lower) {
494 return (Intent::Ambiguous, 0.3);
495 }
496
497 (Intent::SymbolQuery, 0.5)
498 }
499
500 fn classify_graph_intent(text_lower: &str) -> Option<(Intent, f32)> {
501 if Self::matches_callers(text_lower) {
502 return Some((Intent::FindCallers, 0.85));
503 }
504
505 if Self::matches_callees(text_lower) {
506 return Some((Intent::FindCallees, 0.85));
507 }
508
509 if Self::matches_trace_path(text_lower) {
510 return Some((Intent::TracePath, 0.8));
511 }
512
513 if Self::matches_visualize(text_lower) {
514 return Some((Intent::Visualize, 0.8));
515 }
516
517 None
518 }
519
520 fn matches_callers(text_lower: &str) -> bool {
521 text_lower.contains("callers")
522 || text_lower.contains("who calls")
523 || text_lower.contains("what calls")
524 || text_lower.contains("who uses")
525 || text_lower.contains("who depends")
526 || text_lower.contains("find usages")
527 || text_lower.contains("find all references")
528 || text_lower.contains("where is") && text_lower.contains("used")
529 }
530
531 fn matches_callees(text_lower: &str) -> bool {
532 text_lower.contains("callees")
533 || text_lower.contains("what does") && text_lower.contains("call")
534 || text_lower.contains("functions called by")
535 || text_lower.contains("methods called by")
536 || text_lower.contains("dependencies of")
537 || text_lower.contains("outgoing calls")
538 || text_lower.contains("what functions does")
539 || text_lower.contains("what methods does")
540 }
541
542 fn matches_trace_path(text_lower: &str) -> bool {
543 text_lower.contains("trace")
544 || text_lower.contains("path from")
545 || text_lower.contains("path to")
546 || text_lower.contains("path between")
547 || text_lower.contains("call chain")
548 || text_lower.contains("call sequence")
549 || (text_lower.contains("how does") && text_lower.contains("reach"))
550 || (text_lower.contains("how does") && text_lower.contains("flow"))
551 }
552
553 fn matches_visualize(text_lower: &str) -> bool {
554 text_lower.contains("visualize")
555 || text_lower.contains("diagram")
556 || text_lower.contains("draw")
557 || text_lower.contains("mermaid")
558 || text_lower.contains("dot graph")
559 || (text_lower.contains("generate") && text_lower.contains("graph"))
560 || (text_lower.contains("show") && text_lower.contains("visual"))
561 }
562
563 fn classify_index_intent(text_lower: &str) -> Option<(Intent, f32)> {
564 if (text_lower.contains("index") && text_lower.contains("status"))
565 || text_lower.starts_with("index status")
566 || text_lower.contains("is index")
567 || text_lower.contains("check index")
568 || text_lower.contains("index info")
569 || text_lower.contains("index stat")
570 || text_lower.contains("indexed")
571 || text_lower.contains("what files are indexed")
572 || text_lower.contains("how many symbols")
573 || text_lower.contains("when was index")
574 {
575 return Some((Intent::IndexStatus, 0.85));
576 }
577
578 None
579 }
580
581 fn classify_text_search_intent(text_lower: &str, text: &str) -> Option<(Intent, f32)> {
582 let is_predicate_query = Self::is_predicate_query(text_lower);
583
584 if text_lower.starts_with("grep")
585 || text_lower.starts_with("search for")
586 || text_lower.contains("grep for")
587 || text_lower.contains("grep ")
588 || text_lower.contains("look for")
589 || (text_lower.contains("search") && !text_lower.contains("code search"))
590 || text_lower.contains("todo")
591 || text_lower.contains("fixme")
592 || text_lower.contains("deprecated")
593 || text_lower.contains("copyright")
594 || text_lower.contains("hardcoded")
595 || text.contains('!')
596 || (!is_predicate_query && text_lower.contains("unsafe"))
597 || text_lower.contains(" pub ")
598 || text_lower.contains(" mut ")
599 || (!is_predicate_query && text_lower.contains("async"))
600 || text_lower.contains("unsafe blocks")
601 || text_lower.contains("impl blocks")
602 || text_lower.contains("import")
603 || text_lower.contains("use statement")
604 || text_lower.contains("require")
605 {
606 return Some((Intent::TextSearch, 0.8));
607 }
608
609 None
610 }
611
612 fn classify_symbol_query_intent(
613 text_lower: &str,
614 entities: &ExtractedEntities,
615 ) -> Option<(Intent, f32)> {
616 if text_lower.starts_with("find")
617 || text_lower.starts_with("show")
618 || text_lower.starts_with("list")
619 || text_lower.starts_with("where is")
620 || text_lower.starts_with("where are")
621 || text_lower.contains("function")
622 || text_lower.contains("method")
623 || text_lower.contains("class")
624 || text_lower.contains("struct")
625 || text_lower.contains("enum")
626 || text_lower.contains("trait")
627 || text_lower.contains("interface")
628 || text_lower.contains("module")
629 || text_lower.contains("constant")
630 || text_lower.contains("variable")
631 || text_lower.contains("public")
632 || text_lower.contains("private")
633 || text_lower.contains("defined")
634 {
635 return Some((Intent::SymbolQuery, 0.8));
636 }
637
638 if entities.kind.is_some() {
639 return Some((Intent::SymbolQuery, 0.85));
640 }
641
642 if !entities.symbols.is_empty() {
643 return Some((Intent::SymbolQuery, 0.7));
644 }
645
646 if !entities.languages.is_empty() {
647 return Some((Intent::SymbolQuery, 0.65));
648 }
649
650 None
651 }
652
653 fn is_predicate_query(text_lower: &str) -> bool {
654 text_lower.contains("functions")
655 || text_lower.contains("methods")
656 || text_lower.contains("function")
657 || text_lower.contains("method")
658 }
659
660 fn is_ambiguous(text_lower: &str) -> bool {
661 text_lower.split_whitespace().count() <= 2
662 }
663
664 fn create_response_with_latency(
666 &self,
667 command: String,
668 confidence: f32,
669 intent: Intent,
670 entities: &ExtractedEntities,
671 latency_ms: u64,
672 ) -> TranslationResponse {
673 if confidence >= self.config.execute_threshold {
674 TranslationResponse::Execute {
675 command,
676 confidence,
677 intent,
678 cached: false,
679 latency_ms,
680 }
681 } else if confidence >= self.config.confirm_threshold {
682 let prompt = format!(
683 "I'll run: {}\nConfidence: {:.0}%. Proceed? [y/N]",
684 command,
685 confidence * 100.0
686 );
687 TranslationResponse::Confirm {
688 command,
689 confidence,
690 prompt,
691 }
692 } else {
693 let options = Self::generate_disambiguation_options(entities);
695 TranslationResponse::Disambiguate {
696 options,
697 prompt: "I'm not sure what you mean. Did you want to:".to_string(),
698 }
699 }
700 }
701
702 #[allow(dead_code)]
704 fn create_response(
705 &self,
706 command: String,
707 confidence: f32,
708 intent: Intent,
709 entities: &ExtractedEntities,
710 ) -> TranslationResponse {
711 self.create_response_with_latency(command, confidence, intent, entities, 0)
712 }
713
714 fn generate_disambiguation_options(entities: &ExtractedEntities) -> Vec<DisambiguationOption> {
716 let mut options = Vec::new();
717
718 if let Some(symbol) = entities.primary_symbol() {
719 options.push(DisambiguationOption {
720 command: format!("sqry query \"{symbol}\""),
721 intent: Intent::SymbolQuery,
722 description: format!("Search for symbol \"{symbol}\""),
723 confidence: 0.5,
724 });
725 options.push(DisambiguationOption {
726 command: format!("sqry graph direct-callers \"{symbol}\""),
727 intent: Intent::FindCallers,
728 description: format!("Find callers of \"{symbol}\""),
729 confidence: 0.4,
730 });
731 } else {
732 options.push(DisambiguationOption {
733 command: "sqry query \"<symbol>\"".to_string(),
734 intent: Intent::SymbolQuery,
735 description: "Search for a specific symbol".to_string(),
736 confidence: 0.3,
737 });
738 }
739
740 options
741 }
742
743 fn handle_assembly_error(
745 error: AssemblerError,
746 entities: &ExtractedEntities,
747 ) -> TranslationResponse {
748 match error {
749 AssemblerError::MissingSymbol => {
750 let suggestions = if entities.languages.is_empty() {
751 vec![
752 "Specify what symbol or pattern you're looking for".to_string(),
753 "Example: find \"authenticate\" in rust".to_string(),
754 ]
755 } else {
756 vec![
757 format!(
758 "Try: find <symbol name> in {}",
759 entities.languages.join(", ")
760 ),
761 "Specify what you're looking for in quotes".to_string(),
762 ]
763 };
764 TranslationResponse::Reject {
765 reason: "Could not determine what to search for".to_string(),
766 suggestions,
767 }
768 }
769 AssemblerError::AmbiguousIntent => TranslationResponse::Disambiguate {
770 options: vec![
771 DisambiguationOption {
772 command: "sqry query \"<symbol>\"".to_string(),
773 intent: Intent::SymbolQuery,
774 description: "Search for symbols matching a pattern".to_string(),
775 confidence: 0.3,
776 },
777 DisambiguationOption {
778 command: "sqry graph direct-callers \"<symbol>\"".to_string(),
779 intent: Intent::FindCallers,
780 description: "Find callers of a function".to_string(),
781 confidence: 0.3,
782 },
783 ],
784 prompt: "Please clarify what you'd like to do:".to_string(),
785 },
786 AssemblerError::MissingTracePath => TranslationResponse::Reject {
787 reason: "Trace path requires both source and target symbols".to_string(),
788 suggestions: vec![
789 "Specify two symbols: trace path from X to Y".to_string(),
790 "Example: trace path from login to database".to_string(),
791 ],
792 },
793 AssemblerError::CommandTooLong { .. } => TranslationResponse::Reject {
794 reason: "Generated command is too long".to_string(),
795 suggestions: vec![
796 "Try a simpler query".to_string(),
797 "Reduce the number of filters".to_string(),
798 ],
799 },
800 AssemblerError::NoTemplate(intent_name) => TranslationResponse::Reject {
801 reason: format!("No template available for intent: {intent_name}"),
802 suggestions: vec![
803 "Try a different query type".to_string(),
804 "Supported queries: symbol search, callers, callees, trace path".to_string(),
805 ],
806 },
807 }
808 }
809
810 #[must_use]
812 pub fn translation_count(&self) -> u64 {
813 self.translations.load(Ordering::Relaxed)
814 }
815
816 #[must_use]
820 pub fn cache_stats(&self) -> Option<crate::cache::CacheStats> {
821 self.cache
822 .as_ref()
823 .map(super::cache::TranslationCache::stats)
824 }
825
826 #[must_use]
830 pub fn cache_hit_rate(&self) -> Option<f64> {
831 self.cache
832 .as_ref()
833 .map(super::cache::TranslationCache::hit_rate)
834 }
835
836 pub fn clear_cache(&self) {
840 if let Some(ref cache) = self.cache {
841 cache.clear();
842 }
843 }
844}
845
846impl std::fmt::Debug for Translator {
847 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
848 let mut debug = f.debug_struct("Translator");
849 debug
850 .field("translations", &self.translations.load(Ordering::Relaxed))
851 .field("cache_enabled", &self.cache.is_some());
852 #[cfg(feature = "classifier")]
853 debug.field("classifier_pool", &self.classifier_pool);
854 debug.finish()
855 }
856}
857
858#[cfg(test)]
859mod tests {
860 use super::*;
861
862 #[test]
863 fn test_translator_creation() {
864 let translator = Translator::load_default().unwrap();
865 assert_eq!(translator.translation_count(), 0);
866 }
867
868 #[test]
869 fn test_translate_simple_query() {
870 let mut translator = Translator::load_default().unwrap();
871 let response = translator.translate("find authentication functions");
872
873 if let TranslationResponse::Reject { reason, .. } = &response {
875 assert!(!reason.contains("Could not determine"));
877 }
878 assert_eq!(translator.translation_count(), 1);
879 }
880
881 #[test]
882 fn test_translate_with_language() {
883 let mut translator = Translator::load_default().unwrap();
884 let response = translator.translate("find authentication in rust");
885
886 match response {
887 TranslationResponse::Execute { command, .. }
888 | TranslationResponse::Confirm { command, .. } => {
889 assert!(command.contains("--language rust"));
890 }
891 _ => {} }
893 }
894
895 #[test]
896 fn test_translate_callers() {
897 let mut translator = Translator::load_default().unwrap();
898 let response = translator.translate("who calls authenticate");
899
900 match response {
901 TranslationResponse::Execute { intent, .. } => {
902 assert_eq!(intent, Intent::FindCallers);
903 }
904 TranslationResponse::Confirm { command, .. } => {
905 assert!(
907 command.contains("graph direct-callers") || command.contains("authenticate")
908 );
909 }
910 _ => {}
911 }
912 }
913
914 #[test]
915 fn test_custom_thresholds() {
916 let config = TranslatorConfig {
917 execute_threshold: 0.99,
918 confirm_threshold: 0.90,
919 ..Default::default()
920 };
921 let mut translator = Translator::new(config).unwrap();
922
923 let response = translator.translate("find foo");
925 assert!(!matches!(response, TranslationResponse::Execute { .. }));
926 }
927
928 #[test]
929 fn test_kind_only_query() {
930 let mut translator = Translator::load_default().unwrap();
931
932 let response = translator.translate("list all traits");
934 match response {
935 TranslationResponse::Execute { command, .. }
936 | TranslationResponse::Confirm { command, .. } => {
937 assert!(command.contains("kind:trait"));
939 }
940 _ => panic!("Expected Execute or Confirm response"),
941 }
942 }
943
944 #[test]
945 fn test_snake_case_symbol() {
946 let mut translator = Translator::load_default().unwrap();
947
948 let response = translator.translate("find user_id variable");
950 match response {
951 TranslationResponse::Execute { command, .. }
952 | TranslationResponse::Confirm { command, .. } => {
953 assert!(command.contains("user_id"));
954 }
955 _ => panic!("Expected Execute or Confirm response"),
956 }
957 }
958}
959
960#[cfg(test)]
962mod predicate_translation_tests {
963 use super::*;
964
965 #[test]
966 fn test_async_functions_translation() {
967 let config = TranslatorConfig::default();
968 let mut translator = Translator::new(config).expect("Translator init failed");
969
970 let response = translator.translate("find async functions");
971 match response {
972 TranslationResponse::Execute { command, .. }
973 | TranslationResponse::Confirm { command, .. } => {
974 assert!(command.contains("async:true"));
975 }
976 _ => panic!("should execute or confirm"),
977 }
978 }
979
980 #[test]
981 fn test_unsafe_functions_translation() {
982 let config = TranslatorConfig::default();
983 let mut translator = Translator::new(config).expect("Translator init failed");
984
985 let response = translator.translate("find unsafe functions");
986 match response {
987 TranslationResponse::Execute { command, .. }
988 | TranslationResponse::Confirm { command, .. } => {
989 assert!(command.contains("unsafe:true"));
990 }
991 _ => panic!("should execute or confirm"),
992 }
993 }
994
995 #[test]
996 fn test_public_async_functions_translation() {
997 let config = TranslatorConfig::default();
998 let mut translator = Translator::new(config).expect("Translator init failed");
999
1000 let response = translator.translate("find public async functions");
1001 match response {
1002 TranslationResponse::Execute { command, .. }
1003 | TranslationResponse::Confirm { command, .. } => {
1004 assert!(command.contains("visibility:public"));
1005 assert!(command.contains("async:true"));
1006 }
1007 _ => panic!("should execute or confirm"),
1008 }
1009 }
1010}