causal_hub/estimators/parameters/mod.rs
1mod bayesian;
2pub use bayesian::*;
3
4mod expectation_maximization;
5pub use expectation_maximization::*;
6
7mod maximum_likelihood;
8pub use maximum_likelihood::*;
9
10mod sufficient_statistics;
11pub use sufficient_statistics::*;
12
13mod raw;
14pub use raw::*;
15use rayon::prelude::*;
16
17use crate::{
18 models::{BN, CIM, CPD, CTBN, DiGraph, Graph},
19 set,
20 types::Set,
21};
22
23/// A trait for sufficient statistics estimators.
24pub trait CSSEstimator<T> {
25 /// Fits the estimator to the dataset and returns the conditional sufficient statistics.
26 ///
27 /// # Arguments
28 ///
29 /// * `x` - The variable to fit the estimator to.
30 /// * `z` - The variables to condition on.
31 ///
32 /// # Returns
33 ///
34 /// The sufficient statistics.
35 ///
36 fn fit(&self, x: &Set<usize>, z: &Set<usize>) -> T;
37}
38
39/// A trait for sufficient statistics estimators in parallel.
40pub trait ParCSSEstimator<T> {
41 /// Fits the estimator to the dataset and returns the conditional sufficient statistics in parallel.
42 ///
43 /// # Arguments
44 ///
45 /// * `x` - The variable to fit the estimator to.
46 /// * `z` - The variables to condition on.
47 ///
48 /// # Returns
49 ///
50 /// The sufficient statistics.
51 ///
52 fn par_fit(&self, x: &Set<usize>, z: &Set<usize>) -> T;
53}
54
55/// A trait for conditional probability distribution estimators.
56pub trait CPDEstimator<T>
57where
58 T: CPD,
59{
60 /// Fits the estimator to the dataset and returns a CPD.
61 ///
62 /// # Arguments
63 ///
64 /// * `x` - The variable to fit the estimator to.
65 /// * `z` - The variables to condition on.
66 ///
67 /// # Returns
68 ///
69 /// The estimated CPD.
70 ///
71 fn fit(&self, x: &Set<usize>, z: &Set<usize>) -> T;
72}
73
74/// A trait for conditional probability distribution estimators in parallel.
75pub trait ParCPDEstimator<T>
76where
77 T: CPD,
78{
79 /// Fits the estimator to the dataset and returns a CPD in parallel.
80 ///
81 /// # Arguments
82 ///
83 /// * `x` - The variable to fit the estimator to.
84 /// * `z` - The variables to condition on.
85 ///
86 /// # Returns
87 ///
88 /// The estimated CPD.
89 ///
90 fn par_fit(&self, x: &Set<usize>, z: &Set<usize>) -> T;
91}
92
93/// A trait for Bayesian network estimators.
94pub trait BNEstimator<T> {
95 /// Fits the estimator to the dataset and returns a Bayesian network.
96 ///
97 /// # Arguments
98 ///
99 /// * `graph` - The graph to fit the estimator to.
100 ///
101 /// # Returns
102 ///
103 /// The estimated Bayesian network.
104 ///
105 fn fit(&self, graph: DiGraph) -> T;
106}
107
108/// Blanket implement for all BN estimators with a corresponding CPD estimator.
109impl<T, E> BNEstimator<T> for E
110where
111 T: BN,
112 T::CPD: CPD,
113 E: CPDEstimator<T::CPD>,
114{
115 fn fit(&self, graph: DiGraph) -> T {
116 // Fit the parameters of the distribution using the estimator.
117 let cpds: Vec<_> = graph
118 .vertices()
119 .into_iter()
120 .map(|i| {
121 let i = set![i];
122 self.fit(&i, &graph.parents(&i))
123 })
124 .collect();
125 // Construct the BN with the graph and the parameters.
126 T::new(graph, cpds)
127 }
128}
129
130/// A trait for parallel Bayesian network estimators.
131pub trait ParBNEstimator<T> {
132 /// Fits the estimator to the dataset and returns a Bayesian network in parallel.
133 ///
134 /// # Arguments
135 ///
136 /// * `graph` - The graph to fit the estimator to.
137 ///
138 /// # Returns
139 ///
140 /// The estimated Bayesian network.
141 ///
142 fn par_fit(&self, graph: DiGraph) -> T;
143}
144
145/// Blanket implement for all BN estimators with a corresponding CPD estimator.
146impl<T, E> ParBNEstimator<T> for E
147where
148 T: BN,
149 T::CPD: CPD + Send,
150 E: ParCPDEstimator<T::CPD> + Sync,
151{
152 fn par_fit(&self, graph: DiGraph) -> T {
153 // Fit the parameters of the distribution using the estimator.
154 let cpds: Vec<_> = graph
155 .vertices()
156 .into_par_iter()
157 .map(|i| {
158 let i = set![i];
159 self.par_fit(&i, &graph.parents(&i))
160 })
161 .collect();
162 // Construct the BN with the graph and the parameters.
163 T::new(graph, cpds)
164 }
165}
166
167/// A trait for conditional intensity matrix estimators.
168pub trait CIMEstimator<T>
169where
170 T: CIM,
171{
172 /// Fits the estimator to the dataset and returns a CIM.
173 ///
174 /// # Arguments
175 ///
176 /// * `x` - The variable to fit the estimator to.
177 /// * `z` - The variables to condition on.
178 ///
179 /// # Returns
180 ///
181 /// The estimated CIM.
182 ///
183 fn fit(&self, x: &Set<usize>, z: &Set<usize>) -> T;
184}
185
186/// A trait for conditional intensity matrix estimators in parallel.
187pub trait ParCIMEstimator<T>
188where
189 T: CIM,
190{
191 /// Fits the estimator to the dataset and returns a CIM in parallel.
192 ///
193 /// # Arguments
194 ///
195 /// * `x` - The variable to fit the estimator to.
196 /// * `z` - The variables to condition on.
197 ///
198 /// # Returns
199 ///
200 /// The estimated CIM.
201 ///
202 fn par_fit(&self, x: &Set<usize>, z: &Set<usize>) -> T;
203}
204
205/// A trait for CTBN estimators.
206pub trait CTBNEstimator<T> {
207 /// Fits the estimator to the trajectory and returns a CTBN.
208 ///
209 /// # Arguments
210 ///
211 /// * `graph` - The graph to fit the estimator to.
212 ///
213 /// # Returns
214 ///
215 /// The estimated CTBN.
216 ///
217 fn fit(&self, graph: DiGraph) -> T;
218}
219
220/// Blanket implement for all CTBN estimators with a corresponding CIM estimator.
221impl<T, E> CTBNEstimator<T> for E
222where
223 T: CTBN,
224 T::CIM: CIM,
225 E: CIMEstimator<T::CIM>,
226{
227 fn fit(&self, graph: DiGraph) -> T {
228 // Fit the parameters of the distribution using the estimator.
229 let cims: Vec<_> = graph
230 .vertices()
231 .into_iter()
232 .map(|i| {
233 let i = set![i];
234 self.fit(&i, &graph.parents(&i))
235 })
236 .collect();
237 // Construct the CTBN with the graph and the parameters.
238 T::new(graph, cims)
239 }
240}
241
242/// A trait for parallel CTBN estimators.
243pub trait ParCTBNEstimator<T> {
244 /// Fits the estimator to the trajectory and returns a CTBN in parallel.
245 ///
246 /// # Arguments
247 ///
248 /// * `graph` - The graph to fit the estimator to.
249 ///
250 /// # Returns
251 ///
252 /// The estimated CTBN.
253 ///
254 fn par_fit(&self, graph: DiGraph) -> T;
255}
256
257/// Blanket implement for all CTBN estimators with a corresponding CIM estimator.
258impl<T, E> ParCTBNEstimator<T> for E
259where
260 T: CTBN,
261 T::CIM: CIM + Send,
262 E: ParCIMEstimator<T::CIM> + Sync,
263{
264 fn par_fit(&self, graph: DiGraph) -> T {
265 // Fit the parameters of the distribution using the estimator.
266 let cims: Vec<_> = graph
267 .vertices()
268 .into_par_iter()
269 .map(|i| {
270 let i = set![i];
271 self.par_fit(&i, &graph.parents(&i))
272 })
273 .collect();
274 // Construct the CTBN with the graph and the parameters.
275 T::new(graph, cims)
276 }
277}