causal_hub/models/bayesian_network/gaussian/
model.rs1use approx::{AbsDiffEq, RelativeEq};
2use serde::{
3 Deserialize, Deserializer, Serialize, Serializer,
4 de::{MapAccess, Visitor},
5 ser::SerializeMap,
6};
7
8use crate::{
9 datasets::{GaussEv, GaussSample, GaussTable},
10 impl_json_io,
11 inference::TopologicalOrder,
12 models::{BN, CPD, DiGraph, GaussCPD, Graph, Labelled},
13 set,
14 types::{Labels, Map},
15};
16
17#[derive(Clone, Debug)]
19pub struct GaussBN {
20 name: Option<String>,
22 description: Option<String>,
24 labels: Labels,
26 graph: DiGraph,
28 cpds: Map<String, GaussCPD>,
30 topological_order: Vec<usize>,
32}
33
34impl PartialEq for GaussBN {
35 fn eq(&self, other: &Self) -> bool {
36 self.labels.eq(&other.labels)
37 && self.graph.eq(&other.graph)
38 && self.topological_order.eq(&other.topological_order)
39 && self.cpds.eq(&other.cpds)
40 }
41}
42
43impl AbsDiffEq for GaussBN {
44 type Epsilon = f64;
45
46 fn default_epsilon() -> Self::Epsilon {
47 Self::Epsilon::default_epsilon()
48 }
49
50 fn abs_diff_eq(&self, other: &Self, epsilon: Self::Epsilon) -> bool {
51 self.labels.eq(&other.labels)
52 && self.graph.eq(&other.graph)
53 && self.topological_order.eq(&other.topological_order)
54 && self
55 .cpds
56 .iter()
57 .zip(&other.cpds)
58 .all(|((label, cpd), (other_label, other_cpd))| {
59 label.eq(other_label) && cpd.abs_diff_eq(other_cpd, epsilon)
60 })
61 }
62}
63
64impl RelativeEq for GaussBN {
65 fn default_max_relative() -> Self::Epsilon {
66 Self::Epsilon::default_max_relative()
67 }
68
69 fn relative_eq(
70 &self,
71 other: &Self,
72 epsilon: Self::Epsilon,
73 max_relative: Self::Epsilon,
74 ) -> bool {
75 self.labels.eq(&other.labels)
76 && self.graph.eq(&other.graph)
77 && self.topological_order.eq(&other.topological_order)
78 && self
79 .cpds
80 .iter()
81 .zip(&other.cpds)
82 .all(|((label, cpd), (other_label, other_cpd))| {
83 label.eq(other_label) && cpd.relative_eq(other_cpd, epsilon, max_relative)
84 })
85 }
86}
87
88impl Labelled for GaussBN {
89 #[inline]
90 fn labels(&self) -> &Labels {
91 &self.labels
92 }
93}
94
95impl BN for GaussBN {
96 type CPD = GaussCPD;
97 type Evidence = GaussEv;
98 type Sample = GaussSample;
99 type Samples = GaussTable;
100
101 fn new<I>(graph: DiGraph, cpds: I) -> Self
102 where
103 I: IntoIterator<Item = Self::CPD>,
104 {
105 let mut cpds: Map<_, _> = cpds
107 .into_iter()
108 .inspect(|x| {
111 assert_eq!(x.labels().len(), 1, "CPD must contain exactly one label.");
112 })
113 .map(|x| (x.labels()[0].to_owned(), x))
114 .collect();
115 cpds.sort_keys();
117
118 assert!(
120 graph.labels().iter().eq(cpds.keys()),
121 "Graph labels and distributions labels must be the same."
122 );
123
124 let labels: Labels = graph.labels().clone();
126
127 graph.vertices().iter().for_each(|&i| {
129 let pa_i = graph.parents(&set![i]).into_iter();
131 let pa_i: &Labels = &pa_i.map(|j| labels[j].to_owned()).collect();
132 let pa_j = cpds[&labels[i]].conditioning_labels();
134 assert_eq!(
136 pa_i, pa_j,
137 "Graph parents labels and CPD conditioning labels must be the same:\n\
138 \t expected: {:?} ,\n\
139 \t found: {:?} .",
140 pa_i, pa_j
141 );
142 });
143
144 let topological_order = graph.topological_order().expect("Graph must be acyclic.");
146
147 Self {
148 name: None,
149 description: None,
150 labels,
151 graph,
152 cpds,
153 topological_order,
154 }
155 }
156
157 #[inline]
158 fn name(&self) -> Option<&str> {
159 self.name.as_deref()
160 }
161
162 #[inline]
163 fn description(&self) -> Option<&str> {
164 self.description.as_deref()
165 }
166
167 #[inline]
168 fn graph(&self) -> &DiGraph {
169 &self.graph
170 }
171
172 #[inline]
173 fn cpds(&self) -> &Map<String, Self::CPD> {
174 &self.cpds
175 }
176
177 #[inline]
178 fn parameters_size(&self) -> usize {
179 self.cpds.iter().map(|(_, x)| x.parameters_size()).sum()
180 }
181
182 #[inline]
183 fn topological_order(&self) -> &[usize] {
184 &self.topological_order
185 }
186
187 fn with_optionals<I>(
188 name: Option<String>,
189 description: Option<String>,
190 graph: DiGraph,
191 cpds: I,
192 ) -> Self
193 where
194 I: IntoIterator<Item = Self::CPD>,
195 {
196 if let Some(name) = &name {
198 assert!(!name.is_empty(), "Name cannot be an empty string.");
199 }
200 if let Some(description) = &description {
202 assert!(
203 !description.is_empty(),
204 "Description cannot be an empty string."
205 );
206 }
207
208 let mut bn = Self::new(graph, cpds);
210
211 bn.name = name;
213 bn.description = description;
214
215 bn
216 }
217}
218
219impl Serialize for GaussBN {
220 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
221 where
222 S: Serializer,
223 {
224 let mut size = 3;
226 size += self.name.is_some() as usize;
228 size += self.description.is_some() as usize;
229 let mut map = serializer.serialize_map(Some(size))?;
231
232 if let Some(name) = &self.name {
234 map.serialize_entry("name", name)?;
235 }
236 if let Some(description) = &self.description {
238 map.serialize_entry("description", description)?;
239 }
240
241 map.serialize_entry("graph", &self.graph)?;
243
244 let cpds: Vec<_> = self.cpds.values().cloned().collect();
246 map.serialize_entry("cpds", &cpds)?;
248
249 map.serialize_entry("type", "gaussbn")?;
251
252 map.end()
254 }
255}
256
257impl<'de> Deserialize<'de> for GaussBN {
258 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
259 where
260 D: Deserializer<'de>,
261 {
262 #[derive(Deserialize)]
263 #[serde(field_identifier, rename_all = "snake_case")]
264 enum Field {
265 Name,
266 Description,
267 Graph,
268 Cpds,
269 Type,
270 }
271
272 struct GaussBNVisitor;
273
274 impl<'de> Visitor<'de> for GaussBNVisitor {
275 type Value = GaussBN;
276
277 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
278 formatter.write_str("struct GaussBN")
279 }
280
281 fn visit_map<V>(self, mut map: V) -> Result<GaussBN, V::Error>
282 where
283 V: MapAccess<'de>,
284 {
285 use serde::de::Error as E;
286
287 let mut name = None;
289 let mut description = None;
290 let mut graph = None;
291 let mut cpds = None;
292 let mut type_ = None;
293
294 while let Some(key) = map.next_key()? {
296 match key {
297 Field::Name => {
298 if name.is_some() {
299 return Err(E::duplicate_field("name"));
300 }
301 name = Some(map.next_value()?);
302 }
303 Field::Description => {
304 if description.is_some() {
305 return Err(E::duplicate_field("description"));
306 }
307 description = Some(map.next_value()?);
308 }
309 Field::Graph => {
310 if graph.is_some() {
311 return Err(E::duplicate_field("graph"));
312 }
313 graph = Some(map.next_value()?);
314 }
315 Field::Cpds => {
316 if cpds.is_some() {
317 return Err(E::duplicate_field("cpds"));
318 }
319 cpds = Some(map.next_value()?);
320 }
321 Field::Type => {
322 if type_.is_some() {
323 return Err(E::duplicate_field("type"));
324 }
325 type_ = Some(map.next_value()?);
326 }
327 }
328 }
329
330 let graph = graph.ok_or_else(|| E::missing_field("graph"))?;
332 let cpds = cpds.ok_or_else(|| E::missing_field("cpds"))?;
333
334 let type_: String = type_.ok_or_else(|| E::missing_field("type"))?;
336 assert_eq!(type_, "gaussbn", "Invalid type for GaussBN.");
337
338 let cpds: Vec<_> = cpds;
340
341 Ok(GaussBN::with_optionals(name, description, graph, cpds))
342 }
343 }
344
345 const FIELDS: &[&str] = &["name", "description", "graph", "cpds", "type"];
346
347 deserializer.deserialize_struct("GaussBN", FIELDS, GaussBNVisitor)
348 }
349}
350
351impl_json_io!(GaussBN);