quantrs2_ml/
autodiff.rs

1//! Automatic differentiation for quantum machine learning.
2//!
3//! This module provides SciRS2-style automatic differentiation capabilities
4//! for computing gradients of quantum circuits and variational algorithms.
5
6use scirs2_core::ndarray::{Array1, Array2};
7use std::collections::HashMap;
8use std::f64::consts::PI;
9
10use crate::error::{MLError, Result};
11use quantrs2_circuit::prelude::*;
12use quantrs2_core::gate::GateOp;
13
14/// Differentiable parameter in a quantum circuit
15#[derive(Debug, Clone)]
16pub struct DifferentiableParam {
17    /// Parameter name/ID
18    pub name: String,
19    /// Current value
20    pub value: f64,
21    /// Gradient accumulator
22    pub gradient: f64,
23    /// Whether this parameter requires gradient
24    pub requires_grad: bool,
25}
26
27impl DifferentiableParam {
28    /// Create a new differentiable parameter
29    pub fn new(name: impl Into<String>, value: f64) -> Self {
30        Self {
31            name: name.into(),
32            value,
33            gradient: 0.0,
34            requires_grad: true,
35        }
36    }
37
38    /// Create a constant (non-differentiable) parameter
39    pub fn constant(name: impl Into<String>, value: f64) -> Self {
40        Self {
41            name: name.into(),
42            value,
43            gradient: 0.0,
44            requires_grad: false,
45        }
46    }
47}
48
49/// Computation graph node for automatic differentiation
50#[derive(Debug, Clone)]
51pub enum ComputationNode {
52    /// Input parameter
53    Parameter(String),
54    /// Constant value
55    Constant(f64),
56    /// Addition operation
57    Add(Box<ComputationNode>, Box<ComputationNode>),
58    /// Multiplication operation
59    Mul(Box<ComputationNode>, Box<ComputationNode>),
60    /// Sine function
61    Sin(Box<ComputationNode>),
62    /// Cosine function
63    Cos(Box<ComputationNode>),
64    /// Exponential function
65    Exp(Box<ComputationNode>),
66    /// Quantum expectation value
67    Expectation {
68        circuit_params: Vec<String>,
69        observable: String,
70    },
71}
72
73/// Automatic differentiation engine
74pub struct AutoDiff {
75    /// Parameters registry
76    parameters: HashMap<String, DifferentiableParam>,
77    /// Computation graph
78    graph: Option<ComputationNode>,
79    /// Cached forward values
80    forward_cache: HashMap<String, f64>,
81}
82
83impl AutoDiff {
84    /// Create a new AutoDiff engine
85    pub fn new() -> Self {
86        Self {
87            parameters: HashMap::new(),
88            graph: None,
89            forward_cache: HashMap::new(),
90        }
91    }
92
93    /// Register a parameter
94    pub fn register_parameter(&mut self, param: DifferentiableParam) {
95        self.parameters.insert(param.name.clone(), param);
96    }
97
98    /// Set computation graph
99    pub fn set_graph(&mut self, graph: ComputationNode) {
100        self.graph = Some(graph);
101    }
102
103    /// Forward pass - compute value
104    pub fn forward(&mut self) -> Result<f64> {
105        self.forward_cache.clear();
106
107        if let Some(graph) = self.graph.clone() {
108            self.evaluate_node(&graph)
109        } else {
110            Err(MLError::InvalidConfiguration(
111                "No computation graph set".to_string(),
112            ))
113        }
114    }
115
116    /// Backward pass - compute gradients
117    pub fn backward(&mut self, loss_gradient: f64) -> Result<()> {
118        // Reset gradients
119        for param in self.parameters.values_mut() {
120            param.gradient = 0.0;
121        }
122
123        if let Some(graph) = self.graph.clone() {
124            self.backpropagate(&graph, loss_gradient)?;
125        }
126
127        Ok(())
128    }
129
130    /// Evaluate a computation node
131    fn evaluate_node(&mut self, node: &ComputationNode) -> Result<f64> {
132        match node {
133            ComputationNode::Parameter(name) => {
134                self.parameters.get(name).map(|p| p.value).ok_or_else(|| {
135                    MLError::InvalidConfiguration(format!("Unknown parameter: {}", name))
136                })
137            }
138            ComputationNode::Constant(value) => Ok(*value),
139            ComputationNode::Add(left, right) => {
140                let l = self.evaluate_node(left)?;
141                let r = self.evaluate_node(right)?;
142                Ok(l + r)
143            }
144            ComputationNode::Mul(left, right) => {
145                let l = self.evaluate_node(left)?;
146                let r = self.evaluate_node(right)?;
147                Ok(l * r)
148            }
149            ComputationNode::Sin(inner) => {
150                let x = self.evaluate_node(inner)?;
151                Ok(x.sin())
152            }
153            ComputationNode::Cos(inner) => {
154                let x = self.evaluate_node(inner)?;
155                Ok(x.cos())
156            }
157            ComputationNode::Exp(inner) => {
158                let x = self.evaluate_node(inner)?;
159                Ok(x.exp())
160            }
161            ComputationNode::Expectation {
162                circuit_params,
163                observable,
164            } => {
165                // Simplified - would compute actual expectation value
166                let mut sum = 0.0;
167                for param_name in circuit_params {
168                    if let Some(param) = self.parameters.get(param_name) {
169                        sum += param.value;
170                    }
171                }
172                Ok(sum.cos()) // Placeholder
173            }
174        }
175    }
176
177    /// Backpropagate gradients through the graph
178    fn backpropagate(&mut self, node: &ComputationNode, grad: f64) -> Result<()> {
179        match node {
180            ComputationNode::Parameter(name) => {
181                if let Some(param) = self.parameters.get_mut(name) {
182                    if param.requires_grad {
183                        param.gradient += grad;
184                    }
185                }
186            }
187            ComputationNode::Constant(_) => {
188                // No gradient for constants
189            }
190            ComputationNode::Add(left, right) => {
191                // Gradient distributes equally for addition
192                self.backpropagate(left, grad)?;
193                self.backpropagate(right, grad)?;
194            }
195            ComputationNode::Mul(left, right) => {
196                // Product rule
197                let l_val = self.evaluate_node(left)?;
198                let r_val = self.evaluate_node(right)?;
199                self.backpropagate(left, grad * r_val)?;
200                self.backpropagate(right, grad * l_val)?;
201            }
202            ComputationNode::Sin(inner) => {
203                // d/dx sin(x) = cos(x)
204                let x = self.evaluate_node(inner)?;
205                self.backpropagate(inner, grad * x.cos())?;
206            }
207            ComputationNode::Cos(inner) => {
208                // d/dx cos(x) = -sin(x)
209                let x = self.evaluate_node(inner)?;
210                self.backpropagate(inner, grad * (-x.sin()))?;
211            }
212            ComputationNode::Exp(inner) => {
213                // d/dx exp(x) = exp(x)
214                let x = self.evaluate_node(inner)?;
215                self.backpropagate(inner, grad * x.exp())?;
216            }
217            ComputationNode::Expectation { circuit_params, .. } => {
218                // Use parameter shift rule for quantum gradients
219                for param_name in circuit_params {
220                    let shift_grad = self.parameter_shift_gradient(param_name, PI / 2.0)?;
221                    if let Some(param) = self.parameters.get_mut(param_name) {
222                        if param.requires_grad {
223                            param.gradient += grad * shift_grad;
224                        }
225                    }
226                }
227            }
228        }
229        Ok(())
230    }
231
232    /// Compute gradient using parameter shift rule
233    fn parameter_shift_gradient(&self, param_name: &str, shift: f64) -> Result<f64> {
234        // Simplified parameter shift rule
235        // In practice, would evaluate circuit with ±shift
236        Ok(0.5) // Placeholder
237    }
238
239    /// Get all gradients
240    pub fn gradients(&self) -> HashMap<String, f64> {
241        self.parameters
242            .iter()
243            .filter(|(_, p)| p.requires_grad)
244            .map(|(name, param)| (name.clone(), param.gradient))
245            .collect()
246    }
247
248    /// Update parameters using gradients
249    pub fn update_parameters(&mut self, learning_rate: f64) {
250        for param in self.parameters.values_mut() {
251            if param.requires_grad {
252                param.value -= learning_rate * param.gradient;
253            }
254        }
255    }
256}
257
258/// Quantum-aware automatic differentiation
259pub struct QuantumAutoDiff {
260    /// Base autodiff engine
261    autodiff: AutoDiff,
262    /// Circuit executor (placeholder)
263    executor: Box<dyn Fn(&[f64]) -> f64>,
264}
265
266impl QuantumAutoDiff {
267    /// Create a new quantum autodiff engine
268    pub fn new<F>(executor: F) -> Self
269    where
270        F: Fn(&[f64]) -> f64 + 'static,
271    {
272        Self {
273            autodiff: AutoDiff::new(),
274            executor: Box::new(executor),
275        }
276    }
277
278    /// Compute gradients using parameter shift rule
279    pub fn parameter_shift_gradients(&self, params: &[f64], shift: f64) -> Result<Vec<f64>> {
280        let mut gradients = vec![0.0; params.len()];
281
282        for (i, _) in params.iter().enumerate() {
283            // Shift parameter positively
284            let mut params_plus = params.to_vec();
285            params_plus[i] += shift;
286            let val_plus = (self.executor)(&params_plus);
287
288            // Shift parameter negatively
289            let mut params_minus = params.to_vec();
290            params_minus[i] -= shift;
291            let val_minus = (self.executor)(&params_minus);
292
293            // Parameter shift rule gradient
294            gradients[i] = (val_plus - val_minus) / (2.0 * shift.sin());
295        }
296
297        Ok(gradients)
298    }
299
300    /// Compute natural gradients using quantum Fisher information
301    pub fn natural_gradients(
302        &self,
303        params: &[f64],
304        gradients: &[f64],
305        regularization: f64,
306    ) -> Result<Vec<f64>> {
307        let n = params.len();
308        let mut fisher = Array2::<f64>::zeros((n, n));
309
310        // Compute quantum Fisher information matrix
311        for i in 0..n {
312            for j in 0..n {
313                fisher[[i, j]] = self.compute_fisher_element(params, i, j)?;
314            }
315        }
316
317        // Add regularization
318        for i in 0..n {
319            fisher[[i, i]] += regularization;
320        }
321
322        // Solve F * nat_grad = grad
323        self.solve_linear_system(&fisher, gradients)
324    }
325
326    /// Compute element of quantum Fisher information matrix
327    fn compute_fisher_element(&self, params: &[f64], i: usize, j: usize) -> Result<f64> {
328        // Simplified - would compute <∂ψ/∂θᵢ|∂ψ/∂θⱼ>
329        if i == j {
330            Ok(1.0 + 0.1 * fastrand::f64())
331        } else {
332            Ok(0.1 * fastrand::f64())
333        }
334    }
335
336    /// Solve linear system (simplified)
337    fn solve_linear_system(&self, matrix: &Array2<f64>, rhs: &[f64]) -> Result<Vec<f64>> {
338        // Simplified - would use proper linear algebra
339        Ok(rhs.to_vec())
340    }
341}
342
343/// Gradient tape for recording operations
344#[derive(Debug, Clone)]
345pub struct GradientTape {
346    /// Recorded operations
347    operations: Vec<Operation>,
348    /// Variable values
349    variables: HashMap<String, f64>,
350}
351
352/// Recorded operation
353#[derive(Debug, Clone)]
354enum Operation {
355    /// Variable assignment
356    Assign { var: String, value: f64 },
357    /// Addition
358    Add {
359        result: String,
360        left: String,
361        right: String,
362    },
363    /// Multiplication
364    Mul {
365        result: String,
366        left: String,
367        right: String,
368    },
369    /// Quantum operation
370    Quantum { result: String, params: Vec<String> },
371}
372
373impl GradientTape {
374    /// Create a new gradient tape
375    pub fn new() -> Self {
376        Self {
377            operations: Vec::new(),
378            variables: HashMap::new(),
379        }
380    }
381
382    /// Record a variable
383    pub fn variable(&mut self, name: impl Into<String>, value: f64) -> String {
384        let name = name.into();
385        self.variables.insert(name.clone(), value);
386        self.operations.push(Operation::Assign {
387            var: name.clone(),
388            value,
389        });
390        name
391    }
392
393    /// Record addition
394    pub fn add(&mut self, left: &str, right: &str) -> String {
395        let result = format!("tmp_{}", self.operations.len());
396        let left_val = self.variables[left];
397        let right_val = self.variables[right];
398        self.variables.insert(result.clone(), left_val + right_val);
399        self.operations.push(Operation::Add {
400            result: result.clone(),
401            left: left.to_string(),
402            right: right.to_string(),
403        });
404        result
405    }
406
407    /// Record multiplication
408    pub fn mul(&mut self, left: &str, right: &str) -> String {
409        let result = format!("tmp_{}", self.operations.len());
410        let left_val = self.variables[left];
411        let right_val = self.variables[right];
412        self.variables.insert(result.clone(), left_val * right_val);
413        self.operations.push(Operation::Mul {
414            result: result.clone(),
415            left: left.to_string(),
416            right: right.to_string(),
417        });
418        result
419    }
420
421    /// Compute gradients
422    pub fn gradient(&self, output: &str, inputs: &[&str]) -> HashMap<String, f64> {
423        let mut gradients: HashMap<String, f64> = HashMap::new();
424
425        // Initialize output gradient
426        gradients.insert(output.to_string(), 1.0);
427
428        // Backward pass through operations
429        for op in self.operations.iter().rev() {
430            match op {
431                Operation::Add {
432                    result,
433                    left,
434                    right,
435                } => {
436                    if let Some(&grad) = gradients.get(result) {
437                        *gradients.entry(left.clone()).or_insert(0.0) += grad;
438                        *gradients.entry(right.clone()).or_insert(0.0) += grad;
439                    }
440                }
441                Operation::Mul {
442                    result,
443                    left,
444                    right,
445                } => {
446                    if let Some(&grad) = gradients.get(result) {
447                        let left_val = self.variables[left];
448                        let right_val = self.variables[right];
449                        *gradients.entry(left.clone()).or_insert(0.0) += grad * right_val;
450                        *gradients.entry(right.clone()).or_insert(0.0) += grad * left_val;
451                    }
452                }
453                _ => {}
454            }
455        }
456
457        // Extract gradients for requested inputs
458        inputs
459            .iter()
460            .map(|&input| {
461                (
462                    input.to_string(),
463                    gradients.get(input).copied().unwrap_or(0.0),
464                )
465            })
466            .collect()
467    }
468}
469
470/// Optimizers for gradient-based training
471pub mod optimizers {
472    use super::*;
473
474    /// Base optimizer trait
475    pub trait Optimizer {
476        /// Update parameters given gradients
477        fn step(&mut self, params: &mut HashMap<String, f64>, gradients: &HashMap<String, f64>);
478
479        /// Reset optimizer state
480        fn reset(&mut self);
481    }
482
483    /// Stochastic Gradient Descent
484    pub struct SGD {
485        learning_rate: f64,
486        momentum: f64,
487        velocities: HashMap<String, f64>,
488    }
489
490    impl SGD {
491        pub fn new(learning_rate: f64, momentum: f64) -> Self {
492            Self {
493                learning_rate,
494                momentum,
495                velocities: HashMap::new(),
496            }
497        }
498    }
499
500    impl Optimizer for SGD {
501        fn step(&mut self, params: &mut HashMap<String, f64>, gradients: &HashMap<String, f64>) {
502            for (name, grad) in gradients {
503                let velocity = self.velocities.entry(name.clone()).or_insert(0.0);
504                *velocity = self.momentum * *velocity - self.learning_rate * grad;
505
506                if let Some(param) = params.get_mut(name) {
507                    *param += *velocity;
508                }
509            }
510        }
511
512        fn reset(&mut self) {
513            self.velocities.clear();
514        }
515    }
516
517    /// Adam optimizer
518    pub struct Adam {
519        learning_rate: f64,
520        beta1: f64,
521        beta2: f64,
522        epsilon: f64,
523        t: usize,
524        m: HashMap<String, f64>,
525        v: HashMap<String, f64>,
526    }
527
528    impl Adam {
529        pub fn new(learning_rate: f64) -> Self {
530            Self {
531                learning_rate,
532                beta1: 0.9,
533                beta2: 0.999,
534                epsilon: 1e-8,
535                t: 0,
536                m: HashMap::new(),
537                v: HashMap::new(),
538            }
539        }
540    }
541
542    impl Optimizer for Adam {
543        fn step(&mut self, params: &mut HashMap<String, f64>, gradients: &HashMap<String, f64>) {
544            self.t += 1;
545            let t = self.t as f64;
546
547            for (name, grad) in gradients {
548                let m_t = self.m.entry(name.clone()).or_insert(0.0);
549                let v_t = self.v.entry(name.clone()).or_insert(0.0);
550
551                // Update biased moments
552                *m_t = self.beta1 * *m_t + (1.0 - self.beta1) * grad;
553                *v_t = self.beta2 * *v_t + (1.0 - self.beta2) * grad * grad;
554
555                // Bias correction
556                let m_hat = *m_t / (1.0 - self.beta1.powf(t));
557                let v_hat = *v_t / (1.0 - self.beta2.powf(t));
558
559                // Update parameters
560                if let Some(param) = params.get_mut(name) {
561                    *param -= self.learning_rate * m_hat / (v_hat.sqrt() + self.epsilon);
562                }
563            }
564        }
565
566        fn reset(&mut self) {
567            self.t = 0;
568            self.m.clear();
569            self.v.clear();
570        }
571    }
572
573    /// Quantum Natural Gradient
574    pub struct QNG {
575        learning_rate: f64,
576        regularization: f64,
577    }
578
579    impl QNG {
580        pub fn new(learning_rate: f64, regularization: f64) -> Self {
581            Self {
582                learning_rate,
583                regularization,
584            }
585        }
586    }
587
588    impl Optimizer for QNG {
589        fn step(&mut self, params: &mut HashMap<String, f64>, gradients: &HashMap<String, f64>) {
590            // Simplified - would compute natural gradient
591            for (name, grad) in gradients {
592                if let Some(param) = params.get_mut(name) {
593                    *param -= self.learning_rate * grad;
594                }
595            }
596        }
597
598        fn reset(&mut self) {}
599    }
600}
601
602#[cfg(test)]
603mod tests {
604    use super::*;
605
606    #[test]
607    fn test_autodiff_basic() {
608        let mut autodiff = AutoDiff::new();
609
610        // Register parameters
611        autodiff.register_parameter(DifferentiableParam::new("x", 2.0));
612        autodiff.register_parameter(DifferentiableParam::new("y", 3.0));
613
614        // Build computation graph: z = x * y
615        let graph = ComputationNode::Mul(
616            Box::new(ComputationNode::Parameter("x".to_string())),
617            Box::new(ComputationNode::Parameter("y".to_string())),
618        );
619        autodiff.set_graph(graph);
620
621        // Forward pass
622        let result = autodiff.forward().expect("forward pass should succeed");
623        assert_eq!(result, 6.0);
624
625        // Backward pass
626        autodiff
627            .backward(1.0)
628            .expect("backward pass should succeed");
629        let gradients = autodiff.gradients();
630
631        assert_eq!(gradients["x"], 3.0); // dz/dx = y
632        assert_eq!(gradients["y"], 2.0); // dz/dy = x
633    }
634
635    #[test]
636    fn test_gradient_tape() {
637        let mut tape = GradientTape::new();
638
639        let x = tape.variable("x", 2.0);
640        let y = tape.variable("y", 3.0);
641        let z = tape.mul(&x, &y);
642
643        let gradients = tape.gradient(&z, &[&x, &y]);
644
645        assert_eq!(gradients[&x], 3.0);
646        assert_eq!(gradients[&y], 2.0);
647    }
648
649    #[test]
650    fn test_optimizers() {
651        use optimizers::*;
652
653        let mut params = HashMap::new();
654        params.insert("x".to_string(), 5.0);
655
656        let mut gradients = HashMap::new();
657        gradients.insert("x".to_string(), 2.0);
658
659        // Test SGD
660        let mut sgd = SGD::new(0.1, 0.0);
661        sgd.step(&mut params, &gradients);
662        assert!((params["x"] - 4.8).abs() < 1e-6);
663
664        // Test Adam
665        params.insert("x".to_string(), 5.0);
666        let mut adam = Adam::new(0.1);
667        adam.step(&mut params, &gradients);
668        assert!(params["x"] < 5.0); // Should decrease
669    }
670
671    #[test]
672    fn test_parameter_shift() {
673        let executor = |params: &[f64]| -> f64 { params[0].cos() + params[1].sin() };
674
675        let qad = QuantumAutoDiff::new(executor);
676        let params = vec![PI / 4.0, PI / 3.0];
677
678        let gradients = qad
679            .parameter_shift_gradients(&params, PI / 2.0)
680            .expect("parameter shift gradients should succeed");
681        assert_eq!(gradients.len(), 2);
682    }
683}