1use std::collections::BTreeMap;
2use std::collections::BTreeSet;
3use std::collections::HashMap;
4use std::sync::Arc;
5use std::sync::Mutex;
6
7use crate::network::network::NetworkEdge;
8use qudit_core::ComplexScalar;
9use qudit_core::ParamInfo;
10use qudit_core::Radices;
11use qudit_core::UnitaryMatrix;
12use qudit_expr::ExpressionCache;
13use qudit_expr::ExpressionId;
14use qudit_expr::TensorExpression;
15use qudit_expr::UnitaryExpression;
16
17use super::index::ContractionIndex;
18use super::index::IndexDirection;
19use super::index::IndexId;
20use super::index::IndexSize;
21use super::index::NetworkIndex;
22use super::index::TensorIndex;
23use super::network::QuditTensorNetwork;
24use super::tensor::QuditTensor;
25
26#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, PartialOrd, Ord)]
27enum Wire {
28 Empty,
29 Closed,
30 Connected(usize, usize), }
32
33impl Wire {
34 pub fn is_empty(&self) -> bool {
35 match self {
36 Wire::Empty => true,
37 Wire::Closed => false,
38 Wire::Connected(_, _) => false,
39 }
40 }
41
42 pub fn is_active(&self) -> bool {
43 match self {
44 Wire::Empty => true,
45 Wire::Closed => false,
46 Wire::Connected(_, _) => true,
47 }
48 }
49}
50
51#[derive(Debug, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)]
52enum NetworkBuilderIndex {
53 Front(usize),
54 Rear(usize),
55 Batch(String),
56 Contraction(usize),
57}
58
59pub struct QuditCircuitTensorNetworkBuilder {
60 tensors: Vec<QuditTensor>,
61 local_to_network_index_map: Vec<Vec<NetworkBuilderIndex>>,
62 radices: Radices,
63
64 expressions: Arc<Mutex<ExpressionCache>>,
65
66 front: Vec<Wire>,
68
69 rear: Vec<Wire>,
71 batch_indices: HashMap<String, IndexSize>,
72 contracted_indices: Vec<ContractionIndex>,
73}
74
75impl QuditCircuitTensorNetworkBuilder {
76 pub fn new(radices: Radices, expressions: Option<Arc<Mutex<ExpressionCache>>>) -> Self {
77 let expressions = match expressions {
78 Some(cache) => cache,
79 None => ExpressionCache::new_shared(),
80 };
81
82 QuditCircuitTensorNetworkBuilder {
83 tensors: vec![],
84 local_to_network_index_map: vec![],
85 front: vec![Wire::Empty; radices.len()],
86 rear: vec![Wire::Empty; radices.len()],
87 batch_indices: HashMap::new(),
88 radices,
89 expressions,
90 contracted_indices: vec![],
91 }
92 }
93
94 pub fn open_output_indices(&self) -> Vec<usize> {
98 self.front
99 .iter()
100 .enumerate()
101 .filter(|(_, wire)| wire.is_active())
102 .map(|(id, _)| id)
103 .collect()
104 }
105
106 pub fn open_input_indices(&self) -> Vec<usize> {
110 self.rear
111 .iter()
112 .enumerate()
113 .filter(|(_, wire)| wire.is_active())
114 .map(|(id, _)| id)
115 .collect()
116 }
117
118 pub fn num_open_output_indices(&self) -> usize {
119 self.front.iter().filter(|wire| wire.is_active()).count()
120 }
121
122 pub fn num_open_input_indices(&self) -> usize {
123 self.rear.iter().filter(|wire| wire.is_active()).count()
124 }
125
126 pub fn expression_get(&mut self, expression: TensorExpression) -> ExpressionId {
127 let result = { self.expressions.lock().unwrap().lookup(&expression) };
128 match result {
129 None => self.expressions.lock().unwrap().insert(expression),
130 Some(id) => id,
131 }
132 }
133
134 pub fn prepend_expression(
135 mut self,
136 expression: TensorExpression,
137 param_info: ParamInfo,
138 input_index_map: Vec<usize>,
139 output_index_map: Vec<usize>,
140 batch_index_map: Vec<String>,
141 ) -> Self {
142 let indices = expression.indices().to_owned();
143 let id = self.expression_get(expression);
144 let tensor = QuditTensor::new(indices, id, param_info);
145 self.prepend(tensor, input_index_map, output_index_map, batch_index_map)
146 }
147
148 pub fn prepend(
182 mut self,
183 tensor: QuditTensor,
184 input_index_map: Vec<usize>,
185 output_index_map: Vec<usize>,
186 batch_index_map: Vec<String>,
187 ) -> Self {
188 let batch_tensor_indices = tensor.batch_indices();
190 let output_tensor_indices = tensor.output_indices();
191 let input_tensor_indices = tensor.input_indices();
192 let batch_tensor_index_sizes = tensor.batch_sizes();
193 let output_tensor_index_sizes = tensor.output_sizes();
194 let input_tensor_index_sizes = tensor.input_sizes();
195
196 if batch_tensor_indices.len() != batch_index_map.len() {
197 panic!("Batch tensor indices and batch qudit map lengths do not match");
198 }
199
200 if output_tensor_indices.len() != output_index_map.len() {
201 panic!("Output tensor indices and output qudit map lengths do not match");
202 }
203
204 if input_tensor_indices.len() != input_index_map.len() {
205 panic!("Input tensor indices and input qudit map lengths do not match");
206 }
207
208 for (i, qudit_id) in output_index_map.iter().enumerate() {
209 if *qudit_id >= self.radices.len() {
210 panic!("Qudit id {qudit_id} is out of bounds from tensor's output map");
211 }
212 assert_eq!(
213 self.radices[*qudit_id], output_tensor_index_sizes[i],
214 "Tensor index size doesn't match mapped qudit radix.",
215 );
216 }
217
218 for (i, qudit_id) in input_index_map.iter().enumerate() {
219 if *qudit_id >= self.radices.len() {
220 panic!("Qudit id {qudit_id} is out of bounds from tensor's input map");
221 }
222 assert_eq!(
223 self.radices[*qudit_id], input_tensor_index_sizes[i],
224 "Tensor index size doesn't match mapped qudit radix.",
225 );
226 }
227
228 let num_qudits_involved = input_index_map
229 .iter()
230 .chain(output_index_map.iter())
231 .copied()
232 .collect::<BTreeSet<_>>()
233 .len();
234 if num_qudits_involved > input_index_map.len()
235 && num_qudits_involved > output_index_map.len()
236 {
237 panic!("Invalid input and output index map for QuditCircuitTensorNetworkBuilder.");
243 }
244
245 let argsorted_output_index_map = {
250 let mut argsorted_indices = (0..output_index_map.len()).collect::<Vec<_>>();
251 argsorted_indices.sort_by_key(|&i| output_index_map[i]);
252 argsorted_indices
253 };
254 let argsorted_input_index_map = {
255 let mut argsorted_indices = (0..input_index_map.len()).collect::<Vec<_>>();
256 argsorted_indices.sort_by_key(|&i| input_index_map[i]);
257 argsorted_indices
258 };
259 let perm = (0..batch_index_map.len())
260 .chain(
261 argsorted_output_index_map
262 .iter()
263 .cloned()
264 .map(|idx| idx + batch_index_map.len()),
265 )
266 .chain(
267 argsorted_input_index_map
268 .iter()
269 .cloned()
270 .map(|idx| idx + output_index_map.len() + batch_index_map.len()),
271 )
272 .collect::<Vec<_>>();
273 let sequential_expr_id = self.expressions.lock().unwrap().permute_reshape(
274 tensor.expression,
275 perm,
276 tensor.shape(),
277 );
278 let sequential_indices = self.expressions.lock().unwrap().indices(sequential_expr_id);
279
280 let tensor = QuditTensor::new(sequential_indices, sequential_expr_id, tensor.param_info);
281 let output_index_map = argsorted_output_index_map
282 .into_iter()
283 .map(|i| output_index_map[i])
284 .collect::<Vec<_>>();
285 let input_index_map = argsorted_input_index_map
286 .into_iter()
287 .map(|i| input_index_map[i])
288 .collect::<Vec<_>>();
289
290 let tensor_id = self.tensors.len();
292 let mut tensor_local_to_network_map = vec![None; tensor.num_indices()];
293 self.tensors.push(tensor);
294
295 let mut new_contraction_ids: BTreeMap<usize, usize> = BTreeMap::new();
297 for (tensor_input_idx_id, qudit_id) in input_index_map.iter().enumerate() {
298 let local_index_id = input_tensor_indices[tensor_input_idx_id];
299 match self.front[*qudit_id] {
300 Wire::Empty => {
301 self.rear[*qudit_id] = Wire::Connected(tensor_id, local_index_id);
302 tensor_local_to_network_map[local_index_id] =
303 Some(NetworkBuilderIndex::Rear(*qudit_id));
304 }
305 Wire::Closed => {
306 panic!("Cannot contract tensor index with a closed qudit.");
307 }
308 Wire::Connected(existing_tensor_id, existing_local_index_id) => {
309 let contraction_id = *new_contraction_ids
313 .entry(existing_tensor_id)
314 .or_insert_with(|| {
315 let id = self.contracted_indices.len();
316 self.contracted_indices.push(ContractionIndex {
317 left_id: tensor_id,
318 right_id: existing_tensor_id,
319 total_dimension: 1,
320 });
321 id
322 });
323
324 self.contracted_indices[contraction_id].total_dimension *=
325 usize::from(self.radices[*qudit_id]);
326 self.local_to_network_index_map[existing_tensor_id][existing_local_index_id] =
327 NetworkBuilderIndex::Contraction(contraction_id);
328 tensor_local_to_network_map[local_index_id] =
329 Some(NetworkBuilderIndex::Contraction(contraction_id));
330 }
331 }
332
333 if !output_index_map.contains(qudit_id) {
334 self.front[*qudit_id] = Wire::Closed;
337 }
338 }
339
340 for (tensor_output_idx_id, qudit_id) in output_index_map.iter().enumerate() {
342 let local_index_id = output_tensor_indices[tensor_output_idx_id];
343 if !input_index_map.contains(qudit_id) {
345 match self.front[*qudit_id] {
346 Wire::Empty => {
347 self.rear[*qudit_id] = Wire::Closed;
349 }
350 Wire::Closed => {}
351 Wire::Connected(_, _) => {
352 panic!(
353 "Cannot map a tensor output qudit over an active edge without connecting on the input side."
354 );
355 }
356 }
357 }
358 tensor_local_to_network_map[local_index_id] =
359 Some(NetworkBuilderIndex::Front(*qudit_id));
360 self.front[*qudit_id] = Wire::Connected(tensor_id, local_index_id);
361 }
362
363 for (tensor_batch_idx_id, batch_idx_name) in batch_index_map.into_iter().enumerate() {
365 let local_index_id = batch_tensor_indices[tensor_batch_idx_id];
366 let batch_tensor_index_size = batch_tensor_index_sizes[tensor_batch_idx_id];
367
368 match self.batch_indices.get(&batch_idx_name) {
369 Some(index_size) => {
370 assert_eq!(batch_tensor_index_size, *index_size);
371 }
372 None => {
373 self.batch_indices
374 .insert(batch_idx_name.clone(), batch_tensor_index_size);
375 }
376 }
377
378 tensor_local_to_network_map[local_index_id] =
379 Some(NetworkBuilderIndex::Batch(batch_idx_name));
380 }
381
382 self.local_to_network_index_map.push(
384 tensor_local_to_network_map
385 .into_iter()
386 .map(|idx| match idx {
387 Some(idx) => idx,
388 None => panic!("Failed to map local tensor index to network index."),
389 })
390 .collect(),
391 );
392
393 self
394 }
395
396 pub fn prepend_unitary<C: ComplexScalar>(
407 mut self,
408 utry: UnitaryMatrix<C>,
409 qudits: Vec<usize>,
410 ) -> Self {
411 let expr: TensorExpression = UnitaryExpression::from(utry).into();
412 let indices = expr.indices().to_owned();
413 let id = self.expression_get(expr);
414 self.prepend(
415 QuditTensor::new(indices, id, ParamInfo::empty()),
416 qudits.clone(),
417 qudits,
418 vec![],
419 )
420 }
421
422 pub fn trace_wire(mut self, front_qudit: usize, rear_qudit: usize) -> Self {
423 assert_eq!(self.radices[front_qudit], self.radices[rear_qudit]);
424 assert!(self.front[front_qudit].is_active() && self.rear[rear_qudit].is_active());
425
426 if self.front[front_qudit].is_empty() {
427 let identity =
428 UnitaryExpression::identity("Identity", [self.radices[front_qudit]]).into();
429 self = self.prepend_expression(
430 identity,
431 ParamInfo::empty(),
432 [front_qudit].into(),
433 [front_qudit].into(),
434 vec![],
435 );
436 }
437
438 if self.rear[rear_qudit].is_empty() {
439 let identity =
440 UnitaryExpression::identity("Identity", [self.radices[rear_qudit]]).into();
441 self = self.prepend_expression(
442 identity,
443 ParamInfo::empty(),
444 [rear_qudit].into(),
445 [rear_qudit].into(),
446 vec![],
447 );
448 }
449
450 match (&self.front[front_qudit], &self.rear[rear_qudit]) {
451 (Wire::Connected(tid_f, local_id_f), Wire::Connected(tid_r, local_id_r)) => {
452 debug_assert_eq!(
453 self.local_to_network_index_map[*tid_f][*local_id_f],
454 NetworkBuilderIndex::Front(front_qudit)
455 );
456 debug_assert_eq!(
457 self.local_to_network_index_map[*tid_r][*local_id_r],
458 NetworkBuilderIndex::Rear(rear_qudit)
459 );
460
461 let contraction_id = {
463 let mut contraction_id = None;
464 for net_index in &self.local_to_network_index_map[*tid_f] {
465 if let NetworkBuilderIndex::Contraction(cid) = net_index {
466 let contraction = &self.contracted_indices[*cid];
467 if contraction.left_id == *tid_r || contraction.right_id == *tid_r {
468 contraction_id = Some(*cid);
469 break;
470 }
471 }
472 }
473 contraction_id.unwrap_or_else(|| {
474 let cid = self.contracted_indices.len();
475 self.contracted_indices.push(ContractionIndex {
476 left_id: *tid_f,
477 right_id: *tid_r,
478 total_dimension: self.radices[front_qudit].into(),
479 });
480 cid
481 })
482 };
483
484 self.local_to_network_index_map[*tid_f][*local_id_f] =
485 NetworkBuilderIndex::Contraction(contraction_id);
486 self.local_to_network_index_map[*tid_r][*local_id_r] =
487 NetworkBuilderIndex::Contraction(contraction_id);
488 }
489 _ => panic!("Cannot connect a closed wire to another wire."),
490 }
491
492 self.front[front_qudit] = Wire::Closed;
493 self.rear[rear_qudit] = Wire::Closed;
494 self
495 }
496
497 pub fn trace_all_open_wires(mut self) -> Self {
498 assert_eq!(
499 self.num_open_input_indices(),
500 self.num_open_output_indices()
501 );
502 for (f, r) in self
503 .open_output_indices()
504 .into_iter()
505 .zip(self.open_input_indices().into_iter())
506 {
507 self = self.trace_wire(f, r);
508 }
509 self
510 }
511
512 pub fn build(self) -> QuditTensorNetwork {
513 let QuditCircuitTensorNetworkBuilder {
514 mut tensors,
515 mut local_to_network_index_map,
516 expressions,
517 front,
518 rear,
519 batch_indices,
520 contracted_indices,
521 ..
522 } = self;
523
524 let mut indices = Vec::new();
526 let mut builder_to_network_map = HashMap::new();
527
528 let sorted_batch_indices = {
529 let mut as_vec: Vec<(String, IndexSize)> = batch_indices.into_iter().collect();
530 as_vec.sort();
531 as_vec
532 };
533
534 for (batch_idx_name, batch_idx_size) in sorted_batch_indices.into_iter() {
535 let index_id = indices.len();
536 indices.push(NetworkIndex::Output(TensorIndex::new(
537 IndexDirection::Batch,
538 index_id,
539 batch_idx_size,
540 )));
541 builder_to_network_map.insert(NetworkBuilderIndex::Batch(batch_idx_name), index_id);
542 }
543
544 for (qudit_id, wire) in front.into_iter().enumerate() {
545 if wire.is_empty() {
546 let identity_expression: TensorExpression =
548 UnitaryExpression::identity("Identity", [self.radices[qudit_id]]).into();
549 let identity_indices = identity_expression.indices().to_owned();
550 let lookup_temp = expressions.lock().unwrap().lookup(&identity_expression);
551 let identity_expr_id = match lookup_temp {
552 None => expressions.lock().unwrap().insert(identity_expression),
553 Some(id) => id,
554 };
555 let identity_tensor =
556 QuditTensor::new(identity_indices, identity_expr_id, ParamInfo::empty());
557 tensors.push(identity_tensor);
558 local_to_network_index_map.push(vec![
559 NetworkBuilderIndex::Front(qudit_id),
560 NetworkBuilderIndex::Rear(qudit_id),
561 ]);
562 }
563
564 if wire.is_active() {
565 let index_id = indices.len();
566 indices.push(NetworkIndex::Output(TensorIndex::new(
567 IndexDirection::Output,
568 index_id,
569 self.radices[qudit_id].into(),
570 )));
571 builder_to_network_map.insert(NetworkBuilderIndex::Front(qudit_id), index_id);
572 }
573 }
574
575 for (qudit_id, wire) in rear.into_iter().enumerate() {
576 if wire.is_active() {
577 let index_id = indices.len();
578 indices.push(NetworkIndex::Output(TensorIndex::new(
579 IndexDirection::Input,
580 index_id,
581 self.radices[qudit_id].into(),
582 )));
583 builder_to_network_map.insert(NetworkBuilderIndex::Rear(qudit_id), index_id);
584 }
585 }
586
587 for (cidx_id, contraction_index) in contracted_indices.into_iter().enumerate() {
588 let index_id = indices.len();
589 indices.push(NetworkIndex::Contracted(contraction_index));
590 builder_to_network_map.insert(NetworkBuilderIndex::Contraction(cidx_id), index_id);
591 }
592
593 let mut index_edges: Vec<NetworkEdge> =
594 indices.into_iter().map(|x| (x, BTreeSet::new())).collect();
595
596 let new_index_map = local_to_network_index_map
597 .into_iter()
598 .enumerate()
599 .map(|(tid, tidx_map)| {
600 tidx_map
601 .into_iter()
602 .map(|index| {
603 let network_index = builder_to_network_map[&index];
604 index_edges[network_index].1.insert(tid);
605 network_index
606 })
607 .collect::<Vec<IndexId>>()
608 })
609 .collect::<Vec<Vec<IndexId>>>();
610
611 QuditTensorNetwork::new(tensors, expressions, new_index_map, index_edges)
612 }
613}