qasmsim 1.2.0

A QASM interpreter and quantum simulator in Rust
Documentation
use std::collections::HashMap;

use crate::grammar::ast;

#[derive(Debug, Clone, PartialEq)]
pub struct ExpressionSolver<'bindings>(&'bindings HashMap<String, f64>);

impl<'bindings> ExpressionSolver<'bindings> {
    pub fn new(symbol_table: &'bindings HashMap<String, f64>) -> Self {
        ExpressionSolver::<'bindings>(symbol_table)
    }

    pub fn solve(&self, expression: &ast::Expression) -> Result<f64, String> {
        Ok(match expression {
            ast::Expression::Pi => std::f64::consts::PI,
            ast::Expression::Int(value) => *value as f64,
            ast::Expression::Real(value) => *value,
            ast::Expression::Minus(expr) => -self.solve(expr)?,
            ast::Expression::Op(op_code, left, right) => match op_code {
                ast::OpCode::Add => self.solve(left)? + self.solve(right)?,
                ast::OpCode::Sub => self.solve(left)? - self.solve(right)?,
                ast::OpCode::Mul => self.solve(left)? * self.solve(right)?,
                ast::OpCode::Div => self.solve(left)? / self.solve(right)?,
                ast::OpCode::Pow => self.solve(left)?.powf(self.solve(right)?),
            },
            ast::Expression::Function(func_code, expr) => match func_code {
                ast::FuncCode::Sin => self.solve(expr)?.sin(),
                ast::FuncCode::Cos => self.solve(expr)?.cos(),
                ast::FuncCode::Tan => self.solve(expr)?.tan(),
                ast::FuncCode::Exp => self.solve(expr)?.exp(),
                ast::FuncCode::Ln => self.solve(expr)?.ln(),
                ast::FuncCode::Sqrt => self.solve(expr)?.sqrt(),
            },
            ast::Expression::Id(name) => match self.0.get(name) {
                None => return Err(name.into()),
                Some(value) => *value,
            },
        })
    }
}

#[cfg(test)]
mod test {
    use std::f64::consts::PI;
    use std::iter::FromIterator;

    use super::*;
    use crate::grammar::ast::*;

    #[test]
    #[allow(clippy::float_cmp)]
    fn test_expression_solver() {
        let expression = Expression::Op(
            OpCode::Add,
            Box::new(Expression::Minus(Box::new(Expression::Pi))),
            Box::new(Expression::Op(
                OpCode::Div,
                Box::new(Expression::Op(
                    OpCode::Mul,
                    Box::new(Expression::Op(
                        OpCode::Sub,
                        Box::new(Expression::Real(1.0)),
                        Box::new(Expression::Op(
                            OpCode::Pow,
                            Box::new(Expression::Real(2.0)),
                            Box::new(Expression::Real(3.0)),
                        )),
                    )),
                    Box::new(Expression::Real(4.0)),
                )),
                Box::new(Expression::Real(5.0)),
            )),
        );
        let empty = HashMap::new();
        let solver = ExpressionSolver::new(&empty);
        let result = solver.solve(&expression).expect("get value of expression");
        assert_eq!(result, -PI + (1.0 - 2.0_f64.powf(3.0)) * 4.0 / 5.0);
    }

    #[test]
    #[allow(clippy::float_cmp)]
    fn test_expression_solver_with_functions() {
        let expression = Expression::Function(
            FuncCode::Sqrt,
            Box::new(Expression::Function(
                FuncCode::Ln,
                Box::new(Expression::Function(
                    FuncCode::Exp,
                    Box::new(Expression::Function(
                        FuncCode::Tan,
                        Box::new(Expression::Function(
                            FuncCode::Cos,
                            Box::new(Expression::Function(
                                FuncCode::Sin,
                                Box::new(Expression::Real(1.0)),
                            )),
                        )),
                    )),
                )),
            )),
        );
        let empty = HashMap::new();
        let solver = ExpressionSolver::new(&empty);
        let result = solver.solve(&expression).expect("get value of expression");
        assert_eq!(result, 1.0_f64.sin().cos().tan().exp().ln().sqrt());
    }

    #[test]
    #[allow(clippy::float_cmp)]
    fn test_expression_solver_with_symbol_substitution() {
        let expression = Expression::Op(
            OpCode::Add,
            Box::new(Expression::Id("some_name".into())),
            Box::new(Expression::Op(
                OpCode::Div,
                Box::new(Expression::Op(
                    OpCode::Mul,
                    Box::new(Expression::Op(
                        OpCode::Sub,
                        Box::new(Expression::Real(1.0)),
                        Box::new(Expression::Real(2.0)),
                    )),
                    Box::new(Expression::Real(3.0)),
                )),
                Box::new(Expression::Real(4.0)),
            )),
        );
        let bindings = HashMap::from_iter(vec![("some_name".into(), 1.0)]);
        let solver = ExpressionSolver::new(&bindings);
        let result = solver.solve(&expression).expect("get value of expression");
        assert_eq!(result, 1.0 + (1.0 - 2.0) * 3.0 / 4.0);
    }

    #[test]
    fn test_expression_solver_fails_at_symbol_substitution() {
        let expression = Expression::Op(
            OpCode::Add,
            Box::new(Expression::Id("some_name".into())),
            Box::new(Expression::Real(1.0)),
        );
        let empty_bindings = HashMap::new();
        let solver = ExpressionSolver::new(&empty_bindings);
        let error = solver
            .solve(&expression)
            .expect_err("fails at replacing `some_name`");
        assert_eq!(error, String::from("some_name"));
    }
}