use std::collections::HashMap;
use crate::error::{SymEngineError, SymEngineResult};
use crate::expr::Expression;
pub fn gradient_at(
expr: &Expression,
params: &[Expression],
values: &HashMap<String, f64>,
) -> SymEngineResult<Vec<f64>> {
let gradient = expr.gradient(params);
gradient.iter().map(|g| g.eval(values)).collect()
}
pub fn hessian_at(
expr: &Expression,
params: &[Expression],
values: &HashMap<String, f64>,
) -> SymEngineResult<Vec<Vec<f64>>> {
let hessian = expr.hessian(params);
hessian
.iter()
.map(|row| row.iter().map(|h| h.eval(values)).collect())
.collect()
}
pub struct ParameterShiftRule {
pub shift: f64,
}
impl ParameterShiftRule {
#[must_use]
pub const fn new() -> Self {
Self {
shift: std::f64::consts::FRAC_PI_2,
}
}
#[must_use]
pub const fn with_shift(shift: f64) -> Self {
Self { shift }
}
pub fn compute_gradient<F>(&self, energy_fn: F, params: &[f64]) -> Vec<f64>
where
F: Fn(&[f64]) -> f64,
{
let n = params.len();
let mut gradient = vec![0.0; n];
let s = self.shift;
let denominator = 2.0 * s.sin();
for i in 0..n {
let mut params_plus = params.to_vec();
let mut params_minus = params.to_vec();
params_plus[i] += s;
params_minus[i] -= s;
let e_plus = energy_fn(¶ms_plus);
let e_minus = energy_fn(¶ms_minus);
gradient[i] = (e_plus - e_minus) / denominator;
}
gradient
}
}
impl Default for ParameterShiftRule {
fn default() -> Self {
Self::new()
}
}
pub struct VqeOptimizer {
pub energy: Expression,
pub params: Vec<Expression>,
pub learning_rate: f64,
}
impl VqeOptimizer {
#[allow(clippy::missing_const_for_fn)] pub fn new(energy: Expression, params: Vec<Expression>, learning_rate: f64) -> Self {
Self {
energy,
params,
learning_rate,
}
}
pub fn compute_gradient(&self, values: &HashMap<String, f64>) -> SymEngineResult<Vec<f64>> {
gradient_at(&self.energy, &self.params, values)
}
pub fn step(&self, values: &mut HashMap<String, f64>) -> SymEngineResult<f64> {
let gradient = self.compute_gradient(values)?;
for (param, grad) in self.params.iter().zip(gradient.iter()) {
if let Some(name) = param.as_symbol() {
if let Some(value) = values.get_mut(name) {
*value -= self.learning_rate * grad;
}
}
}
self.energy.eval(values)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ops::trig;
#[test]
fn test_gradient_at() {
let theta = Expression::symbol("theta");
let energy = theta.clone() * theta.clone();
let params = vec![theta];
let mut values = HashMap::new();
values.insert("theta".to_string(), 3.0);
let grad = gradient_at(&energy, ¶ms, &values).expect("should compute gradient");
assert!((grad[0] - 6.0).abs() < 1e-6);
}
#[test]
fn test_parameter_shift_rule() {
let psr = ParameterShiftRule::new();
let gradient = psr.compute_gradient(
|params| params[0].sin(),
&[0.0], );
assert!((gradient[0] - 1.0).abs() < 1e-6);
}
}