qudit_tensor/network/
network.rs

1use std::{
2    collections::{BTreeSet, HashMap},
3    sync::{Arc, Mutex},
4};
5
6use qudit_expr::{ExpressionCache, GenerationShape};
7
8use super::TensorId;
9use super::index::IndexDirection;
10use super::index::IndexId;
11use super::index::IndexSize;
12use super::index::NetworkIndex;
13use super::index::TensorIndex;
14use super::index::WeightedIndex;
15use super::path::ContractionPath;
16use super::tensor::QuditTensor;
17use crate::tree::TTGTTree;
18
19pub type NetworkEdge = (NetworkIndex, BTreeSet<TensorId>);
20
21pub struct QuditTensorNetwork {
22    tensors: Vec<QuditTensor>,
23    expressions: Arc<Mutex<ExpressionCache>>,
24    local_to_network_index_map: Vec<Vec<IndexId>>,
25    indices: Vec<NetworkEdge>,
26}
27
28// TODO: handle multiple disjoint (potentially empty) subnetworks
29// TODO: handle partial trace
30impl QuditTensorNetwork {
31    pub fn new(
32        tensors: Vec<QuditTensor>,
33        expressions: Arc<Mutex<ExpressionCache>>,
34        local_to_network_index_map: Vec<Vec<IndexId>>,
35        indices: Vec<NetworkEdge>,
36    ) -> Self {
37        for (_, edge) in indices.iter() {
38            if edge.is_empty() {
39                panic!(
40                    "Index not attached to any tensor detected. Empty indices, must have explicit identity/copy tensors attached before final network construction."
41                );
42            }
43        }
44
45        QuditTensorNetwork {
46            tensors,
47            expressions,
48            local_to_network_index_map,
49            indices,
50        }
51    }
52
53    // fn get_num_outputs(&self) -> usize {
54    //     todo!()
55    // }
56
57    fn num_indices(&self) -> usize {
58        self.indices.len()
59    }
60
61    #[allow(dead_code)] // Left for documentation purposes
62    fn index_id(&self, idx: &NetworkIndex) -> Option<IndexId> {
63        self.indices.iter().position(|x| &x.0 == idx)
64    }
65
66    fn index_size(&self, idx_id: IndexId) -> Option<IndexSize> {
67        if idx_id >= self.num_indices() {
68            return None;
69        }
70
71        // Safety: checked bounds
72        unsafe { Some(self.index_size_unchecked(idx_id)) }
73    }
74
75    unsafe fn index_size_unchecked(&self, idx_id: IndexId) -> IndexSize {
76        match &self.indices[idx_id].0 {
77            NetworkIndex::Output(tidx) => tidx.index_size(),
78            NetworkIndex::Contracted(con) => con.index_size(),
79        }
80    }
81
82    #[allow(dead_code)] // Left for documentation purposes
83    fn get_output_indices(&self) -> Vec<TensorIndex> {
84        self.indices
85            .iter()
86            .filter_map(|x| match &x.0 {
87                NetworkIndex::Output(idx) => Some(idx),
88                NetworkIndex::Contracted(_) => None,
89            })
90            .copied()
91            .collect()
92    }
93
94    #[allow(dead_code)] // Left for documentation purposes
95    fn get_output_shape(&self) -> GenerationShape {
96        // Calculate dimension totals for each direction
97        let mut total_batch_dim = None;
98        let mut total_output_dim = None;
99        let mut total_input_dim = None;
100        for idx in self.get_output_indices() {
101            match idx.direction() {
102                IndexDirection::Derivative => {
103                    panic!("Derivatives should not be explicit in networks.")
104                }
105                IndexDirection::Batch => {
106                    if let Some(value) = total_batch_dim.as_mut() {
107                        *value *= idx.index_size();
108                    } else {
109                        total_batch_dim = Some(idx.index_size());
110                    }
111                }
112                IndexDirection::Output => {
113                    if let Some(value) = total_output_dim.as_mut() {
114                        *value *= idx.index_size();
115                    } else {
116                        total_output_dim = Some(idx.index_size());
117                    }
118                }
119                IndexDirection::Input => {
120                    if let Some(value) = total_input_dim.as_mut() {
121                        *value *= idx.index_size();
122                    } else {
123                        total_input_dim = Some(idx.index_size());
124                    }
125                }
126            }
127        }
128
129        match (total_batch_dim, total_output_dim, total_input_dim) {
130            (None, None, None) => GenerationShape::Scalar,
131            (Some(nbatches), None, None) => GenerationShape::Vector(nbatches),
132            (None, Some(nrows), None) => GenerationShape::Matrix(nrows, 1), // Ket
133            (None, None, Some(ncols)) => GenerationShape::Vector(ncols),    // Bra
134            (Some(nbatches), Some(nrows), None) => GenerationShape::Tensor3D(nbatches, nrows, 1),
135            (Some(nbatches), None, Some(ncols)) => GenerationShape::Matrix(nbatches, ncols),
136            (None, Some(nrows), Some(ncols)) => GenerationShape::Matrix(nrows, ncols),
137            (Some(nmats), Some(nrows), Some(ncols)) => {
138                GenerationShape::Tensor3D(nmats, nrows, ncols)
139            }
140        }
141    }
142
143    #[allow(dead_code)] // Left for documentation purposes
144    fn get_tensor_unique_network_indices(&self, tensor_id: TensorId) -> BTreeSet<NetworkIndex> {
145        self.local_to_network_index_map[tensor_id]
146            .iter()
147            .map(|&idx_id| self.indices[idx_id].0)
148            .collect()
149    }
150
151    fn get_tensor_unique_flat_indices(&self, tensor_id: TensorId) -> BTreeSet<WeightedIndex> {
152        self.local_to_network_index_map[tensor_id]
153            .iter()
154            .map(|&idx_id| {
155                (
156                    idx_id,
157                    self.index_size(idx_id)
158                        .expect("Index id unexpectedly not found"),
159                )
160            })
161            .collect()
162    }
163
164    fn get_tensor_output_index_ids(&self, tensor_id: TensorId) -> BTreeSet<IndexId> {
165        self.local_to_network_index_map[tensor_id]
166            .iter()
167            .filter(|&idx_id| self.indices[*idx_id].0.is_output())
168            .copied()
169            .collect()
170    }
171
172    // fn convert_network_to_flat_index(&self, idx: &NetworkIndex) -> WeightedIndex {
173    //     match idx {
174    //         NetworkIndex::Output(idx) => (*idx, self.output_indices[*idx].index_size()),
175    //         NetworkIndex::Contracted(idx) => (*idx + self.get_num_outputs(), self.contractions[*idx].total_dimension),
176    //     }
177    // }
178
179    fn get_neighbors(&self, tensor: TensorId) -> BTreeSet<TensorId> {
180        let mut neighbors = BTreeSet::new();
181        for idx_id in &self.local_to_network_index_map[tensor] {
182            neighbors.extend(self.indices[*idx_id].1.iter());
183        }
184        neighbors
185    }
186
187    fn get_subnetworks(&self) -> Vec<Vec<TensorId>> {
188        let mut subnetworks: Vec<Vec<TensorId>> = Vec::new();
189        let mut visited = vec![false; self.tensors.len()];
190
191        for current_tensor_id in 0..self.tensors.len() {
192            if visited[current_tensor_id] {
193                continue;
194            }
195
196            let mut current_subnetwork = Vec::new();
197            let mut queue = vec![current_tensor_id];
198
199            while let Some(tensor_id) = queue.pop() {
200                if visited[tensor_id] {
201                    continue;
202                }
203                visited[tensor_id] = true;
204                current_subnetwork.push(tensor_id);
205
206                for neighbor in self.get_neighbors(tensor_id) {
207                    if !visited[neighbor] {
208                        queue.push(neighbor);
209                    }
210                }
211            }
212
213            subnetworks.push(current_subnetwork);
214        }
215        subnetworks
216    }
217
218    pub fn solve_for_path(&self) -> ContractionPath {
219        let mut disjoint_paths = Vec::new();
220
221        for subgraph in self.get_subnetworks() {
222            let input = self.build_trivial_contraction_paths(subgraph);
223            let path = if input.len() < 7 {
224                ContractionPath::solve_optimal_simple(input)
225            } else {
226                ContractionPath::solve_greedy_simple(input)
227            };
228            disjoint_paths.push(path);
229        }
230
231        ContractionPath::solve_by_size_simple(disjoint_paths)
232        // pick smallest two and contract (TODO: add new operation to TTGT tree method to
233        // determine function. If contracted indices.len() == 0 then just KRON (need batch kron
234        // too)).
235    }
236
237    fn build_trivial_contraction_paths(&self, subnetwork: Vec<TensorId>) -> Vec<ContractionPath> {
238        subnetwork
239            .iter()
240            .map(|&tensor_id| {
241                let flat_indices = self.get_tensor_unique_flat_indices(tensor_id);
242                let output_indices = self.get_tensor_output_index_ids(tensor_id);
243                ContractionPath::trivial(tensor_id, flat_indices, output_indices)
244            })
245            .collect()
246    }
247
248    pub fn path_to_ttgt_tree(&self, path: ContractionPath) -> TTGTTree {
249        let mut tree_stack: Vec<TTGTTree> = Vec::new();
250
251        for path_element in path.path.iter() {
252            if *path_element == usize::MAX {
253                let left = tree_stack.pop().unwrap();
254                let right = tree_stack.pop().unwrap();
255
256                let left_network_index_ids: Vec<IndexId> =
257                    left.indices().iter().map(|&idx| idx.index_id()).collect();
258                let right_network_index_ids: Vec<IndexId> =
259                    right.indices().iter().map(|&idx| idx.index_id()).collect();
260                // println!("Contracting");
261                // println!("Left network ids: {:?}", left_network_index_ids);
262                // println!("Right network ids: {:?}", right_network_index_ids);
263                // println!("");
264
265                let intersection: Vec<IndexId> = left_network_index_ids
266                    .iter()
267                    .filter(|&id| right_network_index_ids.contains(id))
268                    .copied()
269                    .collect();
270
271                // Shared indices appear in a contraction on both sides, but are not summed over.
272                // These are realized as indices that are output to the network that appear on
273                // in both left and right index sets.
274                let shared_ids: Vec<IndexId> = intersection
275                    .iter()
276                    .filter(|&id| self.indices[*id].0.is_output())
277                    .copied()
278                    .collect();
279
280                let contraction_ids: Vec<IndexId> = intersection
281                    .into_iter()
282                    .filter(|id| !shared_ids.contains(id))
283                    .collect();
284
285                tree_stack.push(left.contract(right, shared_ids, contraction_ids));
286            } else {
287                // This tensor is time to be formatted: self.tensors[*path_element]
288                // it's cache id is base_id
289                // It has these indices: [5, 1, 0, 5, 1, 2] (5 contracted, 1 traced)
290                // First trace over 1: [5, 0, 5, 2]
291                // let traced_id = self.expression_cache.trace(base_id, vec![(1, 4)]);
292                // Then permute-reshape: [5, 0, 2]
293                // let permuted_id = self.expression_cache.permute_reshape(traced_id, (0, 2, 1, 3), [4, 2, 2])
294                // tree_stack.push(New Leaf Node (permuted_id, TensorIndices(...))
295
296                let QuditTensor {
297                    expression: expr_id,
298                    indices,
299                    param_info,
300                } = &self.tensors[*path_element];
301                // [5, 1, 0, 5, 1, 2] (5 contracted, 1 traced)
302                let mut network_idx_ids = self.local_to_network_index_map[*path_element].clone();
303                // println!("");
304                // println!("New Leaf {} (id={:?}), with network ids {network_idx_ids:?}", self.expressions.borrow().base_name(*expr_id), *expr_id);
305
306                // Perform partial traces if necessary
307                // find any indices that appear twice in indices and are only connected to this
308                let mut looped_index_map: HashMap<IndexId, Vec<usize>> = HashMap::new();
309                for (local_idx, &network_idx_id) in network_idx_ids.iter().enumerate() {
310                    let index_edge = &self.indices[network_idx_id];
311                    if !index_edge.0.is_output() && index_edge.1.len() == 1 {
312                        // This edge is looped
313                        looped_index_map
314                            .entry(network_idx_id)
315                            .or_default()
316                            .push(local_idx);
317                    }
318                }
319                // looped_index_map = {1 : (1, 4)}
320
321                // Assert that each looped index vector is exactly length 2 and convert them to pairs
322                let mut to_remove = Vec::with_capacity(looped_index_map.len() * 2);
323                let looped_index_pairs: Vec<(usize, usize)> = looped_index_map
324                    .into_iter()
325                    .map(|(index_id, local_indices)| {
326                        assert_eq!(
327                            local_indices.len(),
328                            2,
329                            "Looped index {:?} did not have exactly two occurrences. It had {}.",
330                            index_id,
331                            local_indices.len()
332                        );
333                        to_remove.extend(local_indices.clone());
334                        (local_indices[0], local_indices[1])
335                    })
336                    .collect();
337
338                to_remove.sort();
339                for traced_local_index in to_remove.iter().rev() {
340                    network_idx_ids.remove(*traced_local_index);
341                }
342                // network_idx_ids = [5, 0, 5, 2]
343
344                let (traced_id, traced_indices) = if looped_index_pairs.is_empty() {
345                    (*expr_id, indices.clone())
346                } else {
347                    let mut guard = self.expressions.lock().unwrap();
348                    let id = guard.trace(*expr_id, looped_index_pairs);
349                    let indices = guard.indices(id);
350                    (id, indices)
351                };
352                // traced_indices = ((0, output), (1, output), (2, input), (3, input))
353
354                // need to argsort indices so local indices that correspond to the same network
355                // index are consecutive
356                let perm = {
357                    let mut argsorted_indices = (0..network_idx_ids.len()).collect::<Vec<_>>();
358                    argsorted_indices.sort_by_key(|&i| network_idx_ids[i]);
359                    argsorted_indices
360                };
361
362                // For now, set generation shape to a vector as the the first time this tensor
363                // is used (either in a contraction, or in output ordering) the tensor indices
364                // will be reshaped again.
365                let traced_nelems = self.expressions.lock().unwrap().num_elements(traced_id);
366                let new_shape = GenerationShape::Vector(traced_nelems);
367                let tranposed_id = self.expressions.lock().unwrap().permute_reshape(
368                    traced_id,
369                    perm.clone(),
370                    new_shape,
371                );
372
373                // group (redimension) indices together that have the same network id
374                let (new_node_indices, tensor_to_expr_position_map) = {
375                    let mut new_node_indices = Vec::new();
376                    let mut tensor_to_expr_position_map = Vec::new();
377
378                    if perm.is_empty() {
379                        // If there are no indices after tracing (e.g., a scalar result),
380                        // the list of new node indices should be empty.
381                    } else {
382                        // Initialize accumulator for the first group of indices
383                        let mut index_size_acm = 1;
384                        let mut prev_network_idx_id = network_idx_ids[perm[0]];
385                        let mut current_group = vec![];
386
387                        // Iterate through the permuted local indices to group by network index ID
388                        for i in 0..perm.len() {
389                            let curr_local_idx = perm[i];
390                            let curr_network_idx_id = network_idx_ids[curr_local_idx];
391                            let curr_index_size = traced_indices[curr_local_idx].index_size();
392
393                            if curr_network_idx_id == prev_network_idx_id {
394                                // If the current network index ID is the same as the previous, accumulate its size
395                                index_size_acm *= curr_index_size;
396                                current_group.push(i);
397                            } else {
398                                // If a new network index ID is encountered, push the accumulated
399                                // TensorIndex for the previous group, then start a new group.
400                                new_node_indices.push(TensorIndex::new(
401                                    IndexDirection::Input,
402                                    prev_network_idx_id,
403                                    index_size_acm,
404                                ));
405                                tensor_to_expr_position_map.push(current_group.clone());
406                                // Start a new group with the current index's size and ID
407                                current_group = vec![i];
408                                index_size_acm = curr_index_size;
409                                prev_network_idx_id = curr_network_idx_id;
410                            }
411                        }
412                        // After the loop, push the last accumulated group
413                        new_node_indices.push(TensorIndex::new(
414                            IndexDirection::Input,
415                            prev_network_idx_id,
416                            index_size_acm,
417                        ));
418                        tensor_to_expr_position_map.push(current_group.clone());
419                    }
420                    (new_node_indices, tensor_to_expr_position_map)
421                };
422
423                // println!("Leaf node has indices: {new_node_indices:?}");
424                tree_stack.push(TTGTTree::leaf(
425                    self.expressions.clone(),
426                    tranposed_id,
427                    param_info.clone(),
428                    new_node_indices,
429                    tensor_to_expr_position_map,
430                ));
431                // println!("");
432            }
433        }
434        if tree_stack.len() != 1 {
435            panic!("Tree stack should have exactly one element.");
436        }
437
438        let tree = tree_stack.pop().unwrap();
439
440        // Perform final Transpose
441        let mut goal_index_order = tree.indices();
442        goal_index_order.sort_by_key(|x| &self.indices[x.index_id()]);
443
444        let final_transpose = goal_index_order
445            .iter()
446            .map(|i| {
447                tree.indices()
448                    .iter()
449                    .position(|x| x.index_id() == i.index_id())
450                    .unwrap()
451            })
452            .collect::<Vec<_>>();
453
454        let final_redirection = goal_index_order
455            .iter()
456            .map(|i| {
457                if let NetworkIndex::Output(tidx) = self.indices[i.index_id()].0 {
458                    tidx.direction()
459                } else {
460                    panic!("Non output index made it to final network output.");
461                }
462            })
463            .collect();
464
465        // println!("Current direction: {:?}", tree.indices().iter().map(|idx| idx.direction()).collect::<Vec<_>>());
466        // println!("Final transpose: {:?} redirection: {:?}", final_transpose, final_redirection);
467        tree.transpose(final_transpose, final_redirection)
468    }
469}