causal_hub/models/continuous_time_bayesian_network/categorical/
model.rs

1use approx::{AbsDiffEq, RelativeEq};
2use ndarray::prelude::*;
3use serde::{
4    Deserialize, Deserializer, Serialize, Serializer,
5    de::{MapAccess, Visitor},
6    ser::SerializeMap,
7};
8
9use crate::{
10    datasets::{CatSample, CatTrj, CatTrjs},
11    impl_json_io,
12    models::{BN, CIM, CTBN, CatBN, CatCIM, CatCPD, DiGraph, Graph, Labelled},
13    set,
14    types::{Labels, Map, Set, States},
15};
16
17/// A categorical continuous time Bayesian network.
18#[derive(Clone, Debug)]
19pub struct CatCTBN {
20    /// The name of the model.
21    name: Option<String>,
22    /// The description of the model.
23    description: Option<String>,
24    /// The labels of the variables.
25    labels: Labels,
26    /// The states of the variables.
27    states: States,
28    /// The shape of the variables.
29    shape: Array1<usize>,
30    /// The initial distribution.
31    initial_distribution: CatBN,
32    /// The underlying graph.
33    graph: DiGraph,
34    /// The conditional intensity matrices.
35    cims: Map<String, CatCIM>,
36}
37
38impl CatCTBN {
39    /// Returns the name of the model, if any.
40    ///
41    /// # Returns
42    ///
43    /// The name of the model, if it exists.
44    ///
45    #[inline]
46    pub fn name(&self) -> Option<&str> {
47        self.name.as_deref()
48    }
49
50    /// Returns the description of the model, if any.
51    ///
52    /// # Returns
53    ///
54    /// The description of the model, if it exists.
55    ///
56    #[inline]
57    pub fn description(&self) -> Option<&str> {
58        self.description.as_deref()
59    }
60
61    /// Returns the states of the variables.
62    ///
63    /// # Returns
64    ///
65    /// A reference to the states of the variables.
66    ///
67    #[inline]
68    pub const fn states(&self) -> &States {
69        self.initial_distribution.states()
70    }
71}
72
73impl PartialEq for CatCTBN {
74    fn eq(&self, other: &Self) -> bool {
75        self.labels.eq(&other.labels)
76            && self.states.eq(&other.states)
77            && self.shape.eq(&other.shape)
78            && self.initial_distribution.eq(&other.initial_distribution)
79            && self.graph.eq(&other.graph)
80            && self.cims.eq(&other.cims)
81    }
82}
83
84impl AbsDiffEq for CatCTBN {
85    type Epsilon = f64;
86
87    fn default_epsilon() -> Self::Epsilon {
88        Self::Epsilon::default_epsilon()
89    }
90
91    fn abs_diff_eq(&self, other: &Self, epsilon: Self::Epsilon) -> bool {
92        self.labels.eq(&other.labels)
93            && self.states.eq(&other.states)
94            && self.shape.eq(&other.shape)
95            && self.initial_distribution.eq(&other.initial_distribution)
96            && self.graph.eq(&other.graph)
97            && self
98                .cims
99                .iter()
100                .zip(&other.cims)
101                .all(|((label, cpd), (other_label, other_cpd))| {
102                    label.eq(other_label) && cpd.abs_diff_eq(other_cpd, epsilon)
103                })
104    }
105}
106
107impl RelativeEq for CatCTBN {
108    fn default_max_relative() -> Self::Epsilon {
109        Self::Epsilon::default_max_relative()
110    }
111
112    fn relative_eq(
113        &self,
114        other: &Self,
115        epsilon: Self::Epsilon,
116        max_relative: Self::Epsilon,
117    ) -> bool {
118        self.labels.eq(&other.labels)
119            && self.states.eq(&other.states)
120            && self.shape.eq(&other.shape)
121            && self.initial_distribution.eq(&other.initial_distribution)
122            && self.graph.eq(&other.graph)
123            && self
124                .cims
125                .iter()
126                .zip(&other.cims)
127                .all(|((label, cpd), (other_label, other_cpd))| {
128                    label.eq(other_label) && cpd.relative_eq(other_cpd, epsilon, max_relative)
129                })
130    }
131}
132
133impl Labelled for CatCTBN {
134    #[inline]
135    fn labels(&self) -> &Labels {
136        &self.labels
137    }
138}
139
140impl CTBN for CatCTBN {
141    type CIM = CatCIM;
142    type InitialDistribution = CatBN;
143    type Event = (f64, CatSample);
144    type Trajectory = CatTrj;
145    type Trajectories = CatTrjs;
146
147    fn new<I>(graph: DiGraph, cims: I) -> Self
148    where
149        I: IntoIterator<Item = Self::CIM>,
150    {
151        // Collect the CPDs into a map.
152        let mut cims: Map<_, _> = cims
153            .into_iter()
154            // Assert CIM contains exactly one label.
155            // TODO: Refactor code and remove this assumption.
156            .inspect(|x| {
157                assert_eq!(x.labels().len(), 1, "CPD must contain exactly one label.");
158            })
159            .map(|x| (x.labels()[0].to_owned(), x))
160            .collect();
161        // Sort the CPDs by their labels.
162        cims.sort_keys();
163
164        // Allocate the states of the variables.
165        let mut states: States = Default::default();
166        // Insert the states of the variables into the map to check if they are the same.
167        for cim in cims.values() {
168            cim.states()
169                .iter()
170                .chain(cim.conditioning_states())
171                .for_each(|(l, s)| {
172                    // Check if the states are already in the map.
173                    if let Some(existing_states) = states.get(l) {
174                        // Check if the states are the same.
175                        assert_eq!(
176                            existing_states, s,
177                            "States of `{l}` must be the same across CIMs.",
178                        );
179                    } else {
180                        // Insert the states into the map.
181                        states.insert(l.to_owned(), s.clone());
182                    }
183                });
184        }
185        // Sort the states of the variables.
186        states.sort_keys();
187
188        // Get the labels of the variables.
189        let labels: Labels = states.keys().cloned().collect();
190        // Get the shape of the variables.
191        let shape = Array::from_iter(states.values().map(Set::len));
192
193        // Assert same number of graph labels and CIMs.
194        assert!(
195            graph.labels().iter().eq(cims.keys()),
196            "Graph labels and distributions labels must be the same."
197        );
198
199        // Check if all vertices have the same labels as their parents.
200        graph.vertices().iter().for_each(|&i| {
201            // Get the parents of the vertex.
202            let pa_i = graph.parents(&set![i]).into_iter();
203            let pa_i: &Labels = &pa_i.map(|j| labels[j].to_owned()).collect();
204            // Get the conditioning labels of the CIM.
205            let pa_j = cims[&labels[i]].conditioning_labels();
206            // Assert they are the same.
207            assert_eq!(
208                pa_i, pa_j,
209                "Graph parents labels and CIM conditioning labels must be the same:\n\
210                \t expected:    {:?} ,\n\
211                \t found:       {:?} .",
212                pa_i, pa_j
213            );
214        });
215
216        // Initialize an empty graph for the uniform initial distribution.
217        let initial_graph = DiGraph::empty(graph.labels());
218        // Initialize the CPDs as uniform distributions.
219        let initial_cpds = cims.values().map(|cim| {
220            // Get label and states of the CIM.
221            let states = cim.states().clone();
222            // Set empty conditioning states.
223            let conditioning_states = States::default();
224            // Set uniform parameters.
225            let alpha = cim.shape().product();
226            let parameters = Array::from_vec(vec![1. / alpha as f64; alpha]);
227            let parameters = parameters.insert_axis(Axis(0));
228            // Construct the CPD.
229            CatCPD::new(states, conditioning_states, parameters)
230        });
231        // Initialize a uniform initial distribution.
232        let initial_distribution = CatBN::new(initial_graph, initial_cpds);
233
234        Self {
235            name: None,
236            description: None,
237            labels,
238            states,
239            shape,
240            initial_distribution,
241            graph,
242            cims,
243        }
244    }
245
246    fn initial_distribution(&self) -> &Self::InitialDistribution {
247        &self.initial_distribution
248    }
249
250    fn graph(&self) -> &DiGraph {
251        &self.graph
252    }
253
254    fn cims(&self) -> &Map<String, Self::CIM> {
255        &self.cims
256    }
257
258    fn parameters_size(&self) -> usize {
259        // Parameters size of the initial distribution.
260        self.initial_distribution.parameters_size()
261            // Parameters size of the CIMs.
262            + self
263                .cims
264                .values()
265                .map(|x| x.parameters_size())
266                .sum::<usize>()
267    }
268
269    fn with_optionals<I>(
270        name: Option<String>,
271        description: Option<String>,
272        initial_distribution: Self::InitialDistribution,
273        graph: DiGraph,
274        cims: I,
275    ) -> Self
276    where
277        I: IntoIterator<Item = Self::CIM>,
278    {
279        // Assert name is not empty string.
280        if let Some(name) = &name {
281            assert!(!name.is_empty(), "Name cannot be an empty string.");
282        }
283        // Assert description is not empty string.
284        if let Some(description) = &description {
285            assert!(
286                !description.is_empty(),
287                "Description cannot be an empty string."
288            );
289        }
290
291        // Construct the categorical CTBN.
292        let mut ctbn = Self::new(graph, cims);
293
294        // Assert the initial distribution has same labels.
295        assert!(
296            initial_distribution.labels().eq(ctbn.labels()),
297            "Initial distribution labels must be the same as the CIMs labels."
298        );
299        // Assert the initial distribution has same states.
300        assert!(
301            initial_distribution
302                .cpds()
303                .into_iter()
304                .zip(ctbn.cims())
305                .all(|((_, cpd), (_, cim))| cpd.states().eq(cim.states())),
306            "Initial distribution states must be the same as the CIMs states."
307        );
308
309        // Set the optional fields.
310        ctbn.name = name;
311        ctbn.description = description;
312        ctbn.initial_distribution = initial_distribution;
313
314        ctbn
315    }
316}
317
318impl Serialize for CatCTBN {
319    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
320    where
321        S: Serializer,
322    {
323        // Count the elements to serialize.
324        let mut size = 4;
325        size += self.name.is_some() as usize;
326        size += self.description.is_some() as usize;
327
328        // Allocate the map.
329        let mut map = serializer.serialize_map(Some(size))?;
330
331        // Convert the CIMs to a flat format.
332        let cims: Vec<_> = self.cims.values().cloned().collect();
333
334        // Serialize name, if any.
335        if let Some(name) = &self.name {
336            map.serialize_entry("name", name)?;
337        }
338        // Serialize description, if any.
339        if let Some(description) = &self.description {
340            map.serialize_entry("description", description)?;
341        }
342        // Serialize initial distribution.
343        map.serialize_entry("initial_distribution", &self.initial_distribution)?;
344        // Serialize graph.
345        map.serialize_entry("graph", &self.graph)?;
346        // Serialize CIMs.
347        map.serialize_entry("cims", &cims)?;
348        // Serialize type.
349        map.serialize_entry("type", "catctbn")?;
350
351        // Finalize the map serialization.
352        map.end()
353    }
354}
355
356impl<'de> Deserialize<'de> for CatCTBN {
357    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
358    where
359        D: Deserializer<'de>,
360    {
361        #[derive(Deserialize)]
362        #[serde(field_identifier, rename_all = "snake_case")]
363        enum Field {
364            Name,
365            Description,
366            InitialDistribution,
367            Graph,
368            Cims,
369            Type,
370        }
371
372        struct CatCTBNVisitor;
373
374        impl<'de> Visitor<'de> for CatCTBNVisitor {
375            type Value = CatCTBN;
376
377            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
378                formatter.write_str("struct CatCTBN")
379            }
380
381            fn visit_map<V>(self, mut map: V) -> Result<CatCTBN, V::Error>
382            where
383                V: MapAccess<'de>,
384            {
385                use serde::de::Error as E;
386
387                // Allocate fields
388                let mut name = None;
389                let mut description = None;
390                let mut initial_distribution = None;
391                let mut graph = None;
392                let mut cims = None;
393                let mut type_ = None;
394
395                // Parse the map.
396                while let Some(key) = map.next_key()? {
397                    match key {
398                        Field::Name => {
399                            if name.is_some() {
400                                return Err(E::duplicate_field("name"));
401                            }
402                            name = Some(map.next_value()?);
403                        }
404                        Field::Description => {
405                            if description.is_some() {
406                                return Err(E::duplicate_field("description"));
407                            }
408                            description = Some(map.next_value()?);
409                        }
410                        Field::InitialDistribution => {
411                            if initial_distribution.is_some() {
412                                return Err(E::duplicate_field("initial_distribution"));
413                            }
414                            initial_distribution = Some(map.next_value()?);
415                        }
416                        Field::Graph => {
417                            if graph.is_some() {
418                                return Err(E::duplicate_field("graph"));
419                            }
420                            graph = Some(map.next_value()?);
421                        }
422                        Field::Cims => {
423                            if cims.is_some() {
424                                return Err(E::duplicate_field("cims"));
425                            }
426                            cims = Some(map.next_value()?);
427                        }
428                        Field::Type => {
429                            if type_.is_some() {
430                                return Err(E::duplicate_field("type"));
431                            }
432                            type_ = Some(map.next_value()?);
433                        }
434                    }
435                }
436
437                // Check required fields.
438                let initial_distribution =
439                    initial_distribution.ok_or_else(|| E::missing_field("initial_distribution"))?;
440                let graph = graph.ok_or_else(|| E::missing_field("graph"))?;
441                let cims = cims.ok_or_else(|| E::missing_field("cims"))?;
442
443                // Assert type is correct.
444                let type_: String = type_.ok_or_else(|| E::missing_field("type"))?;
445                assert_eq!(type_, "catctbn", "Invalid type for CatCTBN.");
446
447                // Set helper types.
448                let cims: Vec<_> = cims;
449
450                Ok(CatCTBN::with_optionals(
451                    name,
452                    description,
453                    initial_distribution,
454                    graph,
455                    cims,
456                ))
457            }
458        }
459
460        const FIELDS: &[&str] = &[
461            "name",
462            "description",
463            "initial_distribution",
464            "graph",
465            "cims",
466            "type",
467        ];
468
469        deserializer.deserialize_struct("CatCTBN", FIELDS, CatCTBNVisitor)
470    }
471}
472
473// Implement `JsonIO` for `CatCTBN`.
474impl_json_io!(CatCTBN);