causal_hub/estimators/structures/
continuous_time_hill_climbing.rs

1use itertools::Itertools;
2use rayon::iter::{IntoParallelIterator, ParallelIterator};
3
4use crate::{
5    estimators::{CIMEstimator, PK},
6    models::{CIM, CatCIM, DiGraph, Graph, Labelled},
7    set,
8    types::{Labels, Set},
9};
10
11/// A trait for scoring criteria used in score-based structure learning.
12pub trait ScoringCriterion {
13    /// Computes the score for a given variable and its conditioning set.
14    ///
15    /// # Arguments
16    ///
17    /// * `x` - The variable to score.
18    /// * `z` - The conditioning set.
19    ///
20    /// # Returns
21    ///
22    /// The computed score.
23    ///
24    fn call(&self, x: &Set<usize>, z: &Set<usize>) -> f64;
25}
26
27/// The Bayesian Information Criterion (BIC).
28pub struct BIC<'a, E> {
29    estimator: &'a E,
30}
31
32impl<'a, E> BIC<'a, E> {
33    /// Creates a new BIC instance.
34    ///
35    /// # Arguments
36    ///
37    /// * `estimator` - A reference to the estimator.
38    ///
39    /// # Returns
40    ///
41    /// A new `BIC` instance.
42    ///
43    #[inline]
44    pub const fn new(estimator: &'a E) -> Self {
45        Self { estimator }
46    }
47}
48
49impl<'a, E> Labelled for BIC<'a, E>
50where
51    E: Labelled,
52{
53    #[inline]
54    fn labels(&self) -> &Labels {
55        self.estimator.labels()
56    }
57}
58
59impl<E> ScoringCriterion for BIC<'_, E>
60where
61    E: CIMEstimator<CatCIM>,
62{
63    #[inline]
64    fn call(&self, x: &Set<usize>, z: &Set<usize>) -> f64 {
65        // Compute the intensity matrices for the sets.
66        let q_xz = self.estimator.fit(x, z);
67        // Get the sample size.
68        let n = q_xz
69            .sample_statistics()
70            .map(|s| s.sample_size())
71            .expect("Failed to get the sample size.");
72        // Get the log-likelihood.
73        let ll = q_xz
74            .sample_log_likelihood()
75            .expect("Failed to compute the log-likelihood.");
76        // Get the number of parameters.
77        let k = q_xz.parameters_size() as f64;
78
79        // Compute the BIC.
80        ll - 0.5 * k * f64::ln(n)
81    }
82}
83
84/// The hill climbing algorithm for structure learning in CTBNs.
85#[derive(Clone, Debug)]
86pub struct CTHC<'a, S> {
87    initial_graph: &'a DiGraph,
88    score: &'a S,
89    max_parents: Option<usize>,
90    prior_knowledge: Option<&'a PK>,
91}
92
93impl<'a, S> CTHC<'a, S>
94where
95    S: ScoringCriterion + Labelled,
96{
97    /// Creates a new continuous time hill climbing instance.
98    ///
99    /// # Arguments
100    ///
101    /// * `initial_graph` - The initial directed graph.
102    /// * `score` - The scoring criterion to use.
103    ///
104    /// # Returns
105    ///
106    /// A new `ContinuousTimeHillClimbing` instance.
107    ///
108    #[inline]
109    pub fn new(initial_graph: &'a DiGraph, score: &'a S) -> Self {
110        // Assert labels of the initial graph and the estimator are the same.
111        assert_eq!(
112            initial_graph.labels(),
113            score.labels(),
114            "Labels of initial graph and estimator must be the same: \n\
115            \t expected:    {:?}, \n\
116            \t found:       {:?}.",
117            initial_graph.labels(),
118            score.labels()
119        );
120
121        Self {
122            initial_graph,
123            score,
124            max_parents: None,
125            prior_knowledge: None,
126        }
127    }
128
129    /// Sets the maximum number of parents for each vertex.
130    ///
131    /// # Arguments
132    ///
133    /// * `max_parents` - The maximum number of parents for each vertex.
134    ///
135    /// # Returns
136    ///
137    /// A mutable reference to the current instance.
138    ///
139    #[inline]
140    pub const fn with_max_parents(mut self, max_parents: usize) -> Self {
141        self.max_parents = Some(max_parents);
142        self
143    }
144
145    /// Sets the prior knowledge for the algorithm.
146    ///
147    /// # Arguments
148    ///
149    /// * `prior_knowledge` - The prior knowledge to use.
150    ///
151    /// # Returns
152    ///
153    /// A mutable reference to the current instance.
154    ///
155    #[inline]
156    pub fn with_prior_knowledge(mut self, prior_knowledge: &'a PK) -> Self {
157        // Assert labels of prior knowledge and initial graph are the same.
158        assert_eq!(
159            self.initial_graph.labels(),
160            prior_knowledge.labels(),
161            "Labels of initial graph and prior knowledge must be the same: \n\
162            \t expected:    {:?}, \n\
163            \t found:       {:?}.",
164            self.initial_graph.labels(),
165            prior_knowledge.labels()
166        );
167        // Assert prior knowledge is consistent with initial graph.
168        self.initial_graph
169            .vertices()
170            .into_iter()
171            .permutations(2)
172            .for_each(|edge| {
173                // Get the edge indices.
174                let (i, j) = (edge[0], edge[1]);
175                // Assert edge must be either present and not forbidden ...
176                if self.initial_graph.has_edge(i, j) {
177                    assert!(
178                        !prior_knowledge.is_forbidden(i, j),
179                        "Initial graph contains forbidden edge ({i}, {j})."
180                    );
181                // ... or absent and not required.
182                } else {
183                    assert!(
184                        !prior_knowledge.is_required(i, j),
185                        "Initial graph does not contain required edge ({i}, {j})."
186                    );
187                }
188            });
189        // Set prior knowledge.
190        self.prior_knowledge = Some(prior_knowledge);
191        self
192    }
193
194    /// Execute the CTHC algorithm.
195    ///
196    /// # Returns
197    ///
198    /// The fitted graph.
199    ///
200    pub fn fit(&self) -> DiGraph {
201        // Clone the initial graph.
202        let mut graph = DiGraph::empty(self.initial_graph.labels());
203
204        // For each vertex in the graph ...
205        for i in self.initial_graph.vertices() {
206            // Initialize the previous score to negative infinity.
207            let mut prev_score = f64::NEG_INFINITY;
208
209            // Set the initial parent set as the current parent set.
210            let mut curr_pa = self.initial_graph.parents(&set![i]);
211            // Compute the score of the current parent set.
212            let mut curr_score = self.score.call(&set![i], &curr_pa);
213
214            // While the score of the current parent set is higher than the previous score ...
215            while prev_score < curr_score {
216                // Set the previous score to the score of the current parent set.
217                prev_score = curr_score;
218
219                // Get the candidate parent sets by adding ...
220                let poss_pa = {
221                    // Clone the current parent set.
222                    [curr_pa.clone()].into_iter().filter(|curr_pa|
223                        // Check if maximum parents has been reached.
224                        if let Some(max_parents) = self.max_parents {
225                            curr_pa.len() < max_parents
226                        } else {
227                            true
228                        }
229                    ).flat_map(|curr_pa| {
230                        // Get the vertices that are not in the current parent set.
231                        self.initial_graph
232                            .vertices()
233                            .into_iter()
234                            .filter_map(move |j| {
235                                if i != j {
236                                    // If the vertex is not in the current parent set ...
237                                    if let Err(p_j) = curr_pa.binary_search(&j) {
238                                        // Clone the current parent set.
239                                        let mut curr_pa = curr_pa.clone();
240                                        // Insert the vertex in order.
241                                        curr_pa.shift_insert(p_j, j);
242                                        // Return it as a candidate for addition.
243                                        return Some(curr_pa);
244                                    }
245                                }
246                                // Otherwise, the vertex is already present.
247                                None
248                            })
249                    })
250                }
251                // ... or removing vertices.
252                .chain({
253                    // Clone the current parent set.
254                    let curr_pa = curr_pa.clone();
255                    // Get the size of the candidate subset, avoid underflow.
256                    let k = curr_pa.len().saturating_sub(1);
257                    // Generate all the k-sized subsets.
258                    curr_pa.into_iter().combinations(k).map(Set::from_iter)
259                });
260
261                // For each candidate parent sets ...
262                for next_pa in poss_pa {
263                    // Compute the score of the candidate parent set.
264                    let next_score = self.score.call(&set![i], &next_pa);
265                    // If the score of the candidate parent set is higher ...
266                    if curr_score < next_score {
267                        // Update the current parent set to the candidate parent set.
268                        curr_pa = next_pa;
269                        // Update the score of the current parent set.
270                        curr_score = next_score;
271                    }
272                }
273            }
274
275            // Set the current parent set.
276            for j in curr_pa {
277                // Add an edge from vertex `j` to vertex `i`.
278                graph.add_edge(j, i);
279            }
280        }
281
282        // Return the final graph.
283        graph
284    }
285}
286
287impl<'a, S> CTHC<'a, S>
288where
289    S: ScoringCriterion + Sync,
290{
291    /// Execute the CTHC algorithm in parallel.
292    ///
293    /// # Returns
294    ///
295    /// The fitted graph.
296    ///
297    pub fn par_fit(&self) -> DiGraph {
298        // For each vertex in the graph ...
299        let parents: Vec<_> = self
300            .initial_graph
301            .vertices()
302            .into_par_iter()
303            .map(|i| {
304                // Initialize the previous score to negative infinity.
305                let mut prev_score = f64::NEG_INFINITY;
306
307                // Set the initial parent set as the current parent set.
308                let mut curr_pa = self.initial_graph.parents(&set![i]);
309                // Compute the score of the current parent set.
310                let mut curr_score = self.score.call(&set![i], &curr_pa);
311
312                // While the score of the current parent set is higher than the previous score ...
313                while prev_score < curr_score {
314                    // Set the previous score to the score of the current parent set.
315                    prev_score = curr_score;
316
317                    // Get the candidate parent sets by adding ...
318                    let poss_pa: Vec<_> = {
319                        // Clone the current parent set.
320                        [curr_pa.clone()].into_iter().filter(|curr_pa|
321                            // Check if maximum parents has been reached.
322                            if let Some(max_parents) = self.max_parents {
323                                curr_pa.len() < max_parents
324                            } else {
325                                true
326                            }
327                        ).flat_map(|curr_pa| {
328                            // Get the vertices that are not in the current parent set.
329                            self.initial_graph
330                                .vertices()
331                                .into_iter()
332                                .filter_map(move |j| {
333                                    if i != j {
334                                        // If the vertex is not in the current parent set ...
335                                        if let Err(p_j) = curr_pa.binary_search(&j) {
336                                            // Clone the current parent set.
337                                            let mut curr_pa = curr_pa.clone();
338                                            // Insert the vertex in order.
339                                            curr_pa.shift_insert(p_j, j);
340                                            // Return it as a candidate for addition.
341                                            return Some(curr_pa);
342                                        }
343                                    }
344                                    // Otherwise, the vertex is already present.
345                                    None
346                                })
347                        })
348                    }
349                    // ... or removing vertices.
350                    .chain({
351                        // Clone the current parent set.
352                        let curr_pa = curr_pa.clone();
353                        // Get the size of the candidate subset, avoid underflow.
354                        let k = curr_pa.len().saturating_sub(1);
355                        // Generate all the k-sized subsets.
356                        curr_pa.into_iter().combinations(k).map(Set::from_iter)
357                    })
358                    // Collect to allow for parallel iteration.
359                    .collect();
360
361                    // For each candidate parent sets ...
362                    if let Some((next_score, next_pa)) = poss_pa
363                        .into_par_iter()
364                        // Compute the score of the candidate parent set in parallel.
365                        .map(|next_pa| (self.score.call(&set![i], &next_pa), next_pa))
366                        // Get the one with the highest score in parallel.
367                        .max_by(|(a, _), (b, _)| a.partial_cmp(b).unwrap())
368                    {
369                        // If the score of the candidate parent set is higher ...
370                        if curr_score < next_score {
371                            // Update the current parent set to the candidate parent set.
372                            curr_pa = next_pa;
373                            // Update the score of the current parent set.
374                            curr_score = next_score;
375                        }
376                    }
377                }
378
379                // Return the current parent set.
380                curr_pa
381            })
382            .collect();
383
384        // Clone the initial graph.
385        let mut graph = DiGraph::empty(self.initial_graph.labels());
386
387        // Set the current parent set.
388        for (i, curr_pa) in parents.into_iter().enumerate() {
389            for j in curr_pa {
390                // Add an edge from vertex `j` to vertex `i`.
391                graph.add_edge(j, i);
392            }
393        }
394
395        // Return the final graph.
396        graph
397    }
398}