causal_hub/models/mod.rs
1mod bayesian_network;
2use std::ops::{DivAssign, MulAssign};
3
4use approx::{AbsDiffEq, RelativeEq};
5pub use bayesian_network::*;
6
7mod continuous_time_bayesian_network;
8pub use continuous_time_bayesian_network::*;
9use rand::Rng;
10
11mod graphs;
12use std::fmt::Debug;
13
14pub use graphs::*;
15
16use crate::types::{Labels, Set};
17
18/// A trait for models with labelled variables.
19pub trait Labelled {
20 /// Returns the labels of the variables.
21 ///
22 /// # Returns
23 ///
24 /// A reference to the labels.
25 ///
26 fn labels(&self) -> &Labels;
27
28 /// Return the variable index for a given label.
29 ///
30 /// # Arguments
31 ///
32 /// * `x` - The label of the variable.
33 ///
34 /// # Panics
35 ///
36 /// * If the label is not in the map.
37 ///
38 /// # Returns
39 ///
40 /// The index of the variable.
41 ///
42 #[inline]
43 fn label_to_index(&self, x: &str) -> usize {
44 self.labels()
45 .get_index_of(x)
46 .unwrap_or_else(|| panic!("Variable `{x}` label does not exist."))
47 }
48
49 /// Return the label for a given variable index.
50 ///
51 /// # Arguments
52 ///
53 /// * `x` - The index of the variable.
54 ///
55 /// # Panics
56 ///
57 /// * If the index is out of bounds.
58 ///
59 /// # Returns
60 ///
61 /// The label of the variable.
62 ///
63 #[inline]
64 fn index_to_label(&self, x: usize) -> &str {
65 self.labels()
66 .get_index(x)
67 .unwrap_or_else(|| panic!("Variable `{x}` is out of bounds."))
68 }
69
70 /// Maps an index from this model to another model with the same label.
71 ///
72 /// # Arguments
73 ///
74 /// * `x` - The index in this model.
75 /// * `other` - The labels of the other model.
76 ///
77 /// # Panics
78 ///
79 /// * If the index is out of bounds.
80 /// * If the label does not exist in the other model.
81 ///
82 /// # Returns
83 ///
84 /// The index in the other model.
85 ///
86 #[inline]
87 fn index_to(&self, x: usize, other: &Labels) -> usize {
88 // Get the label of the variable in this model.
89 let label = self.index_to_label(x);
90 // Get the index of the variable in the other model.
91 other.get_index_of(label).unwrap_or_else(|| {
92 panic!("Variable `{label}` label does not exist in the other model.")
93 })
94 }
95
96 /// Maps a set of indices from this model to another model with the same labels.
97 ///
98 /// # Arguments
99 ///
100 /// * `x` - The set of indices in this model.
101 /// * `other` - The labels of the other model.
102 ///
103 /// # Panics
104 ///
105 /// * If any index is out of bounds.
106 /// * If any label does not exist in the other model.
107 ///
108 /// # Returns
109 ///
110 /// The set of indices in the other model.
111 ///
112 #[inline]
113 fn indices_to(&self, x: &Set<usize>, other: &Labels) -> Set<usize> {
114 x.iter().map(|&x| self.index_to(x, other)).collect()
115 }
116
117 /// Maps an index from another model to this model with the same label.
118 ///
119 /// # Arguments
120 ///
121 /// * `x` - The index in the other model.
122 /// * `other` - The labels of the other model.
123 ///
124 /// # Panics
125 ///
126 /// * If the index is out of bounds.
127 /// * If the label does not exist in this model.
128 ///
129 /// # Returns
130 ///
131 /// The index in this model.
132 ///
133 #[inline]
134 fn index_from(&self, x: usize, other: &Labels) -> usize {
135 // Get the label of the variable in the other model.
136 let label = other
137 .get_index(x)
138 .unwrap_or_else(|| panic!("Variable `{x}` is out of bounds in the other model."));
139 // Get the index of the variable in this model.
140 self.labels()
141 .get_index_of(label)
142 .unwrap_or_else(|| panic!("Variable `{label}` label does not exist."))
143 }
144
145 /// Maps a set of indices from another model to this model with the same labels.
146 ///
147 /// # Arguments
148 ///
149 /// * `x` - The set of indices in the other model.
150 /// * `other` - The labels of the other model.
151 ///
152 /// # Panics
153 ///
154 /// * If any index is out of bounds.
155 /// * If any label does not exist in this model.
156 ///
157 /// # Returns
158 ///
159 /// The set of indices in this model.
160 ///
161 #[inline]
162 fn indices_from(&self, x: &Set<usize>, other: &Labels) -> Set<usize> {
163 x.iter().map(|&x| self.index_from(x, other)).collect()
164 }
165}
166
167/// A trait for conditional probability distributions.
168pub trait CPD: Clone + Debug + Labelled + PartialEq + AbsDiffEq + RelativeEq {
169 /// The type of the support.
170 type Support;
171 /// The type of the parameters.
172 type Parameters;
173 /// The type of the sufficient statistics.
174 type Statistics;
175
176 /// Returns the labels of the conditioned variables.
177 ///
178 /// # Returns
179 ///
180 /// A reference to the conditioning labels.
181 ///
182 fn conditioning_labels(&self) -> &Labels;
183
184 /// Returns the parameters.
185 ///
186 /// # Returns
187 ///
188 /// A reference to the parameters.
189 ///
190 fn parameters(&self) -> &Self::Parameters;
191
192 /// Returns the parameters size.
193 ///
194 /// # Returns
195 ///
196 /// The parameters size.
197 ///
198 fn parameters_size(&self) -> usize;
199
200 /// Returns the sufficient statistics, if any.
201 ///
202 /// # Returns
203 ///
204 /// An option containing a reference to the sufficient statistics.
205 ///
206 fn sample_statistics(&self) -> Option<&Self::Statistics>;
207
208 /// Returns the log-likelihood of the fitted dataset, if any.
209 ///
210 /// # Returns
211 ///
212 /// An option containing the log-likelihood.
213 ///
214 fn sample_log_likelihood(&self) -> Option<f64>;
215
216 /// Returns the value of probability (mass or density) function for P(X = x | Z = z).
217 ///
218 /// # Arguments
219 ///
220 /// * `x` - The value of the conditioned variables.
221 /// * `z` - The value of the conditioning variables.
222 ///
223 /// # Returns
224 ///
225 /// The probability P(X = x | Z = z).
226 ///
227 fn pf(&self, x: &Self::Support, z: &Self::Support) -> f64;
228
229 /// Samples from the conditional distribution P(X | Z = z).
230 ///
231 /// # Arguments
232 ///
233 /// * `rng` - A mutable reference to a random number generator.
234 /// * `z` - The value of the conditioning variables.
235 ///
236 /// # Returns
237 ///
238 /// A sample from P(X | Z = z).
239 ///
240 fn sample<R: Rng>(&self, rng: &mut R, z: &Self::Support) -> Self::Support;
241}
242
243/// A trait for conditional intensity matrices.
244pub trait CIM: Clone + Debug + Labelled + PartialEq + AbsDiffEq + RelativeEq {
245 /// The type of the support.
246 type Support;
247 /// The type of the parameters.
248 type Parameters;
249 /// The type of the sufficient statistics.
250 type Statistics;
251
252 /// Returns the labels of the conditioned variables.
253 ///
254 /// # Returns
255 ///
256 /// A reference to the conditioning labels.
257 ///
258 fn conditioning_labels(&self) -> &Labels;
259
260 /// Returns the parameters.
261 ///
262 /// # Returns
263 ///
264 /// A reference to the parameters.
265 ///
266 fn parameters(&self) -> &Self::Parameters;
267
268 /// Returns the parameters size.
269 ///
270 /// # Returns
271 ///
272 /// The parameters size.
273 ///
274 fn parameters_size(&self) -> usize;
275
276 /// Returns the sufficient statistics, if any.
277 ///
278 /// # Returns
279 ///
280 /// An option containing a reference to the sufficient statistics.
281 ///
282 fn sample_statistics(&self) -> Option<&Self::Statistics>;
283
284 /// Returns the log-likelihood of the fitted dataset, if any.
285 ///
286 /// # Returns
287 ///
288 /// An option containing the log-likelihood.
289 ///
290 fn sample_log_likelihood(&self) -> Option<f64>;
291}
292
293/// A trait for potential functions.
294pub trait Phi:
295 Clone
296 + Debug
297 + Labelled
298 + PartialEq
299 + AbsDiffEq
300 + RelativeEq
301 + for<'a> MulAssign<&'a Self>
302 + for<'a> DivAssign<&'a Self>
303{
304 /// The type of the CPD.
305 type CPD;
306 /// The type of the parameters.
307 type Parameters;
308 /// The type of the evidence.
309 type Evidence;
310
311 /// Returns the parameters.
312 ///
313 /// # Returns
314 ///
315 /// A reference to the parameters.
316 ///
317 fn parameters(&self) -> &Self::Parameters;
318
319 /// Returns the parameters size.
320 ///
321 /// # Returns
322 ///
323 /// The parameters size.
324 ///
325 fn parameters_size(&self) -> usize;
326
327 /// Conditions the potential on a set of variables.
328 ///
329 /// # Arguments
330 ///
331 /// * `e` - A map from variable indices to their observed states.
332 ///
333 /// # Returns
334 ///
335 /// A new potential instance.
336 ///
337 fn condition(&self, e: &Self::Evidence) -> Self;
338
339 /// Marginalizes the potential over a set of variables.
340 ///
341 /// # Arguments
342 ///
343 /// * `x` - A set of variable indices to marginalize over.
344 ///
345 /// # Returns
346 ///
347 /// A new potential instance.
348 ///
349 fn marginalize(&self, x: &Set<usize>) -> Self;
350
351 /// Normalizes the potential.
352 ///
353 /// # Returns
354 ///
355 /// The normalized potential.
356 ///
357 fn normalize(&self) -> Self;
358
359 /// Converts a CPD P(X | Z) to a potential \phi(X \cup Z).
360 ///
361 /// # Arguments
362 ///
363 /// * `cpd` - The CPD to convert.
364 ///
365 /// # Returns
366 ///
367 /// The corresponding potential.
368 ///
369 fn from_cpd(cpd: Self::CPD) -> Self;
370
371 /// Converts a potential \phi(X \cup Z) to a CPD P(X | Z).
372 ///
373 /// # Arguments
374 ///
375 /// * `x` - The set of variables.
376 /// * `z` - The set of conditioning variables.
377 ///
378 /// # Returns
379 ///
380 /// The corresponding CPD.
381 ///
382 fn into_cpd(self, x: &Set<usize>, z: &Set<usize>) -> Self::CPD;
383}