causal_hub/io/bif/
parser.rs

1use itertools::Itertools;
2use ndarray::prelude::*;
3use pest::{Parser, iterators::Pair};
4use pest_derive::Parser;
5
6use crate::{
7    models::{BN, CPD, CatBN, CatCPD, DiGraph, Graph, Labelled},
8    types::{Map, States},
9};
10
11#[derive(Debug)]
12struct Network {
13    pub name: String,
14    pub properties: Vec<Property>,
15    pub variables: Vec<Variable>,
16    pub probabilities: Vec<Probability>,
17}
18
19#[derive(Debug)]
20struct Property {
21    pub key: String,
22    pub value: String,
23}
24
25#[derive(Debug)]
26struct Variable {
27    pub label: String,
28    pub states: Vec<String>,
29    pub _properties: Vec<Property>,
30}
31
32#[derive(Debug)]
33struct Probability {
34    pub label: String,
35    pub parents: Vec<String>,
36    pub table: Option<Vec<f64>>,                       // For flat CPTs
37    pub entries: Option<Vec<(Vec<String>, Vec<f64>)>>, // For conditional CPTs
38}
39
40/// BIF parser for parsing Bayesian Interchange Format (BIF) files.
41#[derive(Parser)]
42#[grammar = "src/io/bif/bif.pest"]
43pub struct BifParser;
44
45impl BifParser {
46    /// Read a BIF string and returns a `Network` object.
47    pub fn parse_str(bif: &str) -> CatBN {
48        let network = Self::parse(Rule::file, bif)
49            .expect("Failed to parse BIF file.")
50            .map(build_ast)
51            .next()
52            .expect("Failed to parse BIF file.");
53        // Get network properties.
54        let properties: Map<_, _> = network
55            .properties
56            .into_iter()
57            .map(|p| (p.key, p.value))
58            .collect();
59        // Get network name and description.
60        let name = Some(network.name);
61        let description = properties.get("description").cloned();
62        // Construct states.
63        let states: States = network
64            .variables
65            .into_iter()
66            .map(|v| (v.label, v.states.into_iter().collect()))
67            .collect();
68        // Construct CPDs.
69        let cpds: Vec<_> = network
70            .probabilities
71            .into_iter()
72            .map(|p| {
73                // Get the variable of the CPD.
74                let variable = States::from_iter([(
75                    p.label.clone(),
76                    states
77                        .get(&p.label)
78                        .expect("Failed to get variable states.")
79                        .clone(),
80                )]);
81                // Get the conditioning variables of the CPD.
82                let conditioning_variables: States = p
83                    .parents
84                    .iter()
85                    .map(|x| {
86                        let states = states.get(x).expect("Failed to get variable states.");
87                        (x.to_string(), states.iter().cloned().collect())
88                    })
89                    .collect();
90                // Map the probability values.
91                let parameters = match (p.table, p.entries) {
92                    (Some(table), None) => Array1::from_vec(table).insert_axis(Axis(0)),
93                    (None, Some(entries)) => {
94                        // Align the probability values with the states.
95                        let entries: Map<_, _> = entries.into_iter().collect();
96                        // Align the entries with the states.
97                        let entries: Vec<_> = conditioning_variables
98                            .iter()
99                            .map(|(_, states)| states)
100                            .cloned()
101                            .multi_cartesian_product()
102                            .map(|states| &entries[&states])
103                            .collect();
104                        // Get the shape of the parameters.
105                        let shape = (entries.len(), entries[0].len());
106                        // Collect the parameters.
107                        let parameters: Array1<_> =
108                            entries.into_iter().flatten().copied().collect();
109                        // Reshape the parameters.
110                        parameters
111                            .into_shape_with_order(shape)
112                            .expect("Failed to reshape parameters.")
113                    }
114                    _ => unreachable!(),
115                };
116                // Normalize the parameters so that they sum exactly to 1 by row.
117                let parameters = &parameters / parameters.sum_axis(Axis(1)).insert_axis(Axis(1));
118                // Construct the CPD.
119                CatCPD::new(variable, conditioning_variables, parameters)
120            })
121            .collect();
122
123        // Construct the graph.
124        let mut graph = DiGraph::empty(states.keys());
125        cpds.iter().for_each(|p| {
126            // Assert the CPD has a single variable in the BIF file.
127            assert_eq!(p.labels().len(), 1);
128            // Get child index.
129            let x = &p.labels()[0];
130            let x = graph
131                .labels()
132                .get_index_of(x)
133                .unwrap_or_else(|| panic!("Failed to get index of label '{x}'."));
134            // Get parent indices.
135            p.conditioning_labels().into_iter().for_each(|z| {
136                // Get parent index.
137                let z = graph
138                    .labels()
139                    .get_index_of(z)
140                    .unwrap_or_else(|| panic!("Failed to get index of label '{z}'."));
141                // Add edge from parent to child.
142                graph.add_edge(z, x);
143            });
144        });
145
146        // Construct the Bayesian network.
147        CatBN::with_optionals(name, description, graph, cpds)
148    }
149}
150
151fn build_ast(pair: Pair<Rule>) -> Network {
152    assert_eq!(pair.as_rule(), Rule::file);
153
154    let mut name = String::new();
155    let mut properties = vec![];
156    let mut variables = vec![];
157    let mut probabilities = vec![];
158
159    for item in pair.into_inner() {
160        match item.as_rule() {
161            Rule::network => {
162                let mut inner = item.into_inner();
163                name = inner.next().unwrap().as_str().to_string();
164                for p in inner {
165                    if p.as_rule() == Rule::property {
166                        properties.push(parse_property(p));
167                    }
168                }
169            }
170            Rule::variable => variables.push(parse_variable(item)),
171            Rule::probability => probabilities.push(parse_probability(item)),
172            _ => {}
173        }
174    }
175
176    Network {
177        name,
178        properties,
179        variables,
180        probabilities,
181    }
182}
183
184fn parse_property(pair: Pair<Rule>) -> Property {
185    let mut inner = pair.into_inner();
186    let key = inner.next().unwrap().as_str().to_string();
187    let value = inner.next().unwrap().as_str().to_string();
188
189    Property { key, value }
190}
191
192fn parse_variable(pair: Pair<Rule>) -> Variable {
193    let mut inner = pair.into_inner();
194    let label = inner.next().unwrap().as_str().to_string();
195
196    // Skip 'type discrete [n] { values } ;'
197    inner.next(); // n
198    let values_pair = inner.next().unwrap(); // { values }
199    let states = values_pair
200        .into_inner()
201        .map(|v| v.as_str().to_string())
202        .collect();
203
204    inner.next(); // ';'
205
206    let properties = inner
207        .filter(|p| p.as_rule() == Rule::property)
208        .map(parse_property)
209        .collect();
210
211    Variable {
212        label,
213        states,
214        _properties: properties,
215    }
216}
217
218fn parse_probability(pair: Pair<Rule>) -> Probability {
219    let mut inner = pair.into_inner();
220    let label = inner.next().unwrap().as_str().to_string();
221
222    let mut parents = vec![];
223    let mut table = None;
224    let mut entries = vec![];
225
226    let mut next = inner.next().unwrap();
227    if next.as_rule() == Rule::parents {
228        parents = next
229            .into_inner()
230            .next()
231            .unwrap()
232            .into_inner()
233            .map(|p| p.as_str().to_string())
234            .collect();
235        next = inner.next().unwrap(); // move to table or entry
236    }
237
238    match next.as_rule() {
239        Rule::number_list => {
240            table = Some(parse_number_list(next));
241        }
242        Rule::entry => {
243            entries.push(parse_entry(next));
244            for entry in inner {
245                if entry.as_rule() == Rule::entry {
246                    entries.push(parse_entry(entry));
247                }
248            }
249        }
250        _ => {}
251    }
252
253    let entries = if entries.is_empty() {
254        None
255    } else {
256        Some(entries)
257    };
258
259    Probability {
260        label,
261        parents,
262        table,
263        entries,
264    }
265}
266
267fn parse_entry(pair: Pair<Rule>) -> (Vec<String>, Vec<f64>) {
268    let mut inner = pair.into_inner();
269    let values = inner
270        .next()
271        .unwrap()
272        .into_inner()
273        .map(|v| v.as_str().to_string())
274        .collect();
275    let probs = parse_number_list(inner.next().unwrap());
276    (values, probs)
277}
278
279fn parse_number_list(pair: Pair<Rule>) -> Vec<f64> {
280    pair.into_inner()
281        .map(|n| n.as_str().parse::<f64>().unwrap())
282        .collect()
283}