causal_hub/estimators/parameters/
mod.rs

1mod bayesian;
2pub use bayesian::*;
3
4mod expectation_maximization;
5pub use expectation_maximization::*;
6
7mod maximum_likelihood;
8pub use maximum_likelihood::*;
9
10mod sufficient_statistics;
11pub use sufficient_statistics::*;
12
13mod raw;
14pub use raw::*;
15use rayon::prelude::*;
16
17use crate::{
18    models::{BN, CIM, CPD, CTBN, DiGraph, Graph},
19    set,
20    types::Set,
21};
22
23/// A trait for sufficient statistics estimators.
24pub trait CSSEstimator<T> {
25    /// Fits the estimator to the dataset and returns the conditional sufficient statistics.
26    ///
27    /// # Arguments
28    ///
29    /// * `x` - The variable to fit the estimator to.
30    /// * `z` - The variables to condition on.
31    ///
32    /// # Returns
33    ///
34    /// The sufficient statistics.
35    ///
36    fn fit(&self, x: &Set<usize>, z: &Set<usize>) -> T;
37}
38
39/// A trait for sufficient statistics estimators in parallel.
40pub trait ParCSSEstimator<T> {
41    /// Fits the estimator to the dataset and returns the conditional sufficient statistics in parallel.
42    ///
43    /// # Arguments
44    ///
45    /// * `x` - The variable to fit the estimator to.
46    /// * `z` - The variables to condition on.
47    ///
48    /// # Returns
49    ///
50    /// The sufficient statistics.
51    ///
52    fn par_fit(&self, x: &Set<usize>, z: &Set<usize>) -> T;
53}
54
55/// A trait for conditional probability distribution estimators.
56pub trait CPDEstimator<T>
57where
58    T: CPD,
59{
60    /// Fits the estimator to the dataset and returns a CPD.
61    ///
62    /// # Arguments
63    ///
64    /// * `x` - The variable to fit the estimator to.
65    /// * `z` - The variables to condition on.
66    ///
67    /// # Returns
68    ///
69    /// The estimated CPD.
70    ///
71    fn fit(&self, x: &Set<usize>, z: &Set<usize>) -> T;
72}
73
74/// A trait for conditional probability distribution estimators in parallel.
75pub trait ParCPDEstimator<T>
76where
77    T: CPD,
78{
79    /// Fits the estimator to the dataset and returns a CPD in parallel.
80    ///
81    /// # Arguments
82    ///
83    /// * `x` - The variable to fit the estimator to.
84    /// * `z` - The variables to condition on.
85    ///
86    /// # Returns
87    ///
88    /// The estimated CPD.
89    ///
90    fn par_fit(&self, x: &Set<usize>, z: &Set<usize>) -> T;
91}
92
93/// A trait for Bayesian network estimators.
94pub trait BNEstimator<T> {
95    /// Fits the estimator to the dataset and returns a Bayesian network.
96    ///
97    /// # Arguments
98    ///
99    /// * `graph` - The graph to fit the estimator to.
100    ///
101    /// # Returns
102    ///
103    /// The estimated Bayesian network.
104    ///
105    fn fit(&self, graph: DiGraph) -> T;
106}
107
108/// Blanket implement for all BN estimators with a corresponding CPD estimator.
109impl<T, E> BNEstimator<T> for E
110where
111    T: BN,
112    T::CPD: CPD,
113    E: CPDEstimator<T::CPD>,
114{
115    fn fit(&self, graph: DiGraph) -> T {
116        // Fit the parameters of the distribution using the estimator.
117        let cpds: Vec<_> = graph
118            .vertices()
119            .into_iter()
120            .map(|i| {
121                let i = set![i];
122                self.fit(&i, &graph.parents(&i))
123            })
124            .collect();
125        // Construct the BN with the graph and the parameters.
126        T::new(graph, cpds)
127    }
128}
129
130/// A trait for parallel Bayesian network estimators.
131pub trait ParBNEstimator<T> {
132    /// Fits the estimator to the dataset and returns a Bayesian network in parallel.
133    ///
134    /// # Arguments
135    ///
136    /// * `graph` - The graph to fit the estimator to.
137    ///
138    /// # Returns
139    ///
140    /// The estimated Bayesian network.
141    ///
142    fn par_fit(&self, graph: DiGraph) -> T;
143}
144
145/// Blanket implement for all BN estimators with a corresponding CPD estimator.
146impl<T, E> ParBNEstimator<T> for E
147where
148    T: BN,
149    T::CPD: CPD + Send,
150    E: ParCPDEstimator<T::CPD> + Sync,
151{
152    fn par_fit(&self, graph: DiGraph) -> T {
153        // Fit the parameters of the distribution using the estimator.
154        let cpds: Vec<_> = graph
155            .vertices()
156            .into_par_iter()
157            .map(|i| {
158                let i = set![i];
159                self.par_fit(&i, &graph.parents(&i))
160            })
161            .collect();
162        // Construct the BN with the graph and the parameters.
163        T::new(graph, cpds)
164    }
165}
166
167/// A trait for conditional intensity matrix estimators.
168pub trait CIMEstimator<T>
169where
170    T: CIM,
171{
172    /// Fits the estimator to the dataset and returns a CIM.
173    ///
174    /// # Arguments
175    ///
176    /// * `x` - The variable to fit the estimator to.
177    /// * `z` - The variables to condition on.
178    ///
179    /// # Returns
180    ///
181    /// The estimated CIM.
182    ///
183    fn fit(&self, x: &Set<usize>, z: &Set<usize>) -> T;
184}
185
186/// A trait for conditional intensity matrix estimators in parallel.
187pub trait ParCIMEstimator<T>
188where
189    T: CIM,
190{
191    /// Fits the estimator to the dataset and returns a CIM in parallel.
192    ///
193    /// # Arguments
194    ///
195    /// * `x` - The variable to fit the estimator to.
196    /// * `z` - The variables to condition on.
197    ///
198    /// # Returns
199    ///
200    /// The estimated CIM.
201    ///
202    fn par_fit(&self, x: &Set<usize>, z: &Set<usize>) -> T;
203}
204
205/// A trait for CTBN estimators.
206pub trait CTBNEstimator<T> {
207    /// Fits the estimator to the trajectory and returns a CTBN.
208    ///
209    /// # Arguments
210    ///
211    /// * `graph` - The graph to fit the estimator to.
212    ///
213    /// # Returns
214    ///
215    /// The estimated CTBN.
216    ///
217    fn fit(&self, graph: DiGraph) -> T;
218}
219
220/// Blanket implement for all CTBN estimators with a corresponding CIM estimator.
221impl<T, E> CTBNEstimator<T> for E
222where
223    T: CTBN,
224    T::CIM: CIM,
225    E: CIMEstimator<T::CIM>,
226{
227    fn fit(&self, graph: DiGraph) -> T {
228        // Fit the parameters of the distribution using the estimator.
229        let cims: Vec<_> = graph
230            .vertices()
231            .into_iter()
232            .map(|i| {
233                let i = set![i];
234                self.fit(&i, &graph.parents(&i))
235            })
236            .collect();
237        // Construct the CTBN with the graph and the parameters.
238        T::new(graph, cims)
239    }
240}
241
242/// A trait for parallel CTBN estimators.
243pub trait ParCTBNEstimator<T> {
244    /// Fits the estimator to the trajectory and returns a CTBN in parallel.
245    ///
246    /// # Arguments
247    ///
248    /// * `graph` - The graph to fit the estimator to.
249    ///
250    /// # Returns
251    ///
252    /// The estimated CTBN.
253    ///
254    fn par_fit(&self, graph: DiGraph) -> T;
255}
256
257/// Blanket implement for all CTBN estimators with a corresponding CIM estimator.
258impl<T, E> ParCTBNEstimator<T> for E
259where
260    T: CTBN,
261    T::CIM: CIM + Send,
262    E: ParCIMEstimator<T::CIM> + Sync,
263{
264    fn par_fit(&self, graph: DiGraph) -> T {
265        // Fit the parameters of the distribution using the estimator.
266        let cims: Vec<_> = graph
267            .vertices()
268            .into_par_iter()
269            .map(|i| {
270                let i = set![i];
271                self.par_fit(&i, &graph.parents(&i))
272            })
273            .collect();
274        // Construct the CTBN with the graph and the parameters.
275        T::new(graph, cims)
276    }
277}