qudit_expr/
cache.rs

1use std::collections::BTreeMap;
2use std::sync::{Arc, Mutex};
3
4use qudit_core::RealScalar;
5
6use crate::TensorExpression;
7use crate::analysis::simplify_expressions_iter;
8use crate::codegen::CompilableUnit;
9use crate::expressions::{ExpressionBody, NamedExpression};
10use crate::index::{IndexSize, TensorIndex};
11use crate::{
12    ComplexExpression, DifferentiationLevel, Expression, GRADIENT, GenerationShape, HESSIAN,
13    Module, ModuleBuilder, WriteFunc,
14};
15
16pub type ExpressionId = u64;
17
18pub struct CachedExpressionBody {
19    original: ExpressionBody,
20    func: Option<Vec<Expression>>,
21    grad: Option<Vec<Expression>>,
22    hess: Option<Vec<Expression>>,
23}
24
25impl CachedExpressionBody {
26    pub fn new(original: impl Into<ExpressionBody>) -> CachedExpressionBody {
27        let original = original.into();
28        CachedExpressionBody {
29            original,
30            func: None,
31            grad: None,
32            hess: None,
33        }
34    }
35
36    pub fn num_elements(&self) -> usize {
37        self.original.num_elements()
38    }
39
40    // Simplify and differentiate to prepare expression to evaluate up to diff_level.
41    pub fn prepare(&mut self, variables: &[String], diff_level: DifferentiationLevel) -> bool {
42        let mut has_changed = false;
43
44        if self.func.is_none() {
45            self.func = Some(simplify_expressions_iter(
46                self.original.iter().flat_map(|c| [&c.real, &c.imag]),
47            ));
48
49            has_changed = true;
50        }
51
52        if diff_level >= GRADIENT && self.grad.is_none() {
53            let mut grad_exprs = vec![];
54            for variable in variables {
55                for expr in self.original.iter() {
56                    let grad_expr = expr.differentiate(variable);
57                    grad_exprs.push(grad_expr);
58                }
59            }
60
61            self.grad = Some(simplify_expressions_iter(
62                self.original
63                    .iter()
64                    .chain(grad_exprs.iter())
65                    .flat_map(|c| [&c.real, &c.imag]),
66            ));
67
68            has_changed = true;
69        }
70
71        if diff_level >= HESSIAN && self.hess.is_some() {
72            let mut grad_exprs = vec![];
73            for variable in variables {
74                for expr in self.original.iter() {
75                    let grad_expr = expr.differentiate(variable);
76                    grad_exprs.push(grad_expr);
77                }
78            }
79
80            let mut hess_exprs = vec![];
81            for variable in variables {
82                for expr in grad_exprs.iter() {
83                    let hess_expr = expr.differentiate(variable);
84                    hess_exprs.push(hess_expr);
85                }
86            }
87
88            self.hess = Some(simplify_expressions_iter(
89                self.original
90                    .iter()
91                    .chain(grad_exprs.iter())
92                    .chain(hess_exprs.iter())
93                    .flat_map(|c| [&c.real, &c.imag]),
94            ));
95
96            has_changed = true;
97        }
98
99        has_changed
100    }
101}
102
103impl<B: Into<ExpressionBody>> From<B> for CachedExpressionBody {
104    fn from(value: B) -> Self {
105        CachedExpressionBody::new(value)
106    }
107}
108
109pub struct CachedTensorExpression {
110    name: String,
111    variables: Vec<String>,
112    indices: Vec<TensorIndex>,
113    expressions: CachedExpressionBody,
114    id_lookup: BTreeMap<ExpressionId, (Vec<usize>, GenerationShape)>,
115    base_id: ExpressionId,
116}
117
118impl CachedTensorExpression {
119    pub fn new<E: Into<TensorExpression>>(expr: E, base_id: ExpressionId) -> Self {
120        let expr = expr.into();
121        let mut id_lookup = BTreeMap::new();
122        id_lookup.insert(base_id, ((0..expr.rank()).collect(), expr.indices().into()));
123        let (name, variables, body, indices) = expr.destruct();
124        CachedTensorExpression {
125            name,
126            variables,
127            indices,
128            expressions: body.into(),
129            id_lookup,
130            base_id,
131        }
132    }
133
134    // TODO: deduplicate code with tensorexpression
135    pub fn indices(&self) -> &[TensorIndex] {
136        &self.indices
137    }
138
139    // TODO: deduplicate code with tensorexpression
140    pub fn rank(&self) -> usize {
141        self.indices.len()
142    }
143
144    pub fn num_params(&self) -> usize {
145        self.variables.len()
146    }
147
148    // TODO: deduplicate code with tensorexpression
149    pub fn dimensions(&self) -> Vec<IndexSize> {
150        self.indices.iter().map(|idx| idx.index_size()).collect()
151    }
152
153    // TODO: deduplicate code with tensorexpression
154    pub fn tensor_strides(&self) -> Vec<usize> {
155        let mut strides = Vec::with_capacity(self.indices.len());
156        let mut current_stride = 1;
157        for &index in self.indices.iter().rev() {
158            strides.push(current_stride);
159            current_stride *= index.index_size();
160        }
161        strides.reverse();
162        strides
163    }
164
165    pub fn elements(&self) -> &[ComplexExpression] {
166        &self.expressions.original
167    }
168
169    // Simplify and differentiate to prepare expression to evaluate up to diff_level.
170    pub fn prepare(&mut self, diff_level: DifferentiationLevel) -> bool {
171        self.expressions.prepare(&self.variables, diff_level)
172    }
173
174    fn add_to_builder<'a, R: RealScalar>(
175        &'a self,
176        mut builder: ModuleBuilder<'a, R>,
177    ) -> ModuleBuilder<'a, R> {
178        if self.expressions.func.is_some() {
179            // println!("Adding {} function to module", self.name.clone() + "_" + "1");
180            let unit = CompilableUnit::new(
181                &(format!("expr_{}_{}", self.base_id, "1")),
182                self.expressions.func.as_ref().unwrap(),
183                self.variables.clone(),
184                self.expressions.original.len() * 2,
185            );
186
187            builder = builder.add_unit(unit);
188        }
189
190        if self.expressions.grad.is_some() {
191            // println!("Adding {} function to module", self.name.clone() + "_" + "2");
192            let unit = CompilableUnit::new(
193                &(format!("expr_{}_{}", self.base_id, "2")),
194                self.expressions.grad.as_ref().unwrap(),
195                self.variables.clone(),
196                self.expressions.original.len() * 2,
197            );
198
199            builder = builder.add_unit(unit);
200        }
201
202        if self.expressions.hess.is_some() {
203            // println!("Adding {} function to module", self.name.clone() + "_" + "3");
204            let unit = CompilableUnit::new(
205                &(format!("expr_{}_{}", self.base_id, "3")),
206                self.expressions.grad.as_ref().unwrap(),
207                self.variables.clone(),
208                self.expressions.original.len() * 2,
209            );
210
211            builder = builder.add_unit(unit);
212        }
213
214        builder
215    }
216
217    pub fn num_elements(&self) -> usize {
218        self.expressions.num_elements()
219    }
220
221    pub fn form_expression(&self) -> TensorExpression {
222        let named = NamedExpression::new(
223            self.name.clone(),
224            self.variables.clone(),
225            self.expressions.original.clone(),
226        );
227        TensorExpression::from_raw(self.indices.clone(), named)
228    }
229
230    pub fn form_modified_indices(
231        &self,
232        modifiers: &(Vec<usize>, GenerationShape),
233    ) -> Vec<TensorIndex> {
234        let perm_index_sizes: Vec<usize> = modifiers
235            .0
236            .iter()
237            .map(|p| self.indices[*p].index_size())
238            .collect();
239        let redirection = modifiers.1.calculate_directions(&perm_index_sizes);
240        perm_index_sizes
241            .iter()
242            .zip(redirection.iter())
243            .enumerate()
244            .map(|(id, (s, d))| TensorIndex::new(*d, id, *s))
245            .collect()
246    }
247
248    pub fn form_modified_expression(
249        &self,
250        modifiers: &(Vec<usize>, GenerationShape),
251    ) -> TensorExpression {
252        let perm_index_sizes: Vec<usize> = modifiers
253            .0
254            .iter()
255            .map(|p| self.indices[*p].index_size())
256            .collect();
257        let redirection = modifiers.1.calculate_directions(&perm_index_sizes);
258        let mut expression = self.form_expression();
259        let new_indices = perm_index_sizes
260            .iter()
261            .zip(redirection.iter())
262            .enumerate()
263            .map(|(id, (s, d))| TensorIndex::new(*d, id, *s))
264            .collect();
265        expression.permute(&modifiers.0, redirection);
266        expression.reindex(new_indices);
267        expression
268    }
269}
270
271#[derive(Default)]
272pub struct ExpressionCache {
273    expressions: BTreeMap<ExpressionId, CachedTensorExpression>,
274    id_lookup: BTreeMap<ExpressionId, ExpressionId>, // Maps derived expressions to their base id
275    name_lookup: BTreeMap<String, Vec<ExpressionId>>, // Name to base expression ids
276    module32: Option<Module<f32>>,
277    module64: Option<Module<f64>>,
278    id_counter: ExpressionId,
279}
280
281impl ExpressionCache {
282    pub fn new() -> Self {
283        Self {
284            expressions: BTreeMap::new(),
285            id_lookup: BTreeMap::new(),
286            name_lookup: BTreeMap::new(),
287            module32: None,
288            module64: None,
289            id_counter: 0,
290        }
291    }
292
293    pub fn new_shared() -> Arc<Mutex<Self>> {
294        Arc::new(Self::new().into())
295    }
296
297    pub fn get(&self, expr_id: ExpressionId) -> Option<TensorExpression> {
298        // TODO: do better.
299        let base_id = self
300            .id_lookup
301            .get(&expr_id)
302            .unwrap_or_else(|| panic!("Failed to {expr_id} in cache."));
303        let cexpr = self
304            .expressions
305            .get(base_id)
306            .unwrap_or_else(|| panic!("Failed to {expr_id} in cache."));
307        let modifiers = cexpr
308            .id_lookup
309            .get(&expr_id)
310            .unwrap_or_else(|| panic!("Failed to {expr_id} in cache."));
311        Some(cexpr.form_modified_expression(modifiers))
312    }
313
314    pub fn lookup(&self, expr: impl AsRef<NamedExpression>) -> Option<ExpressionId> {
315        let expr = expr.as_ref();
316        if let Some(ids) = self.name_lookup.get(expr.name()) {
317            for id in ids {
318                let cexpr = self
319                    .expressions
320                    .get(id)
321                    .expect("Expected since just looked up id.");
322                if expr == &cexpr.expressions.original {
323                    return Some(*id);
324                }
325            }
326        }
327        None
328    }
329
330    pub fn contains(&self, expr: impl AsRef<NamedExpression>) -> bool {
331        self.lookup(expr).is_some()
332    }
333
334    fn get_new_id(&mut self) -> ExpressionId {
335        let new_id = self.id_counter;
336        self.id_counter += 1;
337        new_id
338    }
339
340    fn uncompile(&mut self) {
341        self.module32 = None;
342        self.module64 = None;
343    }
344
345    pub fn remove(&mut self, expr_id: ExpressionId) {
346        let mut modified_flag = false;
347        if let Some(base_id) = self.id_lookup.get(&expr_id) {
348            // Only actually remove anything if its a base expression
349            // TODO: This might be incorrect as derived expressions might
350            // still exist. Maybe, just mark this one as removed, and remove
351            // it when it's id_lookup is empty?
352            if *base_id == expr_id {
353                modified_flag = true;
354                let cexpr = self.expressions.remove(base_id).unwrap();
355                self.id_lookup.remove(&expr_id);
356                let name_vec = self.name_lookup.get_mut(&cexpr.name).unwrap();
357                name_vec.swap_remove(name_vec.iter().position(|x| *x == expr_id).unwrap());
358            }
359        }
360
361        if modified_flag {
362            self.uncompile();
363        }
364    }
365
366    pub fn insert(&mut self, expr: impl Into<TensorExpression>) -> ExpressionId {
367        let expr: TensorExpression = expr.into();
368        if let Some(id) = self.lookup(&expr) {
369            return id;
370        }
371
372        self.uncompile();
373
374        let id = self.get_new_id();
375        self.name_lookup.insert(expr.name().to_owned(), vec![id]);
376        self.expressions
377            .insert(id, CachedTensorExpression::new(expr, id));
378        self.id_lookup.insert(id, id);
379        id
380    }
381
382    pub fn indices(&self, expr_id: ExpressionId) -> Vec<TensorIndex> {
383        let base_id = self
384            .id_lookup
385            .get(&expr_id)
386            .unwrap_or_else(|| panic!("Failed to {expr_id} in cache."));
387        let cexpr = self
388            .expressions
389            .get(base_id)
390            .unwrap_or_else(|| panic!("Failed to {expr_id} in cache."));
391        let modifiers = cexpr
392            .id_lookup
393            .get(&expr_id)
394            .unwrap_or_else(|| panic!("Failed to {expr_id} in cache."));
395        cexpr.form_modified_indices(modifiers)
396    }
397
398    pub fn num_elements(&self, expr_id: ExpressionId) -> usize {
399        let base_id = self
400            .id_lookup
401            .get(&expr_id)
402            .unwrap_or_else(|| panic!("Failed to {expr_id} in cache."));
403        let cexpr = self
404            .expressions
405            .get(base_id)
406            .unwrap_or_else(|| panic!("Failed to {expr_id} in cache."));
407        cexpr.num_elements()
408    }
409
410    pub fn generation_shape(&self, expr_id: ExpressionId) -> GenerationShape {
411        let base_id = self
412            .id_lookup
413            .get(&expr_id)
414            .unwrap_or_else(|| panic!("Failed to {expr_id} in cache."));
415        let cexpr = self
416            .expressions
417            .get(base_id)
418            .unwrap_or_else(|| panic!("Failed to {expr_id} in cache."));
419        let modifiers = cexpr
420            .id_lookup
421            .get(&expr_id)
422            .unwrap_or_else(|| panic!("Failed to {expr_id} in cache."));
423        modifiers.1
424    }
425
426    pub fn base_name(&self, expr_id: ExpressionId) -> String {
427        let base_id = self
428            .id_lookup
429            .get(&expr_id)
430            .unwrap_or_else(|| panic!("Failed to {expr_id} in cache."));
431        let cexpr = self
432            .expressions
433            .get(base_id)
434            .unwrap_or_else(|| panic!("Failed to {expr_id} in cache."));
435
436        cexpr.name.clone()
437    }
438
439    pub fn name(&self, expr_id: ExpressionId) -> String {
440        let base_id = self
441            .id_lookup
442            .get(&expr_id)
443            .unwrap_or_else(|| panic!("Failed to {expr_id} in cache."));
444        let cexpr = self
445            .expressions
446            .get(base_id)
447            .unwrap_or_else(|| panic!("Failed to {expr_id} in cache."));
448        let modifiers = cexpr
449            .id_lookup
450            .get(&expr_id)
451            .unwrap_or_else(|| panic!("Failed to {expr_id} in cache."));
452
453        let base_name = cexpr.name.clone();
454        let permutation = modifiers
455            .0
456            .iter()
457            .map(|&i| i.to_string())
458            .collect::<Vec<String>>()
459            .join("_");
460        let shape = modifiers
461            .1
462            .to_vec()
463            .iter()
464            .map(|&i| i.to_string())
465            .collect::<Vec<String>>()
466            .join("_");
467        format!("{base_name}_perm{permutation}_shape{shape}")
468    }
469
470    pub fn num_params(&self, expr_id: ExpressionId) -> usize {
471        let base_id = self
472            .id_lookup
473            .get(&expr_id)
474            .unwrap_or_else(|| panic!("Failed to {expr_id} in cache."));
475        let cexpr = self
476            .expressions
477            .get(base_id)
478            .unwrap_or_else(|| panic!("Failed to {expr_id} in cache."));
479        cexpr.num_params()
480    }
481
482    pub fn trace(&mut self, expr_id: ExpressionId, pairs: Vec<(usize, usize)>) -> ExpressionId {
483        // expr_id -> base_id -> base_expr -> derived_name -> base_name -> check if base_name_tracedXYZ exists
484        let old_name = self.name(expr_id);
485        let traced = pairs
486            .iter()
487            .map(|(x, y)| format!("{x}_{y}"))
488            .collect::<Vec<String>>()
489            .join("_");
490        let traced_name = format!("traced{traced}_{old_name}");
491
492        let base_id = self
493            .id_lookup
494            .get(&expr_id)
495            .unwrap_or_else(|| panic!("Failed to {expr_id} in cache."));
496        let cexpr = self
497            .expressions
498            .get(base_id)
499            .unwrap_or_else(|| panic!("Failed to {expr_id} in cache."));
500        let modifiers = cexpr
501            .id_lookup
502            .get(&expr_id)
503            .unwrap_or_else(|| panic!("Failed to {expr_id} in cache."));
504
505        let old_expr: TensorExpression = cexpr.form_modified_expression(modifiers);
506
507        let mut traced_expr = old_expr.partial_trace(&pairs);
508        traced_expr.set_name(traced_name.clone());
509
510        if let Some(ids) = self.name_lookup.get(&traced_name) {
511            for id in ids {
512                let cexpr_other = self
513                    .expressions
514                    .get(id)
515                    .unwrap_or_else(|| panic!("Failed to {id} in cache."));
516
517                if traced_expr.elements() == cexpr_other.elements() {
518                    return *id;
519                }
520            }
521        }
522
523        self.uncompile();
524        self.insert(traced_expr)
525    }
526
527    pub fn permute_reshape(
528        &mut self,
529        expr_id: ExpressionId,
530        perm: Vec<usize>,
531        reshape: GenerationShape,
532    ) -> ExpressionId {
533        let base_id = *self
534            .id_lookup
535            .get(&expr_id)
536            .unwrap_or_else(|| panic!("Failed to {expr_id} in cache."));
537        let cexpr = self
538            .expressions
539            .get(&base_id)
540            .unwrap_or_else(|| panic!("Failed to {expr_id} in cache."));
541        let modifiers = cexpr
542            .id_lookup
543            .get(&expr_id)
544            .unwrap_or_else(|| panic!("Failed to {expr_id} in cache."));
545
546        // assert reshape num elements is correct
547        let composed_perm: Vec<usize> = perm.iter().map(|&idx| modifiers.0[idx]).collect(); // TODO: check if correct ordering
548
549        let new_val = (composed_perm, reshape);
550        for (id, val) in cexpr.id_lookup.iter() {
551            if *val == new_val {
552                return *id;
553            }
554        }
555
556        let new_id = self.get_new_id();
557        let cexpr = self
558            .expressions
559            .get_mut(&base_id)
560            .unwrap_or_else(|| panic!("Failed to {expr_id} in cache."));
561        cexpr.id_lookup.insert(new_id, new_val);
562        self.id_lookup.insert(new_id, base_id);
563        new_id
564    }
565
566    fn _get_module_mut<R: RealScalar>(&mut self) -> &mut Option<Module<R>> {
567        if std::any::TypeId::of::<R>() == std::any::TypeId::of::<f32>() {
568            // Safety: Just checked exact type id
569            unsafe { std::mem::transmute(&mut self.module32) }
570        } else if std::any::TypeId::of::<R>() == std::any::TypeId::of::<f64>() {
571            // Safety: Just checked exact type id
572            unsafe { std::mem::transmute(&mut self.module64) }
573        } else {
574            unreachable!()
575        }
576    }
577
578    fn _get_module<R: RealScalar>(&self) -> &Option<Module<R>> {
579        if std::any::TypeId::of::<R>() == std::any::TypeId::of::<f32>() {
580            // Safety: Just checked exact type id
581            unsafe { std::mem::transmute(&self.module32) }
582        } else if std::any::TypeId::of::<R>() == std::any::TypeId::of::<f64>() {
583            // Safety: Just checked exact type id
584            unsafe { std::mem::transmute(&self.module64) }
585        } else {
586            unreachable!()
587        }
588    }
589
590    pub fn is_compiled<R: RealScalar>(&self) -> bool {
591        self._get_module::<R>().is_some()
592    }
593
594    // Simplify and differentiate to prepare expression to evaluate up to diff_level.
595    pub fn prepare(&mut self, diff_level: DifferentiationLevel) {
596        let mut should_uncompile = false;
597        for (_, cexpr) in self.expressions.iter_mut() {
598            if cexpr.prepare(diff_level) {
599                should_uncompile = true;
600            }
601        }
602
603        if should_uncompile {
604            self.uncompile();
605        }
606    }
607
608    pub fn compile<R: RealScalar>(&mut self) {
609        let mut module_builder: ModuleBuilder<R> = ModuleBuilder::new("cache");
610        // For each base expression
611        for (_, cexpr) in self.expressions.iter() {
612            module_builder = cexpr.add_to_builder(module_builder);
613        }
614        *self._get_module_mut() = Some(module_builder.build());
615    }
616
617    // TODO: explore a lock feature: compile and lock -> returns something that needs to be passed
618    // into get_fn; that provides guarantees that when the lock goes out of scope all the write
619    // functions are also out of scope.
620
621    pub fn get_fn<R: RealScalar>(
622        &mut self,
623        expr_id: ExpressionId,
624        diff_level: DifferentiationLevel,
625    ) -> WriteFunc<R> {
626        if !self.is_compiled::<R>() {
627            self.compile::<R>();
628        }
629
630        let module = self
631            ._get_module()
632            .as_ref()
633            .expect("Module should exist due to previous compilation.");
634
635        let base_id = self
636            .id_lookup
637            .get(&expr_id)
638            .expect("Unexpected expression id.");
639
640        // TODO: ensure diff_level function exists
641        module
642            .get_function(&format!("expr_{}_{}", base_id, diff_level))
643            .expect("Error retrieving compiled expression.")
644    }
645
646    pub fn get_output_map<R: RealScalar>(
647        &self,
648        expr_id: ExpressionId,
649        row_stride: u64,
650        col_stride: u64,
651        mat_stride: u64,
652    ) -> Vec<u64> {
653        let base_id = self
654            .id_lookup
655            .get(&expr_id)
656            .unwrap_or_else(|| panic!("Failed to {expr_id} in cache."));
657        let cexpr = self
658            .expressions
659            .get(base_id)
660            .unwrap_or_else(|| panic!("Failed to {expr_id} in cache."));
661        let modifiers = cexpr
662            .id_lookup
663            .get(&expr_id)
664            .unwrap_or_else(|| panic!("Failed to {expr_id} in cache."));
665
666        let (perm, shape) = modifiers;
667
668        // TODO: deduplicate code with tensorexpression.permute; extract calc_elem_perm ->
669        // Vec<TensorIndex>
670        let original_strides = cexpr.tensor_strides();
671        let original_dimensions = cexpr.dimensions();
672
673        let new_strides = {
674            let mut strides = Vec::with_capacity(cexpr.rank());
675            let mut current_stride = 1;
676            for perm_index in perm.iter().map(|p| cexpr.indices()[*p]).rev() {
677                strides.push(current_stride);
678                current_stride *= perm_index.index_size();
679            }
680            strides.reverse();
681            strides
682        };
683
684        let mut complex_elem_perm = Vec::with_capacity(cexpr.num_elements());
685        for i in 0..cexpr.num_elements() {
686            let mut original_coordinate: Vec<usize> = Vec::with_capacity(cexpr.rank());
687            let mut temp_i = i;
688            for d_idx in 0..cexpr.rank() {
689                original_coordinate
690                    .push((temp_i / original_strides[d_idx]) % original_dimensions[d_idx]);
691                temp_i %= original_strides[d_idx]; // Update temp_i for next dimension
692            }
693
694            // Map original coordinate components to their new positions according to `perm`.
695            // If `perm[j]` is `k`, it means the `j`-th dimension in the new order
696            // corresponds to the `k`-th dimension in the original order.
697            let mut permuted_coordinate: Vec<usize> = vec![0; cexpr.rank()];
698            for j in 0..cexpr.rank() {
699                permuted_coordinate[j] = original_coordinate[perm[j]];
700            }
701
702            // Calculate new linear index using the permuted coordinate and new strides
703            let mut new_linear_idx = 0;
704            for d_idx in 0..cexpr.rank() {
705                new_linear_idx += permuted_coordinate[d_idx] * new_strides[d_idx];
706            }
707            complex_elem_perm.push(new_linear_idx);
708        }
709
710        let ncols = shape.ncols() as u64;
711        let nrows = shape.nrows() as u64;
712
713        let num_real_elements = cexpr.num_elements() * 2;
714        let mut map = Vec::with_capacity(num_real_elements);
715
716        for real_idx in 0..(num_real_elements as u64) {
717            let complex_idx = real_idx / 2;
718            let complex_perm_idx = complex_elem_perm[complex_idx as usize] as u64;
719            let imag_offset = real_idx % 2;
720            let mat_idx = complex_perm_idx / (nrows * ncols);
721            let row_idx = (complex_perm_idx % (nrows * ncols)) / ncols;
722            let col_idx = complex_perm_idx % ncols;
723            map.push(
724                2 * (mat_idx * mat_stride + row_idx * row_stride + col_idx * col_stride)
725                    + imag_offset,
726            )
727        }
728
729        map
730    }
731}