causal_hub/io/bif/
parser.rs1use itertools::Itertools;
2use ndarray::prelude::*;
3use pest::{Parser, iterators::Pair};
4use pest_derive::Parser;
5
6use crate::{
7 models::{BN, CPD, CatBN, CatCPD, DiGraph, Graph, Labelled},
8 types::{Map, States},
9};
10
11#[derive(Debug)]
12struct Network {
13 pub name: String,
14 pub properties: Vec<Property>,
15 pub variables: Vec<Variable>,
16 pub probabilities: Vec<Probability>,
17}
18
19#[derive(Debug)]
20struct Property {
21 pub key: String,
22 pub value: String,
23}
24
25#[derive(Debug)]
26struct Variable {
27 pub label: String,
28 pub states: Vec<String>,
29 pub _properties: Vec<Property>,
30}
31
32#[derive(Debug)]
33struct Probability {
34 pub label: String,
35 pub parents: Vec<String>,
36 pub table: Option<Vec<f64>>, pub entries: Option<Vec<(Vec<String>, Vec<f64>)>>, }
39
40#[derive(Parser)]
42#[grammar = "src/io/bif/bif.pest"]
43pub struct BifParser;
44
45impl BifParser {
46 pub fn parse_str(bif: &str) -> CatBN {
48 let network = Self::parse(Rule::file, bif)
49 .expect("Failed to parse BIF file.")
50 .map(build_ast)
51 .next()
52 .expect("Failed to parse BIF file.");
53 let properties: Map<_, _> = network
55 .properties
56 .into_iter()
57 .map(|p| (p.key, p.value))
58 .collect();
59 let name = Some(network.name);
61 let description = properties.get("description").cloned();
62 let states: States = network
64 .variables
65 .into_iter()
66 .map(|v| (v.label, v.states.into_iter().collect()))
67 .collect();
68 let cpds: Vec<_> = network
70 .probabilities
71 .into_iter()
72 .map(|p| {
73 let variable = States::from_iter([(
75 p.label.clone(),
76 states
77 .get(&p.label)
78 .expect("Failed to get variable states.")
79 .clone(),
80 )]);
81 let conditioning_variables: States = p
83 .parents
84 .iter()
85 .map(|x| {
86 let states = states.get(x).expect("Failed to get variable states.");
87 (x.to_string(), states.iter().cloned().collect())
88 })
89 .collect();
90 let parameters = match (p.table, p.entries) {
92 (Some(table), None) => Array1::from_vec(table).insert_axis(Axis(0)),
93 (None, Some(entries)) => {
94 let entries: Map<_, _> = entries.into_iter().collect();
96 let entries: Vec<_> = conditioning_variables
98 .iter()
99 .map(|(_, states)| states)
100 .cloned()
101 .multi_cartesian_product()
102 .map(|states| &entries[&states])
103 .collect();
104 let shape = (entries.len(), entries[0].len());
106 let parameters: Array1<_> =
108 entries.into_iter().flatten().copied().collect();
109 parameters
111 .into_shape_with_order(shape)
112 .expect("Failed to reshape parameters.")
113 }
114 _ => unreachable!(),
115 };
116 let parameters = ¶meters / parameters.sum_axis(Axis(1)).insert_axis(Axis(1));
118 CatCPD::new(variable, conditioning_variables, parameters)
120 })
121 .collect();
122
123 let mut graph = DiGraph::empty(states.keys());
125 cpds.iter().for_each(|p| {
126 assert_eq!(p.labels().len(), 1);
128 let x = &p.labels()[0];
130 let x = graph
131 .labels()
132 .get_index_of(x)
133 .unwrap_or_else(|| panic!("Failed to get index of label '{x}'."));
134 p.conditioning_labels().into_iter().for_each(|z| {
136 let z = graph
138 .labels()
139 .get_index_of(z)
140 .unwrap_or_else(|| panic!("Failed to get index of label '{z}'."));
141 graph.add_edge(z, x);
143 });
144 });
145
146 CatBN::with_optionals(name, description, graph, cpds)
148 }
149}
150
151fn build_ast(pair: Pair<Rule>) -> Network {
152 assert_eq!(pair.as_rule(), Rule::file);
153
154 let mut name = String::new();
155 let mut properties = vec![];
156 let mut variables = vec![];
157 let mut probabilities = vec![];
158
159 for item in pair.into_inner() {
160 match item.as_rule() {
161 Rule::network => {
162 let mut inner = item.into_inner();
163 name = inner.next().unwrap().as_str().to_string();
164 for p in inner {
165 if p.as_rule() == Rule::property {
166 properties.push(parse_property(p));
167 }
168 }
169 }
170 Rule::variable => variables.push(parse_variable(item)),
171 Rule::probability => probabilities.push(parse_probability(item)),
172 _ => {}
173 }
174 }
175
176 Network {
177 name,
178 properties,
179 variables,
180 probabilities,
181 }
182}
183
184fn parse_property(pair: Pair<Rule>) -> Property {
185 let mut inner = pair.into_inner();
186 let key = inner.next().unwrap().as_str().to_string();
187 let value = inner.next().unwrap().as_str().to_string();
188
189 Property { key, value }
190}
191
192fn parse_variable(pair: Pair<Rule>) -> Variable {
193 let mut inner = pair.into_inner();
194 let label = inner.next().unwrap().as_str().to_string();
195
196 inner.next(); let values_pair = inner.next().unwrap(); let states = values_pair
200 .into_inner()
201 .map(|v| v.as_str().to_string())
202 .collect();
203
204 inner.next(); let properties = inner
207 .filter(|p| p.as_rule() == Rule::property)
208 .map(parse_property)
209 .collect();
210
211 Variable {
212 label,
213 states,
214 _properties: properties,
215 }
216}
217
218fn parse_probability(pair: Pair<Rule>) -> Probability {
219 let mut inner = pair.into_inner();
220 let label = inner.next().unwrap().as_str().to_string();
221
222 let mut parents = vec![];
223 let mut table = None;
224 let mut entries = vec![];
225
226 let mut next = inner.next().unwrap();
227 if next.as_rule() == Rule::parents {
228 parents = next
229 .into_inner()
230 .next()
231 .unwrap()
232 .into_inner()
233 .map(|p| p.as_str().to_string())
234 .collect();
235 next = inner.next().unwrap(); }
237
238 match next.as_rule() {
239 Rule::number_list => {
240 table = Some(parse_number_list(next));
241 }
242 Rule::entry => {
243 entries.push(parse_entry(next));
244 for entry in inner {
245 if entry.as_rule() == Rule::entry {
246 entries.push(parse_entry(entry));
247 }
248 }
249 }
250 _ => {}
251 }
252
253 let entries = if entries.is_empty() {
254 None
255 } else {
256 Some(entries)
257 };
258
259 Probability {
260 label,
261 parents,
262 table,
263 entries,
264 }
265}
266
267fn parse_entry(pair: Pair<Rule>) -> (Vec<String>, Vec<f64>) {
268 let mut inner = pair.into_inner();
269 let values = inner
270 .next()
271 .unwrap()
272 .into_inner()
273 .map(|v| v.as_str().to_string())
274 .collect();
275 let probs = parse_number_list(inner.next().unwrap());
276 (values, probs)
277}
278
279fn parse_number_list(pair: Pair<Rule>) -> Vec<f64> {
280 pair.into_inner()
281 .map(|n| n.as_str().parse::<f64>().unwrap())
282 .collect()
283}