causal_hub/estimators/structures/continuous_time_hill_climbing.rs
1use itertools::Itertools;
2use rayon::iter::{IntoParallelIterator, ParallelIterator};
3
4use crate::{
5 estimators::{CIMEstimator, PK},
6 models::{CIM, CatCIM, DiGraph, Graph, Labelled},
7 set,
8 types::{Labels, Set},
9};
10
11/// A trait for scoring criteria used in score-based structure learning.
12pub trait ScoringCriterion {
13 /// Computes the score for a given variable and its conditioning set.
14 ///
15 /// # Arguments
16 ///
17 /// * `x` - The variable to score.
18 /// * `z` - The conditioning set.
19 ///
20 /// # Returns
21 ///
22 /// The computed score.
23 ///
24 fn call(&self, x: &Set<usize>, z: &Set<usize>) -> f64;
25}
26
27/// The Bayesian Information Criterion (BIC).
28pub struct BIC<'a, E> {
29 estimator: &'a E,
30}
31
32impl<'a, E> BIC<'a, E> {
33 /// Creates a new BIC instance.
34 ///
35 /// # Arguments
36 ///
37 /// * `estimator` - A reference to the estimator.
38 ///
39 /// # Returns
40 ///
41 /// A new `BIC` instance.
42 ///
43 #[inline]
44 pub const fn new(estimator: &'a E) -> Self {
45 Self { estimator }
46 }
47}
48
49impl<'a, E> Labelled for BIC<'a, E>
50where
51 E: Labelled,
52{
53 #[inline]
54 fn labels(&self) -> &Labels {
55 self.estimator.labels()
56 }
57}
58
59impl<E> ScoringCriterion for BIC<'_, E>
60where
61 E: CIMEstimator<CatCIM>,
62{
63 #[inline]
64 fn call(&self, x: &Set<usize>, z: &Set<usize>) -> f64 {
65 // Compute the intensity matrices for the sets.
66 let q_xz = self.estimator.fit(x, z);
67 // Get the sample size.
68 let n = q_xz
69 .sample_statistics()
70 .map(|s| s.sample_size())
71 .expect("Failed to get the sample size.");
72 // Get the log-likelihood.
73 let ll = q_xz
74 .sample_log_likelihood()
75 .expect("Failed to compute the log-likelihood.");
76 // Get the number of parameters.
77 let k = q_xz.parameters_size() as f64;
78
79 // Compute the BIC.
80 ll - 0.5 * k * f64::ln(n)
81 }
82}
83
84/// The hill climbing algorithm for structure learning in CTBNs.
85#[derive(Clone, Debug)]
86pub struct CTHC<'a, S> {
87 initial_graph: &'a DiGraph,
88 score: &'a S,
89 max_parents: Option<usize>,
90 prior_knowledge: Option<&'a PK>,
91}
92
93impl<'a, S> CTHC<'a, S>
94where
95 S: ScoringCriterion + Labelled,
96{
97 /// Creates a new continuous time hill climbing instance.
98 ///
99 /// # Arguments
100 ///
101 /// * `initial_graph` - The initial directed graph.
102 /// * `score` - The scoring criterion to use.
103 ///
104 /// # Returns
105 ///
106 /// A new `ContinuousTimeHillClimbing` instance.
107 ///
108 #[inline]
109 pub fn new(initial_graph: &'a DiGraph, score: &'a S) -> Self {
110 // Assert labels of the initial graph and the estimator are the same.
111 assert_eq!(
112 initial_graph.labels(),
113 score.labels(),
114 "Labels of initial graph and estimator must be the same: \n\
115 \t expected: {:?}, \n\
116 \t found: {:?}.",
117 initial_graph.labels(),
118 score.labels()
119 );
120
121 Self {
122 initial_graph,
123 score,
124 max_parents: None,
125 prior_knowledge: None,
126 }
127 }
128
129 /// Sets the maximum number of parents for each vertex.
130 ///
131 /// # Arguments
132 ///
133 /// * `max_parents` - The maximum number of parents for each vertex.
134 ///
135 /// # Returns
136 ///
137 /// A mutable reference to the current instance.
138 ///
139 #[inline]
140 pub const fn with_max_parents(mut self, max_parents: usize) -> Self {
141 self.max_parents = Some(max_parents);
142 self
143 }
144
145 /// Sets the prior knowledge for the algorithm.
146 ///
147 /// # Arguments
148 ///
149 /// * `prior_knowledge` - The prior knowledge to use.
150 ///
151 /// # Returns
152 ///
153 /// A mutable reference to the current instance.
154 ///
155 #[inline]
156 pub fn with_prior_knowledge(mut self, prior_knowledge: &'a PK) -> Self {
157 // Assert labels of prior knowledge and initial graph are the same.
158 assert_eq!(
159 self.initial_graph.labels(),
160 prior_knowledge.labels(),
161 "Labels of initial graph and prior knowledge must be the same: \n\
162 \t expected: {:?}, \n\
163 \t found: {:?}.",
164 self.initial_graph.labels(),
165 prior_knowledge.labels()
166 );
167 // Assert prior knowledge is consistent with initial graph.
168 self.initial_graph
169 .vertices()
170 .into_iter()
171 .permutations(2)
172 .for_each(|edge| {
173 // Get the edge indices.
174 let (i, j) = (edge[0], edge[1]);
175 // Assert edge must be either present and not forbidden ...
176 if self.initial_graph.has_edge(i, j) {
177 assert!(
178 !prior_knowledge.is_forbidden(i, j),
179 "Initial graph contains forbidden edge ({i}, {j})."
180 );
181 // ... or absent and not required.
182 } else {
183 assert!(
184 !prior_knowledge.is_required(i, j),
185 "Initial graph does not contain required edge ({i}, {j})."
186 );
187 }
188 });
189 // Set prior knowledge.
190 self.prior_knowledge = Some(prior_knowledge);
191 self
192 }
193
194 /// Execute the CTHC algorithm.
195 ///
196 /// # Returns
197 ///
198 /// The fitted graph.
199 ///
200 pub fn fit(&self) -> DiGraph {
201 // Clone the initial graph.
202 let mut graph = DiGraph::empty(self.initial_graph.labels());
203
204 // For each vertex in the graph ...
205 for i in self.initial_graph.vertices() {
206 // Initialize the previous score to negative infinity.
207 let mut prev_score = f64::NEG_INFINITY;
208
209 // Set the initial parent set as the current parent set.
210 let mut curr_pa = self.initial_graph.parents(&set![i]);
211 // Compute the score of the current parent set.
212 let mut curr_score = self.score.call(&set![i], &curr_pa);
213
214 // While the score of the current parent set is higher than the previous score ...
215 while prev_score < curr_score {
216 // Set the previous score to the score of the current parent set.
217 prev_score = curr_score;
218
219 // Get the candidate parent sets by adding ...
220 let poss_pa = {
221 // Clone the current parent set.
222 [curr_pa.clone()].into_iter().filter(|curr_pa|
223 // Check if maximum parents has been reached.
224 if let Some(max_parents) = self.max_parents {
225 curr_pa.len() < max_parents
226 } else {
227 true
228 }
229 ).flat_map(|curr_pa| {
230 // Get the vertices that are not in the current parent set.
231 self.initial_graph
232 .vertices()
233 .into_iter()
234 .filter_map(move |j| {
235 if i != j {
236 // If the vertex is not in the current parent set ...
237 if let Err(p_j) = curr_pa.binary_search(&j) {
238 // Clone the current parent set.
239 let mut curr_pa = curr_pa.clone();
240 // Insert the vertex in order.
241 curr_pa.shift_insert(p_j, j);
242 // Return it as a candidate for addition.
243 return Some(curr_pa);
244 }
245 }
246 // Otherwise, the vertex is already present.
247 None
248 })
249 })
250 }
251 // ... or removing vertices.
252 .chain({
253 // Clone the current parent set.
254 let curr_pa = curr_pa.clone();
255 // Get the size of the candidate subset, avoid underflow.
256 let k = curr_pa.len().saturating_sub(1);
257 // Generate all the k-sized subsets.
258 curr_pa.into_iter().combinations(k).map(Set::from_iter)
259 });
260
261 // For each candidate parent sets ...
262 for next_pa in poss_pa {
263 // Compute the score of the candidate parent set.
264 let next_score = self.score.call(&set![i], &next_pa);
265 // If the score of the candidate parent set is higher ...
266 if curr_score < next_score {
267 // Update the current parent set to the candidate parent set.
268 curr_pa = next_pa;
269 // Update the score of the current parent set.
270 curr_score = next_score;
271 }
272 }
273 }
274
275 // Set the current parent set.
276 for j in curr_pa {
277 // Add an edge from vertex `j` to vertex `i`.
278 graph.add_edge(j, i);
279 }
280 }
281
282 // Return the final graph.
283 graph
284 }
285}
286
287impl<'a, S> CTHC<'a, S>
288where
289 S: ScoringCriterion + Sync,
290{
291 /// Execute the CTHC algorithm in parallel.
292 ///
293 /// # Returns
294 ///
295 /// The fitted graph.
296 ///
297 pub fn par_fit(&self) -> DiGraph {
298 // For each vertex in the graph ...
299 let parents: Vec<_> = self
300 .initial_graph
301 .vertices()
302 .into_par_iter()
303 .map(|i| {
304 // Initialize the previous score to negative infinity.
305 let mut prev_score = f64::NEG_INFINITY;
306
307 // Set the initial parent set as the current parent set.
308 let mut curr_pa = self.initial_graph.parents(&set![i]);
309 // Compute the score of the current parent set.
310 let mut curr_score = self.score.call(&set![i], &curr_pa);
311
312 // While the score of the current parent set is higher than the previous score ...
313 while prev_score < curr_score {
314 // Set the previous score to the score of the current parent set.
315 prev_score = curr_score;
316
317 // Get the candidate parent sets by adding ...
318 let poss_pa: Vec<_> = {
319 // Clone the current parent set.
320 [curr_pa.clone()].into_iter().filter(|curr_pa|
321 // Check if maximum parents has been reached.
322 if let Some(max_parents) = self.max_parents {
323 curr_pa.len() < max_parents
324 } else {
325 true
326 }
327 ).flat_map(|curr_pa| {
328 // Get the vertices that are not in the current parent set.
329 self.initial_graph
330 .vertices()
331 .into_iter()
332 .filter_map(move |j| {
333 if i != j {
334 // If the vertex is not in the current parent set ...
335 if let Err(p_j) = curr_pa.binary_search(&j) {
336 // Clone the current parent set.
337 let mut curr_pa = curr_pa.clone();
338 // Insert the vertex in order.
339 curr_pa.shift_insert(p_j, j);
340 // Return it as a candidate for addition.
341 return Some(curr_pa);
342 }
343 }
344 // Otherwise, the vertex is already present.
345 None
346 })
347 })
348 }
349 // ... or removing vertices.
350 .chain({
351 // Clone the current parent set.
352 let curr_pa = curr_pa.clone();
353 // Get the size of the candidate subset, avoid underflow.
354 let k = curr_pa.len().saturating_sub(1);
355 // Generate all the k-sized subsets.
356 curr_pa.into_iter().combinations(k).map(Set::from_iter)
357 })
358 // Collect to allow for parallel iteration.
359 .collect();
360
361 // For each candidate parent sets ...
362 if let Some((next_score, next_pa)) = poss_pa
363 .into_par_iter()
364 // Compute the score of the candidate parent set in parallel.
365 .map(|next_pa| (self.score.call(&set![i], &next_pa), next_pa))
366 // Get the one with the highest score in parallel.
367 .max_by(|(a, _), (b, _)| a.partial_cmp(b).unwrap())
368 {
369 // If the score of the candidate parent set is higher ...
370 if curr_score < next_score {
371 // Update the current parent set to the candidate parent set.
372 curr_pa = next_pa;
373 // Update the score of the current parent set.
374 curr_score = next_score;
375 }
376 }
377 }
378
379 // Return the current parent set.
380 curr_pa
381 })
382 .collect();
383
384 // Clone the initial graph.
385 let mut graph = DiGraph::empty(self.initial_graph.labels());
386
387 // Set the current parent set.
388 for (i, curr_pa) in parents.into_iter().enumerate() {
389 for j in curr_pa {
390 // Add an edge from vertex `j` to vertex `i`.
391 graph.add_edge(j, i);
392 }
393 }
394
395 // Return the final graph.
396 graph
397 }
398}