Skip to main content

batuta/oracle/
query_engine.rs

1//! Query Engine for Oracle Mode
2//!
3//! Parses natural language queries and extracts structured information
4//! for the recommendation engine.
5
6use super::types::*;
7use std::collections::HashSet;
8
9// =============================================================================
10// Table-Driven Constants
11// =============================================================================
12
13/// Keyword-to-PerformanceHint mapping for simple OR-match patterns.
14/// Each entry: a list of keywords (any match triggers the hint) and the hint variant.
15///
16/// Note: LowLatency also has a compound check ("<" AND "ms") handled separately.
17/// Note: LowMemory uses AND logic ("memory" AND ("low" OR "efficient")) handled separately.
18const HINT_PATTERNS: &[(&[&str], PerformanceHint)] = &[
19    (&["fast", "low latency", "real-time", "realtime"], PerformanceHint::LowLatency),
20    (&["throughput", "high volume"], PerformanceHint::HighThroughput),
21    (&["gpu"], PerformanceHint::GPURequired),
22    (&["distributed", "multi-node", "cluster"], PerformanceHint::Distributed),
23    (&["edge", "embedded", "iot"], PerformanceHint::EdgeDeployment),
24    (&["sovereign", "gdpr", "local only", "eu ai act", "on-premise"], PerformanceHint::Sovereign),
25];
26
27/// Domain-to-OpComplexity mapping for estimate_complexity().
28/// Checked in order; first match wins. Domains not listed fall through to Low.
29const DOMAIN_COMPLEXITY: &[(ProblemDomain, OpComplexity)] = &[
30    (ProblemDomain::DeepLearning, OpComplexity::High),
31    (ProblemDomain::SpeechRecognition, OpComplexity::High),
32    (ProblemDomain::GraphAnalytics, OpComplexity::Medium),
33    (ProblemDomain::SupervisedLearning, OpComplexity::Medium),
34    (ProblemDomain::UnsupervisedLearning, OpComplexity::Medium),
35    (ProblemDomain::MediaProduction, OpComplexity::Medium),
36];
37
38// =============================================================================
39// Query Parser
40// =============================================================================
41
42/// Parsed query with extracted information
43#[derive(Debug, Clone)]
44pub struct ParsedQuery {
45    /// Original query text
46    pub original: String,
47    /// Detected problem domains
48    pub domains: Vec<ProblemDomain>,
49    /// Detected algorithms/techniques
50    pub algorithms: Vec<String>,
51    /// Extracted keywords
52    pub keywords: Vec<String>,
53    /// Detected data size indicators
54    pub data_size: Option<DataSize>,
55    /// Performance requirements detected
56    pub performance_hints: Vec<PerformanceHint>,
57    /// Component mentions
58    pub mentioned_components: Vec<String>,
59}
60
61/// Performance hint extracted from query
62#[derive(Debug, Clone, Copy, PartialEq, Eq)]
63pub enum PerformanceHint {
64    LowLatency,
65    HighThroughput,
66    LowMemory,
67    GPURequired,
68    Distributed,
69    EdgeDeployment,
70    Sovereign,
71}
72
73/// Query parser for natural language queries
74#[derive(Debug, Default)]
75pub(crate) struct QueryParser {
76    /// Known algorithm keywords
77    algorithm_keywords: HashSet<String>,
78    /// Problem domain keywords
79    domain_keywords: Vec<(String, ProblemDomain)>,
80    /// Component names
81    component_names: HashSet<String>,
82}
83
84impl QueryParser {
85    /// Create a new query parser
86    pub(crate) fn new() -> Self {
87        let mut parser = Self::default();
88        parser.initialize_keywords();
89        parser
90    }
91
92    fn initialize_keywords(&mut self) {
93        // Algorithm keywords
94        self.algorithm_keywords.extend(
95            [
96                "random_forest",
97                "random forest",
98                "randomforest",
99                "linear_regression",
100                "linear regression",
101                "linearregression",
102                "logistic_regression",
103                "logistic regression",
104                "logisticregression",
105                "decision_tree",
106                "decision tree",
107                "decisiontree",
108                "gradient_boosting",
109                "gradient boosting",
110                "gbm",
111                "xgboost",
112                "lightgbm",
113                "naive_bayes",
114                "naive bayes",
115                "naivebayes",
116                "knn",
117                "k-nearest",
118                "nearest neighbor",
119                "svm",
120                "support vector",
121                "supportvector",
122                "kmeans",
123                "k-means",
124                "clustering",
125                "pca",
126                "principal component",
127                "dimensionality reduction",
128                "dbscan",
129                "density clustering",
130                "neural network",
131                "deep learning",
132                "transformer",
133                "llm",
134                "lora",
135                "qlora",
136                "fine-tuning",
137                "fine tuning",
138                "finetuning",
139                "whisper",
140                "speech recognition",
141                "speech-to-text",
142                "transcription",
143                "asr",
144            ]
145            .map(String::from),
146        );
147
148        // Domain keywords
149        self.domain_keywords = vec![
150            // Supervised Learning
151            ("classify".into(), ProblemDomain::SupervisedLearning),
152            ("classification".into(), ProblemDomain::SupervisedLearning),
153            ("predict".into(), ProblemDomain::SupervisedLearning),
154            ("regression".into(), ProblemDomain::SupervisedLearning),
155            ("train".into(), ProblemDomain::SupervisedLearning),
156            ("supervised".into(), ProblemDomain::SupervisedLearning),
157            // Unsupervised Learning
158            ("cluster".into(), ProblemDomain::UnsupervisedLearning),
159            ("clustering".into(), ProblemDomain::UnsupervisedLearning),
160            ("unsupervised".into(), ProblemDomain::UnsupervisedLearning),
161            ("anomaly".into(), ProblemDomain::UnsupervisedLearning),
162            ("outlier".into(), ProblemDomain::UnsupervisedLearning),
163            // Deep Learning
164            ("neural".into(), ProblemDomain::DeepLearning),
165            ("deep learning".into(), ProblemDomain::DeepLearning),
166            ("transformer".into(), ProblemDomain::DeepLearning),
167            ("llm".into(), ProblemDomain::DeepLearning),
168            ("fine-tune".into(), ProblemDomain::DeepLearning),
169            ("lora".into(), ProblemDomain::DeepLearning),
170            // Inference
171            ("serve".into(), ProblemDomain::Inference),
172            ("serving".into(), ProblemDomain::Inference),
173            ("inference".into(), ProblemDomain::Inference),
174            ("deploy".into(), ProblemDomain::Inference),
175            ("production".into(), ProblemDomain::Inference),
176            // Speech Recognition
177            ("speech".into(), ProblemDomain::SpeechRecognition),
178            ("whisper".into(), ProblemDomain::SpeechRecognition),
179            ("asr".into(), ProblemDomain::SpeechRecognition),
180            ("transcription".into(), ProblemDomain::SpeechRecognition),
181            ("speech-to-text".into(), ProblemDomain::SpeechRecognition),
182            ("speech recognition".into(), ProblemDomain::SpeechRecognition),
183            // Linear Algebra
184            ("matrix".into(), ProblemDomain::LinearAlgebra),
185            ("tensor".into(), ProblemDomain::LinearAlgebra),
186            ("vector".into(), ProblemDomain::LinearAlgebra),
187            ("linear algebra".into(), ProblemDomain::LinearAlgebra),
188            ("simd".into(), ProblemDomain::LinearAlgebra),
189            // Vector Search
190            ("similarity".into(), ProblemDomain::VectorSearch),
191            ("embedding".into(), ProblemDomain::VectorSearch),
192            ("vector search".into(), ProblemDomain::VectorSearch),
193            ("nearest neighbor".into(), ProblemDomain::VectorSearch),
194            // Graph Analytics
195            ("graph".into(), ProblemDomain::GraphAnalytics),
196            ("pagerank".into(), ProblemDomain::GraphAnalytics),
197            ("pathfinding".into(), ProblemDomain::GraphAnalytics),
198            ("community".into(), ProblemDomain::GraphAnalytics),
199            // Python Migration
200            ("python".into(), ProblemDomain::PythonMigration),
201            ("sklearn".into(), ProblemDomain::PythonMigration),
202            ("scikit".into(), ProblemDomain::PythonMigration),
203            ("numpy".into(), ProblemDomain::PythonMigration),
204            ("pandas".into(), ProblemDomain::PythonMigration),
205            ("pytorch".into(), ProblemDomain::PythonMigration),
206            // C Migration
207            ("c code".into(), ProblemDomain::CMigration),
208            ("c++".into(), ProblemDomain::CMigration),
209            ("cpp".into(), ProblemDomain::CMigration),
210            // Shell Migration
211            ("bash".into(), ProblemDomain::ShellMigration),
212            ("shell".into(), ProblemDomain::ShellMigration),
213            ("script".into(), ProblemDomain::ShellMigration),
214            // Distribution
215            ("distributed".into(), ProblemDomain::DistributedCompute),
216            ("parallel".into(), ProblemDomain::DistributedCompute),
217            ("multi-node".into(), ProblemDomain::DistributedCompute),
218            ("cluster".into(), ProblemDomain::DistributedCompute),
219            // Data Pipeline
220            ("data loading".into(), ProblemDomain::DataPipeline),
221            ("csv".into(), ProblemDomain::DataPipeline),
222            ("parquet".into(), ProblemDomain::DataPipeline),
223            ("etl".into(), ProblemDomain::DataPipeline),
224            // Model Serving
225            ("lambda".into(), ProblemDomain::ModelServing),
226            ("serverless".into(), ProblemDomain::ModelServing),
227            ("container".into(), ProblemDomain::ModelServing),
228            ("edge".into(), ProblemDomain::ModelServing),
229            // Quality
230            ("test".into(), ProblemDomain::Testing),
231            ("coverage".into(), ProblemDomain::Testing),
232            ("mutation".into(), ProblemDomain::Testing),
233            ("profile".into(), ProblemDomain::Profiling),
234            ("trace".into(), ProblemDomain::Profiling),
235            ("syscall".into(), ProblemDomain::Profiling),
236            ("validate".into(), ProblemDomain::Validation),
237            ("quality".into(), ProblemDomain::Validation),
238            // Formal Verification
239            ("formal verification".into(), ProblemDomain::Validation),
240            ("kani".into(), ProblemDomain::Validation),
241            ("contract".into(), ProblemDomain::Validation),
242            ("provable".into(), ProblemDomain::Validation),
243            ("proof".into(), ProblemDomain::Validation),
244            ("harness".into(), ProblemDomain::Validation),
245            ("bounded model checking".into(), ProblemDomain::Validation),
246            ("verification".into(), ProblemDomain::Validation),
247            // Model Parity / Ground Truth
248            ("parity".into(), ProblemDomain::Testing),
249            ("falsification".into(), ProblemDomain::Testing),
250            ("ground truth".into(), ProblemDomain::Testing),
251            ("oracle test".into(), ProblemDomain::Testing),
252            ("conversion parity".into(), ProblemDomain::Testing),
253            ("quantization drift".into(), ProblemDomain::Testing),
254            // Media Production
255            ("video".into(), ProblemDomain::MediaProduction),
256            ("render".into(), ProblemDomain::MediaProduction),
257            ("mlt".into(), ProblemDomain::MediaProduction),
258            ("encode".into(), ProblemDomain::MediaProduction),
259            ("decode".into(), ProblemDomain::MediaProduction),
260            ("transition".into(), ProblemDomain::MediaProduction),
261            ("media".into(), ProblemDomain::MediaProduction),
262            ("editing".into(), ProblemDomain::MediaProduction),
263            ("ffmpeg".into(), ProblemDomain::MediaProduction),
264            ("course".into(), ProblemDomain::MediaProduction),
265            ("screencast".into(), ProblemDomain::MediaProduction),
266            ("dissolve".into(), ProblemDomain::MediaProduction),
267            ("fade".into(), ProblemDomain::MediaProduction),
268            ("title card".into(), ProblemDomain::MediaProduction),
269            ("audio".into(), ProblemDomain::MediaProduction),
270            ("transcribe".into(), ProblemDomain::MediaProduction),
271            ("subtitle".into(), ProblemDomain::MediaProduction),
272            ("caption".into(), ProblemDomain::MediaProduction),
273            ("vocabulary".into(), ProblemDomain::MediaProduction),
274            ("key terms".into(), ProblemDomain::MediaProduction),
275            ("reflection".into(), ProblemDomain::MediaProduction),
276            ("landing page".into(), ProblemDomain::MediaProduction),
277            ("outline".into(), ProblemDomain::MediaProduction),
278            ("syllabus".into(), ProblemDomain::MediaProduction),
279            ("svg".into(), ProblemDomain::MediaProduction),
280            ("banner".into(), ProblemDomain::MediaProduction),
281            ("thumbnail".into(), ProblemDomain::MediaProduction),
282            ("grid protocol".into(), ProblemDomain::MediaProduction),
283            ("content quality".into(), ProblemDomain::MediaProduction),
284            ("completeness".into(), ProblemDomain::MediaProduction),
285            ("conformance".into(), ProblemDomain::MediaProduction),
286            ("freshness".into(), ProblemDomain::MediaProduction),
287            ("transcript".into(), ProblemDomain::MediaProduction),
288            ("tts".into(), ProblemDomain::MediaProduction),
289            ("narration".into(), ProblemDomain::MediaProduction),
290            ("text-to-speech".into(), ProblemDomain::MediaProduction),
291            ("av sync".into(), ProblemDomain::MediaProduction),
292            ("av-sync".into(), ProblemDomain::MediaProduction),
293            ("publishing".into(), ProblemDomain::MediaProduction),
294            ("coursepage".into(), ProblemDomain::MediaProduction),
295        ];
296
297        // Component names
298        self.component_names.extend(
299            [
300                "trueno",
301                "trueno-db",
302                "trueno-graph",
303                "trueno-viz",
304                "trueno-rag",
305                "aprender",
306                "entrenar",
307                "realizar",
308                "depyler",
309                "decy",
310                "bashrs",
311                "ruchy",
312                "batuta",
313                "repartir",
314                "pforge",
315                "certeza",
316                "pmat",
317                "renacer",
318                "alimentar",
319                "pacha",
320                "whisper-apr",
321                "simular",
322                "probar",
323                "pepita",
324                "provable-contracts",
325                "tiny-model-ground-truth",
326                "forjar",
327                "rmedia",
328            ]
329            .map(String::from),
330        );
331    }
332
333    /// Parse a natural language query
334    pub(crate) fn parse(&self, query: &str) -> ParsedQuery {
335        let lower = query.to_lowercase();
336
337        ParsedQuery {
338            original: query.to_string(),
339            domains: self.extract_domains(&lower),
340            algorithms: self.extract_algorithms(&lower),
341            keywords: self.extract_keywords(&lower),
342            data_size: self.extract_data_size(&lower),
343            performance_hints: self.extract_performance_hints(&lower),
344            mentioned_components: self.extract_components(&lower),
345        }
346    }
347
348    fn extract_domains(&self, query: &str) -> Vec<ProblemDomain> {
349        let mut domains = Vec::new();
350        let mut seen = HashSet::new();
351
352        for (keyword, domain) in &self.domain_keywords {
353            if query.contains(keyword) && !seen.contains(domain) {
354                domains.push(*domain);
355                seen.insert(*domain);
356            }
357        }
358
359        domains
360    }
361
362    fn extract_algorithms(&self, query: &str) -> Vec<String> {
363        let mut algorithms = Vec::new();
364
365        for algo in &self.algorithm_keywords {
366            if query.contains(algo) {
367                // Normalize algorithm name
368                let normalized = algo.replace([' ', '-'], "_").to_lowercase();
369                if !algorithms.contains(&normalized) {
370                    algorithms.push(normalized);
371                }
372            }
373        }
374
375        algorithms
376    }
377
378    fn extract_keywords(&self, query: &str) -> Vec<String> {
379        // Extract significant keywords (words > 3 chars, not stopwords)
380        let stopwords: HashSet<_> = [
381            "the", "and", "for", "with", "how", "what", "can", "does", "want", "need", "use",
382            "using", "have", "this", "that", "from", "into", "about", "which", "when", "where",
383            "should",
384        ]
385        .iter()
386        .map(|s| (*s).to_string())
387        .collect();
388
389        query
390            .split_whitespace()
391            .map(|w| w.trim_matches(|c: char| !c.is_alphanumeric()))
392            .filter(|w| w.len() > 3 && !stopwords.contains(*w))
393            .map(String::from)
394            .collect()
395    }
396
397    fn extract_data_size(&self, query: &str) -> Option<DataSize> {
398        // Look for patterns like "1M samples", "100K rows", "1GB data"
399        // Note: regex patterns defined for reference but simple string matching used below
400        // for patterns like: r"(\d+)\s*[mM]\s*(samples?|rows?|records?|items?)"
401
402        for (suffix, multiplier) in [
403            ("m samples", 1_000_000),
404            ("m rows", 1_000_000),
405            ("k samples", 1_000),
406            ("k rows", 1_000),
407            ("million", 1_000_000),
408            ("thousand", 1_000),
409            ("billion", 1_000_000_000),
410        ] {
411            if let Some(idx) = query.find(suffix) {
412                // Look for number before suffix
413                let before = &query[..idx];
414                if let Some(num_str) = before.split_whitespace().last() {
415                    if let Ok(num) = num_str.parse::<u64>() {
416                        return Some(DataSize::samples(num * multiplier));
417                    }
418                }
419            }
420        }
421
422        // Look for "large" / "small" / "huge" indicators
423        if query.contains("large") || query.contains("huge") || query.contains("big") {
424            return Some(DataSize::samples(1_000_000));
425        }
426        if query.contains("small") || query.contains("tiny") {
427            return Some(DataSize::samples(1_000));
428        }
429
430        None
431    }
432
433    fn extract_performance_hints(&self, query: &str) -> Vec<PerformanceHint> {
434        let mut hints = Vec::new();
435
436        // Table-driven: simple OR-match patterns
437        for &(keywords, hint) in HINT_PATTERNS {
438            if keywords.iter().any(|kw| query.contains(kw)) {
439                hints.push(hint);
440            }
441        }
442
443        // Compound check for LowLatency: "<" AND "ms" (e.g. "<10ms")
444        // Only add if not already matched by the table-driven keywords above
445        if !hints.contains(&PerformanceHint::LowLatency)
446            && query.contains('<')
447            && query.contains("ms")
448        {
449            hints.push(PerformanceHint::LowLatency);
450        }
451
452        // Compound AND logic: "memory" AND ("low" OR "efficient")
453        if query.contains("memory") && (query.contains("low") || query.contains("efficient")) {
454            hints.push(PerformanceHint::LowMemory);
455        }
456
457        hints
458    }
459
460    fn extract_components(&self, query: &str) -> Vec<String> {
461        self.component_names
462            .iter()
463            .filter(|name| query.contains(name.as_str()) || query.contains(&name.replace('-', " ")))
464            .cloned()
465            .collect()
466    }
467}
468
469// =============================================================================
470// Query Engine
471// =============================================================================
472
473/// Query engine that processes queries and generates responses
474pub struct QueryEngine {
475    parser: QueryParser,
476}
477
478impl Default for QueryEngine {
479    fn default() -> Self {
480        Self::new()
481    }
482}
483
484impl QueryEngine {
485    /// Create a new query engine
486    pub fn new() -> Self {
487        Self { parser: QueryParser::new() }
488    }
489
490    /// Parse a query string
491    pub fn parse(&self, query: &str) -> ParsedQuery {
492        self.parser.parse(query)
493    }
494
495    /// Determine the primary problem domain from a parsed query
496    pub fn primary_domain(&self, parsed: &ParsedQuery) -> Option<ProblemDomain> {
497        parsed.domains.first().copied()
498    }
499
500    /// Determine the primary algorithm mentioned
501    pub fn primary_algorithm<'a>(&self, parsed: &'a ParsedQuery) -> Option<&'a str> {
502        parsed.algorithms.first().map(|s| s.as_str())
503    }
504
505    /// Check if the query requires GPU
506    pub fn requires_gpu(&self, parsed: &ParsedQuery) -> bool {
507        parsed.performance_hints.contains(&PerformanceHint::GPURequired)
508    }
509
510    /// Check if the query requires distributed computing
511    pub fn requires_distribution(&self, parsed: &ParsedQuery) -> bool {
512        parsed.performance_hints.contains(&PerformanceHint::Distributed)
513            || parsed.data_size.map(|s| s.is_large()).unwrap_or(false)
514    }
515
516    /// Check if the query requires sovereign/local execution
517    pub fn requires_sovereign(&self, parsed: &ParsedQuery) -> bool {
518        parsed.performance_hints.contains(&PerformanceHint::Sovereign)
519    }
520
521    /// Estimate operation complexity from parsed query
522    pub fn estimate_complexity(&self, parsed: &ParsedQuery) -> OpComplexity {
523        // Matrix operations are high complexity (keyword-based, checked first)
524        if parsed.keywords.iter().any(|k| k.contains("matrix") || k.contains("matmul")) {
525            return OpComplexity::High;
526        }
527
528        // Table-driven domain-to-complexity lookup (first match wins)
529        for &(domain, complexity) in DOMAIN_COMPLEXITY {
530            if parsed.domains.contains(&domain) {
531                return complexity;
532            }
533        }
534
535        // Default to low
536        OpComplexity::Low
537    }
538}
539
540// =============================================================================
541// Tests
542// =============================================================================
543
544#[cfg(test)]
545mod tests {
546    use super::*;
547
548    // =========================================================================
549    // Test Fixtures
550    // =========================================================================
551
552    /// Create a fresh QueryParser for tests
553    fn parser() -> QueryParser {
554        QueryParser::new()
555    }
556
557    /// Create a fresh QueryEngine for tests
558    fn engine() -> QueryEngine {
559        QueryEngine::new()
560    }
561
562    // =========================================================================
563    // QueryParser Tests
564    // =========================================================================
565
566    #[test]
567    fn test_parser_new() {
568        let p = parser();
569        assert!(!p.algorithm_keywords.is_empty());
570        assert!(!p.domain_keywords.is_empty());
571        assert!(!p.component_names.is_empty());
572    }
573
574    #[test]
575    fn test_parse_basic() {
576        let parsed = parser().parse("Train a random forest classifier");
577
578        assert_eq!(parsed.original, "Train a random forest classifier");
579        assert!(!parsed.domains.is_empty());
580        assert!(!parsed.algorithms.is_empty());
581    }
582
583    // =========================================================================
584    // Domain Extraction Tests
585    // =========================================================================
586
587    #[test]
588    fn test_extract_supervised_learning() {
589        let parsed = parser().parse("I want to train a classification model");
590
591        assert!(parsed.domains.contains(&ProblemDomain::SupervisedLearning));
592    }
593
594    #[test]
595    fn test_extract_unsupervised_learning() {
596        let parsed = parser().parse("Help me cluster my data for anomaly detection");
597
598        assert!(parsed.domains.contains(&ProblemDomain::UnsupervisedLearning));
599    }
600
601    #[test]
602    fn test_extract_deep_learning() {
603        let parsed = parser().parse("Fine-tune a transformer with LoRA");
604
605        assert!(parsed.domains.contains(&ProblemDomain::DeepLearning));
606    }
607
608    #[test]
609    fn test_extract_inference() {
610        let parsed = parser().parse("Deploy model for production inference");
611
612        assert!(parsed.domains.contains(&ProblemDomain::Inference));
613    }
614
615    #[test]
616    fn test_extract_python_migration() {
617        let parsed = parser().parse("Convert my sklearn pipeline to Rust");
618
619        assert!(parsed.domains.contains(&ProblemDomain::PythonMigration));
620    }
621
622    #[test]
623    fn test_extract_linear_algebra() {
624        let parsed = parser().parse("Fast matrix multiplication with SIMD");
625
626        assert!(parsed.domains.contains(&ProblemDomain::LinearAlgebra));
627    }
628
629    #[test]
630    fn test_extract_graph_analytics() {
631        let parsed = parser().parse("Run pagerank on my graph");
632
633        assert!(parsed.domains.contains(&ProblemDomain::GraphAnalytics));
634    }
635
636    #[test]
637    fn test_extract_multiple_domains() {
638        let parsed = parser().parse("Train a classifier on python sklearn data");
639
640        assert!(parsed.domains.len() >= 2);
641        assert!(parsed.domains.contains(&ProblemDomain::SupervisedLearning));
642        assert!(parsed.domains.contains(&ProblemDomain::PythonMigration));
643    }
644
645    // =========================================================================
646    // Algorithm Extraction Tests
647    // =========================================================================
648
649    #[test]
650    fn test_extract_random_forest() {
651        let parsed = parser().parse("Train a random forest on my data");
652
653        assert!(parsed.algorithms.iter().any(|a| a.contains("random_forest")));
654    }
655
656    #[test]
657    fn test_extract_gradient_boosting() {
658        let parsed = parser().parse("Use gradient boosting for regression");
659
660        assert!(parsed.algorithms.iter().any(|a| a.contains("gradient_boosting") || a == "gbm"));
661    }
662
663    #[test]
664    fn test_extract_kmeans() {
665        let parsed = parser().parse("Cluster with k-means algorithm");
666
667        assert!(parsed.algorithms.iter().any(|a| a.contains("kmeans") || a.contains("k_means")));
668    }
669
670    #[test]
671    fn test_extract_lora() {
672        let parsed = parser().parse("Fine-tune with LoRA");
673
674        assert!(parsed.algorithms.iter().any(|a| a.contains("lora")));
675    }
676
677    // =========================================================================
678    // Data Size Extraction Tests
679    // =========================================================================
680
681    #[test]
682    fn test_extract_data_size_million() {
683        let parsed = parser().parse("Train on 1 million samples");
684
685        assert!(parsed.data_size.is_some());
686        let size = parsed.data_size.expect("unexpected failure");
687        assert!(size.is_large());
688    }
689
690    #[test]
691    fn test_extract_data_size_1m() {
692        let parsed = parser().parse("Process 5m rows of data");
693
694        assert!(parsed.data_size.is_some());
695    }
696
697    #[test]
698    fn test_extract_data_size_thousand() {
699        let parsed = parser().parse("Test on 10 thousand samples");
700
701        assert!(parsed.data_size.is_some());
702        let size = parsed.data_size.expect("unexpected failure");
703        assert!(!size.is_large());
704    }
705
706    #[test]
707    fn test_extract_data_size_large_indicator() {
708        let parsed = parser().parse("Handle large dataset");
709
710        assert!(parsed.data_size.is_some());
711        assert!(parsed.data_size.expect("unexpected failure").is_large());
712    }
713
714    #[test]
715    fn test_extract_data_size_small_indicator() {
716        let parsed = parser().parse("Small dataset for testing");
717
718        assert!(parsed.data_size.is_some());
719        assert!(!parsed.data_size.expect("unexpected failure").is_large());
720    }
721
722    // =========================================================================
723    // Performance Hints Tests
724    // =========================================================================
725
726    #[test]
727    fn test_extract_low_latency() {
728        let parsed = parser().parse("Need fast inference with low latency");
729
730        assert!(parsed.performance_hints.contains(&PerformanceHint::LowLatency));
731    }
732
733    #[test]
734    fn test_extract_gpu_required() {
735        let parsed = parser().parse("Train model on GPU");
736
737        assert!(parsed.performance_hints.contains(&PerformanceHint::GPURequired));
738    }
739
740    #[test]
741    fn test_extract_distributed() {
742        let parsed = parser().parse("Distributed training on multi-node cluster");
743
744        assert!(parsed.performance_hints.contains(&PerformanceHint::Distributed));
745    }
746
747    #[test]
748    fn test_extract_edge_deployment() {
749        let parsed = parser().parse("Deploy model to edge devices");
750
751        assert!(parsed.performance_hints.contains(&PerformanceHint::EdgeDeployment));
752    }
753
754    #[test]
755    fn test_extract_sovereign() {
756        let parsed = parser().parse("GDPR compliant, sovereign execution");
757
758        assert!(parsed.performance_hints.contains(&PerformanceHint::Sovereign));
759    }
760
761    #[test]
762    fn test_extract_eu_ai_act() {
763        let parsed = parser().parse("Must comply with EU AI Act");
764
765        assert!(parsed.performance_hints.contains(&PerformanceHint::Sovereign));
766    }
767
768    // =========================================================================
769    // Component Extraction Tests
770    // =========================================================================
771
772    #[test]
773    fn test_extract_component_trueno() {
774        let parsed = parser().parse("Use trueno for tensor operations");
775
776        assert!(parsed.mentioned_components.contains(&"trueno".to_string()));
777    }
778
779    #[test]
780    fn test_extract_component_aprender() {
781        let parsed = parser().parse("Train with aprender random forest");
782
783        assert!(parsed.mentioned_components.contains(&"aprender".to_string()));
784    }
785
786    #[test]
787    fn test_extract_multiple_components() {
788        let parsed = parser().parse("Use depyler to convert sklearn to aprender");
789
790        assert!(parsed.mentioned_components.contains(&"depyler".to_string()));
791        assert!(parsed.mentioned_components.contains(&"aprender".to_string()));
792    }
793
794    // =========================================================================
795    // QueryEngine Tests
796    // =========================================================================
797
798    #[test]
799    fn test_query_engine_new() {
800        let e = engine();
801        let parsed = e.parse("Test query");
802        assert!(!parsed.original.is_empty());
803    }
804
805    #[test]
806    fn test_query_engine_default() {
807        let e = QueryEngine::default();
808        let parsed = e.parse("Test");
809        assert_eq!(parsed.original, "Test");
810    }
811
812    #[test]
813    fn test_primary_domain() {
814        let e = engine();
815        let parsed = e.parse("Train a classifier");
816
817        let domain = e.primary_domain(&parsed);
818        assert!(domain.is_some());
819        assert_eq!(domain.expect("unexpected failure"), ProblemDomain::SupervisedLearning);
820    }
821
822    #[test]
823    fn test_primary_algorithm() {
824        let e = engine();
825        let parsed = e.parse("Use random forest");
826
827        let algo = e.primary_algorithm(&parsed);
828        assert!(algo.is_some());
829        assert!(algo.expect("unexpected failure").contains("random_forest"));
830    }
831
832    #[test]
833    fn test_requires_gpu() {
834        let e = engine();
835
836        let parsed = e.parse("Train on GPU");
837        assert!(e.requires_gpu(&parsed));
838
839        let parsed = e.parse("Simple CPU training");
840        assert!(!e.requires_gpu(&parsed));
841    }
842
843    #[test]
844    fn test_requires_distribution() {
845        let e = engine();
846
847        let parsed = e.parse("Distributed training");
848        assert!(e.requires_distribution(&parsed));
849
850        let parsed = e.parse("Train on 1 billion samples");
851        assert!(e.requires_distribution(&parsed));
852
853        let parsed = e.parse("Small local training");
854        assert!(!e.requires_distribution(&parsed));
855    }
856
857    #[test]
858    fn test_requires_sovereign() {
859        let e = engine();
860
861        let parsed = e.parse("GDPR compliant local execution");
862        assert!(e.requires_sovereign(&parsed));
863
864        let parsed = e.parse("Cloud training");
865        assert!(!e.requires_sovereign(&parsed));
866    }
867
868    #[test]
869    fn test_estimate_complexity_high() {
870        let e = engine();
871
872        let parsed = e.parse("Matrix multiplication");
873        assert_eq!(e.estimate_complexity(&parsed), OpComplexity::High);
874
875        let parsed = e.parse("Deep learning training");
876        assert_eq!(e.estimate_complexity(&parsed), OpComplexity::High);
877    }
878
879    #[test]
880    fn test_estimate_complexity_medium() {
881        let e = engine();
882
883        let parsed = e.parse("Train a classifier");
884        assert_eq!(e.estimate_complexity(&parsed), OpComplexity::Medium);
885
886        let parsed = e.parse("Graph pagerank");
887        assert_eq!(e.estimate_complexity(&parsed), OpComplexity::Medium);
888    }
889
890    #[test]
891    fn test_estimate_complexity_low() {
892        let e = engine();
893
894        let parsed = e.parse("Simple data loading");
895        assert_eq!(e.estimate_complexity(&parsed), OpComplexity::Low);
896    }
897
898    // =========================================================================
899    // Integration Tests
900    // =========================================================================
901
902    #[test]
903    fn test_full_query_parsing() {
904        let e = engine();
905        let parsed =
906            e.parse("I need to train a random forest on 1 million samples with GPU acceleration");
907
908        // Should detect supervised learning
909        assert!(parsed.domains.contains(&ProblemDomain::SupervisedLearning));
910
911        // Should detect random forest
912        assert!(parsed.algorithms.iter().any(|a| a.contains("random_forest")));
913
914        // Should detect large data
915        assert!(parsed.data_size.is_some());
916        assert!(parsed.data_size.expect("unexpected failure").is_large());
917
918        // Should detect GPU requirement
919        assert!(parsed.performance_hints.contains(&PerformanceHint::GPURequired));
920    }
921
922    #[test]
923    fn test_sklearn_migration_query() {
924        let e = engine();
925        let parsed = e.parse("Convert my sklearn pipeline with RandomForest to Rust aprender");
926
927        assert!(parsed.domains.contains(&ProblemDomain::PythonMigration));
928        assert!(parsed.algorithms.iter().any(|a| a.contains("random")));
929        assert!(parsed.mentioned_components.contains(&"aprender".to_string()));
930    }
931
932    #[test]
933    fn test_inference_query() {
934        let e = engine();
935        let parsed = e.parse("Deploy model to AWS Lambda with <10ms latency");
936
937        assert!(parsed.domains.contains(&ProblemDomain::Inference));
938        assert!(parsed.domains.contains(&ProblemDomain::ModelServing));
939        assert!(parsed.performance_hints.contains(&PerformanceHint::LowLatency));
940    }
941
942    // =========================================================================
943    // Falsification: Course Creation Workflow Coverage
944    // =========================================================================
945
946    #[test]
947    fn test_media_production_course_workflows() {
948        let p = parser();
949
950        // Every concept from resolve-pipeline make demo* MUST route to MediaProduction
951        let must_route: &[(&str, &str)] = &[
952            ("render a course video", "demo: rmedia course"),
953            ("transcribe audio from course", "demo: --transcribe"),
954            ("check transcript quality", "demo-course: whisper-apr score"),
955            ("generate course outline", "demo-outline: rmedia outline"),
956            ("extract key terms from transcripts", "demo-key-terms"),
957            ("generate reflection prompts", "demo-reflection"),
958            ("convert svg banner to png", "demo-banners: svg2png"),
959            ("generate landing page for course", "demo-coursera-page"),
960            ("tts narration for course video", "demo-tts: espeak-ng"),
961            ("check av sync on rendered video", "demo-av-check: probador"),
962            ("coursera publishing pipeline", "demo-coursera"),
963            ("subtitle burn in", "existing capability"),
964            ("generate thumbnail for video", "thumbnail generation"),
965            ("vocabulary enrichment from transcripts", "vocab-enrich skill"),
966            ("course quality scoring", "coursera-score target"),
967            ("content completeness check", "quality dimension"),
968            ("syllabus generation", "outline variant"),
969        ];
970
971        for (query, context) in must_route {
972            let parsed = p.parse(query);
973            assert!(
974                parsed.domains.contains(&ProblemDomain::MediaProduction),
975                "FAIL: '{}' ({}) did not route to MediaProduction. Got: {:?}",
976                query,
977                context,
978                parsed.domains,
979            );
980        }
981    }
982
983    #[test]
984    fn test_media_production_complexity_medium() {
985        let e = engine();
986        let parsed = e.parse("render course video with encoding");
987        assert_eq!(e.estimate_complexity(&parsed), OpComplexity::Medium);
988    }
989
990    /// Harder falsification: bare concepts without "video"/"course" crutch words
991    #[test]
992    fn test_media_production_bare_concepts() {
993        let p = parser();
994
995        let must_route: &[(&str, &str)] = &[
996            ("tts narration pipeline", "demo-tts"),
997            ("text-to-speech for lectures", "demo-tts variant"),
998            ("av sync verification", "demo-av-check"),
999            ("publishing workflow", "coursera pipeline"),
1000            ("generate coursepage prompt", "demo-coursera-page"),
1001        ];
1002
1003        for (query, context) in must_route {
1004            let parsed = p.parse(query);
1005            assert!(
1006                parsed.domains.contains(&ProblemDomain::MediaProduction),
1007                "FAIL: '{}' ({}) did not route to MediaProduction. Got: {:?}",
1008                query,
1009                context,
1010                parsed.domains,
1011            );
1012        }
1013    }
1014}