use derive_more::Display;
use hugr::ops::OpType;
use hugr::std_extensions::arithmetic::float_ops::FloatOps;
use itertools::Itertools;
use pest::iterators::{Pair, Pairs};
use pest::pratt_parser::PrattParser;
use pest::Parser;
use pest_derive::Parser;
#[derive(Debug, Display, Clone, PartialEq)]
pub enum PytketParam<'a> {
#[display("{_0}")]
Constant(f64),
#[display("\"{name}\"")]
InputVariable {
name: &'a str,
},
#[display("Sympy(\"{_0}\")")]
Sympy(&'a str),
#[display("{}({})", op.to_string(), args.iter().map(|a| a.to_string()).join(", "))]
Operation {
op: OpType,
args: Vec<PytketParam<'a>>,
},
}
#[inline]
pub fn parse_pytket_param(param: &str) -> PytketParam<'_> {
let Ok(mut parsed) = ParamParser::parse(Rule::parameter, param) else {
return PytketParam::Sympy(param);
};
let parsed = parsed
.next()
.expect("The `parameter` rule can only be matched once.");
parse_infix_ops(parsed.into_inner())
}
#[derive(Parser)]
#[grammar = "serialize/pytket/decoder/param.pest"]
struct ParamParser;
lazy_static::lazy_static! {
static ref PRATT_PARSER: PrattParser<Rule> = {
use pest::pratt_parser::{Assoc::*, Op};
use Rule::*;
PrattParser::new()
.op(Op::infix(add, Left) | Op::infix(subtract, Left))
.op(Op::infix(multiply, Left) | Op::infix(divide, Left))
.op(Op::infix(power, Left))
};
}
fn parse_infix_ops(pairs: Pairs<'_, Rule>) -> PytketParam<'_> {
PRATT_PARSER
.map_primary(|primary| parse_term(primary))
.map_infix(|lhs, op, rhs| {
let op = match op.as_rule() {
Rule::add => FloatOps::fadd,
Rule::subtract => FloatOps::fsub,
Rule::multiply => FloatOps::fmul,
Rule::divide => FloatOps::fdiv,
Rule::power => FloatOps::fpow,
rule => unreachable!("Expr::parse expected infix operation, found {:?}", rule),
}
.into();
PytketParam::Operation {
op,
args: vec![lhs, rhs],
}
})
.parse(pairs)
}
fn parse_term(pair: Pair<'_, Rule>) -> PytketParam<'_> {
match pair.as_rule() {
Rule::expr => parse_infix_ops(pair.into_inner()),
Rule::implicit_multiply => {
let mut pairs = pair.into_inner();
let lhs = parse_term(pairs.next().unwrap());
let rhs = parse_term(pairs.next().unwrap());
PytketParam::Operation {
op: FloatOps::fmul.into(),
args: vec![lhs, rhs],
}
}
Rule::num => parse_number(pair),
Rule::unary_minus => PytketParam::Operation {
op: FloatOps::fneg.into(),
args: vec![parse_term(pair.into_inner().next().unwrap())],
},
Rule::function_call => parse_function_call(pair),
Rule::ident => PytketParam::InputVariable {
name: pair.as_str(),
},
rule => unreachable!("Term::parse expected a term, found {:?}", rule),
}
}
fn parse_number(pair: Pair<'_, Rule>) -> PytketParam<'_> {
let num = pair.as_str();
let half_turns = num
.parse::<f64>()
.unwrap_or_else(|_| panic!("`num` rule matched invalid number \"{num}\""));
PytketParam::Constant(half_turns)
}
fn parse_function_call(pair: Pair<'_, Rule>) -> PytketParam<'_> {
let pair_str = pair.as_str();
let mut args = pair.into_inner();
let name = args
.next()
.expect("Function call must have a name")
.as_str();
let op = match name {
"max" => FloatOps::fmax.into(),
"min" => FloatOps::fmin.into(),
"abs" => FloatOps::fabs.into(),
"floor" => FloatOps::ffloor.into(),
"ceil" => FloatOps::fceil.into(),
"round" => FloatOps::fround.into(),
_ => return PytketParam::Sympy(pair_str),
};
let args = args.map(|arg| parse_term(arg)).collect::<Vec<_>>();
PytketParam::Operation { op, args }
}
#[cfg(test)]
mod test {
use super::*;
use rstest::rstest;
#[rstest]
#[case::int("42", PytketParam::Constant(42.0))]
#[case::float("42.37", PytketParam::Constant(42.37))]
#[case::float_pointless("37.", PytketParam::Constant(37.))]
#[case::exp("42e4", PytketParam::Constant(42e4))]
#[case::neg("-42.55", PytketParam::Constant(-42.55))]
#[case::parens("(42)", PytketParam::Constant(42.))]
#[case::var("f64", PytketParam::InputVariable{name: "f64"})]
#[case::add("42 + f64", PytketParam::Operation {
op: FloatOps::fadd.into(),
args: vec![PytketParam::Constant(42.), PytketParam::InputVariable{name: "f64"}]
})]
#[case::sub("42 - 2", PytketParam::Operation {
op: FloatOps::fsub.into(),
args: vec![PytketParam::Constant(42.), PytketParam::Constant(2.)]
})]
#[case::product_implicit("42 f64", PytketParam::Operation {
op: FloatOps::fmul.into(),
args: vec![PytketParam::Constant(42.), PytketParam::InputVariable{name: "f64"}]
})]
#[case::product_implicit2("42f64", PytketParam::Operation {
op: FloatOps::fmul.into(),
args: vec![PytketParam::Constant(42.), PytketParam::InputVariable{name: "f64"}]
})]
#[case::product_implicit3("42 e4", PytketParam::Operation {
op: FloatOps::fmul.into(),
args: vec![PytketParam::Constant(42.), PytketParam::InputVariable{name: "e4"}]
})]
#[case::max("max(42, f64)", PytketParam::Operation {
op: FloatOps::fmax.into(),
args: vec![PytketParam::Constant(42.), PytketParam::InputVariable{name: "f64"}]
})]
#[case::minus("-f64", PytketParam::Operation {
op: FloatOps::fneg.into(),
args: vec![PytketParam::InputVariable{name: "f64"}]
})]
#[case::unknown("unknown_op(42, f64)", PytketParam::Sympy("unknown_op(42, f64)"))]
#[case::unknown_no_params("unknown_op()", PytketParam::Sympy("unknown_op()"))]
#[case::nested("max(42, unknown_op(37))", PytketParam::Operation {
op: FloatOps::fmax.into(),
args: vec![PytketParam::Constant(42.), PytketParam::Sympy("unknown_op(37)")]
})]
#[case::precedence("5-2/3x+4**6", PytketParam::Operation {
op: FloatOps::fadd.into(),
args: vec![
PytketParam::Operation {
op: FloatOps::fsub.into(),
args: vec![
PytketParam::Constant(5.),
PytketParam::Operation { op: FloatOps::fdiv.into(), args: vec![
PytketParam::Constant(2.),
PytketParam::Operation { op: FloatOps::fmul.into(), args: vec![
PytketParam::Constant(3.),
PytketParam::InputVariable{name: "x"},
]}
]}
]
},
PytketParam::Operation { op: FloatOps::fpow.into(), args: vec![
PytketParam::Constant(4.),
PytketParam::Constant(6.),
]}
]
})]
#[case::associativity("1-2-3+4", PytketParam::Operation {
op: FloatOps::fadd.into(),
args: vec![
PytketParam::Operation { op: FloatOps::fsub.into(), args: vec![
PytketParam::Operation { op: FloatOps::fsub.into(), args: vec![
PytketParam::Constant(1.),
PytketParam::Constant(2.),
]},
PytketParam::Constant(3.),
]},
PytketParam::Constant(4.),
]
})]
fn parse_param(#[case] param: &str, #[case] expected: PytketParam) {
let parsed = parse_pytket_param(param);
if parsed != expected {
panic!("Incorrect parameter parsing\n\texpression: \"{param}\"\n\tparsed: {parsed}\n\texpected: {expected}");
}
}
}