qudit_expr/expressions/
tensor.rs

1use std::collections::BTreeSet;
2use std::collections::HashMap;
3use std::ops::Deref;
4use std::ops::DerefMut;
5
6use super::ComplexExpression;
7use crate::Expression;
8use crate::GenerationShape;
9use crate::expressions::JittableExpression;
10use crate::expressions::NamedExpression;
11use crate::index::IndexDirection;
12use crate::index::IndexSize;
13use crate::index::TensorIndex;
14use crate::qgl::Expression as CiscExpression;
15use crate::qgl::parse_qobj;
16
17#[derive(Debug, Clone, PartialEq, Eq)]
18pub struct TensorExpression {
19    indices: Vec<TensorIndex>,
20    inner: NamedExpression,
21}
22
23impl JittableExpression for TensorExpression {
24    fn generation_shape(&self) -> GenerationShape {
25        self.indices().into()
26    }
27}
28
29impl AsRef<NamedExpression> for TensorExpression {
30    fn as_ref(&self) -> &NamedExpression {
31        &self.inner
32    }
33}
34
35impl Deref for TensorExpression {
36    type Target = NamedExpression;
37
38    fn deref(&self) -> &Self::Target {
39        &self.inner
40    }
41}
42
43impl DerefMut for TensorExpression {
44    fn deref_mut(&mut self) -> &mut Self::Target {
45        &mut self.inner
46    }
47}
48
49impl TensorExpression {
50    pub fn new<T: AsRef<str>>(input: T) -> Self {
51        let qdef = match parse_qobj(input.as_ref()) {
52            Ok(qdef) => qdef,
53            Err(e) => panic!("Parsing Error: {}", e),
54        };
55
56        let indices = qdef.get_tensor_indices();
57        let name = qdef.name;
58        let variables = qdef.variables;
59        let element_wise = qdef.body.into_element_wise();
60        let body: Vec<ComplexExpression> = match element_wise {
61            CiscExpression::Vector(vec) => vec.into_iter().map(ComplexExpression::new).collect(),
62            CiscExpression::Matrix(mat) => mat
63                .into_iter()
64                .flat_map(|row| {
65                    row.into_iter()
66                        .map(ComplexExpression::new)
67                        .collect::<Vec<_>>()
68                })
69                .collect(),
70            CiscExpression::Tensor(tensor) => tensor
71                .into_iter()
72                .flat_map(|row| {
73                    row.into_iter()
74                        .flat_map(|col| {
75                            col.into_iter()
76                                .map(ComplexExpression::new)
77                                .collect::<Vec<_>>()
78                        })
79                        .collect::<Vec<_>>()
80                })
81                .collect(),
82            _ => panic!("Tensor body must be a vector"),
83        };
84
85        TensorExpression {
86            indices,
87            inner: NamedExpression::new(name, variables, body),
88        }
89    }
90
91    pub fn from_raw(indices: Vec<TensorIndex>, inner: NamedExpression) -> Self {
92        TensorExpression { indices, inner }
93    }
94
95    pub fn indices(&self) -> &[TensorIndex] {
96        &self.indices
97    }
98
99    pub fn dimensions(&self) -> Vec<IndexSize> {
100        self.indices.iter().map(|idx| idx.index_size()).collect()
101    }
102
103    pub fn rank(&self) -> usize {
104        self.indices.len()
105    }
106
107    pub fn generation_shape(&self) -> GenerationShape {
108        self.indices().into()
109    }
110
111    /// Calculates the strides for each dimension of the tensor.
112    ///
113    /// The stride for a dimension is the number of elements one must skip in the
114    /// flattened array to move to the next element along that dimension.
115    /// The strides are calculated such that the last dimension has a stride of 1,
116    /// the second to last dimension has a stride equal to the size of the last dimension,
117    /// and so on.
118    ///
119    /// # Examples
120    ///
121    /// ```
122    /// use qudit_expr::TensorExpression;
123    ///
124    /// let mut tensor = TensorExpression::new("example() {[
125    ///         [
126    ///             [ 1, 0, 0, 0 ],
127    ///             [ 0, 1, 0, 0 ],
128    ///             [ 0, 0, 1, 0 ],
129    ///             [ 0, 0, 0, 1 ],
130    ///         ], [
131    ///             [ 1, 0, 0, 0 ],
132    ///             [ 0, 1, 0, 0 ],
133    ///             [ 0, 0, 0, 1 ],
134    ///             [ 0, 0, 1, 0 ],
135    ///         ]
136    ///     ]}");
137    ///
138    /// assert_eq!(tensor.tensor_strides(), vec![16, 8, 4, 2, 1]);
139    /// ```
140    pub fn tensor_strides(&self) -> Vec<usize> {
141        let mut strides = Vec::with_capacity(self.indices.len());
142        let mut current_stride = 1;
143        for &index in self.indices.iter().rev() {
144            strides.push(current_stride);
145            current_stride *= index.index_size();
146        }
147        strides.reverse();
148        strides
149    }
150
151    pub fn reindex(&mut self, new_indices: Vec<TensorIndex>) -> &mut Self {
152        assert_eq!(
153            new_indices
154                .iter()
155                .map(|idx| idx.index_size())
156                .product::<usize>(),
157            self.num_elements(),
158            "Product of new dimensions must match the total number of elements in the tensor body."
159        );
160
161        // Assert that all indices are lined up correctly (Derv | Batch | Output | Input)
162        let mut last_direction = IndexDirection::Derivative;
163        for index in &new_indices {
164            let current_direction = index.direction();
165            if current_direction < last_direction {
166                panic!(
167                    "New indices are not ordered correctly. Expected order: Derv, Batch, Output, Input."
168                );
169            }
170            last_direction = current_direction;
171        }
172
173        self.indices = new_indices;
174        self
175    }
176
177    // Really a fused reshape and permutation
178    pub fn permute(&mut self, perm: &[usize], redirection: Vec<IndexDirection>) -> &mut Self {
179        assert_eq!(perm.len(), self.rank());
180
181        // Store original strides and dimensions for body permutation before `self.indices` is modified
182        let original_strides = self.tensor_strides();
183        let original_dimensions = self.dimensions();
184
185        //Reorder the TensorIndex objects based on `perm`
186        let reordered_indices: Vec<TensorIndex> = perm
187            .iter()
188            .enumerate()
189            .map(|(id, &p_id)| {
190                TensorIndex::new(
191                    redirection[id],
192                    self.indices[p_id].index_id(),
193                    self.indices[p_id].index_size(),
194                )
195            })
196            .collect();
197
198        // Update self.indices with the newly reordered and direction-assigned indices
199        self.reindex(reordered_indices);
200
201        // Get the new strides based on the updated `self.indices`
202        let new_strides = self.tensor_strides();
203
204        // Permute elements in the body based on the new index order.
205        let mut elem_perm: Vec<usize> = Vec::with_capacity(self.num_elements());
206        for i in 0..self.num_elements() {
207            let mut original_coordinate: Vec<usize> = Vec::with_capacity(self.rank());
208            let mut temp_i = i;
209            for d_idx in 0..self.rank() {
210                original_coordinate
211                    .push((temp_i / original_strides[d_idx]) % original_dimensions[d_idx]);
212                temp_i %= original_strides[d_idx]; // Update temp_i for next dimension
213            }
214
215            // Map original coordinate components to their new positions according to `perm`.
216            // If `perm[j]` is `k`, it means the `j`-th dimension in the new order
217            // corresponds to the `k`-th dimension in the original order.
218            let mut permuted_coordinate: Vec<usize> = vec![0; self.rank()];
219            for j in 0..self.rank() {
220                permuted_coordinate[j] = original_coordinate[perm[j]];
221            }
222
223            // Calculate new linear index using the permuted coordinate and new strides
224            let mut new_linear_idx = 0;
225            for d_idx in 0..self.rank() {
226                new_linear_idx += permuted_coordinate[d_idx] * new_strides[d_idx];
227            }
228            elem_perm.push(new_linear_idx);
229        }
230
231        self.apply_element_permutation(&elem_perm);
232        self
233    }
234
235    pub fn stack_with_identity(&self, positions: &[usize], new_dim: usize) -> TensorExpression {
236        // Assertions for input validity
237        let (nrows, ncols) = match self.generation_shape() {
238            GenerationShape::Matrix(r, c) => (r, c),
239            _ => panic!(
240                "TensorExpression must be a square matrix to use stack_with_identity, got {:?}",
241                self.generation_shape()
242            ),
243        };
244        assert_eq!(
245            nrows, ncols,
246            "TensorExpression must be a square matrix for stack_with_identity"
247        );
248        // assert_eq!(nrows, self.dimensions.dimension(), "Matrix dimension must match qudit dimensions for stack_with_identity");
249        assert!(
250            positions.len() <= new_dim,
251            "Cannot place tensor in more locations than length of new dimension."
252        );
253
254        // Ensure positions are unique
255        let mut sorted_positions = positions.to_vec();
256        sorted_positions.sort_unstable();
257        assert!(
258            sorted_positions.iter().collect::<BTreeSet<_>>().len() == sorted_positions.len(),
259            "Positions must be unique"
260        );
261
262        // Construct identity expression
263        let mut identity = Vec::with_capacity(nrows * ncols);
264        for i in 0..nrows {
265            for j in 0..ncols {
266                if i == j {
267                    identity.push(ComplexExpression::one());
268                } else {
269                    identity.push(ComplexExpression::zero());
270                }
271            }
272        }
273
274        // construct larger tensor
275        let mut expressions = Vec::with_capacity(nrows * ncols * new_dim);
276        for i in 0..new_dim {
277            if positions.contains(&i) {
278                expressions.extend(self.elements().iter().cloned());
279            } else {
280                expressions.extend(identity.iter().cloned());
281            }
282        }
283
284        let new_indices =
285            [TensorIndex::new(IndexDirection::Batch, 0, new_dim)]
286                .into_iter()
287                .chain(self.indices().iter().map(|idx| {
288                    TensorIndex::new(idx.direction(), idx.index_id() + 1, idx.index_size())
289                }))
290                .collect();
291
292        TensorExpression {
293            indices: new_indices,
294            inner: NamedExpression::new(
295                format!("Stacked_{}", self.name()),
296                self.variables().to_owned(),
297                expressions,
298            ),
299        }
300    }
301
302    pub fn partial_trace(&self, dimension_pairs: &[(usize, usize)]) -> TensorExpression {
303        if dimension_pairs.is_empty() {
304            return self.clone();
305        }
306
307        let in_dims = self.indices();
308        let num_dims = in_dims.len();
309
310        // 1. Validate dimension_pairs and identify dimensions to keep/trace
311        let mut traced_dim_indices = std::collections::HashSet::new();
312        for &(d1, d2) in dimension_pairs {
313            if d1 >= num_dims || d2 >= num_dims {
314                panic!(
315                    "Dimension index out of bounds: ({}, {}) for dimensions {:?}",
316                    d1, d2, in_dims
317                );
318            }
319            if in_dims[d1] != in_dims[d2] {
320                panic!(
321                    "Dimensions being traced must have the same size: D{} ({}) != D{} ({})",
322                    d1, in_dims[d1], d2, in_dims[d2]
323                );
324            }
325            if !traced_dim_indices.insert(d1) {
326                panic!("Dimension {} appears more than once as a trace source.", d1);
327            }
328            if !traced_dim_indices.insert(d2) {
329                panic!("Dimension {} appears more than once as a trace target.", d2);
330            }
331        }
332
333        let remaining_dims_indices: Vec<usize> = (0..num_dims)
334            .filter(|&i| !traced_dim_indices.contains(&i))
335            .collect();
336
337        let out_dims: Vec<usize> = remaining_dims_indices
338            .iter()
339            .map(|&i| in_dims[i].index_size())
340            .collect();
341
342        let new_body_len = out_dims.iter().product::<usize>();
343        let mut new_body = vec![ComplexExpression::zero(); new_body_len];
344
345        let mut in_strides = self.tensor_strides();
346        in_strides.insert(0, self.num_elements());
347
348        // Calculate output strides for new linear index conversion (row-major)
349        let out_strides: Vec<usize> = {
350            let mut strides = vec![0; out_dims.len()];
351            strides[out_dims.len() - 1] = 1; // Stride for the last dimension is 1
352            for i in (0..out_dims.len() - 1).rev() {
353                strides[i] = strides[i + 1] * out_dims[i + 1];
354            }
355            strides
356        };
357
358        // 3. Iterate through original tensor elements
359        for (i, expr) in self.elements().iter().enumerate() {
360            let original_coordinate: Vec<usize> = (0..in_dims.len())
361                .map(|d| (i % in_strides[d + 1]) / in_strides[d])
362                .rev()
363                .collect();
364
365            if dimension_pairs
366                .iter()
367                .all(|(d1, d2)| original_coordinate[*d1] == original_coordinate[*d2])
368            {
369                let mut new_linear_idx = 0;
370                for (&idx, stride) in remaining_dims_indices.iter().zip(&out_strides) {
371                    new_linear_idx += original_coordinate[idx] * stride;
372                }
373
374                new_body[new_linear_idx] += expr;
375            }
376        }
377
378        let new_indices = remaining_dims_indices.iter().map(|x| in_dims[*x]).collect();
379
380        TensorExpression {
381            indices: new_indices,
382            inner: NamedExpression::new(
383                format!("PartialTraced_{}", self.name()),
384                self.variables().to_owned(),
385                new_body,
386            ),
387        }
388    }
389
390    /// replace all variables with values, new variables given in variables input
391    pub fn substitute_parameters<S: AsRef<str>, C: AsRef<Expression>>(
392        &self,
393        variables: &[S],
394        values: &[C],
395    ) -> Self {
396        let sub_map: HashMap<_, _> = self
397            .variables()
398            .iter()
399            .zip(values.iter())
400            .map(|(k, v)| (Expression::Variable(k.to_string()), v.as_ref()))
401            .collect();
402
403        let mut new_body = vec![];
404        for expr in self.elements().iter() {
405            let mut new_expr = None;
406            for (var, value) in &sub_map {
407                match new_expr {
408                    None => new_expr = Some(expr.substitute(var, value)),
409                    Some(ref mut e) => *e = e.substitute(var, value),
410                }
411            }
412            match new_expr {
413                None => new_body.push(expr.clone()),
414                Some(e) => new_body.push(e),
415            }
416        }
417
418        let new_variables = variables.iter().map(|s| s.as_ref().to_string()).collect();
419
420        TensorExpression {
421            indices: self.indices.clone(),
422            inner: NamedExpression::new(format!("{}_subbed", self.name()), new_variables, new_body),
423        }
424    }
425
426    pub fn destruct(
427        self,
428    ) -> (
429        String,
430        Vec<String>,
431        Vec<ComplexExpression>,
432        Vec<TensorIndex>,
433    ) {
434        let Self { inner, indices } = self;
435        let (name, variables, body) = inner.destruct();
436        (name, variables, body, indices)
437    }
438}
439
440// impl<C: ComplexScalar> From<UnitaryMatrix<C>> for TensorExpression {
441//     fn from(utry: UnitaryMatrix<C>) -> Self {
442//         UnitaryExpression::from(utry).to_tensor_expression()
443//     }
444// }
445
446impl From<TensorExpression> for NamedExpression {
447    fn from(value: TensorExpression) -> Self {
448        value.inner
449    }
450}