use std::collections::HashMap;
use crate::error::{SymEngineError, SymEngineResult};
use crate::expr::{ExprLang, Expression};
#[derive(Clone, Debug)]
pub enum Pattern {
Wildcard(String),
Constant(f64),
Symbol(String),
Zero,
One,
Add(Box<Self>, Box<Self>),
Mul(Box<Self>, Box<Self>),
Pow(Box<Self>, Box<Self>),
Neg(Box<Self>),
Sin(Box<Self>),
Cos(Box<Self>),
Exp(Box<Self>),
Log(Box<Self>),
Commutator(Box<Self>, Box<Self>),
Anticommutator(Box<Self>, Box<Self>),
TensorProduct(Box<Self>, Box<Self>),
Dagger(Box<Self>),
}
#[allow(clippy::should_implement_trait)]
impl Pattern {
#[must_use]
pub fn wildcard(name: &str) -> Self {
Self::Wildcard(name.to_string())
}
#[must_use]
pub fn symbol(name: &str) -> Self {
Self::Symbol(name.to_string())
}
#[must_use]
pub const fn constant(value: f64) -> Self {
Self::Constant(value)
}
#[must_use]
pub fn add(left: Self, right: Self) -> Self {
Self::Add(Box::new(left), Box::new(right))
}
#[must_use]
pub fn mul(left: Self, right: Self) -> Self {
Self::Mul(Box::new(left), Box::new(right))
}
#[must_use]
pub fn pow(base: Self, exp: Self) -> Self {
Self::Pow(Box::new(base), Box::new(exp))
}
#[must_use]
pub fn sin(arg: Self) -> Self {
Self::Sin(Box::new(arg))
}
#[must_use]
pub fn cos(arg: Self) -> Self {
Self::Cos(Box::new(arg))
}
#[must_use]
pub fn commutator(a: Self, b: Self) -> Self {
Self::Commutator(Box::new(a), Box::new(b))
}
#[must_use]
pub fn anticommutator(a: Self, b: Self) -> Self {
Self::Anticommutator(Box::new(a), Box::new(b))
}
#[must_use]
pub fn tensor(a: Self, b: Self) -> Self {
Self::TensorProduct(Box::new(a), Box::new(b))
}
#[must_use]
pub fn dagger(a: Self) -> Self {
Self::Dagger(Box::new(a))
}
}
pub type Captures = HashMap<String, Expression>;
pub fn match_pattern(pattern: &Pattern, expr: &Expression) -> Option<Captures> {
let mut captures = Captures::new();
if match_pattern_rec(pattern, expr, &mut captures) {
Some(captures)
} else {
None
}
}
#[allow(clippy::option_if_let_else)]
fn match_pattern_rec(pattern: &Pattern, expr: &Expression, captures: &mut Captures) -> bool {
match pattern {
Pattern::Wildcard(name) => {
if let Some(existing) = captures.get(name) {
existing == expr
} else {
captures.insert(name.clone(), expr.clone());
true
}
}
Pattern::Constant(value) => {
if let Some(v) = expr.to_f64() {
(v - value).abs() < 1e-15
} else {
false
}
}
Pattern::Symbol(name) => expr.as_symbol() == Some(name.as_str()),
Pattern::Zero => expr.is_zero(),
Pattern::One => expr.is_one(),
_ => match_compound_pattern(pattern, expr, captures),
}
}
fn match_compound_pattern(pattern: &Pattern, expr: &Expression, captures: &mut Captures) -> bool {
let expr_str = expr.to_string();
match pattern {
Pattern::Neg(inner) => {
if expr_str.starts_with("(neg ") {
let inner_expr = extract_unary_arg(expr, "neg");
if let Some(inner_expr) = inner_expr {
return match_pattern_rec(inner, &inner_expr, captures);
}
}
false
}
Pattern::Sin(inner) => {
if expr_str.starts_with("(sin ") {
if let Some(inner_expr) = extract_unary_arg(expr, "sin") {
return match_pattern_rec(inner, &inner_expr, captures);
}
}
false
}
Pattern::Cos(inner) => {
if expr_str.starts_with("(cos ") {
if let Some(inner_expr) = extract_unary_arg(expr, "cos") {
return match_pattern_rec(inner, &inner_expr, captures);
}
}
false
}
Pattern::Exp(inner) => {
if expr_str.starts_with("(exp ") {
if let Some(inner_expr) = extract_unary_arg(expr, "exp") {
return match_pattern_rec(inner, &inner_expr, captures);
}
}
false
}
Pattern::Log(inner) => {
if expr_str.starts_with("(log ") {
if let Some(inner_expr) = extract_unary_arg(expr, "log") {
return match_pattern_rec(inner, &inner_expr, captures);
}
}
false
}
Pattern::Dagger(inner) => {
if expr_str.starts_with("(dagger ") {
if let Some(inner_expr) = extract_unary_arg(expr, "dagger") {
return match_pattern_rec(inner, &inner_expr, captures);
}
}
false
}
Pattern::Add(left, right) => {
if expr_str.starts_with("(+ ") {
if let Some((left_expr, right_expr)) = extract_binary_args(expr, "+") {
return match_pattern_rec(left, &left_expr, captures)
&& match_pattern_rec(right, &right_expr, captures);
}
}
false
}
Pattern::Mul(left, right) => {
if expr_str.starts_with("(* ") {
if let Some((left_expr, right_expr)) = extract_binary_args(expr, "*") {
return match_pattern_rec(left, &left_expr, captures)
&& match_pattern_rec(right, &right_expr, captures);
}
}
false
}
Pattern::Pow(base, exp) => {
if expr_str.starts_with("(^ ") {
if let Some((base_expr, exp_expr)) = extract_binary_args(expr, "^") {
return match_pattern_rec(base, &base_expr, captures)
&& match_pattern_rec(exp, &exp_expr, captures);
}
}
false
}
Pattern::Commutator(a, b) => {
if expr_str.starts_with("(comm ") {
if let Some((a_expr, b_expr)) = extract_binary_args(expr, "comm") {
return match_pattern_rec(a, &a_expr, captures)
&& match_pattern_rec(b, &b_expr, captures);
}
}
false
}
Pattern::Anticommutator(a, b) => {
if expr_str.starts_with("(anticomm ") {
if let Some((a_expr, b_expr)) = extract_binary_args(expr, "anticomm") {
return match_pattern_rec(a, &a_expr, captures)
&& match_pattern_rec(b, &b_expr, captures);
}
}
false
}
Pattern::TensorProduct(a, b) => {
if expr_str.starts_with("(tensor ") {
if let Some((a_expr, b_expr)) = extract_binary_args(expr, "tensor") {
return match_pattern_rec(a, &a_expr, captures)
&& match_pattern_rec(b, &b_expr, captures);
}
}
false
}
Pattern::Wildcard(_)
| Pattern::Constant(_)
| Pattern::Symbol(_)
| Pattern::Zero
| Pattern::One => unreachable!(),
}
}
const fn extract_unary_arg(_expr: &Expression, _op: &str) -> Option<Expression> {
None
}
const fn extract_binary_args(_expr: &Expression, _op: &str) -> Option<(Expression, Expression)> {
None
}
pub fn is_rotation_gate(expr: &Expression) -> Option<(Expression, Expression)> {
let s = expr.to_string();
if s.starts_with("(exp ") {
return None;
}
None
}
pub fn is_hermitian_form(expr: &Expression) -> bool {
if expr.is_number() {
return true;
}
expr.as_symbol().is_some_and(|sym| {
matches!(
sym,
"sigma_x" | "sigma_y" | "sigma_z" | "X" | "Y" | "Z" | "I"
)
})
}
pub const fn is_projector_form(expr: &Expression) -> bool {
false
}
pub fn is_pure_imaginary(expr: &Expression) -> bool {
let s = expr.to_string();
s.contains("(* ") && s.contains(" I)") || s.contains("(* I ")
}
pub fn is_unit_complex_form(expr: &Expression) -> bool {
let s = expr.to_string();
s.starts_with("(exp (* I ") || s.starts_with("(exp (* (neg I) ")
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum QuantumGatePattern {
PauliX,
PauliY,
PauliZ,
Hadamard,
SGate,
TGate,
Rx(Expression),
Ry(Expression),
Rz(Expression),
Rotation(Expression, Expression, Expression), Unknown,
}
pub fn recognize_gate_pattern(expr: &Expression) -> QuantumGatePattern {
if let Some(sym) = expr.as_symbol() {
match sym {
"X" | "sigma_x" | "pauli_x" => return QuantumGatePattern::PauliX,
"Y" | "sigma_y" | "pauli_y" => return QuantumGatePattern::PauliY,
"Z" | "sigma_z" | "pauli_z" => return QuantumGatePattern::PauliZ,
"H" | "hadamard" => return QuantumGatePattern::Hadamard,
"S" | "s_gate" => return QuantumGatePattern::SGate,
"T" | "t_gate" => return QuantumGatePattern::TGate,
_ => {}
}
}
QuantumGatePattern::Unknown
}
#[derive(Debug, Clone)]
pub enum VariationalPattern {
SingleRotation {
axis: char, param: Expression,
},
EntanglingLayer { params: Vec<Expression> },
VqeAnsatz { params: Vec<Expression> },
QaoaMixer { beta: Expression },
QaoaCost { gamma: Expression },
}
pub fn is_vqe_parameter(expr: &Expression) -> bool {
expr.as_symbol().is_some_and(|sym| {
sym.starts_with("theta") || sym.starts_with("phi") || sym.starts_with("lambda")
})
}
pub fn is_qaoa_parameter(expr: &Expression) -> bool {
expr.as_symbol()
.is_some_and(|sym| sym.starts_with("beta") || sym.starts_with("gamma"))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_wildcard_pattern() {
let x = Expression::symbol("x");
let pattern = Pattern::wildcard("a");
let result = match_pattern(&pattern, &x);
assert!(result.is_some());
let captures = result.expect("should match");
assert!(captures.contains_key("a"));
assert_eq!(captures.get("a").expect("has a").as_symbol(), Some("x"));
}
#[test]
fn test_symbol_pattern() {
let x = Expression::symbol("x");
let pattern = Pattern::symbol("x");
assert!(match_pattern(&pattern, &x).is_some());
let y = Expression::symbol("y");
assert!(match_pattern(&pattern, &y).is_none());
}
#[test]
fn test_constant_pattern() {
let expr = Expression::float_unchecked(2.5);
let pattern = Pattern::constant(2.5);
assert!(match_pattern(&pattern, &expr).is_some());
let pattern2 = Pattern::constant(3.0);
assert!(match_pattern(&pattern2, &expr).is_none());
}
#[test]
fn test_zero_one_patterns() {
let zero = Expression::zero();
let one = Expression::one();
assert!(match_pattern(&Pattern::Zero, &zero).is_some());
assert!(match_pattern(&Pattern::One, &one).is_some());
assert!(match_pattern(&Pattern::Zero, &one).is_none());
assert!(match_pattern(&Pattern::One, &zero).is_none());
}
#[test]
fn test_gate_recognition() {
let x = Expression::symbol("X");
assert_eq!(recognize_gate_pattern(&x), QuantumGatePattern::PauliX);
let y = Expression::symbol("sigma_y");
assert_eq!(recognize_gate_pattern(&y), QuantumGatePattern::PauliY);
let h = Expression::symbol("H");
assert_eq!(recognize_gate_pattern(&h), QuantumGatePattern::Hadamard);
}
#[test]
fn test_hermitian_recognition() {
let x = Expression::symbol("X");
assert!(is_hermitian_form(&x));
let num = Expression::float_unchecked(2.5);
assert!(is_hermitian_form(&num));
}
#[test]
fn test_vqe_parameter_recognition() {
let theta = Expression::symbol("theta_1");
assert!(is_vqe_parameter(&theta));
let x = Expression::symbol("x");
assert!(!is_vqe_parameter(&x));
}
#[test]
fn test_qaoa_parameter_recognition() {
let beta = Expression::symbol("beta_0");
assert!(is_qaoa_parameter(&beta));
let gamma = Expression::symbol("gamma_1");
assert!(is_qaoa_parameter(&gamma));
let x = Expression::symbol("x");
assert!(!is_qaoa_parameter(&x));
}
}