causal_hub/models/graphs/
undirected.rs

1use ndarray::prelude::*;
2use serde::{
3    Deserialize, Deserializer, Serialize, Serializer,
4    de::{MapAccess, Visitor},
5    ser::SerializeMap,
6};
7
8use crate::{
9    impl_json_io,
10    models::{Graph, Labelled},
11    types::{Labels, Set},
12};
13
14/// A struct representing an undirected graph using an adjacency matrix.
15#[derive(Clone, Debug)]
16pub struct UnGraph {
17    labels: Labels,
18    adjacency_matrix: Array2<bool>,
19}
20
21impl UnGraph {
22    /// Returns the neighbors of a vertex.
23    ///
24    /// # Arguments
25    ///
26    /// * `x` - The vertex for which to find the neighbors.
27    ///
28    /// # Panics
29    ///
30    /// * If the vertex is out of bounds.
31    ///
32    /// # Returns
33    ///
34    /// The neighbors of the vertex.
35    ///
36    pub fn neighbors(&self, x: &Set<usize>) -> Set<usize> {
37        // Check if the vertices are within bounds.
38        x.iter().for_each(|&v| {
39            assert!(v < self.labels.len(), "Vertex `{v}` is out of bounds");
40        });
41
42        // Iterate over all vertices and filter the ones that are neighbors.
43        let mut neighbors: Set<_> = x
44            .into_iter()
45            .flat_map(|&v| {
46                self.adjacency_matrix
47                    .row(v)
48                    .into_iter()
49                    .enumerate()
50                    .filter_map(|(y, &has_edge)| if has_edge { Some(y) } else { None })
51            })
52            .collect();
53
54        // Sort the neighbors.
55        neighbors.sort();
56
57        // Return the neighbors.
58        neighbors
59    }
60}
61
62impl Labelled for UnGraph {
63    fn labels(&self) -> &Labels {
64        &self.labels
65    }
66}
67
68impl Graph for UnGraph {
69    fn empty<I, V>(labels: I) -> Self
70    where
71        I: IntoIterator<Item = V>,
72        V: AsRef<str>,
73    {
74        // Initialize labels counter.
75        let mut n = 0;
76        // Collect the labels.
77        let mut labels: Labels = labels
78            .into_iter()
79            .inspect(|_| n += 1)
80            .map(|x| x.as_ref().to_owned())
81            .collect();
82
83        // Assert no duplicate labels.
84        assert_eq!(labels.len(), n, "Labels must be unique.");
85
86        // Sort the labels.
87        labels.sort();
88
89        // Initialize the adjacency matrix with `false` values.
90        let adjacency_matrix: Array2<_> = Array::from_elem((n, n), false);
91
92        // Debug assert to check the sorting of the labels.
93        debug_assert!(labels.iter().is_sorted(), "Vertices labels must be sorted.");
94
95        Self {
96            labels,
97            adjacency_matrix,
98        }
99    }
100
101    fn complete<I, V>(labels: I) -> Self
102    where
103        I: IntoIterator<Item = V>,
104        V: AsRef<str>,
105    {
106        // Initialize labels counter.
107        let mut n = 0;
108        // Collect the labels.
109        let mut labels: Labels = labels
110            .into_iter()
111            .inspect(|_| n += 1)
112            .map(|x| x.as_ref().to_owned())
113            .collect();
114
115        // Assert no duplicate labels.
116        assert_eq!(labels.len(), n, "Labels must be unique.");
117
118        // Sort the labels.
119        labels.sort();
120
121        // Initialize the adjacency matrix with `true` values.
122        let mut adjacency_matrix: Array2<_> = Array::from_elem((n, n), true);
123        // Set the diagonal to `false` to avoid self-loops.
124        adjacency_matrix.diag_mut().fill(false);
125
126        // Debug assert to check the sorting of the labels.
127        debug_assert!(labels.iter().is_sorted(), "Vertices labels must be sorted.");
128
129        Self {
130            labels,
131            adjacency_matrix,
132        }
133    }
134
135    fn vertices(&self) -> Set<usize> {
136        (0..self.labels.len()).collect()
137    }
138
139    fn has_vertex(&self, x: usize) -> bool {
140        // Check if the vertex is within bounds.
141        x < self.labels.len()
142    }
143
144    fn edges(&self) -> Set<(usize, usize)> {
145        // Iterate over the adjacency matrix and collect the edges.
146        self.adjacency_matrix
147            .indexed_iter()
148            .filter_map(|((x, y), &has_edge)| {
149                // Since the graph is undirected, we only need to check one direction.
150                if has_edge && x <= y {
151                    Some((x, y))
152                } else {
153                    None
154                }
155            })
156            .collect()
157    }
158
159    fn has_edge(&self, x: usize, y: usize) -> bool {
160        // Check if the vertices are within bounds.
161        assert!(x < self.labels.len(), "Vertex `{x}` is out of bounds");
162        assert!(y < self.labels.len(), "Vertex `{y}` is out of bounds");
163
164        self.adjacency_matrix[[x, y]]
165    }
166
167    fn add_edge(&mut self, x: usize, y: usize) -> bool {
168        // Check if the vertices are within bounds.
169        assert!(x < self.labels.len(), "Vertex `{x}` is out of bounds");
170        assert!(y < self.labels.len(), "Vertex `{y}` is out of bounds");
171
172        // Check if the edge already exists.
173        if self.adjacency_matrix[[x, y]] {
174            return false;
175        }
176
177        // Add the edge.
178        self.adjacency_matrix[[x, y]] = true;
179        self.adjacency_matrix[[y, x]] = true;
180
181        true
182    }
183
184    fn del_edge(&mut self, x: usize, y: usize) -> bool {
185        // Check if the vertices are within bounds.
186        assert!(x < self.labels.len(), "Vertex `{x}` is out of bounds");
187        assert!(y < self.labels.len(), "Vertex `{y}` is out of bounds");
188
189        // Check if the edge exists.
190        if !self.adjacency_matrix[[x, y]] {
191            return false;
192        }
193
194        // Delete the edge.
195        self.adjacency_matrix[[x, y]] = false;
196        self.adjacency_matrix[[y, x]] = false;
197
198        true
199    }
200
201    fn from_adjacency_matrix(mut labels: Labels, mut adjacency_matrix: Array2<bool>) -> Self {
202        // Assert labels and adjacency matrix dimensions match.
203        assert_eq!(
204            labels.len(),
205            adjacency_matrix.nrows(),
206            "Number of labels must match the number of rows in the adjacency matrix."
207        );
208        // Assert adjacency matrix must be square.
209        assert_eq!(
210            adjacency_matrix.nrows(),
211            adjacency_matrix.ncols(),
212            "Adjacency matrix must be square."
213        );
214        // Assert the adjacency matrix is symmetric.
215        assert_eq!(
216            adjacency_matrix,
217            adjacency_matrix.t(),
218            "Adjacency matrix must be symmetric."
219        );
220
221        // Check if the labels are sorted.
222        if !labels.is_sorted() {
223            // Allocate the sorted indices.
224            let mut indices: Vec<usize> = (0..labels.len()).collect();
225            // Sort the indices based on the labels.
226            indices.sort_by_key(|&i| &labels[i]);
227            // Sort the labels.
228            labels.sort();
229            // Allocate a new adjacency matrix.
230            let mut new_adjacency_matrix = adjacency_matrix.clone();
231            // Fill the rows.
232            for (i, &j) in indices.iter().enumerate() {
233                new_adjacency_matrix
234                    .row_mut(i)
235                    .assign(&adjacency_matrix.row(j));
236            }
237            // Update the adjacency matrix.
238            adjacency_matrix = new_adjacency_matrix;
239            // Allocate a new adjacency matrix.
240            let mut new_adjacency_matrix = adjacency_matrix.clone();
241            // Fill the columns.
242            for (i, &j) in indices.iter().enumerate() {
243                new_adjacency_matrix
244                    .column_mut(i)
245                    .assign(&adjacency_matrix.column(j));
246            }
247            // Update the adjacency matrix.
248            adjacency_matrix = new_adjacency_matrix;
249        }
250
251        // Create a new graph instance.
252        Self {
253            labels,
254            adjacency_matrix,
255        }
256    }
257
258    #[inline]
259    fn to_adjacency_matrix(&self) -> Array2<bool> {
260        self.adjacency_matrix.clone()
261    }
262}
263
264impl Serialize for UnGraph {
265    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
266    where
267        S: Serializer,
268    {
269        // Convert adjacency matrix to a flat format.
270        let edges: Vec<_> = self
271            .edges()
272            .into_iter()
273            .map(|(x, y)| {
274                (
275                    self.index_to_label(x).to_owned(),
276                    self.index_to_label(y).to_owned(),
277                )
278            })
279            .collect();
280
281        // Allocate the map.
282        let mut map = serializer.serialize_map(Some(3))?;
283
284        // Serialize labels.
285        map.serialize_entry("labels", &self.labels)?;
286        // Serialize edges.
287        map.serialize_entry("edges", &edges)?;
288        // Serialize type.
289        map.serialize_entry("type", "ungraph")?;
290
291        // Finalize the map serialization.
292        map.end()
293    }
294}
295
296impl<'de> Deserialize<'de> for UnGraph {
297    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
298    where
299        D: Deserializer<'de>,
300    {
301        #[derive(Deserialize)]
302        #[serde(field_identifier, rename_all = "snake_case")]
303        enum Field {
304            Labels,
305            Edges,
306            Type,
307        }
308
309        struct UnGraphVisitor;
310
311        impl<'de> Visitor<'de> for UnGraphVisitor {
312            type Value = UnGraph;
313
314            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
315                formatter.write_str("struct UnGraph")
316            }
317
318            fn visit_map<V>(self, mut map: V) -> Result<UnGraph, V::Error>
319            where
320                V: MapAccess<'de>,
321            {
322                use serde::de::Error as E;
323
324                // Allocate fields
325                let mut labels = None;
326                let mut edges = None;
327                let mut type_ = None;
328
329                // Parse the map.
330                while let Some(key) = map.next_key()? {
331                    match key {
332                        Field::Labels => {
333                            if labels.is_some() {
334                                return Err(E::duplicate_field("labels"));
335                            }
336                            labels = Some(map.next_value()?);
337                        }
338                        Field::Edges => {
339                            if edges.is_some() {
340                                return Err(E::duplicate_field("edges"));
341                            }
342                            edges = Some(map.next_value()?);
343                        }
344                        Field::Type => {
345                            if type_.is_some() {
346                                return Err(E::duplicate_field("type"));
347                            }
348                            type_ = Some(map.next_value()?);
349                        }
350                    }
351                }
352
353                // Check required fields.
354                let labels = labels.ok_or_else(|| E::missing_field("labels"))?;
355                let edges = edges.ok_or_else(|| E::missing_field("edges"))?;
356
357                // Assert type is correct.
358                let type_: String = type_.ok_or_else(|| E::missing_field("type"))?;
359                assert_eq!(type_, "ungraph", "Invalid type for UnGraph.");
360
361                // Convert edges to an adjacency matrix.
362                let labels: Labels = labels;
363                let edges: Vec<(String, String)> = edges;
364                let shape = (labels.len(), labels.len());
365                let mut adjacency_matrix = Array2::from_elem(shape, false);
366                for (x, y) in edges {
367                    let x = labels
368                        .get_index_of(&x)
369                        .ok_or_else(|| E::custom(format!("Vertex `{x}` label does not exist")))?;
370                    let y = labels
371                        .get_index_of(&y)
372                        .ok_or_else(|| E::custom(format!("Vertex `{y}` label does not exist")))?;
373                    adjacency_matrix[(x, y)] = true;
374                }
375
376                Ok(UnGraph::from_adjacency_matrix(labels, adjacency_matrix))
377            }
378        }
379
380        const FIELDS: &[&str] = &["labels", "edges", "type"];
381
382        deserializer.deserialize_struct("UnGraph", FIELDS, UnGraphVisitor)
383    }
384}
385
386// Implement `JsonIO` for `UnGraph`.
387impl_json_io!(UnGraph);