1use std::collections::BTreeMap;
2use std::sync::{Arc, Mutex};
3
4use qudit_core::RealScalar;
5
6use crate::TensorExpression;
7use crate::analysis::simplify_expressions_iter;
8use crate::codegen::CompilableUnit;
9use crate::expressions::{ExpressionBody, NamedExpression};
10use crate::index::{IndexSize, TensorIndex};
11use crate::{
12 ComplexExpression, DifferentiationLevel, Expression, GRADIENT, GenerationShape, HESSIAN,
13 Module, ModuleBuilder, WriteFunc,
14};
15
16pub type ExpressionId = u64;
17
18pub struct CachedExpressionBody {
19 original: ExpressionBody,
20 func: Option<Vec<Expression>>,
21 grad: Option<Vec<Expression>>,
22 hess: Option<Vec<Expression>>,
23}
24
25impl CachedExpressionBody {
26 pub fn new(original: impl Into<ExpressionBody>) -> CachedExpressionBody {
27 let original = original.into();
28 CachedExpressionBody {
29 original,
30 func: None,
31 grad: None,
32 hess: None,
33 }
34 }
35
36 pub fn num_elements(&self) -> usize {
37 self.original.num_elements()
38 }
39
40 pub fn prepare(&mut self, variables: &[String], diff_level: DifferentiationLevel) -> bool {
42 let mut has_changed = false;
43
44 if self.func.is_none() {
45 self.func = Some(simplify_expressions_iter(
46 self.original.iter().flat_map(|c| [&c.real, &c.imag]),
47 ));
48
49 has_changed = true;
50 }
51
52 if diff_level >= GRADIENT && self.grad.is_none() {
53 let mut grad_exprs = vec![];
54 for variable in variables {
55 for expr in self.original.iter() {
56 let grad_expr = expr.differentiate(variable);
57 grad_exprs.push(grad_expr);
58 }
59 }
60
61 self.grad = Some(simplify_expressions_iter(
62 self.original
63 .iter()
64 .chain(grad_exprs.iter())
65 .flat_map(|c| [&c.real, &c.imag]),
66 ));
67
68 has_changed = true;
69 }
70
71 if diff_level >= HESSIAN && self.hess.is_some() {
72 let mut grad_exprs = vec![];
73 for variable in variables {
74 for expr in self.original.iter() {
75 let grad_expr = expr.differentiate(variable);
76 grad_exprs.push(grad_expr);
77 }
78 }
79
80 let mut hess_exprs = vec![];
81 for variable in variables {
82 for expr in grad_exprs.iter() {
83 let hess_expr = expr.differentiate(variable);
84 hess_exprs.push(hess_expr);
85 }
86 }
87
88 self.hess = Some(simplify_expressions_iter(
89 self.original
90 .iter()
91 .chain(grad_exprs.iter())
92 .chain(hess_exprs.iter())
93 .flat_map(|c| [&c.real, &c.imag]),
94 ));
95
96 has_changed = true;
97 }
98
99 has_changed
100 }
101}
102
103impl<B: Into<ExpressionBody>> From<B> for CachedExpressionBody {
104 fn from(value: B) -> Self {
105 CachedExpressionBody::new(value)
106 }
107}
108
109pub struct CachedTensorExpression {
110 name: String,
111 variables: Vec<String>,
112 indices: Vec<TensorIndex>,
113 expressions: CachedExpressionBody,
114 id_lookup: BTreeMap<ExpressionId, (Vec<usize>, GenerationShape)>,
115 base_id: ExpressionId,
116}
117
118impl CachedTensorExpression {
119 pub fn new<E: Into<TensorExpression>>(expr: E, base_id: ExpressionId) -> Self {
120 let expr = expr.into();
121 let mut id_lookup = BTreeMap::new();
122 id_lookup.insert(base_id, ((0..expr.rank()).collect(), expr.indices().into()));
123 let (name, variables, body, indices) = expr.destruct();
124 CachedTensorExpression {
125 name,
126 variables,
127 indices,
128 expressions: body.into(),
129 id_lookup,
130 base_id,
131 }
132 }
133
134 pub fn indices(&self) -> &[TensorIndex] {
136 &self.indices
137 }
138
139 pub fn rank(&self) -> usize {
141 self.indices.len()
142 }
143
144 pub fn num_params(&self) -> usize {
145 self.variables.len()
146 }
147
148 pub fn dimensions(&self) -> Vec<IndexSize> {
150 self.indices.iter().map(|idx| idx.index_size()).collect()
151 }
152
153 pub fn tensor_strides(&self) -> Vec<usize> {
155 let mut strides = Vec::with_capacity(self.indices.len());
156 let mut current_stride = 1;
157 for &index in self.indices.iter().rev() {
158 strides.push(current_stride);
159 current_stride *= index.index_size();
160 }
161 strides.reverse();
162 strides
163 }
164
165 pub fn elements(&self) -> &[ComplexExpression] {
166 &self.expressions.original
167 }
168
169 pub fn prepare(&mut self, diff_level: DifferentiationLevel) -> bool {
171 self.expressions.prepare(&self.variables, diff_level)
172 }
173
174 fn add_to_builder<'a, R: RealScalar>(
175 &'a self,
176 mut builder: ModuleBuilder<'a, R>,
177 ) -> ModuleBuilder<'a, R> {
178 if self.expressions.func.is_some() {
179 let unit = CompilableUnit::new(
181 &(format!("expr_{}_{}", self.base_id, "1")),
182 self.expressions.func.as_ref().unwrap(),
183 self.variables.clone(),
184 self.expressions.original.len() * 2,
185 );
186
187 builder = builder.add_unit(unit);
188 }
189
190 if self.expressions.grad.is_some() {
191 let unit = CompilableUnit::new(
193 &(format!("expr_{}_{}", self.base_id, "2")),
194 self.expressions.grad.as_ref().unwrap(),
195 self.variables.clone(),
196 self.expressions.original.len() * 2,
197 );
198
199 builder = builder.add_unit(unit);
200 }
201
202 if self.expressions.hess.is_some() {
203 let unit = CompilableUnit::new(
205 &(format!("expr_{}_{}", self.base_id, "3")),
206 self.expressions.grad.as_ref().unwrap(),
207 self.variables.clone(),
208 self.expressions.original.len() * 2,
209 );
210
211 builder = builder.add_unit(unit);
212 }
213
214 builder
215 }
216
217 pub fn num_elements(&self) -> usize {
218 self.expressions.num_elements()
219 }
220
221 pub fn form_expression(&self) -> TensorExpression {
222 let named = NamedExpression::new(
223 self.name.clone(),
224 self.variables.clone(),
225 self.expressions.original.clone(),
226 );
227 TensorExpression::from_raw(self.indices.clone(), named)
228 }
229
230 pub fn form_modified_indices(
231 &self,
232 modifiers: &(Vec<usize>, GenerationShape),
233 ) -> Vec<TensorIndex> {
234 let perm_index_sizes: Vec<usize> = modifiers
235 .0
236 .iter()
237 .map(|p| self.indices[*p].index_size())
238 .collect();
239 let redirection = modifiers.1.calculate_directions(&perm_index_sizes);
240 perm_index_sizes
241 .iter()
242 .zip(redirection.iter())
243 .enumerate()
244 .map(|(id, (s, d))| TensorIndex::new(*d, id, *s))
245 .collect()
246 }
247
248 pub fn form_modified_expression(
249 &self,
250 modifiers: &(Vec<usize>, GenerationShape),
251 ) -> TensorExpression {
252 let perm_index_sizes: Vec<usize> = modifiers
253 .0
254 .iter()
255 .map(|p| self.indices[*p].index_size())
256 .collect();
257 let redirection = modifiers.1.calculate_directions(&perm_index_sizes);
258 let mut expression = self.form_expression();
259 let new_indices = perm_index_sizes
260 .iter()
261 .zip(redirection.iter())
262 .enumerate()
263 .map(|(id, (s, d))| TensorIndex::new(*d, id, *s))
264 .collect();
265 expression.permute(&modifiers.0, redirection);
266 expression.reindex(new_indices);
267 expression
268 }
269}
270
271#[derive(Default)]
272pub struct ExpressionCache {
273 expressions: BTreeMap<ExpressionId, CachedTensorExpression>,
274 id_lookup: BTreeMap<ExpressionId, ExpressionId>, name_lookup: BTreeMap<String, Vec<ExpressionId>>, module32: Option<Module<f32>>,
277 module64: Option<Module<f64>>,
278 id_counter: ExpressionId,
279}
280
281impl ExpressionCache {
282 pub fn new() -> Self {
283 Self {
284 expressions: BTreeMap::new(),
285 id_lookup: BTreeMap::new(),
286 name_lookup: BTreeMap::new(),
287 module32: None,
288 module64: None,
289 id_counter: 0,
290 }
291 }
292
293 pub fn new_shared() -> Arc<Mutex<Self>> {
294 Arc::new(Self::new().into())
295 }
296
297 pub fn get(&self, expr_id: ExpressionId) -> Option<TensorExpression> {
298 let base_id = self
300 .id_lookup
301 .get(&expr_id)
302 .unwrap_or_else(|| panic!("Failed to {expr_id} in cache."));
303 let cexpr = self
304 .expressions
305 .get(base_id)
306 .unwrap_or_else(|| panic!("Failed to {expr_id} in cache."));
307 let modifiers = cexpr
308 .id_lookup
309 .get(&expr_id)
310 .unwrap_or_else(|| panic!("Failed to {expr_id} in cache."));
311 Some(cexpr.form_modified_expression(modifiers))
312 }
313
314 pub fn lookup(&self, expr: impl AsRef<NamedExpression>) -> Option<ExpressionId> {
315 let expr = expr.as_ref();
316 if let Some(ids) = self.name_lookup.get(expr.name()) {
317 for id in ids {
318 let cexpr = self
319 .expressions
320 .get(id)
321 .expect("Expected since just looked up id.");
322 if expr == &cexpr.expressions.original {
323 return Some(*id);
324 }
325 }
326 }
327 None
328 }
329
330 pub fn contains(&self, expr: impl AsRef<NamedExpression>) -> bool {
331 self.lookup(expr).is_some()
332 }
333
334 fn get_new_id(&mut self) -> ExpressionId {
335 let new_id = self.id_counter;
336 self.id_counter += 1;
337 new_id
338 }
339
340 fn uncompile(&mut self) {
341 self.module32 = None;
342 self.module64 = None;
343 }
344
345 pub fn remove(&mut self, expr_id: ExpressionId) {
346 let mut modified_flag = false;
347 if let Some(base_id) = self.id_lookup.get(&expr_id) {
348 if *base_id == expr_id {
353 modified_flag = true;
354 let cexpr = self.expressions.remove(base_id).unwrap();
355 self.id_lookup.remove(&expr_id);
356 let name_vec = self.name_lookup.get_mut(&cexpr.name).unwrap();
357 name_vec.swap_remove(name_vec.iter().position(|x| *x == expr_id).unwrap());
358 }
359 }
360
361 if modified_flag {
362 self.uncompile();
363 }
364 }
365
366 pub fn insert(&mut self, expr: impl Into<TensorExpression>) -> ExpressionId {
367 let expr: TensorExpression = expr.into();
368 if let Some(id) = self.lookup(&expr) {
369 return id;
370 }
371
372 self.uncompile();
373
374 let id = self.get_new_id();
375 self.name_lookup.insert(expr.name().to_owned(), vec![id]);
376 self.expressions
377 .insert(id, CachedTensorExpression::new(expr, id));
378 self.id_lookup.insert(id, id);
379 id
380 }
381
382 pub fn indices(&self, expr_id: ExpressionId) -> Vec<TensorIndex> {
383 let base_id = self
384 .id_lookup
385 .get(&expr_id)
386 .unwrap_or_else(|| panic!("Failed to {expr_id} in cache."));
387 let cexpr = self
388 .expressions
389 .get(base_id)
390 .unwrap_or_else(|| panic!("Failed to {expr_id} in cache."));
391 let modifiers = cexpr
392 .id_lookup
393 .get(&expr_id)
394 .unwrap_or_else(|| panic!("Failed to {expr_id} in cache."));
395 cexpr.form_modified_indices(modifiers)
396 }
397
398 pub fn num_elements(&self, expr_id: ExpressionId) -> usize {
399 let base_id = self
400 .id_lookup
401 .get(&expr_id)
402 .unwrap_or_else(|| panic!("Failed to {expr_id} in cache."));
403 let cexpr = self
404 .expressions
405 .get(base_id)
406 .unwrap_or_else(|| panic!("Failed to {expr_id} in cache."));
407 cexpr.num_elements()
408 }
409
410 pub fn generation_shape(&self, expr_id: ExpressionId) -> GenerationShape {
411 let base_id = self
412 .id_lookup
413 .get(&expr_id)
414 .unwrap_or_else(|| panic!("Failed to {expr_id} in cache."));
415 let cexpr = self
416 .expressions
417 .get(base_id)
418 .unwrap_or_else(|| panic!("Failed to {expr_id} in cache."));
419 let modifiers = cexpr
420 .id_lookup
421 .get(&expr_id)
422 .unwrap_or_else(|| panic!("Failed to {expr_id} in cache."));
423 modifiers.1
424 }
425
426 pub fn base_name(&self, expr_id: ExpressionId) -> String {
427 let base_id = self
428 .id_lookup
429 .get(&expr_id)
430 .unwrap_or_else(|| panic!("Failed to {expr_id} in cache."));
431 let cexpr = self
432 .expressions
433 .get(base_id)
434 .unwrap_or_else(|| panic!("Failed to {expr_id} in cache."));
435
436 cexpr.name.clone()
437 }
438
439 pub fn name(&self, expr_id: ExpressionId) -> String {
440 let base_id = self
441 .id_lookup
442 .get(&expr_id)
443 .unwrap_or_else(|| panic!("Failed to {expr_id} in cache."));
444 let cexpr = self
445 .expressions
446 .get(base_id)
447 .unwrap_or_else(|| panic!("Failed to {expr_id} in cache."));
448 let modifiers = cexpr
449 .id_lookup
450 .get(&expr_id)
451 .unwrap_or_else(|| panic!("Failed to {expr_id} in cache."));
452
453 let base_name = cexpr.name.clone();
454 let permutation = modifiers
455 .0
456 .iter()
457 .map(|&i| i.to_string())
458 .collect::<Vec<String>>()
459 .join("_");
460 let shape = modifiers
461 .1
462 .to_vec()
463 .iter()
464 .map(|&i| i.to_string())
465 .collect::<Vec<String>>()
466 .join("_");
467 format!("{base_name}_perm{permutation}_shape{shape}")
468 }
469
470 pub fn num_params(&self, expr_id: ExpressionId) -> usize {
471 let base_id = self
472 .id_lookup
473 .get(&expr_id)
474 .unwrap_or_else(|| panic!("Failed to {expr_id} in cache."));
475 let cexpr = self
476 .expressions
477 .get(base_id)
478 .unwrap_or_else(|| panic!("Failed to {expr_id} in cache."));
479 cexpr.num_params()
480 }
481
482 pub fn trace(&mut self, expr_id: ExpressionId, pairs: Vec<(usize, usize)>) -> ExpressionId {
483 let old_name = self.name(expr_id);
485 let traced = pairs
486 .iter()
487 .map(|(x, y)| format!("{x}_{y}"))
488 .collect::<Vec<String>>()
489 .join("_");
490 let traced_name = format!("traced{traced}_{old_name}");
491
492 let base_id = self
493 .id_lookup
494 .get(&expr_id)
495 .unwrap_or_else(|| panic!("Failed to {expr_id} in cache."));
496 let cexpr = self
497 .expressions
498 .get(base_id)
499 .unwrap_or_else(|| panic!("Failed to {expr_id} in cache."));
500 let modifiers = cexpr
501 .id_lookup
502 .get(&expr_id)
503 .unwrap_or_else(|| panic!("Failed to {expr_id} in cache."));
504
505 let old_expr: TensorExpression = cexpr.form_modified_expression(modifiers);
506
507 let mut traced_expr = old_expr.partial_trace(&pairs);
508 traced_expr.set_name(traced_name.clone());
509
510 if let Some(ids) = self.name_lookup.get(&traced_name) {
511 for id in ids {
512 let cexpr_other = self
513 .expressions
514 .get(id)
515 .unwrap_or_else(|| panic!("Failed to {id} in cache."));
516
517 if traced_expr.elements() == cexpr_other.elements() {
518 return *id;
519 }
520 }
521 }
522
523 self.uncompile();
524 self.insert(traced_expr)
525 }
526
527 pub fn permute_reshape(
528 &mut self,
529 expr_id: ExpressionId,
530 perm: Vec<usize>,
531 reshape: GenerationShape,
532 ) -> ExpressionId {
533 let base_id = *self
534 .id_lookup
535 .get(&expr_id)
536 .unwrap_or_else(|| panic!("Failed to {expr_id} in cache."));
537 let cexpr = self
538 .expressions
539 .get(&base_id)
540 .unwrap_or_else(|| panic!("Failed to {expr_id} in cache."));
541 let modifiers = cexpr
542 .id_lookup
543 .get(&expr_id)
544 .unwrap_or_else(|| panic!("Failed to {expr_id} in cache."));
545
546 let composed_perm: Vec<usize> = perm.iter().map(|&idx| modifiers.0[idx]).collect(); let new_val = (composed_perm, reshape);
550 for (id, val) in cexpr.id_lookup.iter() {
551 if *val == new_val {
552 return *id;
553 }
554 }
555
556 let new_id = self.get_new_id();
557 let cexpr = self
558 .expressions
559 .get_mut(&base_id)
560 .unwrap_or_else(|| panic!("Failed to {expr_id} in cache."));
561 cexpr.id_lookup.insert(new_id, new_val);
562 self.id_lookup.insert(new_id, base_id);
563 new_id
564 }
565
566 fn _get_module_mut<R: RealScalar>(&mut self) -> &mut Option<Module<R>> {
567 if std::any::TypeId::of::<R>() == std::any::TypeId::of::<f32>() {
568 unsafe { std::mem::transmute(&mut self.module32) }
570 } else if std::any::TypeId::of::<R>() == std::any::TypeId::of::<f64>() {
571 unsafe { std::mem::transmute(&mut self.module64) }
573 } else {
574 unreachable!()
575 }
576 }
577
578 fn _get_module<R: RealScalar>(&self) -> &Option<Module<R>> {
579 if std::any::TypeId::of::<R>() == std::any::TypeId::of::<f32>() {
580 unsafe { std::mem::transmute(&self.module32) }
582 } else if std::any::TypeId::of::<R>() == std::any::TypeId::of::<f64>() {
583 unsafe { std::mem::transmute(&self.module64) }
585 } else {
586 unreachable!()
587 }
588 }
589
590 pub fn is_compiled<R: RealScalar>(&self) -> bool {
591 self._get_module::<R>().is_some()
592 }
593
594 pub fn prepare(&mut self, diff_level: DifferentiationLevel) {
596 let mut should_uncompile = false;
597 for (_, cexpr) in self.expressions.iter_mut() {
598 if cexpr.prepare(diff_level) {
599 should_uncompile = true;
600 }
601 }
602
603 if should_uncompile {
604 self.uncompile();
605 }
606 }
607
608 pub fn compile<R: RealScalar>(&mut self) {
609 let mut module_builder: ModuleBuilder<R> = ModuleBuilder::new("cache");
610 for (_, cexpr) in self.expressions.iter() {
612 module_builder = cexpr.add_to_builder(module_builder);
613 }
614 *self._get_module_mut() = Some(module_builder.build());
615 }
616
617 pub fn get_fn<R: RealScalar>(
622 &mut self,
623 expr_id: ExpressionId,
624 diff_level: DifferentiationLevel,
625 ) -> WriteFunc<R> {
626 if !self.is_compiled::<R>() {
627 self.compile::<R>();
628 }
629
630 let module = self
631 ._get_module()
632 .as_ref()
633 .expect("Module should exist due to previous compilation.");
634
635 let base_id = self
636 .id_lookup
637 .get(&expr_id)
638 .expect("Unexpected expression id.");
639
640 module
642 .get_function(&format!("expr_{}_{}", base_id, diff_level))
643 .expect("Error retrieving compiled expression.")
644 }
645
646 pub fn get_output_map<R: RealScalar>(
647 &self,
648 expr_id: ExpressionId,
649 row_stride: u64,
650 col_stride: u64,
651 mat_stride: u64,
652 ) -> Vec<u64> {
653 let base_id = self
654 .id_lookup
655 .get(&expr_id)
656 .unwrap_or_else(|| panic!("Failed to {expr_id} in cache."));
657 let cexpr = self
658 .expressions
659 .get(base_id)
660 .unwrap_or_else(|| panic!("Failed to {expr_id} in cache."));
661 let modifiers = cexpr
662 .id_lookup
663 .get(&expr_id)
664 .unwrap_or_else(|| panic!("Failed to {expr_id} in cache."));
665
666 let (perm, shape) = modifiers;
667
668 let original_strides = cexpr.tensor_strides();
671 let original_dimensions = cexpr.dimensions();
672
673 let new_strides = {
674 let mut strides = Vec::with_capacity(cexpr.rank());
675 let mut current_stride = 1;
676 for perm_index in perm.iter().map(|p| cexpr.indices()[*p]).rev() {
677 strides.push(current_stride);
678 current_stride *= perm_index.index_size();
679 }
680 strides.reverse();
681 strides
682 };
683
684 let mut complex_elem_perm = Vec::with_capacity(cexpr.num_elements());
685 for i in 0..cexpr.num_elements() {
686 let mut original_coordinate: Vec<usize> = Vec::with_capacity(cexpr.rank());
687 let mut temp_i = i;
688 for d_idx in 0..cexpr.rank() {
689 original_coordinate
690 .push((temp_i / original_strides[d_idx]) % original_dimensions[d_idx]);
691 temp_i %= original_strides[d_idx]; }
693
694 let mut permuted_coordinate: Vec<usize> = vec![0; cexpr.rank()];
698 for j in 0..cexpr.rank() {
699 permuted_coordinate[j] = original_coordinate[perm[j]];
700 }
701
702 let mut new_linear_idx = 0;
704 for d_idx in 0..cexpr.rank() {
705 new_linear_idx += permuted_coordinate[d_idx] * new_strides[d_idx];
706 }
707 complex_elem_perm.push(new_linear_idx);
708 }
709
710 let ncols = shape.ncols() as u64;
711 let nrows = shape.nrows() as u64;
712
713 let num_real_elements = cexpr.num_elements() * 2;
714 let mut map = Vec::with_capacity(num_real_elements);
715
716 for real_idx in 0..(num_real_elements as u64) {
717 let complex_idx = real_idx / 2;
718 let complex_perm_idx = complex_elem_perm[complex_idx as usize] as u64;
719 let imag_offset = real_idx % 2;
720 let mat_idx = complex_perm_idx / (nrows * ncols);
721 let row_idx = (complex_perm_idx % (nrows * ncols)) / ncols;
722 let col_idx = complex_perm_idx % ncols;
723 map.push(
724 2 * (mat_idx * mat_stride + row_idx * row_stride + col_idx * col_stride)
725 + imag_offset,
726 )
727 }
728
729 map
730 }
731}