quantrs2_symengine_pure/pattern/
mod.rs1use std::collections::HashMap;
7
8use crate::error::{SymEngineError, SymEngineResult};
9use crate::expr::{ExprLang, Expression};
10
11#[derive(Clone, Debug)]
13pub enum Pattern {
14 Wildcard(String),
16 Constant(f64),
18 Symbol(String),
20 Zero,
22 One,
24 Add(Box<Self>, Box<Self>),
26 Mul(Box<Self>, Box<Self>),
28 Pow(Box<Self>, Box<Self>),
30 Neg(Box<Self>),
32 Sin(Box<Self>),
34 Cos(Box<Self>),
36 Exp(Box<Self>),
38 Log(Box<Self>),
40 Commutator(Box<Self>, Box<Self>),
42 Anticommutator(Box<Self>, Box<Self>),
44 TensorProduct(Box<Self>, Box<Self>),
46 Dagger(Box<Self>),
48}
49
50#[allow(clippy::should_implement_trait)]
51impl Pattern {
52 #[must_use]
54 pub fn wildcard(name: &str) -> Self {
55 Self::Wildcard(name.to_string())
56 }
57
58 #[must_use]
60 pub fn symbol(name: &str) -> Self {
61 Self::Symbol(name.to_string())
62 }
63
64 #[must_use]
66 pub const fn constant(value: f64) -> Self {
67 Self::Constant(value)
68 }
69
70 #[must_use]
72 pub fn add(left: Self, right: Self) -> Self {
73 Self::Add(Box::new(left), Box::new(right))
74 }
75
76 #[must_use]
78 pub fn mul(left: Self, right: Self) -> Self {
79 Self::Mul(Box::new(left), Box::new(right))
80 }
81
82 #[must_use]
84 pub fn pow(base: Self, exp: Self) -> Self {
85 Self::Pow(Box::new(base), Box::new(exp))
86 }
87
88 #[must_use]
90 pub fn sin(arg: Self) -> Self {
91 Self::Sin(Box::new(arg))
92 }
93
94 #[must_use]
96 pub fn cos(arg: Self) -> Self {
97 Self::Cos(Box::new(arg))
98 }
99
100 #[must_use]
102 pub fn commutator(a: Self, b: Self) -> Self {
103 Self::Commutator(Box::new(a), Box::new(b))
104 }
105
106 #[must_use]
108 pub fn anticommutator(a: Self, b: Self) -> Self {
109 Self::Anticommutator(Box::new(a), Box::new(b))
110 }
111
112 #[must_use]
114 pub fn tensor(a: Self, b: Self) -> Self {
115 Self::TensorProduct(Box::new(a), Box::new(b))
116 }
117
118 #[must_use]
120 pub fn dagger(a: Self) -> Self {
121 Self::Dagger(Box::new(a))
122 }
123}
124
125pub type Captures = HashMap<String, Expression>;
127
128pub fn match_pattern(pattern: &Pattern, expr: &Expression) -> Option<Captures> {
130 let mut captures = Captures::new();
131 if match_pattern_rec(pattern, expr, &mut captures) {
132 Some(captures)
133 } else {
134 None
135 }
136}
137
138#[allow(clippy::option_if_let_else)]
140fn match_pattern_rec(pattern: &Pattern, expr: &Expression, captures: &mut Captures) -> bool {
141 match pattern {
142 Pattern::Wildcard(name) => {
143 if let Some(existing) = captures.get(name) {
145 existing == expr
147 } else {
148 captures.insert(name.clone(), expr.clone());
149 true
150 }
151 }
152
153 Pattern::Constant(value) => {
154 if let Some(v) = expr.to_f64() {
155 (v - value).abs() < 1e-15
156 } else {
157 false
158 }
159 }
160
161 Pattern::Symbol(name) => expr.as_symbol() == Some(name.as_str()),
162
163 Pattern::Zero => expr.is_zero(),
164
165 Pattern::One => expr.is_one(),
166
167 _ => match_compound_pattern(pattern, expr, captures),
171 }
172}
173
174fn match_compound_pattern(pattern: &Pattern, expr: &Expression, captures: &mut Captures) -> bool {
176 let expr_str = expr.to_string();
178
179 match pattern {
180 Pattern::Neg(inner) => {
181 if expr_str.starts_with("(neg ") {
182 let inner_expr = extract_unary_arg(expr, "neg");
185 if let Some(inner_expr) = inner_expr {
186 return match_pattern_rec(inner, &inner_expr, captures);
187 }
188 }
189 false
190 }
191
192 Pattern::Sin(inner) => {
193 if expr_str.starts_with("(sin ") {
194 if let Some(inner_expr) = extract_unary_arg(expr, "sin") {
195 return match_pattern_rec(inner, &inner_expr, captures);
196 }
197 }
198 false
199 }
200
201 Pattern::Cos(inner) => {
202 if expr_str.starts_with("(cos ") {
203 if let Some(inner_expr) = extract_unary_arg(expr, "cos") {
204 return match_pattern_rec(inner, &inner_expr, captures);
205 }
206 }
207 false
208 }
209
210 Pattern::Exp(inner) => {
211 if expr_str.starts_with("(exp ") {
212 if let Some(inner_expr) = extract_unary_arg(expr, "exp") {
213 return match_pattern_rec(inner, &inner_expr, captures);
214 }
215 }
216 false
217 }
218
219 Pattern::Log(inner) => {
220 if expr_str.starts_with("(log ") {
221 if let Some(inner_expr) = extract_unary_arg(expr, "log") {
222 return match_pattern_rec(inner, &inner_expr, captures);
223 }
224 }
225 false
226 }
227
228 Pattern::Dagger(inner) => {
229 if expr_str.starts_with("(dagger ") {
230 if let Some(inner_expr) = extract_unary_arg(expr, "dagger") {
231 return match_pattern_rec(inner, &inner_expr, captures);
232 }
233 }
234 false
235 }
236
237 Pattern::Add(left, right) => {
239 if expr_str.starts_with("(+ ") {
240 if let Some((left_expr, right_expr)) = extract_binary_args(expr, "+") {
241 return match_pattern_rec(left, &left_expr, captures)
242 && match_pattern_rec(right, &right_expr, captures);
243 }
244 }
245 false
246 }
247
248 Pattern::Mul(left, right) => {
249 if expr_str.starts_with("(* ") {
250 if let Some((left_expr, right_expr)) = extract_binary_args(expr, "*") {
251 return match_pattern_rec(left, &left_expr, captures)
252 && match_pattern_rec(right, &right_expr, captures);
253 }
254 }
255 false
256 }
257
258 Pattern::Pow(base, exp) => {
259 if expr_str.starts_with("(^ ") {
260 if let Some((base_expr, exp_expr)) = extract_binary_args(expr, "^") {
261 return match_pattern_rec(base, &base_expr, captures)
262 && match_pattern_rec(exp, &exp_expr, captures);
263 }
264 }
265 false
266 }
267
268 Pattern::Commutator(a, b) => {
269 if expr_str.starts_with("(comm ") {
270 if let Some((a_expr, b_expr)) = extract_binary_args(expr, "comm") {
271 return match_pattern_rec(a, &a_expr, captures)
272 && match_pattern_rec(b, &b_expr, captures);
273 }
274 }
275 false
276 }
277
278 Pattern::Anticommutator(a, b) => {
279 if expr_str.starts_with("(anticomm ") {
280 if let Some((a_expr, b_expr)) = extract_binary_args(expr, "anticomm") {
281 return match_pattern_rec(a, &a_expr, captures)
282 && match_pattern_rec(b, &b_expr, captures);
283 }
284 }
285 false
286 }
287
288 Pattern::TensorProduct(a, b) => {
289 if expr_str.starts_with("(tensor ") {
290 if let Some((a_expr, b_expr)) = extract_binary_args(expr, "tensor") {
291 return match_pattern_rec(a, &a_expr, captures)
292 && match_pattern_rec(b, &b_expr, captures);
293 }
294 }
295 false
296 }
297
298 Pattern::Wildcard(_)
300 | Pattern::Constant(_)
301 | Pattern::Symbol(_)
302 | Pattern::Zero
303 | Pattern::One => unreachable!(),
304 }
305}
306
307const fn extract_unary_arg(_expr: &Expression, _op: &str) -> Option<Expression> {
309 None
312}
313
314const fn extract_binary_args(_expr: &Expression, _op: &str) -> Option<(Expression, Expression)> {
316 None
319}
320
321pub fn is_rotation_gate(expr: &Expression) -> Option<(Expression, Expression)> {
328 let s = expr.to_string();
331 if s.starts_with("(exp ") {
332 return None;
335 }
336 None
337}
338
339pub fn is_hermitian_form(expr: &Expression) -> bool {
341 if expr.is_number() {
344 return true;
345 }
346 expr.as_symbol().is_some_and(|sym| {
348 matches!(
349 sym,
350 "sigma_x" | "sigma_y" | "sigma_z" | "X" | "Y" | "Z" | "I"
351 )
352 })
353}
354
355pub const fn is_projector_form(expr: &Expression) -> bool {
357 false
360}
361
362pub fn is_pure_imaginary(expr: &Expression) -> bool {
364 let s = expr.to_string();
365 s.contains("(* ") && s.contains(" I)") || s.contains("(* I ")
366}
367
368pub fn is_unit_complex_form(expr: &Expression) -> bool {
370 let s = expr.to_string();
371 s.starts_with("(exp (* I ") || s.starts_with("(exp (* (neg I) ")
373}
374
375#[derive(Debug, Clone, PartialEq, Eq)]
377pub enum QuantumGatePattern {
378 PauliX,
380 PauliY,
382 PauliZ,
384 Hadamard,
386 SGate,
388 TGate,
390 Rx(Expression),
392 Ry(Expression),
394 Rz(Expression),
396 Rotation(Expression, Expression, Expression), Unknown,
400}
401
402pub fn recognize_gate_pattern(expr: &Expression) -> QuantumGatePattern {
404 if let Some(sym) = expr.as_symbol() {
405 match sym {
406 "X" | "sigma_x" | "pauli_x" => return QuantumGatePattern::PauliX,
407 "Y" | "sigma_y" | "pauli_y" => return QuantumGatePattern::PauliY,
408 "Z" | "sigma_z" | "pauli_z" => return QuantumGatePattern::PauliZ,
409 "H" | "hadamard" => return QuantumGatePattern::Hadamard,
410 "S" | "s_gate" => return QuantumGatePattern::SGate,
411 "T" | "t_gate" => return QuantumGatePattern::TGate,
412 _ => {}
413 }
414 }
415 QuantumGatePattern::Unknown
416}
417
418#[derive(Debug, Clone)]
420pub enum VariationalPattern {
421 SingleRotation {
423 axis: char, param: Expression,
425 },
426 EntanglingLayer { params: Vec<Expression> },
428 VqeAnsatz { params: Vec<Expression> },
430 QaoaMixer { beta: Expression },
432 QaoaCost { gamma: Expression },
434}
435
436pub fn is_vqe_parameter(expr: &Expression) -> bool {
438 expr.as_symbol().is_some_and(|sym| {
439 sym.starts_with("theta") || sym.starts_with("phi") || sym.starts_with("lambda")
440 })
441}
442
443pub fn is_qaoa_parameter(expr: &Expression) -> bool {
445 expr.as_symbol()
446 .is_some_and(|sym| sym.starts_with("beta") || sym.starts_with("gamma"))
447}
448
449#[cfg(test)]
450mod tests {
451 use super::*;
452
453 #[test]
454 fn test_wildcard_pattern() {
455 let x = Expression::symbol("x");
456 let pattern = Pattern::wildcard("a");
457
458 let result = match_pattern(&pattern, &x);
459 assert!(result.is_some());
460
461 let captures = result.expect("should match");
462 assert!(captures.contains_key("a"));
463 assert_eq!(captures.get("a").expect("has a").as_symbol(), Some("x"));
464 }
465
466 #[test]
467 fn test_symbol_pattern() {
468 let x = Expression::symbol("x");
469 let pattern = Pattern::symbol("x");
470
471 assert!(match_pattern(&pattern, &x).is_some());
472
473 let y = Expression::symbol("y");
474 assert!(match_pattern(&pattern, &y).is_none());
475 }
476
477 #[test]
478 fn test_constant_pattern() {
479 let expr = Expression::float_unchecked(2.5);
480 let pattern = Pattern::constant(2.5);
481
482 assert!(match_pattern(&pattern, &expr).is_some());
483
484 let pattern2 = Pattern::constant(3.0);
485 assert!(match_pattern(&pattern2, &expr).is_none());
486 }
487
488 #[test]
489 fn test_zero_one_patterns() {
490 let zero = Expression::zero();
491 let one = Expression::one();
492
493 assert!(match_pattern(&Pattern::Zero, &zero).is_some());
494 assert!(match_pattern(&Pattern::One, &one).is_some());
495 assert!(match_pattern(&Pattern::Zero, &one).is_none());
496 assert!(match_pattern(&Pattern::One, &zero).is_none());
497 }
498
499 #[test]
500 fn test_gate_recognition() {
501 let x = Expression::symbol("X");
502 assert_eq!(recognize_gate_pattern(&x), QuantumGatePattern::PauliX);
503
504 let y = Expression::symbol("sigma_y");
505 assert_eq!(recognize_gate_pattern(&y), QuantumGatePattern::PauliY);
506
507 let h = Expression::symbol("H");
508 assert_eq!(recognize_gate_pattern(&h), QuantumGatePattern::Hadamard);
509 }
510
511 #[test]
512 fn test_hermitian_recognition() {
513 let x = Expression::symbol("X");
514 assert!(is_hermitian_form(&x));
515
516 let num = Expression::float_unchecked(2.5);
517 assert!(is_hermitian_form(&num));
518 }
519
520 #[test]
521 fn test_vqe_parameter_recognition() {
522 let theta = Expression::symbol("theta_1");
523 assert!(is_vqe_parameter(&theta));
524
525 let x = Expression::symbol("x");
526 assert!(!is_vqe_parameter(&x));
527 }
528
529 #[test]
530 fn test_qaoa_parameter_recognition() {
531 let beta = Expression::symbol("beta_0");
532 assert!(is_qaoa_parameter(&beta));
533
534 let gamma = Expression::symbol("gamma_1");
535 assert!(is_qaoa_parameter(&gamma));
536
537 let x = Expression::symbol("x");
538 assert!(!is_qaoa_parameter(&x));
539 }
540}