1use scirs2_core::ndarray::{Array1, Array2};
7use std::collections::HashMap;
8use std::f64::consts::PI;
9
10use crate::error::{MLError, Result};
11use quantrs2_circuit::prelude::*;
12use quantrs2_core::gate::GateOp;
13
14#[derive(Debug, Clone)]
16pub struct DifferentiableParam {
17 pub name: String,
19 pub value: f64,
21 pub gradient: f64,
23 pub requires_grad: bool,
25}
26
27impl DifferentiableParam {
28 pub fn new(name: impl Into<String>, value: f64) -> Self {
30 Self {
31 name: name.into(),
32 value,
33 gradient: 0.0,
34 requires_grad: true,
35 }
36 }
37
38 pub fn constant(name: impl Into<String>, value: f64) -> Self {
40 Self {
41 name: name.into(),
42 value,
43 gradient: 0.0,
44 requires_grad: false,
45 }
46 }
47}
48
49#[derive(Debug, Clone)]
51pub enum ComputationNode {
52 Parameter(String),
54 Constant(f64),
56 Add(Box<ComputationNode>, Box<ComputationNode>),
58 Mul(Box<ComputationNode>, Box<ComputationNode>),
60 Sin(Box<ComputationNode>),
62 Cos(Box<ComputationNode>),
64 Exp(Box<ComputationNode>),
66 Expectation {
68 circuit_params: Vec<String>,
69 observable: String,
70 },
71}
72
73pub struct AutoDiff {
75 parameters: HashMap<String, DifferentiableParam>,
77 graph: Option<ComputationNode>,
79 forward_cache: HashMap<String, f64>,
81}
82
83impl AutoDiff {
84 pub fn new() -> Self {
86 Self {
87 parameters: HashMap::new(),
88 graph: None,
89 forward_cache: HashMap::new(),
90 }
91 }
92
93 pub fn register_parameter(&mut self, param: DifferentiableParam) {
95 self.parameters.insert(param.name.clone(), param);
96 }
97
98 pub fn set_graph(&mut self, graph: ComputationNode) {
100 self.graph = Some(graph);
101 }
102
103 pub fn forward(&mut self) -> Result<f64> {
105 self.forward_cache.clear();
106
107 if let Some(graph) = self.graph.clone() {
108 self.evaluate_node(&graph)
109 } else {
110 Err(MLError::InvalidConfiguration(
111 "No computation graph set".to_string(),
112 ))
113 }
114 }
115
116 pub fn backward(&mut self, loss_gradient: f64) -> Result<()> {
118 for param in self.parameters.values_mut() {
120 param.gradient = 0.0;
121 }
122
123 if let Some(graph) = self.graph.clone() {
124 self.backpropagate(&graph, loss_gradient)?;
125 }
126
127 Ok(())
128 }
129
130 fn evaluate_node(&mut self, node: &ComputationNode) -> Result<f64> {
132 match node {
133 ComputationNode::Parameter(name) => {
134 self.parameters.get(name).map(|p| p.value).ok_or_else(|| {
135 MLError::InvalidConfiguration(format!("Unknown parameter: {}", name))
136 })
137 }
138 ComputationNode::Constant(value) => Ok(*value),
139 ComputationNode::Add(left, right) => {
140 let l = self.evaluate_node(left)?;
141 let r = self.evaluate_node(right)?;
142 Ok(l + r)
143 }
144 ComputationNode::Mul(left, right) => {
145 let l = self.evaluate_node(left)?;
146 let r = self.evaluate_node(right)?;
147 Ok(l * r)
148 }
149 ComputationNode::Sin(inner) => {
150 let x = self.evaluate_node(inner)?;
151 Ok(x.sin())
152 }
153 ComputationNode::Cos(inner) => {
154 let x = self.evaluate_node(inner)?;
155 Ok(x.cos())
156 }
157 ComputationNode::Exp(inner) => {
158 let x = self.evaluate_node(inner)?;
159 Ok(x.exp())
160 }
161 ComputationNode::Expectation {
162 circuit_params,
163 observable,
164 } => {
165 let mut sum = 0.0;
167 for param_name in circuit_params {
168 if let Some(param) = self.parameters.get(param_name) {
169 sum += param.value;
170 }
171 }
172 Ok(sum.cos()) }
174 }
175 }
176
177 fn backpropagate(&mut self, node: &ComputationNode, grad: f64) -> Result<()> {
179 match node {
180 ComputationNode::Parameter(name) => {
181 if let Some(param) = self.parameters.get_mut(name) {
182 if param.requires_grad {
183 param.gradient += grad;
184 }
185 }
186 }
187 ComputationNode::Constant(_) => {
188 }
190 ComputationNode::Add(left, right) => {
191 self.backpropagate(left, grad)?;
193 self.backpropagate(right, grad)?;
194 }
195 ComputationNode::Mul(left, right) => {
196 let l_val = self.evaluate_node(left)?;
198 let r_val = self.evaluate_node(right)?;
199 self.backpropagate(left, grad * r_val)?;
200 self.backpropagate(right, grad * l_val)?;
201 }
202 ComputationNode::Sin(inner) => {
203 let x = self.evaluate_node(inner)?;
205 self.backpropagate(inner, grad * x.cos())?;
206 }
207 ComputationNode::Cos(inner) => {
208 let x = self.evaluate_node(inner)?;
210 self.backpropagate(inner, grad * (-x.sin()))?;
211 }
212 ComputationNode::Exp(inner) => {
213 let x = self.evaluate_node(inner)?;
215 self.backpropagate(inner, grad * x.exp())?;
216 }
217 ComputationNode::Expectation { circuit_params, .. } => {
218 for param_name in circuit_params {
220 let shift_grad = self.parameter_shift_gradient(param_name, PI / 2.0)?;
221 if let Some(param) = self.parameters.get_mut(param_name) {
222 if param.requires_grad {
223 param.gradient += grad * shift_grad;
224 }
225 }
226 }
227 }
228 }
229 Ok(())
230 }
231
232 fn parameter_shift_gradient(&self, param_name: &str, shift: f64) -> Result<f64> {
234 Ok(0.5) }
238
239 pub fn gradients(&self) -> HashMap<String, f64> {
241 self.parameters
242 .iter()
243 .filter(|(_, p)| p.requires_grad)
244 .map(|(name, param)| (name.clone(), param.gradient))
245 .collect()
246 }
247
248 pub fn update_parameters(&mut self, learning_rate: f64) {
250 for param in self.parameters.values_mut() {
251 if param.requires_grad {
252 param.value -= learning_rate * param.gradient;
253 }
254 }
255 }
256}
257
258pub struct QuantumAutoDiff {
260 autodiff: AutoDiff,
262 executor: Box<dyn Fn(&[f64]) -> f64>,
264}
265
266impl QuantumAutoDiff {
267 pub fn new<F>(executor: F) -> Self
269 where
270 F: Fn(&[f64]) -> f64 + 'static,
271 {
272 Self {
273 autodiff: AutoDiff::new(),
274 executor: Box::new(executor),
275 }
276 }
277
278 pub fn parameter_shift_gradients(&self, params: &[f64], shift: f64) -> Result<Vec<f64>> {
280 let mut gradients = vec![0.0; params.len()];
281
282 for (i, _) in params.iter().enumerate() {
283 let mut params_plus = params.to_vec();
285 params_plus[i] += shift;
286 let val_plus = (self.executor)(¶ms_plus);
287
288 let mut params_minus = params.to_vec();
290 params_minus[i] -= shift;
291 let val_minus = (self.executor)(¶ms_minus);
292
293 gradients[i] = (val_plus - val_minus) / (2.0 * shift.sin());
295 }
296
297 Ok(gradients)
298 }
299
300 pub fn natural_gradients(
302 &self,
303 params: &[f64],
304 gradients: &[f64],
305 regularization: f64,
306 ) -> Result<Vec<f64>> {
307 let n = params.len();
308 let mut fisher = Array2::<f64>::zeros((n, n));
309
310 for i in 0..n {
312 for j in 0..n {
313 fisher[[i, j]] = self.compute_fisher_element(params, i, j)?;
314 }
315 }
316
317 for i in 0..n {
319 fisher[[i, i]] += regularization;
320 }
321
322 self.solve_linear_system(&fisher, gradients)
324 }
325
326 fn compute_fisher_element(&self, params: &[f64], i: usize, j: usize) -> Result<f64> {
331 let shift = PI / 2.0;
332
333 let mut p_pp = params.to_vec();
334 let mut p_pm = params.to_vec();
335 let mut p_mp = params.to_vec();
336 let mut p_mm = params.to_vec();
337
338 p_pp[i] += shift;
339 p_pp[j] += shift;
340
341 p_pm[i] += shift;
342 p_pm[j] -= shift;
343
344 p_mp[i] -= shift;
345 p_mp[j] += shift;
346
347 p_mm[i] -= shift;
348 p_mm[j] -= shift;
349
350 let e_pp = (self.executor)(&p_pp);
351 let e_pm = (self.executor)(&p_pm);
352 let e_mp = (self.executor)(&p_mp);
353 let e_mm = (self.executor)(&p_mm);
354
355 Ok((e_pp - e_pm - e_mp + e_mm) / 4.0)
356 }
357
358 fn solve_linear_system(&self, matrix: &Array2<f64>, rhs: &[f64]) -> Result<Vec<f64>> {
363 let n = rhs.len();
364 if matrix.nrows() != n || matrix.ncols() != n {
365 return Err(MLError::DimensionMismatch(format!(
366 "Matrix ({} x {}) incompatible with rhs length {}",
367 matrix.nrows(),
368 matrix.ncols(),
369 n
370 )));
371 }
372
373 let mut a: Vec<Vec<f64>> = (0..n)
375 .map(|i| {
376 let mut row: Vec<f64> = (0..n).map(|j| matrix[[i, j]]).collect();
377 row.push(rhs[i]);
378 row
379 })
380 .collect();
381
382 for k in 0..n {
384 let mut max_val = a[k][k].abs();
386 let mut max_idx = k;
387 for row in (k + 1)..n {
388 let val = a[row][k].abs();
389 if val > max_val {
390 max_val = val;
391 max_idx = row;
392 }
393 }
394
395 if max_val < 1e-12 {
396 return Err(MLError::NumericalError(format!(
397 "Singular matrix: |pivot| = {:.2e} < 1e-12 at column {}",
398 max_val, k
399 )));
400 }
401
402 if max_idx != k {
404 a.swap(k, max_idx);
405 }
406
407 let pivot = a[k][k];
408
409 for i in (k + 1)..n {
411 let factor = a[i][k] / pivot;
412 for col in k..=n {
413 let sub = factor * a[k][col];
414 a[i][col] -= sub;
415 }
416 }
417 }
418
419 let mut x = vec![0.0_f64; n];
421 for i in (0..n).rev() {
422 let mut sum = a[i][n]; for j in (i + 1)..n {
424 sum -= a[i][j] * x[j];
425 }
426 x[i] = sum / a[i][i];
427 }
428
429 Ok(x)
430 }
431}
432
433#[derive(Debug, Clone)]
435pub struct GradientTape {
436 operations: Vec<Operation>,
438 variables: HashMap<String, f64>,
440}
441
442#[derive(Debug, Clone)]
444enum Operation {
445 Assign { var: String, value: f64 },
447 Add {
449 result: String,
450 left: String,
451 right: String,
452 },
453 Mul {
455 result: String,
456 left: String,
457 right: String,
458 },
459 Quantum { result: String, params: Vec<String> },
461}
462
463impl GradientTape {
464 pub fn new() -> Self {
466 Self {
467 operations: Vec::new(),
468 variables: HashMap::new(),
469 }
470 }
471
472 pub fn variable(&mut self, name: impl Into<String>, value: f64) -> String {
474 let name = name.into();
475 self.variables.insert(name.clone(), value);
476 self.operations.push(Operation::Assign {
477 var: name.clone(),
478 value,
479 });
480 name
481 }
482
483 pub fn add(&mut self, left: &str, right: &str) -> String {
485 let result = format!("tmp_{}", self.operations.len());
486 let left_val = self.variables[left];
487 let right_val = self.variables[right];
488 self.variables.insert(result.clone(), left_val + right_val);
489 self.operations.push(Operation::Add {
490 result: result.clone(),
491 left: left.to_string(),
492 right: right.to_string(),
493 });
494 result
495 }
496
497 pub fn mul(&mut self, left: &str, right: &str) -> String {
499 let result = format!("tmp_{}", self.operations.len());
500 let left_val = self.variables[left];
501 let right_val = self.variables[right];
502 self.variables.insert(result.clone(), left_val * right_val);
503 self.operations.push(Operation::Mul {
504 result: result.clone(),
505 left: left.to_string(),
506 right: right.to_string(),
507 });
508 result
509 }
510
511 pub fn gradient(&self, output: &str, inputs: &[&str]) -> HashMap<String, f64> {
513 let mut gradients: HashMap<String, f64> = HashMap::new();
514
515 gradients.insert(output.to_string(), 1.0);
517
518 for op in self.operations.iter().rev() {
520 match op {
521 Operation::Add {
522 result,
523 left,
524 right,
525 } => {
526 if let Some(&grad) = gradients.get(result) {
527 *gradients.entry(left.clone()).or_insert(0.0) += grad;
528 *gradients.entry(right.clone()).or_insert(0.0) += grad;
529 }
530 }
531 Operation::Mul {
532 result,
533 left,
534 right,
535 } => {
536 if let Some(&grad) = gradients.get(result) {
537 let left_val = self.variables[left];
538 let right_val = self.variables[right];
539 *gradients.entry(left.clone()).or_insert(0.0) += grad * right_val;
540 *gradients.entry(right.clone()).or_insert(0.0) += grad * left_val;
541 }
542 }
543 _ => {}
544 }
545 }
546
547 inputs
549 .iter()
550 .map(|&input| {
551 (
552 input.to_string(),
553 gradients.get(input).copied().unwrap_or(0.0),
554 )
555 })
556 .collect()
557 }
558}
559
560pub mod optimizers {
562 use super::*;
563
564 pub trait Optimizer {
566 fn step(&mut self, params: &mut HashMap<String, f64>, gradients: &HashMap<String, f64>);
568
569 fn reset(&mut self);
571 }
572
573 pub struct SGD {
575 learning_rate: f64,
576 momentum: f64,
577 velocities: HashMap<String, f64>,
578 }
579
580 impl SGD {
581 pub fn new(learning_rate: f64, momentum: f64) -> Self {
582 Self {
583 learning_rate,
584 momentum,
585 velocities: HashMap::new(),
586 }
587 }
588 }
589
590 impl Optimizer for SGD {
591 fn step(&mut self, params: &mut HashMap<String, f64>, gradients: &HashMap<String, f64>) {
592 for (name, grad) in gradients {
593 let velocity = self.velocities.entry(name.clone()).or_insert(0.0);
594 *velocity = self.momentum * *velocity - self.learning_rate * grad;
595
596 if let Some(param) = params.get_mut(name) {
597 *param += *velocity;
598 }
599 }
600 }
601
602 fn reset(&mut self) {
603 self.velocities.clear();
604 }
605 }
606
607 pub struct Adam {
609 learning_rate: f64,
610 beta1: f64,
611 beta2: f64,
612 epsilon: f64,
613 t: usize,
614 m: HashMap<String, f64>,
615 v: HashMap<String, f64>,
616 }
617
618 impl Adam {
619 pub fn new(learning_rate: f64) -> Self {
620 Self {
621 learning_rate,
622 beta1: 0.9,
623 beta2: 0.999,
624 epsilon: 1e-8,
625 t: 0,
626 m: HashMap::new(),
627 v: HashMap::new(),
628 }
629 }
630 }
631
632 impl Optimizer for Adam {
633 fn step(&mut self, params: &mut HashMap<String, f64>, gradients: &HashMap<String, f64>) {
634 self.t += 1;
635 let t = self.t as f64;
636
637 for (name, grad) in gradients {
638 let m_t = self.m.entry(name.clone()).or_insert(0.0);
639 let v_t = self.v.entry(name.clone()).or_insert(0.0);
640
641 *m_t = self.beta1 * *m_t + (1.0 - self.beta1) * grad;
643 *v_t = self.beta2 * *v_t + (1.0 - self.beta2) * grad * grad;
644
645 let m_hat = *m_t / (1.0 - self.beta1.powf(t));
647 let v_hat = *v_t / (1.0 - self.beta2.powf(t));
648
649 if let Some(param) = params.get_mut(name) {
651 *param -= self.learning_rate * m_hat / (v_hat.sqrt() + self.epsilon);
652 }
653 }
654 }
655
656 fn reset(&mut self) {
657 self.t = 0;
658 self.m.clear();
659 self.v.clear();
660 }
661 }
662
663 pub struct QNG {
665 learning_rate: f64,
666 regularization: f64,
667 }
668
669 impl QNG {
670 pub fn new(learning_rate: f64, regularization: f64) -> Self {
671 Self {
672 learning_rate,
673 regularization,
674 }
675 }
676 }
677
678 impl Optimizer for QNG {
679 fn step(&mut self, params: &mut HashMap<String, f64>, gradients: &HashMap<String, f64>) {
680 for (name, grad) in gradients {
682 if let Some(param) = params.get_mut(name) {
683 *param -= self.learning_rate * grad;
684 }
685 }
686 }
687
688 fn reset(&mut self) {}
689 }
690}
691
692#[cfg(test)]
693mod tests {
694 use super::*;
695
696 #[test]
697 fn test_autodiff_basic() {
698 let mut autodiff = AutoDiff::new();
699
700 autodiff.register_parameter(DifferentiableParam::new("x", 2.0));
702 autodiff.register_parameter(DifferentiableParam::new("y", 3.0));
703
704 let graph = ComputationNode::Mul(
706 Box::new(ComputationNode::Parameter("x".to_string())),
707 Box::new(ComputationNode::Parameter("y".to_string())),
708 );
709 autodiff.set_graph(graph);
710
711 let result = autodiff.forward().expect("forward pass should succeed");
713 assert_eq!(result, 6.0);
714
715 autodiff
717 .backward(1.0)
718 .expect("backward pass should succeed");
719 let gradients = autodiff.gradients();
720
721 assert_eq!(gradients["x"], 3.0); assert_eq!(gradients["y"], 2.0); }
724
725 #[test]
726 fn test_gradient_tape() {
727 let mut tape = GradientTape::new();
728
729 let x = tape.variable("x", 2.0);
730 let y = tape.variable("y", 3.0);
731 let z = tape.mul(&x, &y);
732
733 let gradients = tape.gradient(&z, &[&x, &y]);
734
735 assert_eq!(gradients[&x], 3.0);
736 assert_eq!(gradients[&y], 2.0);
737 }
738
739 #[test]
740 fn test_optimizers() {
741 use optimizers::*;
742
743 let mut params = HashMap::new();
744 params.insert("x".to_string(), 5.0);
745
746 let mut gradients = HashMap::new();
747 gradients.insert("x".to_string(), 2.0);
748
749 let mut sgd = SGD::new(0.1, 0.0);
751 sgd.step(&mut params, &gradients);
752 assert!((params["x"] - 4.8).abs() < 1e-6);
753
754 params.insert("x".to_string(), 5.0);
756 let mut adam = Adam::new(0.1);
757 adam.step(&mut params, &gradients);
758 assert!(params["x"] < 5.0); }
760
761 #[test]
762 fn test_parameter_shift() {
763 let executor = |params: &[f64]| -> f64 { params[0].cos() + params[1].sin() };
764
765 let qad = QuantumAutoDiff::new(executor);
766 let params = vec![PI / 4.0, PI / 3.0];
767
768 let gradients = qad
769 .parameter_shift_gradients(¶ms, PI / 2.0)
770 .expect("parameter shift gradients should succeed");
771 assert_eq!(gradients.len(), 2);
772 }
773}