causal_hub/models/graphs/
undirected.rs1use ndarray::prelude::*;
2use serde::{
3 Deserialize, Deserializer, Serialize, Serializer,
4 de::{MapAccess, Visitor},
5 ser::SerializeMap,
6};
7
8use crate::{
9 impl_json_io,
10 models::{Graph, Labelled},
11 types::{Labels, Set},
12};
13
14#[derive(Clone, Debug)]
16pub struct UnGraph {
17 labels: Labels,
18 adjacency_matrix: Array2<bool>,
19}
20
21impl UnGraph {
22 pub fn neighbors(&self, x: &Set<usize>) -> Set<usize> {
37 x.iter().for_each(|&v| {
39 assert!(v < self.labels.len(), "Vertex `{v}` is out of bounds");
40 });
41
42 let mut neighbors: Set<_> = x
44 .into_iter()
45 .flat_map(|&v| {
46 self.adjacency_matrix
47 .row(v)
48 .into_iter()
49 .enumerate()
50 .filter_map(|(y, &has_edge)| if has_edge { Some(y) } else { None })
51 })
52 .collect();
53
54 neighbors.sort();
56
57 neighbors
59 }
60}
61
62impl Labelled for UnGraph {
63 fn labels(&self) -> &Labels {
64 &self.labels
65 }
66}
67
68impl Graph for UnGraph {
69 fn empty<I, V>(labels: I) -> Self
70 where
71 I: IntoIterator<Item = V>,
72 V: AsRef<str>,
73 {
74 let mut n = 0;
76 let mut labels: Labels = labels
78 .into_iter()
79 .inspect(|_| n += 1)
80 .map(|x| x.as_ref().to_owned())
81 .collect();
82
83 assert_eq!(labels.len(), n, "Labels must be unique.");
85
86 labels.sort();
88
89 let adjacency_matrix: Array2<_> = Array::from_elem((n, n), false);
91
92 debug_assert!(labels.iter().is_sorted(), "Vertices labels must be sorted.");
94
95 Self {
96 labels,
97 adjacency_matrix,
98 }
99 }
100
101 fn complete<I, V>(labels: I) -> Self
102 where
103 I: IntoIterator<Item = V>,
104 V: AsRef<str>,
105 {
106 let mut n = 0;
108 let mut labels: Labels = labels
110 .into_iter()
111 .inspect(|_| n += 1)
112 .map(|x| x.as_ref().to_owned())
113 .collect();
114
115 assert_eq!(labels.len(), n, "Labels must be unique.");
117
118 labels.sort();
120
121 let mut adjacency_matrix: Array2<_> = Array::from_elem((n, n), true);
123 adjacency_matrix.diag_mut().fill(false);
125
126 debug_assert!(labels.iter().is_sorted(), "Vertices labels must be sorted.");
128
129 Self {
130 labels,
131 adjacency_matrix,
132 }
133 }
134
135 fn vertices(&self) -> Set<usize> {
136 (0..self.labels.len()).collect()
137 }
138
139 fn has_vertex(&self, x: usize) -> bool {
140 x < self.labels.len()
142 }
143
144 fn edges(&self) -> Set<(usize, usize)> {
145 self.adjacency_matrix
147 .indexed_iter()
148 .filter_map(|((x, y), &has_edge)| {
149 if has_edge && x <= y {
151 Some((x, y))
152 } else {
153 None
154 }
155 })
156 .collect()
157 }
158
159 fn has_edge(&self, x: usize, y: usize) -> bool {
160 assert!(x < self.labels.len(), "Vertex `{x}` is out of bounds");
162 assert!(y < self.labels.len(), "Vertex `{y}` is out of bounds");
163
164 self.adjacency_matrix[[x, y]]
165 }
166
167 fn add_edge(&mut self, x: usize, y: usize) -> bool {
168 assert!(x < self.labels.len(), "Vertex `{x}` is out of bounds");
170 assert!(y < self.labels.len(), "Vertex `{y}` is out of bounds");
171
172 if self.adjacency_matrix[[x, y]] {
174 return false;
175 }
176
177 self.adjacency_matrix[[x, y]] = true;
179 self.adjacency_matrix[[y, x]] = true;
180
181 true
182 }
183
184 fn del_edge(&mut self, x: usize, y: usize) -> bool {
185 assert!(x < self.labels.len(), "Vertex `{x}` is out of bounds");
187 assert!(y < self.labels.len(), "Vertex `{y}` is out of bounds");
188
189 if !self.adjacency_matrix[[x, y]] {
191 return false;
192 }
193
194 self.adjacency_matrix[[x, y]] = false;
196 self.adjacency_matrix[[y, x]] = false;
197
198 true
199 }
200
201 fn from_adjacency_matrix(mut labels: Labels, mut adjacency_matrix: Array2<bool>) -> Self {
202 assert_eq!(
204 labels.len(),
205 adjacency_matrix.nrows(),
206 "Number of labels must match the number of rows in the adjacency matrix."
207 );
208 assert_eq!(
210 adjacency_matrix.nrows(),
211 adjacency_matrix.ncols(),
212 "Adjacency matrix must be square."
213 );
214 assert_eq!(
216 adjacency_matrix,
217 adjacency_matrix.t(),
218 "Adjacency matrix must be symmetric."
219 );
220
221 if !labels.is_sorted() {
223 let mut indices: Vec<usize> = (0..labels.len()).collect();
225 indices.sort_by_key(|&i| &labels[i]);
227 labels.sort();
229 let mut new_adjacency_matrix = adjacency_matrix.clone();
231 for (i, &j) in indices.iter().enumerate() {
233 new_adjacency_matrix
234 .row_mut(i)
235 .assign(&adjacency_matrix.row(j));
236 }
237 adjacency_matrix = new_adjacency_matrix;
239 let mut new_adjacency_matrix = adjacency_matrix.clone();
241 for (i, &j) in indices.iter().enumerate() {
243 new_adjacency_matrix
244 .column_mut(i)
245 .assign(&adjacency_matrix.column(j));
246 }
247 adjacency_matrix = new_adjacency_matrix;
249 }
250
251 Self {
253 labels,
254 adjacency_matrix,
255 }
256 }
257
258 #[inline]
259 fn to_adjacency_matrix(&self) -> Array2<bool> {
260 self.adjacency_matrix.clone()
261 }
262}
263
264impl Serialize for UnGraph {
265 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
266 where
267 S: Serializer,
268 {
269 let edges: Vec<_> = self
271 .edges()
272 .into_iter()
273 .map(|(x, y)| {
274 (
275 self.index_to_label(x).to_owned(),
276 self.index_to_label(y).to_owned(),
277 )
278 })
279 .collect();
280
281 let mut map = serializer.serialize_map(Some(3))?;
283
284 map.serialize_entry("labels", &self.labels)?;
286 map.serialize_entry("edges", &edges)?;
288 map.serialize_entry("type", "ungraph")?;
290
291 map.end()
293 }
294}
295
296impl<'de> Deserialize<'de> for UnGraph {
297 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
298 where
299 D: Deserializer<'de>,
300 {
301 #[derive(Deserialize)]
302 #[serde(field_identifier, rename_all = "snake_case")]
303 enum Field {
304 Labels,
305 Edges,
306 Type,
307 }
308
309 struct UnGraphVisitor;
310
311 impl<'de> Visitor<'de> for UnGraphVisitor {
312 type Value = UnGraph;
313
314 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
315 formatter.write_str("struct UnGraph")
316 }
317
318 fn visit_map<V>(self, mut map: V) -> Result<UnGraph, V::Error>
319 where
320 V: MapAccess<'de>,
321 {
322 use serde::de::Error as E;
323
324 let mut labels = None;
326 let mut edges = None;
327 let mut type_ = None;
328
329 while let Some(key) = map.next_key()? {
331 match key {
332 Field::Labels => {
333 if labels.is_some() {
334 return Err(E::duplicate_field("labels"));
335 }
336 labels = Some(map.next_value()?);
337 }
338 Field::Edges => {
339 if edges.is_some() {
340 return Err(E::duplicate_field("edges"));
341 }
342 edges = Some(map.next_value()?);
343 }
344 Field::Type => {
345 if type_.is_some() {
346 return Err(E::duplicate_field("type"));
347 }
348 type_ = Some(map.next_value()?);
349 }
350 }
351 }
352
353 let labels = labels.ok_or_else(|| E::missing_field("labels"))?;
355 let edges = edges.ok_or_else(|| E::missing_field("edges"))?;
356
357 let type_: String = type_.ok_or_else(|| E::missing_field("type"))?;
359 assert_eq!(type_, "ungraph", "Invalid type for UnGraph.");
360
361 let labels: Labels = labels;
363 let edges: Vec<(String, String)> = edges;
364 let shape = (labels.len(), labels.len());
365 let mut adjacency_matrix = Array2::from_elem(shape, false);
366 for (x, y) in edges {
367 let x = labels
368 .get_index_of(&x)
369 .ok_or_else(|| E::custom(format!("Vertex `{x}` label does not exist")))?;
370 let y = labels
371 .get_index_of(&y)
372 .ok_or_else(|| E::custom(format!("Vertex `{y}` label does not exist")))?;
373 adjacency_matrix[(x, y)] = true;
374 }
375
376 Ok(UnGraph::from_adjacency_matrix(labels, adjacency_matrix))
377 }
378 }
379
380 const FIELDS: &[&str] = &["labels", "edges", "type"];
381
382 deserializer.deserialize_struct("UnGraph", FIELDS, UnGraphVisitor)
383 }
384}
385
386impl_json_io!(UnGraph);