causal_hub/models/graphs/
directed.rs1use std::collections::VecDeque;
2
3use ndarray::prelude::*;
4use serde::{
5 Deserialize, Deserializer, Serialize, Serializer,
6 de::{MapAccess, Visitor},
7 ser::SerializeMap,
8};
9
10use crate::{
11 impl_json_io,
12 models::{Graph, Labelled},
13 set,
14 types::{Labels, Set},
15};
16
17#[derive(Clone, Debug, Eq, PartialEq)]
19pub struct DiGraph {
20 labels: Labels,
21 adjacency_matrix: Array2<bool>,
22}
23
24impl DiGraph {
25 pub fn parents(&self, x: &Set<usize>) -> Set<usize> {
40 x.iter().for_each(|&v| {
42 assert!(v < self.labels.len(), "Vertex `{v}` is out of bounds");
43 });
44
45 let mut parents: Set<_> = x
47 .into_iter()
48 .flat_map(|&v| {
49 self.adjacency_matrix
50 .column(v)
51 .into_iter()
52 .enumerate()
53 .filter_map(|(y, &has_edge)| if has_edge { Some(y) } else { None })
54 })
55 .collect();
56
57 parents.sort();
59
60 parents
62 }
63
64 pub fn ancestors(&self, x: &Set<usize>) -> Set<usize> {
79 x.iter().for_each(|&v| {
81 assert!(v < self.labels.len(), "Vertex `{v}` is out of bounds");
82 });
83
84 let mut stack = VecDeque::new();
86 let mut visited = set![];
87
88 stack.extend(x);
90
91 while let Some(y) = stack.pop_back() {
93 for z in self.parents(&set![y]) {
95 if !visited.contains(&z) {
97 visited.insert(z);
99 stack.push_back(z);
101 }
102 }
103 }
104
105 visited.sort();
107
108 visited
110 }
111
112 pub fn children(&self, x: &Set<usize>) -> Set<usize> {
127 x.iter().for_each(|&v| {
129 assert!(v < self.labels.len(), "Vertex `{v}` is out of bounds");
130 });
131
132 let mut children: Set<_> = x
134 .into_iter()
135 .flat_map(|&v| {
136 self.adjacency_matrix
137 .row(v)
138 .into_iter()
139 .enumerate()
140 .filter_map(|(y, &has_edge)| if has_edge { Some(y) } else { None })
141 })
142 .collect();
143
144 children.sort();
146
147 children
149 }
150
151 pub fn descendants(&self, x: &Set<usize>) -> Set<usize> {
166 x.iter().for_each(|&v| {
168 assert!(v < self.labels.len(), "Vertex `{v}` is out of bounds");
169 });
170
171 let mut stack = VecDeque::new();
173 let mut visited = set![];
174
175 stack.extend(x);
177
178 while let Some(y) = stack.pop_back() {
180 for z in self.children(&set![y]) {
182 if !visited.contains(&z) {
184 visited.insert(z);
186 stack.push_back(z);
188 }
189 }
190 }
191
192 visited.sort();
194
195 visited
197 }
198}
199
200impl Labelled for DiGraph {
201 fn labels(&self) -> &Labels {
202 &self.labels
203 }
204}
205
206impl Graph for DiGraph {
207 fn empty<I, V>(labels: I) -> Self
208 where
209 I: IntoIterator<Item = V>,
210 V: AsRef<str>,
211 {
212 let mut n = 0;
214 let mut labels: Labels = labels
216 .into_iter()
217 .inspect(|_| n += 1)
218 .map(|x| x.as_ref().to_owned())
219 .collect();
220
221 assert_eq!(labels.len(), n, "Labels must be unique.");
223
224 labels.sort();
226
227 let adjacency_matrix: Array2<_> = Array::from_elem((n, n), false);
229
230 debug_assert!(labels.iter().is_sorted(), "Vertices labels must be sorted.");
232
233 Self {
234 labels,
235 adjacency_matrix,
236 }
237 }
238
239 fn complete<I, V>(labels: I) -> Self
240 where
241 I: IntoIterator<Item = V>,
242 V: AsRef<str>,
243 {
244 let mut n = 0;
246 let mut labels: Labels = labels
248 .into_iter()
249 .inspect(|_| n += 1)
250 .map(|x| x.as_ref().to_owned())
251 .collect();
252
253 assert_eq!(labels.len(), n, "Labels must be unique.");
255
256 labels.sort();
258
259 let mut adjacency_matrix: Array2<_> = Array::from_elem((n, n), true);
261 adjacency_matrix.diag_mut().fill(false);
263
264 debug_assert!(labels.iter().is_sorted(), "Vertices labels must be sorted.");
266
267 Self {
268 labels,
269 adjacency_matrix,
270 }
271 }
272
273 fn vertices(&self) -> Set<usize> {
274 (0..self.labels.len()).collect()
275 }
276
277 fn has_vertex(&self, x: usize) -> bool {
278 x < self.labels.len()
280 }
281
282 fn edges(&self) -> Set<(usize, usize)> {
283 self.adjacency_matrix
285 .indexed_iter()
286 .filter_map(|((x, y), &has_edge)| if has_edge { Some((x, y)) } else { None })
287 .collect()
288 }
289
290 fn has_edge(&self, x: usize, y: usize) -> bool {
291 assert!(x < self.labels.len(), "Vertex `{x}` is out of bounds");
293 assert!(y < self.labels.len(), "Vertex `{y}` is out of bounds");
294
295 self.adjacency_matrix[[x, y]]
296 }
297
298 fn add_edge(&mut self, x: usize, y: usize) -> bool {
299 assert!(x < self.labels.len(), "Vertex `{x}` is out of bounds");
301 assert!(y < self.labels.len(), "Vertex `{y}` is out of bounds");
302
303 if self.adjacency_matrix[[x, y]] {
305 return false;
306 }
307
308 self.adjacency_matrix[[x, y]] = true;
310
311 true
312 }
313
314 fn del_edge(&mut self, x: usize, y: usize) -> bool {
315 assert!(x < self.labels.len(), "Vertex `{x}` is out of bounds");
317 assert!(y < self.labels.len(), "Vertex `{y}` is out of bounds");
318
319 if !self.adjacency_matrix[[x, y]] {
321 return false;
322 }
323
324 self.adjacency_matrix[[x, y]] = false;
326
327 true
328 }
329
330 fn from_adjacency_matrix(mut labels: Labels, mut adjacency_matrix: Array2<bool>) -> Self {
331 assert_eq!(
333 labels.len(),
334 adjacency_matrix.nrows(),
335 "Number of labels must match the number of rows in the adjacency matrix."
336 );
337 assert_eq!(
339 adjacency_matrix.nrows(),
340 adjacency_matrix.ncols(),
341 "Adjacency matrix must be square."
342 );
343
344 if !labels.is_sorted() {
346 let mut indices: Vec<usize> = (0..labels.len()).collect();
348 indices.sort_by_key(|&i| &labels[i]);
350 labels.sort();
352 let mut new_adjacency_matrix = adjacency_matrix.clone();
354 for (i, &j) in indices.iter().enumerate() {
356 new_adjacency_matrix
357 .row_mut(i)
358 .assign(&adjacency_matrix.row(j));
359 }
360 adjacency_matrix = new_adjacency_matrix;
362 let mut new_adjacency_matrix = adjacency_matrix.clone();
364 for (i, &j) in indices.iter().enumerate() {
366 new_adjacency_matrix
367 .column_mut(i)
368 .assign(&adjacency_matrix.column(j));
369 }
370 adjacency_matrix = new_adjacency_matrix;
372 }
373
374 Self {
376 labels,
377 adjacency_matrix,
378 }
379 }
380
381 #[inline]
382 fn to_adjacency_matrix(&self) -> Array2<bool> {
383 self.adjacency_matrix.clone()
384 }
385}
386
387impl Serialize for DiGraph {
388 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
389 where
390 S: Serializer,
391 {
392 let edges: Vec<_> = self
394 .edges()
395 .into_iter()
396 .map(|(x, y)| {
397 (
398 self.index_to_label(x).to_owned(),
399 self.index_to_label(y).to_owned(),
400 )
401 })
402 .collect();
403
404 let mut map = serializer.serialize_map(Some(3))?;
406
407 map.serialize_entry("labels", &self.labels)?;
409 map.serialize_entry("edges", &edges)?;
411 map.serialize_entry("type", "digraph")?;
413
414 map.end()
416 }
417}
418
419impl<'de> Deserialize<'de> for DiGraph {
420 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
421 where
422 D: Deserializer<'de>,
423 {
424 #[derive(Deserialize)]
425 #[serde(field_identifier, rename_all = "snake_case")]
426 enum Field {
427 Labels,
428 Edges,
429 Type,
430 }
431
432 struct DiGraphVisitor;
433
434 impl<'de> Visitor<'de> for DiGraphVisitor {
435 type Value = DiGraph;
436
437 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
438 formatter.write_str("struct DiGraph")
439 }
440
441 fn visit_map<V>(self, mut map: V) -> Result<DiGraph, V::Error>
442 where
443 V: MapAccess<'de>,
444 {
445 use serde::de::Error as E;
446
447 let mut labels = None;
449 let mut edges = None;
450 let mut type_ = None;
451
452 while let Some(key) = map.next_key()? {
454 match key {
455 Field::Labels => {
456 if labels.is_some() {
457 return Err(E::duplicate_field("labels"));
458 }
459 labels = Some(map.next_value()?);
460 }
461 Field::Edges => {
462 if edges.is_some() {
463 return Err(E::duplicate_field("edges"));
464 }
465 edges = Some(map.next_value()?);
466 }
467 Field::Type => {
468 if type_.is_some() {
469 return Err(E::duplicate_field("type"));
470 }
471 type_ = Some(map.next_value()?);
472 }
473 }
474 }
475
476 let labels = labels.ok_or_else(|| E::missing_field("labels"))?;
478 let edges = edges.ok_or_else(|| E::missing_field("edges"))?;
479
480 let type_: String = type_.ok_or_else(|| E::missing_field("type"))?;
482 assert_eq!(type_, "digraph", "Invalid type for DiGraph.");
483
484 let labels: Labels = labels;
486 let edges: Vec<(String, String)> = edges;
487 let shape = (labels.len(), labels.len());
488 let mut adjacency_matrix = Array2::from_elem(shape, false);
489 for (x, y) in edges {
490 let x = labels
491 .get_index_of(&x)
492 .ok_or_else(|| E::custom(format!("Vertex `{x}` label does not exist")))?;
493 let y = labels
494 .get_index_of(&y)
495 .ok_or_else(|| E::custom(format!("Vertex `{y}` label does not exist")))?;
496 adjacency_matrix[(x, y)] = true;
497 }
498
499 Ok(DiGraph::from_adjacency_matrix(labels, adjacency_matrix))
500 }
501 }
502
503 const FIELDS: &[&str] = &["labels", "edges", "type"];
504
505 deserializer.deserialize_struct("DiGraph", FIELDS, DiGraphVisitor)
506 }
507}
508
509impl_json_io!(DiGraph);