1use std::{
2 collections::{BTreeSet, HashMap},
3 sync::{Arc, Mutex},
4};
5
6use qudit_expr::{ExpressionCache, GenerationShape};
7
8use super::TensorId;
9use super::index::IndexDirection;
10use super::index::IndexId;
11use super::index::IndexSize;
12use super::index::NetworkIndex;
13use super::index::TensorIndex;
14use super::index::WeightedIndex;
15use super::path::ContractionPath;
16use super::tensor::QuditTensor;
17use crate::tree::TTGTTree;
18
19pub type NetworkEdge = (NetworkIndex, BTreeSet<TensorId>);
20
21pub struct QuditTensorNetwork {
22 tensors: Vec<QuditTensor>,
23 expressions: Arc<Mutex<ExpressionCache>>,
24 local_to_network_index_map: Vec<Vec<IndexId>>,
25 indices: Vec<NetworkEdge>,
26}
27
28impl QuditTensorNetwork {
31 pub fn new(
32 tensors: Vec<QuditTensor>,
33 expressions: Arc<Mutex<ExpressionCache>>,
34 local_to_network_index_map: Vec<Vec<IndexId>>,
35 indices: Vec<NetworkEdge>,
36 ) -> Self {
37 for (_, edge) in indices.iter() {
38 if edge.is_empty() {
39 panic!(
40 "Index not attached to any tensor detected. Empty indices, must have explicit identity/copy tensors attached before final network construction."
41 );
42 }
43 }
44
45 QuditTensorNetwork {
46 tensors,
47 expressions,
48 local_to_network_index_map,
49 indices,
50 }
51 }
52
53 fn num_indices(&self) -> usize {
58 self.indices.len()
59 }
60
61 #[allow(dead_code)] fn index_id(&self, idx: &NetworkIndex) -> Option<IndexId> {
63 self.indices.iter().position(|x| &x.0 == idx)
64 }
65
66 fn index_size(&self, idx_id: IndexId) -> Option<IndexSize> {
67 if idx_id >= self.num_indices() {
68 return None;
69 }
70
71 unsafe { Some(self.index_size_unchecked(idx_id)) }
73 }
74
75 unsafe fn index_size_unchecked(&self, idx_id: IndexId) -> IndexSize {
76 match &self.indices[idx_id].0 {
77 NetworkIndex::Output(tidx) => tidx.index_size(),
78 NetworkIndex::Contracted(con) => con.index_size(),
79 }
80 }
81
82 #[allow(dead_code)] fn get_output_indices(&self) -> Vec<TensorIndex> {
84 self.indices
85 .iter()
86 .filter_map(|x| match &x.0 {
87 NetworkIndex::Output(idx) => Some(idx),
88 NetworkIndex::Contracted(_) => None,
89 })
90 .copied()
91 .collect()
92 }
93
94 #[allow(dead_code)] fn get_output_shape(&self) -> GenerationShape {
96 let mut total_batch_dim = None;
98 let mut total_output_dim = None;
99 let mut total_input_dim = None;
100 for idx in self.get_output_indices() {
101 match idx.direction() {
102 IndexDirection::Derivative => {
103 panic!("Derivatives should not be explicit in networks.")
104 }
105 IndexDirection::Batch => {
106 if let Some(value) = total_batch_dim.as_mut() {
107 *value *= idx.index_size();
108 } else {
109 total_batch_dim = Some(idx.index_size());
110 }
111 }
112 IndexDirection::Output => {
113 if let Some(value) = total_output_dim.as_mut() {
114 *value *= idx.index_size();
115 } else {
116 total_output_dim = Some(idx.index_size());
117 }
118 }
119 IndexDirection::Input => {
120 if let Some(value) = total_input_dim.as_mut() {
121 *value *= idx.index_size();
122 } else {
123 total_input_dim = Some(idx.index_size());
124 }
125 }
126 }
127 }
128
129 match (total_batch_dim, total_output_dim, total_input_dim) {
130 (None, None, None) => GenerationShape::Scalar,
131 (Some(nbatches), None, None) => GenerationShape::Vector(nbatches),
132 (None, Some(nrows), None) => GenerationShape::Matrix(nrows, 1), (None, None, Some(ncols)) => GenerationShape::Vector(ncols), (Some(nbatches), Some(nrows), None) => GenerationShape::Tensor3D(nbatches, nrows, 1),
135 (Some(nbatches), None, Some(ncols)) => GenerationShape::Matrix(nbatches, ncols),
136 (None, Some(nrows), Some(ncols)) => GenerationShape::Matrix(nrows, ncols),
137 (Some(nmats), Some(nrows), Some(ncols)) => {
138 GenerationShape::Tensor3D(nmats, nrows, ncols)
139 }
140 }
141 }
142
143 #[allow(dead_code)] fn get_tensor_unique_network_indices(&self, tensor_id: TensorId) -> BTreeSet<NetworkIndex> {
145 self.local_to_network_index_map[tensor_id]
146 .iter()
147 .map(|&idx_id| self.indices[idx_id].0)
148 .collect()
149 }
150
151 fn get_tensor_unique_flat_indices(&self, tensor_id: TensorId) -> BTreeSet<WeightedIndex> {
152 self.local_to_network_index_map[tensor_id]
153 .iter()
154 .map(|&idx_id| {
155 (
156 idx_id,
157 self.index_size(idx_id)
158 .expect("Index id unexpectedly not found"),
159 )
160 })
161 .collect()
162 }
163
164 fn get_tensor_output_index_ids(&self, tensor_id: TensorId) -> BTreeSet<IndexId> {
165 self.local_to_network_index_map[tensor_id]
166 .iter()
167 .filter(|&idx_id| self.indices[*idx_id].0.is_output())
168 .copied()
169 .collect()
170 }
171
172 fn get_neighbors(&self, tensor: TensorId) -> BTreeSet<TensorId> {
180 let mut neighbors = BTreeSet::new();
181 for idx_id in &self.local_to_network_index_map[tensor] {
182 neighbors.extend(self.indices[*idx_id].1.iter());
183 }
184 neighbors
185 }
186
187 fn get_subnetworks(&self) -> Vec<Vec<TensorId>> {
188 let mut subnetworks: Vec<Vec<TensorId>> = Vec::new();
189 let mut visited = vec![false; self.tensors.len()];
190
191 for current_tensor_id in 0..self.tensors.len() {
192 if visited[current_tensor_id] {
193 continue;
194 }
195
196 let mut current_subnetwork = Vec::new();
197 let mut queue = vec![current_tensor_id];
198
199 while let Some(tensor_id) = queue.pop() {
200 if visited[tensor_id] {
201 continue;
202 }
203 visited[tensor_id] = true;
204 current_subnetwork.push(tensor_id);
205
206 for neighbor in self.get_neighbors(tensor_id) {
207 if !visited[neighbor] {
208 queue.push(neighbor);
209 }
210 }
211 }
212
213 subnetworks.push(current_subnetwork);
214 }
215 subnetworks
216 }
217
218 pub fn solve_for_path(&self) -> ContractionPath {
219 let mut disjoint_paths = Vec::new();
220
221 for subgraph in self.get_subnetworks() {
222 let input = self.build_trivial_contraction_paths(subgraph);
223 let path = if input.len() < 7 {
224 ContractionPath::solve_optimal_simple(input)
225 } else {
226 ContractionPath::solve_greedy_simple(input)
227 };
228 disjoint_paths.push(path);
229 }
230
231 ContractionPath::solve_by_size_simple(disjoint_paths)
232 }
236
237 fn build_trivial_contraction_paths(&self, subnetwork: Vec<TensorId>) -> Vec<ContractionPath> {
238 subnetwork
239 .iter()
240 .map(|&tensor_id| {
241 let flat_indices = self.get_tensor_unique_flat_indices(tensor_id);
242 let output_indices = self.get_tensor_output_index_ids(tensor_id);
243 ContractionPath::trivial(tensor_id, flat_indices, output_indices)
244 })
245 .collect()
246 }
247
248 pub fn path_to_ttgt_tree(&self, path: ContractionPath) -> TTGTTree {
249 let mut tree_stack: Vec<TTGTTree> = Vec::new();
250
251 for path_element in path.path.iter() {
252 if *path_element == usize::MAX {
253 let left = tree_stack.pop().unwrap();
254 let right = tree_stack.pop().unwrap();
255
256 let left_network_index_ids: Vec<IndexId> =
257 left.indices().iter().map(|&idx| idx.index_id()).collect();
258 let right_network_index_ids: Vec<IndexId> =
259 right.indices().iter().map(|&idx| idx.index_id()).collect();
260 let intersection: Vec<IndexId> = left_network_index_ids
266 .iter()
267 .filter(|&id| right_network_index_ids.contains(id))
268 .copied()
269 .collect();
270
271 let shared_ids: Vec<IndexId> = intersection
275 .iter()
276 .filter(|&id| self.indices[*id].0.is_output())
277 .copied()
278 .collect();
279
280 let contraction_ids: Vec<IndexId> = intersection
281 .into_iter()
282 .filter(|id| !shared_ids.contains(id))
283 .collect();
284
285 tree_stack.push(left.contract(right, shared_ids, contraction_ids));
286 } else {
287 let QuditTensor {
297 expression: expr_id,
298 indices,
299 param_info,
300 } = &self.tensors[*path_element];
301 let mut network_idx_ids = self.local_to_network_index_map[*path_element].clone();
303 let mut looped_index_map: HashMap<IndexId, Vec<usize>> = HashMap::new();
309 for (local_idx, &network_idx_id) in network_idx_ids.iter().enumerate() {
310 let index_edge = &self.indices[network_idx_id];
311 if !index_edge.0.is_output() && index_edge.1.len() == 1 {
312 looped_index_map
314 .entry(network_idx_id)
315 .or_default()
316 .push(local_idx);
317 }
318 }
319 let mut to_remove = Vec::with_capacity(looped_index_map.len() * 2);
323 let looped_index_pairs: Vec<(usize, usize)> = looped_index_map
324 .into_iter()
325 .map(|(index_id, local_indices)| {
326 assert_eq!(
327 local_indices.len(),
328 2,
329 "Looped index {:?} did not have exactly two occurrences. It had {}.",
330 index_id,
331 local_indices.len()
332 );
333 to_remove.extend(local_indices.clone());
334 (local_indices[0], local_indices[1])
335 })
336 .collect();
337
338 to_remove.sort();
339 for traced_local_index in to_remove.iter().rev() {
340 network_idx_ids.remove(*traced_local_index);
341 }
342 let (traced_id, traced_indices) = if looped_index_pairs.is_empty() {
345 (*expr_id, indices.clone())
346 } else {
347 let mut guard = self.expressions.lock().unwrap();
348 let id = guard.trace(*expr_id, looped_index_pairs);
349 let indices = guard.indices(id);
350 (id, indices)
351 };
352 let perm = {
357 let mut argsorted_indices = (0..network_idx_ids.len()).collect::<Vec<_>>();
358 argsorted_indices.sort_by_key(|&i| network_idx_ids[i]);
359 argsorted_indices
360 };
361
362 let traced_nelems = self.expressions.lock().unwrap().num_elements(traced_id);
366 let new_shape = GenerationShape::Vector(traced_nelems);
367 let tranposed_id = self.expressions.lock().unwrap().permute_reshape(
368 traced_id,
369 perm.clone(),
370 new_shape,
371 );
372
373 let (new_node_indices, tensor_to_expr_position_map) = {
375 let mut new_node_indices = Vec::new();
376 let mut tensor_to_expr_position_map = Vec::new();
377
378 if perm.is_empty() {
379 } else {
382 let mut index_size_acm = 1;
384 let mut prev_network_idx_id = network_idx_ids[perm[0]];
385 let mut current_group = vec![];
386
387 for i in 0..perm.len() {
389 let curr_local_idx = perm[i];
390 let curr_network_idx_id = network_idx_ids[curr_local_idx];
391 let curr_index_size = traced_indices[curr_local_idx].index_size();
392
393 if curr_network_idx_id == prev_network_idx_id {
394 index_size_acm *= curr_index_size;
396 current_group.push(i);
397 } else {
398 new_node_indices.push(TensorIndex::new(
401 IndexDirection::Input,
402 prev_network_idx_id,
403 index_size_acm,
404 ));
405 tensor_to_expr_position_map.push(current_group.clone());
406 current_group = vec![i];
408 index_size_acm = curr_index_size;
409 prev_network_idx_id = curr_network_idx_id;
410 }
411 }
412 new_node_indices.push(TensorIndex::new(
414 IndexDirection::Input,
415 prev_network_idx_id,
416 index_size_acm,
417 ));
418 tensor_to_expr_position_map.push(current_group.clone());
419 }
420 (new_node_indices, tensor_to_expr_position_map)
421 };
422
423 tree_stack.push(TTGTTree::leaf(
425 self.expressions.clone(),
426 tranposed_id,
427 param_info.clone(),
428 new_node_indices,
429 tensor_to_expr_position_map,
430 ));
431 }
433 }
434 if tree_stack.len() != 1 {
435 panic!("Tree stack should have exactly one element.");
436 }
437
438 let tree = tree_stack.pop().unwrap();
439
440 let mut goal_index_order = tree.indices();
442 goal_index_order.sort_by_key(|x| &self.indices[x.index_id()]);
443
444 let final_transpose = goal_index_order
445 .iter()
446 .map(|i| {
447 tree.indices()
448 .iter()
449 .position(|x| x.index_id() == i.index_id())
450 .unwrap()
451 })
452 .collect::<Vec<_>>();
453
454 let final_redirection = goal_index_order
455 .iter()
456 .map(|i| {
457 if let NetworkIndex::Output(tidx) = self.indices[i.index_id()].0 {
458 tidx.direction()
459 } else {
460 panic!("Non output index made it to final network output.");
461 }
462 })
463 .collect();
464
465 tree.transpose(final_transpose, final_redirection)
468 }
469}