1use scirs2_core::ndarray::{Array1, Array2};
7use std::collections::HashMap;
8use std::f64::consts::PI;
9
10use crate::error::{MLError, Result};
11use quantrs2_circuit::prelude::*;
12use quantrs2_core::gate::GateOp;
13
14#[derive(Debug, Clone)]
16pub struct DifferentiableParam {
17 pub name: String,
19 pub value: f64,
21 pub gradient: f64,
23 pub requires_grad: bool,
25}
26
27impl DifferentiableParam {
28 pub fn new(name: impl Into<String>, value: f64) -> Self {
30 Self {
31 name: name.into(),
32 value,
33 gradient: 0.0,
34 requires_grad: true,
35 }
36 }
37
38 pub fn constant(name: impl Into<String>, value: f64) -> Self {
40 Self {
41 name: name.into(),
42 value,
43 gradient: 0.0,
44 requires_grad: false,
45 }
46 }
47}
48
49#[derive(Debug, Clone)]
51pub enum ComputationNode {
52 Parameter(String),
54 Constant(f64),
56 Add(Box<ComputationNode>, Box<ComputationNode>),
58 Mul(Box<ComputationNode>, Box<ComputationNode>),
60 Sin(Box<ComputationNode>),
62 Cos(Box<ComputationNode>),
64 Exp(Box<ComputationNode>),
66 Expectation {
68 circuit_params: Vec<String>,
69 observable: String,
70 },
71}
72
73pub struct AutoDiff {
75 parameters: HashMap<String, DifferentiableParam>,
77 graph: Option<ComputationNode>,
79 forward_cache: HashMap<String, f64>,
81}
82
83impl AutoDiff {
84 pub fn new() -> Self {
86 Self {
87 parameters: HashMap::new(),
88 graph: None,
89 forward_cache: HashMap::new(),
90 }
91 }
92
93 pub fn register_parameter(&mut self, param: DifferentiableParam) {
95 self.parameters.insert(param.name.clone(), param);
96 }
97
98 pub fn set_graph(&mut self, graph: ComputationNode) {
100 self.graph = Some(graph);
101 }
102
103 pub fn forward(&mut self) -> Result<f64> {
105 self.forward_cache.clear();
106
107 if let Some(graph) = self.graph.clone() {
108 self.evaluate_node(&graph)
109 } else {
110 Err(MLError::InvalidConfiguration(
111 "No computation graph set".to_string(),
112 ))
113 }
114 }
115
116 pub fn backward(&mut self, loss_gradient: f64) -> Result<()> {
118 for param in self.parameters.values_mut() {
120 param.gradient = 0.0;
121 }
122
123 if let Some(graph) = self.graph.clone() {
124 self.backpropagate(&graph, loss_gradient)?;
125 }
126
127 Ok(())
128 }
129
130 fn evaluate_node(&mut self, node: &ComputationNode) -> Result<f64> {
132 match node {
133 ComputationNode::Parameter(name) => {
134 self.parameters.get(name).map(|p| p.value).ok_or_else(|| {
135 MLError::InvalidConfiguration(format!("Unknown parameter: {}", name))
136 })
137 }
138 ComputationNode::Constant(value) => Ok(*value),
139 ComputationNode::Add(left, right) => {
140 let l = self.evaluate_node(left)?;
141 let r = self.evaluate_node(right)?;
142 Ok(l + r)
143 }
144 ComputationNode::Mul(left, right) => {
145 let l = self.evaluate_node(left)?;
146 let r = self.evaluate_node(right)?;
147 Ok(l * r)
148 }
149 ComputationNode::Sin(inner) => {
150 let x = self.evaluate_node(inner)?;
151 Ok(x.sin())
152 }
153 ComputationNode::Cos(inner) => {
154 let x = self.evaluate_node(inner)?;
155 Ok(x.cos())
156 }
157 ComputationNode::Exp(inner) => {
158 let x = self.evaluate_node(inner)?;
159 Ok(x.exp())
160 }
161 ComputationNode::Expectation {
162 circuit_params,
163 observable,
164 } => {
165 let mut sum = 0.0;
167 for param_name in circuit_params {
168 if let Some(param) = self.parameters.get(param_name) {
169 sum += param.value;
170 }
171 }
172 Ok(sum.cos()) }
174 }
175 }
176
177 fn backpropagate(&mut self, node: &ComputationNode, grad: f64) -> Result<()> {
179 match node {
180 ComputationNode::Parameter(name) => {
181 if let Some(param) = self.parameters.get_mut(name) {
182 if param.requires_grad {
183 param.gradient += grad;
184 }
185 }
186 }
187 ComputationNode::Constant(_) => {
188 }
190 ComputationNode::Add(left, right) => {
191 self.backpropagate(left, grad)?;
193 self.backpropagate(right, grad)?;
194 }
195 ComputationNode::Mul(left, right) => {
196 let l_val = self.evaluate_node(left)?;
198 let r_val = self.evaluate_node(right)?;
199 self.backpropagate(left, grad * r_val)?;
200 self.backpropagate(right, grad * l_val)?;
201 }
202 ComputationNode::Sin(inner) => {
203 let x = self.evaluate_node(inner)?;
205 self.backpropagate(inner, grad * x.cos())?;
206 }
207 ComputationNode::Cos(inner) => {
208 let x = self.evaluate_node(inner)?;
210 self.backpropagate(inner, grad * (-x.sin()))?;
211 }
212 ComputationNode::Exp(inner) => {
213 let x = self.evaluate_node(inner)?;
215 self.backpropagate(inner, grad * x.exp())?;
216 }
217 ComputationNode::Expectation { circuit_params, .. } => {
218 for param_name in circuit_params {
220 let shift_grad = self.parameter_shift_gradient(param_name, PI / 2.0)?;
221 if let Some(param) = self.parameters.get_mut(param_name) {
222 if param.requires_grad {
223 param.gradient += grad * shift_grad;
224 }
225 }
226 }
227 }
228 }
229 Ok(())
230 }
231
232 fn parameter_shift_gradient(&self, param_name: &str, shift: f64) -> Result<f64> {
234 Ok(0.5) }
238
239 pub fn gradients(&self) -> HashMap<String, f64> {
241 self.parameters
242 .iter()
243 .filter(|(_, p)| p.requires_grad)
244 .map(|(name, param)| (name.clone(), param.gradient))
245 .collect()
246 }
247
248 pub fn update_parameters(&mut self, learning_rate: f64) {
250 for param in self.parameters.values_mut() {
251 if param.requires_grad {
252 param.value -= learning_rate * param.gradient;
253 }
254 }
255 }
256}
257
258pub struct QuantumAutoDiff {
260 autodiff: AutoDiff,
262 executor: Box<dyn Fn(&[f64]) -> f64>,
264}
265
266impl QuantumAutoDiff {
267 pub fn new<F>(executor: F) -> Self
269 where
270 F: Fn(&[f64]) -> f64 + 'static,
271 {
272 Self {
273 autodiff: AutoDiff::new(),
274 executor: Box::new(executor),
275 }
276 }
277
278 pub fn parameter_shift_gradients(&self, params: &[f64], shift: f64) -> Result<Vec<f64>> {
280 let mut gradients = vec![0.0; params.len()];
281
282 for (i, _) in params.iter().enumerate() {
283 let mut params_plus = params.to_vec();
285 params_plus[i] += shift;
286 let val_plus = (self.executor)(¶ms_plus);
287
288 let mut params_minus = params.to_vec();
290 params_minus[i] -= shift;
291 let val_minus = (self.executor)(¶ms_minus);
292
293 gradients[i] = (val_plus - val_minus) / (2.0 * shift.sin());
295 }
296
297 Ok(gradients)
298 }
299
300 pub fn natural_gradients(
302 &self,
303 params: &[f64],
304 gradients: &[f64],
305 regularization: f64,
306 ) -> Result<Vec<f64>> {
307 let n = params.len();
308 let mut fisher = Array2::<f64>::zeros((n, n));
309
310 for i in 0..n {
312 for j in 0..n {
313 fisher[[i, j]] = self.compute_fisher_element(params, i, j)?;
314 }
315 }
316
317 for i in 0..n {
319 fisher[[i, i]] += regularization;
320 }
321
322 self.solve_linear_system(&fisher, gradients)
324 }
325
326 fn compute_fisher_element(&self, params: &[f64], i: usize, j: usize) -> Result<f64> {
328 if i == j {
330 Ok(1.0 + 0.1 * fastrand::f64())
331 } else {
332 Ok(0.1 * fastrand::f64())
333 }
334 }
335
336 fn solve_linear_system(&self, matrix: &Array2<f64>, rhs: &[f64]) -> Result<Vec<f64>> {
338 Ok(rhs.to_vec())
340 }
341}
342
343#[derive(Debug, Clone)]
345pub struct GradientTape {
346 operations: Vec<Operation>,
348 variables: HashMap<String, f64>,
350}
351
352#[derive(Debug, Clone)]
354enum Operation {
355 Assign { var: String, value: f64 },
357 Add {
359 result: String,
360 left: String,
361 right: String,
362 },
363 Mul {
365 result: String,
366 left: String,
367 right: String,
368 },
369 Quantum { result: String, params: Vec<String> },
371}
372
373impl GradientTape {
374 pub fn new() -> Self {
376 Self {
377 operations: Vec::new(),
378 variables: HashMap::new(),
379 }
380 }
381
382 pub fn variable(&mut self, name: impl Into<String>, value: f64) -> String {
384 let name = name.into();
385 self.variables.insert(name.clone(), value);
386 self.operations.push(Operation::Assign {
387 var: name.clone(),
388 value,
389 });
390 name
391 }
392
393 pub fn add(&mut self, left: &str, right: &str) -> String {
395 let result = format!("tmp_{}", self.operations.len());
396 let left_val = self.variables[left];
397 let right_val = self.variables[right];
398 self.variables.insert(result.clone(), left_val + right_val);
399 self.operations.push(Operation::Add {
400 result: result.clone(),
401 left: left.to_string(),
402 right: right.to_string(),
403 });
404 result
405 }
406
407 pub fn mul(&mut self, left: &str, right: &str) -> String {
409 let result = format!("tmp_{}", self.operations.len());
410 let left_val = self.variables[left];
411 let right_val = self.variables[right];
412 self.variables.insert(result.clone(), left_val * right_val);
413 self.operations.push(Operation::Mul {
414 result: result.clone(),
415 left: left.to_string(),
416 right: right.to_string(),
417 });
418 result
419 }
420
421 pub fn gradient(&self, output: &str, inputs: &[&str]) -> HashMap<String, f64> {
423 let mut gradients: HashMap<String, f64> = HashMap::new();
424
425 gradients.insert(output.to_string(), 1.0);
427
428 for op in self.operations.iter().rev() {
430 match op {
431 Operation::Add {
432 result,
433 left,
434 right,
435 } => {
436 if let Some(&grad) = gradients.get(result) {
437 *gradients.entry(left.clone()).or_insert(0.0) += grad;
438 *gradients.entry(right.clone()).or_insert(0.0) += grad;
439 }
440 }
441 Operation::Mul {
442 result,
443 left,
444 right,
445 } => {
446 if let Some(&grad) = gradients.get(result) {
447 let left_val = self.variables[left];
448 let right_val = self.variables[right];
449 *gradients.entry(left.clone()).or_insert(0.0) += grad * right_val;
450 *gradients.entry(right.clone()).or_insert(0.0) += grad * left_val;
451 }
452 }
453 _ => {}
454 }
455 }
456
457 inputs
459 .iter()
460 .map(|&input| {
461 (
462 input.to_string(),
463 gradients.get(input).copied().unwrap_or(0.0),
464 )
465 })
466 .collect()
467 }
468}
469
470pub mod optimizers {
472 use super::*;
473
474 pub trait Optimizer {
476 fn step(&mut self, params: &mut HashMap<String, f64>, gradients: &HashMap<String, f64>);
478
479 fn reset(&mut self);
481 }
482
483 pub struct SGD {
485 learning_rate: f64,
486 momentum: f64,
487 velocities: HashMap<String, f64>,
488 }
489
490 impl SGD {
491 pub fn new(learning_rate: f64, momentum: f64) -> Self {
492 Self {
493 learning_rate,
494 momentum,
495 velocities: HashMap::new(),
496 }
497 }
498 }
499
500 impl Optimizer for SGD {
501 fn step(&mut self, params: &mut HashMap<String, f64>, gradients: &HashMap<String, f64>) {
502 for (name, grad) in gradients {
503 let velocity = self.velocities.entry(name.clone()).or_insert(0.0);
504 *velocity = self.momentum * *velocity - self.learning_rate * grad;
505
506 if let Some(param) = params.get_mut(name) {
507 *param += *velocity;
508 }
509 }
510 }
511
512 fn reset(&mut self) {
513 self.velocities.clear();
514 }
515 }
516
517 pub struct Adam {
519 learning_rate: f64,
520 beta1: f64,
521 beta2: f64,
522 epsilon: f64,
523 t: usize,
524 m: HashMap<String, f64>,
525 v: HashMap<String, f64>,
526 }
527
528 impl Adam {
529 pub fn new(learning_rate: f64) -> Self {
530 Self {
531 learning_rate,
532 beta1: 0.9,
533 beta2: 0.999,
534 epsilon: 1e-8,
535 t: 0,
536 m: HashMap::new(),
537 v: HashMap::new(),
538 }
539 }
540 }
541
542 impl Optimizer for Adam {
543 fn step(&mut self, params: &mut HashMap<String, f64>, gradients: &HashMap<String, f64>) {
544 self.t += 1;
545 let t = self.t as f64;
546
547 for (name, grad) in gradients {
548 let m_t = self.m.entry(name.clone()).or_insert(0.0);
549 let v_t = self.v.entry(name.clone()).or_insert(0.0);
550
551 *m_t = self.beta1 * *m_t + (1.0 - self.beta1) * grad;
553 *v_t = self.beta2 * *v_t + (1.0 - self.beta2) * grad * grad;
554
555 let m_hat = *m_t / (1.0 - self.beta1.powf(t));
557 let v_hat = *v_t / (1.0 - self.beta2.powf(t));
558
559 if let Some(param) = params.get_mut(name) {
561 *param -= self.learning_rate * m_hat / (v_hat.sqrt() + self.epsilon);
562 }
563 }
564 }
565
566 fn reset(&mut self) {
567 self.t = 0;
568 self.m.clear();
569 self.v.clear();
570 }
571 }
572
573 pub struct QNG {
575 learning_rate: f64,
576 regularization: f64,
577 }
578
579 impl QNG {
580 pub fn new(learning_rate: f64, regularization: f64) -> Self {
581 Self {
582 learning_rate,
583 regularization,
584 }
585 }
586 }
587
588 impl Optimizer for QNG {
589 fn step(&mut self, params: &mut HashMap<String, f64>, gradients: &HashMap<String, f64>) {
590 for (name, grad) in gradients {
592 if let Some(param) = params.get_mut(name) {
593 *param -= self.learning_rate * grad;
594 }
595 }
596 }
597
598 fn reset(&mut self) {}
599 }
600}
601
602#[cfg(test)]
603mod tests {
604 use super::*;
605
606 #[test]
607 fn test_autodiff_basic() {
608 let mut autodiff = AutoDiff::new();
609
610 autodiff.register_parameter(DifferentiableParam::new("x", 2.0));
612 autodiff.register_parameter(DifferentiableParam::new("y", 3.0));
613
614 let graph = ComputationNode::Mul(
616 Box::new(ComputationNode::Parameter("x".to_string())),
617 Box::new(ComputationNode::Parameter("y".to_string())),
618 );
619 autodiff.set_graph(graph);
620
621 let result = autodiff.forward().unwrap();
623 assert_eq!(result, 6.0);
624
625 autodiff.backward(1.0).unwrap();
627 let gradients = autodiff.gradients();
628
629 assert_eq!(gradients["x"], 3.0); assert_eq!(gradients["y"], 2.0); }
632
633 #[test]
634 fn test_gradient_tape() {
635 let mut tape = GradientTape::new();
636
637 let x = tape.variable("x", 2.0);
638 let y = tape.variable("y", 3.0);
639 let z = tape.mul(&x, &y);
640
641 let gradients = tape.gradient(&z, &[&x, &y]);
642
643 assert_eq!(gradients[&x], 3.0);
644 assert_eq!(gradients[&y], 2.0);
645 }
646
647 #[test]
648 fn test_optimizers() {
649 use optimizers::*;
650
651 let mut params = HashMap::new();
652 params.insert("x".to_string(), 5.0);
653
654 let mut gradients = HashMap::new();
655 gradients.insert("x".to_string(), 2.0);
656
657 let mut sgd = SGD::new(0.1, 0.0);
659 sgd.step(&mut params, &gradients);
660 assert!((params["x"] - 4.8).abs() < 1e-6);
661
662 params.insert("x".to_string(), 5.0);
664 let mut adam = Adam::new(0.1);
665 adam.step(&mut params, &gradients);
666 assert!(params["x"] < 5.0); }
668
669 #[test]
670 fn test_parameter_shift() {
671 let executor = |params: &[f64]| -> f64 { params[0].cos() + params[1].sin() };
672
673 let qad = QuantumAutoDiff::new(executor);
674 let params = vec![PI / 4.0, PI / 3.0];
675
676 let gradients = qad.parameter_shift_gradients(¶ms, PI / 2.0).unwrap();
677 assert_eq!(gradients.len(), 2);
678 }
679}