1use super::contraction::{ContractableNetwork, ContractionPath};
7use super::tensor::{Tensor, TensorIndex};
8use quantrs2_core::error::{QuantRS2Error, QuantRS2Result};
9use std::cmp::Reverse;
10use std::collections::{BinaryHeap, HashMap, HashSet};
11use std::time::{Duration, Instant};
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum ContractionOptMethod {
16 Greedy,
18
19 DynamicProgramming,
21
22 Sliced,
24
25 Hybrid,
27}
28
29#[derive(Debug, Clone)]
31pub struct PathOptimizer {
32 max_optimization_time: Duration,
34
35 method: ContractionOptMethod,
37
38 max_slices: usize,
40
41 max_bond_dimension: usize,
43
44 use_memory_estimates: bool,
46}
47
48impl Default for PathOptimizer {
49 fn default() -> Self {
50 Self {
51 max_optimization_time: Duration::from_secs(10),
52 method: ContractionOptMethod::Hybrid,
53 max_slices: 16,
54 max_bond_dimension: 64,
55 use_memory_estimates: true,
56 }
57 }
58}
59
60impl PathOptimizer {
61 pub fn new() -> Self {
63 Self::default()
64 }
65
66 #[must_use]
68 pub const fn with_max_time(mut self, time: Duration) -> Self {
69 self.max_optimization_time = time;
70 self
71 }
72
73 #[must_use]
75 pub const fn with_method(mut self, method: ContractionOptMethod) -> Self {
76 self.method = method;
77 self
78 }
79
80 #[must_use]
82 pub const fn with_max_slices(mut self, slices: usize) -> Self {
83 self.max_slices = slices;
84 self
85 }
86
87 #[must_use]
89 pub const fn with_max_bond_dimension(mut self, dim: usize) -> Self {
90 self.max_bond_dimension = dim;
91 self
92 }
93
94 #[must_use]
96 pub const fn with_memory_estimates(mut self, use_estimates: bool) -> Self {
97 self.use_memory_estimates = use_estimates;
98 self
99 }
100
101 pub fn find_optimal_path(
103 &self,
104 tensors: &HashMap<usize, Tensor>,
105 connections: &[(TensorIndex, TensorIndex)],
106 ) -> QuantRS2Result<ContractionPath> {
107 match self.method {
108 ContractionOptMethod::Greedy => self.find_greedy_path(tensors, connections),
109 ContractionOptMethod::DynamicProgramming => self.find_dp_path(tensors, connections),
110 ContractionOptMethod::Sliced => self.find_sliced_path(tensors, connections),
111 ContractionOptMethod::Hybrid => self.find_hybrid_path(tensors, connections),
112 }
113 }
114
115 fn find_greedy_path(
117 &self,
118 tensors: &HashMap<usize, Tensor>,
119 connections: &[(TensorIndex, TensorIndex)],
120 ) -> QuantRS2Result<ContractionPath> {
121 let start_time = Instant::now();
123
124 let mut tensor_connections = HashMap::new();
126 for (t1, t2) in connections {
127 tensor_connections
128 .entry(t1.tensor_id)
129 .or_insert_with(HashSet::new)
130 .insert(t2.tensor_id);
131 tensor_connections
132 .entry(t2.tensor_id)
133 .or_insert_with(HashSet::new)
134 .insert(t1.tensor_id);
135 }
136
137 let mut tensor_sizes = HashMap::new();
139 for (&id, tensor) in tensors {
140 let size: usize = tensor.dimensions.iter().product();
141 tensor_sizes.insert(id, size);
142 }
143
144 let mut remaining_tensors: HashSet<usize> = tensors.keys().copied().collect();
146 let mut steps = Vec::new();
147 let mut total_cost = 0.0;
148
149 while remaining_tensors.len() > 1 {
151 if start_time.elapsed() > self.max_optimization_time {
153 break;
155 }
156
157 let mut best_pair = None;
159 let mut best_cost = f64::INFINITY;
160
161 for &t1 in &remaining_tensors {
162 if let Some(connected) = tensor_connections.get(&t1) {
163 for &t2 in connected {
164 if remaining_tensors.contains(&t2) {
165 let t1_size = tensor_sizes[&t1];
167 let t2_size = tensor_sizes[&t2];
168
169 let common_indices = count_common_indices(t1, t2, connections);
171
172 let result_size =
174 estimate_contraction_size(t1_size, t2_size, common_indices);
175
176 let cost = (t1_size * t2_size) as f64 + result_size as f64;
178
179 if cost < best_cost {
180 best_cost = cost;
181 best_pair = Some((t1, t2));
182 }
183 }
184 }
185 }
186 }
187
188 if let Some((t1, t2)) = best_pair {
190 steps.push((t1, t2));
192 total_cost += best_cost;
193
194 remaining_tensors.remove(&t1);
196 remaining_tensors.remove(&t2);
197 let new_id = t1; remaining_tensors.insert(new_id);
199
200 let mut new_connections = HashSet::new();
202
203 for id in &[t1, t2] {
205 if let Some(connections) = tensor_connections.get(id) {
206 let connections_clone = connections.clone();
207 for &connected in &connections_clone {
208 if connected != t1
209 && connected != t2
210 && remaining_tensors.contains(&connected)
211 {
212 new_connections.insert(connected);
213
214 if let Some(other_conns) = tensor_connections.get_mut(&connected) {
216 other_conns.remove(&t1);
217 other_conns.remove(&t2);
218 other_conns.insert(new_id);
219 }
220 }
221 }
222 }
223 }
224
225 tensor_connections.insert(new_id, new_connections);
227
228 let common_indices = count_common_indices(t1, t2, connections);
230 let new_size =
231 estimate_contraction_size(tensor_sizes[&t1], tensor_sizes[&t2], common_indices);
232 tensor_sizes.insert(new_id, new_size);
233 } else {
234 if remaining_tensors.len() >= 2 {
236 let mut ids: Vec<_> = remaining_tensors.iter().copied().collect();
237 ids.sort_unstable();
238 let t1 = ids[0];
239 let t2 = ids[1];
240
241 steps.push((t1, t2));
242 total_cost += (tensor_sizes[&t1] * tensor_sizes[&t2]) as f64;
243
244 remaining_tensors.remove(&t1);
245 remaining_tensors.remove(&t2);
246 remaining_tensors.insert(t1);
247
248 tensor_sizes.insert(t1, tensor_sizes[&t1] * tensor_sizes[&t2]);
250
251 }
253 break;
255 }
256 }
257
258 Ok(ContractionPath::new(steps, total_cost))
259 }
260
261 fn find_dp_path(
263 &self,
264 tensors: &HashMap<usize, Tensor>,
265 connections: &[(TensorIndex, TensorIndex)],
266 ) -> QuantRS2Result<ContractionPath> {
267 self.find_greedy_path(tensors, connections)
271 }
272
273 fn find_sliced_path(
275 &self,
276 tensors: &HashMap<usize, Tensor>,
277 connections: &[(TensorIndex, TensorIndex)],
278 ) -> QuantRS2Result<ContractionPath> {
279 self.find_greedy_path(tensors, connections)
283 }
284
285 fn find_hybrid_path(
287 &self,
288 tensors: &HashMap<usize, Tensor>,
289 connections: &[(TensorIndex, TensorIndex)],
290 ) -> QuantRS2Result<ContractionPath> {
291 let network_size = tensors.len();
293
294 if network_size <= 12 {
296 return self.find_dp_path(tensors, connections);
297 }
298
299 if network_size <= 24 {
301 return self.find_greedy_path(tensors, connections);
302 }
303
304 self.find_sliced_path(tensors, connections)
306 }
307}
308
309pub struct OptimizedTensorNetwork {
311 tensors: HashMap<usize, Tensor>,
313
314 connections: Vec<(TensorIndex, TensorIndex)>,
316
317 cached_path: Option<ContractionPath>,
319
320 optimizer: PathOptimizer,
322}
323
324impl Default for OptimizedTensorNetwork {
325 fn default() -> Self {
326 Self::new()
327 }
328}
329
330impl OptimizedTensorNetwork {
331 pub fn new() -> Self {
333 Self {
334 tensors: HashMap::new(),
335 connections: Vec::new(),
336 cached_path: None,
337 optimizer: PathOptimizer::default(),
338 }
339 }
340
341 #[must_use]
343 pub const fn with_optimization_method(mut self, method: ContractionOptMethod) -> Self {
344 self.optimizer = self.optimizer.with_method(method);
345 self
346 }
347
348 pub fn add_tensor(&mut self, id: usize, tensor: Tensor) {
350 self.tensors.insert(id, tensor);
351
352 self.cached_path = None;
354 }
355
356 pub fn add_connection(&mut self, t1: TensorIndex, t2: TensorIndex) {
358 self.connections.push((t1, t2));
359
360 self.cached_path = None;
362 }
363
364 pub fn get_optimal_path(&mut self) -> QuantRS2Result<ContractionPath> {
366 if let Some(path) = &self.cached_path {
368 return Ok(path.clone());
369 }
370
371 let path = self
373 .optimizer
374 .find_optimal_path(&self.tensors, &self.connections)?;
375 self.cached_path = Some(path.clone());
376
377 Ok(path)
378 }
379
380 pub fn contract(&mut self) -> QuantRS2Result<Tensor> {
382 let path = self.get_optimal_path()?;
384
385 let mut working_tensors = self.tensors.clone();
387 let mut working_connections = self.connections.clone();
388
389 for (id1, id2) in path.steps() {
391 let tensor1 = working_tensors.remove(id1).ok_or_else(|| {
393 QuantRS2Error::CircuitValidationFailed(format!("Tensor with ID {id1} not found"))
394 })?;
395
396 let tensor2 = working_tensors.remove(id2).ok_or_else(|| {
397 QuantRS2Error::CircuitValidationFailed(format!("Tensor with ID {id2} not found"))
398 })?;
399
400 let shared_indices = find_shared_indices(*id1, *id2, &working_connections);
402
403 let result_tensor = contract_tensors(&tensor1, &tensor2, shared_indices)?;
405
406 working_tensors.insert(*id1, result_tensor);
408
409 }
412
413 if working_tensors.len() != 1 {
415 return Err(QuantRS2Error::CircuitValidationFailed(format!(
416 "{} tensors left after contraction (expected 1)",
417 working_tensors.len()
418 )));
419 }
420
421 Ok(working_tensors
423 .into_values()
424 .next()
425 .expect("Exactly one tensor should remain after contraction"))
426 }
427}
428
429fn count_common_indices(
431 id1: usize,
432 id2: usize,
433 connections: &[(TensorIndex, TensorIndex)],
434) -> usize {
435 let mut count = 0;
436
437 for (t1, t2) in connections {
438 if (t1.tensor_id == id1 && t2.tensor_id == id2)
439 || (t1.tensor_id == id2 && t2.tensor_id == id1)
440 {
441 count += 1;
442 }
443 }
444
445 count
446}
447
448const fn estimate_contraction_size(size1: usize, size2: usize, common_indices: usize) -> usize {
450 let common_dim = 2usize.pow(common_indices as u32);
453 (size1 * size2) / common_dim
454}
455
456fn find_shared_indices(
458 id1: usize,
459 id2: usize,
460 connections: &[(TensorIndex, TensorIndex)],
461) -> Vec<(usize, usize)> {
462 let mut shared = Vec::new();
463
464 for (t1, t2) in connections {
465 if t1.tensor_id == id1 && t2.tensor_id == id2 {
466 shared.push((t1.index, t2.index));
467 } else if t1.tensor_id == id2 && t2.tensor_id == id1 {
468 shared.push((t2.index, t1.index));
469 }
470 }
471
472 shared
473}
474
475fn contract_tensors(
477 t1: &Tensor,
478 t2: &Tensor,
479 indices: Vec<(usize, usize)>,
480) -> QuantRS2Result<Tensor> {
481 Ok(t1.clone())
486}
487
488#[derive(Debug, Clone, PartialEq)]
490pub struct ContractionPlan {
491 pairs: Vec<(usize, usize)>,
493
494 flop_estimate: f64,
496
497 memory_estimate: usize,
499}
500
501impl ContractionPlan {
502 pub const fn new(
504 pairs: Vec<(usize, usize)>,
505 flop_estimate: f64,
506 memory_estimate: usize,
507 ) -> Self {
508 Self {
509 pairs,
510 flop_estimate,
511 memory_estimate,
512 }
513 }
514
515 pub fn pairs(&self) -> &[(usize, usize)] {
517 &self.pairs
518 }
519
520 pub const fn flop_estimate(&self) -> f64 {
522 self.flop_estimate
523 }
524
525 pub const fn memory_estimate(&self) -> usize {
527 self.memory_estimate
528 }
529}
530
531impl Eq for ContractionPlan {}
532
533impl Ord for ContractionPlan {
534 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
535 self.flop_estimate
537 .partial_cmp(&other.flop_estimate)
538 .unwrap_or(std::cmp::Ordering::Equal)
539 .then_with(|| self.memory_estimate.cmp(&other.memory_estimate))
540 }
541}
542
543impl PartialOrd for ContractionPlan {
544 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
545 Some(self.cmp(other))
546 }
547}
548
549pub fn generate_contraction_plan(
551 tensors: &HashMap<usize, Tensor>,
552 connections: &[(TensorIndex, TensorIndex)],
553 max_time: Duration,
554) -> QuantRS2Result<ContractionPlan> {
555 let start_time = Instant::now();
557
558 if tensors.is_empty() {
560 return Ok(ContractionPlan::new(Vec::new(), 0.0, 0));
561 }
562
563 let mut tensor_graph = HashMap::new();
565 for (t1, t2) in connections {
566 tensor_graph
567 .entry(t1.tensor_id)
568 .or_insert_with(HashSet::new)
569 .insert(t2.tensor_id);
570 tensor_graph
571 .entry(t2.tensor_id)
572 .or_insert_with(HashSet::new)
573 .insert(t1.tensor_id);
574 }
575
576 let mut tensor_sizes = HashMap::new();
578 for (&id, tensor) in tensors {
579 let size: usize = tensor.dimensions.iter().product();
580 tensor_sizes.insert(id, size);
581 }
582
583 let mut plan_queue = BinaryHeap::new();
585
586 let mut candidate_pairs = Vec::new();
588 for &id1 in tensors.keys() {
589 if let Some(connected) = tensor_graph.get(&id1) {
590 for &id2 in connected {
591 if id1 < id2 {
592 let cost = tensor_sizes[&id1] * tensor_sizes[&id2];
594 candidate_pairs.push((cost, id1, id2));
595 }
596 }
597 }
598 }
599
600 candidate_pairs.sort_by_key(|&(cost, _, _)| cost);
602
603 for (cost, id1, id2) in candidate_pairs.iter().take(5) {
605 let pairs = vec![(*id1, *id2)];
606 plan_queue.push(Reverse(ContractionPlan::new(
607 pairs,
608 *cost as f64,
609 std::cmp::max(tensor_sizes[id1], tensor_sizes[id2]),
610 )));
611 }
612
613 if plan_queue.is_empty() {
615 return Ok(ContractionPlan::new(Vec::new(), 0.0, 0));
616 }
617
618 let mut best_plan = plan_queue
620 .peek()
621 .expect("Plan queue should not be empty at this point")
622 .0
623 .clone();
624
625 while !plan_queue.is_empty() && start_time.elapsed() < max_time {
627 let current_plan = plan_queue
629 .pop()
630 .expect("Plan queue verified non-empty in loop condition")
631 .0;
632
633 if current_plan.pairs.len() == tensors.len() - 1 {
635 if current_plan.flop_estimate < best_plan.flop_estimate {
636 best_plan = current_plan;
637 }
638 continue;
639 }
640
641 let mut remaining = tensors.keys().copied().collect::<HashSet<_>>();
643 let mut current_graph = tensor_graph.clone();
644 let mut current_sizes = tensor_sizes.clone();
645
646 for &(id1, id2) in ¤t_plan.pairs {
647 remaining.remove(&id1);
649 remaining.remove(&id2);
650
651 remaining.insert(id1);
653
654 }
657
658 let mut candidates = Vec::new();
660 for &id1 in &remaining {
661 if let Some(connected) = current_graph.get(&id1) {
662 for &id2 in connected {
663 if remaining.contains(&id2) && id1 < id2 {
664 let cost = current_sizes[&id1] * current_sizes[&id2];
665 candidates.push((cost, id1, id2));
666 }
667 }
668 }
669 }
670
671 candidates.sort_by_key(|&(cost, _, _)| cost);
673
674 for (cost, id1, id2) in candidates.iter().take(3) {
676 let mut new_pairs = current_plan.pairs.clone();
677 new_pairs.push((*id1, *id2));
678
679 let new_flops = current_plan.flop_estimate + *cost as f64;
680 let new_memory = std::cmp::max(current_plan.memory_estimate, *cost);
681
682 plan_queue.push(Reverse(ContractionPlan::new(
683 new_pairs, new_flops, new_memory,
684 )));
685 }
686 }
687
688 Ok(best_plan)
689}