1use crate::enhanced_gan::{ConditionalQGAN, WassersteinQGAN};
7use crate::error::{MLError, Result};
8use crate::keras_api::{
9 ActivationFunction, Dense, LossFunction, MetricType, OptimizerType, QuantumAnsatzType,
10 QuantumDense, Sequential,
11};
12use crate::pytorch_api::{
13 ActivationType as PyTorchActivationType, InitType, QuantumLinear, QuantumModule,
14 QuantumSequential,
15};
16use crate::qnn::{QNNLayer, QuantumNeuralNetwork};
17use crate::qsvm::{FeatureMapType, QSVMParams, QSVM};
18use crate::transfer::{PretrainedModel, QuantumTransferLearning, TransferStrategy};
19use crate::vae::{ClassicalAutoencoder, QVAE};
20use ndarray::{s, Array1, Array2, ArrayD};
21use quantrs2_circuit::prelude::*;
22use quantrs2_core::prelude::*;
23use serde::{Deserialize, Serialize};
24use std::collections::HashMap;
25use std::path::Path;
26
27pub struct ModelZoo {
29 models: HashMap<String, ModelMetadata>,
31 cache: HashMap<String, Box<dyn QuantumModel>>,
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct ModelMetadata {
38 pub name: String,
40 pub description: String,
42 pub category: ModelCategory,
44 pub input_shape: Vec<usize>,
46 pub output_shape: Vec<usize>,
48 pub num_qubits: usize,
50 pub num_parameters: usize,
52 pub dataset: String,
54 pub accuracy: Option<f64>,
56 pub size_bytes: usize,
58 pub created_date: String,
60 pub version: String,
62 pub requirements: ModelRequirements,
64}
65
66#[derive(Debug, Clone, Serialize, Deserialize, Hash, Eq, PartialEq)]
68pub enum ModelCategory {
69 Classification,
71 Regression,
73 Generative,
75 Variational,
77 Kernel,
79 Transfer,
81 AnomalyDetection,
83 TimeSeries,
85 NLP,
87 Vision,
89 ReinforcementLearning,
91}
92
93#[derive(Debug, Clone, Serialize, Deserialize)]
95pub struct ModelRequirements {
96 pub min_qubits: usize,
98 pub coherence_time: f64,
100 pub gate_fidelity: f64,
102 pub backends: Vec<String>,
104}
105
106pub trait QuantumModel: Send + Sync {
108 fn name(&self) -> &str;
110
111 fn predict(&self, input: &ArrayD<f64>) -> Result<ArrayD<f64>>;
113
114 fn metadata(&self) -> &ModelMetadata;
116
117 fn save(&self, path: &str) -> Result<()>;
119
120 fn load(path: &str) -> Result<Box<dyn QuantumModel>>
122 where
123 Self: Sized;
124
125 fn architecture(&self) -> String;
127
128 fn training_config(&self) -> TrainingConfig;
130}
131
132#[derive(Debug, Clone, Serialize, Deserialize)]
134pub struct TrainingConfig {
135 pub loss_function: String,
137 pub optimizer: String,
139 pub learning_rate: f64,
141 pub epochs: usize,
143 pub batch_size: usize,
145 pub validation_split: f64,
147}
148
149impl ModelZoo {
150 pub fn new() -> Self {
152 let mut zoo = Self {
153 models: HashMap::new(),
154 cache: HashMap::new(),
155 };
156
157 zoo.register_builtin_models();
159 zoo
160 }
161
162 fn register_builtin_models(&mut self) {
164 self.models.insert(
166 "mnist_qnn".to_string(),
167 ModelMetadata {
168 name: "MNIST Quantum Neural Network".to_string(),
169 description: "Pre-trained quantum neural network for MNIST digit classification"
170 .to_string(),
171 category: ModelCategory::Classification,
172 input_shape: vec![784],
173 output_shape: vec![10],
174 num_qubits: 8,
175 num_parameters: 32,
176 dataset: "MNIST".to_string(),
177 accuracy: Some(0.92),
178 size_bytes: 1024,
179 created_date: "2024-01-15".to_string(),
180 version: "1.0".to_string(),
181 requirements: ModelRequirements {
182 min_qubits: 8,
183 coherence_time: 100.0,
184 gate_fidelity: 0.99,
185 backends: vec!["statevector".to_string(), "qasm".to_string()],
186 },
187 },
188 );
189
190 self.models.insert(
192 "iris_qsvm".to_string(),
193 ModelMetadata {
194 name: "Iris Quantum SVM".to_string(),
195 description: "Pre-trained quantum SVM for Iris flower classification".to_string(),
196 category: ModelCategory::Classification,
197 input_shape: vec![4],
198 output_shape: vec![3],
199 num_qubits: 4,
200 num_parameters: 16,
201 dataset: "Iris".to_string(),
202 accuracy: Some(0.97),
203 size_bytes: 512,
204 created_date: "2024-01-20".to_string(),
205 version: "1.0".to_string(),
206 requirements: ModelRequirements {
207 min_qubits: 4,
208 coherence_time: 50.0,
209 gate_fidelity: 0.995,
210 backends: vec!["statevector".to_string()],
211 },
212 },
213 );
214
215 self.models.insert(
217 "h2_vqe".to_string(),
218 ModelMetadata {
219 name: "H2 Molecule VQE".to_string(),
220 description: "Pre-trained VQE for hydrogen molecule ground state".to_string(),
221 category: ModelCategory::Variational,
222 input_shape: vec![1], output_shape: vec![1], num_qubits: 4,
225 num_parameters: 8,
226 dataset: "H2 PES".to_string(),
227 accuracy: Some(0.999), size_bytes: 256,
229 created_date: "2024-01-25".to_string(),
230 version: "1.0".to_string(),
231 requirements: ModelRequirements {
232 min_qubits: 4,
233 coherence_time: 200.0,
234 gate_fidelity: 0.999,
235 backends: vec!["statevector".to_string()],
236 },
237 },
238 );
239
240 self.models.insert(
242 "portfolio_qaoa".to_string(),
243 ModelMetadata {
244 name: "Portfolio Optimization QAOA".to_string(),
245 description: "Pre-trained QAOA for portfolio optimization problems".to_string(),
246 category: ModelCategory::Variational,
247 input_shape: vec![100], output_shape: vec![10], num_qubits: 10,
250 num_parameters: 20,
251 dataset: "S&P 500".to_string(),
252 accuracy: None,
253 size_bytes: 2048,
254 created_date: "2024-02-01".to_string(),
255 version: "1.0".to_string(),
256 requirements: ModelRequirements {
257 min_qubits: 10,
258 coherence_time: 150.0,
259 gate_fidelity: 0.98,
260 backends: vec!["statevector".to_string(), "aer".to_string()],
261 },
262 },
263 );
264
265 self.models.insert(
267 "qae_anomaly".to_string(),
268 ModelMetadata {
269 name: "Quantum Autoencoder for Anomaly Detection".to_string(),
270 description: "Pre-trained quantum autoencoder for detecting anomalies in data"
271 .to_string(),
272 category: ModelCategory::AnomalyDetection,
273 input_shape: vec![16],
274 output_shape: vec![16],
275 num_qubits: 6,
276 num_parameters: 24,
277 dataset: "Credit Card Fraud".to_string(),
278 accuracy: Some(0.94),
279 size_bytes: 1536,
280 created_date: "2024-02-05".to_string(),
281 version: "1.0".to_string(),
282 requirements: ModelRequirements {
283 min_qubits: 6,
284 coherence_time: 120.0,
285 gate_fidelity: 0.995,
286 backends: vec!["statevector".to_string()],
287 },
288 },
289 );
290
291 self.models.insert(
293 "qts_forecaster".to_string(),
294 ModelMetadata {
295 name: "Quantum Time Series Forecaster".to_string(),
296 description: "Pre-trained quantum model for time series forecasting".to_string(),
297 category: ModelCategory::TimeSeries,
298 input_shape: vec![20], output_shape: vec![1], num_qubits: 8,
301 num_parameters: 40,
302 dataset: "Stock Prices".to_string(),
303 accuracy: Some(0.89),
304 size_bytes: 2560,
305 created_date: "2024-02-10".to_string(),
306 version: "1.0".to_string(),
307 requirements: ModelRequirements {
308 min_qubits: 8,
309 coherence_time: 100.0,
310 gate_fidelity: 0.99,
311 backends: vec!["statevector".to_string(), "mps".to_string()],
312 },
313 },
314 );
315 }
316
317 pub fn list_models(&self) -> Vec<&ModelMetadata> {
319 self.models.values().collect()
320 }
321
322 pub fn list_by_category(&self, category: &ModelCategory) -> Vec<&ModelMetadata> {
324 self.models
325 .values()
326 .filter(|meta| {
327 std::mem::discriminant(&meta.category) == std::mem::discriminant(category)
328 })
329 .collect()
330 }
331
332 pub fn search(&self, query: &str) -> Vec<&ModelMetadata> {
334 let query_lower = query.to_lowercase();
335 self.models
336 .values()
337 .filter(|meta| {
338 meta.name.to_lowercase().contains(&query_lower)
339 || meta.description.to_lowercase().contains(&query_lower)
340 })
341 .collect()
342 }
343
344 pub fn get_metadata(&self, name: &str) -> Option<&ModelMetadata> {
346 self.models.get(name)
347 }
348
349 pub fn load_model(&mut self, name: &str) -> Result<&dyn QuantumModel> {
351 if !self.cache.contains_key(name) {
352 let model = self.create_model(name)?;
353 self.cache.insert(name.to_string(), model);
354 }
355
356 Ok(self.cache.get(name).unwrap().as_ref())
357 }
358
359 fn create_model(&self, name: &str) -> Result<Box<dyn QuantumModel>> {
361 match name {
362 "mnist_qnn" => Ok(Box::new(MNISTQuantumNN::new()?)),
363 "iris_qsvm" => Ok(Box::new(IrisQuantumSVM::new()?)),
364 "h2_vqe" => Ok(Box::new(H2VQE::new()?)),
365 "portfolio_qaoa" => Ok(Box::new(PortfolioQAOA::new()?)),
366 "qae_anomaly" => Ok(Box::new(QuantumAnomalyDetector::new()?)),
367 "qts_forecaster" => Ok(Box::new(QuantumTimeSeriesForecaster::new()?)),
368 _ => Err(MLError::InvalidConfiguration(format!(
369 "Unknown model: {}",
370 name
371 ))),
372 }
373 }
374
375 pub fn register_model(&mut self, name: String, metadata: ModelMetadata) {
377 self.models.insert(name, metadata);
378 }
379
380 pub fn download_model(&mut self, name: &str, url: &str) -> Result<()> {
382 println!("Downloading model {} from {}", name, url);
384 Ok(())
385 }
386
387 pub fn recommend_models(
389 &self,
390 task_description: &str,
391 num_qubits: Option<usize>,
392 ) -> Vec<&ModelMetadata> {
393 let task_lower = task_description.to_lowercase();
394 let mut recommendations: Vec<_> = self
395 .models
396 .values()
397 .filter(|meta| {
398 if let Some(qubits) = num_qubits {
400 if meta.requirements.min_qubits > qubits {
401 return false;
402 }
403 }
404
405 task_lower.contains("classification")
407 && matches!(meta.category, ModelCategory::Classification)
408 || task_lower.contains("regression")
409 && matches!(meta.category, ModelCategory::Regression)
410 || task_lower.contains("generation")
411 && matches!(meta.category, ModelCategory::Generative)
412 || task_lower.contains("anomaly")
413 && matches!(meta.category, ModelCategory::AnomalyDetection)
414 || task_lower.contains("time series")
415 && matches!(meta.category, ModelCategory::TimeSeries)
416 || task_lower.contains("nlp") && matches!(meta.category, ModelCategory::NLP)
417 || task_lower.contains("vision")
418 && matches!(meta.category, ModelCategory::Vision)
419 })
420 .collect();
421
422 recommendations.sort_by(|a, b| match (a.accuracy, b.accuracy) {
424 (Some(acc_a), Some(acc_b)) => acc_b.partial_cmp(&acc_a).unwrap(),
425 (Some(_), None) => std::cmp::Ordering::Less,
426 (None, Some(_)) => std::cmp::Ordering::Greater,
427 (None, None) => std::cmp::Ordering::Equal,
428 });
429
430 recommendations
431 }
432
433 pub fn export_catalog(&self, path: &str) -> Result<()> {
435 let catalog: Vec<_> = self.models.values().collect();
436 let json = serde_json::to_string_pretty(&catalog)?;
437 std::fs::write(path, json)?;
438 Ok(())
439 }
440
441 pub fn import_catalog(&mut self, path: &str) -> Result<()> {
443 let json = std::fs::read_to_string(path)?;
444 let catalog: Vec<ModelMetadata> = serde_json::from_str(&json)?;
445
446 for metadata in catalog {
447 self.models.insert(metadata.name.clone(), metadata);
448 }
449
450 Ok(())
451 }
452}
453
454pub struct MNISTQuantumNN {
458 model: Sequential,
459 metadata: ModelMetadata,
460}
461
462impl MNISTQuantumNN {
463 pub fn new() -> Result<Self> {
464 let mut model = Sequential::new().name("mnist_qnn");
465
466 model.add(Box::new(
468 QuantumDense::new(8, 64)
469 .ansatz_type(QuantumAnsatzType::HardwareEfficient)
470 .num_layers(2)
471 .name("quantum_layer"),
472 ));
473
474 model.add(Box::new(
476 Dense::new(10)
477 .activation(ActivationFunction::Softmax)
478 .name("output_layer"),
479 ));
480
481 model.build(vec![784])?;
482
483 let metadata = ModelMetadata {
484 name: "MNIST Quantum Neural Network".to_string(),
485 description: "Pre-trained quantum neural network for MNIST digit classification"
486 .to_string(),
487 category: ModelCategory::Classification,
488 input_shape: vec![784],
489 output_shape: vec![10],
490 num_qubits: 8,
491 num_parameters: 32,
492 dataset: "MNIST".to_string(),
493 accuracy: Some(0.92),
494 size_bytes: 1024,
495 created_date: "2024-01-15".to_string(),
496 version: "1.0".to_string(),
497 requirements: ModelRequirements {
498 min_qubits: 8,
499 coherence_time: 100.0,
500 gate_fidelity: 0.99,
501 backends: vec!["statevector".to_string(), "qasm".to_string()],
502 },
503 };
504
505 Ok(Self { model, metadata })
506 }
507}
508
509impl QuantumModel for MNISTQuantumNN {
510 fn name(&self) -> &str {
511 &self.metadata.name
512 }
513
514 fn predict(&self, input: &ArrayD<f64>) -> Result<ArrayD<f64>> {
515 self.model.predict(input)
516 }
517
518 fn metadata(&self) -> &ModelMetadata {
519 &self.metadata
520 }
521
522 fn save(&self, path: &str) -> Result<()> {
523 std::fs::write(
525 format!("{}_metadata.json", path),
526 serde_json::to_string(&self.metadata)?,
527 )?;
528 Ok(())
529 }
530
531 fn load(path: &str) -> Result<Box<dyn QuantumModel>> {
532 Ok(Box::new(Self::new()?))
534 }
535
536 fn architecture(&self) -> String {
537 "QuantumDense(8 qubits, 64 units) -> Dense(10 units, softmax)".to_string()
538 }
539
540 fn training_config(&self) -> TrainingConfig {
541 TrainingConfig {
542 loss_function: "categorical_crossentropy".to_string(),
543 optimizer: "adam".to_string(),
544 learning_rate: 0.001,
545 epochs: 100,
546 batch_size: 32,
547 validation_split: 0.2,
548 }
549 }
550}
551
552pub struct IrisQuantumSVM {
554 model: QSVM,
555 metadata: ModelMetadata,
556}
557
558impl IrisQuantumSVM {
559 pub fn new() -> Result<Self> {
560 let params = QSVMParams {
561 feature_map: FeatureMapType::ZZFeatureMap,
562 reps: 2,
563 c: 1.0,
564 tolerance: 1e-3,
565 num_qubits: 4,
566 depth: 2,
567 gamma: None,
568 regularization: 1.0,
569 max_iterations: 100,
570 seed: None,
571 };
572
573 let model = QSVM::new(params);
574
575 let metadata = ModelMetadata {
576 name: "Iris Quantum SVM".to_string(),
577 description: "Pre-trained quantum SVM for Iris flower classification".to_string(),
578 category: ModelCategory::Classification,
579 input_shape: vec![4],
580 output_shape: vec![3],
581 num_qubits: 4,
582 num_parameters: 16,
583 dataset: "Iris".to_string(),
584 accuracy: Some(0.97),
585 size_bytes: 512,
586 created_date: "2024-01-20".to_string(),
587 version: "1.0".to_string(),
588 requirements: ModelRequirements {
589 min_qubits: 4,
590 coherence_time: 50.0,
591 gate_fidelity: 0.995,
592 backends: vec!["statevector".to_string()],
593 },
594 };
595
596 Ok(Self { model, metadata })
597 }
598}
599
600impl QuantumModel for IrisQuantumSVM {
601 fn name(&self) -> &str {
602 &self.metadata.name
603 }
604
605 fn predict(&self, input: &ArrayD<f64>) -> Result<ArrayD<f64>> {
606 let input_2d = input
608 .clone()
609 .into_dimensionality::<ndarray::Ix2>()
610 .map_err(|_| MLError::InvalidConfiguration("Input must be 2D".to_string()))?;
611
612 let predictions_i32 = self
614 .model
615 .predict(&input_2d)
616 .map_err(|e| MLError::ValidationError(e))?;
617
618 let predictions_f64 = predictions_i32.mapv(|x| x as f64);
620 Ok(predictions_f64.into_dyn())
621 }
622
623 fn metadata(&self) -> &ModelMetadata {
624 &self.metadata
625 }
626
627 fn save(&self, path: &str) -> Result<()> {
628 std::fs::write(
629 format!("{}_metadata.json", path),
630 serde_json::to_string(&self.metadata)?,
631 )?;
632 Ok(())
633 }
634
635 fn load(path: &str) -> Result<Box<dyn QuantumModel>> {
636 Ok(Box::new(Self::new()?))
637 }
638
639 fn architecture(&self) -> String {
640 "Quantum SVM with ZZ Feature Map (4 qubits, depth 2)".to_string()
641 }
642
643 fn training_config(&self) -> TrainingConfig {
644 TrainingConfig {
645 loss_function: "hinge".to_string(),
646 optimizer: "cvxpy".to_string(),
647 learning_rate: 0.01,
648 epochs: 50,
649 batch_size: 16,
650 validation_split: 0.3,
651 }
652 }
653}
654
655pub struct H2VQE {
657 metadata: ModelMetadata,
658 optimal_parameters: Array1<f64>,
659}
660
661impl H2VQE {
662 pub fn new() -> Result<Self> {
663 let metadata = ModelMetadata {
664 name: "H2 Molecule VQE".to_string(),
665 description: "Pre-trained VQE for hydrogen molecule ground state".to_string(),
666 category: ModelCategory::Variational,
667 input_shape: vec![1],
668 output_shape: vec![1],
669 num_qubits: 4,
670 num_parameters: 8,
671 dataset: "H2 PES".to_string(),
672 accuracy: Some(0.999),
673 size_bytes: 256,
674 created_date: "2024-01-25".to_string(),
675 version: "1.0".to_string(),
676 requirements: ModelRequirements {
677 min_qubits: 4,
678 coherence_time: 200.0,
679 gate_fidelity: 0.999,
680 backends: vec!["statevector".to_string()],
681 },
682 };
683
684 let optimal_parameters = Array1::from_vec(vec![
686 0.0,
687 std::f64::consts::PI,
688 0.0,
689 std::f64::consts::PI,
690 0.0,
691 0.0,
692 0.0,
693 0.0,
694 ]);
695
696 Ok(Self {
697 metadata,
698 optimal_parameters,
699 })
700 }
701}
702
703impl QuantumModel for H2VQE {
704 fn name(&self) -> &str {
705 &self.metadata.name
706 }
707
708 fn predict(&self, input: &ArrayD<f64>) -> Result<ArrayD<f64>> {
709 let bond_length = input[[0]];
711 let energy = -1.137 + 0.5 * (bond_length - 0.74).powi(2); Ok(ArrayD::from_shape_vec(vec![1], vec![energy])?)
713 }
714
715 fn metadata(&self) -> &ModelMetadata {
716 &self.metadata
717 }
718
719 fn save(&self, path: &str) -> Result<()> {
720 std::fs::write(
721 format!("{}_metadata.json", path),
722 serde_json::to_string(&self.metadata)?,
723 )?;
724 Ok(())
725 }
726
727 fn load(path: &str) -> Result<Box<dyn QuantumModel>> {
728 Ok(Box::new(Self::new()?))
729 }
730
731 fn architecture(&self) -> String {
732 "VQE with UCCSD ansatz (4 qubits, 8 parameters)".to_string()
733 }
734
735 fn training_config(&self) -> TrainingConfig {
736 TrainingConfig {
737 loss_function: "energy_expectation".to_string(),
738 optimizer: "cobyla".to_string(),
739 learning_rate: 0.1,
740 epochs: 200,
741 batch_size: 1,
742 validation_split: 0.0,
743 }
744 }
745}
746
747pub struct PortfolioQAOA {
749 metadata: ModelMetadata,
750}
751
752impl PortfolioQAOA {
753 pub fn new() -> Result<Self> {
754 let metadata = ModelMetadata {
755 name: "Portfolio Optimization QAOA".to_string(),
756 description: "Pre-trained QAOA for portfolio optimization problems".to_string(),
757 category: ModelCategory::Variational,
758 input_shape: vec![100],
759 output_shape: vec![10],
760 num_qubits: 10,
761 num_parameters: 20,
762 dataset: "S&P 500".to_string(),
763 accuracy: None,
764 size_bytes: 2048,
765 created_date: "2024-02-01".to_string(),
766 version: "1.0".to_string(),
767 requirements: ModelRequirements {
768 min_qubits: 10,
769 coherence_time: 150.0,
770 gate_fidelity: 0.98,
771 backends: vec!["statevector".to_string(), "aer".to_string()],
772 },
773 };
774
775 Ok(Self { metadata })
776 }
777}
778
779impl QuantumModel for PortfolioQAOA {
780 fn name(&self) -> &str {
781 &self.metadata.name
782 }
783
784 fn predict(&self, input: &ArrayD<f64>) -> Result<ArrayD<f64>> {
785 let returns = input.slice(s![..10]);
787 let weights = returns.mapv(|x| if x > 0.0 { 1.0 } else { 0.0 });
788 let normalized_weights = &weights / weights.sum();
789 Ok(normalized_weights.to_owned().into_dyn())
790 }
791
792 fn metadata(&self) -> &ModelMetadata {
793 &self.metadata
794 }
795
796 fn save(&self, path: &str) -> Result<()> {
797 std::fs::write(
798 format!("{}_metadata.json", path),
799 serde_json::to_string(&self.metadata)?,
800 )?;
801 Ok(())
802 }
803
804 fn load(path: &str) -> Result<Box<dyn QuantumModel>> {
805 Ok(Box::new(Self::new()?))
806 }
807
808 fn architecture(&self) -> String {
809 "QAOA with p=5 layers (10 qubits, 20 parameters)".to_string()
810 }
811
812 fn training_config(&self) -> TrainingConfig {
813 TrainingConfig {
814 loss_function: "portfolio_variance".to_string(),
815 optimizer: "cobyla".to_string(),
816 learning_rate: 0.05,
817 epochs: 150,
818 batch_size: 1,
819 validation_split: 0.0,
820 }
821 }
822}
823
824pub struct QuantumAnomalyDetector {
826 metadata: ModelMetadata,
827}
828
829impl QuantumAnomalyDetector {
830 pub fn new() -> Result<Self> {
831 let metadata = ModelMetadata {
832 name: "Quantum Autoencoder for Anomaly Detection".to_string(),
833 description: "Pre-trained quantum autoencoder for detecting anomalies in data"
834 .to_string(),
835 category: ModelCategory::AnomalyDetection,
836 input_shape: vec![16],
837 output_shape: vec![16],
838 num_qubits: 6,
839 num_parameters: 24,
840 dataset: "Credit Card Fraud".to_string(),
841 accuracy: Some(0.94),
842 size_bytes: 1536,
843 created_date: "2024-02-05".to_string(),
844 version: "1.0".to_string(),
845 requirements: ModelRequirements {
846 min_qubits: 6,
847 coherence_time: 120.0,
848 gate_fidelity: 0.995,
849 backends: vec!["statevector".to_string()],
850 },
851 };
852
853 Ok(Self { metadata })
854 }
855}
856
857impl QuantumModel for QuantumAnomalyDetector {
858 fn name(&self) -> &str {
859 &self.metadata.name
860 }
861
862 fn predict(&self, input: &ArrayD<f64>) -> Result<ArrayD<f64>> {
863 let reconstruction = input * 0.95; Ok(reconstruction)
866 }
867
868 fn metadata(&self) -> &ModelMetadata {
869 &self.metadata
870 }
871
872 fn save(&self, path: &str) -> Result<()> {
873 std::fs::write(
874 format!("{}_metadata.json", path),
875 serde_json::to_string(&self.metadata)?,
876 )?;
877 Ok(())
878 }
879
880 fn load(path: &str) -> Result<Box<dyn QuantumModel>> {
881 Ok(Box::new(Self::new()?))
882 }
883
884 fn architecture(&self) -> String {
885 "Quantum Autoencoder: Encoder(16->4) + Decoder(4->16) with 6 qubits".to_string()
886 }
887
888 fn training_config(&self) -> TrainingConfig {
889 TrainingConfig {
890 loss_function: "reconstruction_error".to_string(),
891 optimizer: "adam".to_string(),
892 learning_rate: 0.001,
893 epochs: 80,
894 batch_size: 64,
895 validation_split: 0.2,
896 }
897 }
898}
899
900pub struct QuantumTimeSeriesForecaster {
902 metadata: ModelMetadata,
903}
904
905impl QuantumTimeSeriesForecaster {
906 pub fn new() -> Result<Self> {
907 let metadata = ModelMetadata {
908 name: "Quantum Time Series Forecaster".to_string(),
909 description: "Pre-trained quantum model for time series forecasting".to_string(),
910 category: ModelCategory::TimeSeries,
911 input_shape: vec![20],
912 output_shape: vec![1],
913 num_qubits: 8,
914 num_parameters: 40,
915 dataset: "Stock Prices".to_string(),
916 accuracy: Some(0.89),
917 size_bytes: 2560,
918 created_date: "2024-02-10".to_string(),
919 version: "1.0".to_string(),
920 requirements: ModelRequirements {
921 min_qubits: 8,
922 coherence_time: 100.0,
923 gate_fidelity: 0.99,
924 backends: vec!["statevector".to_string(), "mps".to_string()],
925 },
926 };
927
928 Ok(Self { metadata })
929 }
930}
931
932impl QuantumModel for QuantumTimeSeriesForecaster {
933 fn name(&self) -> &str {
934 &self.metadata.name
935 }
936
937 fn predict(&self, input: &ArrayD<f64>) -> Result<ArrayD<f64>> {
938 let window = input.slice(s![..20]);
940 let trend = (window[19] - window[0]) / 19.0;
941 let prediction = window[19] + trend;
942 Ok(ArrayD::from_shape_vec(vec![1], vec![prediction])?)
943 }
944
945 fn metadata(&self) -> &ModelMetadata {
946 &self.metadata
947 }
948
949 fn save(&self, path: &str) -> Result<()> {
950 std::fs::write(
951 format!("{}_metadata.json", path),
952 serde_json::to_string(&self.metadata)?,
953 )?;
954 Ok(())
955 }
956
957 fn load(path: &str) -> Result<Box<dyn QuantumModel>> {
958 Ok(Box::new(Self::new()?))
959 }
960
961 fn architecture(&self) -> String {
962 "Quantum LSTM: QuantumRNN(8 qubits, 40 params) + Dense(1)".to_string()
963 }
964
965 fn training_config(&self) -> TrainingConfig {
966 TrainingConfig {
967 loss_function: "mean_squared_error".to_string(),
968 optimizer: "adam".to_string(),
969 learning_rate: 0.001,
970 epochs: 120,
971 batch_size: 16,
972 validation_split: 0.2,
973 }
974 }
975}
976
977pub mod utils {
979 use super::*;
980
981 pub fn get_default_zoo() -> ModelZoo {
983 ModelZoo::new()
984 }
985
986 pub fn print_model_info(metadata: &ModelMetadata) {
988 println!("Model: {}", metadata.name);
989 println!("Description: {}", metadata.description);
990 println!("Category: {:?}", metadata.category);
991 println!("Input Shape: {:?}", metadata.input_shape);
992 println!("Output Shape: {:?}", metadata.output_shape);
993 println!("Qubits: {}", metadata.num_qubits);
994 println!("Parameters: {}", metadata.num_parameters);
995 println!("Dataset: {}", metadata.dataset);
996 if let Some(acc) = metadata.accuracy {
997 println!("Accuracy: {:.2}%", acc * 100.0);
998 }
999 println!("Size: {} bytes", metadata.size_bytes);
1000 println!("Version: {}", metadata.version);
1001 println!("Requirements:");
1002 println!(" Min Qubits: {}", metadata.requirements.min_qubits);
1003 println!(
1004 " Coherence Time: {:.1} μs",
1005 metadata.requirements.coherence_time
1006 );
1007 println!(
1008 " Gate Fidelity: {:.3}",
1009 metadata.requirements.gate_fidelity
1010 );
1011 println!(" Backends: {:?}", metadata.requirements.backends);
1012 println!();
1013 }
1014
1015 pub fn compare_models(model1: &ModelMetadata, model2: &ModelMetadata) -> std::cmp::Ordering {
1017 match (model1.accuracy, model2.accuracy) {
1019 (Some(acc1), Some(acc2)) => acc2.partial_cmp(&acc1).unwrap(),
1020 (Some(_), None) => std::cmp::Ordering::Less,
1021 (None, Some(_)) => std::cmp::Ordering::Greater,
1022 (None, None) => model1.num_parameters.cmp(&model2.num_parameters),
1023 }
1024 }
1025
1026 pub fn check_device_compatibility(
1028 metadata: &ModelMetadata,
1029 device_qubits: usize,
1030 device_coherence: f64,
1031 device_fidelity: f64,
1032 ) -> bool {
1033 metadata.requirements.min_qubits <= device_qubits
1034 && metadata.requirements.coherence_time <= device_coherence
1035 && metadata.requirements.gate_fidelity <= device_fidelity
1036 }
1037
1038 pub fn benchmark_model_zoo(zoo: &ModelZoo) -> String {
1040 let mut report = String::new();
1041 report.push_str("Model Zoo Benchmark Report\n");
1042 report.push_str("==========================\n\n");
1043
1044 let models = zoo.list_models();
1045 report.push_str(&format!("Total Models: {}\n", models.len()));
1046
1047 let mut category_counts = HashMap::new();
1049 for model in &models {
1050 *category_counts.entry(&model.category).or_insert(0) += 1;
1051 }
1052
1053 report.push_str("\nModels by Category:\n");
1054 for (category, count) in category_counts {
1055 report.push_str(&format!(" {:?}: {}\n", category, count));
1056 }
1057
1058 let min_qubits: Vec<_> = models.iter().map(|m| m.requirements.min_qubits).collect();
1060 let avg_qubits = min_qubits.iter().sum::<usize>() as f64 / min_qubits.len() as f64;
1061 let max_qubits = *min_qubits.iter().max().unwrap();
1062
1063 report.push_str(&format!("\nQubit Requirements:\n"));
1064 report.push_str(&format!(" Average: {:.1}\n", avg_qubits));
1065 report.push_str(&format!(" Maximum: {}\n", max_qubits));
1066
1067 let sizes: Vec<_> = models.iter().map(|m| m.size_bytes).collect();
1069 let total_size = sizes.iter().sum::<usize>();
1070 report.push_str(&format!(
1071 "\nTotal Size: {} bytes ({:.1} KB)\n",
1072 total_size,
1073 total_size as f64 / 1024.0
1074 ));
1075
1076 report
1077 }
1078}
1079
1080#[cfg(test)]
1081mod tests {
1082 use super::*;
1083
1084 #[test]
1085 fn test_model_zoo_creation() {
1086 let zoo = ModelZoo::new();
1087 assert!(!zoo.list_models().is_empty());
1088 }
1089
1090 #[test]
1091 fn test_model_search() {
1092 let zoo = ModelZoo::new();
1093 let results = zoo.search("mnist");
1094 assert!(!results.is_empty());
1095 assert!(results[0].name.to_lowercase().contains("mnist"));
1096 }
1097
1098 #[test]
1099 fn test_category_filtering() {
1100 let zoo = ModelZoo::new();
1101 let classification_models = zoo.list_by_category(&ModelCategory::Classification);
1102 assert!(!classification_models.is_empty());
1103
1104 for model in classification_models {
1105 assert!(matches!(model.category, ModelCategory::Classification));
1106 }
1107 }
1108
1109 #[test]
1110 fn test_model_recommendations() {
1111 let zoo = ModelZoo::new();
1112 let recommendations = zoo.recommend_models("classification task", Some(8));
1113 assert!(!recommendations.is_empty());
1114
1115 for model in recommendations {
1116 assert!(model.requirements.min_qubits <= 8);
1117 }
1118 }
1119
1120 #[test]
1121 fn test_model_metadata() {
1122 let zoo = ModelZoo::new();
1123 let metadata = zoo.get_metadata("mnist_qnn");
1124 assert!(metadata.is_some());
1125
1126 let meta = metadata.unwrap();
1127 assert_eq!(meta.name, "MNIST Quantum Neural Network");
1128 assert_eq!(meta.num_qubits, 8);
1129 }
1130
1131 #[test]
1132 fn test_device_compatibility() {
1133 let zoo = ModelZoo::new();
1134 let metadata = zoo.get_metadata("mnist_qnn").unwrap();
1135
1136 assert!(utils::check_device_compatibility(
1138 metadata, 10, 150.0, 0.995
1139 ));
1140
1141 assert!(!utils::check_device_compatibility(
1143 metadata, 4, 150.0, 0.995
1144 ));
1145 }
1146
1147 #[test]
1148 fn test_model_instantiation() {
1149 let mnist_model = MNISTQuantumNN::new();
1150 assert!(mnist_model.is_ok());
1151
1152 let model = mnist_model.unwrap();
1153 assert_eq!(model.name(), "MNIST Quantum Neural Network");
1154 assert_eq!(model.metadata().num_qubits, 8);
1155 }
1156
1157 #[test]
1158 fn test_catalog_export_import() {
1159 let mut zoo = ModelZoo::new();
1160
1161 let export_result = zoo.export_catalog("/tmp/test_catalog.json");
1163 assert!(export_result.is_ok());
1164
1165 let mut new_zoo = ModelZoo::new();
1167 new_zoo.models.clear(); let import_result = new_zoo.import_catalog("/tmp/test_catalog.json");
1170 assert!(import_result.is_ok());
1171
1172 assert!(!new_zoo.list_models().is_empty());
1173 }
1174}