qudit_core/quantum/unitary.rs
1//! Implements a struct for unitary matrices and associated methods for the Openqudit library.
2
3use std::fmt::Debug;
4use std::fmt::Formatter;
5use std::ops::Deref;
6use std::ops::Index;
7use std::ops::Mul;
8use std::ops::Sub;
9
10use coe::{coerce_static, is_same};
11use faer::mat::AsMatRef;
12use faer::unzip;
13use faer::zip;
14use num_traits::Float;
15
16use crate::ComplexScalar;
17use crate::QuditPermutation;
18use crate::QuditSystem;
19use crate::Radices;
20use crate::RealScalar;
21use crate::bitwidth::BitWidthConvertible;
22use crate::c32;
23use crate::c64;
24use faer::Mat;
25use faer::MatMut;
26use faer::MatRef;
27
28/// A unitary matrix over a qudit system.
29///
30/// This is a thin wrapper around a matrix that ensures it is unitary.
31#[derive(Clone)]
32pub struct UnitaryMatrix<C: ComplexScalar> {
33 radices: Radices,
34 matrix: Mat<C>,
35}
36
37impl<C: ComplexScalar> UnitaryMatrix<C> {
38 /// Create a new unitary matrix.
39 ///
40 /// # Arguments
41 ///
42 /// * `radices` - The radices of the qudit system.
43 ///
44 /// * `matrix` - The matrix to wrap.
45 ///
46 /// # Panics
47 ///
48 /// Panics if the matrix is not unitary.
49 ///
50 /// # Example
51 ///
52 /// ```
53 /// use faer::Mat;
54 /// use qudit_core::UnitaryMatrix;
55 /// use qudit_core::c64;
56 /// let unitary: UnitaryMatrix<c64> = UnitaryMatrix::new([2, 2], Mat::<c64>::identity(4, 4));
57 /// ```
58 ///
59 /// # See Also
60 ///
61 /// * [UnitaryMatrix::is_unitary] - Check if a matrix is a unitary.
62 /// * [UnitaryMatrix::new_unchecked] - Create a unitary without checking unitary conditions.
63 /// * [UnitaryMatrix::identity] - Create a unitary identity matrix.
64 /// * [UnitaryMatrix::random] - Create a random unitary matrix.
65 #[inline(always)]
66 #[track_caller]
67 pub fn new<T: Into<Radices>>(radices: T, matrix: Mat<C>) -> Self {
68 assert!(Self::is_unitary(&matrix));
69 Self {
70 matrix,
71 radices: radices.into(),
72 }
73 }
74
75 /// Create a new unitary matrix without checking if it is unitary.
76 ///
77 /// # Arguments
78 ///
79 /// * `radices` - The radices of the qudit system.
80 ///
81 /// * `matrix` - The matrix to wrap.
82 ///
83 /// # Safety
84 ///
85 /// The caller must ensure that the provided matrix is unitary.
86 ///
87 /// # Example
88 ///
89 /// ```
90 /// use faer::Mat;
91 /// use qudit_core::UnitaryMatrix;
92 /// use qudit_core::c64;
93 /// let unitary: UnitaryMatrix<c64> = UnitaryMatrix::new_unchecked([2, 2], Mat::identity(4, 4));
94 /// ```
95 ///
96 /// # See Also
97 ///
98 /// * [UnitaryMatrix::new] - Create a unitary matrix.
99 /// * [UnitaryMatrix::identity] - Create a unitary identity matrix.
100 /// * [UnitaryMatrix::random] - Create a random unitary matrix.
101 /// * [UnitaryMatrix::is_unitary] - Check if a matrix is a unitary.
102 #[inline(always)]
103 #[track_caller]
104 pub fn new_unchecked<T: Into<Radices>>(radices: T, matrix: Mat<C>) -> Self {
105 Self {
106 matrix,
107 radices: radices.into(),
108 }
109 }
110
111 /// Create a new identity unitary matrix for a given qudit system.
112 ///
113 /// # Arguments
114 ///
115 /// * `radices` - The radices of the qudit system.
116 ///
117 /// # Returns
118 ///
119 /// A new unitary matrix that is the identity.
120 ///
121 /// # Example
122 ///
123 /// ```
124 /// use faer::Mat;
125 /// use qudit_core::UnitaryMatrix;
126 /// use qudit_core::c64;
127 /// let unitary: UnitaryMatrix<c64> = UnitaryMatrix::identity([2, 2]);
128 /// assert_eq!(unitary, UnitaryMatrix::new([2, 2], Mat::identity(4, 4)));
129 /// ```
130 ///
131 /// # See Also
132 ///
133 /// * [UnitaryMatrix::new] - Create a unitary matrix.
134 /// * [UnitaryMatrix::random] - Create a random unitary matrix.
135 pub fn identity<T: Into<Radices>>(radices: T) -> Self {
136 let radices = radices.into();
137 let dim = radices.dimension();
138 Self::new(radices, Mat::identity(dim, dim))
139 }
140
141 /// Generate a random Unitary from the haar distribution.
142 ///
143 /// Reference:
144 /// - <https://arxiv.org/pdf/math-ph/0609050v2.pdf>
145 ///
146 /// # Arguments
147 ///
148 /// * `radices` - The radices of the qudit system.
149 ///
150 /// # Returns
151 ///
152 /// A new unitary matrix that is random.
153 ///
154 /// # Example
155 ///
156 /// ```
157 /// use qudit_core::c64;
158 /// use qudit_core::UnitaryMatrix;
159 /// let unitary: UnitaryMatrix<c64> = UnitaryMatrix::random([2, 2]);
160 /// assert!(UnitaryMatrix::is_unitary(&unitary));
161 /// ```
162 ///
163 /// # See Also
164 ///
165 /// * [UnitaryMatrix::new] - Create a unitary matrix.
166 /// * [UnitaryMatrix::identity] - Create a unitary identity matrix.
167 pub fn random<T: Into<Radices>>(radices: T) -> Self {
168 let radices = radices.into();
169 let n = radices.dimension();
170 let standard: Mat<C> =
171 Mat::from_fn(n, n, |_, _| C::standard_random() / C::from_real(2.0).sqrt());
172 let qr = standard.qr();
173 let r = qr.R();
174 let mut q = qr.compute_Q();
175 for j in 0..n {
176 let r = r[(j, j)];
177 let r = if r == C::zero() {
178 C::one()
179 } else {
180 r / r.abs()
181 };
182
183 zip!(q.as_mut().col_mut(j)).for_each(|unzip!(q)| {
184 *q *= r;
185 });
186 }
187 UnitaryMatrix::new(radices, q)
188 }
189
190 /// Check if a matrix is unitary.
191 ///
192 /// A matrix is unitary if it satisfies the following condition:
193 /// ```math
194 /// U U^\dagger = U^\dagger U = I
195 /// ```
196 ///
197 /// Where `U` is the matrix, `U^\dagger` is the dagger (conjugate-transpose)
198 /// of `U`, and `I` is the identity matrix of the same size.
199 ///
200 /// # Arguments
201 ///
202 /// * `mat` - The matrix to check.
203 ///
204 /// # Returns
205 ///
206 /// `true` if the matrix is unitary, `false` otherwise.
207 ///
208 /// # Example
209 ///
210 /// ```
211 /// use qudit_core::c64;
212 /// use faer::mat;
213 /// use faer::Mat;
214 /// use qudit_core::UnitaryMatrix;
215 /// let mat: Mat<c64> = Mat::identity(2, 2);
216 /// assert!(UnitaryMatrix::is_unitary(&mat));
217 ///
218 /// let mat = Mat::from_fn(2, 2, |_, _| c64::new(1.0, 1.0));
219 /// assert!(!UnitaryMatrix::is_unitary(&mat));
220 /// ```
221 ///
222 /// # Notes
223 ///
224 /// The function checks the l2 norm or frobenius norm of the difference
225 /// between the product of the matrix and its adjoint and the identity.
226 /// Due to floating point errors, the norm is checked against a threshold
227 /// defined by the `THRESHOLD` constant in the `ComplexScalar` trait.
228 ///
229 /// # See Also
230 ///
231 /// * [ComplexScalar] - The floating point number type used for the matrix.
232 /// * [RealScalar::is_close] - The threshold used to check if a matrix is unitary.
233 pub fn is_unitary(mat: impl AsMatRef<T = C, Rows = usize, Cols = usize>) -> bool {
234 let mat_ref = mat.as_mat_ref();
235
236 if mat_ref.nrows() != mat_ref.ncols() {
237 return false;
238 }
239
240 let id: Mat<C> = Mat::identity(mat_ref.nrows(), mat_ref.ncols());
241 let product = mat_ref * mat_ref.adjoint().to_owned();
242 let error = product - id;
243 C::R::is_close(error.norm_l2(), 0.0)
244 }
245
246 /// Global-phase-agnostic, psuedo-metric over the space of unitaries.
247 ///
248 /// This is based on the hilbert-schmidt inner product. It is defined as:
249 ///
250 /// ```math
251 /// \sqrt{1 - \big(\frac{|\text{tr}(A B^\dagger)|}{\text{dim}(A)}\big)^2}
252 /// ```
253 ///
254 /// Where `A` and `B` are the unitaries, `B^\dagger` is the conjugate transpose
255 /// of `B`, `|\text{tr}(A B^\dagger)|` is the absolute value of the trace of the
256 /// product of `A` and `B^\dagger`, and `dim(A)` is the dimension of `A`.
257 ///
258 /// # Arguments
259 ///
260 /// * `x` - The other unitary matrix.
261 ///
262 /// # Returns
263 ///
264 /// The distance between the two unitaries.
265 ///
266 /// # Panics
267 ///
268 /// Panics if the two unitaries have different dimensions.
269 ///
270 /// # Example
271 ///
272 /// ```
273 /// use qudit_core::c64;
274 /// use faer::mat;
275 /// use faer::Mat;
276 /// use qudit_core::UnitaryMatrix;
277 /// use qudit_core::ComplexScalar;
278 /// let u1: UnitaryMatrix<c64> = UnitaryMatrix::identity([2, 2]);
279 /// let u2 = UnitaryMatrix::identity([2, 2]);
280 /// let u3 = UnitaryMatrix::random([2, 2]);
281 /// assert_eq!(u1.get_distance_from(&u2), 0.0);
282 /// assert!(u1.get_distance_from(&u3) > 0.0);
283 /// ```
284 ///
285 /// # See Also
286 ///
287 /// * [RealScalar::is_close] - The threshold used to check if a matrix is unitary.
288 pub fn get_distance_from(&self, x: impl AsMatRef<T = C, Rows = usize, Cols = usize>) -> C::R {
289 let mat_ref = x.as_mat_ref();
290
291 if mat_ref.nrows() != self.nrows() || mat_ref.ncols() != self.ncols() {
292 panic!("Unitary and matrix must have same shape.");
293 }
294
295 let mut acc = C::ZERO;
296 zip!(self.matrix.as_ref(), mat_ref).for_each(|unzip!(a, b)| {
297 acc += *a * b.conj();
298 });
299 let num = acc.abs();
300 let dem = C::R::from64(self.dimension() as f64);
301 if num > dem {
302 // This shouldn't happen but can due to floating point errors.
303 // If it does, we correct it to zero.
304 C::R::from64(0.0)
305 } else {
306 (C::R::from64(1.0) - (num / dem).powi(2i32)).sqrt()
307 }
308 }
309
310 /// Permute the unitary matrix according to a qudit system permutation.
311 ///
312 /// # Arguments
313 ///
314 /// * `perm` - The permutation to apply.
315 ///
316 /// # Returns
317 ///
318 /// A newly allocated unitary matrix that is the result of applying the
319 /// permutation to the original unitary matrix.
320 ///
321 /// # Panics
322 ///
323 /// Panics if there is a radix mismatch between the unitary matrix and
324 /// the permutation.
325 ///
326 /// # Example
327 ///
328 /// ```
329 /// use qudit_core::c64;
330 /// use faer::mat;
331 /// use faer::Mat;
332 /// use qudit_core::UnitaryMatrix;
333 /// use qudit_core::QuditPermutation;
334 /// use num_traits::{One, Zero};
335 /// let unitary: UnitaryMatrix<c64> = UnitaryMatrix::identity([2, 2]);
336 /// let perm = QuditPermutation::new([2, 2], &vec![1, 0]);
337 /// let permuted = unitary.permute(&perm);
338 /// let mat = mat![
339 /// [c64::one(), c64::zero(), c64::zero(), c64::zero()],
340 /// [c64::zero(), c64::zero(), c64::one(), c64::zero()],
341 /// [c64::zero(), c64::one(), c64::zero(), c64::zero()],
342 /// [c64::zero(), c64::zero(), c64::zero(), c64::one()],
343 /// ];
344 /// assert_eq!(permuted, UnitaryMatrix::new([2, 2], Mat::identity(4, 4)));
345 /// ```
346 pub fn permute(&self, perm: &QuditPermutation) -> UnitaryMatrix<C> {
347 assert_eq!(perm.radices(), self.radices());
348 UnitaryMatrix::new(perm.permuted_radices(), perm.apply(&self.matrix))
349 }
350
351 /// Conjugate the unitary matrix.
352 ///
353 /// # Returns
354 ///
355 /// A newly allocated unitary matrix that is the conjugate of the original
356 /// unitary matrix.
357 ///
358 /// # Example
359 /// ```
360 /// use qudit_core::c64;
361 /// use faer::mat;
362 /// use faer::Mat;
363 /// use qudit_core::UnitaryMatrix;
364 /// use num_traits::Zero;
365 /// let y_mat = mat![
366 /// [c64::zero(), c64::new(0.0, -1.0)],
367 /// [c64::new(0.0, 1.0), c64::zero()],
368 /// ];
369 /// let unitary = UnitaryMatrix::new([2], y_mat);
370 /// let conjugate = unitary.conjugate();
371 /// let y_mat_conjugate = mat![
372 /// [c64::zero(), c64::new(0.0, 1.0)],
373 /// [c64::new(0.0, -1.0), c64::zero()],
374 /// ];
375 /// assert_eq!(conjugate, UnitaryMatrix::new([2], y_mat_conjugate));
376 /// ```
377 pub fn conjugate(&self) -> UnitaryMatrix<C> {
378 Self::new(self.radices.clone(), self.matrix.conjugate().to_owned())
379 }
380
381 /// Transpose the unitary matrix.
382 ///
383 /// # Returns
384 ///
385 /// A newly allocated unitary matrix that is the transpose of the original
386 /// unitary matrix.
387 ///
388 /// # Example
389 ///
390 /// ```
391 /// use qudit_core::c64;
392 /// use faer::mat;
393 /// use faer::Mat;
394 /// use qudit_core::UnitaryMatrix;
395 /// use num_traits::Zero;
396 /// let y_mat = mat![
397 /// [c64::zero(), c64::new(0.0, -1.0)],
398 /// [c64::new(0.0, 1.0), c64::zero()],
399 /// ];
400 /// let unitary = UnitaryMatrix::new([2], y_mat);
401 /// let transpose = unitary.transpose();
402 /// let y_mat_transpose = mat![
403 /// [c64::zero(), c64::new(0.0, 1.0)],
404 /// [c64::new(0.0, -1.0), c64::zero()],
405 /// ];
406 /// assert_eq!(transpose, UnitaryMatrix::new([2], y_mat_transpose));
407 /// ```
408 pub fn transpose(&self) -> UnitaryMatrix<C> {
409 Self::new(self.radices.clone(), self.matrix.transpose().to_owned())
410 }
411
412 /// Adjoint or dagger the unitary matrix.
413 ///
414 /// # Returns
415 ///
416 /// A newly allocated unitary matrix that is the adjoint of the original
417 /// unitary matrix.
418 ///
419 /// # Example
420 ///
421 /// ```
422 /// use qudit_core::c64;
423 /// use faer::mat;
424 /// use faer::Mat;
425 /// use qudit_core::UnitaryMatrix;
426 /// use num_traits::Zero;
427 /// let y_mat: Mat<c64> = mat![
428 /// [c64::zero(), c64::new(0.0, -1.0)],
429 /// [c64::new(0.0, 1.0), c64::zero()],
430 /// ];
431 /// let unitary = UnitaryMatrix::new([2], y_mat);
432 /// let dagger: UnitaryMatrix<c64> = unitary.dagger();
433 /// let y_mat_adjoint = mat![
434 /// [c64::zero(), c64::new(0.0, -1.0)],
435 /// [c64::new(0.0, 1.0), c64::zero()],
436 /// ];
437 /// assert_eq!(dagger, UnitaryMatrix::new([2], y_mat_adjoint));
438 /// assert_eq!(dagger.dagger(), unitary);
439 /// assert_eq!(dagger.dot(&unitary), Mat::<c64>::identity(2, 2));
440 /// assert_eq!(unitary.dot(&dagger), Mat::<c64>::identity(2, 2));
441 /// ```
442 pub fn dagger(&self) -> Self {
443 Self::new(self.radices.clone(), self.matrix.adjoint().to_owned())
444 }
445
446 /// Adjoint or dagger the unitary matrix (Alias for [Self::dagger]).
447 pub fn adjoint(&self) -> Self {
448 self.dagger()
449 }
450
451 /// Multiply the unitary matrix by another matrix.
452 ///
453 /// # Arguments
454 ///
455 /// * `rhs` - The matrix to multiply by.
456 ///
457 /// # Returns
458 ///
459 /// A newly allocated unitary matrix that is the result of multiplying the
460 /// original unitary matrix by the other matrix.
461 ///
462 /// # Panics
463 ///
464 /// Panics if the two matrices have different dimensions.
465 ///
466 /// # Example
467 ///
468 /// ```
469 /// use qudit_core::c64;
470 /// use faer::mat;
471 /// use faer::Mat;
472 /// use qudit_core::UnitaryMatrix;
473 /// use num_traits::Zero;
474 /// let y_mat = mat![
475 /// [c64::zero(), c64::new(0.0, -1.0)],
476 /// [c64::new(0.0, 1.0), c64::zero()],
477 /// ];
478 /// let unitary = UnitaryMatrix::new([2], y_mat.clone());
479 /// let result = unitary.dot(&unitary);
480 /// assert_eq!(result, UnitaryMatrix::new([2], y_mat.clone() * y_mat));
481 /// ```
482 ///
483 /// # See Also
484 ///
485 /// * [crate::accel::matmul_unchecked] - The accelerated version of the matrix multiplication.
486 pub fn dot(&self, rhs: impl AsMatRef<T = C, Rows = usize, Cols = usize>) -> Self {
487 Self::new(
488 self.radices.clone(),
489 self.matrix.as_ref() * rhs.as_mat_ref(),
490 )
491 }
492
493 /// Kronecker product the unitary matrix with another matrix.
494 ///
495 /// # Arguments
496 ///
497 /// * `rhs` - The matrix to kronecker product with.
498 ///
499 /// # Returns
500 ///
501 /// A newly allocated unitary matrix that is the result of kronecker producting
502 /// the original unitary matrix with the other matrix.
503 ///
504 /// # Example
505 ///
506 /// ```
507 /// use qudit_core::c64;
508 /// use faer::mat;
509 /// use faer::Mat;
510 /// use qudit_core::UnitaryMatrix;
511 /// use num_traits::Zero;
512 /// let y_mat = mat![
513 /// [c64::zero(), c64::new(0.0, -1.0)],
514 /// [c64::new(0.0, 1.0), c64::zero()],
515 /// ];
516 /// let unitary = UnitaryMatrix::new([2], y_mat.clone());
517 /// let result = unitary.kron(&unitary);
518 /// let y_mat_kron = mat![
519 /// [c64::zero(), c64::zero(), c64::zero(), c64::new(-1.0, 0.0)],
520 /// [c64::zero(), c64::zero(), c64::new(1.0, 0.0), c64::zero()],
521 /// [c64::zero(), c64::new(1.0, 0.0), c64::zero(), c64::zero()],
522 /// [c64::new(-1.0, 0.0), c64::zero(), c64::zero(), c64::zero()],
523 /// ];
524 /// assert_eq!(result, UnitaryMatrix::new([2, 2], y_mat_kron));
525 /// ```
526 ///
527 /// # See Also
528 ///
529 /// * [Mat::kron] - The method used to perform the kronecker product.
530 /// * [crate::accel::kron] - The accelerated version of the kronecker product.
531 pub fn kron(&self, rhs: &UnitaryMatrix<C>) -> Self {
532 let mut dst = Mat::zeros(self.nrows() * rhs.nrows(), self.ncols() * rhs.ncols());
533 faer::linalg::kron::kron(dst.as_mut(), self.matrix.as_ref(), rhs.as_mat_ref());
534 Self::new(self.radices.concat(&rhs.radices()), dst)
535 }
536}
537
538impl<C: ComplexScalar> QuditSystem for UnitaryMatrix<C> {
539 #[inline(always)]
540 fn radices(&self) -> Radices {
541 self.radices.clone()
542 }
543
544 #[inline(always)]
545 fn num_qudits(&self) -> usize {
546 self.radices.len()
547 }
548
549 #[inline(always)]
550 fn dimension(&self) -> usize {
551 self.matrix.nrows()
552 }
553}
554
555impl<C: ComplexScalar> Deref for UnitaryMatrix<C> {
556 type Target = Mat<C>;
557
558 #[inline(always)]
559 fn deref(&self) -> &Self::Target {
560 &self.matrix
561 }
562}
563
564impl<C: ComplexScalar> AsMatRef for UnitaryMatrix<C> {
565 type T = C;
566 type Rows = usize;
567 type Cols = usize;
568 type Owned = Mat<C>;
569
570 #[inline(always)]
571 fn as_mat_ref(&self) -> MatRef<'_, C> {
572 self.matrix.as_ref()
573 }
574}
575
576impl<C: ComplexScalar> Index<(usize, usize)> for UnitaryMatrix<C> {
577 type Output = C;
578
579 #[inline(always)]
580 fn index(&self, index: (usize, usize)) -> &Self::Output {
581 &self.matrix[index]
582 }
583}
584
585impl<C: ComplexScalar> Sub<UnitaryMatrix<C>> for UnitaryMatrix<C> {
586 type Output = Mat<C>;
587
588 fn sub(self, rhs: Self) -> Self::Output {
589 self.matrix - rhs.matrix
590 }
591}
592
593impl<C: ComplexScalar> Mul<UnitaryMatrix<C>> for UnitaryMatrix<C> {
594 type Output = UnitaryMatrix<C>;
595
596 fn mul(self, rhs: Self) -> Self::Output {
597 let output = Mat::from_fn(self.nrows(), self.ncols(), |i, j| {
598 self[(i, j)] * rhs[(i, j)]
599 });
600 UnitaryMatrix::new(self.radices, output)
601 }
602}
603
604impl<C: ComplexScalar> Mul<&UnitaryMatrix<C>> for Mat<C> {
605 type Output = Mat<C>;
606
607 fn mul(self, rhs: &UnitaryMatrix<C>) -> Self::Output {
608 Mat::from_fn(self.nrows(), self.ncols(), |i, j| {
609 self[(i, j)] * rhs[(i, j)]
610 })
611 }
612}
613
614impl<C: ComplexScalar> Mul<UnitaryMatrix<C>> for Mat<C> {
615 type Output = Mat<C>;
616
617 fn mul(self, rhs: UnitaryMatrix<C>) -> Self::Output {
618 Mat::from_fn(self.nrows(), self.ncols(), |i, j| {
619 self[(i, j)] * rhs[(i, j)]
620 })
621 }
622}
623
624impl<C: ComplexScalar> Mul<&UnitaryMatrix<C>> for &Mat<C> {
625 type Output = Mat<C>;
626
627 fn mul(self, rhs: &UnitaryMatrix<C>) -> Self::Output {
628 Mat::from_fn(self.nrows(), self.ncols(), |i, j| {
629 self[(i, j)] * rhs[(i, j)]
630 })
631 }
632}
633
634impl<C: ComplexScalar> Mul<UnitaryMatrix<C>> for &Mat<C> {
635 type Output = Mat<C>;
636
637 fn mul(self, rhs: UnitaryMatrix<C>) -> Self::Output {
638 Mat::from_fn(self.nrows(), self.ncols(), |i, j| {
639 self[(i, j)] * rhs[(i, j)]
640 })
641 }
642}
643
644impl<C: ComplexScalar> Debug for UnitaryMatrix<C> {
645 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
646 // TODO: print radices and unitary's complex numbers more cleanly
647 write!(f, "Unitary({:?})", self.matrix)
648 }
649}
650
651impl<C: ComplexScalar> PartialEq<UnitaryMatrix<C>> for UnitaryMatrix<C> {
652 fn eq(&self, other: &UnitaryMatrix<C>) -> bool {
653 self.matrix == other.matrix
654 }
655}
656
657impl<C: ComplexScalar> PartialEq<Mat<C>> for UnitaryMatrix<C> {
658 fn eq(&self, other: &Mat<C>) -> bool {
659 self.matrix == *other
660 }
661}
662
663impl<C: ComplexScalar> PartialEq<Mat<i32>> for UnitaryMatrix<C> {
664 fn eq(&self, other: &Mat<i32>) -> bool {
665 self.matrix.nrows() == other.nrows()
666 && self.matrix.ncols() == other.ncols()
667 && self.matrix.col_iter().zip(other.col_iter()).all(|(a, b)| {
668 a.iter()
669 .zip(b.iter())
670 .all(|(a, b)| *a == C::from(*b).unwrap())
671 })
672 }
673}
674
675impl<C: ComplexScalar> Eq for UnitaryMatrix<C> {}
676
677impl<C: ComplexScalar> From<UnitaryMatrix<C>> for Mat<C> {
678 fn from(unitary: UnitaryMatrix<C>) -> Self {
679 unitary.matrix
680 }
681}
682
683impl<'a, C: ComplexScalar> From<&'a UnitaryMatrix<C>> for MatRef<'a, C> {
684 fn from(unitary: &'a UnitaryMatrix<C>) -> Self {
685 unitary.matrix.as_ref()
686 }
687}
688
689impl<'a, C: ComplexScalar> From<&'a mut UnitaryMatrix<C>> for MatRef<'a, C> {
690 fn from(unitary: &'a mut UnitaryMatrix<C>) -> Self {
691 unitary.matrix.as_ref()
692 }
693}
694
695impl<'a, C: ComplexScalar> From<&'a mut UnitaryMatrix<C>> for MatMut<'a, C> {
696 fn from(unitary: &'a mut UnitaryMatrix<C>) -> Self {
697 unitary.matrix.as_mut()
698 }
699}
700
701impl<C: ComplexScalar> BitWidthConvertible for UnitaryMatrix<C> {
702 type Width32 = UnitaryMatrix<c32>;
703 type Width64 = UnitaryMatrix<c64>;
704
705 fn to32(self) -> Self::Width32 {
706 if is_same::<c32, C>() {
707 coerce_static(self)
708 } else {
709 let matrix = Mat::from_fn(self.matrix.nrows(), self.matrix.ncols(), |i, j| {
710 self.matrix[(i, j)].to32()
711 });
712 UnitaryMatrix::new(self.radices, matrix)
713 }
714 }
715
716 fn to64(self) -> Self::Width64 {
717 if is_same::<c64, C>() {
718 coerce_static(self)
719 } else {
720 let matrix = Mat::from_fn(self.matrix.nrows(), self.matrix.ncols(), |i, j| {
721 self.matrix[(i, j)].to64()
722 });
723 UnitaryMatrix::new(self.radices, matrix)
724 }
725 }
726
727 fn from32(unitary: Self::Width32) -> Self {
728 if is_same::<c32, C>() {
729 coerce_static(unitary)
730 } else {
731 let matrix = Mat::from_fn(unitary.matrix.nrows(), unitary.matrix.ncols(), |i, j| {
732 C::from32(unitary.matrix[(i, j)])
733 });
734 UnitaryMatrix::new(unitary.radices, matrix)
735 }
736 }
737
738 fn from64(unitary: Self::Width64) -> Self {
739 if is_same::<c64, C>() {
740 coerce_static(unitary)
741 } else {
742 let matrix = Mat::from_fn(unitary.matrix.nrows(), unitary.matrix.ncols(), |i, j| {
743 C::from64(unitary.matrix[(i, j)])
744 });
745 UnitaryMatrix::new(unitary.radices, matrix)
746 }
747 }
748}
749
750#[cfg(test)]
751mod test {
752 use super::*;
753
754 impl<C: ComplexScalar> UnitaryMatrix<C> {
755 /// Checks if the unitary matrix is close to another unitary matrix.
756 ///
757 /// # Arguments
758 ///
759 /// * `x` - The other unitary matrix to compare to.
760 ///
761 /// # Notes
762 ///
763 /// This is a global-phase-agnostic check. It uses the
764 /// `get_distance_from` method to calculate the distance
765 /// between the two matrices.
766 ///
767 /// # See Also
768 ///
769 /// * [UnitaryMatrix::get_distance_from] - The method used to calculate the distance.
770 pub fn assert_close_to(&self, x: impl AsMatRef<T = C, Rows = usize, Cols = usize>) {
771 let dist = self.get_distance_from(x);
772 assert!(
773 C::R::is_close(dist, 0.0),
774 "Distance between unitaries is {:?}",
775 dist
776 )
777 }
778 }
779}
780
781#[cfg(feature = "python")]
782mod python {
783 use super::*;
784 use crate::PyRegistrar;
785 use pyo3::{exceptions::PyTypeError, prelude::*};
786
787 use numpy::{PyArray2, PyReadonlyArray2, PyUntypedArrayMethods};
788
789 /// Python wrapper for UnitaryMatrix
790 #[pyclass(name = "UnitaryMatrix")]
791 #[derive(Clone)]
792 pub struct PyUnitaryMatrix {
793 inner: UnitaryMatrix<c64>,
794 }
795
796 #[pymethods]
797 impl PyUnitaryMatrix {
798 /// Create a new unitary matrix from a numpy array
799 #[new]
800 #[pyo3(signature = (radices, matrix))]
801 fn new(radices: Radices, matrix: PyReadonlyArray2<c64>) -> PyResult<Self> {
802 let array = matrix.as_array();
803 let (nrows, ncols) = array.dim();
804
805 let mat = Mat::from_fn(nrows, ncols, |i, j| array[[i, j]]);
806
807 if !UnitaryMatrix::is_unitary(&mat) {
808 return Err(pyo3::exceptions::PyValueError::new_err(
809 "Matrix is not unitary",
810 ));
811 }
812
813 Ok(Self {
814 inner: UnitaryMatrix::new(radices, mat),
815 })
816 }
817
818 /// Convert to numpy array
819 fn to_array<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray2<c64>> {
820 let rows: Vec<Vec<c64>> = (0..self.inner.nrows())
821 .map(|i| {
822 (0..self.inner.ncols())
823 .map(|j| self.inner[(i, j)])
824 .collect()
825 })
826 .collect();
827 PyArray2::from_vec2(py, &rows).unwrap()
828 }
829 }
830
831 // Conversion traits
832 impl From<UnitaryMatrix<c64>> for PyUnitaryMatrix {
833 fn from(inner: UnitaryMatrix<c64>) -> Self {
834 Self { inner }
835 }
836 }
837
838 impl From<PyUnitaryMatrix> for UnitaryMatrix<c64> {
839 fn from(py_unitary: PyUnitaryMatrix) -> Self {
840 py_unitary.inner
841 }
842 }
843
844 impl<'py> IntoPyObject<'py> for UnitaryMatrix<c64> {
845 type Target = PyUnitaryMatrix;
846 type Output = Bound<'py, Self::Target>;
847 type Error = PyErr;
848
849 fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
850 Bound::new(py, PyUnitaryMatrix::from(self))
851 }
852 }
853
854 impl<'a, 'py> FromPyObject<'a, 'py> for UnitaryMatrix<c64> {
855 type Error = PyErr;
856
857 fn extract(obj: Borrowed<'a, 'py, PyAny>) -> PyResult<Self> {
858 if let Ok(py_unitary) = obj.extract::<PyRef<PyUnitaryMatrix>>() {
859 Ok(py_unitary.inner.clone())
860 } else if let Ok(py_unitary) = obj.extract::<PyReadonlyArray2<c64>>() {
861 let shape = py_unitary.shape();
862 let dimension = shape[0];
863 let radices = Radices::guess(dimension);
864 Ok(PyUnitaryMatrix::new(radices, py_unitary)
865 .unwrap()
866 .inner
867 .clone())
868 } else {
869 Err(PyTypeError::new_err(
870 "Invalid type for unitary matrix extraction.",
871 ))
872 }
873 }
874 }
875
876 /// Registers the PyUnitaryMatrix with the Python Module.
877 fn register(parent_module: &Bound<'_, PyModule>) -> PyResult<()> {
878 parent_module.add_class::<PyUnitaryMatrix>()?;
879 Ok(())
880 }
881 inventory::submit! { PyRegistrar { func: register } }
882}