ket/
graph.rs

1// SPDX-FileCopyrightText: 2024 Evandro Chagas Ribeiro da Rosa <evandro@quantuloop.com>
2//
3// SPDX-License-Identifier: Apache-2.0
4
5use crate::ir::qubit::Qubit;
6use std::{collections::VecDeque, marker::PhantomData};
7
8#[derive(Debug, Clone, Default)]
9pub(crate) struct GraphMatrix<Q> {
10    graph: Vec<Vec<Option<i64>>>,
11    n: usize,
12    distance: Option<Vec<Vec<i64>>>,
13
14    qubit_type: PhantomData<Q>,
15}
16
17impl<Q> GraphMatrix<Q>
18where
19    Q: Qubit + From<usize> + Clone + Copy + Sync + PartialEq,
20{
21    pub fn new(n: usize) -> Self {
22        let mut graph = vec![];
23        for i in 1..n {
24            graph.push(vec![None; i]);
25        }
26        Self {
27            graph,
28            n,
29            distance: None,
30            qubit_type: PhantomData,
31        }
32    }
33
34    fn add_node(&mut self) {
35        if self.n != 0 {
36            self.graph.push(vec![None; self.n]);
37        }
38        self.n += 1;
39    }
40
41    pub fn edge(&self, i: Q, j: Q) -> Option<i64> {
42        let (i, j) = (i.index(), j.index());
43
44        if self.n <= i || self.n <= j {
45            return None;
46        }
47
48        if i == j {
49            return Some(0);
50        }
51
52        let (i, j) = if i > j { (i, j) } else { (j, i) };
53        let i = i - 1;
54        self.graph[i][j]
55    }
56
57    pub fn neighbors(&self, node: Q) -> Vec<(Q, i64)> {
58        let node = node.index();
59        (0..self.n)
60            .filter_map(|j| {
61                if node != j {
62                    self.edge(node.into(), j.into())
63                        .map(|value| (j.into(), value))
64                } else {
65                    None
66                }
67            })
68            .collect()
69    }
70
71    pub fn set_edge(&mut self, i: Q, j: Q, value: i64) {
72        let (i, j) = (i.index(), j.index());
73
74        while self.n <= i || self.n <= j {
75            self.add_node();
76        }
77
78        if i == j {
79            return;
80        }
81
82        let (i, j) = if i > j { (i, j) } else { (j, i) };
83        let i = i - 1;
84        self.graph[i][j] = Some(value);
85    }
86
87    pub fn dist(&self, i: Q, j: Q) -> i64 {
88        let (i, j) = (i.index(), j.index());
89
90        if let Some(distance) = &self.distance {
91            if i == j {
92                return 0;
93            }
94
95            let (i, j) = if i > j { (i, j) } else { (j, i) };
96            let i = i - 1;
97            distance[i][j]
98        } else {
99            panic!("Calculate distance before")
100        }
101    }
102
103    fn set_dist_min(&mut self, i: Q, j: Q, value: i64) {
104        let (i, j) = (i.index(), j.index());
105
106        if self.distance.is_none() {
107            panic!("Cannot set distance without distance matrix");
108        }
109        if i == j {
110            return;
111        }
112
113        let (i, j) = if i > j { (i, j) } else { (j, i) };
114        let i = i - 1;
115        let value = std::cmp::min(self.distance.as_ref().unwrap()[i][j], value);
116        self.distance.as_mut().unwrap()[i][j] = value;
117    }
118
119    pub fn calculate_distance(&mut self) {
120        if self.distance.is_some() {
121            return;
122        }
123
124        self.distance = Some(
125            self.graph
126                .iter()
127                .map(|row| {
128                    row.iter()
129                        .map(|value| if value.is_some() { 1 } else { u32::MAX as i64 })
130                        .collect::<Vec<i64>>()
131                })
132                .collect(),
133        );
134
135        for k in 0..self.n {
136            for i in 0..self.n {
137                for j in i..self.n {
138                    let dist = self.dist(i.into(), k.into()) + self.dist(k.into(), j.into());
139                    self.set_dist_min(i.into(), j.into(), dist);
140                }
141            }
142        }
143    }
144
145    pub fn get_center(&self) -> Q {
146        let max_distance: Vec<i64> = (0..self.n)
147            .map(|i| {
148                (0..self.n)
149                    .map(|j| self.dist(i.into(), j.into()))
150                    .max()
151                    .unwrap()
152            })
153            .collect();
154
155        let min = max_distance.iter().min().unwrap();
156
157        let mut center_list: Vec<(usize, i64)> = max_distance
158            .iter()
159            .enumerate()
160            .filter_map(|(node, value)| {
161                if value <= min {
162                    Some((
163                        node,
164                        (0..self.n)
165                            .map(|other| self.dist(node.into(), other.into()))
166                            .sum(),
167                    ))
168                } else {
169                    None
170                }
171            })
172            .collect();
173
174        center_list.sort_by_key(|(_, dist)| *dist);
175
176        center_list[0].0.into()
177    }
178
179    pub fn breadth_first_search(&self, start: Q) -> VecDeque<Q> {
180        let mut visited: Vec<Q> = vec![];
181        let mut queue = VecDeque::from([start]);
182
183        while let Some(front) = queue.pop_front() {
184            visited.push(front);
185            let neighbors = self.neighbors(front);
186            let mut neighbors: Vec<_> = neighbors
187                .iter()
188                .filter_map(|(index, value)| {
189                    if !visited.contains(index) && !queue.contains(index) {
190                        Some((*index, *value))
191                    } else {
192                        None
193                    }
194                })
195                .collect();
196
197            neighbors.sort_by_key(|(_, value)| *value);
198
199            for (next, _) in neighbors {
200                queue.push_back(next);
201            }
202        }
203
204        VecDeque::from(visited)
205    }
206
207    pub fn degree(&self, node: Q) -> usize {
208        self.neighbors(node).len()
209    }
210}