causal_hub/models/
mod.rs

1mod bayesian_network;
2use std::ops::{DivAssign, MulAssign};
3
4use approx::{AbsDiffEq, RelativeEq};
5pub use bayesian_network::*;
6
7mod continuous_time_bayesian_network;
8pub use continuous_time_bayesian_network::*;
9use rand::Rng;
10
11mod graphs;
12use std::fmt::Debug;
13
14pub use graphs::*;
15
16use crate::types::{Labels, Set};
17
18/// A trait for models with labelled variables.
19pub trait Labelled {
20    /// Returns the labels of the variables.
21    ///
22    /// # Returns
23    ///
24    /// A reference to the labels.
25    ///
26    fn labels(&self) -> &Labels;
27
28    /// Return the variable index for a given label.
29    ///
30    /// # Arguments
31    ///
32    /// * `x` - The label of the variable.
33    ///
34    /// # Panics
35    ///
36    /// * If the label is not in the map.
37    ///
38    /// # Returns
39    ///
40    /// The index of the variable.
41    ///
42    #[inline]
43    fn label_to_index(&self, x: &str) -> usize {
44        self.labels()
45            .get_index_of(x)
46            .unwrap_or_else(|| panic!("Variable `{x}` label does not exist."))
47    }
48
49    /// Return the label for a given variable index.
50    ///
51    /// # Arguments
52    ///
53    /// * `x` - The index of the variable.
54    ///
55    /// # Panics
56    ///
57    /// * If the index is out of bounds.
58    ///
59    /// # Returns
60    ///
61    /// The label of the variable.
62    ///
63    #[inline]
64    fn index_to_label(&self, x: usize) -> &str {
65        self.labels()
66            .get_index(x)
67            .unwrap_or_else(|| panic!("Variable `{x}` is out of bounds."))
68    }
69
70    /// Maps an index from this model to another model with the same label.
71    ///
72    /// # Arguments
73    ///
74    /// * `x` - The index in this model.
75    /// * `other` - The labels of the other model.
76    ///
77    /// # Panics
78    ///
79    /// * If the index is out of bounds.
80    /// * If the label does not exist in the other model.
81    ///
82    /// # Returns
83    ///
84    /// The index in the other model.
85    ///
86    #[inline]
87    fn index_to(&self, x: usize, other: &Labels) -> usize {
88        // Get the label of the variable in this model.
89        let label = self.index_to_label(x);
90        // Get the index of the variable in the other model.
91        other.get_index_of(label).unwrap_or_else(|| {
92            panic!("Variable `{label}` label does not exist in the other model.")
93        })
94    }
95
96    /// Maps a set of indices from this model to another model with the same labels.
97    ///
98    /// # Arguments
99    ///
100    /// * `x` - The set of indices in this model.
101    /// * `other` - The labels of the other model.
102    ///
103    /// # Panics
104    ///
105    /// * If any index is out of bounds.
106    /// * If any label does not exist in the other model.
107    ///
108    /// # Returns
109    ///
110    /// The set of indices in the other model.
111    ///
112    #[inline]
113    fn indices_to(&self, x: &Set<usize>, other: &Labels) -> Set<usize> {
114        x.iter().map(|&x| self.index_to(x, other)).collect()
115    }
116
117    /// Maps an index from another model to this model with the same label.
118    ///
119    /// # Arguments
120    ///
121    /// * `x` - The index in the other model.
122    /// * `other` - The labels of the other model.
123    ///
124    /// # Panics
125    ///
126    /// * If the index is out of bounds.
127    /// * If the label does not exist in this model.
128    ///
129    /// # Returns
130    ///
131    /// The index in this model.
132    ///
133    #[inline]
134    fn index_from(&self, x: usize, other: &Labels) -> usize {
135        // Get the label of the variable in the other model.
136        let label = other
137            .get_index(x)
138            .unwrap_or_else(|| panic!("Variable `{x}` is out of bounds in the other model."));
139        // Get the index of the variable in this model.
140        self.labels()
141            .get_index_of(label)
142            .unwrap_or_else(|| panic!("Variable `{label}` label does not exist."))
143    }
144
145    /// Maps a set of indices from another model to this model with the same labels.
146    ///
147    /// # Arguments
148    ///
149    /// * `x` - The set of indices in the other model.
150    /// * `other` - The labels of the other model.
151    ///
152    /// # Panics
153    ///
154    /// * If any index is out of bounds.
155    /// * If any label does not exist in this model.
156    ///
157    /// # Returns
158    ///
159    /// The set of indices in this model.
160    ///
161    #[inline]
162    fn indices_from(&self, x: &Set<usize>, other: &Labels) -> Set<usize> {
163        x.iter().map(|&x| self.index_from(x, other)).collect()
164    }
165}
166
167/// A trait for conditional probability distributions.
168pub trait CPD: Clone + Debug + Labelled + PartialEq + AbsDiffEq + RelativeEq {
169    /// The type of the support.
170    type Support;
171    /// The type of the parameters.
172    type Parameters;
173    /// The type of the sufficient statistics.
174    type Statistics;
175
176    /// Returns the labels of the conditioned variables.
177    ///
178    /// # Returns
179    ///
180    /// A reference to the conditioning labels.
181    ///
182    fn conditioning_labels(&self) -> &Labels;
183
184    /// Returns the parameters.
185    ///
186    /// # Returns
187    ///
188    /// A reference to the parameters.
189    ///
190    fn parameters(&self) -> &Self::Parameters;
191
192    /// Returns the parameters size.
193    ///
194    /// # Returns
195    ///
196    /// The parameters size.
197    ///
198    fn parameters_size(&self) -> usize;
199
200    /// Returns the sufficient statistics, if any.
201    ///
202    /// # Returns
203    ///
204    /// An option containing a reference to the sufficient statistics.
205    ///
206    fn sample_statistics(&self) -> Option<&Self::Statistics>;
207
208    /// Returns the log-likelihood of the fitted dataset, if any.
209    ///
210    /// # Returns
211    ///
212    /// An option containing the log-likelihood.
213    ///
214    fn sample_log_likelihood(&self) -> Option<f64>;
215
216    /// Returns the value of probability (mass or density) function for P(X = x | Z = z).
217    ///
218    /// # Arguments
219    ///
220    /// * `x` - The value of the conditioned variables.
221    /// * `z` - The value of the conditioning variables.
222    ///
223    /// # Returns
224    ///
225    /// The probability P(X = x | Z = z).
226    ///
227    fn pf(&self, x: &Self::Support, z: &Self::Support) -> f64;
228
229    /// Samples from the conditional distribution P(X | Z = z).
230    ///
231    /// # Arguments
232    ///
233    /// * `rng` - A mutable reference to a random number generator.
234    /// * `z` - The value of the conditioning variables.
235    ///
236    /// # Returns
237    ///
238    /// A sample from P(X | Z = z).
239    ///
240    fn sample<R: Rng>(&self, rng: &mut R, z: &Self::Support) -> Self::Support;
241}
242
243/// A trait for conditional intensity matrices.
244pub trait CIM: Clone + Debug + Labelled + PartialEq + AbsDiffEq + RelativeEq {
245    /// The type of the support.
246    type Support;
247    /// The type of the parameters.
248    type Parameters;
249    /// The type of the sufficient statistics.
250    type Statistics;
251
252    /// Returns the labels of the conditioned variables.
253    ///
254    /// # Returns
255    ///
256    /// A reference to the conditioning labels.
257    ///
258    fn conditioning_labels(&self) -> &Labels;
259
260    /// Returns the parameters.
261    ///
262    /// # Returns
263    ///
264    /// A reference to the parameters.
265    ///
266    fn parameters(&self) -> &Self::Parameters;
267
268    /// Returns the parameters size.
269    ///
270    /// # Returns
271    ///
272    /// The parameters size.
273    ///
274    fn parameters_size(&self) -> usize;
275
276    /// Returns the sufficient statistics, if any.
277    ///
278    /// # Returns
279    ///
280    /// An option containing a reference to the sufficient statistics.
281    ///
282    fn sample_statistics(&self) -> Option<&Self::Statistics>;
283
284    /// Returns the log-likelihood of the fitted dataset, if any.
285    ///
286    /// # Returns
287    ///
288    /// An option containing the log-likelihood.
289    ///
290    fn sample_log_likelihood(&self) -> Option<f64>;
291}
292
293/// A trait for potential functions.
294pub trait Phi:
295    Clone
296    + Debug
297    + Labelled
298    + PartialEq
299    + AbsDiffEq
300    + RelativeEq
301    + for<'a> MulAssign<&'a Self>
302    + for<'a> DivAssign<&'a Self>
303{
304    /// The type of the CPD.
305    type CPD;
306    /// The type of the parameters.
307    type Parameters;
308    /// The type of the evidence.
309    type Evidence;
310
311    /// Returns the parameters.
312    ///
313    /// # Returns
314    ///
315    /// A reference to the parameters.
316    ///
317    fn parameters(&self) -> &Self::Parameters;
318
319    /// Returns the parameters size.
320    ///
321    /// # Returns
322    ///
323    /// The parameters size.
324    ///
325    fn parameters_size(&self) -> usize;
326
327    /// Conditions the potential on a set of variables.
328    ///
329    /// # Arguments
330    ///
331    /// * `e` - A map from variable indices to their observed states.
332    ///
333    /// # Returns
334    ///
335    /// A new potential instance.
336    ///
337    fn condition(&self, e: &Self::Evidence) -> Self;
338
339    /// Marginalizes the potential over a set of variables.
340    ///
341    /// # Arguments
342    ///
343    /// * `x` - A set of variable indices to marginalize over.
344    ///
345    /// # Returns
346    ///
347    /// A new potential instance.
348    ///
349    fn marginalize(&self, x: &Set<usize>) -> Self;
350
351    /// Normalizes the potential.
352    ///
353    /// # Returns
354    ///
355    /// The normalized potential.
356    ///
357    fn normalize(&self) -> Self;
358
359    /// Converts a CPD P(X | Z) to a potential \phi(X \cup Z).
360    ///
361    /// # Arguments
362    ///
363    /// * `cpd` - The CPD to convert.
364    ///
365    /// # Returns
366    ///
367    /// The corresponding potential.
368    ///
369    fn from_cpd(cpd: Self::CPD) -> Self;
370
371    /// Converts a potential \phi(X \cup Z) to a CPD P(X | Z).
372    ///
373    /// # Arguments
374    ///
375    /// * `x` - The set of variables.
376    /// * `z` - The set of conditioning variables.
377    ///
378    /// # Returns
379    ///
380    /// The corresponding CPD.
381    ///
382    fn into_cpd(self, x: &Set<usize>, z: &Set<usize>) -> Self::CPD;
383}