1use std::collections::HashMap;
42
43#[derive(Debug, Clone, PartialEq, Eq, Hash)]
45#[allow(clippy::upper_case_acronyms)]
46pub enum SklearnAlgorithm {
47 LinearRegression,
49 Ridge,
50 Lasso,
51 LogisticRegression,
52
53 KMeans,
55 DBSCAN,
56
57 DecisionTreeClassifier,
59 DecisionTreeRegressor,
60 RandomForestClassifier,
61 RandomForestRegressor,
62
63 StandardScaler,
65 MinMaxScaler,
66 LabelEncoder,
67
68 TrainTestSplit,
70 CrossValidation,
71
72 Accuracy,
74 Precision,
75 Recall,
76 F1Score,
77 MeanSquaredError,
78 R2Score,
79}
80
81impl SklearnAlgorithm {
82 pub fn complexity(&self) -> crate::backend::OpComplexity {
84 use crate::backend::OpComplexity;
85
86 match self {
87 SklearnAlgorithm::StandardScaler
89 | SklearnAlgorithm::MinMaxScaler
90 | SklearnAlgorithm::LabelEncoder
91 | SklearnAlgorithm::TrainTestSplit => OpComplexity::Low,
92
93 SklearnAlgorithm::LinearRegression
95 | SklearnAlgorithm::Ridge
96 | SklearnAlgorithm::Lasso
97 | SklearnAlgorithm::LogisticRegression
98 | SklearnAlgorithm::Accuracy
99 | SklearnAlgorithm::Precision
100 | SklearnAlgorithm::Recall
101 | SklearnAlgorithm::F1Score
102 | SklearnAlgorithm::MeanSquaredError
103 | SklearnAlgorithm::R2Score => OpComplexity::Medium,
104
105 SklearnAlgorithm::DecisionTreeClassifier
107 | SklearnAlgorithm::DecisionTreeRegressor
108 | SklearnAlgorithm::RandomForestClassifier
109 | SklearnAlgorithm::RandomForestRegressor
110 | SklearnAlgorithm::KMeans
111 | SklearnAlgorithm::DBSCAN
112 | SklearnAlgorithm::CrossValidation => OpComplexity::High,
113 }
114 }
115
116 pub fn sklearn_module(&self) -> &str {
118 match self {
119 SklearnAlgorithm::LinearRegression
120 | SklearnAlgorithm::Ridge
121 | SklearnAlgorithm::Lasso
122 | SklearnAlgorithm::LogisticRegression => "sklearn.linear_model",
123
124 SklearnAlgorithm::KMeans | SklearnAlgorithm::DBSCAN => "sklearn.cluster",
125
126 SklearnAlgorithm::DecisionTreeClassifier | SklearnAlgorithm::DecisionTreeRegressor => {
127 "sklearn.tree"
128 }
129
130 SklearnAlgorithm::RandomForestClassifier | SklearnAlgorithm::RandomForestRegressor => {
131 "sklearn.ensemble"
132 }
133
134 SklearnAlgorithm::StandardScaler
135 | SklearnAlgorithm::MinMaxScaler
136 | SklearnAlgorithm::LabelEncoder => "sklearn.preprocessing",
137
138 SklearnAlgorithm::TrainTestSplit | SklearnAlgorithm::CrossValidation => {
139 "sklearn.model_selection"
140 }
141
142 SklearnAlgorithm::Accuracy
143 | SklearnAlgorithm::Precision
144 | SklearnAlgorithm::Recall
145 | SklearnAlgorithm::F1Score
146 | SklearnAlgorithm::MeanSquaredError
147 | SklearnAlgorithm::R2Score => "sklearn.metrics",
148 }
149 }
150}
151
152#[derive(Debug, Clone)]
154pub struct AprenderAlgorithm {
155 pub code_template: String,
157 pub imports: Vec<String>,
159 pub complexity: crate::backend::OpComplexity,
161 pub usage_pattern: String,
163}
164
165pub struct SklearnConverter {
167 algorithm_map: HashMap<SklearnAlgorithm, AprenderAlgorithm>,
169 backend_selector: crate::backend::BackendSelector,
171}
172
173impl Default for SklearnConverter {
174 fn default() -> Self {
175 Self::new()
176 }
177}
178
179impl SklearnConverter {
180 pub fn new() -> Self {
182 let mut algorithm_map = HashMap::new();
183
184 algorithm_map.insert(
186 SklearnAlgorithm::LinearRegression,
187 AprenderAlgorithm {
188 code_template: "LinearRegression::new()".to_string(),
189 imports: vec![
190 "use aprender::linear_model::LinearRegression;".to_string(),
191 "use aprender::Estimator;".to_string(),
192 ],
193 complexity: crate::backend::OpComplexity::Medium,
194 usage_pattern: "let mut model = LinearRegression::new();\nmodel.fit(&X_train, &y_train)?;\nlet predictions = model.predict(&X_test)?;".to_string(),
195 },
196 );
197
198 algorithm_map.insert(
199 SklearnAlgorithm::LogisticRegression,
200 AprenderAlgorithm {
201 code_template: "LogisticRegression::new()".to_string(),
202 imports: vec![
203 "use aprender::classification::LogisticRegression;".to_string(),
204 "use aprender::Estimator;".to_string(),
205 ],
206 complexity: crate::backend::OpComplexity::Medium,
207 usage_pattern: "let mut model = LogisticRegression::new();\nmodel.fit(&X_train, &y_train)?;\nlet predictions = model.predict(&X_test)?;".to_string(),
208 },
209 );
210
211 algorithm_map.insert(
213 SklearnAlgorithm::KMeans,
214 AprenderAlgorithm {
215 code_template: "KMeans::new({n_clusters})".to_string(),
216 imports: vec![
217 "use aprender::cluster::KMeans;".to_string(),
218 "use aprender::UnsupervisedEstimator;".to_string(),
219 ],
220 complexity: crate::backend::OpComplexity::High,
221 usage_pattern: "let mut model = KMeans::new(3);\nmodel.fit(&X)?;\nlet labels = model.predict(&X)?;".to_string(),
222 },
223 );
224
225 algorithm_map.insert(
227 SklearnAlgorithm::DecisionTreeClassifier,
228 AprenderAlgorithm {
229 code_template: "DecisionTreeClassifier::new()".to_string(),
230 imports: vec![
231 "use aprender::tree::DecisionTreeClassifier;".to_string(),
232 "use aprender::Estimator;".to_string(),
233 ],
234 complexity: crate::backend::OpComplexity::High,
235 usage_pattern: "let mut model = DecisionTreeClassifier::new();\nmodel.fit(&X_train, &y_train)?;\nlet predictions = model.predict(&X_test)?;".to_string(),
236 },
237 );
238
239 algorithm_map.insert(
241 SklearnAlgorithm::StandardScaler,
242 AprenderAlgorithm {
243 code_template: "StandardScaler::new()".to_string(),
244 imports: vec![
245 "use aprender::preprocessing::StandardScaler;".to_string(),
246 "use aprender::Transformer;".to_string(),
247 ],
248 complexity: crate::backend::OpComplexity::Low,
249 usage_pattern: "let mut scaler = StandardScaler::new();\nscaler.fit(&X_train)?;\nlet X_train_scaled = scaler.transform(&X_train)?;".to_string(),
250 },
251 );
252
253 algorithm_map.insert(
255 SklearnAlgorithm::TrainTestSplit,
256 AprenderAlgorithm {
257 code_template: "train_test_split(&X, &y, {test_size})".to_string(),
258 imports: vec!["use aprender::model_selection::train_test_split;".to_string()],
259 complexity: crate::backend::OpComplexity::Low,
260 usage_pattern:
261 "let (X_train, X_test, y_train, y_test) = train_test_split(&X, &y, 0.25)?;"
262 .to_string(),
263 },
264 );
265
266 algorithm_map.insert(
268 SklearnAlgorithm::Accuracy,
269 AprenderAlgorithm {
270 code_template: "accuracy_score(&y_true, &y_pred)".to_string(),
271 imports: vec!["use aprender::metrics::accuracy_score;".to_string()],
272 complexity: crate::backend::OpComplexity::Medium,
273 usage_pattern: "let acc = accuracy_score(&y_true, &y_pred)?;".to_string(),
274 },
275 );
276
277 algorithm_map.insert(
278 SklearnAlgorithm::MeanSquaredError,
279 AprenderAlgorithm {
280 code_template: "mean_squared_error(&y_true, &y_pred)".to_string(),
281 imports: vec!["use aprender::metrics::mean_squared_error;".to_string()],
282 complexity: crate::backend::OpComplexity::Medium,
283 usage_pattern: "let mse = mean_squared_error(&y_true, &y_pred)?;".to_string(),
284 },
285 );
286
287 Self { algorithm_map, backend_selector: crate::backend::BackendSelector::new() }
288 }
289
290 pub fn convert(&self, algorithm: &SklearnAlgorithm) -> Option<&AprenderAlgorithm> {
292 self.algorithm_map.get(algorithm)
293 }
294
295 pub fn recommend_backend(
297 &self,
298 algorithm: &SklearnAlgorithm,
299 data_size: usize,
300 ) -> crate::backend::Backend {
301 self.backend_selector.select_with_moe(algorithm.complexity(), data_size)
302 }
303
304 pub fn available_algorithms(&self) -> Vec<&SklearnAlgorithm> {
306 self.algorithm_map.keys().collect()
307 }
308
309 pub fn conversion_report(&self) -> String {
311 let mut report = String::from("sklearn → Aprender Conversion Map\n");
312 report.push_str("===================================\n\n");
313
314 let mut by_module: HashMap<&str, Vec<(&SklearnAlgorithm, &AprenderAlgorithm)>> =
316 HashMap::new();
317
318 for (alg, aprender_alg) in &self.algorithm_map {
319 by_module.entry(alg.sklearn_module()).or_default().push((alg, aprender_alg));
320 }
321
322 for (module, algorithms) in &by_module {
323 report.push_str(&format!("## {}\n\n", module));
324
325 for (alg, aprender_alg) in algorithms {
326 report.push_str(&format!("{:?}:\n", alg));
327 report.push_str(&format!(" Template: {}\n", aprender_alg.code_template));
328 report.push_str(&format!(" Complexity: {:?}\n", aprender_alg.complexity));
329 report.push_str(&format!(" Imports: {}\n", aprender_alg.imports.join(", ")));
330 report.push_str(&format!(
331 " Usage:\n {}\n\n",
332 aprender_alg.usage_pattern.replace('\n', "\n ")
333 ));
334 }
335 report.push('\n');
336 }
337
338 report
339 }
340}
341
342#[cfg(test)]
343mod tests {
344 use super::*;
345
346 #[test]
347 fn test_converter_creation() {
348 let converter = SklearnConverter::new();
349 assert!(!converter.available_algorithms().is_empty());
350 }
351
352 #[test]
353 fn test_algorithm_complexity() {
354 assert_eq!(
355 SklearnAlgorithm::LinearRegression.complexity(),
356 crate::backend::OpComplexity::Medium
357 );
358 assert_eq!(
359 SklearnAlgorithm::StandardScaler.complexity(),
360 crate::backend::OpComplexity::Low
361 );
362 assert_eq!(SklearnAlgorithm::KMeans.complexity(), crate::backend::OpComplexity::High);
363 }
364
365 #[test]
366 fn test_linear_regression_conversion() {
367 let converter = SklearnConverter::new();
368 let aprender_alg =
369 converter.convert(&SklearnAlgorithm::LinearRegression).expect("unexpected failure");
370 assert!(aprender_alg.code_template.contains("LinearRegression"));
371 assert!(aprender_alg.imports.iter().any(|i| i.contains("linear_model")));
372 }
373
374 #[test]
375 fn test_kmeans_conversion() {
376 let converter = SklearnConverter::new();
377 let aprender_alg = converter.convert(&SklearnAlgorithm::KMeans).expect("conversion failed");
378 assert!(aprender_alg.code_template.contains("KMeans"));
379 assert!(aprender_alg.imports.iter().any(|i| i.contains("cluster")));
380 }
381
382 #[test]
383 fn test_backend_recommendation() {
384 let converter = SklearnConverter::new();
385
386 let backend = converter.recommend_backend(&SklearnAlgorithm::StandardScaler, 100);
388 assert_eq!(backend, crate::backend::Backend::Scalar);
389
390 let backend = converter.recommend_backend(&SklearnAlgorithm::LinearRegression, 50_000);
392 assert_eq!(backend, crate::backend::Backend::SIMD);
393
394 let backend = converter.recommend_backend(&SklearnAlgorithm::KMeans, 100_000);
396 assert_eq!(backend, crate::backend::Backend::GPU);
397 }
398
399 #[test]
400 fn test_sklearn_module_paths() {
401 assert_eq!(SklearnAlgorithm::LinearRegression.sklearn_module(), "sklearn.linear_model");
402 assert_eq!(SklearnAlgorithm::KMeans.sklearn_module(), "sklearn.cluster");
403 assert_eq!(SklearnAlgorithm::StandardScaler.sklearn_module(), "sklearn.preprocessing");
404 }
405
406 #[test]
407 fn test_conversion_report() {
408 let converter = SklearnConverter::new();
409 let report = converter.conversion_report();
410 assert!(report.contains("sklearn → Aprender"));
411 assert!(report.contains("LinearRegression"));
412 assert!(report.contains("Complexity"));
413 }
414
415 #[test]
420 fn test_all_sklearn_algorithms_exist() {
421 let algs = vec![
423 SklearnAlgorithm::LinearRegression,
424 SklearnAlgorithm::Ridge,
425 SklearnAlgorithm::Lasso,
426 SklearnAlgorithm::LogisticRegression,
427 SklearnAlgorithm::KMeans,
428 SklearnAlgorithm::DBSCAN,
429 SklearnAlgorithm::DecisionTreeClassifier,
430 SklearnAlgorithm::DecisionTreeRegressor,
431 SklearnAlgorithm::RandomForestClassifier,
432 SklearnAlgorithm::RandomForestRegressor,
433 SklearnAlgorithm::StandardScaler,
434 SklearnAlgorithm::MinMaxScaler,
435 SklearnAlgorithm::LabelEncoder,
436 SklearnAlgorithm::TrainTestSplit,
437 SklearnAlgorithm::CrossValidation,
438 SklearnAlgorithm::Accuracy,
439 SklearnAlgorithm::Precision,
440 SklearnAlgorithm::Recall,
441 SklearnAlgorithm::F1Score,
442 SklearnAlgorithm::MeanSquaredError,
443 SklearnAlgorithm::R2Score,
444 ];
445 assert_eq!(algs.len(), 21); }
447
448 #[test]
449 fn test_algorithm_equality() {
450 assert_eq!(SklearnAlgorithm::LinearRegression, SklearnAlgorithm::LinearRegression);
451 assert_ne!(SklearnAlgorithm::LinearRegression, SklearnAlgorithm::KMeans);
452 }
453
454 #[test]
455 fn test_algorithm_clone() {
456 let alg1 = SklearnAlgorithm::DecisionTreeClassifier;
457 let alg2 = alg1.clone();
458 assert_eq!(alg1, alg2);
459 }
460
461 #[test]
462 fn test_complexity_low_algorithms() {
463 let low_algs = vec![
464 SklearnAlgorithm::StandardScaler,
465 SklearnAlgorithm::MinMaxScaler,
466 SklearnAlgorithm::LabelEncoder,
467 SklearnAlgorithm::TrainTestSplit,
468 ];
469
470 for alg in low_algs {
471 assert_eq!(alg.complexity(), crate::backend::OpComplexity::Low);
472 }
473 }
474
475 #[test]
476 fn test_complexity_medium_algorithms() {
477 let medium_algs = vec![
478 SklearnAlgorithm::LinearRegression,
479 SklearnAlgorithm::Ridge,
480 SklearnAlgorithm::Lasso,
481 SklearnAlgorithm::LogisticRegression,
482 SklearnAlgorithm::Accuracy,
483 SklearnAlgorithm::Precision,
484 SklearnAlgorithm::Recall,
485 SklearnAlgorithm::F1Score,
486 SklearnAlgorithm::MeanSquaredError,
487 SklearnAlgorithm::R2Score,
488 ];
489
490 for alg in medium_algs {
491 assert_eq!(alg.complexity(), crate::backend::OpComplexity::Medium);
492 }
493 }
494
495 #[test]
496 fn test_complexity_high_algorithms() {
497 let high_algs = vec![
498 SklearnAlgorithm::DecisionTreeClassifier,
499 SklearnAlgorithm::DecisionTreeRegressor,
500 SklearnAlgorithm::RandomForestClassifier,
501 SklearnAlgorithm::RandomForestRegressor,
502 SklearnAlgorithm::KMeans,
503 SklearnAlgorithm::DBSCAN,
504 SklearnAlgorithm::CrossValidation,
505 ];
506
507 for alg in high_algs {
508 assert_eq!(alg.complexity(), crate::backend::OpComplexity::High);
509 }
510 }
511
512 #[test]
513 fn test_sklearn_module_linear_model() {
514 let linear_algs = vec![
515 SklearnAlgorithm::LinearRegression,
516 SklearnAlgorithm::Ridge,
517 SklearnAlgorithm::Lasso,
518 SklearnAlgorithm::LogisticRegression,
519 ];
520
521 for alg in linear_algs {
522 assert_eq!(alg.sklearn_module(), "sklearn.linear_model");
523 }
524 }
525
526 #[test]
527 fn test_sklearn_module_cluster() {
528 let cluster_algs = vec![SklearnAlgorithm::KMeans, SklearnAlgorithm::DBSCAN];
529
530 for alg in cluster_algs {
531 assert_eq!(alg.sklearn_module(), "sklearn.cluster");
532 }
533 }
534
535 #[test]
536 fn test_sklearn_module_tree() {
537 let tree_algs =
538 vec![SklearnAlgorithm::DecisionTreeClassifier, SklearnAlgorithm::DecisionTreeRegressor];
539
540 for alg in tree_algs {
541 assert_eq!(alg.sklearn_module(), "sklearn.tree");
542 }
543 }
544
545 #[test]
546 fn test_sklearn_module_ensemble() {
547 let ensemble_algs =
548 vec![SklearnAlgorithm::RandomForestClassifier, SklearnAlgorithm::RandomForestRegressor];
549
550 for alg in ensemble_algs {
551 assert_eq!(alg.sklearn_module(), "sklearn.ensemble");
552 }
553 }
554
555 #[test]
556 fn test_sklearn_module_preprocessing() {
557 let preprocessing_algs = vec![
558 SklearnAlgorithm::StandardScaler,
559 SklearnAlgorithm::MinMaxScaler,
560 SklearnAlgorithm::LabelEncoder,
561 ];
562
563 for alg in preprocessing_algs {
564 assert_eq!(alg.sklearn_module(), "sklearn.preprocessing");
565 }
566 }
567
568 #[test]
569 fn test_sklearn_module_model_selection() {
570 let model_selection_algs =
571 vec![SklearnAlgorithm::TrainTestSplit, SklearnAlgorithm::CrossValidation];
572
573 for alg in model_selection_algs {
574 assert_eq!(alg.sklearn_module(), "sklearn.model_selection");
575 }
576 }
577
578 #[test]
579 fn test_sklearn_module_metrics() {
580 let metrics_algs = vec![
581 SklearnAlgorithm::Accuracy,
582 SklearnAlgorithm::Precision,
583 SklearnAlgorithm::Recall,
584 SklearnAlgorithm::F1Score,
585 SklearnAlgorithm::MeanSquaredError,
586 SklearnAlgorithm::R2Score,
587 ];
588
589 for alg in metrics_algs {
590 assert_eq!(alg.sklearn_module(), "sklearn.metrics");
591 }
592 }
593
594 #[test]
599 fn test_aprender_algorithm_construction() {
600 let alg = AprenderAlgorithm {
601 code_template: "test_template".to_string(),
602 imports: vec!["use test;".to_string()],
603 complexity: crate::backend::OpComplexity::Medium,
604 usage_pattern: "let x = test();".to_string(),
605 };
606
607 assert_eq!(alg.code_template, "test_template");
608 assert_eq!(alg.imports.len(), 1);
609 assert_eq!(alg.complexity, crate::backend::OpComplexity::Medium);
610 assert!(alg.usage_pattern.contains("test()"));
611 }
612
613 #[test]
614 fn test_aprender_algorithm_clone() {
615 let alg1 = AprenderAlgorithm {
616 code_template: "template".to_string(),
617 imports: vec!["import".to_string()],
618 complexity: crate::backend::OpComplexity::High,
619 usage_pattern: "usage".to_string(),
620 };
621
622 let alg2 = alg1.clone();
623 assert_eq!(alg1.code_template, alg2.code_template);
624 assert_eq!(alg1.imports, alg2.imports);
625 assert_eq!(alg1.complexity, alg2.complexity);
626 }
627
628 #[test]
633 fn test_converter_default() {
634 let converter = SklearnConverter::default();
635 assert!(!converter.available_algorithms().is_empty());
636 }
637
638 #[test]
639 fn test_convert_all_mapped_algorithms() {
640 let converter = SklearnConverter::new();
641
642 let mapped_algs = vec![
644 SklearnAlgorithm::LinearRegression,
645 SklearnAlgorithm::LogisticRegression,
646 SklearnAlgorithm::KMeans,
647 SklearnAlgorithm::DecisionTreeClassifier,
648 SklearnAlgorithm::StandardScaler,
649 SklearnAlgorithm::TrainTestSplit,
650 SklearnAlgorithm::Accuracy,
651 SklearnAlgorithm::MeanSquaredError,
652 ];
653
654 for alg in mapped_algs {
655 assert!(converter.convert(&alg).is_some(), "Missing mapping for {:?}", alg);
656 }
657 }
658
659 #[test]
660 fn test_convert_unmapped_algorithm() {
661 let converter = SklearnConverter::new();
662
663 let result = converter.convert(&SklearnAlgorithm::Ridge);
666 let _ = result;
668 }
669
670 #[test]
671 fn test_logistic_regression_conversion() {
672 let converter = SklearnConverter::new();
673 let alg =
674 converter.convert(&SklearnAlgorithm::LogisticRegression).expect("unexpected failure");
675
676 assert!(alg.code_template.contains("LogisticRegression"));
677 assert!(alg.imports.iter().any(|i| i.contains("classification")));
678 assert_eq!(alg.complexity, crate::backend::OpComplexity::Medium);
679 }
680
681 #[test]
682 fn test_decision_tree_conversion() {
683 let converter = SklearnConverter::new();
684 let alg = converter
685 .convert(&SklearnAlgorithm::DecisionTreeClassifier)
686 .expect("unexpected failure");
687
688 assert!(alg.code_template.contains("DecisionTreeClassifier"));
689 assert!(alg.imports.iter().any(|i| i.contains("tree")));
690 assert_eq!(alg.complexity, crate::backend::OpComplexity::High);
691 }
692
693 #[test]
694 fn test_standard_scaler_conversion() {
695 let converter = SklearnConverter::new();
696 let alg = converter.convert(&SklearnAlgorithm::StandardScaler).expect("unexpected failure");
697
698 assert!(alg.code_template.contains("StandardScaler"));
699 assert!(alg.imports.iter().any(|i| i.contains("preprocessing")));
700 assert_eq!(alg.complexity, crate::backend::OpComplexity::Low);
701 }
702
703 #[test]
704 fn test_train_test_split_conversion() {
705 let converter = SklearnConverter::new();
706 let alg = converter.convert(&SklearnAlgorithm::TrainTestSplit).expect("unexpected failure");
707
708 assert!(alg.code_template.contains("train_test_split"));
709 assert!(alg.imports.iter().any(|i| i.contains("model_selection")));
710 }
711
712 #[test]
713 fn test_accuracy_conversion() {
714 let converter = SklearnConverter::new();
715 let alg = converter.convert(&SklearnAlgorithm::Accuracy).expect("conversion failed");
716
717 assert!(alg.code_template.contains("accuracy_score"));
718 assert!(alg.imports.iter().any(|i| i.contains("metrics")));
719 }
720
721 #[test]
722 fn test_mse_conversion() {
723 let converter = SklearnConverter::new();
724 let alg =
725 converter.convert(&SklearnAlgorithm::MeanSquaredError).expect("unexpected failure");
726
727 assert!(alg.code_template.contains("mean_squared_error"));
728 assert!(alg.imports.iter().any(|i| i.contains("metrics")));
729 }
730
731 #[test]
732 fn test_available_algorithms() {
733 let converter = SklearnConverter::new();
734 let algs = converter.available_algorithms();
735
736 assert!(!algs.is_empty());
737 assert!(algs.len() >= 8);
739 }
740
741 #[test]
742 fn test_recommend_backend_low_complexity() {
743 let converter = SklearnConverter::new();
744
745 let backend = converter.recommend_backend(&SklearnAlgorithm::StandardScaler, 10);
747 assert_eq!(backend, crate::backend::Backend::Scalar);
748 }
749
750 #[test]
751 fn test_recommend_backend_medium_complexity() {
752 let converter = SklearnConverter::new();
753
754 let backend = converter.recommend_backend(&SklearnAlgorithm::LinearRegression, 50_000);
756 assert_eq!(backend, crate::backend::Backend::SIMD);
757 }
758
759 #[test]
760 fn test_recommend_backend_high_complexity() {
761 let converter = SklearnConverter::new();
762
763 let backend =
765 converter.recommend_backend(&SklearnAlgorithm::RandomForestClassifier, 500_000);
766 assert_eq!(backend, crate::backend::Backend::GPU);
767 }
768
769 #[test]
770 fn test_recommend_backend_clustering() {
771 let converter = SklearnConverter::new();
772
773 let backend = converter.recommend_backend(&SklearnAlgorithm::KMeans, 1_000_000);
775 assert_eq!(backend, crate::backend::Backend::GPU);
776 }
777
778 #[test]
779 fn test_conversion_report_structure() {
780 let converter = SklearnConverter::new();
781 let report = converter.conversion_report();
782
783 assert!(report.contains("sklearn → Aprender"));
785 assert!(report.contains("==="));
786 assert!(report.contains("##")); assert!(report.contains("Template:"));
788 assert!(report.contains("Imports:"));
789 assert!(report.contains("Usage:"));
790 }
791
792 #[test]
793 fn test_conversion_report_has_modules() {
794 let converter = SklearnConverter::new();
795 let report = converter.conversion_report();
796
797 assert!(report.contains("sklearn"));
799 }
800
801 #[test]
802 fn test_conversion_report_has_all_algorithms() {
803 let converter = SklearnConverter::new();
804 let report = converter.conversion_report();
805
806 assert!(
808 report.contains("LinearRegression")
809 || report.contains("KMeans")
810 || report.contains("StandardScaler")
811 );
812 }
813
814 #[test]
815 fn test_usage_patterns_not_empty() {
816 let converter = SklearnConverter::new();
817
818 for alg in converter.available_algorithms() {
819 if let Some(aprender_alg) = converter.convert(alg) {
820 assert!(
821 !aprender_alg.usage_pattern.is_empty(),
822 "Empty usage pattern for {:?}",
823 alg
824 );
825 assert!(
826 !aprender_alg.code_template.is_empty(),
827 "Empty code template for {:?}",
828 alg
829 );
830 assert!(!aprender_alg.imports.is_empty(), "Empty imports for {:?}", alg);
831 }
832 }
833 }
834
835 #[test]
836 fn test_imports_are_valid_rust() {
837 let converter = SklearnConverter::new();
838
839 for alg in converter.available_algorithms() {
840 if let Some(aprender_alg) = converter.convert(alg) {
841 for import in &aprender_alg.imports {
842 assert!(import.starts_with("use "), "Invalid import syntax: {}", import);
843 assert!(import.ends_with(';'), "Import missing semicolon: {}", import);
844 }
845 }
846 }
847 }
848
849 #[test]
850 fn test_linear_models_have_estimator_trait() {
851 let converter = SklearnConverter::new();
852
853 let linear_models =
854 vec![SklearnAlgorithm::LinearRegression, SklearnAlgorithm::LogisticRegression];
855
856 for alg in linear_models {
857 if let Some(aprender_alg) = converter.convert(&alg) {
858 assert!(
859 aprender_alg.imports.iter().any(|i| i.contains("Estimator")),
860 "Linear model {:?} should import Estimator trait",
861 alg
862 );
863 }
864 }
865 }
866
867 #[test]
868 fn test_clustering_has_unsupervised_trait() {
869 let converter = SklearnConverter::new();
870
871 if let Some(kmeans_alg) = converter.convert(&SklearnAlgorithm::KMeans) {
872 assert!(
873 kmeans_alg.imports.iter().any(|i| i.contains("UnsupervisedEstimator")),
874 "KMeans should import UnsupervisedEstimator trait"
875 );
876 }
877 }
878
879 #[test]
880 fn test_preprocessing_has_transformer_trait() {
881 let converter = SklearnConverter::new();
882
883 if let Some(scaler_alg) = converter.convert(&SklearnAlgorithm::StandardScaler) {
884 assert!(
885 scaler_alg.imports.iter().any(|i| i.contains("Transformer")),
886 "StandardScaler should import Transformer trait"
887 );
888 }
889 }
890}