open_pql/runner/
mod.rs

1#![allow(clippy::wildcard_imports)]
2
3use std::{io, slice::Iter, sync::LazyLock};
4
5use derive_more::Debug;
6
7mod init_board;
8mod init_game;
9mod init_players;
10
11use concurrency::parallel_exec;
12use init_board::init_board;
13use init_game::init_game;
14use init_players::init_players;
15use pql_parser::{ast, parse};
16use rustc_hash::*;
17use vm::{Vm, VmStackValue, VmStackValueNum, VmStoreVarIdx, *};
18
19use super::*;
20use crate::{LocInfo, PQLLong, vm::VmInstruction};
21
22#[derive(Debug)]
23pub struct StatementsRunner<'src> {
24    pub(crate) src: &'src str,
25    pub n_trials: usize,
26    pub n_threads: usize,
27    #[debug(skip)]
28    pub stream_out: Box<dyn io::Write>,
29    #[debug(skip)]
30    pub stream_err: Box<dyn io::Write>,
31}
32
33type RangeProc = Box<dyn FnOnce() -> Result<PQLRange, PQLError> + Send>;
34
35/// # Panics
36/// ranges at least consists of one player and the board
37pub fn init_vm(stmt: &ast::Stmt, n_trials: usize) -> Result<Vm, PQLError> {
38    let fc = &stmt.from;
39    let sels = &stmt.selectors;
40
41    let game = init_game(fc)?;
42    let board_range_proc = init_board(fc);
43    let (player_names, mut ranges_procs) = init_players(fc, game);
44
45    ranges_procs.push(board_range_proc);
46    let mut ranges = parallel_exec(ranges_procs)?
47        .into_iter()
48        .collect::<Result<Vec<_>, _>>()?;
49    let board_range = ranges.pop().unwrap();
50
51    let (instructions, store) = instruction::init(sels, game, &player_names)?;
52
53    let player_ranges = ranges;
54
55    Ok(Vm {
56        board_range,
57        player_ranges,
58        instructions,
59        store,
60        stack: VmStack::default(),
61        buffer: VmBuffer::new(player_names.len(), game),
62        rng: Rng::default(),
63        n_trials,
64        n_failed: 0,
65    })
66}
67
68impl<'src> StatementsRunner<'src> {
69    pub const fn new(
70        src: &'src str,
71        n_trials: usize,
72        n_threads: usize,
73        stream_out: Box<dyn io::Write>,
74        stream_err: Box<dyn io::Write>,
75    ) -> Self {
76        Self {
77            src,
78            n_trials,
79            n_threads,
80            stream_out,
81            stream_err,
82        }
83    }
84
85    #[allow(clippy::missing_panics_doc)]
86    pub fn run(&mut self) {
87        match parse(self.src) {
88            Ok(stmts) => {
89                for stmt in &stmts {
90                    match self.run_stmt(stmt) {
91                        Ok(()) => (),
92                        Err(e) => self.report_error(&e),
93                    }
94                }
95
96                self.stream_out
97                    .flush()
98                    .expect("Failed to write to output stream");
99            }
100            Err(e) => {
101                self.report_error(&PQLError::from(e));
102            }
103        }
104    }
105
106    fn run_stmt(&mut self, stmt: &ast::Stmt) -> Result<(), PQLError> {
107        let n = self.n_threads;
108        let n_trials = self.n_trials / n;
109
110        let vm = init_vm(stmt, n_trials)?;
111
112        let mut vms: Vec<Vm> = Vec::with_capacity(n);
113        vms.push(vm);
114
115        while vms.len() != n {
116            vms.push(vms[0].clone());
117        }
118
119        let procs = vms
120            .into_iter()
121            .map(|vm| Box::new(move || vm.try_run()) as _)
122            .collect::<Vec<_>>();
123
124        match (parallel_exec(procs)?)
125            .into_iter()
126            .collect::<Result<Vec<Vm>, PQLError>>()
127        {
128            Ok(vms) => {
129                self.aggregate_outputs(&vms, &stmt.selectors, self.n_trials);
130            }
131
132            Err(e) => self.report_error(&e),
133        }
134
135        Ok(())
136    }
137
138    fn report_error(&mut self, e: &PQLError) {
139        let loc: Option<LocInfo> = e.into();
140
141        writeln!(self.stream_err, "Error:")
142            .and_then(|()| writeln!(self.stream_err, "{e}"))
143            .and_then(|()| {
144                if let Some((a, b)) = loc {
145                    writeln!(self.stream_err, "{}", &self.src[a..b])
146                } else {
147                    Ok(())
148                }
149            })
150            .expect("Failed to write to error stream");
151    }
152
153    #[allow(clippy::cast_precision_loss)]
154    fn aggregate_outputs(
155        &mut self,
156        vms: &[Vm],
157        selectors: &[ast::Selector],
158        n_trials: usize,
159    ) {
160        fn next_write(iter: &mut Iter<'_, VmInstruction>) -> VmStoreVarIdx {
161            loop {
162                match iter.next() {
163                    Some(VmInstruction::Write(idx)) => return *idx,
164                    Some(_) => (),
165                    None => todo!(),
166                }
167            }
168        }
169
170        let vec_ins = vms[0].instructions.clone();
171        let _game: PQLGame = (&vms[0].buffer).into();
172        let mut iter = vec_ins.iter();
173
174        for (n, selector) in selectors.iter().enumerate() {
175            let idx = next_write(&mut iter);
176
177            let name = selector.alias.as_ref().map_or_else(
178                || format!("{:?}{}", selector.kind, n + 1),
179                |id| id.inner.into(),
180            );
181
182            let res = match selector.kind {
183                ast::SelectorKind::Avg => {
184                    let v: VmStackValue = vms
185                        .iter()
186                        .map(|vm| {
187                            *vm.store
188                                .downcast_get::<&VmStackValue>(idx)
189                                .unwrap()
190                        })
191                        .reduce(|m, e| m.try_add(e).unwrap())
192                        .unwrap();
193
194                    let c: PQLLong = vms
195                        .iter()
196                        .map(|vm| {
197                            *vm.store.downcast_get::<&PQLLong>(idx + 1).unwrap()
198                        })
199                        .sum::<PQLLong>();
200
201                    let n: VmStackValueNum = v.try_into().unwrap();
202                    writeln!(
203                        self.stream_out,
204                        "{name} = {}",
205                        n.cast_dbl() / c as f64
206                    )
207                }
208
209                ast::SelectorKind::Count => {
210                    let c: PQLLong = vms
211                        .iter()
212                        .map(|vm| {
213                            *vm.store.downcast_get::<&PQLLong>(idx).unwrap()
214                        })
215                        .sum::<PQLLong>();
216
217                    writeln!(
218                        self.stream_out,
219                        "{name} = {}%({})",
220                        100.0 * c as f64 / n_trials as f64,
221                        c
222                    )
223                }
224
225                ast::SelectorKind::Max => {
226                    let v: VmStackValue = vms
227                        .iter()
228                        .map(|vm| {
229                            *vm.store
230                                .downcast_get::<&VmStackValue>(idx)
231                                .unwrap()
232                        })
233                        .reduce(|m, e| if m >= e { m } else { e })
234                        .unwrap();
235
236                    writeln!(self.stream_out, "{name} = {v}")
237                }
238
239                ast::SelectorKind::Min => {
240                    let v: VmStackValue = vms
241                        .iter()
242                        .map(|vm| {
243                            *vm.store
244                                .downcast_get::<&VmStackValue>(idx)
245                                .unwrap()
246                        })
247                        .reduce(|m, e| if m <= e { m } else { e })
248                        .unwrap();
249
250                    writeln!(self.stream_out, "{name} = {v}")
251                }
252            };
253
254            res.expect("Failed to write to output stream");
255        }
256    }
257}