Skip to main content

batuta/
sklearn_converter.rs

1//! sklearn to Aprender conversion module (BATUTA-009)
2//!
3//! Converts Python scikit-learn (sklearn) algorithms to Rust Aprender equivalents
4//! with automatic backend selection and ergonomic API mapping.
5//!
6//! # Conversion Strategy
7//!
8//! sklearn algorithms are mapped to equivalent Aprender algorithms:
9//! - `LinearRegression()` → `LinearRegression::new()`
10//! - `KMeans(n_clusters=3)` → `KMeans::new(3)`
11//! - `DecisionTreeClassifier()` → `DecisionTreeClassifier::new()`
12//! - `train_test_split()` → `train_test_split()`
13//! - Model methods automatically use MoE routing for optimal performance
14//!
15//! # Example
16//!
17//! ```python
18//! # Python sklearn code
19//! from sklearn.linear_model import LinearRegression
20//! from sklearn.model_selection import train_test_split
21//!
22//! X_train, X_test, y_train, y_test = train_test_split(X, y)
23//! model = LinearRegression()
24//! model.fit(X_train, y_train)
25//! predictions = model.predict(X_test)
26//! ```
27//!
28//! Converts to:
29//!
30//! ```rust,ignore
31//! use aprender::linear_model::LinearRegression;
32//! use aprender::model_selection::train_test_split;
33//! use aprender::Estimator;
34//!
35//! let (X_train, X_test, y_train, y_test) = train_test_split(&X, &y, 0.25)?;
36//! let mut model = LinearRegression::new();
37//! model.fit(&X_train, &y_train)?;
38//! let predictions = model.predict(&X_test)?;
39//! ```
40
41use std::collections::HashMap;
42
43/// sklearn algorithm types
44#[derive(Debug, Clone, PartialEq, Eq, Hash)]
45#[allow(clippy::upper_case_acronyms)]
46pub enum SklearnAlgorithm {
47    // Linear Models
48    LinearRegression,
49    Ridge,
50    Lasso,
51    LogisticRegression,
52
53    // Clustering
54    KMeans,
55    DBSCAN,
56
57    // Tree Models
58    DecisionTreeClassifier,
59    DecisionTreeRegressor,
60    RandomForestClassifier,
61    RandomForestRegressor,
62
63    // Preprocessing
64    StandardScaler,
65    MinMaxScaler,
66    LabelEncoder,
67
68    // Model Selection
69    TrainTestSplit,
70    CrossValidation,
71
72    // Metrics
73    Accuracy,
74    Precision,
75    Recall,
76    F1Score,
77    MeanSquaredError,
78    R2Score,
79}
80
81impl SklearnAlgorithm {
82    /// Get the computational complexity for MoE routing
83    pub fn complexity(&self) -> crate::backend::OpComplexity {
84        use crate::backend::OpComplexity;
85
86        match self {
87            // Preprocessing operations are Low complexity
88            SklearnAlgorithm::StandardScaler
89            | SklearnAlgorithm::MinMaxScaler
90            | SklearnAlgorithm::LabelEncoder
91            | SklearnAlgorithm::TrainTestSplit => OpComplexity::Low,
92
93            // Linear models and metrics are Medium complexity
94            SklearnAlgorithm::LinearRegression
95            | SklearnAlgorithm::Ridge
96            | SklearnAlgorithm::Lasso
97            | SklearnAlgorithm::LogisticRegression
98            | SklearnAlgorithm::Accuracy
99            | SklearnAlgorithm::Precision
100            | SklearnAlgorithm::Recall
101            | SklearnAlgorithm::F1Score
102            | SklearnAlgorithm::MeanSquaredError
103            | SklearnAlgorithm::R2Score => OpComplexity::Medium,
104
105            // Tree models, ensemble methods, and clustering are High complexity
106            SklearnAlgorithm::DecisionTreeClassifier
107            | SklearnAlgorithm::DecisionTreeRegressor
108            | SklearnAlgorithm::RandomForestClassifier
109            | SklearnAlgorithm::RandomForestRegressor
110            | SklearnAlgorithm::KMeans
111            | SklearnAlgorithm::DBSCAN
112            | SklearnAlgorithm::CrossValidation => OpComplexity::High,
113        }
114    }
115
116    /// Get the sklearn module path
117    pub fn sklearn_module(&self) -> &str {
118        match self {
119            SklearnAlgorithm::LinearRegression
120            | SklearnAlgorithm::Ridge
121            | SklearnAlgorithm::Lasso
122            | SklearnAlgorithm::LogisticRegression => "sklearn.linear_model",
123
124            SklearnAlgorithm::KMeans | SklearnAlgorithm::DBSCAN => "sklearn.cluster",
125
126            SklearnAlgorithm::DecisionTreeClassifier | SklearnAlgorithm::DecisionTreeRegressor => {
127                "sklearn.tree"
128            }
129
130            SklearnAlgorithm::RandomForestClassifier | SklearnAlgorithm::RandomForestRegressor => {
131                "sklearn.ensemble"
132            }
133
134            SklearnAlgorithm::StandardScaler
135            | SklearnAlgorithm::MinMaxScaler
136            | SklearnAlgorithm::LabelEncoder => "sklearn.preprocessing",
137
138            SklearnAlgorithm::TrainTestSplit | SklearnAlgorithm::CrossValidation => {
139                "sklearn.model_selection"
140            }
141
142            SklearnAlgorithm::Accuracy
143            | SklearnAlgorithm::Precision
144            | SklearnAlgorithm::Recall
145            | SklearnAlgorithm::F1Score
146            | SklearnAlgorithm::MeanSquaredError
147            | SklearnAlgorithm::R2Score => "sklearn.metrics",
148        }
149    }
150}
151
152/// Aprender equivalent algorithm
153#[derive(Debug, Clone)]
154pub struct AprenderAlgorithm {
155    /// Rust code template for the algorithm
156    pub code_template: String,
157    /// Required imports
158    pub imports: Vec<String>,
159    /// Computational complexity
160    pub complexity: crate::backend::OpComplexity,
161    /// Typical usage pattern
162    pub usage_pattern: String,
163}
164
165/// sklearn to Aprender converter
166pub struct SklearnConverter {
167    /// Algorithm mapping
168    algorithm_map: HashMap<SklearnAlgorithm, AprenderAlgorithm>,
169    /// Backend selector for MoE routing
170    backend_selector: crate::backend::BackendSelector,
171}
172
173impl Default for SklearnConverter {
174    fn default() -> Self {
175        Self::new()
176    }
177}
178
179impl SklearnConverter {
180    /// Create a new sklearn converter with default mappings
181    pub fn new() -> Self {
182        let mut algorithm_map = HashMap::new();
183
184        // Linear Models
185        algorithm_map.insert(
186            SklearnAlgorithm::LinearRegression,
187            AprenderAlgorithm {
188                code_template: "LinearRegression::new()".to_string(),
189                imports: vec![
190                    "use aprender::linear_model::LinearRegression;".to_string(),
191                    "use aprender::Estimator;".to_string(),
192                ],
193                complexity: crate::backend::OpComplexity::Medium,
194                usage_pattern: "let mut model = LinearRegression::new();\nmodel.fit(&X_train, &y_train)?;\nlet predictions = model.predict(&X_test)?;".to_string(),
195            },
196        );
197
198        algorithm_map.insert(
199            SklearnAlgorithm::LogisticRegression,
200            AprenderAlgorithm {
201                code_template: "LogisticRegression::new()".to_string(),
202                imports: vec![
203                    "use aprender::classification::LogisticRegression;".to_string(),
204                    "use aprender::Estimator;".to_string(),
205                ],
206                complexity: crate::backend::OpComplexity::Medium,
207                usage_pattern: "let mut model = LogisticRegression::new();\nmodel.fit(&X_train, &y_train)?;\nlet predictions = model.predict(&X_test)?;".to_string(),
208            },
209        );
210
211        // Clustering
212        algorithm_map.insert(
213            SklearnAlgorithm::KMeans,
214            AprenderAlgorithm {
215                code_template: "KMeans::new({n_clusters})".to_string(),
216                imports: vec![
217                    "use aprender::cluster::KMeans;".to_string(),
218                    "use aprender::UnsupervisedEstimator;".to_string(),
219                ],
220                complexity: crate::backend::OpComplexity::High,
221                usage_pattern: "let mut model = KMeans::new(3);\nmodel.fit(&X)?;\nlet labels = model.predict(&X)?;".to_string(),
222            },
223        );
224
225        // Tree Models
226        algorithm_map.insert(
227            SklearnAlgorithm::DecisionTreeClassifier,
228            AprenderAlgorithm {
229                code_template: "DecisionTreeClassifier::new()".to_string(),
230                imports: vec![
231                    "use aprender::tree::DecisionTreeClassifier;".to_string(),
232                    "use aprender::Estimator;".to_string(),
233                ],
234                complexity: crate::backend::OpComplexity::High,
235                usage_pattern: "let mut model = DecisionTreeClassifier::new();\nmodel.fit(&X_train, &y_train)?;\nlet predictions = model.predict(&X_test)?;".to_string(),
236            },
237        );
238
239        // Preprocessing
240        algorithm_map.insert(
241            SklearnAlgorithm::StandardScaler,
242            AprenderAlgorithm {
243                code_template: "StandardScaler::new()".to_string(),
244                imports: vec![
245                    "use aprender::preprocessing::StandardScaler;".to_string(),
246                    "use aprender::Transformer;".to_string(),
247                ],
248                complexity: crate::backend::OpComplexity::Low,
249                usage_pattern: "let mut scaler = StandardScaler::new();\nscaler.fit(&X_train)?;\nlet X_train_scaled = scaler.transform(&X_train)?;".to_string(),
250            },
251        );
252
253        // Model Selection
254        algorithm_map.insert(
255            SklearnAlgorithm::TrainTestSplit,
256            AprenderAlgorithm {
257                code_template: "train_test_split(&X, &y, {test_size})".to_string(),
258                imports: vec!["use aprender::model_selection::train_test_split;".to_string()],
259                complexity: crate::backend::OpComplexity::Low,
260                usage_pattern:
261                    "let (X_train, X_test, y_train, y_test) = train_test_split(&X, &y, 0.25)?;"
262                        .to_string(),
263            },
264        );
265
266        // Metrics
267        algorithm_map.insert(
268            SklearnAlgorithm::Accuracy,
269            AprenderAlgorithm {
270                code_template: "accuracy_score(&y_true, &y_pred)".to_string(),
271                imports: vec!["use aprender::metrics::accuracy_score;".to_string()],
272                complexity: crate::backend::OpComplexity::Medium,
273                usage_pattern: "let acc = accuracy_score(&y_true, &y_pred)?;".to_string(),
274            },
275        );
276
277        algorithm_map.insert(
278            SklearnAlgorithm::MeanSquaredError,
279            AprenderAlgorithm {
280                code_template: "mean_squared_error(&y_true, &y_pred)".to_string(),
281                imports: vec!["use aprender::metrics::mean_squared_error;".to_string()],
282                complexity: crate::backend::OpComplexity::Medium,
283                usage_pattern: "let mse = mean_squared_error(&y_true, &y_pred)?;".to_string(),
284            },
285        );
286
287        Self { algorithm_map, backend_selector: crate::backend::BackendSelector::new() }
288    }
289
290    /// Convert a sklearn algorithm to Aprender
291    pub fn convert(&self, algorithm: &SklearnAlgorithm) -> Option<&AprenderAlgorithm> {
292        self.algorithm_map.get(algorithm)
293    }
294
295    /// Get recommended backend for an algorithm
296    pub fn recommend_backend(
297        &self,
298        algorithm: &SklearnAlgorithm,
299        data_size: usize,
300    ) -> crate::backend::Backend {
301        self.backend_selector.select_with_moe(algorithm.complexity(), data_size)
302    }
303
304    /// Get all available conversions
305    pub fn available_algorithms(&self) -> Vec<&SklearnAlgorithm> {
306        self.algorithm_map.keys().collect()
307    }
308
309    /// Generate conversion report
310    pub fn conversion_report(&self) -> String {
311        let mut report = String::from("sklearn → Aprender Conversion Map\n");
312        report.push_str("===================================\n\n");
313
314        // Group by module
315        let mut by_module: HashMap<&str, Vec<(&SklearnAlgorithm, &AprenderAlgorithm)>> =
316            HashMap::new();
317
318        for (alg, aprender_alg) in &self.algorithm_map {
319            by_module.entry(alg.sklearn_module()).or_default().push((alg, aprender_alg));
320        }
321
322        for (module, algorithms) in &by_module {
323            report.push_str(&format!("## {}\n\n", module));
324
325            for (alg, aprender_alg) in algorithms {
326                report.push_str(&format!("{:?}:\n", alg));
327                report.push_str(&format!("  Template: {}\n", aprender_alg.code_template));
328                report.push_str(&format!("  Complexity: {:?}\n", aprender_alg.complexity));
329                report.push_str(&format!("  Imports: {}\n", aprender_alg.imports.join(", ")));
330                report.push_str(&format!(
331                    "  Usage:\n    {}\n\n",
332                    aprender_alg.usage_pattern.replace('\n', "\n    ")
333                ));
334            }
335            report.push('\n');
336        }
337
338        report
339    }
340}
341
342#[cfg(test)]
343mod tests {
344    use super::*;
345
346    #[test]
347    fn test_converter_creation() {
348        let converter = SklearnConverter::new();
349        assert!(!converter.available_algorithms().is_empty());
350    }
351
352    #[test]
353    fn test_algorithm_complexity() {
354        assert_eq!(
355            SklearnAlgorithm::LinearRegression.complexity(),
356            crate::backend::OpComplexity::Medium
357        );
358        assert_eq!(
359            SklearnAlgorithm::StandardScaler.complexity(),
360            crate::backend::OpComplexity::Low
361        );
362        assert_eq!(SklearnAlgorithm::KMeans.complexity(), crate::backend::OpComplexity::High);
363    }
364
365    #[test]
366    fn test_linear_regression_conversion() {
367        let converter = SklearnConverter::new();
368        let aprender_alg =
369            converter.convert(&SklearnAlgorithm::LinearRegression).expect("unexpected failure");
370        assert!(aprender_alg.code_template.contains("LinearRegression"));
371        assert!(aprender_alg.imports.iter().any(|i| i.contains("linear_model")));
372    }
373
374    #[test]
375    fn test_kmeans_conversion() {
376        let converter = SklearnConverter::new();
377        let aprender_alg = converter.convert(&SklearnAlgorithm::KMeans).expect("conversion failed");
378        assert!(aprender_alg.code_template.contains("KMeans"));
379        assert!(aprender_alg.imports.iter().any(|i| i.contains("cluster")));
380    }
381
382    #[test]
383    fn test_backend_recommendation() {
384        let converter = SklearnConverter::new();
385
386        // Small dataset with preprocessing should use Scalar
387        let backend = converter.recommend_backend(&SklearnAlgorithm::StandardScaler, 100);
388        assert_eq!(backend, crate::backend::Backend::Scalar);
389
390        // Large dataset with linear model should use SIMD
391        let backend = converter.recommend_backend(&SklearnAlgorithm::LinearRegression, 50_000);
392        assert_eq!(backend, crate::backend::Backend::SIMD);
393
394        // Large dataset with clustering should use GPU
395        let backend = converter.recommend_backend(&SklearnAlgorithm::KMeans, 100_000);
396        assert_eq!(backend, crate::backend::Backend::GPU);
397    }
398
399    #[test]
400    fn test_sklearn_module_paths() {
401        assert_eq!(SklearnAlgorithm::LinearRegression.sklearn_module(), "sklearn.linear_model");
402        assert_eq!(SklearnAlgorithm::KMeans.sklearn_module(), "sklearn.cluster");
403        assert_eq!(SklearnAlgorithm::StandardScaler.sklearn_module(), "sklearn.preprocessing");
404    }
405
406    #[test]
407    fn test_conversion_report() {
408        let converter = SklearnConverter::new();
409        let report = converter.conversion_report();
410        assert!(report.contains("sklearn → Aprender"));
411        assert!(report.contains("LinearRegression"));
412        assert!(report.contains("Complexity"));
413    }
414
415    // ============================================================================
416    // SKLEARN ALGORITHM ENUM TESTS
417    // ============================================================================
418
419    #[test]
420    fn test_all_sklearn_algorithms_exist() {
421        // Test all 22 variants can be constructed
422        let algs = vec![
423            SklearnAlgorithm::LinearRegression,
424            SklearnAlgorithm::Ridge,
425            SklearnAlgorithm::Lasso,
426            SklearnAlgorithm::LogisticRegression,
427            SklearnAlgorithm::KMeans,
428            SklearnAlgorithm::DBSCAN,
429            SklearnAlgorithm::DecisionTreeClassifier,
430            SklearnAlgorithm::DecisionTreeRegressor,
431            SklearnAlgorithm::RandomForestClassifier,
432            SklearnAlgorithm::RandomForestRegressor,
433            SklearnAlgorithm::StandardScaler,
434            SklearnAlgorithm::MinMaxScaler,
435            SklearnAlgorithm::LabelEncoder,
436            SklearnAlgorithm::TrainTestSplit,
437            SklearnAlgorithm::CrossValidation,
438            SklearnAlgorithm::Accuracy,
439            SklearnAlgorithm::Precision,
440            SklearnAlgorithm::Recall,
441            SklearnAlgorithm::F1Score,
442            SklearnAlgorithm::MeanSquaredError,
443            SklearnAlgorithm::R2Score,
444        ];
445        assert_eq!(algs.len(), 21); // 21 algorithms tested
446    }
447
448    #[test]
449    fn test_algorithm_equality() {
450        assert_eq!(SklearnAlgorithm::LinearRegression, SklearnAlgorithm::LinearRegression);
451        assert_ne!(SklearnAlgorithm::LinearRegression, SklearnAlgorithm::KMeans);
452    }
453
454    #[test]
455    fn test_algorithm_clone() {
456        let alg1 = SklearnAlgorithm::DecisionTreeClassifier;
457        let alg2 = alg1.clone();
458        assert_eq!(alg1, alg2);
459    }
460
461    #[test]
462    fn test_complexity_low_algorithms() {
463        let low_algs = vec![
464            SklearnAlgorithm::StandardScaler,
465            SklearnAlgorithm::MinMaxScaler,
466            SklearnAlgorithm::LabelEncoder,
467            SklearnAlgorithm::TrainTestSplit,
468        ];
469
470        for alg in low_algs {
471            assert_eq!(alg.complexity(), crate::backend::OpComplexity::Low);
472        }
473    }
474
475    #[test]
476    fn test_complexity_medium_algorithms() {
477        let medium_algs = vec![
478            SklearnAlgorithm::LinearRegression,
479            SklearnAlgorithm::Ridge,
480            SklearnAlgorithm::Lasso,
481            SklearnAlgorithm::LogisticRegression,
482            SklearnAlgorithm::Accuracy,
483            SklearnAlgorithm::Precision,
484            SklearnAlgorithm::Recall,
485            SklearnAlgorithm::F1Score,
486            SklearnAlgorithm::MeanSquaredError,
487            SklearnAlgorithm::R2Score,
488        ];
489
490        for alg in medium_algs {
491            assert_eq!(alg.complexity(), crate::backend::OpComplexity::Medium);
492        }
493    }
494
495    #[test]
496    fn test_complexity_high_algorithms() {
497        let high_algs = vec![
498            SklearnAlgorithm::DecisionTreeClassifier,
499            SklearnAlgorithm::DecisionTreeRegressor,
500            SklearnAlgorithm::RandomForestClassifier,
501            SklearnAlgorithm::RandomForestRegressor,
502            SklearnAlgorithm::KMeans,
503            SklearnAlgorithm::DBSCAN,
504            SklearnAlgorithm::CrossValidation,
505        ];
506
507        for alg in high_algs {
508            assert_eq!(alg.complexity(), crate::backend::OpComplexity::High);
509        }
510    }
511
512    #[test]
513    fn test_sklearn_module_linear_model() {
514        let linear_algs = vec![
515            SklearnAlgorithm::LinearRegression,
516            SklearnAlgorithm::Ridge,
517            SklearnAlgorithm::Lasso,
518            SklearnAlgorithm::LogisticRegression,
519        ];
520
521        for alg in linear_algs {
522            assert_eq!(alg.sklearn_module(), "sklearn.linear_model");
523        }
524    }
525
526    #[test]
527    fn test_sklearn_module_cluster() {
528        let cluster_algs = vec![SklearnAlgorithm::KMeans, SklearnAlgorithm::DBSCAN];
529
530        for alg in cluster_algs {
531            assert_eq!(alg.sklearn_module(), "sklearn.cluster");
532        }
533    }
534
535    #[test]
536    fn test_sklearn_module_tree() {
537        let tree_algs =
538            vec![SklearnAlgorithm::DecisionTreeClassifier, SklearnAlgorithm::DecisionTreeRegressor];
539
540        for alg in tree_algs {
541            assert_eq!(alg.sklearn_module(), "sklearn.tree");
542        }
543    }
544
545    #[test]
546    fn test_sklearn_module_ensemble() {
547        let ensemble_algs =
548            vec![SklearnAlgorithm::RandomForestClassifier, SklearnAlgorithm::RandomForestRegressor];
549
550        for alg in ensemble_algs {
551            assert_eq!(alg.sklearn_module(), "sklearn.ensemble");
552        }
553    }
554
555    #[test]
556    fn test_sklearn_module_preprocessing() {
557        let preprocessing_algs = vec![
558            SklearnAlgorithm::StandardScaler,
559            SklearnAlgorithm::MinMaxScaler,
560            SklearnAlgorithm::LabelEncoder,
561        ];
562
563        for alg in preprocessing_algs {
564            assert_eq!(alg.sklearn_module(), "sklearn.preprocessing");
565        }
566    }
567
568    #[test]
569    fn test_sklearn_module_model_selection() {
570        let model_selection_algs =
571            vec![SklearnAlgorithm::TrainTestSplit, SklearnAlgorithm::CrossValidation];
572
573        for alg in model_selection_algs {
574            assert_eq!(alg.sklearn_module(), "sklearn.model_selection");
575        }
576    }
577
578    #[test]
579    fn test_sklearn_module_metrics() {
580        let metrics_algs = vec![
581            SklearnAlgorithm::Accuracy,
582            SklearnAlgorithm::Precision,
583            SklearnAlgorithm::Recall,
584            SklearnAlgorithm::F1Score,
585            SklearnAlgorithm::MeanSquaredError,
586            SklearnAlgorithm::R2Score,
587        ];
588
589        for alg in metrics_algs {
590            assert_eq!(alg.sklearn_module(), "sklearn.metrics");
591        }
592    }
593
594    // ============================================================================
595    // APRENDER ALGORITHM STRUCT TESTS
596    // ============================================================================
597
598    #[test]
599    fn test_aprender_algorithm_construction() {
600        let alg = AprenderAlgorithm {
601            code_template: "test_template".to_string(),
602            imports: vec!["use test;".to_string()],
603            complexity: crate::backend::OpComplexity::Medium,
604            usage_pattern: "let x = test();".to_string(),
605        };
606
607        assert_eq!(alg.code_template, "test_template");
608        assert_eq!(alg.imports.len(), 1);
609        assert_eq!(alg.complexity, crate::backend::OpComplexity::Medium);
610        assert!(alg.usage_pattern.contains("test()"));
611    }
612
613    #[test]
614    fn test_aprender_algorithm_clone() {
615        let alg1 = AprenderAlgorithm {
616            code_template: "template".to_string(),
617            imports: vec!["import".to_string()],
618            complexity: crate::backend::OpComplexity::High,
619            usage_pattern: "usage".to_string(),
620        };
621
622        let alg2 = alg1.clone();
623        assert_eq!(alg1.code_template, alg2.code_template);
624        assert_eq!(alg1.imports, alg2.imports);
625        assert_eq!(alg1.complexity, alg2.complexity);
626    }
627
628    // ============================================================================
629    // SKLEARN CONVERTER TESTS
630    // ============================================================================
631
632    #[test]
633    fn test_converter_default() {
634        let converter = SklearnConverter::default();
635        assert!(!converter.available_algorithms().is_empty());
636    }
637
638    #[test]
639    fn test_convert_all_mapped_algorithms() {
640        let converter = SklearnConverter::new();
641
642        // Test all algorithms that should have mappings
643        let mapped_algs = vec![
644            SklearnAlgorithm::LinearRegression,
645            SklearnAlgorithm::LogisticRegression,
646            SklearnAlgorithm::KMeans,
647            SklearnAlgorithm::DecisionTreeClassifier,
648            SklearnAlgorithm::StandardScaler,
649            SklearnAlgorithm::TrainTestSplit,
650            SklearnAlgorithm::Accuracy,
651            SklearnAlgorithm::MeanSquaredError,
652        ];
653
654        for alg in mapped_algs {
655            assert!(converter.convert(&alg).is_some(), "Missing mapping for {:?}", alg);
656        }
657    }
658
659    #[test]
660    fn test_convert_unmapped_algorithm() {
661        let converter = SklearnConverter::new();
662
663        // Ridge, Lasso, etc. might not be mapped
664        // Just verify the function handles missing algorithms gracefully
665        let result = converter.convert(&SklearnAlgorithm::Ridge);
666        // It's ok if this is None - we're testing the API works
667        let _ = result;
668    }
669
670    #[test]
671    fn test_logistic_regression_conversion() {
672        let converter = SklearnConverter::new();
673        let alg =
674            converter.convert(&SklearnAlgorithm::LogisticRegression).expect("unexpected failure");
675
676        assert!(alg.code_template.contains("LogisticRegression"));
677        assert!(alg.imports.iter().any(|i| i.contains("classification")));
678        assert_eq!(alg.complexity, crate::backend::OpComplexity::Medium);
679    }
680
681    #[test]
682    fn test_decision_tree_conversion() {
683        let converter = SklearnConverter::new();
684        let alg = converter
685            .convert(&SklearnAlgorithm::DecisionTreeClassifier)
686            .expect("unexpected failure");
687
688        assert!(alg.code_template.contains("DecisionTreeClassifier"));
689        assert!(alg.imports.iter().any(|i| i.contains("tree")));
690        assert_eq!(alg.complexity, crate::backend::OpComplexity::High);
691    }
692
693    #[test]
694    fn test_standard_scaler_conversion() {
695        let converter = SklearnConverter::new();
696        let alg = converter.convert(&SklearnAlgorithm::StandardScaler).expect("unexpected failure");
697
698        assert!(alg.code_template.contains("StandardScaler"));
699        assert!(alg.imports.iter().any(|i| i.contains("preprocessing")));
700        assert_eq!(alg.complexity, crate::backend::OpComplexity::Low);
701    }
702
703    #[test]
704    fn test_train_test_split_conversion() {
705        let converter = SklearnConverter::new();
706        let alg = converter.convert(&SklearnAlgorithm::TrainTestSplit).expect("unexpected failure");
707
708        assert!(alg.code_template.contains("train_test_split"));
709        assert!(alg.imports.iter().any(|i| i.contains("model_selection")));
710    }
711
712    #[test]
713    fn test_accuracy_conversion() {
714        let converter = SklearnConverter::new();
715        let alg = converter.convert(&SklearnAlgorithm::Accuracy).expect("conversion failed");
716
717        assert!(alg.code_template.contains("accuracy_score"));
718        assert!(alg.imports.iter().any(|i| i.contains("metrics")));
719    }
720
721    #[test]
722    fn test_mse_conversion() {
723        let converter = SklearnConverter::new();
724        let alg =
725            converter.convert(&SklearnAlgorithm::MeanSquaredError).expect("unexpected failure");
726
727        assert!(alg.code_template.contains("mean_squared_error"));
728        assert!(alg.imports.iter().any(|i| i.contains("metrics")));
729    }
730
731    #[test]
732    fn test_available_algorithms() {
733        let converter = SklearnConverter::new();
734        let algs = converter.available_algorithms();
735
736        assert!(!algs.is_empty());
737        // Should have at least the mapped algorithms
738        assert!(algs.len() >= 8);
739    }
740
741    #[test]
742    fn test_recommend_backend_low_complexity() {
743        let converter = SklearnConverter::new();
744
745        // Small data size with low complexity should use Scalar
746        let backend = converter.recommend_backend(&SklearnAlgorithm::StandardScaler, 10);
747        assert_eq!(backend, crate::backend::Backend::Scalar);
748    }
749
750    #[test]
751    fn test_recommend_backend_medium_complexity() {
752        let converter = SklearnConverter::new();
753
754        // Medium data size with medium complexity should use SIMD
755        let backend = converter.recommend_backend(&SklearnAlgorithm::LinearRegression, 50_000);
756        assert_eq!(backend, crate::backend::Backend::SIMD);
757    }
758
759    #[test]
760    fn test_recommend_backend_high_complexity() {
761        let converter = SklearnConverter::new();
762
763        // Large data size with high complexity should use GPU
764        let backend =
765            converter.recommend_backend(&SklearnAlgorithm::RandomForestClassifier, 500_000);
766        assert_eq!(backend, crate::backend::Backend::GPU);
767    }
768
769    #[test]
770    fn test_recommend_backend_clustering() {
771        let converter = SklearnConverter::new();
772
773        // Clustering with large data should use GPU
774        let backend = converter.recommend_backend(&SklearnAlgorithm::KMeans, 1_000_000);
775        assert_eq!(backend, crate::backend::Backend::GPU);
776    }
777
778    #[test]
779    fn test_conversion_report_structure() {
780        let converter = SklearnConverter::new();
781        let report = converter.conversion_report();
782
783        // Check report contains expected sections
784        assert!(report.contains("sklearn → Aprender"));
785        assert!(report.contains("==="));
786        assert!(report.contains("##")); // Module headers
787        assert!(report.contains("Template:"));
788        assert!(report.contains("Imports:"));
789        assert!(report.contains("Usage:"));
790    }
791
792    #[test]
793    fn test_conversion_report_has_modules() {
794        let converter = SklearnConverter::new();
795        let report = converter.conversion_report();
796
797        // Should group by sklearn modules
798        assert!(report.contains("sklearn"));
799    }
800
801    #[test]
802    fn test_conversion_report_has_all_algorithms() {
803        let converter = SklearnConverter::new();
804        let report = converter.conversion_report();
805
806        // Spot check a few algorithms appear in report
807        assert!(
808            report.contains("LinearRegression")
809                || report.contains("KMeans")
810                || report.contains("StandardScaler")
811        );
812    }
813
814    #[test]
815    fn test_usage_patterns_not_empty() {
816        let converter = SklearnConverter::new();
817
818        for alg in converter.available_algorithms() {
819            if let Some(aprender_alg) = converter.convert(alg) {
820                assert!(
821                    !aprender_alg.usage_pattern.is_empty(),
822                    "Empty usage pattern for {:?}",
823                    alg
824                );
825                assert!(
826                    !aprender_alg.code_template.is_empty(),
827                    "Empty code template for {:?}",
828                    alg
829                );
830                assert!(!aprender_alg.imports.is_empty(), "Empty imports for {:?}", alg);
831            }
832        }
833    }
834
835    #[test]
836    fn test_imports_are_valid_rust() {
837        let converter = SklearnConverter::new();
838
839        for alg in converter.available_algorithms() {
840            if let Some(aprender_alg) = converter.convert(alg) {
841                for import in &aprender_alg.imports {
842                    assert!(import.starts_with("use "), "Invalid import syntax: {}", import);
843                    assert!(import.ends_with(';'), "Import missing semicolon: {}", import);
844                }
845            }
846        }
847    }
848
849    #[test]
850    fn test_linear_models_have_estimator_trait() {
851        let converter = SklearnConverter::new();
852
853        let linear_models =
854            vec![SklearnAlgorithm::LinearRegression, SklearnAlgorithm::LogisticRegression];
855
856        for alg in linear_models {
857            if let Some(aprender_alg) = converter.convert(&alg) {
858                assert!(
859                    aprender_alg.imports.iter().any(|i| i.contains("Estimator")),
860                    "Linear model {:?} should import Estimator trait",
861                    alg
862                );
863            }
864        }
865    }
866
867    #[test]
868    fn test_clustering_has_unsupervised_trait() {
869        let converter = SklearnConverter::new();
870
871        if let Some(kmeans_alg) = converter.convert(&SklearnAlgorithm::KMeans) {
872            assert!(
873                kmeans_alg.imports.iter().any(|i| i.contains("UnsupervisedEstimator")),
874                "KMeans should import UnsupervisedEstimator trait"
875            );
876        }
877    }
878
879    #[test]
880    fn test_preprocessing_has_transformer_trait() {
881        let converter = SklearnConverter::new();
882
883        if let Some(scaler_alg) = converter.convert(&SklearnAlgorithm::StandardScaler) {
884            assert!(
885                scaler_alg.imports.iter().any(|i| i.contains("Transformer")),
886                "StandardScaler should import Transformer trait"
887            );
888        }
889    }
890}