causal_hub/models/continuous_time_bayesian_network/categorical/
model.rs1use approx::{AbsDiffEq, RelativeEq};
2use ndarray::prelude::*;
3use serde::{
4 Deserialize, Deserializer, Serialize, Serializer,
5 de::{MapAccess, Visitor},
6 ser::SerializeMap,
7};
8
9use crate::{
10 datasets::{CatSample, CatTrj, CatTrjs},
11 impl_json_io,
12 models::{BN, CIM, CTBN, CatBN, CatCIM, CatCPD, DiGraph, Graph, Labelled},
13 set,
14 types::{Labels, Map, Set, States},
15};
16
17#[derive(Clone, Debug)]
19pub struct CatCTBN {
20 name: Option<String>,
22 description: Option<String>,
24 labels: Labels,
26 states: States,
28 shape: Array1<usize>,
30 initial_distribution: CatBN,
32 graph: DiGraph,
34 cims: Map<String, CatCIM>,
36}
37
38impl CatCTBN {
39 #[inline]
46 pub fn name(&self) -> Option<&str> {
47 self.name.as_deref()
48 }
49
50 #[inline]
57 pub fn description(&self) -> Option<&str> {
58 self.description.as_deref()
59 }
60
61 #[inline]
68 pub const fn states(&self) -> &States {
69 self.initial_distribution.states()
70 }
71}
72
73impl PartialEq for CatCTBN {
74 fn eq(&self, other: &Self) -> bool {
75 self.labels.eq(&other.labels)
76 && self.states.eq(&other.states)
77 && self.shape.eq(&other.shape)
78 && self.initial_distribution.eq(&other.initial_distribution)
79 && self.graph.eq(&other.graph)
80 && self.cims.eq(&other.cims)
81 }
82}
83
84impl AbsDiffEq for CatCTBN {
85 type Epsilon = f64;
86
87 fn default_epsilon() -> Self::Epsilon {
88 Self::Epsilon::default_epsilon()
89 }
90
91 fn abs_diff_eq(&self, other: &Self, epsilon: Self::Epsilon) -> bool {
92 self.labels.eq(&other.labels)
93 && self.states.eq(&other.states)
94 && self.shape.eq(&other.shape)
95 && self.initial_distribution.eq(&other.initial_distribution)
96 && self.graph.eq(&other.graph)
97 && self
98 .cims
99 .iter()
100 .zip(&other.cims)
101 .all(|((label, cpd), (other_label, other_cpd))| {
102 label.eq(other_label) && cpd.abs_diff_eq(other_cpd, epsilon)
103 })
104 }
105}
106
107impl RelativeEq for CatCTBN {
108 fn default_max_relative() -> Self::Epsilon {
109 Self::Epsilon::default_max_relative()
110 }
111
112 fn relative_eq(
113 &self,
114 other: &Self,
115 epsilon: Self::Epsilon,
116 max_relative: Self::Epsilon,
117 ) -> bool {
118 self.labels.eq(&other.labels)
119 && self.states.eq(&other.states)
120 && self.shape.eq(&other.shape)
121 && self.initial_distribution.eq(&other.initial_distribution)
122 && self.graph.eq(&other.graph)
123 && self
124 .cims
125 .iter()
126 .zip(&other.cims)
127 .all(|((label, cpd), (other_label, other_cpd))| {
128 label.eq(other_label) && cpd.relative_eq(other_cpd, epsilon, max_relative)
129 })
130 }
131}
132
133impl Labelled for CatCTBN {
134 #[inline]
135 fn labels(&self) -> &Labels {
136 &self.labels
137 }
138}
139
140impl CTBN for CatCTBN {
141 type CIM = CatCIM;
142 type InitialDistribution = CatBN;
143 type Event = (f64, CatSample);
144 type Trajectory = CatTrj;
145 type Trajectories = CatTrjs;
146
147 fn new<I>(graph: DiGraph, cims: I) -> Self
148 where
149 I: IntoIterator<Item = Self::CIM>,
150 {
151 let mut cims: Map<_, _> = cims
153 .into_iter()
154 .inspect(|x| {
157 assert_eq!(x.labels().len(), 1, "CPD must contain exactly one label.");
158 })
159 .map(|x| (x.labels()[0].to_owned(), x))
160 .collect();
161 cims.sort_keys();
163
164 let mut states: States = Default::default();
166 for cim in cims.values() {
168 cim.states()
169 .iter()
170 .chain(cim.conditioning_states())
171 .for_each(|(l, s)| {
172 if let Some(existing_states) = states.get(l) {
174 assert_eq!(
176 existing_states, s,
177 "States of `{l}` must be the same across CIMs.",
178 );
179 } else {
180 states.insert(l.to_owned(), s.clone());
182 }
183 });
184 }
185 states.sort_keys();
187
188 let labels: Labels = states.keys().cloned().collect();
190 let shape = Array::from_iter(states.values().map(Set::len));
192
193 assert!(
195 graph.labels().iter().eq(cims.keys()),
196 "Graph labels and distributions labels must be the same."
197 );
198
199 graph.vertices().iter().for_each(|&i| {
201 let pa_i = graph.parents(&set![i]).into_iter();
203 let pa_i: &Labels = &pa_i.map(|j| labels[j].to_owned()).collect();
204 let pa_j = cims[&labels[i]].conditioning_labels();
206 assert_eq!(
208 pa_i, pa_j,
209 "Graph parents labels and CIM conditioning labels must be the same:\n\
210 \t expected: {:?} ,\n\
211 \t found: {:?} .",
212 pa_i, pa_j
213 );
214 });
215
216 let initial_graph = DiGraph::empty(graph.labels());
218 let initial_cpds = cims.values().map(|cim| {
220 let states = cim.states().clone();
222 let conditioning_states = States::default();
224 let alpha = cim.shape().product();
226 let parameters = Array::from_vec(vec![1. / alpha as f64; alpha]);
227 let parameters = parameters.insert_axis(Axis(0));
228 CatCPD::new(states, conditioning_states, parameters)
230 });
231 let initial_distribution = CatBN::new(initial_graph, initial_cpds);
233
234 Self {
235 name: None,
236 description: None,
237 labels,
238 states,
239 shape,
240 initial_distribution,
241 graph,
242 cims,
243 }
244 }
245
246 fn initial_distribution(&self) -> &Self::InitialDistribution {
247 &self.initial_distribution
248 }
249
250 fn graph(&self) -> &DiGraph {
251 &self.graph
252 }
253
254 fn cims(&self) -> &Map<String, Self::CIM> {
255 &self.cims
256 }
257
258 fn parameters_size(&self) -> usize {
259 self.initial_distribution.parameters_size()
261 + self
263 .cims
264 .values()
265 .map(|x| x.parameters_size())
266 .sum::<usize>()
267 }
268
269 fn with_optionals<I>(
270 name: Option<String>,
271 description: Option<String>,
272 initial_distribution: Self::InitialDistribution,
273 graph: DiGraph,
274 cims: I,
275 ) -> Self
276 where
277 I: IntoIterator<Item = Self::CIM>,
278 {
279 if let Some(name) = &name {
281 assert!(!name.is_empty(), "Name cannot be an empty string.");
282 }
283 if let Some(description) = &description {
285 assert!(
286 !description.is_empty(),
287 "Description cannot be an empty string."
288 );
289 }
290
291 let mut ctbn = Self::new(graph, cims);
293
294 assert!(
296 initial_distribution.labels().eq(ctbn.labels()),
297 "Initial distribution labels must be the same as the CIMs labels."
298 );
299 assert!(
301 initial_distribution
302 .cpds()
303 .into_iter()
304 .zip(ctbn.cims())
305 .all(|((_, cpd), (_, cim))| cpd.states().eq(cim.states())),
306 "Initial distribution states must be the same as the CIMs states."
307 );
308
309 ctbn.name = name;
311 ctbn.description = description;
312 ctbn.initial_distribution = initial_distribution;
313
314 ctbn
315 }
316}
317
318impl Serialize for CatCTBN {
319 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
320 where
321 S: Serializer,
322 {
323 let mut size = 4;
325 size += self.name.is_some() as usize;
326 size += self.description.is_some() as usize;
327
328 let mut map = serializer.serialize_map(Some(size))?;
330
331 let cims: Vec<_> = self.cims.values().cloned().collect();
333
334 if let Some(name) = &self.name {
336 map.serialize_entry("name", name)?;
337 }
338 if let Some(description) = &self.description {
340 map.serialize_entry("description", description)?;
341 }
342 map.serialize_entry("initial_distribution", &self.initial_distribution)?;
344 map.serialize_entry("graph", &self.graph)?;
346 map.serialize_entry("cims", &cims)?;
348 map.serialize_entry("type", "catctbn")?;
350
351 map.end()
353 }
354}
355
356impl<'de> Deserialize<'de> for CatCTBN {
357 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
358 where
359 D: Deserializer<'de>,
360 {
361 #[derive(Deserialize)]
362 #[serde(field_identifier, rename_all = "snake_case")]
363 enum Field {
364 Name,
365 Description,
366 InitialDistribution,
367 Graph,
368 Cims,
369 Type,
370 }
371
372 struct CatCTBNVisitor;
373
374 impl<'de> Visitor<'de> for CatCTBNVisitor {
375 type Value = CatCTBN;
376
377 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
378 formatter.write_str("struct CatCTBN")
379 }
380
381 fn visit_map<V>(self, mut map: V) -> Result<CatCTBN, V::Error>
382 where
383 V: MapAccess<'de>,
384 {
385 use serde::de::Error as E;
386
387 let mut name = None;
389 let mut description = None;
390 let mut initial_distribution = None;
391 let mut graph = None;
392 let mut cims = None;
393 let mut type_ = None;
394
395 while let Some(key) = map.next_key()? {
397 match key {
398 Field::Name => {
399 if name.is_some() {
400 return Err(E::duplicate_field("name"));
401 }
402 name = Some(map.next_value()?);
403 }
404 Field::Description => {
405 if description.is_some() {
406 return Err(E::duplicate_field("description"));
407 }
408 description = Some(map.next_value()?);
409 }
410 Field::InitialDistribution => {
411 if initial_distribution.is_some() {
412 return Err(E::duplicate_field("initial_distribution"));
413 }
414 initial_distribution = Some(map.next_value()?);
415 }
416 Field::Graph => {
417 if graph.is_some() {
418 return Err(E::duplicate_field("graph"));
419 }
420 graph = Some(map.next_value()?);
421 }
422 Field::Cims => {
423 if cims.is_some() {
424 return Err(E::duplicate_field("cims"));
425 }
426 cims = Some(map.next_value()?);
427 }
428 Field::Type => {
429 if type_.is_some() {
430 return Err(E::duplicate_field("type"));
431 }
432 type_ = Some(map.next_value()?);
433 }
434 }
435 }
436
437 let initial_distribution =
439 initial_distribution.ok_or_else(|| E::missing_field("initial_distribution"))?;
440 let graph = graph.ok_or_else(|| E::missing_field("graph"))?;
441 let cims = cims.ok_or_else(|| E::missing_field("cims"))?;
442
443 let type_: String = type_.ok_or_else(|| E::missing_field("type"))?;
445 assert_eq!(type_, "catctbn", "Invalid type for CatCTBN.");
446
447 let cims: Vec<_> = cims;
449
450 Ok(CatCTBN::with_optionals(
451 name,
452 description,
453 initial_distribution,
454 graph,
455 cims,
456 ))
457 }
458 }
459
460 const FIELDS: &[&str] = &[
461 "name",
462 "description",
463 "initial_distribution",
464 "graph",
465 "cims",
466 "type",
467 ];
468
469 deserializer.deserialize_struct("CatCTBN", FIELDS, CatCTBNVisitor)
470 }
471}
472
473impl_json_io!(CatCTBN);