open_pql/runner/
mod.rs

1use std::{io, slice::Iter, sync::LazyLock};
2
3use derive_more::Debug;
4
5mod init_board;
6mod init_game;
7mod init_players;
8
9use concurrency::parallel_exec;
10use init_board::init_board;
11use init_game::init_game;
12use init_players::init_players;
13use pql_parser::{ast, parse};
14use rustc_hash::*;
15use vm::{Vm, VmStackValue, VmStackValueNum, VmStoreVarIdx, *};
16
17use super::*;
18use crate::{LocInfo, PQLLong, vm::VmInstruction};
19
20#[derive(Debug)]
21pub struct StatementsRunner<'src> {
22    pub(crate) src: &'src str,
23    pub n_trials: usize,
24    pub n_threads: usize,
25    #[debug(skip)]
26    pub stream_out: Box<dyn io::Write>,
27    #[debug(skip)]
28    pub stream_err: Box<dyn io::Write>,
29}
30
31type RangeProc = Box<dyn FnOnce() -> Result<PQLRange, PQLError> + Send>;
32
33/// # Panics
34/// ranges at least consists of one player and the board
35pub fn init_vm(stmt: &ast::Stmt, n_trials: usize) -> Result<Vm, PQLError> {
36    let fc = &stmt.from;
37    let sels = &stmt.selectors;
38
39    let game = init_game(fc)?;
40    let board_range_proc = init_board(fc);
41    let (player_names, mut ranges_procs) = init_players(fc, game);
42
43    ranges_procs.push(board_range_proc);
44    let mut ranges = parallel_exec(ranges_procs)?
45        .into_iter()
46        .collect::<Result<Vec<_>, _>>()?;
47    let board_range = ranges.pop().unwrap();
48
49    let (instructions, store) = instruction::init(sels, game, &player_names)?;
50
51    let player_ranges = ranges;
52
53    Ok(Vm {
54        board_range,
55        player_ranges,
56        instructions,
57        store,
58        stack: VmStack::default(),
59        buffer: VmBuffer::new(player_names.len(), game),
60        rng: Rng::default(),
61        n_trials,
62        n_failed: 0,
63    })
64}
65
66impl<'src> StatementsRunner<'src> {
67    pub const fn new(
68        src: &'src str,
69        n_trials: usize,
70        n_threads: usize,
71        stream_out: Box<dyn io::Write>,
72        stream_err: Box<dyn io::Write>,
73    ) -> Self {
74        Self {
75            src,
76            n_trials,
77            n_threads,
78            stream_out,
79            stream_err,
80        }
81    }
82
83    #[allow(clippy::missing_panics_doc)]
84    pub fn run(&mut self) {
85        match parse(self.src) {
86            Ok(stmts) => {
87                for stmt in &stmts {
88                    match self.run_stmt(stmt) {
89                        Ok(()) => (),
90                        Err(e) => self.report_error(&e),
91                    }
92                }
93
94                self.stream_out
95                    .flush()
96                    .expect("Failed to write to output stream");
97            }
98            Err(e) => {
99                self.report_error(&PQLError::from(e));
100            }
101        }
102    }
103
104    fn run_stmt(&mut self, stmt: &ast::Stmt) -> Result<(), PQLError> {
105        let n = self.n_threads;
106        let n_trials = self.n_trials / n;
107
108        let vm = init_vm(stmt, n_trials)?;
109
110        let mut vms: Vec<Vm> = Vec::with_capacity(n);
111        vms.push(vm);
112
113        while vms.len() != n {
114            vms.push(vms[0].clone());
115        }
116
117        let procs = vms
118            .into_iter()
119            .map(|vm| Box::new(move || vm.try_run()) as _)
120            .collect::<Vec<_>>();
121
122        match (parallel_exec(procs)?)
123            .into_iter()
124            .collect::<Result<Vec<Vm>, PQLError>>()
125        {
126            Ok(vms) => {
127                self.aggregate_outputs(&vms, &stmt.selectors, self.n_trials);
128            }
129
130            Err(e) => self.report_error(&e),
131        }
132
133        Ok(())
134    }
135
136    fn report_error(&mut self, e: &PQLError) {
137        let loc: Option<LocInfo> = e.into();
138
139        writeln!(self.stream_err, "Error:")
140            .and_then(|()| writeln!(self.stream_err, "{e}"))
141            .and_then(|()| {
142                if let Some((a, b)) = loc {
143                    writeln!(self.stream_err, "{}", &self.src[a..b])
144                } else {
145                    Ok(())
146                }
147            })
148            .expect("Failed to write to error stream");
149    }
150
151    #[allow(clippy::cast_precision_loss)]
152    fn aggregate_outputs(
153        &mut self,
154        vms: &[Vm],
155        selectors: &[ast::Selector],
156        n_trials: usize,
157    ) {
158        fn next_write(iter: &mut Iter<'_, VmInstruction>) -> VmStoreVarIdx {
159            loop {
160                match iter.next() {
161                    Some(VmInstruction::Write(idx)) => return *idx,
162                    Some(_) => (),
163                    None => todo!(),
164                }
165            }
166        }
167
168        let vec_ins = vms[0].instructions.clone();
169        let _game: PQLGame = (&vms[0].buffer).into();
170        let mut iter = vec_ins.iter();
171
172        for (n, selector) in selectors.iter().enumerate() {
173            let idx = next_write(&mut iter);
174
175            let name = selector.alias.as_ref().map_or_else(
176                || format!("{:?}{}", selector.kind, n + 1),
177                |id| id.inner.into(),
178            );
179
180            let res = match selector.kind {
181                ast::SelectorKind::Avg => {
182                    let v: VmStackValue = vms
183                        .iter()
184                        .map(|vm| {
185                            *vm.store
186                                .downcast_get::<&VmStackValue>(idx)
187                                .unwrap()
188                        })
189                        .reduce(|m, e| m.try_add(e).unwrap())
190                        .unwrap();
191
192                    let c: PQLLong = vms
193                        .iter()
194                        .map(|vm| {
195                            *vm.store.downcast_get::<&PQLLong>(idx + 1).unwrap()
196                        })
197                        .sum::<PQLLong>();
198
199                    let n: VmStackValueNum = v.try_into().unwrap();
200                    writeln!(
201                        self.stream_out,
202                        "{name} = {}",
203                        n.cast_dbl() / c as f64
204                    )
205                }
206
207                ast::SelectorKind::Count => {
208                    let c: PQLLong = vms
209                        .iter()
210                        .map(|vm| {
211                            *vm.store.downcast_get::<&PQLLong>(idx).unwrap()
212                        })
213                        .sum::<PQLLong>();
214
215                    writeln!(
216                        self.stream_out,
217                        "{name} = {}%({})",
218                        100.0 * c as f64 / n_trials as f64,
219                        c
220                    )
221                }
222
223                ast::SelectorKind::Max => {
224                    let v: VmStackValue = vms
225                        .iter()
226                        .map(|vm| {
227                            *vm.store
228                                .downcast_get::<&VmStackValue>(idx)
229                                .unwrap()
230                        })
231                        .reduce(|m, e| if m >= e { m } else { e })
232                        .unwrap();
233
234                    writeln!(self.stream_out, "{name} = {v}")
235                }
236
237                ast::SelectorKind::Min => {
238                    let v: VmStackValue = vms
239                        .iter()
240                        .map(|vm| {
241                            *vm.store
242                                .downcast_get::<&VmStackValue>(idx)
243                                .unwrap()
244                        })
245                        .reduce(|m, e| if m <= e { m } else { e })
246                        .unwrap();
247
248                    writeln!(self.stream_out, "{name} = {v}")
249                }
250            };
251
252            res.expect("Failed to write to output stream");
253        }
254    }
255}