1use std::{io, slice::Iter};
2
3use derive_more::Debug;
4
5mod init_board;
6mod init_game;
7mod init_players;
8
9use concurrency::parallel_exec;
10use init_board::init_board;
11use init_game::init_game;
12use init_players::init_players;
13use lazy_static::lazy_static;
14use pql_parser::{ast, parse};
15use rustc_hash::*;
16use vm::{Vm, VmStackValue, VmStackValueNum, VmStoreVarIdx, *};
17
18use super::*;
19use crate::{vm::VmInstruction, LocInfo, PQLLong};
20
21#[derive(Debug)]
22pub struct StatementsRunner<'src> {
23 pub(crate) src: &'src str,
24 pub n_trials: usize,
25 pub n_threads: usize,
26 #[debug(skip)]
27 pub stream_out: Box<dyn io::Write>,
28 #[debug(skip)]
29 pub stream_err: Box<dyn io::Write>,
30}
31
32type RangeProc = Box<dyn FnOnce() -> Result<PQLRange, PQLError> + Send>;
33
34pub fn init_vm(stmt: &ast::Stmt, n_trials: usize) -> Result<Vm, PQLError> {
37 let fc = &stmt.from;
38 let sels = &stmt.selectors;
39
40 let game = init_game(fc)?;
41 let board_range_proc = init_board(fc);
42 let (player_names, mut ranges_procs) = init_players(fc, game);
43
44 ranges_procs.push(board_range_proc);
45 let mut ranges = parallel_exec(ranges_procs)?
46 .into_iter()
47 .collect::<Result<Vec<_>, _>>()?;
48 let board_range = ranges.pop().unwrap();
49
50 let (instructions, store) = instruction::init(sels, game, &player_names)?;
51
52 let player_ranges = ranges;
53
54 Ok(Vm {
55 board_range,
56 player_ranges,
57 instructions,
58 store,
59 stack: VmStack::default(),
60 buffer: VmBuffer::new(player_names.len(), game),
61 rng: Rng::default(),
62 n_trials,
63 n_failed: 0,
64 })
65}
66
67impl<'src> StatementsRunner<'src> {
68 pub const fn new(
69 src: &'src str,
70 n_trials: usize,
71 n_threads: usize,
72 stream_out: Box<dyn io::Write>,
73 stream_err: Box<dyn io::Write>,
74 ) -> Self {
75 Self {
76 src,
77 n_trials,
78 n_threads,
79 stream_out,
80 stream_err,
81 }
82 }
83
84 #[allow(clippy::missing_panics_doc)]
85 pub fn run(&mut self) {
86 match parse(self.src) {
87 Ok(stmts) => {
88 for stmt in &stmts {
89 match self.run_stmt(stmt) {
90 Ok(()) => (),
91 Err(e) => self.report_error(&e),
92 }
93 }
94
95 self.stream_out
96 .flush()
97 .expect("Failed to write to output stream");
98 }
99 Err(e) => {
100 self.report_error(&PQLError::from(e));
101 }
102 }
103 }
104
105 fn run_stmt(&mut self, stmt: &ast::Stmt) -> Result<(), PQLError> {
106 let n = self.n_threads;
107 let n_trials = self.n_trials / n;
108
109 let vm = init_vm(stmt, n_trials)?;
110
111 let mut vms: Vec<Vm> = Vec::with_capacity(n);
112 vms.push(vm);
113
114 while vms.len() != n {
115 vms.push(vms[0].clone());
116 }
117
118 let procs = vms
119 .into_iter()
120 .map(|vm| Box::new(move || vm.try_run()) as _)
121 .collect::<Vec<_>>();
122
123 match (parallel_exec(procs)?)
124 .into_iter()
125 .collect::<Result<Vec<Vm>, PQLError>>()
126 {
127 Ok(vms) => {
128 self.aggregate_outputs(&vms, &stmt.selectors, self.n_trials);
129 }
130
131 Err(e) => self.report_error(&e),
132 }
133
134 Ok(())
135 }
136
137 fn report_error(&mut self, e: &PQLError) {
138 let loc: Option<LocInfo> = e.into();
139
140 writeln!(self.stream_err, "Error:")
141 .and_then(|()| writeln!(self.stream_err, "{e}"))
142 .and_then(|()| {
143 if let Some((a, b)) = loc {
144 writeln!(self.stream_err, "{}", &self.src[a..b])
145 } else {
146 Ok(())
147 }
148 })
149 .expect("Failed to write to error stream");
150 }
151
152 #[allow(clippy::cast_precision_loss)]
153 fn aggregate_outputs(
154 &mut self,
155 vms: &[Vm],
156 selectors: &[ast::Selector],
157 n_trials: usize,
158 ) {
159 fn next_write(iter: &mut Iter<'_, VmInstruction>) -> VmStoreVarIdx {
160 loop {
161 match iter.next() {
162 Some(VmInstruction::Write(idx)) => return *idx,
163 Some(_) => (),
164 None => todo!(),
165 };
166 }
167 }
168
169 let vec_ins = vms[0].instructions.clone();
170 let _game: PQLGame = (&vms[0].buffer).into();
171 let mut iter = vec_ins.iter();
172
173 for (n, selector) in selectors.iter().enumerate() {
174 let idx = next_write(&mut iter);
175
176 let name = selector.alias.as_ref().map_or_else(
177 || format!("{:?}{}", selector.kind, n + 1),
178 |id| id.inner.into(),
179 );
180
181 let res = match selector.kind {
182 ast::SelectorKind::Avg => {
183 let v: VmStackValue = vms
184 .iter()
185 .map(|vm| {
186 *vm.store
187 .downcast_get::<&VmStackValue>(idx)
188 .unwrap()
189 })
190 .reduce(|m, e| m.try_add(e).unwrap())
191 .unwrap();
192
193 let c: PQLLong = vms
194 .iter()
195 .map(|vm| {
196 *vm.store.downcast_get::<&PQLLong>(idx + 1).unwrap()
197 })
198 .sum::<PQLLong>();
199
200 let n: VmStackValueNum = v.try_into().unwrap();
201 writeln!(
202 self.stream_out,
203 "{name} = {}",
204 n.cast_dbl() / c as f64
205 )
206 }
207
208 ast::SelectorKind::Count => {
209 let c: PQLLong = vms
210 .iter()
211 .map(|vm| {
212 *vm.store.downcast_get::<&PQLLong>(idx).unwrap()
213 })
214 .sum::<PQLLong>();
215
216 writeln!(
217 self.stream_out,
218 "{name} = {}%({})",
219 100.0 * c as f64 / n_trials as f64,
220 c
221 )
222 }
223
224 ast::SelectorKind::Max => {
225 let v: VmStackValue = vms
226 .iter()
227 .map(|vm| {
228 *vm.store
229 .downcast_get::<&VmStackValue>(idx)
230 .unwrap()
231 })
232 .reduce(|m, e| (if m >= e { m } else { e }))
233 .unwrap();
234
235 writeln!(self.stream_out, "{name} = {v}")
236 }
237
238 ast::SelectorKind::Min => {
239 let v: VmStackValue = vms
240 .iter()
241 .map(|vm| {
242 *vm.store
243 .downcast_get::<&VmStackValue>(idx)
244 .unwrap()
245 })
246 .reduce(|m, e| (if m <= e { m } else { e }))
247 .unwrap();
248
249 writeln!(self.stream_out, "{name} = {v}")
250 }
251 };
252
253 res.expect("Failed to write to output stream");
254 }
255 }
256}