sklearn_pipeline_demo/
sklearn_pipeline_demo.rs

1//! Scikit-learn Compatible Quantum ML Pipeline Example
2//!
3//! This example demonstrates the scikit-learn compatibility layer, showing how to use
4//! quantum models with familiar sklearn APIs including pipelines, cross-validation, and grid search.
5
6use quantrs2_ml::prelude::*;
7use quantrs2_ml::sklearn_compatibility::{
8    metrics, model_selection, Pipeline, QuantumFeatureEncoder, SelectKBest, SklearnFit,
9    StandardScaler,
10};
11use scirs2_core::ndarray::{s, Array1, Array2, Axis};
12use scirs2_core::random::prelude::*;
13use std::collections::HashMap;
14
15#[allow(non_snake_case)]
16fn main() -> Result<()> {
17    println!("=== Scikit-learn Compatible Quantum ML Demo ===\n");
18
19    // Step 1: Create sklearn-style dataset
20    println!("1. Creating scikit-learn style dataset...");
21
22    let (X, y) = create_sklearn_dataset()?;
23    println!("   - Dataset shape: {:?}", X.dim());
24    println!(
25        "   - Labels: {} classes",
26        y.iter()
27            .map(|&x| x as i32)
28            .collect::<std::collections::HashSet<_>>()
29            .len()
30    );
31    println!(
32        "   - Feature range: [{:.3}, {:.3}]",
33        X.iter().fold(f64::INFINITY, |a, &b| a.min(b)),
34        X.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b))
35    );
36
37    // Step 2: Create sklearn-compatible quantum estimators
38    println!("\n2. Creating sklearn-compatible quantum estimators...");
39
40    // Quantum Support Vector Classifier
41    let qsvc = QuantumSVC::new();
42
43    // Quantum Multi-Layer Perceptron Classifier
44    let qmlp = QuantumMLPClassifier::new();
45
46    // Quantum K-Means Clustering
47    let mut qkmeans = QuantumKMeans::new(2); // n_clusters
48
49    println!("   - QuantumSVC: quantum kernel");
50    println!("   - QuantumMLP: multi-layer perceptron");
51    println!("   - QuantumKMeans: 2 clusters");
52
53    // Step 3: Create sklearn-style preprocessing pipeline
54    println!("\n3. Building sklearn-compatible preprocessing pipeline...");
55
56    let preprocessing_pipeline = Pipeline::new(vec![
57        ("scaler", Box::new(StandardScaler::new())),
58        (
59            "feature_selection",
60            Box::new(SelectKBest::new(
61                "quantum_mutual_info", // score_func
62                3,                     // k
63            )),
64        ),
65        (
66            "quantum_encoder",
67            Box::new(QuantumFeatureEncoder::new(
68                "angle", // encoding_type
69                "l2",    // normalization
70            )),
71        ),
72    ])?;
73
74    // Step 4: Create complete quantum ML pipeline
75    println!("\n4. Creating complete quantum ML pipeline...");
76
77    let quantum_pipeline = Pipeline::new(vec![
78        ("preprocessing", Box::new(preprocessing_pipeline)),
79        ("classifier", Box::new(qsvc)),
80    ])?;
81
82    println!("   Pipeline steps:");
83    for (i, step_name) in quantum_pipeline.named_steps().iter().enumerate() {
84        println!("   {}. {}", i + 1, step_name);
85    }
86
87    // Step 5: Train-test split (sklearn style)
88    println!("\n5. Performing train-test split...");
89
90    let (X_train, X_test, y_train, y_test) = model_selection::train_test_split(
91        &X,
92        &y,
93        0.3,      // test_size
94        Some(42), // random_state
95    )?;
96
97    println!("   - Training set: {:?}", X_train.dim());
98    println!("   - Test set: {:?}", X_test.dim());
99
100    // Step 6: Cross-validation with quantum models
101    println!("\n6. Performing cross-validation...");
102
103    let mut pipeline_clone = quantum_pipeline.clone();
104    let cv_scores = model_selection::cross_val_score(
105        &mut pipeline_clone,
106        &X_train,
107        &y_train,
108        5, // cv
109    )?;
110
111    println!("   Cross-validation scores: {cv_scores:?}");
112    println!(
113        "   Mean CV accuracy: {:.3} (+/- {:.3})",
114        cv_scores.mean().unwrap(),
115        cv_scores.std(0.0) * 2.0
116    );
117
118    // Step 7: Hyperparameter grid search
119    println!("\n7. Hyperparameter optimization with GridSearchCV...");
120
121    let param_grid = HashMap::from([
122        (
123            "classifier__C".to_string(),
124            vec!["0.1".to_string(), "1.0".to_string(), "10.0".to_string()],
125        ),
126        (
127            "classifier__feature_map_depth".to_string(),
128            vec!["1".to_string(), "2".to_string(), "3".to_string()],
129        ),
130        (
131            "preprocessing__feature_selection__k".to_string(),
132            vec!["2".to_string(), "3".to_string(), "4".to_string()],
133        ),
134    ]);
135
136    let mut grid_search = model_selection::GridSearchCV::new(
137        quantum_pipeline, // estimator
138        param_grid,
139        3, // cv
140    );
141
142    grid_search.fit(&X_train, &y_train)?;
143
144    println!("   Best parameters: {:?}", grid_search.best_params_);
145    println!(
146        "   Best cross-validation score: {:.3}",
147        grid_search.best_score_
148    );
149
150    // Step 8: Train best model and evaluate
151    println!("\n8. Training best model and evaluation...");
152
153    let best_model = grid_search.best_estimator_;
154    let y_pred = best_model.predict(&X_test)?;
155
156    // Calculate metrics using sklearn-style functions
157    let y_test_int = y_test.mapv(|x| x.round() as i32);
158    let accuracy = metrics::accuracy_score(&y_test_int, &y_pred);
159    let precision = metrics::precision_score(&y_test_int, &y_pred, "weighted"); // average
160    let recall = metrics::recall_score(&y_test_int, &y_pred, "weighted"); // average
161    let f1 = metrics::f1_score(&y_test_int, &y_pred, "weighted"); // average
162
163    println!("   Test Results:");
164    println!("   - Accuracy: {accuracy:.3}");
165    println!("   - Precision: {precision:.3}");
166    println!("   - Recall: {recall:.3}");
167    println!("   - F1-score: {f1:.3}");
168
169    // Step 9: Classification report
170    println!("\n9. Detailed classification report...");
171
172    let classification_report = metrics::classification_report(
173        &y_test_int,
174        &y_pred,
175        vec!["Class 0", "Class 1"], // target_names
176        3,                          // digits
177    );
178    println!("{classification_report}");
179
180    // Step 10: Feature importance analysis
181    println!("\n10. Feature importance analysis...");
182
183    if let Some(feature_importances) = best_model.feature_importances() {
184        println!("    Quantum Feature Importances:");
185        for (i, importance) in feature_importances.iter().enumerate() {
186            println!("    - Feature {i}: {importance:.4}");
187        }
188    }
189
190    // Step 11: Model comparison with classical sklearn models
191    println!("\n11. Comparing with classical sklearn models...");
192
193    let classical_models = vec![
194        (
195            "Logistic Regression",
196            Box::new(LogisticRegression::new()) as Box<dyn SklearnClassifier>,
197        ),
198        (
199            "Random Forest",
200            Box::new(RandomForestClassifier::new()) as Box<dyn SklearnClassifier>,
201        ),
202        ("SVM", Box::new(SVC::new()) as Box<dyn SklearnClassifier>),
203    ];
204
205    let mut comparison_results = Vec::new();
206
207    for (name, mut model) in classical_models {
208        model.fit(&X_train, Some(&y_train))?;
209        let y_pred_classical = model.predict(&X_test)?;
210        let classical_accuracy = metrics::accuracy_score(&y_test_int, &y_pred_classical);
211        comparison_results.push((name, classical_accuracy));
212    }
213
214    println!("    Model Comparison:");
215    println!("    - Quantum Pipeline: {accuracy:.3}");
216    for (name, classical_accuracy) in comparison_results {
217        println!("    - {name}: {classical_accuracy:.3}");
218    }
219
220    // Step 12: Clustering with quantum K-means
221    println!("\n12. Quantum clustering analysis...");
222
223    let cluster_labels = qkmeans.fit_predict(&X)?;
224    let silhouette_score = metrics::silhouette_score(&X, &cluster_labels, "euclidean"); // metric
225    let calinski_score = metrics::calinski_harabasz_score(&X, &cluster_labels);
226
227    println!("    Clustering Results:");
228    println!("    - Silhouette Score: {silhouette_score:.3}");
229    println!("    - Calinski-Harabasz Score: {calinski_score:.3}");
230    println!(
231        "    - Unique clusters found: {}",
232        cluster_labels
233            .iter()
234            .collect::<std::collections::HashSet<_>>()
235            .len()
236    );
237
238    // Step 13: Model persistence (sklearn style)
239    println!("\n13. Model persistence (sklearn joblib style)...");
240
241    // Save model
242    best_model.save("quantum_sklearn_model.joblib")?;
243    println!("    - Model saved to: quantum_sklearn_model.joblib");
244
245    // Load model
246    let loaded_model = QuantumSVC::load("quantum_sklearn_model.joblib")?;
247    let test_subset = X_test.slice(s![..5, ..]).to_owned();
248    let y_pred_loaded = loaded_model.predict(&test_subset)?;
249    println!("    - Model loaded and tested on 5 samples");
250
251    // Step 14: Advanced sklearn utilities
252    println!("\n14. Advanced sklearn utilities...");
253
254    // Learning curves (commented out - function not available)
255    // let (train_sizes, train_scores, val_scores) = model_selection::learning_curve(...)?;
256    println!("    Learning Curve Analysis: (Mock results)");
257    let train_sizes = [0.1, 0.33, 0.55, 0.78, 1.0];
258    let train_scores = [0.65, 0.72, 0.78, 0.82, 0.85];
259    let val_scores = [0.62, 0.70, 0.76, 0.79, 0.81];
260
261    for (i, &size) in train_sizes.iter().enumerate() {
262        println!(
263            "    - {:.0}% data: train={:.3}, val={:.3}",
264            size * 100.0,
265            train_scores[i],
266            val_scores[i]
267        );
268    }
269
270    // Validation curves (commented out - function not available)
271    // let (train_scores_val, test_scores_val) = model_selection::validation_curve(...)?;
272    println!("    Validation Curve (C parameter): (Mock results)");
273    let param_range = [0.1, 0.5, 1.0, 2.0, 5.0];
274    let train_scores_val = [0.70, 0.75, 0.80, 0.78, 0.75];
275    let test_scores_val = [0.68, 0.73, 0.78, 0.76, 0.72];
276
277    for (i, &param_value) in param_range.iter().enumerate() {
278        println!(
279            "    - C={}: train={:.3}, test={:.3}",
280            param_value, train_scores_val[i], test_scores_val[i]
281        );
282    }
283
284    // Step 15: Quantum-specific sklearn extensions
285    println!("\n15. Quantum-specific sklearn extensions...");
286
287    // Quantum feature analysis
288    let quantum_feature_analysis = analyze_quantum_features(&best_model, &X_test)?;
289    println!("    Quantum Feature Analysis:");
290    println!(
291        "    - Quantum advantage score: {:.3}",
292        quantum_feature_analysis.advantage_score
293    );
294    println!(
295        "    - Feature entanglement: {:.3}",
296        quantum_feature_analysis.entanglement_measure
297    );
298    println!(
299        "    - Circuit depth efficiency: {:.3}",
300        quantum_feature_analysis.circuit_efficiency
301    );
302
303    // Quantum model interpretation
304    let sample_row = X_test.row(0).to_owned();
305    let quantum_interpretation = interpret_quantum_model(&best_model, &sample_row)?;
306    println!("    Quantum Model Interpretation (sample 0):");
307    println!(
308        "    - Quantum state fidelity: {:.3}",
309        quantum_interpretation.state_fidelity
310    );
311    println!(
312        "    - Feature contributions: {:?}",
313        quantum_interpretation.feature_contributions
314    );
315
316    println!("\n=== Scikit-learn Integration Demo Complete ===");
317
318    Ok(())
319}
320
321#[allow(non_snake_case)]
322fn create_sklearn_dataset() -> Result<(Array2<f64>, Array1<f64>)> {
323    let num_samples = 300;
324    let num_features = 4;
325
326    // Create a dataset similar to sklearn's make_classification
327    let X = Array2::from_shape_fn((num_samples, num_features), |(i, j)| {
328        let base = (i as f64).mul_add(0.02, j as f64 * 0.5);
329        let noise = fastrand::f64().mul_add(0.3, -0.15);
330        base.sin() + noise
331    });
332
333    // Create separable classes
334    let y = Array1::from_shape_fn(num_samples, |i| {
335        let feature_sum = (0..num_features).map(|j| X[[i, j]]).sum::<f64>();
336        if feature_sum > 0.0 {
337            1.0
338        } else {
339            0.0
340        }
341    });
342
343    Ok((X, y))
344}
345
346#[allow(non_snake_case)] // X is standard ML convention for feature matrix
347fn analyze_quantum_features(
348    model: &dyn SklearnClassifier,
349    X: &Array2<f64>,
350) -> Result<QuantumFeatureAnalysis> {
351    // Analyze quantum-specific properties
352    let predictions_quantum = model.predict(X)?;
353
354    // Create classical baseline for comparison
355    let mut classical_model = LogisticRegression::new();
356    SklearnFit::fit(
357        &mut classical_model,
358        X,
359        &predictions_quantum.mapv(f64::from),
360    )?; // Use quantum predictions as targets
361    let predictions_classical = classical_model.predict(X)?;
362
363    // Calculate quantum advantage score
364    let advantage_score = predictions_quantum
365        .iter()
366        .zip(predictions_classical.iter())
367        .map(|(&q, &c)| (f64::from(q) - f64::from(c)).abs())
368        .sum::<f64>()
369        / predictions_quantum.len() as f64;
370
371    Ok(QuantumFeatureAnalysis {
372        advantage_score,
373        entanglement_measure: 0.75, // Mock value
374        circuit_efficiency: 0.85,   // Mock value
375    })
376}
377
378fn interpret_quantum_model(
379    model: &dyn SklearnClassifier,
380    sample: &Array1<f64>,
381) -> Result<QuantumInterpretation> {
382    // Quantum model interpretation
383    let prediction = model.predict(&sample.clone().insert_axis(Axis(0)))?;
384
385    Ok(QuantumInterpretation {
386        state_fidelity: 0.92,                            // Mock value
387        feature_contributions: vec![0.3, 0.2, 0.4, 0.1], // Mock values
388        prediction: f64::from(prediction[0]),
389    })
390}
391
392// Supporting structures and trait implementations
393
394struct QuantumFeatureAnalysis {
395    advantage_score: f64,
396    entanglement_measure: f64,
397    circuit_efficiency: f64,
398}
399
400struct QuantumInterpretation {
401    state_fidelity: f64,
402    feature_contributions: Vec<f64>,
403    prediction: f64,
404}
405
406// Mock implementations for classical sklearn models
407struct LogisticRegression {
408    fitted: bool,
409}
410
411impl LogisticRegression {
412    const fn new() -> Self {
413        Self { fitted: false }
414    }
415}
416
417#[allow(non_snake_case)]
418impl SklearnEstimator for LogisticRegression {
419    fn fit(&mut self, _X: &Array2<f64>, _y: Option<&Array1<f64>>) -> Result<()> {
420        self.fitted = true;
421        Ok(())
422    }
423
424    fn get_params(&self) -> HashMap<String, String> {
425        HashMap::new()
426    }
427
428    fn set_params(&mut self, _params: HashMap<String, String>) -> Result<()> {
429        Ok(())
430    }
431
432    fn is_fitted(&self) -> bool {
433        self.fitted
434    }
435}
436
437#[allow(non_snake_case)]
438impl SklearnClassifier for LogisticRegression {
439    fn predict(&self, X: &Array2<f64>) -> Result<Array1<i32>> {
440        if !self.fitted {
441            return Err(MLError::InvalidConfiguration(
442                "Model not fitted".to_string(),
443            ));
444        }
445        // Mock predictions
446        Ok(Array1::from_shape_fn(X.nrows(), |i| i32::from(i % 2 == 0)))
447    }
448
449    fn predict_proba(&self, X: &Array2<f64>) -> Result<Array2<f64>> {
450        Ok(Array2::from_shape_fn((X.nrows(), 2), |(i, j)| {
451            if j == 0 {
452                0.3
453            } else {
454                0.7
455            }
456        }))
457    }
458
459    fn classes(&self) -> &[i32] {
460        &[0, 1]
461    }
462}
463
464#[allow(non_snake_case)]
465impl SklearnFit for LogisticRegression {
466    fn fit(&mut self, _X: &Array2<f64>, _y: &Array1<f64>) -> Result<()> {
467        self.fitted = true;
468        Ok(())
469    }
470}
471
472struct RandomForestClassifier {
473    fitted: bool,
474}
475
476impl RandomForestClassifier {
477    const fn new() -> Self {
478        Self { fitted: false }
479    }
480}
481
482#[allow(non_snake_case)]
483impl SklearnEstimator for RandomForestClassifier {
484    fn fit(&mut self, _X: &Array2<f64>, _y: Option<&Array1<f64>>) -> Result<()> {
485        self.fitted = true;
486        Ok(())
487    }
488
489    fn get_params(&self) -> HashMap<String, String> {
490        HashMap::new()
491    }
492
493    fn set_params(&mut self, _params: HashMap<String, String>) -> Result<()> {
494        Ok(())
495    }
496
497    fn is_fitted(&self) -> bool {
498        self.fitted
499    }
500}
501
502#[allow(non_snake_case)]
503impl SklearnClassifier for RandomForestClassifier {
504    fn predict(&self, X: &Array2<f64>) -> Result<Array1<i32>> {
505        if !self.fitted {
506            return Err(MLError::InvalidConfiguration(
507                "Model not fitted".to_string(),
508            ));
509        }
510        // Mock predictions with higher accuracy
511        Ok(Array1::from_shape_fn(X.nrows(), |i| {
512            i32::from((i * 3) % 4 != 0)
513        }))
514    }
515
516    fn predict_proba(&self, X: &Array2<f64>) -> Result<Array2<f64>> {
517        Ok(Array2::from_shape_fn((X.nrows(), 2), |(i, j)| {
518            if j == 0 {
519                0.4
520            } else {
521                0.6
522            }
523        }))
524    }
525
526    fn classes(&self) -> &[i32] {
527        &[0, 1]
528    }
529}
530
531#[allow(non_snake_case)]
532impl SklearnFit for RandomForestClassifier {
533    fn fit(&mut self, _X: &Array2<f64>, _y: &Array1<f64>) -> Result<()> {
534        self.fitted = true;
535        Ok(())
536    }
537}
538
539struct SVC {
540    fitted: bool,
541}
542
543impl SVC {
544    const fn new() -> Self {
545        Self { fitted: false }
546    }
547}
548
549#[allow(non_snake_case)]
550impl SklearnEstimator for SVC {
551    fn fit(&mut self, _X: &Array2<f64>, _y: Option<&Array1<f64>>) -> Result<()> {
552        self.fitted = true;
553        Ok(())
554    }
555
556    fn get_params(&self) -> HashMap<String, String> {
557        HashMap::new()
558    }
559
560    fn set_params(&mut self, _params: HashMap<String, String>) -> Result<()> {
561        Ok(())
562    }
563
564    fn is_fitted(&self) -> bool {
565        self.fitted
566    }
567}
568
569#[allow(non_snake_case)]
570impl SklearnClassifier for SVC {
571    fn predict(&self, X: &Array2<f64>) -> Result<Array1<i32>> {
572        if !self.fitted {
573            return Err(MLError::InvalidConfiguration(
574                "Model not fitted".to_string(),
575            ));
576        }
577        // Mock predictions
578        Ok(Array1::from_shape_fn(X.nrows(), |i| i32::from(i % 3 != 0)))
579    }
580
581    fn predict_proba(&self, X: &Array2<f64>) -> Result<Array2<f64>> {
582        Ok(Array2::from_shape_fn((X.nrows(), 2), |(i, j)| {
583            if j == 0 {
584                0.35
585            } else {
586                0.65
587            }
588        }))
589    }
590
591    fn classes(&self) -> &[i32] {
592        &[0, 1]
593    }
594}
595
596#[allow(non_snake_case)]
597impl SklearnFit for SVC {
598    fn fit(&mut self, _X: &Array2<f64>, _y: &Array1<f64>) -> Result<()> {
599        self.fitted = true;
600        Ok(())
601    }
602}