qudit_tensor/network/
builder.rs

1use std::collections::BTreeMap;
2use std::collections::BTreeSet;
3use std::collections::HashMap;
4use std::sync::Arc;
5use std::sync::Mutex;
6
7use crate::network::network::NetworkEdge;
8use qudit_core::ComplexScalar;
9use qudit_core::ParamInfo;
10use qudit_core::Radices;
11use qudit_core::UnitaryMatrix;
12use qudit_expr::ExpressionCache;
13use qudit_expr::ExpressionId;
14use qudit_expr::TensorExpression;
15use qudit_expr::UnitaryExpression;
16
17use super::index::ContractionIndex;
18use super::index::IndexDirection;
19use super::index::IndexId;
20use super::index::IndexSize;
21use super::index::NetworkIndex;
22use super::index::TensorIndex;
23use super::network::QuditTensorNetwork;
24use super::tensor::QuditTensor;
25
26#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, PartialOrd, Ord)]
27enum Wire {
28    Empty,
29    Closed,
30    Connected(usize, usize), // node_id, local_index_id
31}
32
33impl Wire {
34    pub fn is_empty(&self) -> bool {
35        match self {
36            Wire::Empty => true,
37            Wire::Closed => false,
38            Wire::Connected(_, _) => false,
39        }
40    }
41
42    pub fn is_active(&self) -> bool {
43        match self {
44            Wire::Empty => true,
45            Wire::Closed => false,
46            Wire::Connected(_, _) => true,
47        }
48    }
49}
50
51#[derive(Debug, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)]
52enum NetworkBuilderIndex {
53    Front(usize),
54    Rear(usize),
55    Batch(String),
56    Contraction(usize),
57}
58
59pub struct QuditCircuitTensorNetworkBuilder {
60    tensors: Vec<QuditTensor>,
61    local_to_network_index_map: Vec<Vec<NetworkBuilderIndex>>,
62    radices: Radices,
63
64    expressions: Arc<Mutex<ExpressionCache>>,
65
66    /// Pointer to front (left in math/right in circuit diagram) of the network for each qudit.
67    front: Vec<Wire>,
68
69    /// Pointer to rear (right in math/left in circuit diagram) of the network for each qudit.
70    rear: Vec<Wire>,
71    batch_indices: HashMap<String, IndexSize>,
72    contracted_indices: Vec<ContractionIndex>,
73}
74
75impl QuditCircuitTensorNetworkBuilder {
76    pub fn new(radices: Radices, expressions: Option<Arc<Mutex<ExpressionCache>>>) -> Self {
77        let expressions = match expressions {
78            Some(cache) => cache,
79            None => ExpressionCache::new_shared(),
80        };
81
82        QuditCircuitTensorNetworkBuilder {
83            tensors: vec![],
84            local_to_network_index_map: vec![],
85            front: vec![Wire::Empty; radices.len()],
86            rear: vec![Wire::Empty; radices.len()],
87            batch_indices: HashMap::new(),
88            radices,
89            expressions,
90            contracted_indices: vec![],
91        }
92    }
93
94    /// Output indices stick out from the front of the Circuit Tensor Network.
95    ///
96    /// These correspond to wires exiting a circuit in a normal circuit diagram.
97    pub fn open_output_indices(&self) -> Vec<usize> {
98        self.front
99            .iter()
100            .enumerate()
101            .filter(|(_, wire)| wire.is_active())
102            .map(|(id, _)| id)
103            .collect()
104    }
105
106    /// Input indices stick out from the rear of the Circuit Tensor Network.
107    ///
108    /// These correspond to wires entering a circuit in a normal circuit diagram.
109    pub fn open_input_indices(&self) -> Vec<usize> {
110        self.rear
111            .iter()
112            .enumerate()
113            .filter(|(_, wire)| wire.is_active())
114            .map(|(id, _)| id)
115            .collect()
116    }
117
118    pub fn num_open_output_indices(&self) -> usize {
119        self.front.iter().filter(|wire| wire.is_active()).count()
120    }
121
122    pub fn num_open_input_indices(&self) -> usize {
123        self.rear.iter().filter(|wire| wire.is_active()).count()
124    }
125
126    pub fn expression_get(&mut self, expression: TensorExpression) -> ExpressionId {
127        let result = { self.expressions.lock().unwrap().lookup(&expression) };
128        match result {
129            None => self.expressions.lock().unwrap().insert(expression),
130            Some(id) => id,
131        }
132    }
133
134    pub fn prepend_expression(
135        mut self,
136        expression: TensorExpression,
137        param_info: ParamInfo,
138        input_index_map: Vec<usize>,
139        output_index_map: Vec<usize>,
140        batch_index_map: Vec<String>,
141    ) -> Self {
142        let indices = expression.indices().to_owned();
143        let id = self.expression_get(expression);
144        let tensor = QuditTensor::new(indices, id, param_info);
145        self.prepend(tensor, input_index_map, output_index_map, batch_index_map)
146    }
147
148    /// Prepend a tensor onto the circuit network.
149    ///
150    /// # Arguments
151    ///
152    /// * `tensor` - The tensor to prepend
153    ///
154    /// * `input_qudit_map` - An array of qudit ids. It maps the tensors input indices to the
155    ///   qudits at the front of the network that will be connected. `input_qudit_map[i] ==
156    ///   qudit_id` implies that `tensor.input_indices()[i]` will be connected to the
157    ///   `front[qudit_id]` edge.
158    ///
159    /// * `output_qudit_map` - An array of qudit ids. It maps the tensors output indices to
160    ///   the qudits at the front of the network which will become the new open edges for
161    ///   the network front. `output_qudit_map[i] == qudit_id` implies that
162    ///   `tensor.output_tensor_indices()[i]` will be the open edge on qudit_id after the
163    ///   operation.
164    ///
165    /// * `batch_index_map` - An array of strings. It provides names for the tensors batch
166    ///   indices. All batch indices in the network with the same name identify the same
167    ///   network indices. These indices can appear on both sides of a pairwise contraction
168    ///   without being contracted over.
169    ///
170    /// # Panics
171    ///
172    /// - If the length of an index map doesn't match the number of indices the tensor has in that
173    ///   direction.
174    ///
175    /// - If any of the `qudit_ids` referenced by the index maps are invalid or out of bounds.
176    ///
177    /// - If the size of a tensor's index doesn't match the radix of the qudit it's mapped to.
178    ///
179    /// - If a batch index with the same name exists in the network, they must have the same
180    ///   dimension.
181    pub fn prepend(
182        mut self,
183        tensor: QuditTensor,
184        input_index_map: Vec<usize>,
185        output_index_map: Vec<usize>,
186        batch_index_map: Vec<String>,
187    ) -> Self {
188        // Check error conditions
189        let batch_tensor_indices = tensor.batch_indices();
190        let output_tensor_indices = tensor.output_indices();
191        let input_tensor_indices = tensor.input_indices();
192        let batch_tensor_index_sizes = tensor.batch_sizes();
193        let output_tensor_index_sizes = tensor.output_sizes();
194        let input_tensor_index_sizes = tensor.input_sizes();
195
196        if batch_tensor_indices.len() != batch_index_map.len() {
197            panic!("Batch tensor indices and batch qudit map lengths do not match");
198        }
199
200        if output_tensor_indices.len() != output_index_map.len() {
201            panic!("Output tensor indices and output qudit map lengths do not match");
202        }
203
204        if input_tensor_indices.len() != input_index_map.len() {
205            panic!("Input tensor indices and input qudit map lengths do not match");
206        }
207
208        for (i, qudit_id) in output_index_map.iter().enumerate() {
209            if *qudit_id >= self.radices.len() {
210                panic!("Qudit id {qudit_id} is out of bounds from tensor's output map");
211            }
212            assert_eq!(
213                self.radices[*qudit_id], output_tensor_index_sizes[i],
214                "Tensor index size doesn't match mapped qudit radix.",
215            );
216        }
217
218        for (i, qudit_id) in input_index_map.iter().enumerate() {
219            if *qudit_id >= self.radices.len() {
220                panic!("Qudit id {qudit_id} is out of bounds from tensor's input map");
221            }
222            assert_eq!(
223                self.radices[*qudit_id], input_tensor_index_sizes[i],
224                "Tensor index size doesn't match mapped qudit radix.",
225            );
226        }
227
228        let num_qudits_involved = input_index_map
229            .iter()
230            .chain(output_index_map.iter())
231            .copied()
232            .collect::<BTreeSet<_>>()
233            .len();
234        if num_qudits_involved > input_index_map.len()
235            && num_qudits_involved > output_index_map.len()
236        {
237            // Every wire involved needs to do one of three things:
238            //  1. Pass through the tensor
239            //  2. Start at this tensor
240            //  3. End at this tensor
241            // This check among others enforces these constraints
242            panic!("Invalid input and output index map for QuditCircuitTensorNetworkBuilder.");
243        }
244
245        // Permute tensor and index maps so that indices are sequential
246        // For example, a CNOT applied to (1, 0) will be permuted,
247        // so that a permuted cnot is applied to (0, 1). This makes further processing easier.
248
249        let argsorted_output_index_map = {
250            let mut argsorted_indices = (0..output_index_map.len()).collect::<Vec<_>>();
251            argsorted_indices.sort_by_key(|&i| output_index_map[i]);
252            argsorted_indices
253        };
254        let argsorted_input_index_map = {
255            let mut argsorted_indices = (0..input_index_map.len()).collect::<Vec<_>>();
256            argsorted_indices.sort_by_key(|&i| input_index_map[i]);
257            argsorted_indices
258        };
259        let perm = (0..batch_index_map.len())
260            .chain(
261                argsorted_output_index_map
262                    .iter()
263                    .cloned()
264                    .map(|idx| idx + batch_index_map.len()),
265            )
266            .chain(
267                argsorted_input_index_map
268                    .iter()
269                    .cloned()
270                    .map(|idx| idx + output_index_map.len() + batch_index_map.len()),
271            )
272            .collect::<Vec<_>>();
273        let sequential_expr_id = self.expressions.lock().unwrap().permute_reshape(
274            tensor.expression,
275            perm,
276            tensor.shape(),
277        );
278        let sequential_indices = self.expressions.lock().unwrap().indices(sequential_expr_id);
279
280        let tensor = QuditTensor::new(sequential_indices, sequential_expr_id, tensor.param_info);
281        let output_index_map = argsorted_output_index_map
282            .into_iter()
283            .map(|i| output_index_map[i])
284            .collect::<Vec<_>>();
285        let input_index_map = argsorted_input_index_map
286            .into_iter()
287            .map(|i| input_index_map[i])
288            .collect::<Vec<_>>();
289
290        // Add Tensor to network
291        let tensor_id = self.tensors.len();
292        let mut tensor_local_to_network_map = vec![None; tensor.num_indices()];
293        self.tensors.push(tensor);
294
295        // Handle Input Indices
296        let mut new_contraction_ids: BTreeMap<usize, usize> = BTreeMap::new();
297        for (tensor_input_idx_id, qudit_id) in input_index_map.iter().enumerate() {
298            let local_index_id = input_tensor_indices[tensor_input_idx_id];
299            match self.front[*qudit_id] {
300                Wire::Empty => {
301                    self.rear[*qudit_id] = Wire::Connected(tensor_id, local_index_id);
302                    tensor_local_to_network_map[local_index_id] =
303                        Some(NetworkBuilderIndex::Rear(*qudit_id));
304                }
305                Wire::Closed => {
306                    panic!("Cannot contract tensor index with a closed qudit.");
307                }
308                Wire::Connected(existing_tensor_id, existing_local_index_id) => {
309                    // Record a new or update existing contraction between existing_tensor_id and
310                    // tensor_id.
311
312                    let contraction_id = *new_contraction_ids
313                        .entry(existing_tensor_id)
314                        .or_insert_with(|| {
315                            let id = self.contracted_indices.len();
316                            self.contracted_indices.push(ContractionIndex {
317                                left_id: tensor_id,
318                                right_id: existing_tensor_id,
319                                total_dimension: 1,
320                            });
321                            id
322                        });
323
324                    self.contracted_indices[contraction_id].total_dimension *=
325                        usize::from(self.radices[*qudit_id]);
326                    self.local_to_network_index_map[existing_tensor_id][existing_local_index_id] =
327                        NetworkBuilderIndex::Contraction(contraction_id);
328                    tensor_local_to_network_map[local_index_id] =
329                        Some(NetworkBuilderIndex::Contraction(contraction_id));
330                }
331            }
332
333            if !output_index_map.contains(qudit_id) {
334                // The network wire becomes inactive/closed if this tensor is the first one on it
335                // without a corresponding open (output) edge leaving it.
336                self.front[*qudit_id] = Wire::Closed;
337            }
338        }
339
340        // Handle Output Indices
341        for (tensor_output_idx_id, qudit_id) in output_index_map.iter().enumerate() {
342            let local_index_id = output_tensor_indices[tensor_output_idx_id];
343            // Either needs to start here or pass through
344            if !input_index_map.contains(qudit_id) {
345                match self.front[*qudit_id] {
346                    Wire::Empty => {
347                        // Make the start of this wire this tensor by closing rear
348                        self.rear[*qudit_id] = Wire::Closed;
349                    }
350                    Wire::Closed => {}
351                    Wire::Connected(_, _) => {
352                        panic!(
353                            "Cannot map a tensor output qudit over an active edge without connecting on the input side."
354                        );
355                    }
356                }
357            }
358            tensor_local_to_network_map[local_index_id] =
359                Some(NetworkBuilderIndex::Front(*qudit_id));
360            self.front[*qudit_id] = Wire::Connected(tensor_id, local_index_id);
361        }
362
363        // Handle Batch Indices
364        for (tensor_batch_idx_id, batch_idx_name) in batch_index_map.into_iter().enumerate() {
365            let local_index_id = batch_tensor_indices[tensor_batch_idx_id];
366            let batch_tensor_index_size = batch_tensor_index_sizes[tensor_batch_idx_id];
367
368            match self.batch_indices.get(&batch_idx_name) {
369                Some(index_size) => {
370                    assert_eq!(batch_tensor_index_size, *index_size);
371                }
372                None => {
373                    self.batch_indices
374                        .insert(batch_idx_name.clone(), batch_tensor_index_size);
375                }
376            }
377
378            tensor_local_to_network_map[local_index_id] =
379                Some(NetworkBuilderIndex::Batch(batch_idx_name));
380        }
381
382        // Finalize tensor addition
383        self.local_to_network_index_map.push(
384            tensor_local_to_network_map
385                .into_iter()
386                .map(|idx| match idx {
387                    Some(idx) => idx,
388                    None => panic!("Failed to map local tensor index to network index."),
389                })
390                .collect(),
391        );
392
393        self
394    }
395
396    // pub fn append(
397    //     self,
398    //     tensor: QuditTensor,
399    //     left_qudit_map: Vec<usize>,
400    //     right_qudit_map: Vec<usize>,
401    //     batch_index_map: Vec<String>,
402    // ) -> Self {
403    //     todo!()
404    // }
405
406    pub fn prepend_unitary<C: ComplexScalar>(
407        mut self,
408        utry: UnitaryMatrix<C>,
409        qudits: Vec<usize>,
410    ) -> Self {
411        let expr: TensorExpression = UnitaryExpression::from(utry).into();
412        let indices = expr.indices().to_owned();
413        let id = self.expression_get(expr);
414        self.prepend(
415            QuditTensor::new(indices, id, ParamInfo::empty()),
416            qudits.clone(),
417            qudits,
418            vec![],
419        )
420    }
421
422    pub fn trace_wire(mut self, front_qudit: usize, rear_qudit: usize) -> Self {
423        assert_eq!(self.radices[front_qudit], self.radices[rear_qudit]);
424        assert!(self.front[front_qudit].is_active() && self.rear[rear_qudit].is_active());
425
426        if self.front[front_qudit].is_empty() {
427            let identity =
428                UnitaryExpression::identity("Identity", [self.radices[front_qudit]]).into();
429            self = self.prepend_expression(
430                identity,
431                ParamInfo::empty(),
432                [front_qudit].into(),
433                [front_qudit].into(),
434                vec![],
435            );
436        }
437
438        if self.rear[rear_qudit].is_empty() {
439            let identity =
440                UnitaryExpression::identity("Identity", [self.radices[rear_qudit]]).into();
441            self = self.prepend_expression(
442                identity,
443                ParamInfo::empty(),
444                [rear_qudit].into(),
445                [rear_qudit].into(),
446                vec![],
447            );
448        }
449
450        match (&self.front[front_qudit], &self.rear[rear_qudit]) {
451            (Wire::Connected(tid_f, local_id_f), Wire::Connected(tid_r, local_id_r)) => {
452                debug_assert_eq!(
453                    self.local_to_network_index_map[*tid_f][*local_id_f],
454                    NetworkBuilderIndex::Front(front_qudit)
455                );
456                debug_assert_eq!(
457                    self.local_to_network_index_map[*tid_r][*local_id_r],
458                    NetworkBuilderIndex::Rear(rear_qudit)
459                );
460
461                // Find an existing contraction between tid_f and tid_r or create new one
462                let contraction_id = {
463                    let mut contraction_id = None;
464                    for net_index in &self.local_to_network_index_map[*tid_f] {
465                        if let NetworkBuilderIndex::Contraction(cid) = net_index {
466                            let contraction = &self.contracted_indices[*cid];
467                            if contraction.left_id == *tid_r || contraction.right_id == *tid_r {
468                                contraction_id = Some(*cid);
469                                break;
470                            }
471                        }
472                    }
473                    contraction_id.unwrap_or_else(|| {
474                        let cid = self.contracted_indices.len();
475                        self.contracted_indices.push(ContractionIndex {
476                            left_id: *tid_f,
477                            right_id: *tid_r,
478                            total_dimension: self.radices[front_qudit].into(),
479                        });
480                        cid
481                    })
482                };
483
484                self.local_to_network_index_map[*tid_f][*local_id_f] =
485                    NetworkBuilderIndex::Contraction(contraction_id);
486                self.local_to_network_index_map[*tid_r][*local_id_r] =
487                    NetworkBuilderIndex::Contraction(contraction_id);
488            }
489            _ => panic!("Cannot connect a closed wire to another wire."),
490        }
491
492        self.front[front_qudit] = Wire::Closed;
493        self.rear[rear_qudit] = Wire::Closed;
494        self
495    }
496
497    pub fn trace_all_open_wires(mut self) -> Self {
498        assert_eq!(
499            self.num_open_input_indices(),
500            self.num_open_output_indices()
501        );
502        for (f, r) in self
503            .open_output_indices()
504            .into_iter()
505            .zip(self.open_input_indices().into_iter())
506        {
507            self = self.trace_wire(f, r);
508        }
509        self
510    }
511
512    pub fn build(self) -> QuditTensorNetwork {
513        let QuditCircuitTensorNetworkBuilder {
514            mut tensors,
515            mut local_to_network_index_map,
516            expressions,
517            front,
518            rear,
519            batch_indices,
520            contracted_indices,
521            ..
522        } = self;
523
524        // Build network indices, while building map from builder to network
525        let mut indices = Vec::new();
526        let mut builder_to_network_map = HashMap::new();
527
528        let sorted_batch_indices = {
529            let mut as_vec: Vec<(String, IndexSize)> = batch_indices.into_iter().collect();
530            as_vec.sort();
531            as_vec
532        };
533
534        for (batch_idx_name, batch_idx_size) in sorted_batch_indices.into_iter() {
535            let index_id = indices.len();
536            indices.push(NetworkIndex::Output(TensorIndex::new(
537                IndexDirection::Batch,
538                index_id,
539                batch_idx_size,
540            )));
541            builder_to_network_map.insert(NetworkBuilderIndex::Batch(batch_idx_name), index_id);
542        }
543
544        for (qudit_id, wire) in front.into_iter().enumerate() {
545            if wire.is_empty() {
546                // Cannot have empty indices in network, so we need to explicitly add identity.
547                let identity_expression: TensorExpression =
548                    UnitaryExpression::identity("Identity", [self.radices[qudit_id]]).into();
549                let identity_indices = identity_expression.indices().to_owned();
550                let lookup_temp = expressions.lock().unwrap().lookup(&identity_expression);
551                let identity_expr_id = match lookup_temp {
552                    None => expressions.lock().unwrap().insert(identity_expression),
553                    Some(id) => id,
554                };
555                let identity_tensor =
556                    QuditTensor::new(identity_indices, identity_expr_id, ParamInfo::empty());
557                tensors.push(identity_tensor);
558                local_to_network_index_map.push(vec![
559                    NetworkBuilderIndex::Front(qudit_id),
560                    NetworkBuilderIndex::Rear(qudit_id),
561                ]);
562            }
563
564            if wire.is_active() {
565                let index_id = indices.len();
566                indices.push(NetworkIndex::Output(TensorIndex::new(
567                    IndexDirection::Output,
568                    index_id,
569                    self.radices[qudit_id].into(),
570                )));
571                builder_to_network_map.insert(NetworkBuilderIndex::Front(qudit_id), index_id);
572            }
573        }
574
575        for (qudit_id, wire) in rear.into_iter().enumerate() {
576            if wire.is_active() {
577                let index_id = indices.len();
578                indices.push(NetworkIndex::Output(TensorIndex::new(
579                    IndexDirection::Input,
580                    index_id,
581                    self.radices[qudit_id].into(),
582                )));
583                builder_to_network_map.insert(NetworkBuilderIndex::Rear(qudit_id), index_id);
584            }
585        }
586
587        for (cidx_id, contraction_index) in contracted_indices.into_iter().enumerate() {
588            let index_id = indices.len();
589            indices.push(NetworkIndex::Contracted(contraction_index));
590            builder_to_network_map.insert(NetworkBuilderIndex::Contraction(cidx_id), index_id);
591        }
592
593        let mut index_edges: Vec<NetworkEdge> =
594            indices.into_iter().map(|x| (x, BTreeSet::new())).collect();
595
596        let new_index_map = local_to_network_index_map
597            .into_iter()
598            .enumerate()
599            .map(|(tid, tidx_map)| {
600                tidx_map
601                    .into_iter()
602                    .map(|index| {
603                        let network_index = builder_to_network_map[&index];
604                        index_edges[network_index].1.insert(tid);
605                        network_index
606                    })
607                    .collect::<Vec<IndexId>>()
608            })
609            .collect::<Vec<Vec<IndexId>>>();
610
611        QuditTensorNetwork::new(tensors, expressions, new_index_map, index_edges)
612    }
613}