causal_hub/models/bayesian_network/gaussian/
model.rs

1use approx::{AbsDiffEq, RelativeEq};
2use serde::{
3    Deserialize, Deserializer, Serialize, Serializer,
4    de::{MapAccess, Visitor},
5    ser::SerializeMap,
6};
7
8use crate::{
9    datasets::{GaussEv, GaussSample, GaussTable},
10    impl_json_io,
11    inference::TopologicalOrder,
12    models::{BN, CPD, DiGraph, GaussCPD, Graph, Labelled},
13    set,
14    types::{Labels, Map},
15};
16
17/// A Gaussian Bayesian network.
18#[derive(Clone, Debug)]
19pub struct GaussBN {
20    /// The name of the model.
21    name: Option<String>,
22    /// The description of the model.
23    description: Option<String>,
24    /// The labels of the variables.
25    labels: Labels,
26    /// The graph of the model.
27    graph: DiGraph,
28    /// The parameters of the model.
29    cpds: Map<String, GaussCPD>,
30    /// The topological order of the graph.
31    topological_order: Vec<usize>,
32}
33
34impl PartialEq for GaussBN {
35    fn eq(&self, other: &Self) -> bool {
36        self.labels.eq(&other.labels)
37            && self.graph.eq(&other.graph)
38            && self.topological_order.eq(&other.topological_order)
39            && self.cpds.eq(&other.cpds)
40    }
41}
42
43impl AbsDiffEq for GaussBN {
44    type Epsilon = f64;
45
46    fn default_epsilon() -> Self::Epsilon {
47        Self::Epsilon::default_epsilon()
48    }
49
50    fn abs_diff_eq(&self, other: &Self, epsilon: Self::Epsilon) -> bool {
51        self.labels.eq(&other.labels)
52            && self.graph.eq(&other.graph)
53            && self.topological_order.eq(&other.topological_order)
54            && self
55                .cpds
56                .iter()
57                .zip(&other.cpds)
58                .all(|((label, cpd), (other_label, other_cpd))| {
59                    label.eq(other_label) && cpd.abs_diff_eq(other_cpd, epsilon)
60                })
61    }
62}
63
64impl RelativeEq for GaussBN {
65    fn default_max_relative() -> Self::Epsilon {
66        Self::Epsilon::default_max_relative()
67    }
68
69    fn relative_eq(
70        &self,
71        other: &Self,
72        epsilon: Self::Epsilon,
73        max_relative: Self::Epsilon,
74    ) -> bool {
75        self.labels.eq(&other.labels)
76            && self.graph.eq(&other.graph)
77            && self.topological_order.eq(&other.topological_order)
78            && self
79                .cpds
80                .iter()
81                .zip(&other.cpds)
82                .all(|((label, cpd), (other_label, other_cpd))| {
83                    label.eq(other_label) && cpd.relative_eq(other_cpd, epsilon, max_relative)
84                })
85    }
86}
87
88impl Labelled for GaussBN {
89    #[inline]
90    fn labels(&self) -> &Labels {
91        &self.labels
92    }
93}
94
95impl BN for GaussBN {
96    type CPD = GaussCPD;
97    type Evidence = GaussEv;
98    type Sample = GaussSample;
99    type Samples = GaussTable;
100
101    fn new<I>(graph: DiGraph, cpds: I) -> Self
102    where
103        I: IntoIterator<Item = Self::CPD>,
104    {
105        // Collect the CPDs into a map.
106        let mut cpds: Map<_, _> = cpds
107            .into_iter()
108            // Assert CPD contains exactly one label.
109            // TODO: Refactor code and remove this assumption.
110            .inspect(|x| {
111                assert_eq!(x.labels().len(), 1, "CPD must contain exactly one label.");
112            })
113            .map(|x| (x.labels()[0].to_owned(), x))
114            .collect();
115        // Sort the CPDs by their labels.
116        cpds.sort_keys();
117
118        // Assert same number of graph labels and CPDs.
119        assert!(
120            graph.labels().iter().eq(cpds.keys()),
121            "Graph labels and distributions labels must be the same."
122        );
123
124        // Get the labels of the variables.
125        let labels: Labels = graph.labels().clone();
126
127        // Check if all vertices have the same labels as their parents.
128        graph.vertices().iter().for_each(|&i| {
129            // Get the parents of the vertex.
130            let pa_i = graph.parents(&set![i]).into_iter();
131            let pa_i: &Labels = &pa_i.map(|j| labels[j].to_owned()).collect();
132            // Get the conditioning labels of the CPD.
133            let pa_j = cpds[&labels[i]].conditioning_labels();
134            // Assert they are the same.
135            assert_eq!(
136                pa_i, pa_j,
137                "Graph parents labels and CPD conditioning labels must be the same:\n\
138                \t expected:    {:?} ,\n\
139                \t found:       {:?} .",
140                pa_i, pa_j
141            );
142        });
143
144        // Assert the graph is acyclic.
145        let topological_order = graph.topological_order().expect("Graph must be acyclic.");
146
147        Self {
148            name: None,
149            description: None,
150            labels,
151            graph,
152            cpds,
153            topological_order,
154        }
155    }
156
157    #[inline]
158    fn name(&self) -> Option<&str> {
159        self.name.as_deref()
160    }
161
162    #[inline]
163    fn description(&self) -> Option<&str> {
164        self.description.as_deref()
165    }
166
167    #[inline]
168    fn graph(&self) -> &DiGraph {
169        &self.graph
170    }
171
172    #[inline]
173    fn cpds(&self) -> &Map<String, Self::CPD> {
174        &self.cpds
175    }
176
177    #[inline]
178    fn parameters_size(&self) -> usize {
179        self.cpds.iter().map(|(_, x)| x.parameters_size()).sum()
180    }
181
182    #[inline]
183    fn topological_order(&self) -> &[usize] {
184        &self.topological_order
185    }
186
187    fn with_optionals<I>(
188        name: Option<String>,
189        description: Option<String>,
190        graph: DiGraph,
191        cpds: I,
192    ) -> Self
193    where
194        I: IntoIterator<Item = Self::CPD>,
195    {
196        // Assert name is not empty string.
197        if let Some(name) = &name {
198            assert!(!name.is_empty(), "Name cannot be an empty string.");
199        }
200        // Assert description is not empty string.
201        if let Some(description) = &description {
202            assert!(
203                !description.is_empty(),
204                "Description cannot be an empty string."
205            );
206        }
207
208        // Construct the BN.
209        let mut bn = Self::new(graph, cpds);
210
211        // Set the optional fields.
212        bn.name = name;
213        bn.description = description;
214
215        bn
216    }
217}
218
219impl Serialize for GaussBN {
220    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
221    where
222        S: Serializer,
223    {
224        // Count the number of fields.
225        let mut size = 3;
226        // Add optional fields, if any.
227        size += self.name.is_some() as usize;
228        size += self.description.is_some() as usize;
229        // Allocate the map.
230        let mut map = serializer.serialize_map(Some(size))?;
231
232        // Serialize the name, if any.
233        if let Some(name) = &self.name {
234            map.serialize_entry("name", name)?;
235        }
236        // Serialize the description, if any.
237        if let Some(description) = &self.description {
238            map.serialize_entry("description", description)?;
239        }
240
241        // Serialize the graph.
242        map.serialize_entry("graph", &self.graph)?;
243
244        // Convert the CPDs to a flat format.
245        let cpds: Vec<_> = self.cpds.values().cloned().collect();
246        // Serialize CPDs.
247        map.serialize_entry("cpds", &cpds)?;
248
249        // Serialize type.
250        map.serialize_entry("type", "gaussbn")?;
251
252        // Finalize the map.
253        map.end()
254    }
255}
256
257impl<'de> Deserialize<'de> for GaussBN {
258    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
259    where
260        D: Deserializer<'de>,
261    {
262        #[derive(Deserialize)]
263        #[serde(field_identifier, rename_all = "snake_case")]
264        enum Field {
265            Name,
266            Description,
267            Graph,
268            Cpds,
269            Type,
270        }
271
272        struct GaussBNVisitor;
273
274        impl<'de> Visitor<'de> for GaussBNVisitor {
275            type Value = GaussBN;
276
277            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
278                formatter.write_str("struct GaussBN")
279            }
280
281            fn visit_map<V>(self, mut map: V) -> Result<GaussBN, V::Error>
282            where
283                V: MapAccess<'de>,
284            {
285                use serde::de::Error as E;
286
287                // Allocate fields
288                let mut name = None;
289                let mut description = None;
290                let mut graph = None;
291                let mut cpds = None;
292                let mut type_ = None;
293
294                // Parse the map.
295                while let Some(key) = map.next_key()? {
296                    match key {
297                        Field::Name => {
298                            if name.is_some() {
299                                return Err(E::duplicate_field("name"));
300                            }
301                            name = Some(map.next_value()?);
302                        }
303                        Field::Description => {
304                            if description.is_some() {
305                                return Err(E::duplicate_field("description"));
306                            }
307                            description = Some(map.next_value()?);
308                        }
309                        Field::Graph => {
310                            if graph.is_some() {
311                                return Err(E::duplicate_field("graph"));
312                            }
313                            graph = Some(map.next_value()?);
314                        }
315                        Field::Cpds => {
316                            if cpds.is_some() {
317                                return Err(E::duplicate_field("cpds"));
318                            }
319                            cpds = Some(map.next_value()?);
320                        }
321                        Field::Type => {
322                            if type_.is_some() {
323                                return Err(E::duplicate_field("type"));
324                            }
325                            type_ = Some(map.next_value()?);
326                        }
327                    }
328                }
329
330                // Check required fields.
331                let graph = graph.ok_or_else(|| E::missing_field("graph"))?;
332                let cpds = cpds.ok_or_else(|| E::missing_field("cpds"))?;
333
334                // Assert type is correct.
335                let type_: String = type_.ok_or_else(|| E::missing_field("type"))?;
336                assert_eq!(type_, "gaussbn", "Invalid type for GaussBN.");
337
338                // Set helper types.
339                let cpds: Vec<_> = cpds;
340
341                Ok(GaussBN::with_optionals(name, description, graph, cpds))
342            }
343        }
344
345        const FIELDS: &[&str] = &["name", "description", "graph", "cpds", "type"];
346
347        deserializer.deserialize_struct("GaussBN", FIELDS, GaussBNVisitor)
348    }
349}
350
351// Implement `JsonIO` for `GaussBN`.
352impl_json_io!(GaussBN);