causal_hub/models/bayesian_network/categorical/
model.rs1use approx::{AbsDiffEq, RelativeEq};
2use ndarray::prelude::*;
3use serde::{
4 Deserialize, Deserializer, Serialize, Serializer,
5 de::{MapAccess, Visitor},
6 ser::SerializeMap,
7};
8
9use crate::{
10 datasets::{CatEv, CatSample, CatTable},
11 impl_json_io,
12 inference::TopologicalOrder,
13 io::{BifIO, BifParser},
14 models::{BN, CPD, CatCPD, DiGraph, Graph, Labelled},
15 set,
16 types::{Labels, Map, States},
17};
18
19#[derive(Clone, Debug)]
21pub struct CatBN {
22 name: Option<String>,
24 description: Option<String>,
26 labels: Labels,
28 states: States,
30 shape: Array1<usize>,
32 graph: DiGraph,
34 cpds: Map<String, CatCPD>,
36 topological_order: Vec<usize>,
38}
39
40impl CatBN {
41 #[inline]
48 pub const fn states(&self) -> &States {
49 &self.states
50 }
51
52 #[inline]
59 pub fn shape(&self) -> &Array1<usize> {
60 &self.shape
61 }
62}
63
64impl PartialEq for CatBN {
65 fn eq(&self, other: &Self) -> bool {
66 self.labels.eq(&other.labels)
67 && self.states.eq(&other.states)
68 && self.shape.eq(&other.shape)
69 && self.graph.eq(&other.graph)
70 && self.topological_order.eq(&other.topological_order)
71 && self.cpds.eq(&other.cpds)
72 }
73}
74
75impl AbsDiffEq for CatBN {
76 type Epsilon = f64;
77
78 fn default_epsilon() -> Self::Epsilon {
79 Self::Epsilon::default_epsilon()
80 }
81
82 fn abs_diff_eq(&self, other: &Self, epsilon: Self::Epsilon) -> bool {
83 self.labels.eq(&other.labels)
84 && self.states.eq(&other.states)
85 && self.shape.eq(&other.shape)
86 && self.graph.eq(&other.graph)
87 && self.topological_order.eq(&other.topological_order)
88 && self
89 .cpds
90 .iter()
91 .zip(&other.cpds)
92 .all(|((label, cpd), (other_label, other_cpd))| {
93 label.eq(other_label) && cpd.abs_diff_eq(other_cpd, epsilon)
94 })
95 }
96}
97
98impl RelativeEq for CatBN {
99 fn default_max_relative() -> Self::Epsilon {
100 Self::Epsilon::default_max_relative()
101 }
102
103 fn relative_eq(
104 &self,
105 other: &Self,
106 epsilon: Self::Epsilon,
107 max_relative: Self::Epsilon,
108 ) -> bool {
109 self.labels.eq(&other.labels)
110 && self.states.eq(&other.states)
111 && self.shape.eq(&other.shape)
112 && self.graph.eq(&other.graph)
113 && self.topological_order.eq(&other.topological_order)
114 && self
115 .cpds
116 .iter()
117 .zip(&other.cpds)
118 .all(|((label, cpd), (other_label, other_cpd))| {
119 label.eq(other_label) && cpd.relative_eq(other_cpd, epsilon, max_relative)
120 })
121 }
122}
123
124impl Labelled for CatBN {
125 #[inline]
126 fn labels(&self) -> &Labels {
127 &self.labels
128 }
129}
130
131impl BN for CatBN {
132 type CPD = CatCPD;
133 type Evidence = CatEv;
134 type Sample = CatSample;
135 type Samples = CatTable;
136
137 fn new<I>(graph: DiGraph, cpds: I) -> Self
138 where
139 I: IntoIterator<Item = Self::CPD>,
140 {
141 let mut cpds: Map<_, _> = cpds
143 .into_iter()
144 .inspect(|x| {
147 assert_eq!(x.labels().len(), 1, "CPD must contain exactly one label.");
148 })
149 .map(|x| (x.labels()[0].to_owned(), x))
150 .collect();
151 cpds.sort_keys();
153
154 assert!(
156 graph.labels().iter().eq(cpds.keys()),
157 "Graph labels and distributions labels must be the same."
158 );
159
160 let mut states: States = Default::default();
162 for cpd in cpds.values() {
164 cpd.states()
165 .iter()
166 .chain(cpd.conditioning_states())
167 .for_each(|(l, s)| {
168 if let Some(existing_states) = states.get(l) {
170 assert_eq!(
172 existing_states, s,
173 "States of `{l}` must be the same across CPDs.",
174 );
175 } else {
176 states.insert(l.to_owned(), s.clone());
178 }
179 });
180 }
181 states.sort_keys();
183
184 let labels: Labels = states.keys().cloned().collect();
186 let shape: Array1<usize> = states.values().map(|s| s.len()).collect();
188
189 graph.vertices().iter().for_each(|&i| {
191 let pa_i = graph.parents(&set![i]).into_iter();
193 let pa_i: &Labels = &pa_i.map(|j| labels[j].to_owned()).collect();
194 let pa_j = cpds[&labels[i]].conditioning_labels();
196 assert_eq!(
198 pa_i, pa_j,
199 "Graph parents labels and CPD conditioning labels must be the same:\n\
200 \t expected: {:?} ,\n\
201 \t found: {:?} .",
202 pa_i, pa_j
203 );
204 });
205
206 let topological_order = graph.topological_order().expect("Graph must be acyclic.");
208
209 Self {
210 name: None,
211 description: None,
212 labels,
213 states,
214 shape,
215 graph,
216 cpds,
217 topological_order,
218 }
219 }
220
221 #[inline]
222 fn name(&self) -> Option<&str> {
223 self.name.as_deref()
224 }
225
226 #[inline]
227 fn description(&self) -> Option<&str> {
228 self.description.as_deref()
229 }
230
231 #[inline]
232 fn graph(&self) -> &DiGraph {
233 &self.graph
234 }
235
236 #[inline]
237 fn cpds(&self) -> &Map<String, Self::CPD> {
238 &self.cpds
239 }
240
241 #[inline]
242 fn parameters_size(&self) -> usize {
243 self.cpds.iter().map(|(_, x)| x.parameters_size()).sum()
244 }
245
246 #[inline]
247 fn topological_order(&self) -> &[usize] {
248 &self.topological_order
249 }
250
251 fn with_optionals<I>(
252 name: Option<String>,
253 description: Option<String>,
254 graph: DiGraph,
255 cpds: I,
256 ) -> Self
257 where
258 I: IntoIterator<Item = Self::CPD>,
259 {
260 if let Some(name) = &name {
262 assert!(!name.is_empty(), "Name cannot be an empty string.");
263 }
264 if let Some(description) = &description {
266 assert!(
267 !description.is_empty(),
268 "Description cannot be an empty string."
269 );
270 }
271
272 let mut bn = Self::new(graph, cpds);
274
275 bn.name = name;
277 bn.description = description;
278
279 bn
280 }
281}
282
283impl Serialize for CatBN {
284 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
285 where
286 S: Serializer,
287 {
288 let mut size = 3;
290 size += self.name.is_some() as usize;
291 size += self.description.is_some() as usize;
292
293 let mut map = serializer.serialize_map(Some(size))?;
295
296 if let Some(name) = &self.name {
298 map.serialize_entry("name", name)?;
299 }
300 if let Some(description) = &self.description {
302 map.serialize_entry("description", description)?;
303 }
304 map.serialize_entry("graph", &self.graph)?;
306
307 let cpds: Vec<_> = self.cpds.values().cloned().collect();
309 map.serialize_entry("cpds", &cpds)?;
311
312 map.serialize_entry("type", "catbn")?;
314
315 map.end()
317 }
318}
319
320impl<'de> Deserialize<'de> for CatBN {
321 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
322 where
323 D: Deserializer<'de>,
324 {
325 #[derive(Deserialize)]
326 #[serde(field_identifier, rename_all = "snake_case")]
327 enum Field {
328 Name,
329 Description,
330 Graph,
331 Cpds,
332 Type,
333 }
334
335 struct CatBNVisitor;
336
337 impl<'de> Visitor<'de> for CatBNVisitor {
338 type Value = CatBN;
339
340 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
341 formatter.write_str("struct CatBN")
342 }
343
344 fn visit_map<V>(self, mut map: V) -> Result<CatBN, V::Error>
345 where
346 V: MapAccess<'de>,
347 {
348 use serde::de::Error as E;
349
350 let mut name = None;
352 let mut description = None;
353 let mut graph = None;
354 let mut cpds = None;
355 let mut type_ = None;
356
357 while let Some(key) = map.next_key()? {
359 match key {
360 Field::Name => {
361 if name.is_some() {
362 return Err(E::duplicate_field("name"));
363 }
364 name = Some(map.next_value()?);
365 }
366 Field::Description => {
367 if description.is_some() {
368 return Err(E::duplicate_field("description"));
369 }
370 description = Some(map.next_value()?);
371 }
372 Field::Graph => {
373 if graph.is_some() {
374 return Err(E::duplicate_field("graph"));
375 }
376 graph = Some(map.next_value()?);
377 }
378 Field::Cpds => {
379 if cpds.is_some() {
380 return Err(E::duplicate_field("cpds"));
381 }
382 cpds = Some(map.next_value()?);
383 }
384 Field::Type => {
385 if type_.is_some() {
386 return Err(E::duplicate_field("type"));
387 }
388 type_ = Some(map.next_value()?);
389 }
390 }
391 }
392
393 let graph = graph.ok_or_else(|| E::missing_field("graph"))?;
395 let cpds = cpds.ok_or_else(|| E::missing_field("cpds"))?;
396
397 let type_: String = type_.ok_or_else(|| E::missing_field("type"))?;
399 assert_eq!(type_, "catbn", "Invalid type for CatBN.");
400
401 let cpds: Vec<_> = cpds;
403
404 Ok(CatBN::with_optionals(name, description, graph, cpds))
405 }
406 }
407
408 const FIELDS: &[&str] = &["name", "description", "graph", "cpds", "type"];
409
410 deserializer.deserialize_struct("CatBN", FIELDS, CatBNVisitor)
411 }
412}
413
414impl_json_io!(CatBN);
416
417impl BifIO for CatBN {
418 fn from_bif(bif: &str) -> Self {
419 BifParser::parse_str(bif)
420 }
421
422 fn to_bif(&self) -> String {
423 todo!() }
425
426 fn read_bif(path: &str) -> Self {
427 Self::from_bif(&std::fs::read_to_string(path).expect("Failed to read BIF file."))
428 }
429
430 fn write_bif(&self, path: &str) {
431 std::fs::write(path, self.to_bif()).expect("Failed to write BIF file.");
432 }
433}