causal_hub/models/bayesian_network/categorical/
model.rs

1use approx::{AbsDiffEq, RelativeEq};
2use ndarray::prelude::*;
3use serde::{
4    Deserialize, Deserializer, Serialize, Serializer,
5    de::{MapAccess, Visitor},
6    ser::SerializeMap,
7};
8
9use crate::{
10    datasets::{CatEv, CatSample, CatTable},
11    impl_json_io,
12    inference::TopologicalOrder,
13    io::{BifIO, BifParser},
14    models::{BN, CPD, CatCPD, DiGraph, Graph, Labelled},
15    set,
16    types::{Labels, Map, States},
17};
18
19/// A categorical Bayesian network.
20#[derive(Clone, Debug)]
21pub struct CatBN {
22    /// The name of the model.
23    name: Option<String>,
24    /// The description of the model.
25    description: Option<String>,
26    /// The labels of the variables.
27    labels: Labels,
28    /// The states of the variables.
29    states: States,
30    /// The shape of the variables.
31    shape: Array1<usize>,
32    /// The graph of the model.
33    graph: DiGraph,
34    /// The parameters of the model.
35    cpds: Map<String, CatCPD>,
36    /// The topological order of the graph.
37    topological_order: Vec<usize>,
38}
39
40impl CatBN {
41    /// Returns the states of the variables.
42    ///
43    /// # Returns
44    ///
45    /// A reference to the states of the variables.
46    ///
47    #[inline]
48    pub const fn states(&self) -> &States {
49        &self.states
50    }
51
52    /// Returns the shape of the variables.
53    ///
54    /// # Returns
55    ///
56    /// A reference to the shape of the variables.
57    ///
58    #[inline]
59    pub fn shape(&self) -> &Array1<usize> {
60        &self.shape
61    }
62}
63
64impl PartialEq for CatBN {
65    fn eq(&self, other: &Self) -> bool {
66        self.labels.eq(&other.labels)
67            && self.states.eq(&other.states)
68            && self.shape.eq(&other.shape)
69            && self.graph.eq(&other.graph)
70            && self.topological_order.eq(&other.topological_order)
71            && self.cpds.eq(&other.cpds)
72    }
73}
74
75impl AbsDiffEq for CatBN {
76    type Epsilon = f64;
77
78    fn default_epsilon() -> Self::Epsilon {
79        Self::Epsilon::default_epsilon()
80    }
81
82    fn abs_diff_eq(&self, other: &Self, epsilon: Self::Epsilon) -> bool {
83        self.labels.eq(&other.labels)
84            && self.states.eq(&other.states)
85            && self.shape.eq(&other.shape)
86            && self.graph.eq(&other.graph)
87            && self.topological_order.eq(&other.topological_order)
88            && self
89                .cpds
90                .iter()
91                .zip(&other.cpds)
92                .all(|((label, cpd), (other_label, other_cpd))| {
93                    label.eq(other_label) && cpd.abs_diff_eq(other_cpd, epsilon)
94                })
95    }
96}
97
98impl RelativeEq for CatBN {
99    fn default_max_relative() -> Self::Epsilon {
100        Self::Epsilon::default_max_relative()
101    }
102
103    fn relative_eq(
104        &self,
105        other: &Self,
106        epsilon: Self::Epsilon,
107        max_relative: Self::Epsilon,
108    ) -> bool {
109        self.labels.eq(&other.labels)
110            && self.states.eq(&other.states)
111            && self.shape.eq(&other.shape)
112            && self.graph.eq(&other.graph)
113            && self.topological_order.eq(&other.topological_order)
114            && self
115                .cpds
116                .iter()
117                .zip(&other.cpds)
118                .all(|((label, cpd), (other_label, other_cpd))| {
119                    label.eq(other_label) && cpd.relative_eq(other_cpd, epsilon, max_relative)
120                })
121    }
122}
123
124impl Labelled for CatBN {
125    #[inline]
126    fn labels(&self) -> &Labels {
127        &self.labels
128    }
129}
130
131impl BN for CatBN {
132    type CPD = CatCPD;
133    type Evidence = CatEv;
134    type Sample = CatSample;
135    type Samples = CatTable;
136
137    fn new<I>(graph: DiGraph, cpds: I) -> Self
138    where
139        I: IntoIterator<Item = Self::CPD>,
140    {
141        // Collect the CPDs into a map.
142        let mut cpds: Map<_, _> = cpds
143            .into_iter()
144            // Assert CPD contains exactly one label.
145            // TODO: Refactor code and remove this assumption.
146            .inspect(|x| {
147                assert_eq!(x.labels().len(), 1, "CPD must contain exactly one label.");
148            })
149            .map(|x| (x.labels()[0].to_owned(), x))
150            .collect();
151        // Sort the CPDs by their labels.
152        cpds.sort_keys();
153
154        // Assert same number of graph labels and CPDs.
155        assert!(
156            graph.labels().iter().eq(cpds.keys()),
157            "Graph labels and distributions labels must be the same."
158        );
159
160        // Allocate the states of the variables.
161        let mut states: States = Default::default();
162        // Insert the states of the variables into the map to check if they are the same.
163        for cpd in cpds.values() {
164            cpd.states()
165                .iter()
166                .chain(cpd.conditioning_states())
167                .for_each(|(l, s)| {
168                    // Check if the states are already in the map.
169                    if let Some(existing_states) = states.get(l) {
170                        // Check if the states are the same.
171                        assert_eq!(
172                            existing_states, s,
173                            "States of `{l}` must be the same across CPDs.",
174                        );
175                    } else {
176                        // Insert the states into the map.
177                        states.insert(l.to_owned(), s.clone());
178                    }
179                });
180        }
181        // Sort the states of the variables.
182        states.sort_keys();
183
184        // Get the labels of the variables.
185        let labels: Labels = states.keys().cloned().collect();
186        // Get the shape of the variables.
187        let shape: Array1<usize> = states.values().map(|s| s.len()).collect();
188
189        // Check if all vertices have the same labels as their parents.
190        graph.vertices().iter().for_each(|&i| {
191            // Get the parents of the vertex.
192            let pa_i = graph.parents(&set![i]).into_iter();
193            let pa_i: &Labels = &pa_i.map(|j| labels[j].to_owned()).collect();
194            // Get the conditioning labels of the CPD.
195            let pa_j = cpds[&labels[i]].conditioning_labels();
196            // Assert they are the same.
197            assert_eq!(
198                pa_i, pa_j,
199                "Graph parents labels and CPD conditioning labels must be the same:\n\
200                \t expected:    {:?} ,\n\
201                \t found:       {:?} .",
202                pa_i, pa_j
203            );
204        });
205
206        // Assert the graph is acyclic.
207        let topological_order = graph.topological_order().expect("Graph must be acyclic.");
208
209        Self {
210            name: None,
211            description: None,
212            labels,
213            states,
214            shape,
215            graph,
216            cpds,
217            topological_order,
218        }
219    }
220
221    #[inline]
222    fn name(&self) -> Option<&str> {
223        self.name.as_deref()
224    }
225
226    #[inline]
227    fn description(&self) -> Option<&str> {
228        self.description.as_deref()
229    }
230
231    #[inline]
232    fn graph(&self) -> &DiGraph {
233        &self.graph
234    }
235
236    #[inline]
237    fn cpds(&self) -> &Map<String, Self::CPD> {
238        &self.cpds
239    }
240
241    #[inline]
242    fn parameters_size(&self) -> usize {
243        self.cpds.iter().map(|(_, x)| x.parameters_size()).sum()
244    }
245
246    #[inline]
247    fn topological_order(&self) -> &[usize] {
248        &self.topological_order
249    }
250
251    fn with_optionals<I>(
252        name: Option<String>,
253        description: Option<String>,
254        graph: DiGraph,
255        cpds: I,
256    ) -> Self
257    where
258        I: IntoIterator<Item = Self::CPD>,
259    {
260        // Assert name is not empty string.
261        if let Some(name) = &name {
262            assert!(!name.is_empty(), "Name cannot be an empty string.");
263        }
264        // Assert description is not empty string.
265        if let Some(description) = &description {
266            assert!(
267                !description.is_empty(),
268                "Description cannot be an empty string."
269            );
270        }
271
272        // Construct the BN.
273        let mut bn = Self::new(graph, cpds);
274
275        // Set the optional fields.
276        bn.name = name;
277        bn.description = description;
278
279        bn
280    }
281}
282
283impl Serialize for CatBN {
284    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
285    where
286        S: Serializer,
287    {
288        // Count the elements to serialize.
289        let mut size = 3;
290        size += self.name.is_some() as usize;
291        size += self.description.is_some() as usize;
292
293        // Allocate the map.
294        let mut map = serializer.serialize_map(Some(size))?;
295
296        // Serialize name, if any.
297        if let Some(name) = &self.name {
298            map.serialize_entry("name", name)?;
299        }
300        // Serialize description, if any.
301        if let Some(description) = &self.description {
302            map.serialize_entry("description", description)?;
303        }
304        // Serialize graph.
305        map.serialize_entry("graph", &self.graph)?;
306
307        // Convert the CPDs to a flat format.
308        let cpds: Vec<_> = self.cpds.values().cloned().collect();
309        // Serialize CPDs.
310        map.serialize_entry("cpds", &cpds)?;
311
312        // Serialize type.
313        map.serialize_entry("type", "catbn")?;
314
315        // Finalize the map.
316        map.end()
317    }
318}
319
320impl<'de> Deserialize<'de> for CatBN {
321    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
322    where
323        D: Deserializer<'de>,
324    {
325        #[derive(Deserialize)]
326        #[serde(field_identifier, rename_all = "snake_case")]
327        enum Field {
328            Name,
329            Description,
330            Graph,
331            Cpds,
332            Type,
333        }
334
335        struct CatBNVisitor;
336
337        impl<'de> Visitor<'de> for CatBNVisitor {
338            type Value = CatBN;
339
340            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
341                formatter.write_str("struct CatBN")
342            }
343
344            fn visit_map<V>(self, mut map: V) -> Result<CatBN, V::Error>
345            where
346                V: MapAccess<'de>,
347            {
348                use serde::de::Error as E;
349
350                // Allocate fields
351                let mut name = None;
352                let mut description = None;
353                let mut graph = None;
354                let mut cpds = None;
355                let mut type_ = None;
356
357                // Parse the map.
358                while let Some(key) = map.next_key()? {
359                    match key {
360                        Field::Name => {
361                            if name.is_some() {
362                                return Err(E::duplicate_field("name"));
363                            }
364                            name = Some(map.next_value()?);
365                        }
366                        Field::Description => {
367                            if description.is_some() {
368                                return Err(E::duplicate_field("description"));
369                            }
370                            description = Some(map.next_value()?);
371                        }
372                        Field::Graph => {
373                            if graph.is_some() {
374                                return Err(E::duplicate_field("graph"));
375                            }
376                            graph = Some(map.next_value()?);
377                        }
378                        Field::Cpds => {
379                            if cpds.is_some() {
380                                return Err(E::duplicate_field("cpds"));
381                            }
382                            cpds = Some(map.next_value()?);
383                        }
384                        Field::Type => {
385                            if type_.is_some() {
386                                return Err(E::duplicate_field("type"));
387                            }
388                            type_ = Some(map.next_value()?);
389                        }
390                    }
391                }
392
393                // Check required fields.
394                let graph = graph.ok_or_else(|| E::missing_field("graph"))?;
395                let cpds = cpds.ok_or_else(|| E::missing_field("cpds"))?;
396
397                // Assert type is correct.
398                let type_: String = type_.ok_or_else(|| E::missing_field("type"))?;
399                assert_eq!(type_, "catbn", "Invalid type for CatBN.");
400
401                // Set helper types.
402                let cpds: Vec<_> = cpds;
403
404                Ok(CatBN::with_optionals(name, description, graph, cpds))
405            }
406        }
407
408        const FIELDS: &[&str] = &["name", "description", "graph", "cpds", "type"];
409
410        deserializer.deserialize_struct("CatBN", FIELDS, CatBNVisitor)
411    }
412}
413
414// Implement `JsonIO` for `CatBN`.
415impl_json_io!(CatBN);
416
417impl BifIO for CatBN {
418    fn from_bif(bif: &str) -> Self {
419        BifParser::parse_str(bif)
420    }
421
422    fn to_bif(&self) -> String {
423        todo!() // FIXME:
424    }
425
426    fn read_bif(path: &str) -> Self {
427        Self::from_bif(&std::fs::read_to_string(path).expect("Failed to read BIF file."))
428    }
429
430    fn write_bif(&self, path: &str) {
431        std::fs::write(path, self.to_bif()).expect("Failed to write BIF file.");
432    }
433}