qudit_expr/expressions/
unitary.rs

1use std::collections::{BTreeSet, HashMap};
2use std::ops::{Deref, DerefMut};
3
4use faer::Mat;
5use qudit_core::ComplexScalar;
6use qudit_core::QuditSystem;
7use qudit_core::Radices;
8use qudit_core::UnitaryMatrix;
9
10use crate::{ComplexExpression, UnitarySystemExpression};
11use crate::{
12    GenerationShape, TensorExpression,
13    expressions::JittableExpression,
14    index::{IndexDirection, TensorIndex},
15};
16
17use super::NamedExpression;
18
19#[derive(PartialEq, Eq, Debug, Clone)]
20pub struct UnitaryExpression {
21    inner: NamedExpression,
22    radices: Radices,
23}
24
25// pub trait ExpressionContainer {
26//     pub fn elements(&self) -> &[ComplexExpression];
27//     pub fn elements_mut(&mut self) -> &mut [ComplexExpression];
28
29//     pub fn conjugate(&mut self) {
30//         todo!()
31//     }
32// }
33
34// pub trait MatrixExpression: ExpressionContainer {
35//     pub fn nrows(&self) -> usize;
36//     pub fn ncols(&self) -> usize;
37
38//     pub fn dagger(&mut self) {
39//         todo!()
40//     }
41
42//     pub fn transpose(&mut self) {
43//         todo!()
44//     }
45// }
46
47impl UnitaryExpression {
48    pub fn new<T: AsRef<str>>(input: T) -> Self {
49        TensorExpression::new(input).try_into().unwrap()
50    }
51
52    // TODO: Change ToRadices by implementing appropriate Froms
53    pub fn identity<S: Into<String>, T: Into<Radices>>(name: S, radices: T) -> Self {
54        let radices = radices.into();
55        let dim = radices.dimension();
56        let mut body = Vec::with_capacity(dim * dim);
57        for i in 0..dim {
58            for j in 0..dim {
59                if i == j {
60                    body.push(ComplexExpression::one());
61                } else {
62                    body.push(ComplexExpression::zero());
63                }
64            }
65        }
66        let inner = NamedExpression::new(name, vec![], body);
67
68        UnitaryExpression { inner, radices }
69    }
70
71    pub fn set_radices(&mut self, new_radices: Radices) {
72        assert_eq!(self.radices.dimension(), new_radices.dimension());
73        self.radices = new_radices;
74    }
75
76    pub fn transpose(&mut self) {
77        let dim = self.dimension();
78        for i in 1..dim {
79            for j in i..dim {
80                let (head, tail) = self.split_at_mut(j * dim + i);
81                std::mem::swap(&mut head[i * dim + j], &mut tail[0]);
82            }
83        }
84    }
85
86    pub fn dagger(&mut self) {
87        self.conjugate();
88        self.transpose();
89    }
90
91    pub fn classically_multiplex(
92        expressions: &[&UnitaryExpression],
93        new_dim_radices: &[usize],
94    ) -> UnitarySystemExpression {
95        let new_dim = new_dim_radices.iter().product();
96        assert!(
97            expressions.len() == new_dim,
98            "Cannot multiplex a number of expressions not equal to the length of new dimension."
99        );
100        assert!(!expressions.is_empty());
101        for expression in expressions {
102            assert_eq!(
103                expression.radices(),
104                expressions[0].radices(),
105                "All expressions must have equal radices."
106            );
107        }
108
109        // construct larger tensor
110        let mut new_expressions =
111            Vec::with_capacity(expressions[0].dimension() * expressions[0].dimension() * new_dim);
112        let mut var_map_number = 0;
113        let mut new_variables = Vec::new();
114        for i in 0..new_dim {
115            // TODO: double clone
116            let mut renamed_expression = expressions[i].clone();
117            renamed_expression.alpha_rename(Some(var_map_number));
118            new_expressions.extend(renamed_expression.elements().iter().cloned());
119            for _ in 0..expressions[i].variables().len() {
120                new_variables.push(format!("alpha_{}", var_map_number));
121                var_map_number += 1;
122            }
123        }
124
125        let new_indices = new_dim_radices
126            .iter()
127            .map(|r| (IndexDirection::Batch, (*r)))
128            .chain(
129                expressions[0]
130                    .radices()
131                    .iter()
132                    .map(|r| (IndexDirection::Output, usize::from(*r))),
133            )
134            .chain(
135                expressions[0]
136                    .radices()
137                    .iter()
138                    .map(|r| (IndexDirection::Input, usize::from(*r))),
139            )
140            .enumerate()
141            .map(|(id, (dir, size))| TensorIndex::new(dir, id, size))
142            .collect();
143
144        let name = {
145            // if all expressions have same name: multiplex_name
146            // else: multiplex_name_name_...
147            if expressions
148                .iter()
149                .all(|e| e.name() == expressions[0].name())
150            {
151                format!("Multiplexed_{}", expressions[0].name())
152            } else {
153                "Multiplexed_".to_string()
154                    + &expressions
155                        .iter()
156                        .map(|e| e.name())
157                        .collect::<Vec<_>>()
158                        .join("_")
159            }
160        };
161
162        let inner = NamedExpression::new(name, new_variables, new_expressions);
163        TensorExpression::from_raw(new_indices, inner)
164            .try_into()
165            .unwrap()
166    }
167
168    // TODO: better API for user-facing thoughts
169    pub fn classically_control(
170        &self,
171        positions: &[usize],
172        new_dim_radices: &[usize],
173    ) -> UnitarySystemExpression {
174        let new_dim = new_dim_radices.iter().product();
175        assert!(
176            positions.len() <= new_dim,
177            "Cannot place unitary in more locations than length of new dimension."
178        );
179
180        // Ensure positions are unique
181        let mut sorted_positions = positions.to_vec();
182        sorted_positions.sort_unstable();
183        assert!(
184            sorted_positions.iter().collect::<BTreeSet<_>>().len() == sorted_positions.len(),
185            "Positions must be unique"
186        );
187
188        // Construct identity expression
189        let mut identity = Vec::with_capacity(self.dimension() * self.dimension());
190        for i in 0..self.dimension() {
191            for j in 0..self.dimension() {
192                if i == j {
193                    identity.push(ComplexExpression::one());
194                } else {
195                    identity.push(ComplexExpression::zero());
196                }
197            }
198        }
199
200        // construct larger tensor
201        let mut expressions = Vec::with_capacity(self.dimension() * self.dimension() * new_dim);
202        for i in 0..new_dim {
203            if positions.contains(&i) {
204                expressions.extend(self.elements().iter().cloned());
205            } else {
206                expressions.extend(identity.iter().cloned());
207            }
208        }
209
210        let new_indices = new_dim_radices
211            .iter()
212            .map(|r| (IndexDirection::Batch, (*r)))
213            .chain(
214                self.radices
215                    .iter()
216                    .map(|r| (IndexDirection::Output, usize::from(*r))),
217            )
218            .chain(
219                self.radices
220                    .iter()
221                    .map(|r| (IndexDirection::Input, usize::from(*r))),
222            )
223            .enumerate()
224            .map(|(id, (dir, size))| TensorIndex::new(dir, id, size))
225            .collect();
226
227        let inner = NamedExpression::new(
228            format!("Stacked_{}", self.name()),
229            self.variables().to_owned(),
230            expressions,
231        );
232        TensorExpression::from_raw(new_indices, inner)
233            .try_into()
234            .unwrap()
235    }
236
237    pub fn embed(
238        &mut self,
239        sub_matrix: UnitaryExpression,
240        top_left_row_idx: usize,
241        top_left_col_idx: usize,
242    ) {
243        let nrows = self.dimension();
244        let ncols = self.dimension();
245        let sub_nrows = sub_matrix.dimension();
246        let sub_ncols = sub_matrix.dimension();
247
248        if top_left_row_idx + sub_nrows > nrows || top_left_col_idx + sub_ncols > ncols {
249            panic!("Embedding matrix is too large");
250        }
251
252        for i in 0..sub_nrows {
253            for j in 0..sub_ncols {
254                // TODO: remove clone by building into_iter for sub_matrix
255                let sub_expr = sub_matrix[i * sub_ncols + j].clone();
256                self[(top_left_row_idx + i) * ncols + (top_left_col_idx + j)] = sub_expr;
257            }
258        }
259
260        // Update variables: collect all unique variables from self and sub_matrix
261        let mut new_variables: Vec<String> = self.variables().to_vec();
262        for var in sub_matrix.variables().iter() {
263            if !new_variables.contains(var) {
264                new_variables.push(var.clone());
265            }
266        }
267        self.set_variables(new_variables);
268    }
269
270    pub fn otimes<U: AsRef<UnitaryExpression>>(&self, other: U) -> Self {
271        let other = other.as_ref();
272        let mut variables = Vec::new();
273        let mut var_map_self = HashMap::new();
274        let mut i = 0;
275        for var in self.variables().iter() {
276            variables.push(format!("x{}", i));
277            var_map_self.insert(var.clone(), format!("x{}", i));
278            i += 1;
279        }
280        let mut var_map_other = HashMap::new();
281        for var in other.variables().iter() {
282            variables.push(format!("x{}", i));
283            var_map_other.insert(var.clone(), format!("x{}", i));
284            i += 1;
285        }
286
287        let lhs_nrows = self.dimension();
288        let lhs_ncols = self.dimension();
289        let rhs_nrows = other.dimension();
290        let rhs_ncols = other.dimension();
291
292        let mut out_body = Vec::with_capacity(lhs_nrows * lhs_ncols * rhs_ncols * rhs_nrows);
293
294        for i in 0..lhs_nrows {
295            for j in 0..rhs_nrows {
296                for k in 0..lhs_ncols {
297                    for l in 0..rhs_ncols {
298                        out_body.push(
299                            self[i * lhs_ncols + k].map_var_names(&var_map_self)
300                                * other[j * rhs_ncols + l].map_var_names(&var_map_other),
301                        );
302                    }
303                }
304            }
305        }
306
307        let inner = NamedExpression::new(
308            format!("{} ⊗ {}", self.name(), other.name()),
309            variables,
310            out_body,
311        );
312
313        UnitaryExpression {
314            inner,
315            radices: self.radices.concat(&other.radices),
316        }
317    }
318
319    pub fn dot<U: AsRef<UnitaryExpression>>(&self, other: U) -> Self {
320        let other = other.as_ref();
321        let lhs_nrows = self.dimension();
322        let lhs_ncols = self.dimension();
323        let rhs_nrows = other.dimension();
324        let rhs_ncols = other.dimension();
325
326        if lhs_ncols != rhs_nrows {
327            panic!(
328                "Matrix dimensions do not match for dot product: {} != {}",
329                lhs_ncols, rhs_nrows
330            );
331        }
332
333        let mut variables = Vec::new();
334        let mut var_map_self = HashMap::new();
335        let mut i = 0;
336        for var in self.variables().iter() {
337            variables.push(format!("x{}", i));
338            var_map_self.insert(var.clone(), format!("x{}", i));
339            i += 1;
340        }
341        let mut var_map_other = HashMap::new();
342        for var in other.variables().iter() {
343            variables.push(format!("x{}", i));
344            var_map_other.insert(var.clone(), format!("x{}", i));
345            i += 1;
346        }
347
348        let mut out_body = Vec::with_capacity(lhs_nrows * rhs_ncols);
349
350        for i in 0..lhs_nrows {
351            for j in 0..rhs_ncols {
352                let mut sum = self[i * lhs_ncols].map_var_names(&var_map_self)
353                    * other[j].map_var_names(&var_map_other);
354                for k in 1..lhs_ncols {
355                    sum += self[i * lhs_ncols + k].map_var_names(&var_map_self)
356                        * other[k * rhs_ncols + j].map_var_names(&var_map_other);
357                }
358                out_body.push(sum);
359            }
360        }
361
362        let inner = NamedExpression::new(
363            format!("{} ⋅ {}", self.name(), other.name()),
364            variables,
365            out_body,
366        );
367
368        UnitaryExpression {
369            inner,
370            radices: self.radices.clone(),
371        }
372    }
373
374    pub fn get_arg_map<C: ComplexScalar>(&self, args: &[C::R]) -> HashMap<&str, C::R> {
375        self.variables()
376            .iter()
377            .zip(args.iter())
378            .map(|(a, b)| (a.as_str(), *b))
379            .collect()
380    }
381
382    pub fn eval<C: ComplexScalar>(&self, args: &[C::R]) -> UnitaryMatrix<C> {
383        let arg_map = self.get_arg_map::<C>(args);
384        let dim = self.radices.dimension();
385        let mut mat = Mat::zeros(dim, dim);
386        for i in 0..dim {
387            for j in 0..dim {
388                *mat.get_mut(i, j) = self[i * self.dimension() + j].eval(&arg_map);
389            }
390        }
391        UnitaryMatrix::new(self.radices.clone(), mat)
392    }
393}
394
395impl JittableExpression for UnitaryExpression {
396    fn generation_shape(&self) -> GenerationShape {
397        GenerationShape::Matrix(self.radices.dimension(), self.radices.dimension())
398    }
399}
400
401impl AsRef<UnitaryExpression> for UnitaryExpression {
402    fn as_ref(&self) -> &UnitaryExpression {
403        self
404    }
405}
406
407impl AsRef<NamedExpression> for UnitaryExpression {
408    fn as_ref(&self) -> &NamedExpression {
409        &self.inner
410    }
411}
412
413impl From<UnitaryExpression> for NamedExpression {
414    fn from(value: UnitaryExpression) -> Self {
415        value.inner
416    }
417}
418
419impl Deref for UnitaryExpression {
420    type Target = NamedExpression;
421
422    fn deref(&self) -> &Self::Target {
423        &self.inner
424    }
425}
426
427impl DerefMut for UnitaryExpression {
428    fn deref_mut(&mut self) -> &mut Self::Target {
429        &mut self.inner
430    }
431}
432
433impl From<UnitaryExpression> for TensorExpression {
434    fn from(value: UnitaryExpression) -> Self {
435        let UnitaryExpression { inner, radices } = value;
436        let indices = radices
437            .iter()
438            .map(|r| (IndexDirection::Output, usize::from(*r)))
439            .chain(
440                radices
441                    .iter()
442                    .map(|r| (IndexDirection::Input, usize::from(*r))),
443            )
444            .enumerate()
445            .map(|(i, (d, r))| TensorIndex::new(d, i, r))
446            .collect();
447        TensorExpression::from_raw(indices, inner)
448    }
449}
450
451impl TryFrom<TensorExpression> for UnitaryExpression {
452    // TODO: Come up with proper error handling
453    type Error = String;
454
455    fn try_from(value: TensorExpression) -> Result<Self, Self::Error> {
456        let mut input_radices = vec![];
457        let mut output_radices = vec![];
458        for idx in value.indices() {
459            match idx.direction() {
460                IndexDirection::Input => {
461                    input_radices.push(idx.index_size());
462                }
463                IndexDirection::Output => {
464                    output_radices.push(idx.index_size());
465                }
466                _ => {
467                    return Err(String::from(
468                        "Cannot convert a tensor with non-input, non-output indices to an isometry.",
469                    ));
470                }
471            }
472        }
473
474        if input_radices != output_radices {
475            return Err(String::from(
476                "Non-square matrix tensor cannot be converted to a unitary.",
477            ));
478        }
479
480        Ok(UnitaryExpression {
481            inner: value.into(),
482            radices: input_radices.into(),
483        })
484    }
485}
486
487impl<C: ComplexScalar> From<UnitaryMatrix<C>> for UnitaryExpression {
488    fn from(value: UnitaryMatrix<C>) -> Self {
489        let mut body = vec![Vec::with_capacity(value.ncols()); value.nrows()];
490        for col in value.col_iter() {
491            for (row_id, elem) in col.iter().enumerate() {
492                body[row_id].push(ComplexExpression::from(*elem));
493            }
494        }
495        let mut flat_body = Vec::with_capacity(value.ncols() * value.nrows());
496        for row in body.into_iter() {
497            for elem in row.into_iter() {
498                flat_body.push(elem);
499            }
500        }
501        let inner = NamedExpression::new("Constant", vec![], flat_body);
502
503        UnitaryExpression {
504            inner,
505            radices: value.radices(),
506        }
507    }
508}
509
510impl QuditSystem for UnitaryExpression {
511    fn radices(&self) -> Radices {
512        self.radices.clone()
513    }
514}
515
516#[cfg(feature = "python")]
517mod python {
518    use super::*;
519    use crate::python::PyExpressionRegistrar;
520    use ndarray::ArrayViewMut2;
521    use numpy::PyArray2;
522    use numpy::PyArrayMethods;
523    use pyo3::prelude::*;
524    use pyo3::types::PyTuple;
525    use qudit_core::Radix;
526    use qudit_core::c64;
527
528    #[pyclass]
529    #[pyo3(name = "UnitaryExpression")]
530    pub struct PyUnitaryExpression {
531        expr: UnitaryExpression,
532    }
533
534    #[pymethods]
535    impl PyUnitaryExpression {
536        #[new]
537        fn new(expr: String) -> Self {
538            Self {
539                expr: UnitaryExpression::new(expr),
540            }
541        }
542
543        #[staticmethod]
544        fn identity(name: String, radices: Vec<usize>) -> Self {
545            Self {
546                expr: UnitaryExpression::identity(name, radices),
547            }
548        }
549
550        #[pyo3(signature = (*args))]
551        fn __call__<'py>(&self, args: &Bound<'py, PyTuple>) -> PyResult<Bound<'py, PyArray2<c64>>> {
552            let py = args.py();
553            let args: Vec<f64> = args.extract()?;
554            let unitary = self.expr.eval(&args);
555            let py_array: Bound<'py, PyArray2<c64>> =
556                PyArray2::zeros(py, (unitary.dimension(), unitary.dimension()), false);
557
558            {
559                let mut readwrite = py_array.readwrite();
560                let mut py_array_view: ArrayViewMut2<c64> = readwrite.as_array_mut();
561
562                for (j, col) in unitary.col_iter().enumerate() {
563                    for (i, val) in col.iter().enumerate() {
564                        py_array_view[[i, j]] = *val;
565                    }
566                }
567            }
568
569            Ok(py_array)
570        }
571
572        fn num_params(&self) -> usize {
573            self.expr.num_params()
574        }
575
576        fn name(&self) -> String {
577            self.expr.name().to_string()
578        }
579
580        fn radices(&self) -> Vec<Radix> {
581            self.expr.radices().to_vec()
582        }
583
584        fn dimension(&self) -> usize {
585            self.expr.dimension()
586        }
587
588        fn transpose(&mut self) {
589            self.expr.transpose();
590        }
591
592        fn dagger(&mut self) {
593            self.expr.dagger();
594        }
595
596        fn otimes(&self, other: &PyUnitaryExpression) -> Self {
597            Self {
598                expr: self.expr.otimes(&other.expr),
599            }
600        }
601
602        fn dot(&self, other: &PyUnitaryExpression) -> Self {
603            Self {
604                expr: self.expr.dot(&other.expr),
605            }
606        }
607
608        fn embed(
609            &mut self,
610            sub_matrix: &PyUnitaryExpression,
611            top_left_row_idx: usize,
612            top_left_col_idx: usize,
613        ) {
614            self.expr
615                .embed(sub_matrix.expr.clone(), top_left_row_idx, top_left_col_idx);
616        }
617
618        // fn classically_control(&self, positions: Vec<usize>, new_dim_radices: Vec<usize>) -> crate::python::tensor::PyTensorExpression {
619        //     let result = self.expr.classically_control(&positions, &new_dim_radices);
620        //     result.into()
621        // }
622
623        fn __repr__(&self) -> String {
624            format!(
625                "UnitaryExpression(name='{}', radices={:?}, params={})",
626                self.expr.name(),
627                self.expr.radices().to_vec(),
628                self.expr.num_params()
629            )
630        }
631
632        fn __matmul__(&self, other: UnitaryExpression) -> PyUnitaryExpression {
633            Self {
634                expr: self.expr.dot(&other),
635            }
636        }
637    }
638
639    impl From<UnitaryExpression> for PyUnitaryExpression {
640        fn from(value: UnitaryExpression) -> Self {
641            PyUnitaryExpression { expr: value }
642        }
643    }
644
645    impl From<PyUnitaryExpression> for UnitaryExpression {
646        fn from(value: PyUnitaryExpression) -> Self {
647            value.expr
648        }
649    }
650
651    impl<'py> IntoPyObject<'py> for UnitaryExpression {
652        type Target = <PyUnitaryExpression as IntoPyObject<'py>>::Target;
653        type Output = Bound<'py, Self::Target>;
654        type Error = PyErr;
655
656        fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
657            let py_expr = PyUnitaryExpression::from(self);
658            Bound::new(py, py_expr)
659        }
660    }
661
662    impl<'a, 'py> FromPyObject<'a, 'py> for UnitaryExpression {
663        type Error = PyErr;
664
665        fn extract(ob: Borrowed<'a, 'py, PyAny>) -> PyResult<Self> {
666            let py_expr: PyRef<PyUnitaryExpression> = ob.extract()?;
667            Ok(py_expr.expr.clone())
668        }
669    }
670
671    /// Registers the UnitaryExpression class with the Python module.
672    fn register(parent_module: &Bound<'_, PyModule>) -> PyResult<()> {
673        parent_module.add_class::<PyUnitaryExpression>()?;
674        Ok(())
675    }
676    inventory::submit!(PyExpressionRegistrar { func: register });
677}