qudit_expr/expressions/
complex.rs

1use std::collections::HashMap;
2
3use qudit_core::ComplexScalar;
4
5use crate::expressions::Constant;
6use crate::expressions::Expression;
7use crate::qgl::Expression as CiscExpression;
8use crate::qgl::parse_scalar;
9
10#[derive(Clone, PartialEq, Eq, Hash)]
11pub struct ComplexExpression {
12    pub real: Expression,
13    pub imag: Expression,
14}
15
16impl ComplexExpression {
17    pub fn from_string(input: impl AsRef<str>) -> Self {
18        ComplexExpression::new(
19            parse_scalar(input.as_ref()).unwrap_or_else(|e| panic!("Invalid input string: {e}")),
20        )
21    }
22
23    pub fn from_real_64(input: f64) -> Self {
24        ComplexExpression {
25            real: Expression::from_float_64(input),
26            imag: Expression::zero(),
27        }
28    }
29
30    pub fn from_real_32(input: f32) -> Self {
31        ComplexExpression {
32            real: Expression::from_float_32(input),
33            imag: Expression::zero(),
34        }
35    }
36
37    pub fn new(cisc_expr: CiscExpression) -> Self {
38        match cisc_expr {
39            CiscExpression::Number(num) => ComplexExpression {
40                real: Expression::Constant(
41                    Constant::from_float(num.parse::<f64>().unwrap()).unwrap(),
42                ),
43                imag: Expression::zero(),
44            },
45            CiscExpression::Variable(var) => {
46                if var == "i" {
47                    ComplexExpression {
48                        real: Expression::zero(),
49                        imag: Expression::one(),
50                    }
51                } else if var == "π" || var == "pi" {
52                    ComplexExpression {
53                        real: Expression::Pi,
54                        imag: Expression::zero(),
55                    }
56                } else {
57                    ComplexExpression {
58                        real: Expression::Variable(var),
59                        imag: Expression::zero(),
60                    }
61                }
62            }
63            CiscExpression::Unary { op, expr } => {
64                let risc_expr = ComplexExpression::new(*expr);
65                match op {
66                    '~' => ComplexExpression {
67                        real: Expression::Neg(Box::new(risc_expr.real)),
68                        imag: Expression::Neg(Box::new(risc_expr.imag)),
69                    },
70                    _ => panic!("Invalid unary operator: {}", op),
71                }
72            }
73            CiscExpression::Binary { op, lhs, rhs } => {
74                let risc_lhs = ComplexExpression::new(*lhs);
75                let risc_rhs = ComplexExpression::new(*rhs);
76                match op {
77                    '+' => risc_lhs + risc_rhs,
78                    '-' => risc_lhs - risc_rhs,
79                    '*' => risc_lhs * risc_rhs,
80                    '/' => risc_lhs / risc_rhs,
81                    '^' => {
82                        if risc_lhs.is_e() {
83                            assert!(risc_rhs.is_imag(), "Exponential power must be imaginary");
84                            ComplexExpression {
85                                real: Expression::Cos(Box::new(risc_rhs.imag.clone())),
86                                imag: Expression::Sin(Box::new(risc_rhs.imag)),
87                            }
88                        } else {
89                            assert!(risc_lhs.is_real(), "Power base must be real");
90                            assert!(risc_rhs.is_real(), "Power exponent must be real");
91                            ComplexExpression {
92                                real: Expression::Pow(
93                                    Box::new(risc_lhs.real),
94                                    Box::new(risc_rhs.real),
95                                ),
96                                imag: Expression::zero(),
97                            }
98                        }
99                    }
100                    _ => panic!("Invalid binary operator: {}", op),
101                }
102            }
103            CiscExpression::Call { fn_name, args } => match fn_name.as_str() {
104                "sqrt" => {
105                    let risc_arg = ComplexExpression::new(args[0].clone());
106                    assert!(args.len() == 1, "sqrt function takes exactly one argument");
107                    assert!(
108                        risc_arg.is_real(),
109                        "sqrt function is only supported for real numbers"
110                    );
111                    ComplexExpression {
112                        real: Expression::Sqrt(Box::new(risc_arg.real)),
113                        imag: Expression::zero(),
114                    }
115                }
116                "sin" => {
117                    let risc_arg = ComplexExpression::new(args[0].clone());
118                    assert!(args.len() == 1, "sin function takes exactly one argument");
119                    assert!(
120                        risc_arg.is_real(),
121                        "sin function is only supported for real numbers"
122                    );
123                    ComplexExpression {
124                        real: Expression::Sin(Box::new(risc_arg.real)),
125                        imag: Expression::zero(),
126                    }
127                }
128                "cos" => {
129                    let risc_arg = ComplexExpression::new(args[0].clone());
130                    assert!(args.len() == 1, "cos function takes exactly one argument");
131                    assert!(
132                        risc_arg.is_real(),
133                        "cos function is only supported for real numbers"
134                    );
135                    ComplexExpression {
136                        real: Expression::Cos(Box::new(risc_arg.real)),
137                        imag: Expression::zero(),
138                    }
139                }
140                "tan" => {
141                    let risc_arg = ComplexExpression::new(args[0].clone());
142                    assert!(args.len() == 1, "tan function takes exactly one argument");
143                    assert!(
144                        risc_arg.is_real(),
145                        "tan function is only supported for real numbers"
146                    );
147                    ComplexExpression {
148                        real: Expression::Div(
149                            Box::new(Expression::Sin(Box::new(risc_arg.real.clone()))),
150                            Box::new(Expression::Cos(Box::new(risc_arg.real))),
151                        ),
152                        imag: Expression::zero(),
153                    }
154                }
155                "csc" => {
156                    let risc_arg = ComplexExpression::new(args[0].clone());
157                    assert!(args.len() == 1, "csc function takes exactly one argument");
158                    assert!(
159                        risc_arg.is_real(),
160                        "csc function is only supported for real numbers"
161                    );
162                    ComplexExpression {
163                        real: Expression::Div(
164                            Box::new(Expression::one()),
165                            Box::new(Expression::Sin(Box::new(risc_arg.real))),
166                        ),
167                        imag: Expression::zero(),
168                    }
169                }
170                "sec" => {
171                    let risc_arg = ComplexExpression::new(args[0].clone());
172                    assert!(args.len() == 1, "sec function takes exactly one argument");
173                    assert!(
174                        risc_arg.is_real(),
175                        "sec function is only supported for real numbers"
176                    );
177                    ComplexExpression {
178                        real: Expression::Div(
179                            Box::new(Expression::one()),
180                            Box::new(Expression::Cos(Box::new(risc_arg.real))),
181                        ),
182                        imag: Expression::zero(),
183                    }
184                }
185                "cot" => {
186                    let risc_arg = ComplexExpression::new(args[0].clone());
187                    assert!(args.len() == 1, "cot function takes exactly one argument");
188                    assert!(
189                        risc_arg.is_real(),
190                        "cot function is only supported for real numbers"
191                    );
192                    ComplexExpression {
193                        real: Expression::Div(
194                            Box::new(Expression::Cos(Box::new(risc_arg.real.clone()))),
195                            Box::new(Expression::Sin(Box::new(risc_arg.real))),
196                        ),
197                        imag: Expression::zero(),
198                    }
199                }
200                _ => panic!("Invalid function name: {}", fn_name),
201            },
202            _ => panic!("Unexpected expression during complex conversion"),
203        }
204    }
205
206    pub fn one() -> Self {
207        ComplexExpression {
208            real: Expression::one(),
209            imag: Expression::zero(),
210        }
211    }
212
213    pub fn zero() -> Self {
214        ComplexExpression {
215            real: Expression::zero(),
216            imag: Expression::zero(),
217        }
218    }
219
220    pub fn i() -> Self {
221        ComplexExpression {
222            real: Expression::zero(),
223            imag: Expression::one(),
224        }
225    }
226
227    pub fn pi() -> Self {
228        ComplexExpression {
229            real: Expression::Pi,
230            imag: Expression::zero(),
231        }
232    }
233
234    pub fn is_real(&self) -> bool {
235        self.imag.is_zero()
236    }
237
238    pub fn is_imag(&self) -> bool {
239        self.real.is_zero() && !self.imag.is_zero()
240    }
241
242    pub fn is_cplx(&self) -> bool {
243        !self.real.is_zero() && !self.imag.is_zero()
244    }
245
246    pub fn is_real_fast(&self) -> bool {
247        self.imag.is_zero_fast()
248    }
249
250    pub fn is_imag_fast(&self) -> bool {
251        self.real.is_zero_fast() && !self.imag.is_zero_fast()
252    }
253
254    pub fn is_cplx_fast(&self) -> bool {
255        !self.real.is_zero_fast() && !self.imag.is_zero_fast()
256    }
257
258    pub fn is_e(&self) -> bool {
259        match &self.real {
260            Expression::Variable(var) => var == "e",
261            _ => false,
262        }
263    }
264
265    pub fn is_zero(&self) -> bool {
266        self.real.is_zero() && self.imag.is_zero()
267    }
268
269    pub fn is_one(&self) -> bool {
270        self.real.is_one() && self.imag.is_zero()
271    }
272
273    pub fn is_zero_fast(&self) -> bool {
274        self.real.is_zero_fast() && self.imag.is_zero_fast()
275    }
276
277    pub fn is_one_fast(&self) -> bool {
278        self.real.is_one_fast() && self.imag.is_zero_fast()
279    }
280
281    pub fn eval<C: ComplexScalar>(&self, args: &HashMap<&str, C::R>) -> C {
282        C::new(self.real.eval(args), self.imag.eval(args))
283    }
284
285    pub fn map_var_names(&self, var_map: &HashMap<String, String>) -> Self {
286        ComplexExpression {
287            real: self.real.map_var_names(var_map),
288            imag: self.imag.map_var_names(var_map),
289        }
290    }
291
292    pub fn conjugate(&self) -> Self {
293        ComplexExpression {
294            real: self.real.clone(),
295            imag: Expression::Neg(Box::new(self.imag.clone())),
296        }
297    }
298
299    pub fn conjugate_in_place(&mut self) {
300        self.imag = Expression::Neg(Box::new(self.imag.clone()));
301    }
302
303    pub fn differentiate(&self, wrt: &str) -> Self {
304        ComplexExpression {
305            real: self.real.differentiate(wrt),
306            imag: self.imag.differentiate(wrt),
307        }
308    }
309
310    pub fn simplify(&self) -> Self {
311        ComplexExpression {
312            real: self.real.simplify(),
313            imag: self.imag.simplify(),
314        }
315    }
316
317    pub fn get_ancestors(&self, variable: &str) -> Vec<Expression> {
318        let mut ancestors = self.real.get_ancestors(variable);
319        let im_ancestors = self.imag.get_ancestors(variable);
320        if ancestors.is_empty() {
321            return im_ancestors;
322        }
323        if !im_ancestors.is_empty() {
324            ancestors.retain(|x| im_ancestors.contains(x));
325        }
326        ancestors
327    }
328
329    pub fn substitute<S: AsRef<Expression>, T: AsRef<Expression>>(
330        &self,
331        original: S,
332        substitution: T,
333    ) -> Self {
334        // TODO: Allow ComplexExpression substitution
335        ComplexExpression {
336            real: self
337                .real
338                .substitute(original.as_ref(), substitution.as_ref()),
339            imag: self
340                .imag
341                .substitute(original.as_ref(), substitution.as_ref()),
342        }
343    }
344
345    pub fn rename_variable<S: AsRef<str>, T: AsRef<str>>(&self, original: S, new: T) -> Self {
346        ComplexExpression {
347            real: self.real.rename_variable(original.as_ref(), new.as_ref()),
348            imag: self.imag.rename_variable(original.as_ref(), new.as_ref()),
349        }
350    }
351
352    pub fn get_unique_variables(&self) -> Vec<String> {
353        let mut real = self.real.get_unique_variables();
354        for i in self.imag.get_unique_variables().into_iter() {
355            if !real.contains(&i) {
356                real.push(i)
357            }
358        }
359        real
360    }
361
362    pub fn is_parameterized(&self) -> bool {
363        self.real.is_parameterized() || self.imag.is_parameterized()
364    }
365}
366
367impl std::ops::Mul<ComplexExpression> for ComplexExpression {
368    type Output = Self;
369
370    fn mul(self, rhs: Self) -> Self {
371        &self * &rhs
372    }
373}
374
375impl std::ops::Mul<&ComplexExpression> for ComplexExpression {
376    type Output = ComplexExpression;
377
378    fn mul(self, rhs: &ComplexExpression) -> ComplexExpression {
379        &self * rhs
380    }
381}
382
383impl std::ops::Mul<ComplexExpression> for &ComplexExpression {
384    type Output = ComplexExpression;
385
386    fn mul(self, rhs: ComplexExpression) -> ComplexExpression {
387        self * &rhs
388    }
389}
390
391impl std::ops::Mul<&ComplexExpression> for &ComplexExpression {
392    type Output = ComplexExpression;
393
394    fn mul(self, rhs: &ComplexExpression) -> ComplexExpression {
395        ComplexExpression {
396            real: &self.real * &rhs.real - &self.imag * &rhs.imag,
397            imag: &self.real * &rhs.imag + &self.imag * &rhs.real,
398        }
399    }
400}
401
402impl std::ops::Add<ComplexExpression> for ComplexExpression {
403    type Output = Self;
404
405    fn add(self, rhs: Self) -> Self {
406        &self + &rhs
407    }
408}
409
410impl std::ops::Add<&ComplexExpression> for ComplexExpression {
411    type Output = ComplexExpression;
412
413    fn add(self, rhs: &ComplexExpression) -> ComplexExpression {
414        &self + rhs
415    }
416}
417
418impl std::ops::Add<ComplexExpression> for &ComplexExpression {
419    type Output = ComplexExpression;
420
421    fn add(self, rhs: ComplexExpression) -> ComplexExpression {
422        self + &rhs
423    }
424}
425
426impl std::ops::Add<&ComplexExpression> for &ComplexExpression {
427    type Output = ComplexExpression;
428
429    fn add(self, rhs: &ComplexExpression) -> ComplexExpression {
430        ComplexExpression {
431            real: &self.real + &rhs.real,
432            imag: &self.imag + &rhs.imag,
433        }
434    }
435}
436
437impl std::ops::Sub<ComplexExpression> for ComplexExpression {
438    type Output = Self;
439
440    fn sub(self, rhs: Self) -> Self {
441        &self - &rhs
442    }
443}
444
445impl std::ops::Sub<&ComplexExpression> for ComplexExpression {
446    type Output = ComplexExpression;
447
448    fn sub(self, rhs: &ComplexExpression) -> ComplexExpression {
449        &self - rhs
450    }
451}
452
453impl std::ops::Sub<ComplexExpression> for &ComplexExpression {
454    type Output = ComplexExpression;
455
456    fn sub(self, rhs: ComplexExpression) -> ComplexExpression {
457        self - &rhs
458    }
459}
460
461impl std::ops::Sub<&ComplexExpression> for &ComplexExpression {
462    type Output = ComplexExpression;
463
464    fn sub(self, rhs: &ComplexExpression) -> ComplexExpression {
465        ComplexExpression {
466            real: &self.real - &rhs.real,
467            imag: &self.imag - &rhs.imag,
468        }
469    }
470}
471
472impl std::ops::Neg for ComplexExpression {
473    type Output = Self;
474
475    fn neg(self) -> Self {
476        -&self
477    }
478}
479
480impl std::ops::Neg for &ComplexExpression {
481    type Output = ComplexExpression;
482
483    fn neg(self) -> ComplexExpression {
484        ComplexExpression {
485            real: -&self.real,
486            imag: -&self.imag,
487        }
488    }
489}
490
491impl std::ops::Div<ComplexExpression> for ComplexExpression {
492    type Output = Self;
493
494    fn div(self, rhs: Self) -> Self {
495        &self / &rhs
496    }
497}
498
499impl std::ops::Div<&ComplexExpression> for ComplexExpression {
500    type Output = ComplexExpression;
501
502    fn div(self, rhs: &ComplexExpression) -> ComplexExpression {
503        &self / rhs
504    }
505}
506
507impl std::ops::Div<ComplexExpression> for &ComplexExpression {
508    type Output = ComplexExpression;
509
510    fn div(self, rhs: ComplexExpression) -> ComplexExpression {
511        self / &rhs
512    }
513}
514
515impl std::ops::Div<&ComplexExpression> for &ComplexExpression {
516    type Output = ComplexExpression;
517
518    fn div(self, rhs: &ComplexExpression) -> ComplexExpression {
519        let dem = &rhs.real * &rhs.real + &rhs.imag * &rhs.imag;
520        ComplexExpression {
521            real: (&self.real * &rhs.real + &self.imag * &rhs.imag) / &dem,
522            imag: (&self.imag * &rhs.real - &self.real * &rhs.imag) / &dem,
523        }
524    }
525}
526
527impl std::ops::AddAssign for ComplexExpression {
528    fn add_assign(&mut self, rhs: Self) {
529        *self = &*self + &rhs;
530    }
531}
532
533impl std::ops::AddAssign<&ComplexExpression> for ComplexExpression {
534    fn add_assign(&mut self, rhs: &Self) {
535        *self = &*self + rhs;
536    }
537}
538
539impl std::ops::SubAssign for ComplexExpression {
540    fn sub_assign(&mut self, rhs: Self) {
541        *self = &*self - &rhs;
542    }
543}
544
545impl std::ops::SubAssign<&ComplexExpression> for ComplexExpression {
546    fn sub_assign(&mut self, rhs: &Self) {
547        *self = &*self - rhs;
548    }
549}
550
551impl std::ops::MulAssign for ComplexExpression {
552    fn mul_assign(&mut self, rhs: Self) {
553        *self = &*self * &rhs;
554    }
555}
556
557impl std::ops::MulAssign<&ComplexExpression> for ComplexExpression {
558    fn mul_assign(&mut self, rhs: &Self) {
559        *self = &*self * rhs;
560    }
561}
562
563impl std::ops::DivAssign for ComplexExpression {
564    fn div_assign(&mut self, rhs: Self) {
565        *self = &*self / &rhs;
566    }
567}
568
569impl std::ops::DivAssign<&ComplexExpression> for ComplexExpression {
570    fn div_assign(&mut self, rhs: &Self) {
571        *self = &*self / rhs;
572    }
573}
574
575impl std::fmt::Debug for ComplexExpression {
576    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
577        f.debug_struct("ComplexExpression")
578            .field("real", &self.real)
579            .field("imag", &self.imag)
580            .finish()
581    }
582}
583
584impl<C: ComplexScalar> From<C> for ComplexExpression {
585    fn from(value: C) -> Self {
586        ComplexExpression {
587            real: value.re().into(),
588            imag: value.im().into(),
589        }
590    }
591}