qudit_core/quantum/
unitary.rs

1//! Implements a struct for unitary matrices and associated methods for the Openqudit library.
2
3use std::fmt::Debug;
4use std::fmt::Formatter;
5use std::ops::Deref;
6use std::ops::Index;
7use std::ops::Mul;
8use std::ops::Sub;
9
10use coe::{coerce_static, is_same};
11use faer::mat::AsMatRef;
12use faer::unzip;
13use faer::zip;
14use num_traits::Float;
15
16use crate::ComplexScalar;
17use crate::QuditPermutation;
18use crate::QuditSystem;
19use crate::Radices;
20use crate::RealScalar;
21use crate::bitwidth::BitWidthConvertible;
22use crate::c32;
23use crate::c64;
24use faer::Mat;
25use faer::MatMut;
26use faer::MatRef;
27
28/// A unitary matrix over a qudit system.
29///
30/// This is a thin wrapper around a matrix that ensures it is unitary.
31#[derive(Clone)]
32pub struct UnitaryMatrix<C: ComplexScalar> {
33    radices: Radices,
34    matrix: Mat<C>,
35}
36
37impl<C: ComplexScalar> UnitaryMatrix<C> {
38    /// Create a new unitary matrix.
39    ///
40    /// # Arguments
41    ///
42    /// * `radices` - The radices of the qudit system.
43    ///
44    /// * `matrix` - The matrix to wrap.
45    ///
46    /// # Panics
47    ///
48    /// Panics if the matrix is not unitary.
49    ///
50    /// # Example
51    ///
52    /// ```
53    /// use faer::Mat;
54    /// use qudit_core::UnitaryMatrix;
55    /// use qudit_core::c64;
56    /// let unitary: UnitaryMatrix<c64> = UnitaryMatrix::new([2, 2], Mat::<c64>::identity(4, 4));
57    /// ```
58    ///
59    /// # See Also
60    ///
61    /// * [UnitaryMatrix::is_unitary] - Check if a matrix is a unitary.
62    /// * [UnitaryMatrix::new_unchecked] - Create a unitary without checking unitary conditions.
63    /// * [UnitaryMatrix::identity] - Create a unitary identity matrix.
64    /// * [UnitaryMatrix::random] - Create a random unitary matrix.
65    #[inline(always)]
66    #[track_caller]
67    pub fn new<T: Into<Radices>>(radices: T, matrix: Mat<C>) -> Self {
68        assert!(Self::is_unitary(&matrix));
69        Self {
70            matrix,
71            radices: radices.into(),
72        }
73    }
74
75    /// Create a new unitary matrix without checking if it is unitary.
76    ///
77    /// # Arguments
78    ///
79    /// * `radices` - The radices of the qudit system.
80    ///
81    /// * `matrix` - The matrix to wrap.
82    ///
83    /// # Safety
84    ///
85    /// The caller must ensure that the provided matrix is unitary.
86    ///
87    /// # Example
88    ///
89    /// ```
90    /// use faer::Mat;
91    /// use qudit_core::UnitaryMatrix;
92    /// use qudit_core::c64;
93    /// let unitary: UnitaryMatrix<c64> = UnitaryMatrix::new_unchecked([2, 2], Mat::identity(4, 4));
94    /// ```
95    ///
96    /// # See Also
97    ///
98    /// * [UnitaryMatrix::new] - Create a unitary matrix.
99    /// * [UnitaryMatrix::identity] - Create a unitary identity matrix.
100    /// * [UnitaryMatrix::random] - Create a random unitary matrix.
101    /// * [UnitaryMatrix::is_unitary] - Check if a matrix is a unitary.
102    #[inline(always)]
103    #[track_caller]
104    pub fn new_unchecked<T: Into<Radices>>(radices: T, matrix: Mat<C>) -> Self {
105        Self {
106            matrix,
107            radices: radices.into(),
108        }
109    }
110
111    /// Create a new identity unitary matrix for a given qudit system.
112    ///
113    /// # Arguments
114    ///
115    /// * `radices` - The radices of the qudit system.
116    ///
117    /// # Returns
118    ///
119    /// A new unitary matrix that is the identity.
120    ///
121    /// # Example
122    ///
123    /// ```
124    /// use faer::Mat;
125    /// use qudit_core::UnitaryMatrix;
126    /// use qudit_core::c64;
127    /// let unitary: UnitaryMatrix<c64> = UnitaryMatrix::identity([2, 2]);
128    /// assert_eq!(unitary, UnitaryMatrix::new([2, 2], Mat::identity(4, 4)));
129    /// ```
130    ///
131    /// # See Also
132    ///
133    /// * [UnitaryMatrix::new] - Create a unitary matrix.
134    /// * [UnitaryMatrix::random] - Create a random unitary matrix.
135    pub fn identity<T: Into<Radices>>(radices: T) -> Self {
136        let radices = radices.into();
137        let dim = radices.dimension();
138        Self::new(radices, Mat::identity(dim, dim))
139    }
140
141    /// Generate a random Unitary from the haar distribution.
142    ///
143    /// Reference:
144    /// - <https://arxiv.org/pdf/math-ph/0609050v2.pdf>
145    ///
146    /// # Arguments
147    ///
148    /// * `radices` - The radices of the qudit system.
149    ///
150    /// # Returns
151    ///
152    /// A new unitary matrix that is random.
153    ///
154    /// # Example
155    ///
156    /// ```
157    /// use qudit_core::c64;
158    /// use qudit_core::UnitaryMatrix;
159    /// let unitary: UnitaryMatrix<c64> = UnitaryMatrix::random([2, 2]);
160    /// assert!(UnitaryMatrix::is_unitary(&unitary));
161    /// ```
162    ///
163    /// # See Also
164    ///
165    /// * [UnitaryMatrix::new] - Create a unitary matrix.
166    /// * [UnitaryMatrix::identity] - Create a unitary identity matrix.
167    pub fn random<T: Into<Radices>>(radices: T) -> Self {
168        let radices = radices.into();
169        let n = radices.dimension();
170        let standard: Mat<C> =
171            Mat::from_fn(n, n, |_, _| C::standard_random() / C::from_real(2.0).sqrt());
172        let qr = standard.qr();
173        let r = qr.R();
174        let mut q = qr.compute_Q();
175        for j in 0..n {
176            let r = r[(j, j)];
177            let r = if r == C::zero() {
178                C::one()
179            } else {
180                r / r.abs()
181            };
182
183            zip!(q.as_mut().col_mut(j)).for_each(|unzip!(q)| {
184                *q *= r;
185            });
186        }
187        UnitaryMatrix::new(radices, q)
188    }
189
190    /// Check if a matrix is unitary.
191    ///
192    /// A matrix is unitary if it satisfies the following condition:
193    /// ```math
194    /// U U^\dagger = U^\dagger U = I
195    /// ```
196    ///
197    /// Where `U` is the matrix, `U^\dagger` is the dagger (conjugate-transpose)
198    /// of `U`, and `I` is the identity matrix of the same size.
199    ///
200    /// # Arguments
201    ///
202    /// * `mat` - The matrix to check.
203    ///
204    /// # Returns
205    ///
206    /// `true` if the matrix is unitary, `false` otherwise.
207    ///
208    /// # Example
209    ///
210    /// ```
211    /// use qudit_core::c64;
212    /// use faer::mat;
213    /// use faer::Mat;
214    /// use qudit_core::UnitaryMatrix;
215    /// let mat: Mat<c64> = Mat::identity(2, 2);
216    /// assert!(UnitaryMatrix::is_unitary(&mat));
217    ///
218    /// let mat = Mat::from_fn(2, 2, |_, _| c64::new(1.0, 1.0));
219    /// assert!(!UnitaryMatrix::is_unitary(&mat));
220    /// ```
221    ///
222    /// # Notes
223    ///
224    /// The function checks the l2 norm or frobenius norm of the difference
225    /// between the product of the matrix and its adjoint and the identity.
226    /// Due to floating point errors, the norm is checked against a threshold
227    /// defined by the `THRESHOLD` constant in the `ComplexScalar` trait.
228    ///
229    /// # See Also
230    ///
231    /// * [ComplexScalar] - The floating point number type used for the matrix.
232    /// * [RealScalar::is_close] - The threshold used to check if a matrix is unitary.
233    pub fn is_unitary(mat: impl AsMatRef<T = C, Rows = usize, Cols = usize>) -> bool {
234        let mat_ref = mat.as_mat_ref();
235
236        if mat_ref.nrows() != mat_ref.ncols() {
237            return false;
238        }
239
240        let id: Mat<C> = Mat::identity(mat_ref.nrows(), mat_ref.ncols());
241        let product = mat_ref * mat_ref.adjoint().to_owned();
242        let error = product - id;
243        C::R::is_close(error.norm_l2(), 0.0)
244    }
245
246    /// Global-phase-agnostic, psuedo-metric over the space of unitaries.
247    ///
248    /// This is based on the hilbert-schmidt inner product. It is defined as:
249    ///
250    /// ```math
251    /// \sqrt{1 - \big(\frac{|\text{tr}(A B^\dagger)|}{\text{dim}(A)}\big)^2}
252    /// ```
253    ///
254    /// Where `A` and `B` are the unitaries, `B^\dagger` is the conjugate transpose
255    /// of `B`, `|\text{tr}(A B^\dagger)|` is the absolute value of the trace of the
256    /// product of `A` and `B^\dagger`, and `dim(A)` is the dimension of `A`.
257    ///
258    /// # Arguments
259    ///
260    /// * `x` - The other unitary matrix.
261    ///
262    /// # Returns
263    ///
264    /// The distance between the two unitaries.
265    ///
266    /// # Panics
267    ///
268    /// Panics if the two unitaries have different dimensions.
269    ///
270    /// # Example
271    ///
272    /// ```
273    /// use qudit_core::c64;
274    /// use faer::mat;
275    /// use faer::Mat;
276    /// use qudit_core::UnitaryMatrix;
277    /// use qudit_core::ComplexScalar;
278    /// let u1: UnitaryMatrix<c64> = UnitaryMatrix::identity([2, 2]);
279    /// let u2 = UnitaryMatrix::identity([2, 2]);
280    /// let u3 = UnitaryMatrix::random([2, 2]);
281    /// assert_eq!(u1.get_distance_from(&u2), 0.0);
282    /// assert!(u1.get_distance_from(&u3) > 0.0);
283    /// ```
284    ///
285    /// # See Also
286    ///
287    /// * [RealScalar::is_close] - The threshold used to check if a matrix is unitary.
288    pub fn get_distance_from(&self, x: impl AsMatRef<T = C, Rows = usize, Cols = usize>) -> C::R {
289        let mat_ref = x.as_mat_ref();
290
291        if mat_ref.nrows() != self.nrows() || mat_ref.ncols() != self.ncols() {
292            panic!("Unitary and matrix must have same shape.");
293        }
294
295        let mut acc = C::ZERO;
296        zip!(self.matrix.as_ref(), mat_ref).for_each(|unzip!(a, b)| {
297            acc += *a * b.conj();
298        });
299        let num = acc.abs();
300        let dem = C::R::from64(self.dimension() as f64);
301        if num > dem {
302            // This shouldn't happen but can due to floating point errors.
303            // If it does, we correct it to zero.
304            C::R::from64(0.0)
305        } else {
306            (C::R::from64(1.0) - (num / dem).powi(2i32)).sqrt()
307        }
308    }
309
310    /// Permute the unitary matrix according to a qudit system permutation.
311    ///
312    /// # Arguments
313    ///
314    /// * `perm` - The permutation to apply.
315    ///
316    /// # Returns
317    ///
318    /// A newly allocated unitary matrix that is the result of applying the
319    /// permutation to the original unitary matrix.
320    ///
321    /// # Panics
322    ///
323    /// Panics if there is a radix mismatch between the unitary matrix and
324    /// the permutation.
325    ///
326    /// # Example
327    ///
328    /// ```
329    /// use qudit_core::c64;
330    /// use faer::mat;
331    /// use faer::Mat;
332    /// use qudit_core::UnitaryMatrix;
333    /// use qudit_core::QuditPermutation;
334    /// use num_traits::{One, Zero};
335    /// let unitary: UnitaryMatrix<c64> = UnitaryMatrix::identity([2, 2]);
336    /// let perm = QuditPermutation::new([2, 2], &vec![1, 0]);
337    /// let permuted = unitary.permute(&perm);
338    /// let mat = mat![
339    ///     [c64::one(), c64::zero(), c64::zero(), c64::zero()],
340    ///     [c64::zero(), c64::zero(), c64::one(), c64::zero()],
341    ///     [c64::zero(), c64::one(), c64::zero(), c64::zero()],
342    ///     [c64::zero(), c64::zero(), c64::zero(), c64::one()],
343    /// ];
344    /// assert_eq!(permuted, UnitaryMatrix::new([2, 2], Mat::identity(4, 4)));
345    /// ```
346    pub fn permute(&self, perm: &QuditPermutation) -> UnitaryMatrix<C> {
347        assert_eq!(perm.radices(), self.radices());
348        UnitaryMatrix::new(perm.permuted_radices(), perm.apply(&self.matrix))
349    }
350
351    /// Conjugate the unitary matrix.
352    ///
353    /// # Returns
354    ///
355    /// A newly allocated unitary matrix that is the conjugate of the original
356    /// unitary matrix.
357    ///
358    /// # Example
359    /// ```
360    /// use qudit_core::c64;
361    /// use faer::mat;
362    /// use faer::Mat;
363    /// use qudit_core::UnitaryMatrix;
364    /// use num_traits::Zero;
365    /// let y_mat = mat![
366    ///    [c64::zero(), c64::new(0.0, -1.0)],
367    ///    [c64::new(0.0, 1.0), c64::zero()],
368    /// ];
369    /// let unitary = UnitaryMatrix::new([2], y_mat);
370    /// let conjugate = unitary.conjugate();
371    /// let y_mat_conjugate = mat![
372    ///     [c64::zero(), c64::new(0.0, 1.0)],
373    ///     [c64::new(0.0, -1.0), c64::zero()],
374    /// ];
375    /// assert_eq!(conjugate, UnitaryMatrix::new([2], y_mat_conjugate));
376    /// ```
377    pub fn conjugate(&self) -> UnitaryMatrix<C> {
378        Self::new(self.radices.clone(), self.matrix.conjugate().to_owned())
379    }
380
381    /// Transpose the unitary matrix.
382    ///
383    /// # Returns
384    ///
385    /// A newly allocated unitary matrix that is the transpose of the original
386    /// unitary matrix.
387    ///
388    /// # Example
389    ///
390    /// ```
391    /// use qudit_core::c64;
392    /// use faer::mat;
393    /// use faer::Mat;
394    /// use qudit_core::UnitaryMatrix;
395    /// use num_traits::Zero;
396    /// let y_mat = mat![
397    ///   [c64::zero(), c64::new(0.0, -1.0)],
398    ///   [c64::new(0.0, 1.0), c64::zero()],
399    /// ];
400    /// let unitary = UnitaryMatrix::new([2], y_mat);
401    /// let transpose = unitary.transpose();
402    /// let y_mat_transpose = mat![
403    ///     [c64::zero(), c64::new(0.0, 1.0)],
404    ///     [c64::new(0.0, -1.0), c64::zero()],
405    /// ];
406    /// assert_eq!(transpose, UnitaryMatrix::new([2], y_mat_transpose));
407    /// ```
408    pub fn transpose(&self) -> UnitaryMatrix<C> {
409        Self::new(self.radices.clone(), self.matrix.transpose().to_owned())
410    }
411
412    /// Adjoint or dagger the unitary matrix.
413    ///
414    /// # Returns
415    ///
416    /// A newly allocated unitary matrix that is the adjoint of the original
417    /// unitary matrix.
418    ///
419    /// # Example
420    ///
421    /// ```
422    /// use qudit_core::c64;
423    /// use faer::mat;
424    /// use faer::Mat;
425    /// use qudit_core::UnitaryMatrix;
426    /// use num_traits::Zero;
427    /// let y_mat: Mat<c64> = mat![
428    ///     [c64::zero(), c64::new(0.0, -1.0)],
429    ///     [c64::new(0.0, 1.0), c64::zero()],
430    /// ];
431    /// let unitary = UnitaryMatrix::new([2], y_mat);
432    /// let dagger: UnitaryMatrix<c64> = unitary.dagger();
433    /// let y_mat_adjoint = mat![
434    ///     [c64::zero(), c64::new(0.0, -1.0)],
435    ///     [c64::new(0.0, 1.0), c64::zero()],
436    /// ];
437    /// assert_eq!(dagger, UnitaryMatrix::new([2], y_mat_adjoint));
438    /// assert_eq!(dagger.dagger(), unitary);
439    /// assert_eq!(dagger.dot(&unitary), Mat::<c64>::identity(2, 2));
440    /// assert_eq!(unitary.dot(&dagger), Mat::<c64>::identity(2, 2));
441    /// ```
442    pub fn dagger(&self) -> Self {
443        Self::new(self.radices.clone(), self.matrix.adjoint().to_owned())
444    }
445
446    /// Adjoint or dagger the unitary matrix (Alias for [Self::dagger]).
447    pub fn adjoint(&self) -> Self {
448        self.dagger()
449    }
450
451    /// Multiply the unitary matrix by another matrix.
452    ///
453    /// # Arguments
454    ///
455    /// * `rhs` - The matrix to multiply by.
456    ///
457    /// # Returns
458    ///
459    /// A newly allocated unitary matrix that is the result of multiplying the
460    /// original unitary matrix by the other matrix.
461    ///
462    /// # Panics
463    ///
464    /// Panics if the two matrices have different dimensions.
465    ///
466    /// # Example
467    ///
468    /// ```
469    /// use qudit_core::c64;
470    /// use faer::mat;
471    /// use faer::Mat;
472    /// use qudit_core::UnitaryMatrix;
473    /// use num_traits::Zero;
474    /// let y_mat = mat![
475    ///    [c64::zero(), c64::new(0.0, -1.0)],
476    ///    [c64::new(0.0, 1.0), c64::zero()],
477    /// ];
478    /// let unitary = UnitaryMatrix::new([2], y_mat.clone());
479    /// let result = unitary.dot(&unitary);
480    /// assert_eq!(result, UnitaryMatrix::new([2], y_mat.clone() * y_mat));
481    /// ```
482    ///
483    /// # See Also
484    ///
485    /// * [crate::accel::matmul_unchecked] - The accelerated version of the matrix multiplication.
486    pub fn dot(&self, rhs: impl AsMatRef<T = C, Rows = usize, Cols = usize>) -> Self {
487        Self::new(
488            self.radices.clone(),
489            self.matrix.as_ref() * rhs.as_mat_ref(),
490        )
491    }
492
493    /// Kronecker product the unitary matrix with another matrix.
494    ///
495    /// # Arguments
496    ///
497    /// * `rhs` - The matrix to kronecker product with.
498    ///
499    /// # Returns
500    ///
501    /// A newly allocated unitary matrix that is the result of kronecker producting
502    /// the original unitary matrix with the other matrix.
503    ///
504    /// # Example
505    ///
506    /// ```
507    /// use qudit_core::c64;
508    /// use faer::mat;
509    /// use faer::Mat;
510    /// use qudit_core::UnitaryMatrix;
511    /// use num_traits::Zero;
512    /// let y_mat = mat![
513    ///   [c64::zero(), c64::new(0.0, -1.0)],
514    ///   [c64::new(0.0, 1.0), c64::zero()],
515    /// ];
516    /// let unitary = UnitaryMatrix::new([2], y_mat.clone());
517    /// let result = unitary.kron(&unitary);
518    /// let y_mat_kron = mat![
519    ///    [c64::zero(), c64::zero(), c64::zero(), c64::new(-1.0, 0.0)],
520    ///    [c64::zero(), c64::zero(), c64::new(1.0, 0.0), c64::zero()],
521    ///    [c64::zero(), c64::new(1.0, 0.0), c64::zero(), c64::zero()],
522    ///    [c64::new(-1.0, 0.0), c64::zero(), c64::zero(), c64::zero()],
523    /// ];
524    /// assert_eq!(result, UnitaryMatrix::new([2, 2], y_mat_kron));
525    /// ```
526    ///
527    /// # See Also
528    ///
529    /// * [Mat::kron] - The method used to perform the kronecker product.
530    /// * [crate::accel::kron] - The accelerated version of the kronecker product.
531    pub fn kron(&self, rhs: &UnitaryMatrix<C>) -> Self {
532        let mut dst = Mat::zeros(self.nrows() * rhs.nrows(), self.ncols() * rhs.ncols());
533        faer::linalg::kron::kron(dst.as_mut(), self.matrix.as_ref(), rhs.as_mat_ref());
534        Self::new(self.radices.concat(&rhs.radices()), dst)
535    }
536}
537
538impl<C: ComplexScalar> QuditSystem for UnitaryMatrix<C> {
539    #[inline(always)]
540    fn radices(&self) -> Radices {
541        self.radices.clone()
542    }
543
544    #[inline(always)]
545    fn num_qudits(&self) -> usize {
546        self.radices.len()
547    }
548
549    #[inline(always)]
550    fn dimension(&self) -> usize {
551        self.matrix.nrows()
552    }
553}
554
555impl<C: ComplexScalar> Deref for UnitaryMatrix<C> {
556    type Target = Mat<C>;
557
558    #[inline(always)]
559    fn deref(&self) -> &Self::Target {
560        &self.matrix
561    }
562}
563
564impl<C: ComplexScalar> AsMatRef for UnitaryMatrix<C> {
565    type T = C;
566    type Rows = usize;
567    type Cols = usize;
568    type Owned = Mat<C>;
569
570    #[inline(always)]
571    fn as_mat_ref(&self) -> MatRef<'_, C> {
572        self.matrix.as_ref()
573    }
574}
575
576impl<C: ComplexScalar> Index<(usize, usize)> for UnitaryMatrix<C> {
577    type Output = C;
578
579    #[inline(always)]
580    fn index(&self, index: (usize, usize)) -> &Self::Output {
581        &self.matrix[index]
582    }
583}
584
585impl<C: ComplexScalar> Sub<UnitaryMatrix<C>> for UnitaryMatrix<C> {
586    type Output = Mat<C>;
587
588    fn sub(self, rhs: Self) -> Self::Output {
589        self.matrix - rhs.matrix
590    }
591}
592
593impl<C: ComplexScalar> Mul<UnitaryMatrix<C>> for UnitaryMatrix<C> {
594    type Output = UnitaryMatrix<C>;
595
596    fn mul(self, rhs: Self) -> Self::Output {
597        let output = Mat::from_fn(self.nrows(), self.ncols(), |i, j| {
598            self[(i, j)] * rhs[(i, j)]
599        });
600        UnitaryMatrix::new(self.radices, output)
601    }
602}
603
604impl<C: ComplexScalar> Mul<&UnitaryMatrix<C>> for Mat<C> {
605    type Output = Mat<C>;
606
607    fn mul(self, rhs: &UnitaryMatrix<C>) -> Self::Output {
608        Mat::from_fn(self.nrows(), self.ncols(), |i, j| {
609            self[(i, j)] * rhs[(i, j)]
610        })
611    }
612}
613
614impl<C: ComplexScalar> Mul<UnitaryMatrix<C>> for Mat<C> {
615    type Output = Mat<C>;
616
617    fn mul(self, rhs: UnitaryMatrix<C>) -> Self::Output {
618        Mat::from_fn(self.nrows(), self.ncols(), |i, j| {
619            self[(i, j)] * rhs[(i, j)]
620        })
621    }
622}
623
624impl<C: ComplexScalar> Mul<&UnitaryMatrix<C>> for &Mat<C> {
625    type Output = Mat<C>;
626
627    fn mul(self, rhs: &UnitaryMatrix<C>) -> Self::Output {
628        Mat::from_fn(self.nrows(), self.ncols(), |i, j| {
629            self[(i, j)] * rhs[(i, j)]
630        })
631    }
632}
633
634impl<C: ComplexScalar> Mul<UnitaryMatrix<C>> for &Mat<C> {
635    type Output = Mat<C>;
636
637    fn mul(self, rhs: UnitaryMatrix<C>) -> Self::Output {
638        Mat::from_fn(self.nrows(), self.ncols(), |i, j| {
639            self[(i, j)] * rhs[(i, j)]
640        })
641    }
642}
643
644impl<C: ComplexScalar> Debug for UnitaryMatrix<C> {
645    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
646        // TODO: print radices and unitary's complex numbers more cleanly
647        write!(f, "Unitary({:?})", self.matrix)
648    }
649}
650
651impl<C: ComplexScalar> PartialEq<UnitaryMatrix<C>> for UnitaryMatrix<C> {
652    fn eq(&self, other: &UnitaryMatrix<C>) -> bool {
653        self.matrix == other.matrix
654    }
655}
656
657impl<C: ComplexScalar> PartialEq<Mat<C>> for UnitaryMatrix<C> {
658    fn eq(&self, other: &Mat<C>) -> bool {
659        self.matrix == *other
660    }
661}
662
663impl<C: ComplexScalar> PartialEq<Mat<i32>> for UnitaryMatrix<C> {
664    fn eq(&self, other: &Mat<i32>) -> bool {
665        self.matrix.nrows() == other.nrows()
666            && self.matrix.ncols() == other.ncols()
667            && self.matrix.col_iter().zip(other.col_iter()).all(|(a, b)| {
668                a.iter()
669                    .zip(b.iter())
670                    .all(|(a, b)| *a == C::from(*b).unwrap())
671            })
672    }
673}
674
675impl<C: ComplexScalar> Eq for UnitaryMatrix<C> {}
676
677impl<C: ComplexScalar> From<UnitaryMatrix<C>> for Mat<C> {
678    fn from(unitary: UnitaryMatrix<C>) -> Self {
679        unitary.matrix
680    }
681}
682
683impl<'a, C: ComplexScalar> From<&'a UnitaryMatrix<C>> for MatRef<'a, C> {
684    fn from(unitary: &'a UnitaryMatrix<C>) -> Self {
685        unitary.matrix.as_ref()
686    }
687}
688
689impl<'a, C: ComplexScalar> From<&'a mut UnitaryMatrix<C>> for MatRef<'a, C> {
690    fn from(unitary: &'a mut UnitaryMatrix<C>) -> Self {
691        unitary.matrix.as_ref()
692    }
693}
694
695impl<'a, C: ComplexScalar> From<&'a mut UnitaryMatrix<C>> for MatMut<'a, C> {
696    fn from(unitary: &'a mut UnitaryMatrix<C>) -> Self {
697        unitary.matrix.as_mut()
698    }
699}
700
701impl<C: ComplexScalar> BitWidthConvertible for UnitaryMatrix<C> {
702    type Width32 = UnitaryMatrix<c32>;
703    type Width64 = UnitaryMatrix<c64>;
704
705    fn to32(self) -> Self::Width32 {
706        if is_same::<c32, C>() {
707            coerce_static(self)
708        } else {
709            let matrix = Mat::from_fn(self.matrix.nrows(), self.matrix.ncols(), |i, j| {
710                self.matrix[(i, j)].to32()
711            });
712            UnitaryMatrix::new(self.radices, matrix)
713        }
714    }
715
716    fn to64(self) -> Self::Width64 {
717        if is_same::<c64, C>() {
718            coerce_static(self)
719        } else {
720            let matrix = Mat::from_fn(self.matrix.nrows(), self.matrix.ncols(), |i, j| {
721                self.matrix[(i, j)].to64()
722            });
723            UnitaryMatrix::new(self.radices, matrix)
724        }
725    }
726
727    fn from32(unitary: Self::Width32) -> Self {
728        if is_same::<c32, C>() {
729            coerce_static(unitary)
730        } else {
731            let matrix = Mat::from_fn(unitary.matrix.nrows(), unitary.matrix.ncols(), |i, j| {
732                C::from32(unitary.matrix[(i, j)])
733            });
734            UnitaryMatrix::new(unitary.radices, matrix)
735        }
736    }
737
738    fn from64(unitary: Self::Width64) -> Self {
739        if is_same::<c64, C>() {
740            coerce_static(unitary)
741        } else {
742            let matrix = Mat::from_fn(unitary.matrix.nrows(), unitary.matrix.ncols(), |i, j| {
743                C::from64(unitary.matrix[(i, j)])
744            });
745            UnitaryMatrix::new(unitary.radices, matrix)
746        }
747    }
748}
749
750#[cfg(test)]
751mod test {
752    use super::*;
753
754    impl<C: ComplexScalar> UnitaryMatrix<C> {
755        /// Checks if the unitary matrix is close to another unitary matrix.
756        ///
757        /// # Arguments
758        ///
759        /// * `x` - The other unitary matrix to compare to.
760        ///
761        /// # Notes
762        ///
763        /// This is a global-phase-agnostic check. It uses the
764        /// `get_distance_from` method to calculate the distance
765        /// between the two matrices.
766        ///
767        /// # See Also
768        ///
769        /// * [UnitaryMatrix::get_distance_from] - The method used to calculate the distance.
770        pub fn assert_close_to(&self, x: impl AsMatRef<T = C, Rows = usize, Cols = usize>) {
771            let dist = self.get_distance_from(x);
772            assert!(
773                C::R::is_close(dist, 0.0),
774                "Distance between unitaries is {:?}",
775                dist
776            )
777        }
778    }
779}
780
781#[cfg(feature = "python")]
782mod python {
783    use super::*;
784    use crate::PyRegistrar;
785    use pyo3::{exceptions::PyTypeError, prelude::*};
786
787    use numpy::{PyArray2, PyReadonlyArray2, PyUntypedArrayMethods};
788
789    /// Python wrapper for UnitaryMatrix
790    #[pyclass(name = "UnitaryMatrix")]
791    #[derive(Clone)]
792    pub struct PyUnitaryMatrix {
793        inner: UnitaryMatrix<c64>,
794    }
795
796    #[pymethods]
797    impl PyUnitaryMatrix {
798        /// Create a new unitary matrix from a numpy array
799        #[new]
800        #[pyo3(signature = (radices, matrix))]
801        fn new(radices: Radices, matrix: PyReadonlyArray2<c64>) -> PyResult<Self> {
802            let array = matrix.as_array();
803            let (nrows, ncols) = array.dim();
804
805            let mat = Mat::from_fn(nrows, ncols, |i, j| array[[i, j]]);
806
807            if !UnitaryMatrix::is_unitary(&mat) {
808                return Err(pyo3::exceptions::PyValueError::new_err(
809                    "Matrix is not unitary",
810                ));
811            }
812
813            Ok(Self {
814                inner: UnitaryMatrix::new(radices, mat),
815            })
816        }
817
818        /// Convert to numpy array
819        fn to_array<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray2<c64>> {
820            let rows: Vec<Vec<c64>> = (0..self.inner.nrows())
821                .map(|i| {
822                    (0..self.inner.ncols())
823                        .map(|j| self.inner[(i, j)])
824                        .collect()
825                })
826                .collect();
827            PyArray2::from_vec2(py, &rows).unwrap()
828        }
829    }
830
831    // Conversion traits
832    impl From<UnitaryMatrix<c64>> for PyUnitaryMatrix {
833        fn from(inner: UnitaryMatrix<c64>) -> Self {
834            Self { inner }
835        }
836    }
837
838    impl From<PyUnitaryMatrix> for UnitaryMatrix<c64> {
839        fn from(py_unitary: PyUnitaryMatrix) -> Self {
840            py_unitary.inner
841        }
842    }
843
844    impl<'py> IntoPyObject<'py> for UnitaryMatrix<c64> {
845        type Target = PyUnitaryMatrix;
846        type Output = Bound<'py, Self::Target>;
847        type Error = PyErr;
848
849        fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
850            Bound::new(py, PyUnitaryMatrix::from(self))
851        }
852    }
853
854    impl<'a, 'py> FromPyObject<'a, 'py> for UnitaryMatrix<c64> {
855        type Error = PyErr;
856
857        fn extract(obj: Borrowed<'a, 'py, PyAny>) -> PyResult<Self> {
858            if let Ok(py_unitary) = obj.extract::<PyRef<PyUnitaryMatrix>>() {
859                Ok(py_unitary.inner.clone())
860            } else if let Ok(py_unitary) = obj.extract::<PyReadonlyArray2<c64>>() {
861                let shape = py_unitary.shape();
862                let dimension = shape[0];
863                let radices = Radices::guess(dimension);
864                Ok(PyUnitaryMatrix::new(radices, py_unitary)
865                    .unwrap()
866                    .inner
867                    .clone())
868            } else {
869                Err(PyTypeError::new_err(
870                    "Invalid type for unitary matrix extraction.",
871                ))
872            }
873        }
874    }
875
876    /// Registers the PyUnitaryMatrix with the Python Module.
877    fn register(parent_module: &Bound<'_, PyModule>) -> PyResult<()> {
878        parent_module.add_class::<PyUnitaryMatrix>()?;
879        Ok(())
880    }
881    inventory::submit! { PyRegistrar { func: register } }
882}