1use std::collections::BTreeSet;
2use std::collections::HashMap;
3use std::ops::Deref;
4use std::ops::DerefMut;
5
6use super::ComplexExpression;
7use crate::Expression;
8use crate::GenerationShape;
9use crate::expressions::JittableExpression;
10use crate::expressions::NamedExpression;
11use crate::index::IndexDirection;
12use crate::index::IndexSize;
13use crate::index::TensorIndex;
14use crate::qgl::Expression as CiscExpression;
15use crate::qgl::parse_qobj;
16
17#[derive(Debug, Clone, PartialEq, Eq)]
18pub struct TensorExpression {
19 indices: Vec<TensorIndex>,
20 inner: NamedExpression,
21}
22
23impl JittableExpression for TensorExpression {
24 fn generation_shape(&self) -> GenerationShape {
25 self.indices().into()
26 }
27}
28
29impl AsRef<NamedExpression> for TensorExpression {
30 fn as_ref(&self) -> &NamedExpression {
31 &self.inner
32 }
33}
34
35impl Deref for TensorExpression {
36 type Target = NamedExpression;
37
38 fn deref(&self) -> &Self::Target {
39 &self.inner
40 }
41}
42
43impl DerefMut for TensorExpression {
44 fn deref_mut(&mut self) -> &mut Self::Target {
45 &mut self.inner
46 }
47}
48
49impl TensorExpression {
50 pub fn new<T: AsRef<str>>(input: T) -> Self {
51 let qdef = match parse_qobj(input.as_ref()) {
52 Ok(qdef) => qdef,
53 Err(e) => panic!("Parsing Error: {}", e),
54 };
55
56 let indices = qdef.get_tensor_indices();
57 let name = qdef.name;
58 let variables = qdef.variables;
59 let element_wise = qdef.body.into_element_wise();
60 let body: Vec<ComplexExpression> = match element_wise {
61 CiscExpression::Vector(vec) => vec.into_iter().map(ComplexExpression::new).collect(),
62 CiscExpression::Matrix(mat) => mat
63 .into_iter()
64 .flat_map(|row| {
65 row.into_iter()
66 .map(ComplexExpression::new)
67 .collect::<Vec<_>>()
68 })
69 .collect(),
70 CiscExpression::Tensor(tensor) => tensor
71 .into_iter()
72 .flat_map(|row| {
73 row.into_iter()
74 .flat_map(|col| {
75 col.into_iter()
76 .map(ComplexExpression::new)
77 .collect::<Vec<_>>()
78 })
79 .collect::<Vec<_>>()
80 })
81 .collect(),
82 _ => panic!("Tensor body must be a vector"),
83 };
84
85 TensorExpression {
86 indices,
87 inner: NamedExpression::new(name, variables, body),
88 }
89 }
90
91 pub fn from_raw(indices: Vec<TensorIndex>, inner: NamedExpression) -> Self {
92 TensorExpression { indices, inner }
93 }
94
95 pub fn indices(&self) -> &[TensorIndex] {
96 &self.indices
97 }
98
99 pub fn dimensions(&self) -> Vec<IndexSize> {
100 self.indices.iter().map(|idx| idx.index_size()).collect()
101 }
102
103 pub fn rank(&self) -> usize {
104 self.indices.len()
105 }
106
107 pub fn generation_shape(&self) -> GenerationShape {
108 self.indices().into()
109 }
110
111 pub fn tensor_strides(&self) -> Vec<usize> {
141 let mut strides = Vec::with_capacity(self.indices.len());
142 let mut current_stride = 1;
143 for &index in self.indices.iter().rev() {
144 strides.push(current_stride);
145 current_stride *= index.index_size();
146 }
147 strides.reverse();
148 strides
149 }
150
151 pub fn reindex(&mut self, new_indices: Vec<TensorIndex>) -> &mut Self {
152 assert_eq!(
153 new_indices
154 .iter()
155 .map(|idx| idx.index_size())
156 .product::<usize>(),
157 self.num_elements(),
158 "Product of new dimensions must match the total number of elements in the tensor body."
159 );
160
161 let mut last_direction = IndexDirection::Derivative;
163 for index in &new_indices {
164 let current_direction = index.direction();
165 if current_direction < last_direction {
166 panic!(
167 "New indices are not ordered correctly. Expected order: Derv, Batch, Output, Input."
168 );
169 }
170 last_direction = current_direction;
171 }
172
173 self.indices = new_indices;
174 self
175 }
176
177 pub fn permute(&mut self, perm: &[usize], redirection: Vec<IndexDirection>) -> &mut Self {
179 assert_eq!(perm.len(), self.rank());
180
181 let original_strides = self.tensor_strides();
183 let original_dimensions = self.dimensions();
184
185 let reordered_indices: Vec<TensorIndex> = perm
187 .iter()
188 .enumerate()
189 .map(|(id, &p_id)| {
190 TensorIndex::new(
191 redirection[id],
192 self.indices[p_id].index_id(),
193 self.indices[p_id].index_size(),
194 )
195 })
196 .collect();
197
198 self.reindex(reordered_indices);
200
201 let new_strides = self.tensor_strides();
203
204 let mut elem_perm: Vec<usize> = Vec::with_capacity(self.num_elements());
206 for i in 0..self.num_elements() {
207 let mut original_coordinate: Vec<usize> = Vec::with_capacity(self.rank());
208 let mut temp_i = i;
209 for d_idx in 0..self.rank() {
210 original_coordinate
211 .push((temp_i / original_strides[d_idx]) % original_dimensions[d_idx]);
212 temp_i %= original_strides[d_idx]; }
214
215 let mut permuted_coordinate: Vec<usize> = vec![0; self.rank()];
219 for j in 0..self.rank() {
220 permuted_coordinate[j] = original_coordinate[perm[j]];
221 }
222
223 let mut new_linear_idx = 0;
225 for d_idx in 0..self.rank() {
226 new_linear_idx += permuted_coordinate[d_idx] * new_strides[d_idx];
227 }
228 elem_perm.push(new_linear_idx);
229 }
230
231 self.apply_element_permutation(&elem_perm);
232 self
233 }
234
235 pub fn stack_with_identity(&self, positions: &[usize], new_dim: usize) -> TensorExpression {
236 let (nrows, ncols) = match self.generation_shape() {
238 GenerationShape::Matrix(r, c) => (r, c),
239 _ => panic!(
240 "TensorExpression must be a square matrix to use stack_with_identity, got {:?}",
241 self.generation_shape()
242 ),
243 };
244 assert_eq!(
245 nrows, ncols,
246 "TensorExpression must be a square matrix for stack_with_identity"
247 );
248 assert!(
250 positions.len() <= new_dim,
251 "Cannot place tensor in more locations than length of new dimension."
252 );
253
254 let mut sorted_positions = positions.to_vec();
256 sorted_positions.sort_unstable();
257 assert!(
258 sorted_positions.iter().collect::<BTreeSet<_>>().len() == sorted_positions.len(),
259 "Positions must be unique"
260 );
261
262 let mut identity = Vec::with_capacity(nrows * ncols);
264 for i in 0..nrows {
265 for j in 0..ncols {
266 if i == j {
267 identity.push(ComplexExpression::one());
268 } else {
269 identity.push(ComplexExpression::zero());
270 }
271 }
272 }
273
274 let mut expressions = Vec::with_capacity(nrows * ncols * new_dim);
276 for i in 0..new_dim {
277 if positions.contains(&i) {
278 expressions.extend(self.elements().iter().cloned());
279 } else {
280 expressions.extend(identity.iter().cloned());
281 }
282 }
283
284 let new_indices =
285 [TensorIndex::new(IndexDirection::Batch, 0, new_dim)]
286 .into_iter()
287 .chain(self.indices().iter().map(|idx| {
288 TensorIndex::new(idx.direction(), idx.index_id() + 1, idx.index_size())
289 }))
290 .collect();
291
292 TensorExpression {
293 indices: new_indices,
294 inner: NamedExpression::new(
295 format!("Stacked_{}", self.name()),
296 self.variables().to_owned(),
297 expressions,
298 ),
299 }
300 }
301
302 pub fn partial_trace(&self, dimension_pairs: &[(usize, usize)]) -> TensorExpression {
303 if dimension_pairs.is_empty() {
304 return self.clone();
305 }
306
307 let in_dims = self.indices();
308 let num_dims = in_dims.len();
309
310 let mut traced_dim_indices = std::collections::HashSet::new();
312 for &(d1, d2) in dimension_pairs {
313 if d1 >= num_dims || d2 >= num_dims {
314 panic!(
315 "Dimension index out of bounds: ({}, {}) for dimensions {:?}",
316 d1, d2, in_dims
317 );
318 }
319 if in_dims[d1] != in_dims[d2] {
320 panic!(
321 "Dimensions being traced must have the same size: D{} ({}) != D{} ({})",
322 d1, in_dims[d1], d2, in_dims[d2]
323 );
324 }
325 if !traced_dim_indices.insert(d1) {
326 panic!("Dimension {} appears more than once as a trace source.", d1);
327 }
328 if !traced_dim_indices.insert(d2) {
329 panic!("Dimension {} appears more than once as a trace target.", d2);
330 }
331 }
332
333 let remaining_dims_indices: Vec<usize> = (0..num_dims)
334 .filter(|&i| !traced_dim_indices.contains(&i))
335 .collect();
336
337 let out_dims: Vec<usize> = remaining_dims_indices
338 .iter()
339 .map(|&i| in_dims[i].index_size())
340 .collect();
341
342 let new_body_len = out_dims.iter().product::<usize>();
343 let mut new_body = vec![ComplexExpression::zero(); new_body_len];
344
345 let mut in_strides = self.tensor_strides();
346 in_strides.insert(0, self.num_elements());
347
348 let out_strides: Vec<usize> = {
350 let mut strides = vec![0; out_dims.len()];
351 strides[out_dims.len() - 1] = 1; for i in (0..out_dims.len() - 1).rev() {
353 strides[i] = strides[i + 1] * out_dims[i + 1];
354 }
355 strides
356 };
357
358 for (i, expr) in self.elements().iter().enumerate() {
360 let original_coordinate: Vec<usize> = (0..in_dims.len())
361 .map(|d| (i % in_strides[d + 1]) / in_strides[d])
362 .rev()
363 .collect();
364
365 if dimension_pairs
366 .iter()
367 .all(|(d1, d2)| original_coordinate[*d1] == original_coordinate[*d2])
368 {
369 let mut new_linear_idx = 0;
370 for (&idx, stride) in remaining_dims_indices.iter().zip(&out_strides) {
371 new_linear_idx += original_coordinate[idx] * stride;
372 }
373
374 new_body[new_linear_idx] += expr;
375 }
376 }
377
378 let new_indices = remaining_dims_indices.iter().map(|x| in_dims[*x]).collect();
379
380 TensorExpression {
381 indices: new_indices,
382 inner: NamedExpression::new(
383 format!("PartialTraced_{}", self.name()),
384 self.variables().to_owned(),
385 new_body,
386 ),
387 }
388 }
389
390 pub fn substitute_parameters<S: AsRef<str>, C: AsRef<Expression>>(
392 &self,
393 variables: &[S],
394 values: &[C],
395 ) -> Self {
396 let sub_map: HashMap<_, _> = self
397 .variables()
398 .iter()
399 .zip(values.iter())
400 .map(|(k, v)| (Expression::Variable(k.to_string()), v.as_ref()))
401 .collect();
402
403 let mut new_body = vec![];
404 for expr in self.elements().iter() {
405 let mut new_expr = None;
406 for (var, value) in &sub_map {
407 match new_expr {
408 None => new_expr = Some(expr.substitute(var, value)),
409 Some(ref mut e) => *e = e.substitute(var, value),
410 }
411 }
412 match new_expr {
413 None => new_body.push(expr.clone()),
414 Some(e) => new_body.push(e),
415 }
416 }
417
418 let new_variables = variables.iter().map(|s| s.as_ref().to_string()).collect();
419
420 TensorExpression {
421 indices: self.indices.clone(),
422 inner: NamedExpression::new(format!("{}_subbed", self.name()), new_variables, new_body),
423 }
424 }
425
426 pub fn destruct(
427 self,
428 ) -> (
429 String,
430 Vec<String>,
431 Vec<ComplexExpression>,
432 Vec<TensorIndex>,
433 ) {
434 let Self { inner, indices } = self;
435 let (name, variables, body) = inner.destruct();
436 (name, variables, body, indices)
437 }
438}
439
440impl From<TensorExpression> for NamedExpression {
447 fn from(value: TensorExpression) -> Self {
448 value.inner
449 }
450}