1use super::*;
7use crate::continuous_variable::Complex;
8use crate::{CircuitExecutor, CircuitResult, DeviceError, DeviceResult, QuantumDevice};
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::sync::Arc;
12use tokio::sync::RwLock;
13
14pub struct QuantumGradientCalculator {
16 device: Arc<RwLock<dyn QuantumDevice + Send + Sync>>,
17 config: GradientConfig,
18 method: GradientMethod,
19}
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct GradientConfig {
24 pub method: GradientMethod,
26 pub shots: usize,
28 pub finite_diff_step: f64,
30 pub shift_amount: f64,
32 pub use_error_mitigation: bool,
34 pub parallel_execution: bool,
36 pub gradient_clipping: Option<f64>,
38}
39
40impl Default for GradientConfig {
41 fn default() -> Self {
42 Self {
43 method: GradientMethod::ParameterShift,
44 shots: 1024,
45 finite_diff_step: 1e-4,
46 shift_amount: std::f64::consts::PI / 2.0,
47 use_error_mitigation: true,
48 parallel_execution: true,
49 gradient_clipping: Some(1.0),
50 }
51 }
52}
53
54impl QuantumGradientCalculator {
55 pub fn new(
57 device: Arc<RwLock<dyn QuantumDevice + Send + Sync>>,
58 config: GradientConfig,
59 ) -> DeviceResult<Self> {
60 let method = config.method.clone();
61
62 Ok(Self {
63 device,
64 config,
65 method,
66 })
67 }
68
69 pub async fn compute_gradients(
71 &self,
72 circuit: ParameterizedQuantumCircuit,
73 parameters: Vec<f64>,
74 ) -> DeviceResult<Vec<f64>> {
75 match self.method {
76 GradientMethod::ParameterShift => {
77 self.parameter_shift_gradients(circuit, parameters).await
78 }
79 GradientMethod::FiniteDifference => {
80 self.finite_difference_gradients(circuit, parameters).await
81 }
82 GradientMethod::LinearCombination => {
83 self.linear_combination_gradients(circuit, parameters).await
84 }
85 GradientMethod::QuantumNaturalGradient => {
86 self.quantum_natural_gradients(circuit, parameters).await
87 }
88 GradientMethod::Adjoint => self.adjoint_gradients(circuit, parameters).await,
89 }
90 }
91
92 async fn parameter_shift_gradients(
94 &self,
95 circuit: ParameterizedQuantumCircuit,
96 parameters: Vec<f64>,
97 ) -> DeviceResult<Vec<f64>> {
98 let mut gradients = vec![0.0; parameters.len()];
99 let shift = self.config.shift_amount;
100
101 if self.config.parallel_execution {
102 let mut tasks = Vec::new();
104
105 for i in 0..parameters.len() {
106 let mut params_plus = parameters.clone();
107 let mut params_minus = parameters.clone();
108 params_plus[i] += shift;
109 params_minus[i] -= shift;
110
111 let circuit_plus = circuit.clone();
112 let circuit_minus = circuit.clone();
113 let device_plus = self.device.clone();
114 let device_minus = self.device.clone();
115 let shots = self.config.shots;
116
117 let task_plus = tokio::spawn(async move {
118 let circuit_eval =
119 Self::evaluate_circuit_with_params(&circuit_plus, ¶ms_plus)?;
120 let device = device_plus.read().await;
121 Self::execute_circuit_helper(&*device, &circuit_eval, shots).await
122 });
123
124 let task_minus = tokio::spawn(async move {
125 let circuit_eval =
126 Self::evaluate_circuit_with_params(&circuit_minus, ¶ms_minus)?;
127 let device = device_minus.read().await;
128 Self::execute_circuit_helper(&*device, &circuit_eval, shots).await
129 });
130
131 tasks.push((i, task_plus, task_minus));
132 }
133
134 for (param_idx, task_plus, task_minus) in tasks {
136 let result_plus = task_plus
137 .await
138 .map_err(|e| DeviceError::InvalidInput(format!("Task error: {e}")))??;
139 let result_minus = task_minus
140 .await
141 .map_err(|e| DeviceError::InvalidInput(format!("Task error: {e}")))??;
142
143 let expectation_plus = self.compute_expectation_value(&result_plus)?;
144 let expectation_minus = self.compute_expectation_value(&result_minus)?;
145
146 gradients[param_idx] = (expectation_plus - expectation_minus) / 2.0;
147 }
148 } else {
149 for i in 0..parameters.len() {
151 let mut params_plus = parameters.clone();
152 let mut params_minus = parameters.clone();
153 params_plus[i] += shift;
154 params_minus[i] -= shift;
155
156 let circuit_plus = Self::evaluate_circuit_with_params(&circuit, ¶ms_plus)?;
157 let circuit_minus = Self::evaluate_circuit_with_params(&circuit, ¶ms_minus)?;
158
159 let device = self.device.read().await;
160 let result_plus =
161 Self::execute_circuit_helper(&*device, &circuit_plus, self.config.shots)
162 .await?;
163 let result_minus =
164 Self::execute_circuit_helper(&*device, &circuit_minus, self.config.shots)
165 .await?;
166
167 let expectation_plus = self.compute_expectation_value(&result_plus)?;
168 let expectation_minus = self.compute_expectation_value(&result_minus)?;
169
170 gradients[i] = (expectation_plus - expectation_minus) / 2.0;
171 }
172 }
173
174 if let Some(clip_value) = self.config.gradient_clipping {
176 for grad in &mut gradients {
177 *grad = grad.clamp(-clip_value, clip_value);
178 }
179 }
180
181 Ok(gradients)
182 }
183
184 async fn finite_difference_gradients(
186 &self,
187 circuit: ParameterizedQuantumCircuit,
188 parameters: Vec<f64>,
189 ) -> DeviceResult<Vec<f64>> {
190 let mut gradients = vec![0.0; parameters.len()];
191 let step = self.config.finite_diff_step;
192
193 for i in 0..parameters.len() {
194 let mut params_plus = parameters.clone();
195 let mut params_minus = parameters.clone();
196 params_plus[i] += step;
197 params_minus[i] -= step;
198
199 let circuit_plus = Self::evaluate_circuit_with_params(&circuit, ¶ms_plus)?;
200 let circuit_minus = Self::evaluate_circuit_with_params(&circuit, ¶ms_minus)?;
201
202 let device = self.device.read().await;
203 let result_plus =
204 Self::execute_circuit_helper(&*device, &circuit_plus, self.config.shots).await?;
205 let result_minus =
206 Self::execute_circuit_helper(&*device, &circuit_minus, self.config.shots).await?;
207
208 let expectation_plus = self.compute_expectation_value(&result_plus)?;
209 let expectation_minus = self.compute_expectation_value(&result_minus)?;
210
211 gradients[i] = (expectation_plus - expectation_minus) / (2.0 * step);
212 }
213
214 Ok(gradients)
215 }
216
217 async fn linear_combination_gradients(
219 &self,
220 circuit: ParameterizedQuantumCircuit,
221 parameters: Vec<f64>,
222 ) -> DeviceResult<Vec<f64>> {
223 let mut gradients = vec![0.0; parameters.len()];
226
227 for i in 0..parameters.len() {
228 let step = 1e-3;
230 let mut params_plus = parameters.clone();
231 params_plus[i] += step;
232
233 let circuit_original = Self::evaluate_circuit_with_params(&circuit, ¶meters)?;
234 let circuit_plus = Self::evaluate_circuit_with_params(&circuit, ¶ms_plus)?;
235
236 let device = self.device.read().await;
237 let result_original =
238 Self::execute_circuit_helper(&*device, &circuit_original, self.config.shots)
239 .await?;
240 let result_plus =
241 Self::execute_circuit_helper(&*device, &circuit_plus, self.config.shots).await?;
242
243 let expectation_original = self.compute_expectation_value(&result_original)?;
244 let expectation_plus = self.compute_expectation_value(&result_plus)?;
245
246 gradients[i] = (expectation_plus - expectation_original) / step;
247 }
248
249 Ok(gradients)
250 }
251
252 async fn quantum_natural_gradients(
254 &self,
255 circuit: ParameterizedQuantumCircuit,
256 parameters: Vec<f64>,
257 ) -> DeviceResult<Vec<f64>> {
258 let regular_gradients = self
260 .parameter_shift_gradients(circuit.clone(), parameters.clone())
261 .await?;
262
263 let fisher_matrix = self
265 .compute_quantum_fisher_information(&circuit, ¶meters)
266 .await?;
267
268 let natural_gradients = self.solve_linear_system(&fisher_matrix, ®ular_gradients)?;
270
271 Ok(natural_gradients)
272 }
273
274 async fn adjoint_gradients(
276 &self,
277 circuit: ParameterizedQuantumCircuit,
278 parameters: Vec<f64>,
279 ) -> DeviceResult<Vec<f64>> {
280 self.parameter_shift_gradients(circuit, parameters).await
284 }
285
286 async fn compute_quantum_fisher_information(
288 &self,
289 circuit: &ParameterizedQuantumCircuit,
290 parameters: &[f64],
291 ) -> DeviceResult<Vec<Vec<f64>>> {
292 let n_params = parameters.len();
293 let mut fisher_matrix = vec![vec![0.0; n_params]; n_params];
294 let shift = std::f64::consts::PI / 2.0;
295
296 for i in 0..n_params {
297 for j in i..n_params {
298 if i == j {
299 let mut params_plus = parameters.to_vec();
301 let mut params_minus = parameters.to_vec();
302 params_plus[i] += shift;
303 params_minus[i] -= shift;
304
305 let circuit_plus = Self::evaluate_circuit_with_params(circuit, ¶ms_plus)?;
306 let circuit_minus = Self::evaluate_circuit_with_params(circuit, ¶ms_minus)?;
307
308 let device = self.device.read().await;
309 let result_plus =
310 Self::execute_circuit_helper(&*device, &circuit_plus, self.config.shots)
311 .await?;
312 let result_minus =
313 Self::execute_circuit_helper(&*device, &circuit_minus, self.config.shots)
314 .await?;
315
316 let overlap = self.compute_state_overlap(&result_plus, &result_minus)?;
317 fisher_matrix[i][j] = (1.0 - overlap.real) / 2.0;
318 } else {
319 fisher_matrix[i][j] = 0.0;
322 fisher_matrix[j][i] = fisher_matrix[i][j];
323 }
324 }
325 }
326
327 for i in 0..n_params {
329 fisher_matrix[i][i] += 1e-6;
330 }
331
332 Ok(fisher_matrix)
333 }
334
335 fn compute_state_overlap(
337 &self,
338 result1: &CircuitResult,
339 result2: &CircuitResult,
340 ) -> DeviceResult<Complex> {
341 let mut overlap_real = 0.0;
345 let total_shots1 = result1.shots as f64;
346 let total_shots2 = result2.shots as f64;
347
348 for (bitstring, count1) in &result1.counts {
349 if let Some(count2) = result2.counts.get(bitstring) {
350 let prob1 = *count1 as f64 / total_shots1;
351 let prob2 = *count2 as f64 / total_shots2;
352 overlap_real += (prob1 * prob2).sqrt();
353 }
354 }
355
356 Ok(Complex::new(overlap_real, 0.0))
357 }
358
359 fn solve_linear_system(&self, matrix: &[Vec<f64>], vector: &[f64]) -> DeviceResult<Vec<f64>> {
361 let n = matrix.len();
362 if n != vector.len() {
363 return Err(DeviceError::InvalidInput(
364 "Matrix and vector dimensions don't match".to_string(),
365 ));
366 }
367
368 let mut augmented = matrix
370 .iter()
371 .zip(vector.iter())
372 .map(|(row, &b)| {
373 let mut aug_row = row.clone();
374 aug_row.push(b);
375 aug_row
376 })
377 .collect::<Vec<_>>();
378
379 for i in 0..n {
381 let mut max_row = i;
383 for k in i + 1..n {
384 if augmented[k][i].abs() > augmented[max_row][i].abs() {
385 max_row = k;
386 }
387 }
388 augmented.swap(i, max_row);
389
390 if augmented[i][i].abs() < 1e-10 {
392 return Err(DeviceError::InvalidInput(
393 "Singular matrix in linear system".to_string(),
394 ));
395 }
396
397 for k in i + 1..n {
399 let factor = augmented[k][i] / augmented[i][i];
400 for j in i..=n {
401 augmented[k][j] -= factor * augmented[i][j];
402 }
403 }
404 }
405
406 let mut solution = vec![0.0; n];
408 for i in (0..n).rev() {
409 solution[i] = augmented[i][n];
410 for j in i + 1..n {
411 solution[i] -= augmented[i][j] * solution[j];
412 }
413 solution[i] /= augmented[i][i];
414 }
415
416 Ok(solution)
417 }
418
419 async fn execute_circuit_helper(
421 device: &(dyn QuantumDevice + Send + Sync),
422 circuit: &ParameterizedQuantumCircuit,
423 shots: usize,
424 ) -> DeviceResult<CircuitResult> {
425 let mut counts = std::collections::HashMap::new();
428 counts.insert("0".repeat(circuit.num_qubits()), shots / 2);
429 counts.insert("1".repeat(circuit.num_qubits()), shots / 2);
430
431 Ok(CircuitResult {
432 counts,
433 shots,
434 metadata: std::collections::HashMap::new(),
435 })
436 }
437
438 fn evaluate_circuit_with_params(
440 circuit: &ParameterizedQuantumCircuit,
441 parameters: &[f64],
442 ) -> DeviceResult<ParameterizedQuantumCircuit> {
443 Ok(circuit.clone())
446 }
447
448 fn compute_expectation_value(&self, result: &CircuitResult) -> DeviceResult<f64> {
450 let mut expectation = 0.0;
452 let total_shots = result.shots as f64;
453
454 for (bitstring, count) in &result.counts {
455 let ones_count = bitstring.chars().filter(|&c| c == '1').count();
456 let probability = *count as f64 / total_shots;
457 expectation += ones_count as f64 * probability;
458 }
459
460 Ok(expectation)
461 }
462
463 pub async fn compute_observable_gradients(
465 &self,
466 circuit: ParameterizedQuantumCircuit,
467 parameters: Vec<f64>,
468 observable: Observable,
469 ) -> DeviceResult<Vec<f64>> {
470 match self.method {
471 GradientMethod::ParameterShift => {
472 self.parameter_shift_observable_gradients(circuit, parameters, observable)
473 .await
474 }
475 _ => {
476 self.compute_gradients(circuit, parameters).await
478 }
479 }
480 }
481
482 async fn parameter_shift_observable_gradients(
484 &self,
485 circuit: ParameterizedQuantumCircuit,
486 parameters: Vec<f64>,
487 observable: Observable,
488 ) -> DeviceResult<Vec<f64>> {
489 let mut gradients = vec![0.0; parameters.len()];
490 let shift = self.config.shift_amount;
491
492 for i in 0..parameters.len() {
493 let mut params_plus = parameters.clone();
494 let mut params_minus = parameters.clone();
495 params_plus[i] += shift;
496 params_minus[i] -= shift;
497
498 let circuit_plus = Self::evaluate_circuit_with_params(&circuit, ¶ms_plus)?;
499 let circuit_minus = Self::evaluate_circuit_with_params(&circuit, ¶ms_minus)?;
500
501 let device = self.device.read().await;
502 let result_plus =
503 Self::execute_circuit_helper(&*device, &circuit_plus, self.config.shots).await?;
504 let result_minus =
505 Self::execute_circuit_helper(&*device, &circuit_minus, self.config.shots).await?;
506
507 let expectation_plus =
508 self.compute_observable_expectation(&result_plus, &observable)?;
509 let expectation_minus =
510 self.compute_observable_expectation(&result_minus, &observable)?;
511
512 gradients[i] = (expectation_plus - expectation_minus) / 2.0;
513 }
514
515 Ok(gradients)
516 }
517
518 fn compute_observable_expectation(
520 &self,
521 result: &CircuitResult,
522 observable: &Observable,
523 ) -> DeviceResult<f64> {
524 let mut expectation = 0.0;
525 let total_shots = result.shots as f64;
526
527 for (bitstring, count) in &result.counts {
528 let probability = *count as f64 / total_shots;
529 let eigenvalue = observable.evaluate_bitstring(bitstring)?;
530 expectation += probability * eigenvalue;
531 }
532
533 Ok(expectation)
534 }
535}
536
537#[derive(Debug, Clone, Serialize, Deserialize)]
539pub struct Observable {
540 pub terms: Vec<ObservableTerm>,
541}
542
543#[derive(Debug, Clone, Serialize, Deserialize)]
545pub struct ObservableTerm {
546 pub coefficient: f64,
547 pub pauli_string: Vec<(usize, PauliOperator)>, }
549
550impl Observable {
551 pub fn single_z(qubit: usize) -> Self {
553 Self {
554 terms: vec![ObservableTerm {
555 coefficient: 1.0,
556 pauli_string: vec![(qubit, PauliOperator::Z)],
557 }],
558 }
559 }
560
561 pub fn all_z(num_qubits: usize) -> Self {
563 let terms = (0..num_qubits)
564 .map(|i| ObservableTerm {
565 coefficient: 1.0,
566 pauli_string: vec![(i, PauliOperator::Z)],
567 })
568 .collect();
569
570 Self { terms }
571 }
572
573 pub fn evaluate_bitstring(&self, bitstring: &str) -> DeviceResult<f64> {
575 let mut value = 0.0;
576
577 for term in &self.terms {
578 let mut term_value = term.coefficient;
579
580 for (qubit_idx, pauli_op) in &term.pauli_string {
581 if let Some(bit_char) = bitstring.chars().nth(*qubit_idx) {
582 let bit_value = if bit_char == '1' { -1.0 } else { 1.0 };
583
584 match pauli_op {
585 PauliOperator::Z => term_value *= bit_value,
586 PauliOperator::I => {} PauliOperator::X | PauliOperator::Y => {
588 return Err(DeviceError::InvalidInput(
590 "X and Y Pauli measurements require basis rotation".to_string(),
591 ));
592 }
593 }
594 }
595 }
596
597 value += term_value;
598 }
599
600 Ok(value)
601 }
602}
603
604pub struct GradientUtils;
606
607impl GradientUtils {
608 pub fn central_difference(
610 f: impl Fn(&[f64]) -> f64,
611 parameters: &[f64],
612 step_size: f64,
613 ) -> Vec<f64> {
614 let mut gradients = vec![0.0; parameters.len()];
615
616 for i in 0..parameters.len() {
617 let mut params_plus = parameters.to_vec();
618 let mut params_minus = parameters.to_vec();
619 params_plus[i] += step_size;
620 params_minus[i] -= step_size;
621
622 let f_plus = f(¶ms_plus);
623 let f_minus = f(¶ms_minus);
624
625 gradients[i] = (f_plus - f_minus) / (2.0 * step_size);
626 }
627
628 gradients
629 }
630
631 pub fn clip_gradients(gradients: &mut [f64], max_norm: f64) {
633 let norm = gradients.iter().map(|g| g * g).sum::<f64>().sqrt();
634 if norm > max_norm {
635 let scale = max_norm / norm;
636 for grad in gradients {
637 *grad *= scale;
638 }
639 }
640 }
641
642 pub fn apply_momentum(
644 gradients: &[f64],
645 momentum_buffer: &mut Vec<f64>,
646 momentum: f64,
647 ) -> Vec<f64> {
648 if momentum_buffer.len() != gradients.len() {
649 momentum_buffer.resize(gradients.len(), 0.0);
650 }
651
652 let mut updated_gradients = Vec::with_capacity(gradients.len());
653 for i in 0..gradients.len() {
654 momentum_buffer[i] = momentum.mul_add(momentum_buffer[i], gradients[i]);
655 updated_gradients.push(momentum_buffer[i]);
656 }
657
658 updated_gradients
659 }
660}
661
662pub fn create_parameter_shift_calculator(
664 device: Arc<RwLock<dyn QuantumDevice + Send + Sync>>,
665 shots: usize,
666) -> DeviceResult<QuantumGradientCalculator> {
667 let config = GradientConfig {
668 method: GradientMethod::ParameterShift,
669 shots,
670 ..Default::default()
671 };
672
673 QuantumGradientCalculator::new(device, config)
674}
675
676pub fn create_finite_difference_calculator(
678 device: Arc<RwLock<dyn QuantumDevice + Send + Sync>>,
679 step_size: f64,
680) -> DeviceResult<QuantumGradientCalculator> {
681 let config = GradientConfig {
682 method: GradientMethod::FiniteDifference,
683 finite_diff_step: step_size,
684 ..Default::default()
685 };
686
687 QuantumGradientCalculator::new(device, config)
688}
689
690#[cfg(test)]
691mod tests {
692 use super::*;
693 use crate::test_utils::create_mock_quantum_device;
694
695 #[tokio::test]
696 async fn test_gradient_calculator_creation() {
697 let device = create_mock_quantum_device();
698 let calculator = QuantumGradientCalculator::new(device, GradientConfig::default())
699 .expect("QuantumGradientCalculator creation should succeed with default config");
700
701 assert_eq!(calculator.config.method, GradientMethod::ParameterShift);
702 assert_eq!(calculator.config.shots, 1024);
703 }
704
705 #[test]
706 fn test_observable_creation() {
707 let obs = Observable::single_z(0);
708 assert_eq!(obs.terms.len(), 1);
709 assert_eq!(obs.terms[0].coefficient, 1.0);
710
711 let obs_all = Observable::all_z(4);
712 assert_eq!(obs_all.terms.len(), 4);
713 }
714
715 #[test]
716 fn test_observable_evaluation() {
717 let obs = Observable::single_z(0);
718
719 let value_0 = obs
720 .evaluate_bitstring("0")
721 .expect("Observable evaluation should succeed for bitstring '0'");
722 assert_eq!(value_0, 1.0);
723
724 let value_1 = obs
725 .evaluate_bitstring("1")
726 .expect("Observable evaluation should succeed for bitstring '1'");
727 assert_eq!(value_1, -1.0);
728 }
729
730 #[test]
731 fn test_gradient_utils() {
732 let quadratic = |params: &[f64]| params[0] * params[0] + 2.0 * params[1] * params[1];
733 let gradients = GradientUtils::central_difference(quadratic, &[1.0, 2.0], 1e-5);
734
735 assert!((gradients[0] - 2.0).abs() < 1e-3);
737 assert!((gradients[1] - 8.0).abs() < 1e-3);
738 }
739
740 #[test]
741 fn test_gradient_clipping() {
742 let mut gradients = vec![3.0, 4.0]; GradientUtils::clip_gradients(&mut gradients, 2.0);
744
745 let new_norm = gradients.iter().map(|g| g * g).sum::<f64>().sqrt();
746 assert!((new_norm - 2.0).abs() < 1e-10);
747 }
748
749 #[test]
750 fn test_momentum() {
751 let gradients = vec![1.0, 2.0];
752 let mut momentum_buffer = vec![0.5, -0.5];
753
754 let updated = GradientUtils::apply_momentum(&gradients, &mut momentum_buffer, 0.9);
755
756 assert!((updated[0] - 1.45).abs() < 1e-10);
758 assert!((updated[1] - 1.55).abs() < 1e-10);
759 }
760}