quantrs2_symengine_pure/pattern/
mod.rs

1//! Pattern matching for quantum expressions.
2//!
3//! This module provides utilities for recognizing and extracting
4//! common patterns in quantum computing expressions.
5
6use std::collections::HashMap;
7
8use crate::error::{SymEngineError, SymEngineResult};
9use crate::expr::{ExprLang, Expression};
10
11/// A pattern that can match against expressions.
12#[derive(Clone, Debug)]
13pub enum Pattern {
14    /// Match any expression and capture it
15    Wildcard(String),
16    /// Match a specific constant
17    Constant(f64),
18    /// Match a specific symbol
19    Symbol(String),
20    /// Match zero
21    Zero,
22    /// Match one
23    One,
24    /// Match an addition pattern
25    Add(Box<Self>, Box<Self>),
26    /// Match a multiplication pattern
27    Mul(Box<Self>, Box<Self>),
28    /// Match a power pattern
29    Pow(Box<Self>, Box<Self>),
30    /// Match a negation pattern
31    Neg(Box<Self>),
32    /// Match a sine pattern
33    Sin(Box<Self>),
34    /// Match a cosine pattern
35    Cos(Box<Self>),
36    /// Match an exponential pattern
37    Exp(Box<Self>),
38    /// Match a logarithm pattern
39    Log(Box<Self>),
40    /// Match a commutator pattern
41    Commutator(Box<Self>, Box<Self>),
42    /// Match an anticommutator pattern
43    Anticommutator(Box<Self>, Box<Self>),
44    /// Match a tensor product pattern
45    TensorProduct(Box<Self>, Box<Self>),
46    /// Match a dagger pattern
47    Dagger(Box<Self>),
48}
49
50#[allow(clippy::should_implement_trait)]
51impl Pattern {
52    /// Create a wildcard pattern with the given name
53    #[must_use]
54    pub fn wildcard(name: &str) -> Self {
55        Self::Wildcard(name.to_string())
56    }
57
58    /// Create a symbol pattern
59    #[must_use]
60    pub fn symbol(name: &str) -> Self {
61        Self::Symbol(name.to_string())
62    }
63
64    /// Create a constant pattern
65    #[must_use]
66    pub const fn constant(value: f64) -> Self {
67        Self::Constant(value)
68    }
69
70    /// Create an addition pattern
71    #[must_use]
72    pub fn add(left: Self, right: Self) -> Self {
73        Self::Add(Box::new(left), Box::new(right))
74    }
75
76    /// Create a multiplication pattern
77    #[must_use]
78    pub fn mul(left: Self, right: Self) -> Self {
79        Self::Mul(Box::new(left), Box::new(right))
80    }
81
82    /// Create a power pattern
83    #[must_use]
84    pub fn pow(base: Self, exp: Self) -> Self {
85        Self::Pow(Box::new(base), Box::new(exp))
86    }
87
88    /// Create a sine pattern
89    #[must_use]
90    pub fn sin(arg: Self) -> Self {
91        Self::Sin(Box::new(arg))
92    }
93
94    /// Create a cosine pattern
95    #[must_use]
96    pub fn cos(arg: Self) -> Self {
97        Self::Cos(Box::new(arg))
98    }
99
100    /// Create a commutator pattern [A, B]
101    #[must_use]
102    pub fn commutator(a: Self, b: Self) -> Self {
103        Self::Commutator(Box::new(a), Box::new(b))
104    }
105
106    /// Create an anticommutator pattern {A, B}
107    #[must_use]
108    pub fn anticommutator(a: Self, b: Self) -> Self {
109        Self::Anticommutator(Box::new(a), Box::new(b))
110    }
111
112    /// Create a tensor product pattern A ⊗ B
113    #[must_use]
114    pub fn tensor(a: Self, b: Self) -> Self {
115        Self::TensorProduct(Box::new(a), Box::new(b))
116    }
117
118    /// Create a dagger pattern A†
119    #[must_use]
120    pub fn dagger(a: Self) -> Self {
121        Self::Dagger(Box::new(a))
122    }
123}
124
125/// Result of pattern matching - captured expressions
126pub type Captures = HashMap<String, Expression>;
127
128/// Match a pattern against an expression
129pub fn match_pattern(pattern: &Pattern, expr: &Expression) -> Option<Captures> {
130    let mut captures = Captures::new();
131    if match_pattern_rec(pattern, expr, &mut captures) {
132        Some(captures)
133    } else {
134        None
135    }
136}
137
138/// Recursive pattern matching helper
139#[allow(clippy::option_if_let_else)]
140fn match_pattern_rec(pattern: &Pattern, expr: &Expression, captures: &mut Captures) -> bool {
141    match pattern {
142        Pattern::Wildcard(name) => {
143            // Check if already captured with different value
144            if let Some(existing) = captures.get(name) {
145                // Must match the same expression
146                existing == expr
147            } else {
148                captures.insert(name.clone(), expr.clone());
149                true
150            }
151        }
152
153        Pattern::Constant(value) => {
154            if let Some(v) = expr.to_f64() {
155                (v - value).abs() < 1e-15
156            } else {
157                false
158            }
159        }
160
161        Pattern::Symbol(name) => expr.as_symbol() == Some(name.as_str()),
162
163        Pattern::Zero => expr.is_zero(),
164
165        Pattern::One => expr.is_one(),
166
167        // For compound patterns, we need to access the internal structure
168        // This requires parsing the expression representation
169        // For now, use string-based matching as a simple implementation
170        _ => match_compound_pattern(pattern, expr, captures),
171    }
172}
173
174/// Match compound patterns by expression structure
175fn match_compound_pattern(pattern: &Pattern, expr: &Expression, captures: &mut Captures) -> bool {
176    // Get the expression string for structural analysis
177    let expr_str = expr.to_string();
178
179    match pattern {
180        Pattern::Neg(inner) => {
181            if expr_str.starts_with("(neg ") {
182                // For negation patterns, we'd need to extract the inner expression
183                // This is a simplified implementation
184                let inner_expr = extract_unary_arg(expr, "neg");
185                if let Some(inner_expr) = inner_expr {
186                    return match_pattern_rec(inner, &inner_expr, captures);
187                }
188            }
189            false
190        }
191
192        Pattern::Sin(inner) => {
193            if expr_str.starts_with("(sin ") {
194                if let Some(inner_expr) = extract_unary_arg(expr, "sin") {
195                    return match_pattern_rec(inner, &inner_expr, captures);
196                }
197            }
198            false
199        }
200
201        Pattern::Cos(inner) => {
202            if expr_str.starts_with("(cos ") {
203                if let Some(inner_expr) = extract_unary_arg(expr, "cos") {
204                    return match_pattern_rec(inner, &inner_expr, captures);
205                }
206            }
207            false
208        }
209
210        Pattern::Exp(inner) => {
211            if expr_str.starts_with("(exp ") {
212                if let Some(inner_expr) = extract_unary_arg(expr, "exp") {
213                    return match_pattern_rec(inner, &inner_expr, captures);
214                }
215            }
216            false
217        }
218
219        Pattern::Log(inner) => {
220            if expr_str.starts_with("(log ") {
221                if let Some(inner_expr) = extract_unary_arg(expr, "log") {
222                    return match_pattern_rec(inner, &inner_expr, captures);
223                }
224            }
225            false
226        }
227
228        Pattern::Dagger(inner) => {
229            if expr_str.starts_with("(dagger ") {
230                if let Some(inner_expr) = extract_unary_arg(expr, "dagger") {
231                    return match_pattern_rec(inner, &inner_expr, captures);
232                }
233            }
234            false
235        }
236
237        // Binary patterns - simplified implementation
238        Pattern::Add(left, right) => {
239            if expr_str.starts_with("(+ ") {
240                if let Some((left_expr, right_expr)) = extract_binary_args(expr, "+") {
241                    return match_pattern_rec(left, &left_expr, captures)
242                        && match_pattern_rec(right, &right_expr, captures);
243                }
244            }
245            false
246        }
247
248        Pattern::Mul(left, right) => {
249            if expr_str.starts_with("(* ") {
250                if let Some((left_expr, right_expr)) = extract_binary_args(expr, "*") {
251                    return match_pattern_rec(left, &left_expr, captures)
252                        && match_pattern_rec(right, &right_expr, captures);
253                }
254            }
255            false
256        }
257
258        Pattern::Pow(base, exp) => {
259            if expr_str.starts_with("(^ ") {
260                if let Some((base_expr, exp_expr)) = extract_binary_args(expr, "^") {
261                    return match_pattern_rec(base, &base_expr, captures)
262                        && match_pattern_rec(exp, &exp_expr, captures);
263                }
264            }
265            false
266        }
267
268        Pattern::Commutator(a, b) => {
269            if expr_str.starts_with("(comm ") {
270                if let Some((a_expr, b_expr)) = extract_binary_args(expr, "comm") {
271                    return match_pattern_rec(a, &a_expr, captures)
272                        && match_pattern_rec(b, &b_expr, captures);
273                }
274            }
275            false
276        }
277
278        Pattern::Anticommutator(a, b) => {
279            if expr_str.starts_with("(anticomm ") {
280                if let Some((a_expr, b_expr)) = extract_binary_args(expr, "anticomm") {
281                    return match_pattern_rec(a, &a_expr, captures)
282                        && match_pattern_rec(b, &b_expr, captures);
283                }
284            }
285            false
286        }
287
288        Pattern::TensorProduct(a, b) => {
289            if expr_str.starts_with("(tensor ") {
290                if let Some((a_expr, b_expr)) = extract_binary_args(expr, "tensor") {
291                    return match_pattern_rec(a, &a_expr, captures)
292                        && match_pattern_rec(b, &b_expr, captures);
293                }
294            }
295            false
296        }
297
298        // These are handled in the main match
299        Pattern::Wildcard(_)
300        | Pattern::Constant(_)
301        | Pattern::Symbol(_)
302        | Pattern::Zero
303        | Pattern::One => unreachable!(),
304    }
305}
306
307/// Extract unary argument from expression (simplified)
308const fn extract_unary_arg(_expr: &Expression, _op: &str) -> Option<Expression> {
309    // In a full implementation, this would parse the RecExpr structure
310    // For now, return None as this requires deeper integration
311    None
312}
313
314/// Extract binary arguments from expression (simplified)
315const fn extract_binary_args(_expr: &Expression, _op: &str) -> Option<(Expression, Expression)> {
316    // In a full implementation, this would parse the RecExpr structure
317    // For now, return None as this requires deeper integration
318    None
319}
320
321// =========================================================================
322// Common Quantum Pattern Recognizers
323// =========================================================================
324
325/// Check if an expression is a rotation gate form: exp(-i * θ * G / 2)
326/// Returns the angle and generator if matched
327pub fn is_rotation_gate(expr: &Expression) -> Option<(Expression, Expression)> {
328    // Pattern: exp(* (neg (* ?i ?theta)) ?generator)
329    // This is a simplified check
330    let s = expr.to_string();
331    if s.starts_with("(exp ") {
332        // Could be a rotation gate form
333        // For a full implementation, we'd need to check the structure
334        return None;
335    }
336    None
337}
338
339/// Check if an expression represents a Hermitian operator (A = A†)
340pub fn is_hermitian_form(expr: &Expression) -> bool {
341    // Simple check: if it's a symbol, it could be Hermitian
342    // Real numbers are Hermitian
343    if expr.is_number() {
344        return true;
345    }
346    // Pauli matrices are Hermitian
347    expr.as_symbol().is_some_and(|sym| {
348        matches!(
349            sym,
350            "sigma_x" | "sigma_y" | "sigma_z" | "X" | "Y" | "Z" | "I"
351        )
352    })
353}
354
355/// Check if an expression is a projector (P² = P)
356pub const fn is_projector_form(expr: &Expression) -> bool {
357    // |ψ⟩⟨ψ| form is a projector
358    // This would require more sophisticated pattern matching
359    false
360}
361
362/// Check if an expression is a pure imaginary number (i * real)
363pub fn is_pure_imaginary(expr: &Expression) -> bool {
364    let s = expr.to_string();
365    s.contains("(* ") && s.contains(" I)") || s.contains("(* I ")
366}
367
368/// Check if an expression is a unit complex number (|z| = 1)
369pub fn is_unit_complex_form(expr: &Expression) -> bool {
370    let s = expr.to_string();
371    // exp(i * θ) has |exp(i*θ)| = 1
372    s.starts_with("(exp (* I ") || s.starts_with("(exp (* (neg I) ")
373}
374
375/// Recognize common quantum gate patterns
376#[derive(Debug, Clone, PartialEq, Eq)]
377pub enum QuantumGatePattern {
378    /// Pauli X gate
379    PauliX,
380    /// Pauli Y gate
381    PauliY,
382    /// Pauli Z gate
383    PauliZ,
384    /// Hadamard gate
385    Hadamard,
386    /// S gate (phase gate)
387    SGate,
388    /// T gate
389    TGate,
390    /// Rx rotation with angle
391    Rx(Expression),
392    /// Ry rotation with angle
393    Ry(Expression),
394    /// Rz rotation with angle
395    Rz(Expression),
396    /// General rotation
397    Rotation(Expression, Expression, Expression), // θ, φ, λ
398    /// Unknown gate
399    Unknown,
400}
401
402/// Try to recognize a quantum gate from its matrix expression
403pub fn recognize_gate_pattern(expr: &Expression) -> QuantumGatePattern {
404    if let Some(sym) = expr.as_symbol() {
405        match sym {
406            "X" | "sigma_x" | "pauli_x" => return QuantumGatePattern::PauliX,
407            "Y" | "sigma_y" | "pauli_y" => return QuantumGatePattern::PauliY,
408            "Z" | "sigma_z" | "pauli_z" => return QuantumGatePattern::PauliZ,
409            "H" | "hadamard" => return QuantumGatePattern::Hadamard,
410            "S" | "s_gate" => return QuantumGatePattern::SGate,
411            "T" | "t_gate" => return QuantumGatePattern::TGate,
412            _ => {}
413        }
414    }
415    QuantumGatePattern::Unknown
416}
417
418/// Recognize variational quantum circuit parameter patterns
419#[derive(Debug, Clone)]
420pub enum VariationalPattern {
421    /// Single parameter rotation
422    SingleRotation {
423        axis: char, // 'x', 'y', or 'z'
424        param: Expression,
425    },
426    /// Parametric entangling layer
427    EntanglingLayer { params: Vec<Expression> },
428    /// VQE ansatz pattern
429    VqeAnsatz { params: Vec<Expression> },
430    /// QAOA pattern
431    QaoaMixer { beta: Expression },
432    /// QAOA cost pattern
433    QaoaCost { gamma: Expression },
434}
435
436/// Check if expression matches a VQE parameter pattern
437pub fn is_vqe_parameter(expr: &Expression) -> bool {
438    expr.as_symbol().is_some_and(|sym| {
439        sym.starts_with("theta") || sym.starts_with("phi") || sym.starts_with("lambda")
440    })
441}
442
443/// Check if expression matches a QAOA parameter
444pub fn is_qaoa_parameter(expr: &Expression) -> bool {
445    expr.as_symbol()
446        .is_some_and(|sym| sym.starts_with("beta") || sym.starts_with("gamma"))
447}
448
449#[cfg(test)]
450mod tests {
451    use super::*;
452
453    #[test]
454    fn test_wildcard_pattern() {
455        let x = Expression::symbol("x");
456        let pattern = Pattern::wildcard("a");
457
458        let result = match_pattern(&pattern, &x);
459        assert!(result.is_some());
460
461        let captures = result.expect("should match");
462        assert!(captures.contains_key("a"));
463        assert_eq!(captures.get("a").expect("has a").as_symbol(), Some("x"));
464    }
465
466    #[test]
467    fn test_symbol_pattern() {
468        let x = Expression::symbol("x");
469        let pattern = Pattern::symbol("x");
470
471        assert!(match_pattern(&pattern, &x).is_some());
472
473        let y = Expression::symbol("y");
474        assert!(match_pattern(&pattern, &y).is_none());
475    }
476
477    #[test]
478    fn test_constant_pattern() {
479        let expr = Expression::float_unchecked(2.5);
480        let pattern = Pattern::constant(2.5);
481
482        assert!(match_pattern(&pattern, &expr).is_some());
483
484        let pattern2 = Pattern::constant(3.0);
485        assert!(match_pattern(&pattern2, &expr).is_none());
486    }
487
488    #[test]
489    fn test_zero_one_patterns() {
490        let zero = Expression::zero();
491        let one = Expression::one();
492
493        assert!(match_pattern(&Pattern::Zero, &zero).is_some());
494        assert!(match_pattern(&Pattern::One, &one).is_some());
495        assert!(match_pattern(&Pattern::Zero, &one).is_none());
496        assert!(match_pattern(&Pattern::One, &zero).is_none());
497    }
498
499    #[test]
500    fn test_gate_recognition() {
501        let x = Expression::symbol("X");
502        assert_eq!(recognize_gate_pattern(&x), QuantumGatePattern::PauliX);
503
504        let y = Expression::symbol("sigma_y");
505        assert_eq!(recognize_gate_pattern(&y), QuantumGatePattern::PauliY);
506
507        let h = Expression::symbol("H");
508        assert_eq!(recognize_gate_pattern(&h), QuantumGatePattern::Hadamard);
509    }
510
511    #[test]
512    fn test_hermitian_recognition() {
513        let x = Expression::symbol("X");
514        assert!(is_hermitian_form(&x));
515
516        let num = Expression::float_unchecked(2.5);
517        assert!(is_hermitian_form(&num));
518    }
519
520    #[test]
521    fn test_vqe_parameter_recognition() {
522        let theta = Expression::symbol("theta_1");
523        assert!(is_vqe_parameter(&theta));
524
525        let x = Expression::symbol("x");
526        assert!(!is_vqe_parameter(&x));
527    }
528
529    #[test]
530    fn test_qaoa_parameter_recognition() {
531        let beta = Expression::symbol("beta_0");
532        assert!(is_qaoa_parameter(&beta));
533
534        let gamma = Expression::symbol("gamma_1");
535        assert!(is_qaoa_parameter(&gamma));
536
537        let x = Expression::symbol("x");
538        assert!(!is_qaoa_parameter(&x));
539    }
540}