pub fn train_test_split(
X: &Array2<f64>,
y: &Array1<f64>,
test_size: f64,
random_state: Option<u64>,
) -> Result<(Array2<f64>, Array2<f64>, Array1<f64>, Array1<f64>)>Expand description
Train-test split
Examples found in repository?
examples/sklearn_pipeline_demo.rs (lines 99-104)
25fn main() -> Result<()> {
26 println!("=== Scikit-learn Compatible Quantum ML Demo ===\n");
27
28 // Step 1: Create sklearn-style dataset
29 println!("1. Creating scikit-learn style dataset...");
30
31 let (X, y) = create_sklearn_dataset()?;
32 println!(" - Dataset shape: {:?}", X.dim());
33 println!(
34 " - Labels: {} classes",
35 y.iter()
36 .map(|&x| x as i32)
37 .collect::<std::collections::HashSet<_>>()
38 .len()
39 );
40 println!(
41 " - Feature range: [{:.3}, {:.3}]",
42 X.iter().fold(f64::INFINITY, |a, &b| a.min(b)),
43 X.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b))
44 );
45
46 // Step 2: Create sklearn-compatible quantum estimators
47 println!("\n2. Creating sklearn-compatible quantum estimators...");
48
49 // Quantum Support Vector Classifier
50 let qsvc = QuantumSVC::new();
51
52 // Quantum Multi-Layer Perceptron Classifier
53 let qmlp = QuantumMLPClassifier::new();
54
55 // Quantum K-Means Clustering
56 let mut qkmeans = QuantumKMeans::new(2); // n_clusters
57
58 println!(" - QuantumSVC: quantum kernel");
59 println!(" - QuantumMLP: multi-layer perceptron");
60 println!(" - QuantumKMeans: 2 clusters");
61
62 // Step 3: Create sklearn-style preprocessing pipeline
63 println!("\n3. Building sklearn-compatible preprocessing pipeline...");
64
65 let preprocessing_pipeline = Pipeline::new(vec![
66 ("scaler", Box::new(StandardScaler::new())),
67 (
68 "feature_selection",
69 Box::new(SelectKBest::new(
70 "quantum_mutual_info", // score_func
71 3, // k
72 )),
73 ),
74 (
75 "quantum_encoder",
76 Box::new(QuantumFeatureEncoder::new(
77 "angle", // encoding_type
78 "l2", // normalization
79 )),
80 ),
81 ])?;
82
83 // Step 4: Create complete quantum ML pipeline
84 println!("\n4. Creating complete quantum ML pipeline...");
85
86 let quantum_pipeline = Pipeline::new(vec![
87 ("preprocessing", Box::new(preprocessing_pipeline)),
88 ("classifier", Box::new(qsvc)),
89 ])?;
90
91 println!(" Pipeline steps:");
92 for (i, step_name) in quantum_pipeline.named_steps().iter().enumerate() {
93 println!(" {}. {}", i + 1, step_name);
94 }
95
96 // Step 5: Train-test split (sklearn style)
97 println!("\n5. Performing train-test split...");
98
99 let (X_train, X_test, y_train, y_test) = model_selection::train_test_split(
100 &X,
101 &y,
102 0.3, // test_size
103 Some(42), // random_state
104 )?;
105
106 println!(" - Training set: {:?}", X_train.dim());
107 println!(" - Test set: {:?}", X_test.dim());
108
109 // Step 6: Cross-validation with quantum models
110 println!("\n6. Performing cross-validation...");
111
112 let mut pipeline_clone = quantum_pipeline.clone();
113 let cv_scores = model_selection::cross_val_score(
114 &mut pipeline_clone,
115 &X_train,
116 &y_train,
117 5, // cv
118 )?;
119
120 println!(" Cross-validation scores: {cv_scores:?}");
121 println!(
122 " Mean CV accuracy: {:.3} (+/- {:.3})",
123 cv_scores.mean().unwrap(),
124 cv_scores.std(0.0) * 2.0
125 );
126
127 // Step 7: Hyperparameter grid search
128 println!("\n7. Hyperparameter optimization with GridSearchCV...");
129
130 let param_grid = HashMap::from([
131 (
132 "classifier__C".to_string(),
133 vec!["0.1".to_string(), "1.0".to_string(), "10.0".to_string()],
134 ),
135 (
136 "classifier__feature_map_depth".to_string(),
137 vec!["1".to_string(), "2".to_string(), "3".to_string()],
138 ),
139 (
140 "preprocessing__feature_selection__k".to_string(),
141 vec!["2".to_string(), "3".to_string(), "4".to_string()],
142 ),
143 ]);
144
145 let mut grid_search = model_selection::GridSearchCV::new(
146 quantum_pipeline, // estimator
147 param_grid,
148 3, // cv
149 );
150
151 grid_search.fit(&X_train, &y_train)?;
152
153 println!(" Best parameters: {:?}", grid_search.best_params_);
154 println!(
155 " Best cross-validation score: {:.3}",
156 grid_search.best_score_
157 );
158
159 // Step 8: Train best model and evaluate
160 println!("\n8. Training best model and evaluation...");
161
162 let best_model = grid_search.best_estimator_;
163 let y_pred = best_model.predict(&X_test)?;
164
165 // Calculate metrics using sklearn-style functions
166 let y_test_int = y_test.mapv(|x| x.round() as i32);
167 let accuracy = metrics::accuracy_score(&y_test_int, &y_pred);
168 let precision = metrics::precision_score(&y_test_int, &y_pred, "weighted"); // average
169 let recall = metrics::recall_score(&y_test_int, &y_pred, "weighted"); // average
170 let f1 = metrics::f1_score(&y_test_int, &y_pred, "weighted"); // average
171
172 println!(" Test Results:");
173 println!(" - Accuracy: {accuracy:.3}");
174 println!(" - Precision: {precision:.3}");
175 println!(" - Recall: {recall:.3}");
176 println!(" - F1-score: {f1:.3}");
177
178 // Step 9: Classification report
179 println!("\n9. Detailed classification report...");
180
181 let classification_report = metrics::classification_report(
182 &y_test_int,
183 &y_pred,
184 vec!["Class 0", "Class 1"], // target_names
185 3, // digits
186 );
187 println!("{classification_report}");
188
189 // Step 10: Feature importance analysis
190 println!("\n10. Feature importance analysis...");
191
192 if let Some(feature_importances) = best_model.feature_importances() {
193 println!(" Quantum Feature Importances:");
194 for (i, importance) in feature_importances.iter().enumerate() {
195 println!(" - Feature {i}: {importance:.4}");
196 }
197 }
198
199 // Step 11: Model comparison with classical sklearn models
200 println!("\n11. Comparing with classical sklearn models...");
201
202 let classical_models = vec![
203 (
204 "Logistic Regression",
205 Box::new(LogisticRegression::new()) as Box<dyn SklearnClassifier>,
206 ),
207 (
208 "Random Forest",
209 Box::new(RandomForestClassifier::new()) as Box<dyn SklearnClassifier>,
210 ),
211 ("SVM", Box::new(SVC::new()) as Box<dyn SklearnClassifier>),
212 ];
213
214 let mut comparison_results = Vec::new();
215
216 for (name, mut model) in classical_models {
217 model.fit(&X_train, Some(&y_train))?;
218 let y_pred_classical = model.predict(&X_test)?;
219 let classical_accuracy = metrics::accuracy_score(&y_test_int, &y_pred_classical);
220 comparison_results.push((name, classical_accuracy));
221 }
222
223 println!(" Model Comparison:");
224 println!(" - Quantum Pipeline: {accuracy:.3}");
225 for (name, classical_accuracy) in comparison_results {
226 println!(" - {name}: {classical_accuracy:.3}");
227 }
228
229 // Step 12: Clustering with quantum K-means
230 println!("\n12. Quantum clustering analysis...");
231
232 let cluster_labels = qkmeans.fit_predict(&X)?;
233 let silhouette_score = metrics::silhouette_score(&X, &cluster_labels, "euclidean"); // metric
234 let calinski_score = metrics::calinski_harabasz_score(&X, &cluster_labels);
235
236 println!(" Clustering Results:");
237 println!(" - Silhouette Score: {silhouette_score:.3}");
238 println!(" - Calinski-Harabasz Score: {calinski_score:.3}");
239 println!(
240 " - Unique clusters found: {}",
241 cluster_labels
242 .iter()
243 .collect::<std::collections::HashSet<_>>()
244 .len()
245 );
246
247 // Step 13: Model persistence (sklearn style)
248 println!("\n13. Model persistence (sklearn joblib style)...");
249
250 // Save model
251 best_model.save("quantum_sklearn_model.joblib")?;
252 println!(" - Model saved to: quantum_sklearn_model.joblib");
253
254 // Load model
255 let loaded_model = QuantumSVC::load("quantum_sklearn_model.joblib")?;
256 let test_subset = X_test.slice(s![..5, ..]).to_owned();
257 let y_pred_loaded = loaded_model.predict(&test_subset)?;
258 println!(" - Model loaded and tested on 5 samples");
259
260 // Step 14: Advanced sklearn utilities
261 println!("\n14. Advanced sklearn utilities...");
262
263 // Learning curves (commented out - function not available)
264 // let (train_sizes, train_scores, val_scores) = model_selection::learning_curve(...)?;
265 println!(" Learning Curve Analysis: (Mock results)");
266 let train_sizes = [0.1, 0.33, 0.55, 0.78, 1.0];
267 let train_scores = [0.65, 0.72, 0.78, 0.82, 0.85];
268 let val_scores = [0.62, 0.70, 0.76, 0.79, 0.81];
269
270 for (i, &size) in train_sizes.iter().enumerate() {
271 println!(
272 " - {:.0}% data: train={:.3}, val={:.3}",
273 size * 100.0,
274 train_scores[i],
275 val_scores[i]
276 );
277 }
278
279 // Validation curves (commented out - function not available)
280 // let (train_scores_val, test_scores_val) = model_selection::validation_curve(...)?;
281 println!(" Validation Curve (C parameter): (Mock results)");
282 let param_range = [0.1, 0.5, 1.0, 2.0, 5.0];
283 let train_scores_val = [0.70, 0.75, 0.80, 0.78, 0.75];
284 let test_scores_val = [0.68, 0.73, 0.78, 0.76, 0.72];
285
286 for (i, ¶m_value) in param_range.iter().enumerate() {
287 println!(
288 " - C={}: train={:.3}, test={:.3}",
289 param_value, train_scores_val[i], test_scores_val[i]
290 );
291 }
292
293 // Step 15: Quantum-specific sklearn extensions
294 println!("\n15. Quantum-specific sklearn extensions...");
295
296 // Quantum feature analysis
297 let quantum_feature_analysis = analyze_quantum_features(&best_model, &X_test)?;
298 println!(" Quantum Feature Analysis:");
299 println!(
300 " - Quantum advantage score: {:.3}",
301 quantum_feature_analysis.advantage_score
302 );
303 println!(
304 " - Feature entanglement: {:.3}",
305 quantum_feature_analysis.entanglement_measure
306 );
307 println!(
308 " - Circuit depth efficiency: {:.3}",
309 quantum_feature_analysis.circuit_efficiency
310 );
311
312 // Quantum model interpretation
313 let sample_row = X_test.row(0).to_owned();
314 let quantum_interpretation = interpret_quantum_model(&best_model, &sample_row)?;
315 println!(" Quantum Model Interpretation (sample 0):");
316 println!(
317 " - Quantum state fidelity: {:.3}",
318 quantum_interpretation.state_fidelity
319 );
320 println!(
321 " - Feature contributions: {:?}",
322 quantum_interpretation.feature_contributions
323 );
324
325 println!("\n=== Scikit-learn Integration Demo Complete ===");
326
327 Ok(())
328}