1use super::types::*;
7use std::collections::HashSet;
8
9const HINT_PATTERNS: &[(&[&str], PerformanceHint)] = &[
19 (&["fast", "low latency", "real-time", "realtime"], PerformanceHint::LowLatency),
20 (&["throughput", "high volume"], PerformanceHint::HighThroughput),
21 (&["gpu"], PerformanceHint::GPURequired),
22 (&["distributed", "multi-node", "cluster"], PerformanceHint::Distributed),
23 (&["edge", "embedded", "iot"], PerformanceHint::EdgeDeployment),
24 (&["sovereign", "gdpr", "local only", "eu ai act", "on-premise"], PerformanceHint::Sovereign),
25];
26
27const DOMAIN_COMPLEXITY: &[(ProblemDomain, OpComplexity)] = &[
30 (ProblemDomain::DeepLearning, OpComplexity::High),
31 (ProblemDomain::SpeechRecognition, OpComplexity::High),
32 (ProblemDomain::GraphAnalytics, OpComplexity::Medium),
33 (ProblemDomain::SupervisedLearning, OpComplexity::Medium),
34 (ProblemDomain::UnsupervisedLearning, OpComplexity::Medium),
35 (ProblemDomain::MediaProduction, OpComplexity::Medium),
36];
37
38#[derive(Debug, Clone)]
44pub struct ParsedQuery {
45 pub original: String,
47 pub domains: Vec<ProblemDomain>,
49 pub algorithms: Vec<String>,
51 pub keywords: Vec<String>,
53 pub data_size: Option<DataSize>,
55 pub performance_hints: Vec<PerformanceHint>,
57 pub mentioned_components: Vec<String>,
59}
60
61#[derive(Debug, Clone, Copy, PartialEq, Eq)]
63pub enum PerformanceHint {
64 LowLatency,
65 HighThroughput,
66 LowMemory,
67 GPURequired,
68 Distributed,
69 EdgeDeployment,
70 Sovereign,
71}
72
73#[derive(Debug, Default)]
75pub(crate) struct QueryParser {
76 algorithm_keywords: HashSet<String>,
78 domain_keywords: Vec<(String, ProblemDomain)>,
80 component_names: HashSet<String>,
82}
83
84impl QueryParser {
85 pub(crate) fn new() -> Self {
87 let mut parser = Self::default();
88 parser.initialize_keywords();
89 parser
90 }
91
92 fn initialize_keywords(&mut self) {
93 self.algorithm_keywords.extend(
95 [
96 "random_forest",
97 "random forest",
98 "randomforest",
99 "linear_regression",
100 "linear regression",
101 "linearregression",
102 "logistic_regression",
103 "logistic regression",
104 "logisticregression",
105 "decision_tree",
106 "decision tree",
107 "decisiontree",
108 "gradient_boosting",
109 "gradient boosting",
110 "gbm",
111 "xgboost",
112 "lightgbm",
113 "naive_bayes",
114 "naive bayes",
115 "naivebayes",
116 "knn",
117 "k-nearest",
118 "nearest neighbor",
119 "svm",
120 "support vector",
121 "supportvector",
122 "kmeans",
123 "k-means",
124 "clustering",
125 "pca",
126 "principal component",
127 "dimensionality reduction",
128 "dbscan",
129 "density clustering",
130 "neural network",
131 "deep learning",
132 "transformer",
133 "llm",
134 "lora",
135 "qlora",
136 "fine-tuning",
137 "fine tuning",
138 "finetuning",
139 "whisper",
140 "speech recognition",
141 "speech-to-text",
142 "transcription",
143 "asr",
144 ]
145 .map(String::from),
146 );
147
148 self.domain_keywords = vec![
150 ("classify".into(), ProblemDomain::SupervisedLearning),
152 ("classification".into(), ProblemDomain::SupervisedLearning),
153 ("predict".into(), ProblemDomain::SupervisedLearning),
154 ("regression".into(), ProblemDomain::SupervisedLearning),
155 ("train".into(), ProblemDomain::SupervisedLearning),
156 ("supervised".into(), ProblemDomain::SupervisedLearning),
157 ("cluster".into(), ProblemDomain::UnsupervisedLearning),
159 ("clustering".into(), ProblemDomain::UnsupervisedLearning),
160 ("unsupervised".into(), ProblemDomain::UnsupervisedLearning),
161 ("anomaly".into(), ProblemDomain::UnsupervisedLearning),
162 ("outlier".into(), ProblemDomain::UnsupervisedLearning),
163 ("neural".into(), ProblemDomain::DeepLearning),
165 ("deep learning".into(), ProblemDomain::DeepLearning),
166 ("transformer".into(), ProblemDomain::DeepLearning),
167 ("llm".into(), ProblemDomain::DeepLearning),
168 ("fine-tune".into(), ProblemDomain::DeepLearning),
169 ("lora".into(), ProblemDomain::DeepLearning),
170 ("serve".into(), ProblemDomain::Inference),
172 ("serving".into(), ProblemDomain::Inference),
173 ("inference".into(), ProblemDomain::Inference),
174 ("deploy".into(), ProblemDomain::Inference),
175 ("production".into(), ProblemDomain::Inference),
176 ("speech".into(), ProblemDomain::SpeechRecognition),
178 ("whisper".into(), ProblemDomain::SpeechRecognition),
179 ("asr".into(), ProblemDomain::SpeechRecognition),
180 ("transcription".into(), ProblemDomain::SpeechRecognition),
181 ("speech-to-text".into(), ProblemDomain::SpeechRecognition),
182 ("speech recognition".into(), ProblemDomain::SpeechRecognition),
183 ("matrix".into(), ProblemDomain::LinearAlgebra),
185 ("tensor".into(), ProblemDomain::LinearAlgebra),
186 ("vector".into(), ProblemDomain::LinearAlgebra),
187 ("linear algebra".into(), ProblemDomain::LinearAlgebra),
188 ("simd".into(), ProblemDomain::LinearAlgebra),
189 ("similarity".into(), ProblemDomain::VectorSearch),
191 ("embedding".into(), ProblemDomain::VectorSearch),
192 ("vector search".into(), ProblemDomain::VectorSearch),
193 ("nearest neighbor".into(), ProblemDomain::VectorSearch),
194 ("graph".into(), ProblemDomain::GraphAnalytics),
196 ("pagerank".into(), ProblemDomain::GraphAnalytics),
197 ("pathfinding".into(), ProblemDomain::GraphAnalytics),
198 ("community".into(), ProblemDomain::GraphAnalytics),
199 ("python".into(), ProblemDomain::PythonMigration),
201 ("sklearn".into(), ProblemDomain::PythonMigration),
202 ("scikit".into(), ProblemDomain::PythonMigration),
203 ("numpy".into(), ProblemDomain::PythonMigration),
204 ("pandas".into(), ProblemDomain::PythonMigration),
205 ("pytorch".into(), ProblemDomain::PythonMigration),
206 ("c code".into(), ProblemDomain::CMigration),
208 ("c++".into(), ProblemDomain::CMigration),
209 ("cpp".into(), ProblemDomain::CMigration),
210 ("bash".into(), ProblemDomain::ShellMigration),
212 ("shell".into(), ProblemDomain::ShellMigration),
213 ("script".into(), ProblemDomain::ShellMigration),
214 ("distributed".into(), ProblemDomain::DistributedCompute),
216 ("parallel".into(), ProblemDomain::DistributedCompute),
217 ("multi-node".into(), ProblemDomain::DistributedCompute),
218 ("cluster".into(), ProblemDomain::DistributedCompute),
219 ("data loading".into(), ProblemDomain::DataPipeline),
221 ("csv".into(), ProblemDomain::DataPipeline),
222 ("parquet".into(), ProblemDomain::DataPipeline),
223 ("etl".into(), ProblemDomain::DataPipeline),
224 ("lambda".into(), ProblemDomain::ModelServing),
226 ("serverless".into(), ProblemDomain::ModelServing),
227 ("container".into(), ProblemDomain::ModelServing),
228 ("edge".into(), ProblemDomain::ModelServing),
229 ("test".into(), ProblemDomain::Testing),
231 ("coverage".into(), ProblemDomain::Testing),
232 ("mutation".into(), ProblemDomain::Testing),
233 ("profile".into(), ProblemDomain::Profiling),
234 ("trace".into(), ProblemDomain::Profiling),
235 ("syscall".into(), ProblemDomain::Profiling),
236 ("validate".into(), ProblemDomain::Validation),
237 ("quality".into(), ProblemDomain::Validation),
238 ("formal verification".into(), ProblemDomain::Validation),
240 ("kani".into(), ProblemDomain::Validation),
241 ("contract".into(), ProblemDomain::Validation),
242 ("provable".into(), ProblemDomain::Validation),
243 ("proof".into(), ProblemDomain::Validation),
244 ("harness".into(), ProblemDomain::Validation),
245 ("bounded model checking".into(), ProblemDomain::Validation),
246 ("verification".into(), ProblemDomain::Validation),
247 ("parity".into(), ProblemDomain::Testing),
249 ("falsification".into(), ProblemDomain::Testing),
250 ("ground truth".into(), ProblemDomain::Testing),
251 ("oracle test".into(), ProblemDomain::Testing),
252 ("conversion parity".into(), ProblemDomain::Testing),
253 ("quantization drift".into(), ProblemDomain::Testing),
254 ("video".into(), ProblemDomain::MediaProduction),
256 ("render".into(), ProblemDomain::MediaProduction),
257 ("mlt".into(), ProblemDomain::MediaProduction),
258 ("encode".into(), ProblemDomain::MediaProduction),
259 ("decode".into(), ProblemDomain::MediaProduction),
260 ("transition".into(), ProblemDomain::MediaProduction),
261 ("media".into(), ProblemDomain::MediaProduction),
262 ("editing".into(), ProblemDomain::MediaProduction),
263 ("ffmpeg".into(), ProblemDomain::MediaProduction),
264 ("course".into(), ProblemDomain::MediaProduction),
265 ("screencast".into(), ProblemDomain::MediaProduction),
266 ("dissolve".into(), ProblemDomain::MediaProduction),
267 ("fade".into(), ProblemDomain::MediaProduction),
268 ("title card".into(), ProblemDomain::MediaProduction),
269 ("audio".into(), ProblemDomain::MediaProduction),
270 ("transcribe".into(), ProblemDomain::MediaProduction),
271 ("subtitle".into(), ProblemDomain::MediaProduction),
272 ("caption".into(), ProblemDomain::MediaProduction),
273 ("vocabulary".into(), ProblemDomain::MediaProduction),
274 ("key terms".into(), ProblemDomain::MediaProduction),
275 ("reflection".into(), ProblemDomain::MediaProduction),
276 ("landing page".into(), ProblemDomain::MediaProduction),
277 ("outline".into(), ProblemDomain::MediaProduction),
278 ("syllabus".into(), ProblemDomain::MediaProduction),
279 ("svg".into(), ProblemDomain::MediaProduction),
280 ("banner".into(), ProblemDomain::MediaProduction),
281 ("thumbnail".into(), ProblemDomain::MediaProduction),
282 ("grid protocol".into(), ProblemDomain::MediaProduction),
283 ("content quality".into(), ProblemDomain::MediaProduction),
284 ("completeness".into(), ProblemDomain::MediaProduction),
285 ("conformance".into(), ProblemDomain::MediaProduction),
286 ("freshness".into(), ProblemDomain::MediaProduction),
287 ("transcript".into(), ProblemDomain::MediaProduction),
288 ("tts".into(), ProblemDomain::MediaProduction),
289 ("narration".into(), ProblemDomain::MediaProduction),
290 ("text-to-speech".into(), ProblemDomain::MediaProduction),
291 ("av sync".into(), ProblemDomain::MediaProduction),
292 ("av-sync".into(), ProblemDomain::MediaProduction),
293 ("publishing".into(), ProblemDomain::MediaProduction),
294 ("coursepage".into(), ProblemDomain::MediaProduction),
295 ];
296
297 self.component_names.extend(
299 [
300 "trueno",
301 "trueno-db",
302 "trueno-graph",
303 "trueno-viz",
304 "trueno-rag",
305 "aprender",
306 "entrenar",
307 "realizar",
308 "depyler",
309 "decy",
310 "bashrs",
311 "ruchy",
312 "batuta",
313 "repartir",
314 "pforge",
315 "certeza",
316 "pmat",
317 "renacer",
318 "alimentar",
319 "pacha",
320 "whisper-apr",
321 "simular",
322 "probar",
323 "pepita",
324 "provable-contracts",
325 "tiny-model-ground-truth",
326 "forjar",
327 "rmedia",
328 ]
329 .map(String::from),
330 );
331 }
332
333 pub(crate) fn parse(&self, query: &str) -> ParsedQuery {
335 let lower = query.to_lowercase();
336
337 ParsedQuery {
338 original: query.to_string(),
339 domains: self.extract_domains(&lower),
340 algorithms: self.extract_algorithms(&lower),
341 keywords: self.extract_keywords(&lower),
342 data_size: self.extract_data_size(&lower),
343 performance_hints: self.extract_performance_hints(&lower),
344 mentioned_components: self.extract_components(&lower),
345 }
346 }
347
348 fn extract_domains(&self, query: &str) -> Vec<ProblemDomain> {
349 let mut domains = Vec::new();
350 let mut seen = HashSet::new();
351
352 for (keyword, domain) in &self.domain_keywords {
353 if query.contains(keyword) && !seen.contains(domain) {
354 domains.push(*domain);
355 seen.insert(*domain);
356 }
357 }
358
359 domains
360 }
361
362 fn extract_algorithms(&self, query: &str) -> Vec<String> {
363 let mut algorithms = Vec::new();
364
365 for algo in &self.algorithm_keywords {
366 if query.contains(algo) {
367 let normalized = algo.replace([' ', '-'], "_").to_lowercase();
369 if !algorithms.contains(&normalized) {
370 algorithms.push(normalized);
371 }
372 }
373 }
374
375 algorithms
376 }
377
378 fn extract_keywords(&self, query: &str) -> Vec<String> {
379 let stopwords: HashSet<_> = [
381 "the", "and", "for", "with", "how", "what", "can", "does", "want", "need", "use",
382 "using", "have", "this", "that", "from", "into", "about", "which", "when", "where",
383 "should",
384 ]
385 .iter()
386 .map(|s| (*s).to_string())
387 .collect();
388
389 query
390 .split_whitespace()
391 .map(|w| w.trim_matches(|c: char| !c.is_alphanumeric()))
392 .filter(|w| w.len() > 3 && !stopwords.contains(*w))
393 .map(String::from)
394 .collect()
395 }
396
397 fn extract_data_size(&self, query: &str) -> Option<DataSize> {
398 for (suffix, multiplier) in [
403 ("m samples", 1_000_000),
404 ("m rows", 1_000_000),
405 ("k samples", 1_000),
406 ("k rows", 1_000),
407 ("million", 1_000_000),
408 ("thousand", 1_000),
409 ("billion", 1_000_000_000),
410 ] {
411 if let Some(idx) = query.find(suffix) {
412 let before = &query[..idx];
414 if let Some(num_str) = before.split_whitespace().last() {
415 if let Ok(num) = num_str.parse::<u64>() {
416 return Some(DataSize::samples(num * multiplier));
417 }
418 }
419 }
420 }
421
422 if query.contains("large") || query.contains("huge") || query.contains("big") {
424 return Some(DataSize::samples(1_000_000));
425 }
426 if query.contains("small") || query.contains("tiny") {
427 return Some(DataSize::samples(1_000));
428 }
429
430 None
431 }
432
433 fn extract_performance_hints(&self, query: &str) -> Vec<PerformanceHint> {
434 let mut hints = Vec::new();
435
436 for &(keywords, hint) in HINT_PATTERNS {
438 if keywords.iter().any(|kw| query.contains(kw)) {
439 hints.push(hint);
440 }
441 }
442
443 if !hints.contains(&PerformanceHint::LowLatency)
446 && query.contains('<')
447 && query.contains("ms")
448 {
449 hints.push(PerformanceHint::LowLatency);
450 }
451
452 if query.contains("memory") && (query.contains("low") || query.contains("efficient")) {
454 hints.push(PerformanceHint::LowMemory);
455 }
456
457 hints
458 }
459
460 fn extract_components(&self, query: &str) -> Vec<String> {
461 self.component_names
462 .iter()
463 .filter(|name| query.contains(name.as_str()) || query.contains(&name.replace('-', " ")))
464 .cloned()
465 .collect()
466 }
467}
468
469pub struct QueryEngine {
475 parser: QueryParser,
476}
477
478impl Default for QueryEngine {
479 fn default() -> Self {
480 Self::new()
481 }
482}
483
484impl QueryEngine {
485 pub fn new() -> Self {
487 Self { parser: QueryParser::new() }
488 }
489
490 pub fn parse(&self, query: &str) -> ParsedQuery {
492 self.parser.parse(query)
493 }
494
495 pub fn primary_domain(&self, parsed: &ParsedQuery) -> Option<ProblemDomain> {
497 parsed.domains.first().copied()
498 }
499
500 pub fn primary_algorithm<'a>(&self, parsed: &'a ParsedQuery) -> Option<&'a str> {
502 parsed.algorithms.first().map(|s| s.as_str())
503 }
504
505 pub fn requires_gpu(&self, parsed: &ParsedQuery) -> bool {
507 parsed.performance_hints.contains(&PerformanceHint::GPURequired)
508 }
509
510 pub fn requires_distribution(&self, parsed: &ParsedQuery) -> bool {
512 parsed.performance_hints.contains(&PerformanceHint::Distributed)
513 || parsed.data_size.map(|s| s.is_large()).unwrap_or(false)
514 }
515
516 pub fn requires_sovereign(&self, parsed: &ParsedQuery) -> bool {
518 parsed.performance_hints.contains(&PerformanceHint::Sovereign)
519 }
520
521 pub fn estimate_complexity(&self, parsed: &ParsedQuery) -> OpComplexity {
523 if parsed.keywords.iter().any(|k| k.contains("matrix") || k.contains("matmul")) {
525 return OpComplexity::High;
526 }
527
528 for &(domain, complexity) in DOMAIN_COMPLEXITY {
530 if parsed.domains.contains(&domain) {
531 return complexity;
532 }
533 }
534
535 OpComplexity::Low
537 }
538}
539
540#[cfg(test)]
545mod tests {
546 use super::*;
547
548 fn parser() -> QueryParser {
554 QueryParser::new()
555 }
556
557 fn engine() -> QueryEngine {
559 QueryEngine::new()
560 }
561
562 #[test]
567 fn test_parser_new() {
568 let p = parser();
569 assert!(!p.algorithm_keywords.is_empty());
570 assert!(!p.domain_keywords.is_empty());
571 assert!(!p.component_names.is_empty());
572 }
573
574 #[test]
575 fn test_parse_basic() {
576 let parsed = parser().parse("Train a random forest classifier");
577
578 assert_eq!(parsed.original, "Train a random forest classifier");
579 assert!(!parsed.domains.is_empty());
580 assert!(!parsed.algorithms.is_empty());
581 }
582
583 #[test]
588 fn test_extract_supervised_learning() {
589 let parsed = parser().parse("I want to train a classification model");
590
591 assert!(parsed.domains.contains(&ProblemDomain::SupervisedLearning));
592 }
593
594 #[test]
595 fn test_extract_unsupervised_learning() {
596 let parsed = parser().parse("Help me cluster my data for anomaly detection");
597
598 assert!(parsed.domains.contains(&ProblemDomain::UnsupervisedLearning));
599 }
600
601 #[test]
602 fn test_extract_deep_learning() {
603 let parsed = parser().parse("Fine-tune a transformer with LoRA");
604
605 assert!(parsed.domains.contains(&ProblemDomain::DeepLearning));
606 }
607
608 #[test]
609 fn test_extract_inference() {
610 let parsed = parser().parse("Deploy model for production inference");
611
612 assert!(parsed.domains.contains(&ProblemDomain::Inference));
613 }
614
615 #[test]
616 fn test_extract_python_migration() {
617 let parsed = parser().parse("Convert my sklearn pipeline to Rust");
618
619 assert!(parsed.domains.contains(&ProblemDomain::PythonMigration));
620 }
621
622 #[test]
623 fn test_extract_linear_algebra() {
624 let parsed = parser().parse("Fast matrix multiplication with SIMD");
625
626 assert!(parsed.domains.contains(&ProblemDomain::LinearAlgebra));
627 }
628
629 #[test]
630 fn test_extract_graph_analytics() {
631 let parsed = parser().parse("Run pagerank on my graph");
632
633 assert!(parsed.domains.contains(&ProblemDomain::GraphAnalytics));
634 }
635
636 #[test]
637 fn test_extract_multiple_domains() {
638 let parsed = parser().parse("Train a classifier on python sklearn data");
639
640 assert!(parsed.domains.len() >= 2);
641 assert!(parsed.domains.contains(&ProblemDomain::SupervisedLearning));
642 assert!(parsed.domains.contains(&ProblemDomain::PythonMigration));
643 }
644
645 #[test]
650 fn test_extract_random_forest() {
651 let parsed = parser().parse("Train a random forest on my data");
652
653 assert!(parsed.algorithms.iter().any(|a| a.contains("random_forest")));
654 }
655
656 #[test]
657 fn test_extract_gradient_boosting() {
658 let parsed = parser().parse("Use gradient boosting for regression");
659
660 assert!(parsed.algorithms.iter().any(|a| a.contains("gradient_boosting") || a == "gbm"));
661 }
662
663 #[test]
664 fn test_extract_kmeans() {
665 let parsed = parser().parse("Cluster with k-means algorithm");
666
667 assert!(parsed.algorithms.iter().any(|a| a.contains("kmeans") || a.contains("k_means")));
668 }
669
670 #[test]
671 fn test_extract_lora() {
672 let parsed = parser().parse("Fine-tune with LoRA");
673
674 assert!(parsed.algorithms.iter().any(|a| a.contains("lora")));
675 }
676
677 #[test]
682 fn test_extract_data_size_million() {
683 let parsed = parser().parse("Train on 1 million samples");
684
685 assert!(parsed.data_size.is_some());
686 let size = parsed.data_size.expect("unexpected failure");
687 assert!(size.is_large());
688 }
689
690 #[test]
691 fn test_extract_data_size_1m() {
692 let parsed = parser().parse("Process 5m rows of data");
693
694 assert!(parsed.data_size.is_some());
695 }
696
697 #[test]
698 fn test_extract_data_size_thousand() {
699 let parsed = parser().parse("Test on 10 thousand samples");
700
701 assert!(parsed.data_size.is_some());
702 let size = parsed.data_size.expect("unexpected failure");
703 assert!(!size.is_large());
704 }
705
706 #[test]
707 fn test_extract_data_size_large_indicator() {
708 let parsed = parser().parse("Handle large dataset");
709
710 assert!(parsed.data_size.is_some());
711 assert!(parsed.data_size.expect("unexpected failure").is_large());
712 }
713
714 #[test]
715 fn test_extract_data_size_small_indicator() {
716 let parsed = parser().parse("Small dataset for testing");
717
718 assert!(parsed.data_size.is_some());
719 assert!(!parsed.data_size.expect("unexpected failure").is_large());
720 }
721
722 #[test]
727 fn test_extract_low_latency() {
728 let parsed = parser().parse("Need fast inference with low latency");
729
730 assert!(parsed.performance_hints.contains(&PerformanceHint::LowLatency));
731 }
732
733 #[test]
734 fn test_extract_gpu_required() {
735 let parsed = parser().parse("Train model on GPU");
736
737 assert!(parsed.performance_hints.contains(&PerformanceHint::GPURequired));
738 }
739
740 #[test]
741 fn test_extract_distributed() {
742 let parsed = parser().parse("Distributed training on multi-node cluster");
743
744 assert!(parsed.performance_hints.contains(&PerformanceHint::Distributed));
745 }
746
747 #[test]
748 fn test_extract_edge_deployment() {
749 let parsed = parser().parse("Deploy model to edge devices");
750
751 assert!(parsed.performance_hints.contains(&PerformanceHint::EdgeDeployment));
752 }
753
754 #[test]
755 fn test_extract_sovereign() {
756 let parsed = parser().parse("GDPR compliant, sovereign execution");
757
758 assert!(parsed.performance_hints.contains(&PerformanceHint::Sovereign));
759 }
760
761 #[test]
762 fn test_extract_eu_ai_act() {
763 let parsed = parser().parse("Must comply with EU AI Act");
764
765 assert!(parsed.performance_hints.contains(&PerformanceHint::Sovereign));
766 }
767
768 #[test]
773 fn test_extract_component_trueno() {
774 let parsed = parser().parse("Use trueno for tensor operations");
775
776 assert!(parsed.mentioned_components.contains(&"trueno".to_string()));
777 }
778
779 #[test]
780 fn test_extract_component_aprender() {
781 let parsed = parser().parse("Train with aprender random forest");
782
783 assert!(parsed.mentioned_components.contains(&"aprender".to_string()));
784 }
785
786 #[test]
787 fn test_extract_multiple_components() {
788 let parsed = parser().parse("Use depyler to convert sklearn to aprender");
789
790 assert!(parsed.mentioned_components.contains(&"depyler".to_string()));
791 assert!(parsed.mentioned_components.contains(&"aprender".to_string()));
792 }
793
794 #[test]
799 fn test_query_engine_new() {
800 let e = engine();
801 let parsed = e.parse("Test query");
802 assert!(!parsed.original.is_empty());
803 }
804
805 #[test]
806 fn test_query_engine_default() {
807 let e = QueryEngine::default();
808 let parsed = e.parse("Test");
809 assert_eq!(parsed.original, "Test");
810 }
811
812 #[test]
813 fn test_primary_domain() {
814 let e = engine();
815 let parsed = e.parse("Train a classifier");
816
817 let domain = e.primary_domain(&parsed);
818 assert!(domain.is_some());
819 assert_eq!(domain.expect("unexpected failure"), ProblemDomain::SupervisedLearning);
820 }
821
822 #[test]
823 fn test_primary_algorithm() {
824 let e = engine();
825 let parsed = e.parse("Use random forest");
826
827 let algo = e.primary_algorithm(&parsed);
828 assert!(algo.is_some());
829 assert!(algo.expect("unexpected failure").contains("random_forest"));
830 }
831
832 #[test]
833 fn test_requires_gpu() {
834 let e = engine();
835
836 let parsed = e.parse("Train on GPU");
837 assert!(e.requires_gpu(&parsed));
838
839 let parsed = e.parse("Simple CPU training");
840 assert!(!e.requires_gpu(&parsed));
841 }
842
843 #[test]
844 fn test_requires_distribution() {
845 let e = engine();
846
847 let parsed = e.parse("Distributed training");
848 assert!(e.requires_distribution(&parsed));
849
850 let parsed = e.parse("Train on 1 billion samples");
851 assert!(e.requires_distribution(&parsed));
852
853 let parsed = e.parse("Small local training");
854 assert!(!e.requires_distribution(&parsed));
855 }
856
857 #[test]
858 fn test_requires_sovereign() {
859 let e = engine();
860
861 let parsed = e.parse("GDPR compliant local execution");
862 assert!(e.requires_sovereign(&parsed));
863
864 let parsed = e.parse("Cloud training");
865 assert!(!e.requires_sovereign(&parsed));
866 }
867
868 #[test]
869 fn test_estimate_complexity_high() {
870 let e = engine();
871
872 let parsed = e.parse("Matrix multiplication");
873 assert_eq!(e.estimate_complexity(&parsed), OpComplexity::High);
874
875 let parsed = e.parse("Deep learning training");
876 assert_eq!(e.estimate_complexity(&parsed), OpComplexity::High);
877 }
878
879 #[test]
880 fn test_estimate_complexity_medium() {
881 let e = engine();
882
883 let parsed = e.parse("Train a classifier");
884 assert_eq!(e.estimate_complexity(&parsed), OpComplexity::Medium);
885
886 let parsed = e.parse("Graph pagerank");
887 assert_eq!(e.estimate_complexity(&parsed), OpComplexity::Medium);
888 }
889
890 #[test]
891 fn test_estimate_complexity_low() {
892 let e = engine();
893
894 let parsed = e.parse("Simple data loading");
895 assert_eq!(e.estimate_complexity(&parsed), OpComplexity::Low);
896 }
897
898 #[test]
903 fn test_full_query_parsing() {
904 let e = engine();
905 let parsed =
906 e.parse("I need to train a random forest on 1 million samples with GPU acceleration");
907
908 assert!(parsed.domains.contains(&ProblemDomain::SupervisedLearning));
910
911 assert!(parsed.algorithms.iter().any(|a| a.contains("random_forest")));
913
914 assert!(parsed.data_size.is_some());
916 assert!(parsed.data_size.expect("unexpected failure").is_large());
917
918 assert!(parsed.performance_hints.contains(&PerformanceHint::GPURequired));
920 }
921
922 #[test]
923 fn test_sklearn_migration_query() {
924 let e = engine();
925 let parsed = e.parse("Convert my sklearn pipeline with RandomForest to Rust aprender");
926
927 assert!(parsed.domains.contains(&ProblemDomain::PythonMigration));
928 assert!(parsed.algorithms.iter().any(|a| a.contains("random")));
929 assert!(parsed.mentioned_components.contains(&"aprender".to_string()));
930 }
931
932 #[test]
933 fn test_inference_query() {
934 let e = engine();
935 let parsed = e.parse("Deploy model to AWS Lambda with <10ms latency");
936
937 assert!(parsed.domains.contains(&ProblemDomain::Inference));
938 assert!(parsed.domains.contains(&ProblemDomain::ModelServing));
939 assert!(parsed.performance_hints.contains(&PerformanceHint::LowLatency));
940 }
941
942 #[test]
947 fn test_media_production_course_workflows() {
948 let p = parser();
949
950 let must_route: &[(&str, &str)] = &[
952 ("render a course video", "demo: rmedia course"),
953 ("transcribe audio from course", "demo: --transcribe"),
954 ("check transcript quality", "demo-course: whisper-apr score"),
955 ("generate course outline", "demo-outline: rmedia outline"),
956 ("extract key terms from transcripts", "demo-key-terms"),
957 ("generate reflection prompts", "demo-reflection"),
958 ("convert svg banner to png", "demo-banners: svg2png"),
959 ("generate landing page for course", "demo-coursera-page"),
960 ("tts narration for course video", "demo-tts: espeak-ng"),
961 ("check av sync on rendered video", "demo-av-check: probador"),
962 ("coursera publishing pipeline", "demo-coursera"),
963 ("subtitle burn in", "existing capability"),
964 ("generate thumbnail for video", "thumbnail generation"),
965 ("vocabulary enrichment from transcripts", "vocab-enrich skill"),
966 ("course quality scoring", "coursera-score target"),
967 ("content completeness check", "quality dimension"),
968 ("syllabus generation", "outline variant"),
969 ];
970
971 for (query, context) in must_route {
972 let parsed = p.parse(query);
973 assert!(
974 parsed.domains.contains(&ProblemDomain::MediaProduction),
975 "FAIL: '{}' ({}) did not route to MediaProduction. Got: {:?}",
976 query,
977 context,
978 parsed.domains,
979 );
980 }
981 }
982
983 #[test]
984 fn test_media_production_complexity_medium() {
985 let e = engine();
986 let parsed = e.parse("render course video with encoding");
987 assert_eq!(e.estimate_complexity(&parsed), OpComplexity::Medium);
988 }
989
990 #[test]
992 fn test_media_production_bare_concepts() {
993 let p = parser();
994
995 let must_route: &[(&str, &str)] = &[
996 ("tts narration pipeline", "demo-tts"),
997 ("text-to-speech for lectures", "demo-tts variant"),
998 ("av sync verification", "demo-av-check"),
999 ("publishing workflow", "coursera pipeline"),
1000 ("generate coursepage prompt", "demo-coursera-page"),
1001 ];
1002
1003 for (query, context) in must_route {
1004 let parsed = p.parse(query);
1005 assert!(
1006 parsed.domains.contains(&ProblemDomain::MediaProduction),
1007 "FAIL: '{}' ({}) did not route to MediaProduction. Got: {:?}",
1008 query,
1009 context,
1010 parsed.domains,
1011 );
1012 }
1013 }
1014}