Skip to main content

quantrs2_sim/advanced_variational_algorithms/
parametershiftgradient_traits.rs

1//! # ParameterShiftGradient - Trait Implementations
2//!
3//! This module contains trait implementations for `ParameterShiftGradient`.
4//!
5//! ## Implemented Traits
6//!
7//! - `GradientCalculator`
8//!
9//! 🤖 Generated with [SplitRS](https://github.com/cool-japan/splitrs)
10
11use crate::circuit_interfaces::{InterfaceCircuit, InterfaceGate, InterfaceGateType};
12use crate::error::{Result, SimulatorError};
13use scirs2_core::random::prelude::*;
14
15use super::functions::{CostFunction, GradientCalculator};
16use super::types::ParameterShiftGradient;
17
18impl GradientCalculator for ParameterShiftGradient {
19    fn calculate_gradient(
20        &self,
21        parameters: &[f64],
22        cost_function: &dyn CostFunction,
23        circuit: &InterfaceCircuit,
24    ) -> Result<Vec<f64>> {
25        let shift = std::f64::consts::PI / 2.0;
26        let mut gradient = Vec::with_capacity(parameters.len());
27        for i in 0..parameters.len() {
28            let mut params_plus = parameters.to_vec();
29            let mut params_minus = parameters.to_vec();
30            params_plus[i] += shift;
31            params_minus[i] -= shift;
32            let cost_plus = cost_function.evaluate(&params_plus, circuit)?;
33            let cost_minus = cost_function.evaluate(&params_minus, circuit)?;
34            let grad = (cost_plus - cost_minus) / 2.0;
35            gradient.push(grad);
36        }
37        Ok(gradient)
38    }
39    fn method_name(&self) -> &'static str {
40        "ParameterShift"
41    }
42}