open_pql/runner/
mod.rs

1use std::{io, slice::Iter};
2
3use derive_more::Debug;
4
5mod init_board;
6mod init_game;
7mod init_players;
8
9use concurrency::parallel_exec;
10use init_board::init_board;
11use init_game::init_game;
12use init_players::init_players;
13use lazy_static::lazy_static;
14use pql_parser::{ast, parse};
15use rustc_hash::*;
16use vm::{Vm, VmStackValue, VmStackValueNum, VmStoreVarIdx, *};
17
18use super::*;
19use crate::{vm::VmInstruction, LocInfo, PQLLong};
20
21#[derive(Debug)]
22pub struct StatementsRunner<'src> {
23    pub(crate) src: &'src str,
24    pub n_trials: usize,
25    pub n_threads: usize,
26    #[debug(skip)]
27    pub stream_out: Box<dyn io::Write>,
28    #[debug(skip)]
29    pub stream_err: Box<dyn io::Write>,
30}
31
32type RangeProc = Box<dyn FnOnce() -> Result<PQLRange, PQLError> + Send>;
33
34/// # Panics
35/// ranges at least consists of one player and the board
36pub fn init_vm(stmt: &ast::Stmt, n_trials: usize) -> Result<Vm, PQLError> {
37    let fc = &stmt.from;
38    let sels = &stmt.selectors;
39
40    let game = init_game(fc)?;
41    let board_range_proc = init_board(fc);
42    let (player_names, mut ranges_procs) = init_players(fc, game);
43
44    ranges_procs.push(board_range_proc);
45    let mut ranges = parallel_exec(ranges_procs)?
46        .into_iter()
47        .collect::<Result<Vec<_>, _>>()?;
48    let board_range = ranges.pop().unwrap();
49
50    let (instructions, store) = instruction::init(sels, game, &player_names)?;
51
52    let player_ranges = ranges;
53
54    Ok(Vm {
55        board_range,
56        player_ranges,
57        instructions,
58        store,
59        stack: VmStack::default(),
60        buffer: VmBuffer::new(player_names.len(), game),
61        rng: Rng::default(),
62        n_trials,
63        n_failed: 0,
64    })
65}
66
67impl<'src> StatementsRunner<'src> {
68    pub const fn new(
69        src: &'src str,
70        n_trials: usize,
71        n_threads: usize,
72        stream_out: Box<dyn io::Write>,
73        stream_err: Box<dyn io::Write>,
74    ) -> Self {
75        Self {
76            src,
77            n_trials,
78            n_threads,
79            stream_out,
80            stream_err,
81        }
82    }
83
84    #[allow(clippy::missing_panics_doc)]
85    pub fn run(&mut self) {
86        match parse(self.src) {
87            Ok(stmts) => {
88                for stmt in &stmts {
89                    match self.run_stmt(stmt) {
90                        Ok(()) => (),
91                        Err(e) => self.report_error(&e),
92                    }
93                }
94
95                self.stream_out
96                    .flush()
97                    .expect("Failed to write to output stream");
98            }
99            Err(e) => {
100                self.report_error(&PQLError::from(e));
101            }
102        }
103    }
104
105    fn run_stmt(&mut self, stmt: &ast::Stmt) -> Result<(), PQLError> {
106        let n = self.n_threads;
107        let n_trials = self.n_trials / n;
108
109        let vm = init_vm(stmt, n_trials)?;
110
111        let mut vms: Vec<Vm> = Vec::with_capacity(n);
112        vms.push(vm);
113
114        while vms.len() != n {
115            vms.push(vms[0].clone());
116        }
117
118        let procs = vms
119            .into_iter()
120            .map(|vm| Box::new(move || vm.try_run()) as _)
121            .collect::<Vec<_>>();
122
123        match (parallel_exec(procs)?)
124            .into_iter()
125            .collect::<Result<Vec<Vm>, PQLError>>()
126        {
127            Ok(vms) => {
128                self.aggregate_outputs(&vms, &stmt.selectors, self.n_trials);
129            }
130
131            Err(e) => self.report_error(&e),
132        }
133
134        Ok(())
135    }
136
137    fn report_error(&mut self, e: &PQLError) {
138        let loc: Option<LocInfo> = e.into();
139
140        writeln!(self.stream_err, "Error:")
141            .and_then(|()| writeln!(self.stream_err, "{e}"))
142            .and_then(|()| {
143                if let Some((a, b)) = loc {
144                    writeln!(self.stream_err, "{}", &self.src[a..b])
145                } else {
146                    Ok(())
147                }
148            })
149            .expect("Failed to write to error stream");
150    }
151
152    #[allow(clippy::cast_precision_loss)]
153    fn aggregate_outputs(
154        &mut self,
155        vms: &[Vm],
156        selectors: &[ast::Selector],
157        n_trials: usize,
158    ) {
159        fn next_write(iter: &mut Iter<'_, VmInstruction>) -> VmStoreVarIdx {
160            loop {
161                match iter.next() {
162                    Some(VmInstruction::Write(idx)) => return *idx,
163                    Some(_) => (),
164                    None => todo!(),
165                };
166            }
167        }
168
169        let vec_ins = vms[0].instructions.clone();
170        let _game: PQLGame = (&vms[0].buffer).into();
171        let mut iter = vec_ins.iter();
172
173        for (n, selector) in selectors.iter().enumerate() {
174            let idx = next_write(&mut iter);
175
176            let name = selector.alias.as_ref().map_or_else(
177                || format!("{:?}{}", selector.kind, n + 1),
178                |id| id.inner.into(),
179            );
180
181            let res = match selector.kind {
182                ast::SelectorKind::Avg => {
183                    let v: VmStackValue = vms
184                        .iter()
185                        .map(|vm| {
186                            *vm.store
187                                .downcast_get::<&VmStackValue>(idx)
188                                .unwrap()
189                        })
190                        .reduce(|m, e| m.try_add(e).unwrap())
191                        .unwrap();
192
193                    let c: PQLLong = vms
194                        .iter()
195                        .map(|vm| {
196                            *vm.store.downcast_get::<&PQLLong>(idx + 1).unwrap()
197                        })
198                        .sum::<PQLLong>();
199
200                    let n: VmStackValueNum = v.try_into().unwrap();
201                    writeln!(
202                        self.stream_out,
203                        "{name} = {}",
204                        n.cast_dbl() / c as f64
205                    )
206                }
207
208                ast::SelectorKind::Count => {
209                    let c: PQLLong = vms
210                        .iter()
211                        .map(|vm| {
212                            *vm.store.downcast_get::<&PQLLong>(idx).unwrap()
213                        })
214                        .sum::<PQLLong>();
215
216                    writeln!(
217                        self.stream_out,
218                        "{name} = {}%({})",
219                        100.0 * c as f64 / n_trials as f64,
220                        c
221                    )
222                }
223
224                ast::SelectorKind::Max => {
225                    let v: VmStackValue = vms
226                        .iter()
227                        .map(|vm| {
228                            *vm.store
229                                .downcast_get::<&VmStackValue>(idx)
230                                .unwrap()
231                        })
232                        .reduce(|m, e| (if m >= e { m } else { e }))
233                        .unwrap();
234
235                    writeln!(self.stream_out, "{name} = {v}")
236                }
237
238                ast::SelectorKind::Min => {
239                    let v: VmStackValue = vms
240                        .iter()
241                        .map(|vm| {
242                            *vm.store
243                                .downcast_get::<&VmStackValue>(idx)
244                                .unwrap()
245                        })
246                        .reduce(|m, e| (if m <= e { m } else { e }))
247                        .unwrap();
248
249                    writeln!(self.stream_out, "{name} = {v}")
250                }
251            };
252
253            res.expect("Failed to write to output stream");
254        }
255    }
256}