1use std::collections::{BTreeSet, HashMap};
2use std::ops::{Deref, DerefMut};
3
4use faer::Mat;
5use qudit_core::ComplexScalar;
6use qudit_core::QuditSystem;
7use qudit_core::Radices;
8use qudit_core::UnitaryMatrix;
9
10use crate::{ComplexExpression, UnitarySystemExpression};
11use crate::{
12 GenerationShape, TensorExpression,
13 expressions::JittableExpression,
14 index::{IndexDirection, TensorIndex},
15};
16
17use super::NamedExpression;
18
19#[derive(PartialEq, Eq, Debug, Clone)]
20pub struct UnitaryExpression {
21 inner: NamedExpression,
22 radices: Radices,
23}
24
25impl UnitaryExpression {
48 pub fn new<T: AsRef<str>>(input: T) -> Self {
49 TensorExpression::new(input).try_into().unwrap()
50 }
51
52 pub fn identity<S: Into<String>, T: Into<Radices>>(name: S, radices: T) -> Self {
54 let radices = radices.into();
55 let dim = radices.dimension();
56 let mut body = Vec::with_capacity(dim * dim);
57 for i in 0..dim {
58 for j in 0..dim {
59 if i == j {
60 body.push(ComplexExpression::one());
61 } else {
62 body.push(ComplexExpression::zero());
63 }
64 }
65 }
66 let inner = NamedExpression::new(name, vec![], body);
67
68 UnitaryExpression { inner, radices }
69 }
70
71 pub fn set_radices(&mut self, new_radices: Radices) {
72 assert_eq!(self.radices.dimension(), new_radices.dimension());
73 self.radices = new_radices;
74 }
75
76 pub fn transpose(&mut self) {
77 let dim = self.dimension();
78 for i in 1..dim {
79 for j in i..dim {
80 let (head, tail) = self.split_at_mut(j * dim + i);
81 std::mem::swap(&mut head[i * dim + j], &mut tail[0]);
82 }
83 }
84 }
85
86 pub fn dagger(&mut self) {
87 self.conjugate();
88 self.transpose();
89 }
90
91 pub fn classically_multiplex(
92 expressions: &[&UnitaryExpression],
93 new_dim_radices: &[usize],
94 ) -> UnitarySystemExpression {
95 let new_dim = new_dim_radices.iter().product();
96 assert!(
97 expressions.len() == new_dim,
98 "Cannot multiplex a number of expressions not equal to the length of new dimension."
99 );
100 assert!(!expressions.is_empty());
101 for expression in expressions {
102 assert_eq!(
103 expression.radices(),
104 expressions[0].radices(),
105 "All expressions must have equal radices."
106 );
107 }
108
109 let mut new_expressions =
111 Vec::with_capacity(expressions[0].dimension() * expressions[0].dimension() * new_dim);
112 let mut var_map_number = 0;
113 let mut new_variables = Vec::new();
114 for i in 0..new_dim {
115 let mut renamed_expression = expressions[i].clone();
117 renamed_expression.alpha_rename(Some(var_map_number));
118 new_expressions.extend(renamed_expression.elements().iter().cloned());
119 for _ in 0..expressions[i].variables().len() {
120 new_variables.push(format!("alpha_{}", var_map_number));
121 var_map_number += 1;
122 }
123 }
124
125 let new_indices = new_dim_radices
126 .iter()
127 .map(|r| (IndexDirection::Batch, (*r)))
128 .chain(
129 expressions[0]
130 .radices()
131 .iter()
132 .map(|r| (IndexDirection::Output, usize::from(*r))),
133 )
134 .chain(
135 expressions[0]
136 .radices()
137 .iter()
138 .map(|r| (IndexDirection::Input, usize::from(*r))),
139 )
140 .enumerate()
141 .map(|(id, (dir, size))| TensorIndex::new(dir, id, size))
142 .collect();
143
144 let name = {
145 if expressions
148 .iter()
149 .all(|e| e.name() == expressions[0].name())
150 {
151 format!("Multiplexed_{}", expressions[0].name())
152 } else {
153 "Multiplexed_".to_string()
154 + &expressions
155 .iter()
156 .map(|e| e.name())
157 .collect::<Vec<_>>()
158 .join("_")
159 }
160 };
161
162 let inner = NamedExpression::new(name, new_variables, new_expressions);
163 TensorExpression::from_raw(new_indices, inner)
164 .try_into()
165 .unwrap()
166 }
167
168 pub fn classically_control(
170 &self,
171 positions: &[usize],
172 new_dim_radices: &[usize],
173 ) -> UnitarySystemExpression {
174 let new_dim = new_dim_radices.iter().product();
175 assert!(
176 positions.len() <= new_dim,
177 "Cannot place unitary in more locations than length of new dimension."
178 );
179
180 let mut sorted_positions = positions.to_vec();
182 sorted_positions.sort_unstable();
183 assert!(
184 sorted_positions.iter().collect::<BTreeSet<_>>().len() == sorted_positions.len(),
185 "Positions must be unique"
186 );
187
188 let mut identity = Vec::with_capacity(self.dimension() * self.dimension());
190 for i in 0..self.dimension() {
191 for j in 0..self.dimension() {
192 if i == j {
193 identity.push(ComplexExpression::one());
194 } else {
195 identity.push(ComplexExpression::zero());
196 }
197 }
198 }
199
200 let mut expressions = Vec::with_capacity(self.dimension() * self.dimension() * new_dim);
202 for i in 0..new_dim {
203 if positions.contains(&i) {
204 expressions.extend(self.elements().iter().cloned());
205 } else {
206 expressions.extend(identity.iter().cloned());
207 }
208 }
209
210 let new_indices = new_dim_radices
211 .iter()
212 .map(|r| (IndexDirection::Batch, (*r)))
213 .chain(
214 self.radices
215 .iter()
216 .map(|r| (IndexDirection::Output, usize::from(*r))),
217 )
218 .chain(
219 self.radices
220 .iter()
221 .map(|r| (IndexDirection::Input, usize::from(*r))),
222 )
223 .enumerate()
224 .map(|(id, (dir, size))| TensorIndex::new(dir, id, size))
225 .collect();
226
227 let inner = NamedExpression::new(
228 format!("Stacked_{}", self.name()),
229 self.variables().to_owned(),
230 expressions,
231 );
232 TensorExpression::from_raw(new_indices, inner)
233 .try_into()
234 .unwrap()
235 }
236
237 pub fn embed(
238 &mut self,
239 sub_matrix: UnitaryExpression,
240 top_left_row_idx: usize,
241 top_left_col_idx: usize,
242 ) {
243 let nrows = self.dimension();
244 let ncols = self.dimension();
245 let sub_nrows = sub_matrix.dimension();
246 let sub_ncols = sub_matrix.dimension();
247
248 if top_left_row_idx + sub_nrows > nrows || top_left_col_idx + sub_ncols > ncols {
249 panic!("Embedding matrix is too large");
250 }
251
252 for i in 0..sub_nrows {
253 for j in 0..sub_ncols {
254 let sub_expr = sub_matrix[i * sub_ncols + j].clone();
256 self[(top_left_row_idx + i) * ncols + (top_left_col_idx + j)] = sub_expr;
257 }
258 }
259
260 let mut new_variables: Vec<String> = self.variables().to_vec();
262 for var in sub_matrix.variables().iter() {
263 if !new_variables.contains(var) {
264 new_variables.push(var.clone());
265 }
266 }
267 self.set_variables(new_variables);
268 }
269
270 pub fn otimes<U: AsRef<UnitaryExpression>>(&self, other: U) -> Self {
271 let other = other.as_ref();
272 let mut variables = Vec::new();
273 let mut var_map_self = HashMap::new();
274 let mut i = 0;
275 for var in self.variables().iter() {
276 variables.push(format!("x{}", i));
277 var_map_self.insert(var.clone(), format!("x{}", i));
278 i += 1;
279 }
280 let mut var_map_other = HashMap::new();
281 for var in other.variables().iter() {
282 variables.push(format!("x{}", i));
283 var_map_other.insert(var.clone(), format!("x{}", i));
284 i += 1;
285 }
286
287 let lhs_nrows = self.dimension();
288 let lhs_ncols = self.dimension();
289 let rhs_nrows = other.dimension();
290 let rhs_ncols = other.dimension();
291
292 let mut out_body = Vec::with_capacity(lhs_nrows * lhs_ncols * rhs_ncols * rhs_nrows);
293
294 for i in 0..lhs_nrows {
295 for j in 0..rhs_nrows {
296 for k in 0..lhs_ncols {
297 for l in 0..rhs_ncols {
298 out_body.push(
299 self[i * lhs_ncols + k].map_var_names(&var_map_self)
300 * other[j * rhs_ncols + l].map_var_names(&var_map_other),
301 );
302 }
303 }
304 }
305 }
306
307 let inner = NamedExpression::new(
308 format!("{} ⊗ {}", self.name(), other.name()),
309 variables,
310 out_body,
311 );
312
313 UnitaryExpression {
314 inner,
315 radices: self.radices.concat(&other.radices),
316 }
317 }
318
319 pub fn dot<U: AsRef<UnitaryExpression>>(&self, other: U) -> Self {
320 let other = other.as_ref();
321 let lhs_nrows = self.dimension();
322 let lhs_ncols = self.dimension();
323 let rhs_nrows = other.dimension();
324 let rhs_ncols = other.dimension();
325
326 if lhs_ncols != rhs_nrows {
327 panic!(
328 "Matrix dimensions do not match for dot product: {} != {}",
329 lhs_ncols, rhs_nrows
330 );
331 }
332
333 let mut variables = Vec::new();
334 let mut var_map_self = HashMap::new();
335 let mut i = 0;
336 for var in self.variables().iter() {
337 variables.push(format!("x{}", i));
338 var_map_self.insert(var.clone(), format!("x{}", i));
339 i += 1;
340 }
341 let mut var_map_other = HashMap::new();
342 for var in other.variables().iter() {
343 variables.push(format!("x{}", i));
344 var_map_other.insert(var.clone(), format!("x{}", i));
345 i += 1;
346 }
347
348 let mut out_body = Vec::with_capacity(lhs_nrows * rhs_ncols);
349
350 for i in 0..lhs_nrows {
351 for j in 0..rhs_ncols {
352 let mut sum = self[i * lhs_ncols].map_var_names(&var_map_self)
353 * other[j].map_var_names(&var_map_other);
354 for k in 1..lhs_ncols {
355 sum += self[i * lhs_ncols + k].map_var_names(&var_map_self)
356 * other[k * rhs_ncols + j].map_var_names(&var_map_other);
357 }
358 out_body.push(sum);
359 }
360 }
361
362 let inner = NamedExpression::new(
363 format!("{} ⋅ {}", self.name(), other.name()),
364 variables,
365 out_body,
366 );
367
368 UnitaryExpression {
369 inner,
370 radices: self.radices.clone(),
371 }
372 }
373
374 pub fn get_arg_map<C: ComplexScalar>(&self, args: &[C::R]) -> HashMap<&str, C::R> {
375 self.variables()
376 .iter()
377 .zip(args.iter())
378 .map(|(a, b)| (a.as_str(), *b))
379 .collect()
380 }
381
382 pub fn eval<C: ComplexScalar>(&self, args: &[C::R]) -> UnitaryMatrix<C> {
383 let arg_map = self.get_arg_map::<C>(args);
384 let dim = self.radices.dimension();
385 let mut mat = Mat::zeros(dim, dim);
386 for i in 0..dim {
387 for j in 0..dim {
388 *mat.get_mut(i, j) = self[i * self.dimension() + j].eval(&arg_map);
389 }
390 }
391 UnitaryMatrix::new(self.radices.clone(), mat)
392 }
393}
394
395impl JittableExpression for UnitaryExpression {
396 fn generation_shape(&self) -> GenerationShape {
397 GenerationShape::Matrix(self.radices.dimension(), self.radices.dimension())
398 }
399}
400
401impl AsRef<UnitaryExpression> for UnitaryExpression {
402 fn as_ref(&self) -> &UnitaryExpression {
403 self
404 }
405}
406
407impl AsRef<NamedExpression> for UnitaryExpression {
408 fn as_ref(&self) -> &NamedExpression {
409 &self.inner
410 }
411}
412
413impl From<UnitaryExpression> for NamedExpression {
414 fn from(value: UnitaryExpression) -> Self {
415 value.inner
416 }
417}
418
419impl Deref for UnitaryExpression {
420 type Target = NamedExpression;
421
422 fn deref(&self) -> &Self::Target {
423 &self.inner
424 }
425}
426
427impl DerefMut for UnitaryExpression {
428 fn deref_mut(&mut self) -> &mut Self::Target {
429 &mut self.inner
430 }
431}
432
433impl From<UnitaryExpression> for TensorExpression {
434 fn from(value: UnitaryExpression) -> Self {
435 let UnitaryExpression { inner, radices } = value;
436 let indices = radices
437 .iter()
438 .map(|r| (IndexDirection::Output, usize::from(*r)))
439 .chain(
440 radices
441 .iter()
442 .map(|r| (IndexDirection::Input, usize::from(*r))),
443 )
444 .enumerate()
445 .map(|(i, (d, r))| TensorIndex::new(d, i, r))
446 .collect();
447 TensorExpression::from_raw(indices, inner)
448 }
449}
450
451impl TryFrom<TensorExpression> for UnitaryExpression {
452 type Error = String;
454
455 fn try_from(value: TensorExpression) -> Result<Self, Self::Error> {
456 let mut input_radices = vec![];
457 let mut output_radices = vec![];
458 for idx in value.indices() {
459 match idx.direction() {
460 IndexDirection::Input => {
461 input_radices.push(idx.index_size());
462 }
463 IndexDirection::Output => {
464 output_radices.push(idx.index_size());
465 }
466 _ => {
467 return Err(String::from(
468 "Cannot convert a tensor with non-input, non-output indices to an isometry.",
469 ));
470 }
471 }
472 }
473
474 if input_radices != output_radices {
475 return Err(String::from(
476 "Non-square matrix tensor cannot be converted to a unitary.",
477 ));
478 }
479
480 Ok(UnitaryExpression {
481 inner: value.into(),
482 radices: input_radices.into(),
483 })
484 }
485}
486
487impl<C: ComplexScalar> From<UnitaryMatrix<C>> for UnitaryExpression {
488 fn from(value: UnitaryMatrix<C>) -> Self {
489 let mut body = vec![Vec::with_capacity(value.ncols()); value.nrows()];
490 for col in value.col_iter() {
491 for (row_id, elem) in col.iter().enumerate() {
492 body[row_id].push(ComplexExpression::from(*elem));
493 }
494 }
495 let mut flat_body = Vec::with_capacity(value.ncols() * value.nrows());
496 for row in body.into_iter() {
497 for elem in row.into_iter() {
498 flat_body.push(elem);
499 }
500 }
501 let inner = NamedExpression::new("Constant", vec![], flat_body);
502
503 UnitaryExpression {
504 inner,
505 radices: value.radices(),
506 }
507 }
508}
509
510impl QuditSystem for UnitaryExpression {
511 fn radices(&self) -> Radices {
512 self.radices.clone()
513 }
514}
515
516#[cfg(feature = "python")]
517mod python {
518 use super::*;
519 use crate::python::PyExpressionRegistrar;
520 use ndarray::ArrayViewMut2;
521 use numpy::PyArray2;
522 use numpy::PyArrayMethods;
523 use pyo3::prelude::*;
524 use pyo3::types::PyTuple;
525 use qudit_core::Radix;
526 use qudit_core::c64;
527
528 #[pyclass]
529 #[pyo3(name = "UnitaryExpression")]
530 pub struct PyUnitaryExpression {
531 expr: UnitaryExpression,
532 }
533
534 #[pymethods]
535 impl PyUnitaryExpression {
536 #[new]
537 fn new(expr: String) -> Self {
538 Self {
539 expr: UnitaryExpression::new(expr),
540 }
541 }
542
543 #[staticmethod]
544 fn identity(name: String, radices: Vec<usize>) -> Self {
545 Self {
546 expr: UnitaryExpression::identity(name, radices),
547 }
548 }
549
550 #[pyo3(signature = (*args))]
551 fn __call__<'py>(&self, args: &Bound<'py, PyTuple>) -> PyResult<Bound<'py, PyArray2<c64>>> {
552 let py = args.py();
553 let args: Vec<f64> = args.extract()?;
554 let unitary = self.expr.eval(&args);
555 let py_array: Bound<'py, PyArray2<c64>> =
556 PyArray2::zeros(py, (unitary.dimension(), unitary.dimension()), false);
557
558 {
559 let mut readwrite = py_array.readwrite();
560 let mut py_array_view: ArrayViewMut2<c64> = readwrite.as_array_mut();
561
562 for (j, col) in unitary.col_iter().enumerate() {
563 for (i, val) in col.iter().enumerate() {
564 py_array_view[[i, j]] = *val;
565 }
566 }
567 }
568
569 Ok(py_array)
570 }
571
572 fn num_params(&self) -> usize {
573 self.expr.num_params()
574 }
575
576 fn name(&self) -> String {
577 self.expr.name().to_string()
578 }
579
580 fn radices(&self) -> Vec<Radix> {
581 self.expr.radices().to_vec()
582 }
583
584 fn dimension(&self) -> usize {
585 self.expr.dimension()
586 }
587
588 fn transpose(&mut self) {
589 self.expr.transpose();
590 }
591
592 fn dagger(&mut self) {
593 self.expr.dagger();
594 }
595
596 fn otimes(&self, other: &PyUnitaryExpression) -> Self {
597 Self {
598 expr: self.expr.otimes(&other.expr),
599 }
600 }
601
602 fn dot(&self, other: &PyUnitaryExpression) -> Self {
603 Self {
604 expr: self.expr.dot(&other.expr),
605 }
606 }
607
608 fn embed(
609 &mut self,
610 sub_matrix: &PyUnitaryExpression,
611 top_left_row_idx: usize,
612 top_left_col_idx: usize,
613 ) {
614 self.expr
615 .embed(sub_matrix.expr.clone(), top_left_row_idx, top_left_col_idx);
616 }
617
618 fn __repr__(&self) -> String {
624 format!(
625 "UnitaryExpression(name='{}', radices={:?}, params={})",
626 self.expr.name(),
627 self.expr.radices().to_vec(),
628 self.expr.num_params()
629 )
630 }
631
632 fn __matmul__(&self, other: UnitaryExpression) -> PyUnitaryExpression {
633 Self {
634 expr: self.expr.dot(&other),
635 }
636 }
637 }
638
639 impl From<UnitaryExpression> for PyUnitaryExpression {
640 fn from(value: UnitaryExpression) -> Self {
641 PyUnitaryExpression { expr: value }
642 }
643 }
644
645 impl From<PyUnitaryExpression> for UnitaryExpression {
646 fn from(value: PyUnitaryExpression) -> Self {
647 value.expr
648 }
649 }
650
651 impl<'py> IntoPyObject<'py> for UnitaryExpression {
652 type Target = <PyUnitaryExpression as IntoPyObject<'py>>::Target;
653 type Output = Bound<'py, Self::Target>;
654 type Error = PyErr;
655
656 fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
657 let py_expr = PyUnitaryExpression::from(self);
658 Bound::new(py, py_expr)
659 }
660 }
661
662 impl<'a, 'py> FromPyObject<'a, 'py> for UnitaryExpression {
663 type Error = PyErr;
664
665 fn extract(ob: Borrowed<'a, 'py, PyAny>) -> PyResult<Self> {
666 let py_expr: PyRef<PyUnitaryExpression> = ob.extract()?;
667 Ok(py_expr.expr.clone())
668 }
669 }
670
671 fn register(parent_module: &Bound<'_, PyModule>) -> PyResult<()> {
673 parent_module.add_class::<PyUnitaryExpression>()?;
674 Ok(())
675 }
676 inventory::submit!(PyExpressionRegistrar { func: register });
677}