causal_hub/models/graphs/
directed.rs

1use std::collections::VecDeque;
2
3use ndarray::prelude::*;
4use serde::{
5    Deserialize, Deserializer, Serialize, Serializer,
6    de::{MapAccess, Visitor},
7    ser::SerializeMap,
8};
9
10use crate::{
11    impl_json_io,
12    models::{Graph, Labelled},
13    set,
14    types::{Labels, Set},
15};
16
17/// A struct representing a directed graph using an adjacency matrix.
18#[derive(Clone, Debug, Eq, PartialEq)]
19pub struct DiGraph {
20    labels: Labels,
21    adjacency_matrix: Array2<bool>,
22}
23
24impl DiGraph {
25    /// Returns the parents of a set of vertices.
26    ///
27    /// # Arguments
28    ///
29    /// * `x` - The set of vertices for which to find the parents.
30    ///
31    /// # Panics
32    ///
33    /// * If any vertex is out of bounds.
34    ///
35    /// # Returns
36    ///
37    /// The parents of the vertices.
38    ///
39    pub fn parents(&self, x: &Set<usize>) -> Set<usize> {
40        // Assert the vertices are within bounds.
41        x.iter().for_each(|&v| {
42            assert!(v < self.labels.len(), "Vertex `{v}` is out of bounds");
43        });
44
45        // Iterate over all vertices and filter the ones that are parents.
46        let mut parents: Set<_> = x
47            .into_iter()
48            .flat_map(|&v| {
49                self.adjacency_matrix
50                    .column(v)
51                    .into_iter()
52                    .enumerate()
53                    .filter_map(|(y, &has_edge)| if has_edge { Some(y) } else { None })
54            })
55            .collect();
56
57        // Sort the parents.
58        parents.sort();
59
60        // Return the parents.
61        parents
62    }
63
64    /// Returns the ancestors of a set of vertices.
65    ///
66    /// # Arguments
67    ///
68    /// * `x` - The set of vertices for which to find the ancestors.
69    ///
70    /// # Panics
71    ///
72    /// * If any vertex is out of bounds.
73    ///
74    /// # Returns
75    ///
76    /// The ancestors of the vertices.
77    ///
78    pub fn ancestors(&self, x: &Set<usize>) -> Set<usize> {
79        // Assert the vertices are within bounds.
80        x.iter().for_each(|&v| {
81            assert!(v < self.labels.len(), "Vertex `{v}` is out of bounds");
82        });
83
84        // Initialize a stack and a visited set.
85        let mut stack = VecDeque::new();
86        let mut visited = set![];
87
88        // Start with the given vertices.
89        stack.extend(x);
90
91        // While there are vertices to visit ...
92        while let Some(y) = stack.pop_back() {
93            // For each incoming edge ...
94            for z in self.parents(&set![y]) {
95                // If there is an edge from z to y and z has not been visited ...
96                if !visited.contains(&z) {
97                    // Mark z as visited.
98                    visited.insert(z);
99                    // Add z to the stack to visit its ancestors.
100                    stack.push_back(z);
101                }
102            }
103        }
104
105        // Sort the visited set.
106        visited.sort();
107
108        // Return the visited set.
109        visited
110    }
111
112    /// Returns the children of a set of vertices.
113    ///
114    /// # Arguments
115    ///
116    /// * `x` - The set of vertices for which to find the children.
117    ///
118    /// # Panics
119    ///
120    /// * If any vertex is out of bounds.
121    ///
122    /// # Returns
123    ///
124    /// The children of the vertices.
125    ///
126    pub fn children(&self, x: &Set<usize>) -> Set<usize> {
127        // Check if the vertices are within bounds.
128        x.iter().for_each(|&v| {
129            assert!(v < self.labels.len(), "Vertex `{v}` is out of bounds");
130        });
131
132        // Iterate over all vertices and filter the ones that are children.
133        let mut children: Set<_> = x
134            .into_iter()
135            .flat_map(|&v| {
136                self.adjacency_matrix
137                    .row(v)
138                    .into_iter()
139                    .enumerate()
140                    .filter_map(|(y, &has_edge)| if has_edge { Some(y) } else { None })
141            })
142            .collect();
143
144        // Sort the children.
145        children.sort();
146
147        // Return the children.
148        children
149    }
150
151    /// Returns the descendants of a set of vertices.
152    ///
153    /// # Arguments
154    ///
155    /// * `x` - The set of vertices for which to find the descendants.
156    ///
157    /// # Panics
158    ///
159    /// * If any vertex is out of bounds.
160    ///
161    /// # Returns
162    ///
163    /// The descendants of the vertices.
164    ///
165    pub fn descendants(&self, x: &Set<usize>) -> Set<usize> {
166        // Assert the vertices are within bounds.
167        x.iter().for_each(|&v| {
168            assert!(v < self.labels.len(), "Vertex `{v}` is out of bounds");
169        });
170
171        // Initialize a stack and a visited set.
172        let mut stack = VecDeque::new();
173        let mut visited = set![];
174
175        // Start with the given vertices.
176        stack.extend(x);
177
178        // While there are vertices to visit ...
179        while let Some(y) = stack.pop_back() {
180            // For each outgoing edge ...
181            for z in self.children(&set![y]) {
182                // If z has not been visited ...
183                if !visited.contains(&z) {
184                    // Mark z as visited.
185                    visited.insert(z);
186                    // Add z to the stack to visit its descendants.
187                    stack.push_back(z);
188                }
189            }
190        }
191
192        // Sort the visited set.
193        visited.sort();
194
195        // Return the visited set.
196        visited
197    }
198}
199
200impl Labelled for DiGraph {
201    fn labels(&self) -> &Labels {
202        &self.labels
203    }
204}
205
206impl Graph for DiGraph {
207    fn empty<I, V>(labels: I) -> Self
208    where
209        I: IntoIterator<Item = V>,
210        V: AsRef<str>,
211    {
212        // Initialize labels counter.
213        let mut n = 0;
214        // Collect the labels.
215        let mut labels: Labels = labels
216            .into_iter()
217            .inspect(|_| n += 1)
218            .map(|x| x.as_ref().to_owned())
219            .collect();
220
221        // Assert no duplicate labels.
222        assert_eq!(labels.len(), n, "Labels must be unique.");
223
224        // Sort the labels.
225        labels.sort();
226
227        // Initialize the adjacency matrix with `false` values.
228        let adjacency_matrix: Array2<_> = Array::from_elem((n, n), false);
229
230        // Debug assert to check the sorting of the labels.
231        debug_assert!(labels.iter().is_sorted(), "Vertices labels must be sorted.");
232
233        Self {
234            labels,
235            adjacency_matrix,
236        }
237    }
238
239    fn complete<I, V>(labels: I) -> Self
240    where
241        I: IntoIterator<Item = V>,
242        V: AsRef<str>,
243    {
244        // Initialize labels counter.
245        let mut n = 0;
246        // Collect the labels.
247        let mut labels: Labels = labels
248            .into_iter()
249            .inspect(|_| n += 1)
250            .map(|x| x.as_ref().to_owned())
251            .collect();
252
253        // Assert no duplicate labels.
254        assert_eq!(labels.len(), n, "Labels must be unique.");
255
256        // Sort the labels.
257        labels.sort();
258
259        // Initialize the adjacency matrix with `true` values.
260        let mut adjacency_matrix: Array2<_> = Array::from_elem((n, n), true);
261        // Set the diagonal to `false` to avoid self-loops.
262        adjacency_matrix.diag_mut().fill(false);
263
264        // Debug assert to check the sorting of the labels.
265        debug_assert!(labels.iter().is_sorted(), "Vertices labels must be sorted.");
266
267        Self {
268            labels,
269            adjacency_matrix,
270        }
271    }
272
273    fn vertices(&self) -> Set<usize> {
274        (0..self.labels.len()).collect()
275    }
276
277    fn has_vertex(&self, x: usize) -> bool {
278        // Check if the vertex is within bounds.
279        x < self.labels.len()
280    }
281
282    fn edges(&self) -> Set<(usize, usize)> {
283        // Iterate over the adjacency matrix and collect the edges.
284        self.adjacency_matrix
285            .indexed_iter()
286            .filter_map(|((x, y), &has_edge)| if has_edge { Some((x, y)) } else { None })
287            .collect()
288    }
289
290    fn has_edge(&self, x: usize, y: usize) -> bool {
291        // Check if the vertices are within bounds.
292        assert!(x < self.labels.len(), "Vertex `{x}` is out of bounds");
293        assert!(y < self.labels.len(), "Vertex `{y}` is out of bounds");
294
295        self.adjacency_matrix[[x, y]]
296    }
297
298    fn add_edge(&mut self, x: usize, y: usize) -> bool {
299        // Check if the vertices are within bounds.
300        assert!(x < self.labels.len(), "Vertex `{x}` is out of bounds");
301        assert!(y < self.labels.len(), "Vertex `{y}` is out of bounds");
302
303        // Check if the edge already exists.
304        if self.adjacency_matrix[[x, y]] {
305            return false;
306        }
307
308        // Add the edge.
309        self.adjacency_matrix[[x, y]] = true;
310
311        true
312    }
313
314    fn del_edge(&mut self, x: usize, y: usize) -> bool {
315        // Check if the vertices are within bounds.
316        assert!(x < self.labels.len(), "Vertex `{x}` is out of bounds");
317        assert!(y < self.labels.len(), "Vertex `{y}` is out of bounds");
318
319        // Check if the edge exists.
320        if !self.adjacency_matrix[[x, y]] {
321            return false;
322        }
323
324        // Delete the edge.
325        self.adjacency_matrix[[x, y]] = false;
326
327        true
328    }
329
330    fn from_adjacency_matrix(mut labels: Labels, mut adjacency_matrix: Array2<bool>) -> Self {
331        // Assert labels and adjacency matrix dimensions match.
332        assert_eq!(
333            labels.len(),
334            adjacency_matrix.nrows(),
335            "Number of labels must match the number of rows in the adjacency matrix."
336        );
337        // Assert adjacency matrix must be square.
338        assert_eq!(
339            adjacency_matrix.nrows(),
340            adjacency_matrix.ncols(),
341            "Adjacency matrix must be square."
342        );
343
344        // Check if the labels are sorted.
345        if !labels.is_sorted() {
346            // Allocate the sorted indices.
347            let mut indices: Vec<usize> = (0..labels.len()).collect();
348            // Sort the indices based on the labels.
349            indices.sort_by_key(|&i| &labels[i]);
350            // Sort the labels.
351            labels.sort();
352            // Allocate a new adjacency matrix.
353            let mut new_adjacency_matrix = adjacency_matrix.clone();
354            // Fill the rows.
355            for (i, &j) in indices.iter().enumerate() {
356                new_adjacency_matrix
357                    .row_mut(i)
358                    .assign(&adjacency_matrix.row(j));
359            }
360            // Update the adjacency matrix.
361            adjacency_matrix = new_adjacency_matrix;
362            // Allocate a new adjacency matrix.
363            let mut new_adjacency_matrix = adjacency_matrix.clone();
364            // Fill the columns.
365            for (i, &j) in indices.iter().enumerate() {
366                new_adjacency_matrix
367                    .column_mut(i)
368                    .assign(&adjacency_matrix.column(j));
369            }
370            // Update the adjacency matrix.
371            adjacency_matrix = new_adjacency_matrix;
372        }
373
374        // Create a new graph instance.
375        Self {
376            labels,
377            adjacency_matrix,
378        }
379    }
380
381    #[inline]
382    fn to_adjacency_matrix(&self) -> Array2<bool> {
383        self.adjacency_matrix.clone()
384    }
385}
386
387impl Serialize for DiGraph {
388    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
389    where
390        S: Serializer,
391    {
392        // Convert adjacency matrix to a flat format.
393        let edges: Vec<_> = self
394            .edges()
395            .into_iter()
396            .map(|(x, y)| {
397                (
398                    self.index_to_label(x).to_owned(),
399                    self.index_to_label(y).to_owned(),
400                )
401            })
402            .collect();
403
404        // Allocate the map.
405        let mut map = serializer.serialize_map(Some(3))?;
406
407        // Serialize labels.
408        map.serialize_entry("labels", &self.labels)?;
409        // Serialize edges.
410        map.serialize_entry("edges", &edges)?;
411        // Serialize type.
412        map.serialize_entry("type", "digraph")?;
413
414        // Finalize the map serialization.
415        map.end()
416    }
417}
418
419impl<'de> Deserialize<'de> for DiGraph {
420    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
421    where
422        D: Deserializer<'de>,
423    {
424        #[derive(Deserialize)]
425        #[serde(field_identifier, rename_all = "snake_case")]
426        enum Field {
427            Labels,
428            Edges,
429            Type,
430        }
431
432        struct DiGraphVisitor;
433
434        impl<'de> Visitor<'de> for DiGraphVisitor {
435            type Value = DiGraph;
436
437            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
438                formatter.write_str("struct DiGraph")
439            }
440
441            fn visit_map<V>(self, mut map: V) -> Result<DiGraph, V::Error>
442            where
443                V: MapAccess<'de>,
444            {
445                use serde::de::Error as E;
446
447                // Allocate fields
448                let mut labels = None;
449                let mut edges = None;
450                let mut type_ = None;
451
452                // Parse the map.
453                while let Some(key) = map.next_key()? {
454                    match key {
455                        Field::Labels => {
456                            if labels.is_some() {
457                                return Err(E::duplicate_field("labels"));
458                            }
459                            labels = Some(map.next_value()?);
460                        }
461                        Field::Edges => {
462                            if edges.is_some() {
463                                return Err(E::duplicate_field("edges"));
464                            }
465                            edges = Some(map.next_value()?);
466                        }
467                        Field::Type => {
468                            if type_.is_some() {
469                                return Err(E::duplicate_field("type"));
470                            }
471                            type_ = Some(map.next_value()?);
472                        }
473                    }
474                }
475
476                // Check required fields.
477                let labels = labels.ok_or_else(|| E::missing_field("labels"))?;
478                let edges = edges.ok_or_else(|| E::missing_field("edges"))?;
479
480                // Assert type is correct.
481                let type_: String = type_.ok_or_else(|| E::missing_field("type"))?;
482                assert_eq!(type_, "digraph", "Invalid type for DiGraph.");
483
484                // Convert edges to an adjacency matrix.
485                let labels: Labels = labels;
486                let edges: Vec<(String, String)> = edges;
487                let shape = (labels.len(), labels.len());
488                let mut adjacency_matrix = Array2::from_elem(shape, false);
489                for (x, y) in edges {
490                    let x = labels
491                        .get_index_of(&x)
492                        .ok_or_else(|| E::custom(format!("Vertex `{x}` label does not exist")))?;
493                    let y = labels
494                        .get_index_of(&y)
495                        .ok_or_else(|| E::custom(format!("Vertex `{y}` label does not exist")))?;
496                    adjacency_matrix[(x, y)] = true;
497                }
498
499                Ok(DiGraph::from_adjacency_matrix(labels, adjacency_matrix))
500            }
501        }
502
503        const FIELDS: &[&str] = &["labels", "edges", "type"];
504
505        deserializer.deserialize_struct("DiGraph", FIELDS, DiGraphVisitor)
506    }
507}
508
509// Implement `JsonIO` for `DiGraph`.
510impl_json_io!(DiGraph);