reCTBN/structure_learning/
score_function.rs

1//! Module for score based algorithms containing score functions algorithms like Log Likelihood, BIC, etc...
2
3use std::collections::BTreeSet;
4
5use ndarray::prelude::*;
6use statrs::function::gamma;
7
8use crate::{parameter_learning, params, process, tools};
9use log::debug;
10
11/// It defines the required methods for a decomposable ScoreFunction functor over a `NetworkProcess`
12pub trait ScoreFunction: Sync {
13    /// Compute the score function for a node its parentset given a dataset.
14    /// # Arguments
15    ///
16    /// * `net`: `NetworkProcess` object.
17    /// * `node`: Node target for the decomposable score.
18    /// * `parent_set`: parentset of the `node`.
19    /// * `dataset`: instantiation of the `struct tools::Dataset` containing the
20    ///              observations used to compute the score.
21    ///
22    /// # Return
23    ///
24    /// * A `float` representing the score of the node given the dataset.
25    fn call<T>(
26        &self,
27        net: &T,
28        node: usize,
29        parent_set: &BTreeSet<usize>,
30        dataset: &tools::Dataset,
31    ) -> f64
32    where
33        T: process::NetworkProcess;
34}
35
36/// LogLikelihood for a `NetworkProcess`
37pub struct LogLikelihood {
38    alpha: usize,
39    tau: f64,
40}
41
42impl LogLikelihood {
43    /// Create a `struct LogLikelihood`
44    ///
45    /// # Arguments
46    ///
47    /// * `alpha`: pseudo count (immaginary  number of transitions)
48    /// * `tau`: pseudo residence time (immaginary residence time)
49    pub fn new(alpha: usize, tau: f64) -> LogLikelihood {
50        //Tau must be >=0.0
51        if tau < 0.0 {
52            panic!("tau must be >=0.0");
53        }
54        LogLikelihood { alpha, tau }
55    }
56
57    fn compute_score<T>(
58        &self,
59        net: &T,
60        node: usize,
61        parent_set: &BTreeSet<usize>,
62        dataset: &tools::Dataset,
63    ) -> (f64, Array3<usize>)
64    where
65        T: process::NetworkProcess,
66    {
67        //Identify the type of node used
68        match &net.get_node(node) {
69            params::Params::DiscreteStatesContinousTime(_params) => {
70                //Compute the sufficient statistics M (number of transistions) and T (residence
71                //time)
72                let (M, T) =
73                    parameter_learning::sufficient_statistics(net, dataset, node, parent_set);
74
75                //Scale alpha accordingly to the size of the parent set
76                let alpha = self.alpha as f64 / M.shape()[0] as f64;
77                //Scale tau accordingly to the size of the parent set
78                let tau = self.tau / M.shape()[0] as f64;
79
80                //Compute the log likelihood for q
81                let log_ll_q: f64 = M
82                    .sum_axis(Axis(2))
83                    .iter()
84                    .zip(T.iter())
85                    .map(|(m, t)| {
86                        gamma::ln_gamma(alpha + *m as f64 + 1.0) + (alpha + 1.0) * f64::ln(tau)
87                            - gamma::ln_gamma(alpha + 1.0)
88                            - (alpha + *m as f64 + 1.0) * f64::ln(tau + t)
89                    })
90                    .sum();
91
92                //Compute the log likelihood for theta
93                let log_ll_theta: f64 = M
94                    .outer_iter()
95                    .map(|x| {
96                        x.outer_iter()
97                            .map(|y| {
98                                gamma::ln_gamma(alpha) - gamma::ln_gamma(alpha + y.sum() as f64)
99                                    + y.iter()
100                                        .map(|z| {
101                                            gamma::ln_gamma(alpha + *z as f64)
102                                                - gamma::ln_gamma(alpha)
103                                        })
104                                        .sum::<f64>()
105                            })
106                            .sum::<f64>()
107                    })
108                    .sum();
109                (log_ll_theta + log_ll_q, M)
110            }
111        }
112    }
113}
114
115impl ScoreFunction for LogLikelihood {
116    fn call<T>(
117        &self,
118        net: &T,
119        node: usize,
120        parent_set: &BTreeSet<usize>,
121        dataset: &tools::Dataset,
122    ) -> f64
123    where
124        T: process::NetworkProcess,
125    {
126        let score = self.compute_score(net, node, parent_set, dataset).0;
127        debug!(
128            "Node: {} - Parentset: {:?} - score: {}",
129            node, parent_set, score
130        );
131        score
132    }
133}
134
135/// BIC for a `train NetworkProcess`
136pub struct BIC {
137    ll: LogLikelihood,
138}
139
140impl BIC {
141    /// Create a `struct BIC`
142    ///
143    /// # Arguments
144    ///
145    /// * `alpha`: pseudo count (immaginary  number of transitions)
146    /// * `tau`: pseudo residence time (immaginary residence time)
147    pub fn new(alpha: usize, tau: f64) -> BIC {
148        BIC {
149            ll: LogLikelihood::new(alpha, tau),
150        }
151    }
152}
153
154impl ScoreFunction for BIC {
155    fn call<T>(
156        &self,
157        net: &T,
158        node: usize,
159        parent_set: &BTreeSet<usize>,
160        dataset: &tools::Dataset,
161    ) -> f64
162    where
163        T: process::NetworkProcess,
164    {
165        //Compute the log-likelihood
166        let (ll, M) = self.ll.compute_score(net, node, parent_set, dataset);
167        //Compute the number of parameters
168        let n_parameters = M.shape()[0] * M.shape()[1] * (M.shape()[2] - 1);
169        //TODO: Optimize this
170        //Compute the sample size
171        let sample_size: usize = dataset
172            .get_trajectories()
173            .iter()
174            .map(|x| x.get_time().len() - 1)
175            .sum();
176        //Compute BIC
177        let score = ll - f64::ln(sample_size as f64) / 2.0 * n_parameters as f64;
178        debug!(
179            "Node: {} - Parentset: {:?} - score: {}",
180            node, parent_set, score
181        );
182        score
183    }
184}