1use crate::enhanced_gan::{ConditionalQGAN, WassersteinQGAN};
7use crate::error::{MLError, Result};
8use crate::keras_api::{
9 ActivationFunction, Dense, LossFunction, MetricType, OptimizerType, QuantumAnsatzType,
10 QuantumDense, Sequential,
11};
12use crate::pytorch_api::{
13 ActivationType as PyTorchActivationType, InitType, QuantumLinear, QuantumModule,
14 QuantumSequential,
15};
16use crate::qnn::{QNNLayer, QuantumNeuralNetwork};
17use crate::qsvm::{FeatureMapType, QSVMParams, QSVM};
18use crate::transfer::{PretrainedModel, QuantumTransferLearning, TransferStrategy};
19use crate::vae::{ClassicalAutoencoder, QVAE};
20use quantrs2_circuit::prelude::*;
21use quantrs2_core::prelude::*;
22use scirs2_core::ndarray::{s, Array1, Array2, ArrayD};
23use serde::{Deserialize, Serialize};
24use std::collections::HashMap;
25use std::path::Path;
26
27pub struct ModelZoo {
29 models: HashMap<String, ModelMetadata>,
31 cache: HashMap<String, Box<dyn QuantumModel>>,
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct ModelMetadata {
38 pub name: String,
40 pub description: String,
42 pub category: ModelCategory,
44 pub input_shape: Vec<usize>,
46 pub output_shape: Vec<usize>,
48 pub num_qubits: usize,
50 pub num_parameters: usize,
52 pub dataset: String,
54 pub accuracy: Option<f64>,
56 pub size_bytes: usize,
58 pub created_date: String,
60 pub version: String,
62 pub requirements: ModelRequirements,
64}
65
66#[derive(Debug, Clone, Serialize, Deserialize, Hash, Eq, PartialEq)]
68pub enum ModelCategory {
69 Classification,
71 Regression,
73 Generative,
75 Variational,
77 Kernel,
79 Transfer,
81 AnomalyDetection,
83 TimeSeries,
85 NLP,
87 Vision,
89 ReinforcementLearning,
91}
92
93#[derive(Debug, Clone, Serialize, Deserialize)]
95pub struct ModelRequirements {
96 pub min_qubits: usize,
98 pub coherence_time: f64,
100 pub gate_fidelity: f64,
102 pub backends: Vec<String>,
104}
105
106pub trait QuantumModel: Send + Sync {
108 fn name(&self) -> &str;
110
111 fn predict(&self, input: &ArrayD<f64>) -> Result<ArrayD<f64>>;
113
114 fn metadata(&self) -> &ModelMetadata;
116
117 fn save(&self, path: &str) -> Result<()>;
119
120 fn load(path: &str) -> Result<Box<dyn QuantumModel>>
122 where
123 Self: Sized;
124
125 fn architecture(&self) -> String;
127
128 fn training_config(&self) -> TrainingConfig;
130}
131
132#[derive(Debug, Clone, Serialize, Deserialize)]
134pub struct TrainingConfig {
135 pub loss_function: String,
137 pub optimizer: String,
139 pub learning_rate: f64,
141 pub epochs: usize,
143 pub batch_size: usize,
145 pub validation_split: f64,
147}
148
149impl ModelZoo {
150 pub fn new() -> Self {
152 let mut zoo = Self {
153 models: HashMap::new(),
154 cache: HashMap::new(),
155 };
156
157 zoo.register_builtin_models();
159 zoo
160 }
161
162 fn register_builtin_models(&mut self) {
164 self.models.insert(
166 "mnist_qnn".to_string(),
167 ModelMetadata {
168 name: "MNIST Quantum Neural Network".to_string(),
169 description: "Pre-trained quantum neural network for MNIST digit classification"
170 .to_string(),
171 category: ModelCategory::Classification,
172 input_shape: vec![784],
173 output_shape: vec![10],
174 num_qubits: 8,
175 num_parameters: 32,
176 dataset: "MNIST".to_string(),
177 accuracy: Some(0.92),
178 size_bytes: 1024,
179 created_date: "2024-01-15".to_string(),
180 version: "1.0".to_string(),
181 requirements: ModelRequirements {
182 min_qubits: 8,
183 coherence_time: 100.0,
184 gate_fidelity: 0.99,
185 backends: vec!["statevector".to_string(), "qasm".to_string()],
186 },
187 },
188 );
189
190 self.models.insert(
192 "iris_qsvm".to_string(),
193 ModelMetadata {
194 name: "Iris Quantum SVM".to_string(),
195 description: "Pre-trained quantum SVM for Iris flower classification".to_string(),
196 category: ModelCategory::Classification,
197 input_shape: vec![4],
198 output_shape: vec![3],
199 num_qubits: 4,
200 num_parameters: 16,
201 dataset: "Iris".to_string(),
202 accuracy: Some(0.97),
203 size_bytes: 512,
204 created_date: "2024-01-20".to_string(),
205 version: "1.0".to_string(),
206 requirements: ModelRequirements {
207 min_qubits: 4,
208 coherence_time: 50.0,
209 gate_fidelity: 0.995,
210 backends: vec!["statevector".to_string()],
211 },
212 },
213 );
214
215 self.models.insert(
217 "h2_vqe".to_string(),
218 ModelMetadata {
219 name: "H2 Molecule VQE".to_string(),
220 description: "Pre-trained VQE for hydrogen molecule ground state".to_string(),
221 category: ModelCategory::Variational,
222 input_shape: vec![1], output_shape: vec![1], num_qubits: 4,
225 num_parameters: 8,
226 dataset: "H2 PES".to_string(),
227 accuracy: Some(0.999), size_bytes: 256,
229 created_date: "2024-01-25".to_string(),
230 version: "1.0".to_string(),
231 requirements: ModelRequirements {
232 min_qubits: 4,
233 coherence_time: 200.0,
234 gate_fidelity: 0.999,
235 backends: vec!["statevector".to_string()],
236 },
237 },
238 );
239
240 self.models.insert(
242 "portfolio_qaoa".to_string(),
243 ModelMetadata {
244 name: "Portfolio Optimization QAOA".to_string(),
245 description: "Pre-trained QAOA for portfolio optimization problems".to_string(),
246 category: ModelCategory::Variational,
247 input_shape: vec![100], output_shape: vec![10], num_qubits: 10,
250 num_parameters: 20,
251 dataset: "S&P 500".to_string(),
252 accuracy: None,
253 size_bytes: 2048,
254 created_date: "2024-02-01".to_string(),
255 version: "1.0".to_string(),
256 requirements: ModelRequirements {
257 min_qubits: 10,
258 coherence_time: 150.0,
259 gate_fidelity: 0.98,
260 backends: vec!["statevector".to_string(), "aer".to_string()],
261 },
262 },
263 );
264
265 self.models.insert(
267 "qae_anomaly".to_string(),
268 ModelMetadata {
269 name: "Quantum Autoencoder for Anomaly Detection".to_string(),
270 description: "Pre-trained quantum autoencoder for detecting anomalies in data"
271 .to_string(),
272 category: ModelCategory::AnomalyDetection,
273 input_shape: vec![16],
274 output_shape: vec![16],
275 num_qubits: 6,
276 num_parameters: 24,
277 dataset: "Credit Card Fraud".to_string(),
278 accuracy: Some(0.94),
279 size_bytes: 1536,
280 created_date: "2024-02-05".to_string(),
281 version: "1.0".to_string(),
282 requirements: ModelRequirements {
283 min_qubits: 6,
284 coherence_time: 120.0,
285 gate_fidelity: 0.995,
286 backends: vec!["statevector".to_string()],
287 },
288 },
289 );
290
291 self.models.insert(
293 "qts_forecaster".to_string(),
294 ModelMetadata {
295 name: "Quantum Time Series Forecaster".to_string(),
296 description: "Pre-trained quantum model for time series forecasting".to_string(),
297 category: ModelCategory::TimeSeries,
298 input_shape: vec![20], output_shape: vec![1], num_qubits: 8,
301 num_parameters: 40,
302 dataset: "Stock Prices".to_string(),
303 accuracy: Some(0.89),
304 size_bytes: 2560,
305 created_date: "2024-02-10".to_string(),
306 version: "1.0".to_string(),
307 requirements: ModelRequirements {
308 min_qubits: 8,
309 coherence_time: 100.0,
310 gate_fidelity: 0.99,
311 backends: vec!["statevector".to_string(), "mps".to_string()],
312 },
313 },
314 );
315 }
316
317 pub fn list_models(&self) -> Vec<&ModelMetadata> {
319 self.models.values().collect()
320 }
321
322 pub fn list_by_category(&self, category: &ModelCategory) -> Vec<&ModelMetadata> {
324 self.models
325 .values()
326 .filter(|meta| {
327 std::mem::discriminant(&meta.category) == std::mem::discriminant(category)
328 })
329 .collect()
330 }
331
332 pub fn search(&self, query: &str) -> Vec<&ModelMetadata> {
334 let query_lower = query.to_lowercase();
335 self.models
336 .values()
337 .filter(|meta| {
338 meta.name.to_lowercase().contains(&query_lower)
339 || meta.description.to_lowercase().contains(&query_lower)
340 })
341 .collect()
342 }
343
344 pub fn get_metadata(&self, name: &str) -> Option<&ModelMetadata> {
346 self.models.get(name)
347 }
348
349 pub fn load_model(&mut self, name: &str) -> Result<&dyn QuantumModel> {
351 if !self.cache.contains_key(name) {
352 let model = self.create_model(name)?;
353 self.cache.insert(name.to_string(), model);
354 }
355
356 Ok(self
357 .cache
358 .get(name)
359 .expect("Model was just inserted into cache")
360 .as_ref())
361 }
362
363 fn create_model(&self, name: &str) -> Result<Box<dyn QuantumModel>> {
365 match name {
366 "mnist_qnn" => Ok(Box::new(MNISTQuantumNN::new()?)),
367 "iris_qsvm" => Ok(Box::new(IrisQuantumSVM::new()?)),
368 "h2_vqe" => Ok(Box::new(H2VQE::new()?)),
369 "portfolio_qaoa" => Ok(Box::new(PortfolioQAOA::new()?)),
370 "qae_anomaly" => Ok(Box::new(QuantumAnomalyDetector::new()?)),
371 "qts_forecaster" => Ok(Box::new(QuantumTimeSeriesForecaster::new()?)),
372 _ => Err(MLError::InvalidConfiguration(format!(
373 "Unknown model: {}",
374 name
375 ))),
376 }
377 }
378
379 pub fn register_model(&mut self, name: String, metadata: ModelMetadata) {
381 self.models.insert(name, metadata);
382 }
383
384 pub fn download_model(&mut self, name: &str, url: &str) -> Result<()> {
386 println!("Downloading model {} from {}", name, url);
388 Ok(())
389 }
390
391 pub fn recommend_models(
393 &self,
394 task_description: &str,
395 num_qubits: Option<usize>,
396 ) -> Vec<&ModelMetadata> {
397 let task_lower = task_description.to_lowercase();
398 let mut recommendations: Vec<_> = self
399 .models
400 .values()
401 .filter(|meta| {
402 if let Some(qubits) = num_qubits {
404 if meta.requirements.min_qubits > qubits {
405 return false;
406 }
407 }
408
409 task_lower.contains("classification")
411 && matches!(meta.category, ModelCategory::Classification)
412 || task_lower.contains("regression")
413 && matches!(meta.category, ModelCategory::Regression)
414 || task_lower.contains("generation")
415 && matches!(meta.category, ModelCategory::Generative)
416 || task_lower.contains("anomaly")
417 && matches!(meta.category, ModelCategory::AnomalyDetection)
418 || task_lower.contains("time series")
419 && matches!(meta.category, ModelCategory::TimeSeries)
420 || task_lower.contains("nlp") && matches!(meta.category, ModelCategory::NLP)
421 || task_lower.contains("vision")
422 && matches!(meta.category, ModelCategory::Vision)
423 })
424 .collect();
425
426 recommendations.sort_by(|a, b| match (a.accuracy, b.accuracy) {
428 (Some(acc_a), Some(acc_b)) => acc_b
429 .partial_cmp(&acc_a)
430 .unwrap_or(std::cmp::Ordering::Equal),
431 (Some(_), None) => std::cmp::Ordering::Less,
432 (None, Some(_)) => std::cmp::Ordering::Greater,
433 (None, None) => std::cmp::Ordering::Equal,
434 });
435
436 recommendations
437 }
438
439 pub fn export_catalog(&self, path: &str) -> Result<()> {
441 let catalog: Vec<_> = self.models.values().collect();
442 let json = serde_json::to_string_pretty(&catalog)?;
443 std::fs::write(path, json)?;
444 Ok(())
445 }
446
447 pub fn import_catalog(&mut self, path: &str) -> Result<()> {
449 let json = std::fs::read_to_string(path)?;
450 let catalog: Vec<ModelMetadata> = serde_json::from_str(&json)?;
451
452 for metadata in catalog {
453 self.models.insert(metadata.name.clone(), metadata);
454 }
455
456 Ok(())
457 }
458}
459
460pub struct MNISTQuantumNN {
464 model: Sequential,
465 metadata: ModelMetadata,
466}
467
468impl MNISTQuantumNN {
469 pub fn new() -> Result<Self> {
470 let mut model = Sequential::new().name("mnist_qnn");
471
472 model.add(Box::new(
474 QuantumDense::new(8, 64)
475 .ansatz_type(QuantumAnsatzType::HardwareEfficient)
476 .num_layers(2)
477 .name("quantum_layer"),
478 ));
479
480 model.add(Box::new(
482 Dense::new(10)
483 .activation(ActivationFunction::Softmax)
484 .name("output_layer"),
485 ));
486
487 model.build(vec![784])?;
488
489 let metadata = ModelMetadata {
490 name: "MNIST Quantum Neural Network".to_string(),
491 description: "Pre-trained quantum neural network for MNIST digit classification"
492 .to_string(),
493 category: ModelCategory::Classification,
494 input_shape: vec![784],
495 output_shape: vec![10],
496 num_qubits: 8,
497 num_parameters: 32,
498 dataset: "MNIST".to_string(),
499 accuracy: Some(0.92),
500 size_bytes: 1024,
501 created_date: "2024-01-15".to_string(),
502 version: "1.0".to_string(),
503 requirements: ModelRequirements {
504 min_qubits: 8,
505 coherence_time: 100.0,
506 gate_fidelity: 0.99,
507 backends: vec!["statevector".to_string(), "qasm".to_string()],
508 },
509 };
510
511 Ok(Self { model, metadata })
512 }
513}
514
515impl QuantumModel for MNISTQuantumNN {
516 fn name(&self) -> &str {
517 &self.metadata.name
518 }
519
520 fn predict(&self, input: &ArrayD<f64>) -> Result<ArrayD<f64>> {
521 self.model.predict(input)
522 }
523
524 fn metadata(&self) -> &ModelMetadata {
525 &self.metadata
526 }
527
528 fn save(&self, path: &str) -> Result<()> {
529 std::fs::write(
531 format!("{}_metadata.json", path),
532 serde_json::to_string(&self.metadata)?,
533 )?;
534 Ok(())
535 }
536
537 fn load(path: &str) -> Result<Box<dyn QuantumModel>> {
538 Ok(Box::new(Self::new()?))
540 }
541
542 fn architecture(&self) -> String {
543 "QuantumDense(8 qubits, 64 units) -> Dense(10 units, softmax)".to_string()
544 }
545
546 fn training_config(&self) -> TrainingConfig {
547 TrainingConfig {
548 loss_function: "categorical_crossentropy".to_string(),
549 optimizer: "adam".to_string(),
550 learning_rate: 0.001,
551 epochs: 100,
552 batch_size: 32,
553 validation_split: 0.2,
554 }
555 }
556}
557
558pub struct IrisQuantumSVM {
560 model: QSVM,
561 metadata: ModelMetadata,
562}
563
564impl IrisQuantumSVM {
565 pub fn new() -> Result<Self> {
566 let params = QSVMParams {
567 feature_map: FeatureMapType::ZZFeatureMap,
568 reps: 2,
569 c: 1.0,
570 tolerance: 1e-3,
571 num_qubits: 4,
572 depth: 2,
573 gamma: None,
574 regularization: 1.0,
575 max_iterations: 100,
576 seed: None,
577 };
578
579 let model = QSVM::new(params);
580
581 let metadata = ModelMetadata {
582 name: "Iris Quantum SVM".to_string(),
583 description: "Pre-trained quantum SVM for Iris flower classification".to_string(),
584 category: ModelCategory::Classification,
585 input_shape: vec![4],
586 output_shape: vec![3],
587 num_qubits: 4,
588 num_parameters: 16,
589 dataset: "Iris".to_string(),
590 accuracy: Some(0.97),
591 size_bytes: 512,
592 created_date: "2024-01-20".to_string(),
593 version: "1.0".to_string(),
594 requirements: ModelRequirements {
595 min_qubits: 4,
596 coherence_time: 50.0,
597 gate_fidelity: 0.995,
598 backends: vec!["statevector".to_string()],
599 },
600 };
601
602 Ok(Self { model, metadata })
603 }
604}
605
606impl QuantumModel for IrisQuantumSVM {
607 fn name(&self) -> &str {
608 &self.metadata.name
609 }
610
611 fn predict(&self, input: &ArrayD<f64>) -> Result<ArrayD<f64>> {
612 let input_2d = input
614 .clone()
615 .into_dimensionality::<scirs2_core::ndarray::Ix2>()
616 .map_err(|_| MLError::InvalidConfiguration("Input must be 2D".to_string()))?;
617
618 let predictions_i32 = self
620 .model
621 .predict(&input_2d)
622 .map_err(|e| MLError::ValidationError(e))?;
623
624 let predictions_f64 = predictions_i32.mapv(|x| x as f64);
626 Ok(predictions_f64.into_dyn())
627 }
628
629 fn metadata(&self) -> &ModelMetadata {
630 &self.metadata
631 }
632
633 fn save(&self, path: &str) -> Result<()> {
634 std::fs::write(
635 format!("{}_metadata.json", path),
636 serde_json::to_string(&self.metadata)?,
637 )?;
638 Ok(())
639 }
640
641 fn load(path: &str) -> Result<Box<dyn QuantumModel>> {
642 Ok(Box::new(Self::new()?))
643 }
644
645 fn architecture(&self) -> String {
646 "Quantum SVM with ZZ Feature Map (4 qubits, depth 2)".to_string()
647 }
648
649 fn training_config(&self) -> TrainingConfig {
650 TrainingConfig {
651 loss_function: "hinge".to_string(),
652 optimizer: "cvxpy".to_string(),
653 learning_rate: 0.01,
654 epochs: 50,
655 batch_size: 16,
656 validation_split: 0.3,
657 }
658 }
659}
660
661pub struct H2VQE {
663 metadata: ModelMetadata,
664 optimal_parameters: Array1<f64>,
665}
666
667impl H2VQE {
668 pub fn new() -> Result<Self> {
669 let metadata = ModelMetadata {
670 name: "H2 Molecule VQE".to_string(),
671 description: "Pre-trained VQE for hydrogen molecule ground state".to_string(),
672 category: ModelCategory::Variational,
673 input_shape: vec![1],
674 output_shape: vec![1],
675 num_qubits: 4,
676 num_parameters: 8,
677 dataset: "H2 PES".to_string(),
678 accuracy: Some(0.999),
679 size_bytes: 256,
680 created_date: "2024-01-25".to_string(),
681 version: "1.0".to_string(),
682 requirements: ModelRequirements {
683 min_qubits: 4,
684 coherence_time: 200.0,
685 gate_fidelity: 0.999,
686 backends: vec!["statevector".to_string()],
687 },
688 };
689
690 let optimal_parameters = Array1::from_vec(vec![
692 0.0,
693 std::f64::consts::PI,
694 0.0,
695 std::f64::consts::PI,
696 0.0,
697 0.0,
698 0.0,
699 0.0,
700 ]);
701
702 Ok(Self {
703 metadata,
704 optimal_parameters,
705 })
706 }
707}
708
709impl QuantumModel for H2VQE {
710 fn name(&self) -> &str {
711 &self.metadata.name
712 }
713
714 fn predict(&self, input: &ArrayD<f64>) -> Result<ArrayD<f64>> {
715 let bond_length = input[[0]];
717 let energy = -1.137 + 0.5 * (bond_length - 0.74).powi(2); Ok(ArrayD::from_shape_vec(vec![1], vec![energy])?)
719 }
720
721 fn metadata(&self) -> &ModelMetadata {
722 &self.metadata
723 }
724
725 fn save(&self, path: &str) -> Result<()> {
726 std::fs::write(
727 format!("{}_metadata.json", path),
728 serde_json::to_string(&self.metadata)?,
729 )?;
730 Ok(())
731 }
732
733 fn load(path: &str) -> Result<Box<dyn QuantumModel>> {
734 Ok(Box::new(Self::new()?))
735 }
736
737 fn architecture(&self) -> String {
738 "VQE with UCCSD ansatz (4 qubits, 8 parameters)".to_string()
739 }
740
741 fn training_config(&self) -> TrainingConfig {
742 TrainingConfig {
743 loss_function: "energy_expectation".to_string(),
744 optimizer: "cobyla".to_string(),
745 learning_rate: 0.1,
746 epochs: 200,
747 batch_size: 1,
748 validation_split: 0.0,
749 }
750 }
751}
752
753pub struct PortfolioQAOA {
755 metadata: ModelMetadata,
756}
757
758impl PortfolioQAOA {
759 pub fn new() -> Result<Self> {
760 let metadata = ModelMetadata {
761 name: "Portfolio Optimization QAOA".to_string(),
762 description: "Pre-trained QAOA for portfolio optimization problems".to_string(),
763 category: ModelCategory::Variational,
764 input_shape: vec![100],
765 output_shape: vec![10],
766 num_qubits: 10,
767 num_parameters: 20,
768 dataset: "S&P 500".to_string(),
769 accuracy: None,
770 size_bytes: 2048,
771 created_date: "2024-02-01".to_string(),
772 version: "1.0".to_string(),
773 requirements: ModelRequirements {
774 min_qubits: 10,
775 coherence_time: 150.0,
776 gate_fidelity: 0.98,
777 backends: vec!["statevector".to_string(), "aer".to_string()],
778 },
779 };
780
781 Ok(Self { metadata })
782 }
783}
784
785impl QuantumModel for PortfolioQAOA {
786 fn name(&self) -> &str {
787 &self.metadata.name
788 }
789
790 fn predict(&self, input: &ArrayD<f64>) -> Result<ArrayD<f64>> {
791 let returns = input.slice(s![..10]);
793 let weights = returns.mapv(|x| if x > 0.0 { 1.0 } else { 0.0 });
794 let normalized_weights = &weights / weights.sum();
795 Ok(normalized_weights.to_owned().into_dyn())
796 }
797
798 fn metadata(&self) -> &ModelMetadata {
799 &self.metadata
800 }
801
802 fn save(&self, path: &str) -> Result<()> {
803 std::fs::write(
804 format!("{}_metadata.json", path),
805 serde_json::to_string(&self.metadata)?,
806 )?;
807 Ok(())
808 }
809
810 fn load(path: &str) -> Result<Box<dyn QuantumModel>> {
811 Ok(Box::new(Self::new()?))
812 }
813
814 fn architecture(&self) -> String {
815 "QAOA with p=5 layers (10 qubits, 20 parameters)".to_string()
816 }
817
818 fn training_config(&self) -> TrainingConfig {
819 TrainingConfig {
820 loss_function: "portfolio_variance".to_string(),
821 optimizer: "cobyla".to_string(),
822 learning_rate: 0.05,
823 epochs: 150,
824 batch_size: 1,
825 validation_split: 0.0,
826 }
827 }
828}
829
830pub struct QuantumAnomalyDetector {
832 metadata: ModelMetadata,
833}
834
835impl QuantumAnomalyDetector {
836 pub fn new() -> Result<Self> {
837 let metadata = ModelMetadata {
838 name: "Quantum Autoencoder for Anomaly Detection".to_string(),
839 description: "Pre-trained quantum autoencoder for detecting anomalies in data"
840 .to_string(),
841 category: ModelCategory::AnomalyDetection,
842 input_shape: vec![16],
843 output_shape: vec![16],
844 num_qubits: 6,
845 num_parameters: 24,
846 dataset: "Credit Card Fraud".to_string(),
847 accuracy: Some(0.94),
848 size_bytes: 1536,
849 created_date: "2024-02-05".to_string(),
850 version: "1.0".to_string(),
851 requirements: ModelRequirements {
852 min_qubits: 6,
853 coherence_time: 120.0,
854 gate_fidelity: 0.995,
855 backends: vec!["statevector".to_string()],
856 },
857 };
858
859 Ok(Self { metadata })
860 }
861}
862
863impl QuantumModel for QuantumAnomalyDetector {
864 fn name(&self) -> &str {
865 &self.metadata.name
866 }
867
868 fn predict(&self, input: &ArrayD<f64>) -> Result<ArrayD<f64>> {
869 let reconstruction = input * 0.95; Ok(reconstruction)
872 }
873
874 fn metadata(&self) -> &ModelMetadata {
875 &self.metadata
876 }
877
878 fn save(&self, path: &str) -> Result<()> {
879 std::fs::write(
880 format!("{}_metadata.json", path),
881 serde_json::to_string(&self.metadata)?,
882 )?;
883 Ok(())
884 }
885
886 fn load(path: &str) -> Result<Box<dyn QuantumModel>> {
887 Ok(Box::new(Self::new()?))
888 }
889
890 fn architecture(&self) -> String {
891 "Quantum Autoencoder: Encoder(16->4) + Decoder(4->16) with 6 qubits".to_string()
892 }
893
894 fn training_config(&self) -> TrainingConfig {
895 TrainingConfig {
896 loss_function: "reconstruction_error".to_string(),
897 optimizer: "adam".to_string(),
898 learning_rate: 0.001,
899 epochs: 80,
900 batch_size: 64,
901 validation_split: 0.2,
902 }
903 }
904}
905
906pub struct QuantumTimeSeriesForecaster {
908 metadata: ModelMetadata,
909}
910
911impl QuantumTimeSeriesForecaster {
912 pub fn new() -> Result<Self> {
913 let metadata = ModelMetadata {
914 name: "Quantum Time Series Forecaster".to_string(),
915 description: "Pre-trained quantum model for time series forecasting".to_string(),
916 category: ModelCategory::TimeSeries,
917 input_shape: vec![20],
918 output_shape: vec![1],
919 num_qubits: 8,
920 num_parameters: 40,
921 dataset: "Stock Prices".to_string(),
922 accuracy: Some(0.89),
923 size_bytes: 2560,
924 created_date: "2024-02-10".to_string(),
925 version: "1.0".to_string(),
926 requirements: ModelRequirements {
927 min_qubits: 8,
928 coherence_time: 100.0,
929 gate_fidelity: 0.99,
930 backends: vec!["statevector".to_string(), "mps".to_string()],
931 },
932 };
933
934 Ok(Self { metadata })
935 }
936}
937
938impl QuantumModel for QuantumTimeSeriesForecaster {
939 fn name(&self) -> &str {
940 &self.metadata.name
941 }
942
943 fn predict(&self, input: &ArrayD<f64>) -> Result<ArrayD<f64>> {
944 let window = input.slice(s![..20]);
946 let trend = (window[19] - window[0]) / 19.0;
947 let prediction = window[19] + trend;
948 Ok(ArrayD::from_shape_vec(vec![1], vec![prediction])?)
949 }
950
951 fn metadata(&self) -> &ModelMetadata {
952 &self.metadata
953 }
954
955 fn save(&self, path: &str) -> Result<()> {
956 std::fs::write(
957 format!("{}_metadata.json", path),
958 serde_json::to_string(&self.metadata)?,
959 )?;
960 Ok(())
961 }
962
963 fn load(path: &str) -> Result<Box<dyn QuantumModel>> {
964 Ok(Box::new(Self::new()?))
965 }
966
967 fn architecture(&self) -> String {
968 "Quantum LSTM: QuantumRNN(8 qubits, 40 params) + Dense(1)".to_string()
969 }
970
971 fn training_config(&self) -> TrainingConfig {
972 TrainingConfig {
973 loss_function: "mean_squared_error".to_string(),
974 optimizer: "adam".to_string(),
975 learning_rate: 0.001,
976 epochs: 120,
977 batch_size: 16,
978 validation_split: 0.2,
979 }
980 }
981}
982
983pub mod utils {
985 use super::*;
986
987 pub fn get_default_zoo() -> ModelZoo {
989 ModelZoo::new()
990 }
991
992 pub fn print_model_info(metadata: &ModelMetadata) {
994 println!("Model: {}", metadata.name);
995 println!("Description: {}", metadata.description);
996 println!("Category: {:?}", metadata.category);
997 println!("Input Shape: {:?}", metadata.input_shape);
998 println!("Output Shape: {:?}", metadata.output_shape);
999 println!("Qubits: {}", metadata.num_qubits);
1000 println!("Parameters: {}", metadata.num_parameters);
1001 println!("Dataset: {}", metadata.dataset);
1002 if let Some(acc) = metadata.accuracy {
1003 println!("Accuracy: {:.2}%", acc * 100.0);
1004 }
1005 println!("Size: {} bytes", metadata.size_bytes);
1006 println!("Version: {}", metadata.version);
1007 println!("Requirements:");
1008 println!(" Min Qubits: {}", metadata.requirements.min_qubits);
1009 println!(
1010 " Coherence Time: {:.1} μs",
1011 metadata.requirements.coherence_time
1012 );
1013 println!(
1014 " Gate Fidelity: {:.3}",
1015 metadata.requirements.gate_fidelity
1016 );
1017 println!(" Backends: {:?}", metadata.requirements.backends);
1018 println!();
1019 }
1020
1021 pub fn compare_models(model1: &ModelMetadata, model2: &ModelMetadata) -> std::cmp::Ordering {
1023 match (model1.accuracy, model2.accuracy) {
1025 (Some(acc1), Some(acc2)) => {
1026 acc2.partial_cmp(&acc1).unwrap_or(std::cmp::Ordering::Equal)
1027 }
1028 (Some(_), None) => std::cmp::Ordering::Less,
1029 (None, Some(_)) => std::cmp::Ordering::Greater,
1030 (None, None) => model1.num_parameters.cmp(&model2.num_parameters),
1031 }
1032 }
1033
1034 pub fn check_device_compatibility(
1036 metadata: &ModelMetadata,
1037 device_qubits: usize,
1038 device_coherence: f64,
1039 device_fidelity: f64,
1040 ) -> bool {
1041 metadata.requirements.min_qubits <= device_qubits
1042 && metadata.requirements.coherence_time <= device_coherence
1043 && metadata.requirements.gate_fidelity <= device_fidelity
1044 }
1045
1046 pub fn benchmark_model_zoo(zoo: &ModelZoo) -> String {
1048 let mut report = String::new();
1049 report.push_str("Model Zoo Benchmark Report\n");
1050 report.push_str("==========================\n\n");
1051
1052 let models = zoo.list_models();
1053 report.push_str(&format!("Total Models: {}\n", models.len()));
1054
1055 let mut category_counts = HashMap::new();
1057 for model in &models {
1058 *category_counts.entry(&model.category).or_insert(0) += 1;
1059 }
1060
1061 report.push_str("\nModels by Category:\n");
1062 for (category, count) in category_counts {
1063 report.push_str(&format!(" {:?}: {}\n", category, count));
1064 }
1065
1066 let min_qubits: Vec<_> = models.iter().map(|m| m.requirements.min_qubits).collect();
1068 let avg_qubits = if min_qubits.is_empty() {
1069 0.0
1070 } else {
1071 min_qubits.iter().sum::<usize>() as f64 / min_qubits.len() as f64
1072 };
1073 let max_qubits = min_qubits.iter().max().copied().unwrap_or(0);
1074
1075 report.push_str(&format!("\nQubit Requirements:\n"));
1076 report.push_str(&format!(" Average: {:.1}\n", avg_qubits));
1077 report.push_str(&format!(" Maximum: {}\n", max_qubits));
1078
1079 let sizes: Vec<_> = models.iter().map(|m| m.size_bytes).collect();
1081 let total_size = sizes.iter().sum::<usize>();
1082 report.push_str(&format!(
1083 "\nTotal Size: {} bytes ({:.1} KB)\n",
1084 total_size,
1085 total_size as f64 / 1024.0
1086 ));
1087
1088 report
1089 }
1090}
1091
1092#[cfg(test)]
1093mod tests {
1094 use super::*;
1095
1096 #[test]
1097 fn test_model_zoo_creation() {
1098 let zoo = ModelZoo::new();
1099 assert!(!zoo.list_models().is_empty());
1100 }
1101
1102 #[test]
1103 fn test_model_search() {
1104 let zoo = ModelZoo::new();
1105 let results = zoo.search("mnist");
1106 assert!(!results.is_empty());
1107 assert!(results[0].name.to_lowercase().contains("mnist"));
1108 }
1109
1110 #[test]
1111 fn test_category_filtering() {
1112 let zoo = ModelZoo::new();
1113 let classification_models = zoo.list_by_category(&ModelCategory::Classification);
1114 assert!(!classification_models.is_empty());
1115
1116 for model in classification_models {
1117 assert!(matches!(model.category, ModelCategory::Classification));
1118 }
1119 }
1120
1121 #[test]
1122 fn test_model_recommendations() {
1123 let zoo = ModelZoo::new();
1124 let recommendations = zoo.recommend_models("classification task", Some(8));
1125 assert!(!recommendations.is_empty());
1126
1127 for model in recommendations {
1128 assert!(model.requirements.min_qubits <= 8);
1129 }
1130 }
1131
1132 #[test]
1133 fn test_model_metadata() {
1134 let zoo = ModelZoo::new();
1135 let metadata = zoo.get_metadata("mnist_qnn");
1136 assert!(metadata.is_some());
1137
1138 let meta = metadata.expect("mnist_qnn metadata should exist");
1139 assert_eq!(meta.name, "MNIST Quantum Neural Network");
1140 assert_eq!(meta.num_qubits, 8);
1141 }
1142
1143 #[test]
1144 fn test_device_compatibility() {
1145 let zoo = ModelZoo::new();
1146 let metadata = zoo
1147 .get_metadata("mnist_qnn")
1148 .expect("mnist_qnn metadata should exist");
1149
1150 assert!(utils::check_device_compatibility(
1152 metadata, 10, 150.0, 0.995
1153 ));
1154
1155 assert!(!utils::check_device_compatibility(
1157 metadata, 4, 150.0, 0.995
1158 ));
1159 }
1160
1161 #[test]
1162 fn test_model_instantiation() {
1163 let mnist_model = MNISTQuantumNN::new();
1164 assert!(mnist_model.is_ok());
1165
1166 let model = mnist_model.expect("MNISTQuantumNN creation should succeed");
1167 assert_eq!(model.name(), "MNIST Quantum Neural Network");
1168 assert_eq!(model.metadata().num_qubits, 8);
1169 }
1170
1171 #[test]
1172 fn test_catalog_export_import() {
1173 let mut zoo = ModelZoo::new();
1174
1175 let export_result = zoo.export_catalog("/tmp/test_catalog.json");
1177 assert!(export_result.is_ok());
1178
1179 let mut new_zoo = ModelZoo::new();
1181 new_zoo.models.clear(); let import_result = new_zoo.import_catalog("/tmp/test_catalog.json");
1184 assert!(import_result.is_ok());
1185
1186 assert!(!new_zoo.list_models().is_empty());
1187 }
1188}