quantrs2_sim/tensor_network/
contraction.rs1use super::tensor::Tensor;
7use quantrs2_core::error::QuantRS2Result;
8use std::collections::{HashMap, HashSet};
9
10pub trait ContractableNetwork {
12 fn contract_tensors(&mut self, tensor_id1: usize, tensor_id2: usize) -> QuantRS2Result<usize>;
14
15 fn optimize_contraction_order(&mut self) -> QuantRS2Result<()>;
17}
18
19#[derive(Debug, Clone)]
21pub struct ContractionPath {
22 steps: Vec<(usize, usize)>,
24
25 estimated_cost: f64,
27}
28
29impl ContractionPath {
30 pub const fn new(steps: Vec<(usize, usize)>, estimated_cost: f64) -> Self {
32 Self {
33 steps,
34 estimated_cost,
35 }
36 }
37
38 pub fn steps(&self) -> &[(usize, usize)] {
40 &self.steps
41 }
42
43 pub const fn estimated_cost(&self) -> f64 {
45 self.estimated_cost
46 }
47}
48
49pub fn calculate_greedy_contraction_path(
56 tensors: &HashMap<usize, Tensor>,
57 connections: &[(super::tensor::TensorIndex, super::tensor::TensorIndex)],
58) -> QuantRS2Result<ContractionPath> {
59 let mut tensor_connections = HashMap::new();
61 for (t1, t2) in connections {
62 tensor_connections
63 .entry(t1.tensor_id)
64 .or_insert_with(HashSet::new)
65 .insert(t2.tensor_id);
66 tensor_connections
67 .entry(t2.tensor_id)
68 .or_insert_with(HashSet::new)
69 .insert(t1.tensor_id);
70 }
71
72 let mut tensor_dims = HashMap::new();
74 for (&id, tensor) in tensors {
75 tensor_dims.insert(id, tensor.dimensions.iter().product::<usize>());
76 }
77
78 let mut remaining_tensors: HashSet<usize> = tensors.keys().copied().collect();
81 let mut steps = Vec::new();
82 let mut total_cost = 0.0;
83
84 while remaining_tensors.len() > 1 {
85 let mut best_cost = f64::INFINITY;
86 let mut best_pair = None;
87
88 for &t1 in &remaining_tensors {
90 if let Some(connected) = tensor_connections.get(&t1) {
91 for &t2 in connected {
92 if remaining_tensors.contains(&t2) {
93 let combined_dim = tensor_dims[&t1] * tensor_dims[&t2];
95 let cost = combined_dim as f64;
96
97 if cost < best_cost {
98 best_cost = cost;
99 best_pair = Some((t1, t2));
100 }
101 }
102 }
103 }
104 }
105
106 if let Some((t1, t2)) = best_pair {
108 steps.push((t1, t2));
110 total_cost += best_cost;
111
112 remaining_tensors.remove(&t1);
114 remaining_tensors.remove(&t2);
115
116 let new_id = t1; remaining_tensors.insert(new_id);
119
120 let mut new_connections = HashSet::new();
122
123 let mut t1_connected_tensors = Vec::new();
126 if let Some(t1_connections) = tensor_connections.get(&t1) {
127 for &connected_tensor in t1_connections {
128 if connected_tensor != t2 && remaining_tensors.contains(&connected_tensor) {
129 t1_connected_tensors.push(connected_tensor);
130 new_connections.insert(connected_tensor);
131 }
132 }
133 }
134
135 for connected_tensor in t1_connected_tensors {
137 if let Some(other_connections) = tensor_connections.get_mut(&connected_tensor) {
138 other_connections.remove(&t1);
139 other_connections.remove(&t2);
140 other_connections.insert(new_id);
141 }
142 }
143
144 let mut t2_connected_tensors = Vec::new();
147 if let Some(t2_connections) = tensor_connections.get(&t2) {
148 for &connected_tensor in t2_connections {
149 if connected_tensor != t1 && remaining_tensors.contains(&connected_tensor) {
150 t2_connected_tensors.push(connected_tensor);
151 new_connections.insert(connected_tensor);
152 }
153 }
154 }
155
156 for connected_tensor in t2_connected_tensors {
158 if let Some(other_connections) = tensor_connections.get_mut(&connected_tensor) {
159 other_connections.remove(&t1);
160 other_connections.remove(&t2);
161 other_connections.insert(new_id);
162 }
163 }
164
165 tensor_connections.insert(new_id, new_connections);
167
168 tensor_dims.insert(new_id, (tensor_dims[&t1] * tensor_dims[&t2]) / 2);
171 } else {
172 let mut remaining_vec: Vec<_> = remaining_tensors.iter().copied().collect();
174 remaining_vec.sort_unstable();
175
176 if remaining_vec.len() >= 2 {
177 let t1 = remaining_vec[0];
178 let t2 = remaining_vec[1];
179
180 steps.push((t1, t2));
181 total_cost += (tensor_dims[&t1] * tensor_dims[&t2]) as f64;
182
183 remaining_tensors.remove(&t1);
184 remaining_tensors.remove(&t2);
185 remaining_tensors.insert(t1);
186
187 tensor_dims.insert(t1, (tensor_dims[&t1] * tensor_dims[&t2]) / 2);
189 } else {
190 break;
192 }
193 }
194 }
195
196 Ok(ContractionPath::new(steps, total_cost))
197}
198
199pub fn calculate_optimal_contraction_path(
204 tensors: &HashMap<usize, Tensor>,
205 connections: &[(super::tensor::TensorIndex, super::tensor::TensorIndex)],
206) -> QuantRS2Result<ContractionPath> {
207 if let Some(path) = identify_circuit_structure(tensors, connections) {
210 return Ok(path);
211 }
212
213 calculate_greedy_contraction_path(tensors, connections)
215}
216
217fn identify_circuit_structure(
224 tensors: &HashMap<usize, Tensor>,
225 connections: &[(super::tensor::TensorIndex, super::tensor::TensorIndex)],
226) -> Option<ContractionPath> {
227 let mut tensor_connections = HashMap::new();
229 for (t1, t2) in connections {
230 tensor_connections
231 .entry(t1.tensor_id)
232 .or_insert_with(HashSet::new)
233 .insert(t2.tensor_id);
234 tensor_connections
235 .entry(t2.tensor_id)
236 .or_insert_with(HashSet::new)
237 .insert(t1.tensor_id);
238 }
239
240 let mut tensor_ids: Vec<usize> = tensors.keys().copied().collect();
242 tensor_ids.sort_unstable();
243
244 if is_linear_circuit(&tensor_connections, &tensor_ids) {
248 let mut steps = Vec::new();
250 let mut cost = 0.0;
251
252 let ordered_tensors = order_linear_circuit(&tensor_connections, &tensor_ids);
254
255 for ids in ordered_tensors.windows(2) {
257 steps.push((ids[0], ids[1]));
258 cost += 16.0; }
260
261 return Some(ContractionPath::new(steps, cost));
262 }
263
264 if is_star_circuit(&tensor_connections, &tensor_ids) {
268 let mut steps = Vec::new();
270 let mut cost = 0.0;
271
272 let central = find_central_tensor(&tensor_connections);
274
275 let leaf_tensors: Vec<_> = tensor_ids
277 .iter()
278 .filter(|&&id| {
279 id != central
280 && tensor_connections
281 .get(&id)
282 .is_some_and(|conns| conns.contains(¢ral))
283 })
284 .copied()
285 .collect();
286
287 for leaf in leaf_tensors {
288 steps.push((central, leaf));
289 cost += 16.0; }
291
292 return Some(ContractionPath::new(steps, cost));
293 }
294
295 if is_qft_circuit(&tensor_connections, tensors) {
298 return Some(optimize_qft_circuit(&tensor_connections, tensors));
299 }
300
301 if is_qaoa_circuit(&tensor_connections, tensors) {
304 return Some(optimize_qaoa_circuit(&tensor_connections, tensors));
305 }
306
307 None
309}
310
311fn is_qft_circuit(
313 tensor_connections: &HashMap<usize, HashSet<usize>>,
314 tensors: &HashMap<usize, Tensor>,
315) -> bool {
316 let mut hadamard_count = 0;
321 let mut controlled_phase_count = 0;
322 let mut swap_count = 0;
323
324 for tensor in tensors.values() {
326 if tensor.rank == 2 {
328 hadamard_count += 1;
329 } else if tensor.rank == 4 {
330 if tensor.dimensions == vec![2, 2, 2, 2] {
332 controlled_phase_count += 1;
335 }
336
337 if is_swap_like_tensor(tensor) {
339 swap_count += 1;
340 }
341 }
342 }
343
344 hadamard_count > 0 && controlled_phase_count > 0 && hadamard_count >= controlled_phase_count / 2
350}
351
352fn is_swap_like_tensor(tensor: &Tensor) -> bool {
354 tensor.rank == 4 && tensor.dimensions == vec![2, 2, 2, 2]
357}
358
359fn optimize_qft_circuit(
361 tensor_connections: &HashMap<usize, HashSet<usize>>,
362 tensors: &HashMap<usize, Tensor>,
363) -> ContractionPath {
364 let mut ordered_tensors: Vec<usize> = Vec::new();
369 let mut tensor_ids: Vec<usize> = tensors.keys().copied().collect();
370 tensor_ids.sort_unstable();
371
372 let mut steps = Vec::new();
375 let mut cost = 0.0;
376
377 let mut layers = identify_qft_layers(tensor_connections, &tensor_ids);
382
383 for layer in layers {
385 for i in 0..layer.len().saturating_sub(1) {
387 steps.push((layer[i], layer[i + 1]));
388 cost += 16.0; }
390 }
391
392 if steps.is_empty() {
394 for i in 0..tensor_ids.len().saturating_sub(1) {
395 steps.push((tensor_ids[i], tensor_ids[i + 1]));
396 cost += 16.0;
397 }
398 }
399
400 ContractionPath::new(steps, cost)
401}
402
403fn identify_qft_layers(
405 tensor_connections: &HashMap<usize, HashSet<usize>>,
406 tensor_ids: &[usize],
407) -> Vec<Vec<usize>> {
408 let mut degree_groups: HashMap<usize, Vec<usize>> = HashMap::new();
416
417 for &id in tensor_ids {
418 let degree = tensor_connections.get(&id).map_or(0, |conns| conns.len());
419 degree_groups.entry(degree).or_default().push(id);
420 }
421
422 let mut degrees: Vec<usize> = degree_groups.keys().copied().collect();
424 degrees.sort_by(|a, b| b.cmp(a));
425
426 let mut layers = Vec::new();
428 for degree in degrees {
429 if let Some(group) = degree_groups.get(°ree) {
430 layers.push(group.clone());
431 }
432 }
433
434 layers
435}
436
437fn is_qaoa_circuit(
439 tensor_connections: &HashMap<usize, HashSet<usize>>,
440 tensors: &HashMap<usize, Tensor>,
441) -> bool {
442 let mut x_rotation_count = 0;
447 let mut zz_interaction_count = 0;
448
449 for tensor in tensors.values() {
451 if tensor.rank == 2 {
453 x_rotation_count += 1; }
455 else if tensor.rank == 4 {
457 zz_interaction_count += 1; }
459 }
460
461 x_rotation_count > 0 && zz_interaction_count > 0
464}
465
466fn optimize_qaoa_circuit(
468 tensor_connections: &HashMap<usize, HashSet<usize>>,
469 tensors: &HashMap<usize, Tensor>,
470) -> ContractionPath {
471 let mut tensor_ids: Vec<usize> = tensors.keys().copied().collect();
476 tensor_ids.sort_by(|a, b| {
477 if let (Some(tensor_a), Some(tensor_b)) = (tensors.get(a), tensors.get(b)) {
478 tensor_b.rank.cmp(&tensor_a.rank) } else {
480 std::cmp::Ordering::Equal
481 }
482 });
483
484 let mut rank_groups: HashMap<usize, Vec<usize>> = HashMap::new();
486
487 for &id in &tensor_ids {
488 if let Some(tensor) = tensors.get(&id) {
489 rank_groups.entry(tensor.rank).or_default().push(id);
490 }
491 }
492
493 let mut steps = Vec::new();
495 let mut cost = 0.0;
496
497 if let Some(two_qubit_gates) = rank_groups.get(&4) {
499 for (i, &id1) in two_qubit_gates.iter().enumerate() {
500 for &id2 in two_qubit_gates.iter().skip(i + 1) {
501 if tensor_connections
503 .get(&id1)
504 .is_some_and(|conns| conns.contains(&id2))
505 {
506 steps.push((id1, id2));
507 cost += 64.0; }
509 }
510 }
511 }
512
513 if let Some(single_qubit_gates) = rank_groups.get(&2) {
515 for (i, &id1) in single_qubit_gates.iter().enumerate() {
516 for &id2 in single_qubit_gates.iter().skip(i + 1) {
517 if tensor_connections
519 .get(&id1)
520 .is_some_and(|conns| conns.contains(&id2))
521 {
522 steps.push((id1, id2));
523 cost += 16.0; }
525 }
526 }
527 }
528
529 if steps.is_empty() {
532 for i in 0..tensor_ids.len().saturating_sub(1) {
533 steps.push((tensor_ids[i], tensor_ids[i + 1]));
534 cost += 16.0; }
536 }
537
538 ContractionPath::new(steps, cost)
539}
540
541fn is_linear_circuit(
543 tensor_connections: &HashMap<usize, HashSet<usize>>,
544 tensor_ids: &[usize],
545) -> bool {
546 let mut num_endpoints = 0;
548
549 for &id in tensor_ids {
550 let degree = tensor_connections.get(&id).map_or(0, |conns| conns.len());
551
552 if degree > 2 {
553 return false;
555 } else if degree == 1 {
556 num_endpoints += 1;
558 }
559 }
560
561 num_endpoints == 2
563}
564
565fn order_linear_circuit(
567 tensor_connections: &HashMap<usize, HashSet<usize>>,
568 tensor_ids: &[usize],
569) -> Vec<usize> {
570 let mut result = Vec::new();
571
572 let mut current = tensor_ids
574 .iter()
575 .find(|&&id| {
576 tensor_connections
577 .get(&id)
578 .is_some_and(|conns| conns.len() == 1)
579 })
580 .copied();
581
582 if let Some(start) = current {
583 result.push(start);
585 let mut visited = HashSet::new();
586 visited.insert(start);
587
588 while let Some(id) = current {
590 if let Some(connections) = tensor_connections.get(&id) {
591 let next = connections
592 .iter()
593 .find(|&&next_id| !visited.contains(&next_id))
594 .copied();
595
596 if let Some(next_id) = next {
597 result.push(next_id);
598 visited.insert(next_id);
599 current = Some(next_id);
600 } else {
601 current = None;
603 }
604 } else {
605 current = None;
606 }
607 }
608 }
609
610 if result.len() != tensor_ids.len() {
612 return tensor_ids.to_vec();
613 }
614
615 result
616}
617
618fn is_star_circuit(
620 tensor_connections: &HashMap<usize, HashSet<usize>>,
621 tensor_ids: &[usize],
622) -> bool {
623 let mut degree_counts = HashMap::new();
625
626 for &id in tensor_ids {
627 let degree = tensor_connections.get(&id).map_or(0, |conns| conns.len());
628 *degree_counts.entry(degree).or_insert(0) += 1;
629 }
630
631 let high_degree = degree_counts.keys().filter(|&&d| d > 2).count();
634 let degree_one = degree_counts.get(&1).copied().unwrap_or(0);
635
636 high_degree == 1 && degree_one > 2
638}
639
640fn find_central_tensor(tensor_connections: &HashMap<usize, HashSet<usize>>) -> usize {
642 let mut max_degree = 0;
643 let mut central = 0;
644
645 for (&id, connections) in tensor_connections {
646 let degree = connections.len();
647 if degree > max_degree {
648 max_degree = degree;
649 central = id;
650 }
651 }
652
653 central
654}
655
656pub fn contract_network_along_path(
658 tensors: &mut HashMap<usize, Tensor>,
659 connections: &mut Vec<(super::tensor::TensorIndex, super::tensor::TensorIndex)>,
660 path: &ContractionPath,
661 next_id: &mut usize,
662) -> QuantRS2Result<Tensor> {
663 if let Some(tensor) = tensors.values().next() {
668 Ok(tensor.clone())
669 } else {
670 Ok(Tensor::qubit_zero())
671 }
672}