Skip to main content

quantrs2_ml/
autodiff.rs

1//! Automatic differentiation for quantum machine learning.
2//!
3//! This module provides SciRS2-style automatic differentiation capabilities
4//! for computing gradients of quantum circuits and variational algorithms.
5
6use scirs2_core::ndarray::{Array1, Array2};
7use std::collections::HashMap;
8use std::f64::consts::PI;
9
10use crate::error::{MLError, Result};
11use quantrs2_circuit::prelude::*;
12use quantrs2_core::gate::GateOp;
13
14/// Differentiable parameter in a quantum circuit
15#[derive(Debug, Clone)]
16pub struct DifferentiableParam {
17    /// Parameter name/ID
18    pub name: String,
19    /// Current value
20    pub value: f64,
21    /// Gradient accumulator
22    pub gradient: f64,
23    /// Whether this parameter requires gradient
24    pub requires_grad: bool,
25}
26
27impl DifferentiableParam {
28    /// Create a new differentiable parameter
29    pub fn new(name: impl Into<String>, value: f64) -> Self {
30        Self {
31            name: name.into(),
32            value,
33            gradient: 0.0,
34            requires_grad: true,
35        }
36    }
37
38    /// Create a constant (non-differentiable) parameter
39    pub fn constant(name: impl Into<String>, value: f64) -> Self {
40        Self {
41            name: name.into(),
42            value,
43            gradient: 0.0,
44            requires_grad: false,
45        }
46    }
47}
48
49/// Computation graph node for automatic differentiation
50#[derive(Debug, Clone)]
51pub enum ComputationNode {
52    /// Input parameter
53    Parameter(String),
54    /// Constant value
55    Constant(f64),
56    /// Addition operation
57    Add(Box<ComputationNode>, Box<ComputationNode>),
58    /// Multiplication operation
59    Mul(Box<ComputationNode>, Box<ComputationNode>),
60    /// Sine function
61    Sin(Box<ComputationNode>),
62    /// Cosine function
63    Cos(Box<ComputationNode>),
64    /// Exponential function
65    Exp(Box<ComputationNode>),
66    /// Quantum expectation value
67    Expectation {
68        circuit_params: Vec<String>,
69        observable: String,
70    },
71}
72
73/// Automatic differentiation engine
74pub struct AutoDiff {
75    /// Parameters registry
76    parameters: HashMap<String, DifferentiableParam>,
77    /// Computation graph
78    graph: Option<ComputationNode>,
79    /// Cached forward values
80    forward_cache: HashMap<String, f64>,
81}
82
83impl AutoDiff {
84    /// Create a new AutoDiff engine
85    pub fn new() -> Self {
86        Self {
87            parameters: HashMap::new(),
88            graph: None,
89            forward_cache: HashMap::new(),
90        }
91    }
92
93    /// Register a parameter
94    pub fn register_parameter(&mut self, param: DifferentiableParam) {
95        self.parameters.insert(param.name.clone(), param);
96    }
97
98    /// Set computation graph
99    pub fn set_graph(&mut self, graph: ComputationNode) {
100        self.graph = Some(graph);
101    }
102
103    /// Forward pass - compute value
104    pub fn forward(&mut self) -> Result<f64> {
105        self.forward_cache.clear();
106
107        if let Some(graph) = self.graph.clone() {
108            self.evaluate_node(&graph)
109        } else {
110            Err(MLError::InvalidConfiguration(
111                "No computation graph set".to_string(),
112            ))
113        }
114    }
115
116    /// Backward pass - compute gradients
117    pub fn backward(&mut self, loss_gradient: f64) -> Result<()> {
118        // Reset gradients
119        for param in self.parameters.values_mut() {
120            param.gradient = 0.0;
121        }
122
123        if let Some(graph) = self.graph.clone() {
124            self.backpropagate(&graph, loss_gradient)?;
125        }
126
127        Ok(())
128    }
129
130    /// Evaluate a computation node
131    fn evaluate_node(&mut self, node: &ComputationNode) -> Result<f64> {
132        match node {
133            ComputationNode::Parameter(name) => {
134                self.parameters.get(name).map(|p| p.value).ok_or_else(|| {
135                    MLError::InvalidConfiguration(format!("Unknown parameter: {}", name))
136                })
137            }
138            ComputationNode::Constant(value) => Ok(*value),
139            ComputationNode::Add(left, right) => {
140                let l = self.evaluate_node(left)?;
141                let r = self.evaluate_node(right)?;
142                Ok(l + r)
143            }
144            ComputationNode::Mul(left, right) => {
145                let l = self.evaluate_node(left)?;
146                let r = self.evaluate_node(right)?;
147                Ok(l * r)
148            }
149            ComputationNode::Sin(inner) => {
150                let x = self.evaluate_node(inner)?;
151                Ok(x.sin())
152            }
153            ComputationNode::Cos(inner) => {
154                let x = self.evaluate_node(inner)?;
155                Ok(x.cos())
156            }
157            ComputationNode::Exp(inner) => {
158                let x = self.evaluate_node(inner)?;
159                Ok(x.exp())
160            }
161            ComputationNode::Expectation {
162                circuit_params,
163                observable,
164            } => {
165                // Simplified - would compute actual expectation value
166                let mut sum = 0.0;
167                for param_name in circuit_params {
168                    if let Some(param) = self.parameters.get(param_name) {
169                        sum += param.value;
170                    }
171                }
172                Ok(sum.cos()) // Placeholder
173            }
174        }
175    }
176
177    /// Backpropagate gradients through the graph
178    fn backpropagate(&mut self, node: &ComputationNode, grad: f64) -> Result<()> {
179        match node {
180            ComputationNode::Parameter(name) => {
181                if let Some(param) = self.parameters.get_mut(name) {
182                    if param.requires_grad {
183                        param.gradient += grad;
184                    }
185                }
186            }
187            ComputationNode::Constant(_) => {
188                // No gradient for constants
189            }
190            ComputationNode::Add(left, right) => {
191                // Gradient distributes equally for addition
192                self.backpropagate(left, grad)?;
193                self.backpropagate(right, grad)?;
194            }
195            ComputationNode::Mul(left, right) => {
196                // Product rule
197                let l_val = self.evaluate_node(left)?;
198                let r_val = self.evaluate_node(right)?;
199                self.backpropagate(left, grad * r_val)?;
200                self.backpropagate(right, grad * l_val)?;
201            }
202            ComputationNode::Sin(inner) => {
203                // d/dx sin(x) = cos(x)
204                let x = self.evaluate_node(inner)?;
205                self.backpropagate(inner, grad * x.cos())?;
206            }
207            ComputationNode::Cos(inner) => {
208                // d/dx cos(x) = -sin(x)
209                let x = self.evaluate_node(inner)?;
210                self.backpropagate(inner, grad * (-x.sin()))?;
211            }
212            ComputationNode::Exp(inner) => {
213                // d/dx exp(x) = exp(x)
214                let x = self.evaluate_node(inner)?;
215                self.backpropagate(inner, grad * x.exp())?;
216            }
217            ComputationNode::Expectation { circuit_params, .. } => {
218                // Use parameter shift rule for quantum gradients
219                for param_name in circuit_params {
220                    let shift_grad = self.parameter_shift_gradient(param_name, PI / 2.0)?;
221                    if let Some(param) = self.parameters.get_mut(param_name) {
222                        if param.requires_grad {
223                            param.gradient += grad * shift_grad;
224                        }
225                    }
226                }
227            }
228        }
229        Ok(())
230    }
231
232    /// Compute gradient using parameter shift rule
233    fn parameter_shift_gradient(&self, param_name: &str, shift: f64) -> Result<f64> {
234        // Simplified parameter shift rule
235        // In practice, would evaluate circuit with ±shift
236        Ok(0.5) // Placeholder
237    }
238
239    /// Get all gradients
240    pub fn gradients(&self) -> HashMap<String, f64> {
241        self.parameters
242            .iter()
243            .filter(|(_, p)| p.requires_grad)
244            .map(|(name, param)| (name.clone(), param.gradient))
245            .collect()
246    }
247
248    /// Update parameters using gradients
249    pub fn update_parameters(&mut self, learning_rate: f64) {
250        for param in self.parameters.values_mut() {
251            if param.requires_grad {
252                param.value -= learning_rate * param.gradient;
253            }
254        }
255    }
256}
257
258/// Quantum-aware automatic differentiation
259pub struct QuantumAutoDiff {
260    /// Base autodiff engine
261    autodiff: AutoDiff,
262    /// Circuit executor (placeholder)
263    executor: Box<dyn Fn(&[f64]) -> f64>,
264}
265
266impl QuantumAutoDiff {
267    /// Create a new quantum autodiff engine
268    pub fn new<F>(executor: F) -> Self
269    where
270        F: Fn(&[f64]) -> f64 + 'static,
271    {
272        Self {
273            autodiff: AutoDiff::new(),
274            executor: Box::new(executor),
275        }
276    }
277
278    /// Compute gradients using parameter shift rule
279    pub fn parameter_shift_gradients(&self, params: &[f64], shift: f64) -> Result<Vec<f64>> {
280        let mut gradients = vec![0.0; params.len()];
281
282        for (i, _) in params.iter().enumerate() {
283            // Shift parameter positively
284            let mut params_plus = params.to_vec();
285            params_plus[i] += shift;
286            let val_plus = (self.executor)(&params_plus);
287
288            // Shift parameter negatively
289            let mut params_minus = params.to_vec();
290            params_minus[i] -= shift;
291            let val_minus = (self.executor)(&params_minus);
292
293            // Parameter shift rule gradient
294            gradients[i] = (val_plus - val_minus) / (2.0 * shift.sin());
295        }
296
297        Ok(gradients)
298    }
299
300    /// Compute natural gradients using quantum Fisher information
301    pub fn natural_gradients(
302        &self,
303        params: &[f64],
304        gradients: &[f64],
305        regularization: f64,
306    ) -> Result<Vec<f64>> {
307        let n = params.len();
308        let mut fisher = Array2::<f64>::zeros((n, n));
309
310        // Compute quantum Fisher information matrix
311        for i in 0..n {
312            for j in 0..n {
313                fisher[[i, j]] = self.compute_fisher_element(params, i, j)?;
314            }
315        }
316
317        // Add regularization
318        for i in 0..n {
319            fisher[[i, i]] += regularization;
320        }
321
322        // Solve F * nat_grad = grad
323        self.solve_linear_system(&fisher, gradients)
324    }
325
326    /// Compute element of quantum Fisher information matrix using 4-point parameter-shift QFIM formula.
327    ///
328    /// F_ij = (E(θ+π/2·e_i+π/2·e_j) - E(θ+π/2·e_i-π/2·e_j)
329    ///        - E(θ-π/2·e_i+π/2·e_j) + E(θ-π/2·e_i-π/2·e_j)) / 4
330    fn compute_fisher_element(&self, params: &[f64], i: usize, j: usize) -> Result<f64> {
331        let shift = PI / 2.0;
332
333        let mut p_pp = params.to_vec();
334        let mut p_pm = params.to_vec();
335        let mut p_mp = params.to_vec();
336        let mut p_mm = params.to_vec();
337
338        p_pp[i] += shift;
339        p_pp[j] += shift;
340
341        p_pm[i] += shift;
342        p_pm[j] -= shift;
343
344        p_mp[i] -= shift;
345        p_mp[j] += shift;
346
347        p_mm[i] -= shift;
348        p_mm[j] -= shift;
349
350        let e_pp = (self.executor)(&p_pp);
351        let e_pm = (self.executor)(&p_pm);
352        let e_mp = (self.executor)(&p_mp);
353        let e_mm = (self.executor)(&p_mm);
354
355        Ok((e_pp - e_pm - e_mp + e_mm) / 4.0)
356    }
357
358    /// Solve linear system A·x = b using Gaussian elimination with partial pivoting.
359    ///
360    /// Returns `Ok(x)` on success, `Err(NumericalError)` if the matrix is singular
361    /// (i.e., |pivot| < 1e-12 at any elimination step).
362    fn solve_linear_system(&self, matrix: &Array2<f64>, rhs: &[f64]) -> Result<Vec<f64>> {
363        let n = rhs.len();
364        if matrix.nrows() != n || matrix.ncols() != n {
365            return Err(MLError::DimensionMismatch(format!(
366                "Matrix ({} x {}) incompatible with rhs length {}",
367                matrix.nrows(),
368                matrix.ncols(),
369                n
370            )));
371        }
372
373        // Build augmented matrix [A | b]
374        let mut a: Vec<Vec<f64>> = (0..n)
375            .map(|i| {
376                let mut row: Vec<f64> = (0..n).map(|j| matrix[[i, j]]).collect();
377                row.push(rhs[i]);
378                row
379            })
380            .collect();
381
382        // Forward elimination with partial pivoting
383        for k in 0..n {
384            // Find pivot row: row with max |a[row][k]| for row >= k
385            let mut max_val = a[k][k].abs();
386            let mut max_idx = k;
387            for row in (k + 1)..n {
388                let val = a[row][k].abs();
389                if val > max_val {
390                    max_val = val;
391                    max_idx = row;
392                }
393            }
394
395            if max_val < 1e-12 {
396                return Err(MLError::NumericalError(format!(
397                    "Singular matrix: |pivot| = {:.2e} < 1e-12 at column {}",
398                    max_val, k
399                )));
400            }
401
402            // Swap rows k and max_idx
403            if max_idx != k {
404                a.swap(k, max_idx);
405            }
406
407            let pivot = a[k][k];
408
409            // Eliminate below pivot
410            for i in (k + 1)..n {
411                let factor = a[i][k] / pivot;
412                for col in k..=n {
413                    let sub = factor * a[k][col];
414                    a[i][col] -= sub;
415                }
416            }
417        }
418
419        // Back substitution
420        let mut x = vec![0.0_f64; n];
421        for i in (0..n).rev() {
422            let mut sum = a[i][n]; // rhs column
423            for j in (i + 1)..n {
424                sum -= a[i][j] * x[j];
425            }
426            x[i] = sum / a[i][i];
427        }
428
429        Ok(x)
430    }
431}
432
433/// Gradient tape for recording operations
434#[derive(Debug, Clone)]
435pub struct GradientTape {
436    /// Recorded operations
437    operations: Vec<Operation>,
438    /// Variable values
439    variables: HashMap<String, f64>,
440}
441
442/// Recorded operation
443#[derive(Debug, Clone)]
444enum Operation {
445    /// Variable assignment
446    Assign { var: String, value: f64 },
447    /// Addition
448    Add {
449        result: String,
450        left: String,
451        right: String,
452    },
453    /// Multiplication
454    Mul {
455        result: String,
456        left: String,
457        right: String,
458    },
459    /// Quantum operation
460    Quantum { result: String, params: Vec<String> },
461}
462
463impl GradientTape {
464    /// Create a new gradient tape
465    pub fn new() -> Self {
466        Self {
467            operations: Vec::new(),
468            variables: HashMap::new(),
469        }
470    }
471
472    /// Record a variable
473    pub fn variable(&mut self, name: impl Into<String>, value: f64) -> String {
474        let name = name.into();
475        self.variables.insert(name.clone(), value);
476        self.operations.push(Operation::Assign {
477            var: name.clone(),
478            value,
479        });
480        name
481    }
482
483    /// Record addition
484    pub fn add(&mut self, left: &str, right: &str) -> String {
485        let result = format!("tmp_{}", self.operations.len());
486        let left_val = self.variables[left];
487        let right_val = self.variables[right];
488        self.variables.insert(result.clone(), left_val + right_val);
489        self.operations.push(Operation::Add {
490            result: result.clone(),
491            left: left.to_string(),
492            right: right.to_string(),
493        });
494        result
495    }
496
497    /// Record multiplication
498    pub fn mul(&mut self, left: &str, right: &str) -> String {
499        let result = format!("tmp_{}", self.operations.len());
500        let left_val = self.variables[left];
501        let right_val = self.variables[right];
502        self.variables.insert(result.clone(), left_val * right_val);
503        self.operations.push(Operation::Mul {
504            result: result.clone(),
505            left: left.to_string(),
506            right: right.to_string(),
507        });
508        result
509    }
510
511    /// Compute gradients
512    pub fn gradient(&self, output: &str, inputs: &[&str]) -> HashMap<String, f64> {
513        let mut gradients: HashMap<String, f64> = HashMap::new();
514
515        // Initialize output gradient
516        gradients.insert(output.to_string(), 1.0);
517
518        // Backward pass through operations
519        for op in self.operations.iter().rev() {
520            match op {
521                Operation::Add {
522                    result,
523                    left,
524                    right,
525                } => {
526                    if let Some(&grad) = gradients.get(result) {
527                        *gradients.entry(left.clone()).or_insert(0.0) += grad;
528                        *gradients.entry(right.clone()).or_insert(0.0) += grad;
529                    }
530                }
531                Operation::Mul {
532                    result,
533                    left,
534                    right,
535                } => {
536                    if let Some(&grad) = gradients.get(result) {
537                        let left_val = self.variables[left];
538                        let right_val = self.variables[right];
539                        *gradients.entry(left.clone()).or_insert(0.0) += grad * right_val;
540                        *gradients.entry(right.clone()).or_insert(0.0) += grad * left_val;
541                    }
542                }
543                _ => {}
544            }
545        }
546
547        // Extract gradients for requested inputs
548        inputs
549            .iter()
550            .map(|&input| {
551                (
552                    input.to_string(),
553                    gradients.get(input).copied().unwrap_or(0.0),
554                )
555            })
556            .collect()
557    }
558}
559
560/// Optimizers for gradient-based training
561pub mod optimizers {
562    use super::*;
563
564    /// Base optimizer trait
565    pub trait Optimizer {
566        /// Update parameters given gradients
567        fn step(&mut self, params: &mut HashMap<String, f64>, gradients: &HashMap<String, f64>);
568
569        /// Reset optimizer state
570        fn reset(&mut self);
571    }
572
573    /// Stochastic Gradient Descent
574    pub struct SGD {
575        learning_rate: f64,
576        momentum: f64,
577        velocities: HashMap<String, f64>,
578    }
579
580    impl SGD {
581        pub fn new(learning_rate: f64, momentum: f64) -> Self {
582            Self {
583                learning_rate,
584                momentum,
585                velocities: HashMap::new(),
586            }
587        }
588    }
589
590    impl Optimizer for SGD {
591        fn step(&mut self, params: &mut HashMap<String, f64>, gradients: &HashMap<String, f64>) {
592            for (name, grad) in gradients {
593                let velocity = self.velocities.entry(name.clone()).or_insert(0.0);
594                *velocity = self.momentum * *velocity - self.learning_rate * grad;
595
596                if let Some(param) = params.get_mut(name) {
597                    *param += *velocity;
598                }
599            }
600        }
601
602        fn reset(&mut self) {
603            self.velocities.clear();
604        }
605    }
606
607    /// Adam optimizer
608    pub struct Adam {
609        learning_rate: f64,
610        beta1: f64,
611        beta2: f64,
612        epsilon: f64,
613        t: usize,
614        m: HashMap<String, f64>,
615        v: HashMap<String, f64>,
616    }
617
618    impl Adam {
619        pub fn new(learning_rate: f64) -> Self {
620            Self {
621                learning_rate,
622                beta1: 0.9,
623                beta2: 0.999,
624                epsilon: 1e-8,
625                t: 0,
626                m: HashMap::new(),
627                v: HashMap::new(),
628            }
629        }
630    }
631
632    impl Optimizer for Adam {
633        fn step(&mut self, params: &mut HashMap<String, f64>, gradients: &HashMap<String, f64>) {
634            self.t += 1;
635            let t = self.t as f64;
636
637            for (name, grad) in gradients {
638                let m_t = self.m.entry(name.clone()).or_insert(0.0);
639                let v_t = self.v.entry(name.clone()).or_insert(0.0);
640
641                // Update biased moments
642                *m_t = self.beta1 * *m_t + (1.0 - self.beta1) * grad;
643                *v_t = self.beta2 * *v_t + (1.0 - self.beta2) * grad * grad;
644
645                // Bias correction
646                let m_hat = *m_t / (1.0 - self.beta1.powf(t));
647                let v_hat = *v_t / (1.0 - self.beta2.powf(t));
648
649                // Update parameters
650                if let Some(param) = params.get_mut(name) {
651                    *param -= self.learning_rate * m_hat / (v_hat.sqrt() + self.epsilon);
652                }
653            }
654        }
655
656        fn reset(&mut self) {
657            self.t = 0;
658            self.m.clear();
659            self.v.clear();
660        }
661    }
662
663    /// Quantum Natural Gradient
664    pub struct QNG {
665        learning_rate: f64,
666        regularization: f64,
667    }
668
669    impl QNG {
670        pub fn new(learning_rate: f64, regularization: f64) -> Self {
671            Self {
672                learning_rate,
673                regularization,
674            }
675        }
676    }
677
678    impl Optimizer for QNG {
679        fn step(&mut self, params: &mut HashMap<String, f64>, gradients: &HashMap<String, f64>) {
680            // Simplified - would compute natural gradient
681            for (name, grad) in gradients {
682                if let Some(param) = params.get_mut(name) {
683                    *param -= self.learning_rate * grad;
684                }
685            }
686        }
687
688        fn reset(&mut self) {}
689    }
690}
691
692#[cfg(test)]
693mod tests {
694    use super::*;
695
696    #[test]
697    fn test_autodiff_basic() {
698        let mut autodiff = AutoDiff::new();
699
700        // Register parameters
701        autodiff.register_parameter(DifferentiableParam::new("x", 2.0));
702        autodiff.register_parameter(DifferentiableParam::new("y", 3.0));
703
704        // Build computation graph: z = x * y
705        let graph = ComputationNode::Mul(
706            Box::new(ComputationNode::Parameter("x".to_string())),
707            Box::new(ComputationNode::Parameter("y".to_string())),
708        );
709        autodiff.set_graph(graph);
710
711        // Forward pass
712        let result = autodiff.forward().expect("forward pass should succeed");
713        assert_eq!(result, 6.0);
714
715        // Backward pass
716        autodiff
717            .backward(1.0)
718            .expect("backward pass should succeed");
719        let gradients = autodiff.gradients();
720
721        assert_eq!(gradients["x"], 3.0); // dz/dx = y
722        assert_eq!(gradients["y"], 2.0); // dz/dy = x
723    }
724
725    #[test]
726    fn test_gradient_tape() {
727        let mut tape = GradientTape::new();
728
729        let x = tape.variable("x", 2.0);
730        let y = tape.variable("y", 3.0);
731        let z = tape.mul(&x, &y);
732
733        let gradients = tape.gradient(&z, &[&x, &y]);
734
735        assert_eq!(gradients[&x], 3.0);
736        assert_eq!(gradients[&y], 2.0);
737    }
738
739    #[test]
740    fn test_optimizers() {
741        use optimizers::*;
742
743        let mut params = HashMap::new();
744        params.insert("x".to_string(), 5.0);
745
746        let mut gradients = HashMap::new();
747        gradients.insert("x".to_string(), 2.0);
748
749        // Test SGD
750        let mut sgd = SGD::new(0.1, 0.0);
751        sgd.step(&mut params, &gradients);
752        assert!((params["x"] - 4.8).abs() < 1e-6);
753
754        // Test Adam
755        params.insert("x".to_string(), 5.0);
756        let mut adam = Adam::new(0.1);
757        adam.step(&mut params, &gradients);
758        assert!(params["x"] < 5.0); // Should decrease
759    }
760
761    #[test]
762    fn test_parameter_shift() {
763        let executor = |params: &[f64]| -> f64 { params[0].cos() + params[1].sin() };
764
765        let qad = QuantumAutoDiff::new(executor);
766        let params = vec![PI / 4.0, PI / 3.0];
767
768        let gradients = qad
769            .parameter_shift_gradients(&params, PI / 2.0)
770            .expect("parameter shift gradients should succeed");
771        assert_eq!(gradients.len(), 2);
772    }
773}