1use crate::{CircuitExecutor, CircuitResult, DeviceError, DeviceResult, QuantumDevice};
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::sync::Arc;
11use std::time::Duration;
12use tokio::sync::RwLock;
13
14pub mod classical_integration;
15pub mod gradients;
16pub mod hardware_acceleration;
17pub mod inference;
18pub mod optimization;
19pub mod quantum_neural_networks;
20pub mod training;
21pub mod variational_algorithms;
22
23pub use classical_integration::*;
24pub use gradients::*;
25pub use hardware_acceleration::*;
26pub use inference::*;
27pub use optimization::*;
28pub use quantum_neural_networks::*;
29pub use training::*;
30pub use variational_algorithms::*;
31
32pub struct QMLAccelerator {
34 pub device: Arc<RwLock<dyn QuantumDevice + Send + Sync>>,
36 pub config: QMLConfig,
38 pub training_history: Vec<TrainingEpoch>,
40 pub model_registry: ModelRegistry,
42 pub hardware_manager: HardwareAccelerationManager,
44 pub is_connected: bool,
46}
47
48#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct QMLConfig {
51 pub max_qubits: usize,
53 pub optimizer: OptimizerType,
55 pub learning_rate: f64,
57 pub max_epochs: usize,
59 pub convergence_tolerance: f64,
61 pub batch_size: usize,
63 pub enable_hardware_acceleration: bool,
65 pub gradient_method: GradientMethod,
67 pub noise_resilience: NoiseResilienceLevel,
69 pub max_circuit_depth: usize,
71 pub parameter_update_frequency: usize,
73}
74
75impl Default for QMLConfig {
76 fn default() -> Self {
77 Self {
78 max_qubits: 20,
79 optimizer: OptimizerType::Adam,
80 learning_rate: 0.01,
81 max_epochs: 1000,
82 convergence_tolerance: 1e-6,
83 batch_size: 32,
84 enable_hardware_acceleration: true,
85 gradient_method: GradientMethod::ParameterShift,
86 noise_resilience: NoiseResilienceLevel::Medium,
87 max_circuit_depth: 100,
88 parameter_update_frequency: 10,
89 }
90 }
91}
92
93#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
95pub enum OptimizerType {
96 GradientDescent,
98 Adam,
100 AdaGrad,
102 RMSprop,
104 SPSA,
106 QuantumNaturalGradient,
108 NelderMead,
110 COBYLA,
112}
113
114#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
116pub enum GradientMethod {
117 ParameterShift,
119 FiniteDifference,
121 LinearCombination,
123 QuantumNaturalGradient,
125 Adjoint,
127}
128
129#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
131pub enum NoiseResilienceLevel {
132 Low,
133 Medium,
134 High,
135 Adaptive,
136}
137
138#[derive(Debug, Clone, Serialize, Deserialize)]
140pub struct TrainingEpoch {
141 pub epoch: usize,
142 pub loss: f64,
143 pub accuracy: Option<f64>,
144 pub parameters: Vec<f64>,
145 pub gradient_norm: f64,
146 pub learning_rate: f64,
147 pub execution_time: Duration,
148 pub quantum_fidelity: Option<f64>,
149 pub classical_preprocessing_time: Duration,
150 pub quantum_execution_time: Duration,
151}
152
153#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
155pub enum QMLModelType {
156 VQC,
158 QNN,
160 QAOA,
162 VQE,
164 QGAN,
166 QCNN,
168 HybridNetwork,
170}
171
172impl QMLAccelerator {
173 pub fn new(
175 device: Arc<RwLock<dyn QuantumDevice + Send + Sync>>,
176 config: QMLConfig,
177 ) -> DeviceResult<Self> {
178 let model_registry = ModelRegistry::new();
179 let hardware_manager = HardwareAccelerationManager::new(&config)?;
180
181 Ok(Self {
182 device,
183 config,
184 training_history: Vec::new(),
185 model_registry,
186 hardware_manager,
187 is_connected: false,
188 })
189 }
190
191 pub async fn connect(&mut self) -> DeviceResult<()> {
193 let device = self.device.read().await;
194 if !device.is_available().await? {
195 return Err(DeviceError::DeviceNotInitialized(
196 "Quantum device not available".to_string(),
197 ));
198 }
199
200 self.hardware_manager.initialize().await?;
201 self.is_connected = true;
202 Ok(())
203 }
204
205 pub async fn disconnect(&mut self) -> DeviceResult<()> {
207 self.hardware_manager.shutdown().await?;
208 self.is_connected = false;
209 Ok(())
210 }
211
212 pub async fn train_model(
214 &mut self,
215 model_type: QMLModelType,
216 training_data: training::TrainingData,
217 validation_data: Option<training::TrainingData>,
218 ) -> DeviceResult<training::TrainingResult> {
219 if !self.is_connected {
220 return Err(DeviceError::DeviceNotInitialized(
221 "QML accelerator not connected".to_string(),
222 ));
223 }
224
225 let mut trainer = QuantumTrainer::new(self.device.clone(), &self.config, model_type)?;
226
227 let result = trainer
228 .train(training_data, validation_data, &mut self.training_history)
229 .await?;
230
231 self.model_registry
233 .register_model(result.model_id.clone(), result.model.clone())?;
234
235 Ok(result)
236 }
237
238 pub async fn inference(
240 &self,
241 model_id: &str,
242 input_data: InferenceData,
243 ) -> DeviceResult<InferenceResult> {
244 if !self.is_connected {
245 return Err(DeviceError::DeviceNotInitialized(
246 "QML accelerator not connected".to_string(),
247 ));
248 }
249
250 let model = self.model_registry.get_model(model_id)?;
251 let inference_engine = QuantumInferenceEngine::new(self.device.clone(), &self.config)?;
252
253 inference_engine.inference(model, input_data).await
254 }
255
256 pub async fn optimize_parameters(
258 &mut self,
259 initial_parameters: Vec<f64>,
260 objective_function: Box<dyn ObjectiveFunction + Send + Sync>,
261 ) -> DeviceResult<OptimizationResult> {
262 let mut optimizer =
263 create_gradient_optimizer(self.device.clone(), OptimizerType::Adam, 0.01);
264
265 optimizer.optimize(initial_parameters, objective_function)
266 }
267
268 pub async fn compute_gradients(
270 &self,
271 circuit: ParameterizedQuantumCircuit,
272 parameters: Vec<f64>,
273 ) -> DeviceResult<Vec<f64>> {
274 let gradient_calculator =
275 QuantumGradientCalculator::new(self.device.clone(), GradientConfig::default())?;
276
277 gradient_calculator
278 .compute_gradients(circuit, parameters)
279 .await
280 }
281
282 pub fn get_training_statistics(&self) -> TrainingStatistics {
284 TrainingStatistics::from_history(&self.training_history)
285 }
286
287 pub async fn export_model(
289 &self,
290 model_id: &str,
291 format: ModelExportFormat,
292 ) -> DeviceResult<Vec<u8>> {
293 let model = self.model_registry.get_model(model_id)?;
294 model.export(format).await
295 }
296
297 pub async fn import_model(
299 &mut self,
300 model_data: Vec<u8>,
301 format: ModelExportFormat,
302 ) -> DeviceResult<String> {
303 let model = QMLModel::import(model_data, format).await?;
304 let model_id = format!("imported_model_{}", uuid::Uuid::new_v4());
305
306 self.model_registry
307 .register_model(model_id.clone(), model)?;
308 Ok(model_id)
309 }
310
311 pub async fn get_acceleration_metrics(&self) -> HardwareAccelerationMetrics {
313 self.hardware_manager.get_metrics().await
314 }
315
316 pub async fn benchmark_performance(
318 &self,
319 model_type: QMLModelType,
320 problem_size: usize,
321 ) -> DeviceResult<PerformanceBenchmark> {
322 let benchmark_engine = PerformanceBenchmarkEngine::new(self.device.clone(), &self.config)?;
323
324 benchmark_engine.benchmark(model_type, problem_size).await
325 }
326
327 pub async fn get_diagnostics(&self) -> QMLDiagnostics {
329 let device = self.device.read().await;
330 let device_props = device.properties().await.unwrap_or_default();
331
332 QMLDiagnostics {
333 is_connected: self.is_connected,
334 total_models: self.model_registry.model_count(),
335 training_epochs_completed: self.training_history.len(),
336 hardware_acceleration_enabled: self.config.enable_hardware_acceleration,
337 active_model_count: self.model_registry.active_model_count(),
338 average_training_time: self.calculate_average_training_time(),
339 quantum_advantage_ratio: self.hardware_manager.get_quantum_advantage_ratio().await,
340 device_properties: device_props,
341 }
342 }
343
344 fn calculate_average_training_time(&self) -> Duration {
345 if self.training_history.is_empty() {
346 return Duration::from_secs(0);
347 }
348
349 let total_time: Duration = self
350 .training_history
351 .iter()
352 .map(|epoch| epoch.execution_time)
353 .sum();
354
355 total_time / self.training_history.len() as u32
356 }
357}
358
359#[derive(Debug, Clone, Serialize, Deserialize)]
361pub struct InferenceData {
362 pub features: Vec<f64>,
363 pub metadata: HashMap<String, String>,
364}
365
366#[derive(Debug, Clone, Serialize, Deserialize)]
368pub struct InferenceResult {
369 pub prediction: f64,
370 pub confidence: Option<f64>,
371 pub quantum_fidelity: Option<f64>,
372 pub execution_time: Duration,
373 pub metadata: HashMap<String, String>,
374}
375
376#[derive(Debug, Clone, Serialize, Deserialize)]
378pub struct QMLModel {
379 pub model_type: QMLModelType,
380 pub parameters: Vec<f64>,
381 pub circuit_structure: CircuitStructure,
382 pub training_metadata: HashMap<String, String>,
383 pub performance_metrics: HashMap<String, f64>,
384}
385
386impl QMLModel {
387 pub async fn export(&self, format: ModelExportFormat) -> DeviceResult<Vec<u8>> {
388 match format {
389 ModelExportFormat::JSON => serde_json::to_vec(self)
390 .map_err(|e| DeviceError::InvalidInput(format!("JSON export error: {e}"))),
391 ModelExportFormat::Binary => {
392 bincode::serde::encode_to_vec(self, bincode::config::standard())
393 .map_err(|e| DeviceError::InvalidInput(format!("Binary export error: {e:?}")))
394 }
395 ModelExportFormat::ONNX => {
396 Err(DeviceError::InvalidInput(
398 "ONNX export not yet implemented".to_string(),
399 ))
400 }
401 }
402 }
403
404 pub async fn import(data: Vec<u8>, format: ModelExportFormat) -> DeviceResult<Self> {
405 match format {
406 ModelExportFormat::JSON => serde_json::from_slice(&data)
407 .map_err(|e| DeviceError::InvalidInput(format!("JSON import error: {e}"))),
408 ModelExportFormat::Binary => {
409 bincode::serde::decode_from_slice(&data, bincode::config::standard())
410 .map(|(v, _consumed)| v)
411 .map_err(|e| DeviceError::InvalidInput(format!("Binary import error: {e:?}")))
412 }
413 ModelExportFormat::ONNX => Err(DeviceError::InvalidInput(
414 "ONNX import not yet implemented".to_string(),
415 )),
416 }
417 }
418}
419
420#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
422pub enum ModelExportFormat {
423 JSON,
424 Binary,
425 ONNX,
426}
427
428#[derive(Debug, Clone, Serialize, Deserialize)]
430pub struct CircuitStructure {
431 pub num_qubits: usize,
432 pub depth: usize,
433 pub gate_types: Vec<String>,
434 pub parameter_count: usize,
435 pub entangling_gates: usize,
436}
437
438#[derive(Debug, Clone, Serialize, Deserialize)]
440pub struct TrainingStatistics {
441 pub total_epochs: usize,
442 pub final_loss: f64,
443 pub best_loss: f64,
444 pub average_loss: f64,
445 pub convergence_epoch: Option<usize>,
446 pub total_training_time: Duration,
447 pub average_epoch_time: Duration,
448}
449
450impl TrainingStatistics {
451 pub fn from_history(history: &[TrainingEpoch]) -> Self {
452 if history.is_empty() {
453 return Self {
454 total_epochs: 0,
455 final_loss: 0.0,
456 best_loss: f64::INFINITY,
457 average_loss: 0.0,
458 convergence_epoch: None,
459 total_training_time: Duration::from_secs(0),
460 average_epoch_time: Duration::from_secs(0),
461 };
462 }
463
464 let total_epochs = history.len();
465 let final_loss = history
467 .last()
468 .expect("history should not be empty after early return check")
469 .loss;
470 let best_loss = history.iter().map(|e| e.loss).fold(f64::INFINITY, f64::min);
471 let average_loss = history.iter().map(|e| e.loss).sum::<f64>() / total_epochs as f64;
472 let total_training_time = history.iter().map(|e| e.execution_time).sum();
473 let average_epoch_time = total_training_time / total_epochs as u32;
474
475 Self {
476 total_epochs,
477 final_loss,
478 best_loss,
479 average_loss,
480 convergence_epoch: None, total_training_time,
482 average_epoch_time,
483 }
484 }
485}
486
487#[derive(Debug, Clone, Serialize, Deserialize)]
489pub struct QMLDiagnostics {
490 pub is_connected: bool,
491 pub total_models: usize,
492 pub training_epochs_completed: usize,
493 pub hardware_acceleration_enabled: bool,
494 pub active_model_count: usize,
495 pub average_training_time: Duration,
496 pub quantum_advantage_ratio: f64,
497 pub device_properties: HashMap<String, String>,
498}
499
500pub struct ModelRegistry {
502 models: HashMap<String, QMLModel>,
503 active_models: HashMap<String, bool>,
504}
505
506impl Default for ModelRegistry {
507 fn default() -> Self {
508 Self::new()
509 }
510}
511
512impl ModelRegistry {
513 pub fn new() -> Self {
514 Self {
515 models: HashMap::new(),
516 active_models: HashMap::new(),
517 }
518 }
519
520 pub fn register_model(&mut self, id: String, model: QMLModel) -> DeviceResult<()> {
521 self.models.insert(id.clone(), model);
522 self.active_models.insert(id, true);
523 Ok(())
524 }
525
526 pub fn get_model(&self, id: &str) -> DeviceResult<&QMLModel> {
527 self.models
528 .get(id)
529 .ok_or_else(|| DeviceError::InvalidInput(format!("Model {id} not found")))
530 }
531
532 pub fn model_count(&self) -> usize {
533 self.models.len()
534 }
535
536 pub fn active_model_count(&self) -> usize {
537 self.active_models
538 .values()
539 .filter(|&&active| active)
540 .count()
541 }
542
543 pub fn deactivate_model(&mut self, id: &str) -> DeviceResult<()> {
544 if self.active_models.contains_key(id) {
545 self.active_models.insert(id.to_string(), false);
546 Ok(())
547 } else {
548 Err(DeviceError::InvalidInput(format!("Model {id} not found")))
549 }
550 }
551}
552
553pub fn create_vqc_accelerator(
555 device: Arc<RwLock<dyn QuantumDevice + Send + Sync>>,
556 num_qubits: usize,
557) -> DeviceResult<QMLAccelerator> {
558 let config = QMLConfig {
559 max_qubits: num_qubits,
560 optimizer: OptimizerType::Adam,
561 gradient_method: GradientMethod::ParameterShift,
562 ..Default::default()
563 };
564
565 QMLAccelerator::new(device, config)
566}
567
568pub fn create_qaoa_accelerator(
570 device: Arc<RwLock<dyn QuantumDevice + Send + Sync>>,
571 problem_size: usize,
572) -> DeviceResult<QMLAccelerator> {
573 let config = QMLConfig {
574 max_qubits: problem_size,
575 optimizer: OptimizerType::COBYLA,
576 gradient_method: GradientMethod::FiniteDifference,
577 max_circuit_depth: 50,
578 ..Default::default()
579 };
580
581 QMLAccelerator::new(device, config)
582}
583
584#[cfg(test)]
585mod tests {
586 use super::*;
587 use crate::test_utils::*;
588
589 #[tokio::test]
590 async fn test_qml_accelerator_creation() {
591 let device = create_mock_quantum_device();
592 let accelerator = QMLAccelerator::new(device, QMLConfig::default())
593 .expect("QML accelerator creation should succeed with mock device");
594
595 assert_eq!(accelerator.config.max_qubits, 20);
596 assert!(!accelerator.is_connected);
597 }
598
599 #[tokio::test]
600 async fn test_model_registry() {
601 let mut registry = ModelRegistry::new();
602 assert_eq!(registry.model_count(), 0);
603
604 let model = QMLModel {
605 model_type: QMLModelType::VQC,
606 parameters: vec![0.1, 0.2, 0.3],
607 circuit_structure: CircuitStructure {
608 num_qubits: 4,
609 depth: 10,
610 gate_types: vec!["RY".to_string(), "CNOT".to_string()],
611 parameter_count: 8,
612 entangling_gates: 4,
613 },
614 training_metadata: HashMap::new(),
615 performance_metrics: HashMap::new(),
616 };
617
618 registry
619 .register_model("test_model".to_string(), model)
620 .expect("registering model should succeed");
621 assert_eq!(registry.model_count(), 1);
622 assert_eq!(registry.active_model_count(), 1);
623
624 let retrieved = registry
625 .get_model("test_model")
626 .expect("retrieving registered model should succeed");
627 assert_eq!(retrieved.model_type, QMLModelType::VQC);
628 }
629
630 #[test]
631 fn test_training_statistics() {
632 let history = vec![
633 TrainingEpoch {
634 epoch: 0,
635 loss: 1.0,
636 accuracy: Some(0.5),
637 parameters: vec![0.1],
638 gradient_norm: 0.5,
639 learning_rate: 0.01,
640 execution_time: Duration::from_millis(100),
641 quantum_fidelity: Some(0.95),
642 classical_preprocessing_time: Duration::from_millis(10),
643 quantum_execution_time: Duration::from_millis(90),
644 },
645 TrainingEpoch {
646 epoch: 1,
647 loss: 0.5,
648 accuracy: Some(0.7),
649 parameters: vec![0.2],
650 gradient_norm: 0.3,
651 learning_rate: 0.01,
652 execution_time: Duration::from_millis(120),
653 quantum_fidelity: Some(0.96),
654 classical_preprocessing_time: Duration::from_millis(15),
655 quantum_execution_time: Duration::from_millis(105),
656 },
657 ];
658
659 let stats = TrainingStatistics::from_history(&history);
660 assert_eq!(stats.total_epochs, 2);
661 assert_eq!(stats.final_loss, 0.5);
662 assert_eq!(stats.best_loss, 0.5);
663 assert_eq!(stats.average_loss, 0.75);
664 }
665}