1use std::{io, slice::Iter, sync::LazyLock};
2
3use derive_more::Debug;
4
5mod init_board;
6mod init_game;
7mod init_players;
8
9use concurrency::parallel_exec;
10use init_board::init_board;
11use init_game::init_game;
12use init_players::init_players;
13use pql_parser::{ast, parse};
14use rustc_hash::*;
15use vm::{Vm, VmStackValue, VmStackValueNum, VmStoreVarIdx, *};
16
17use super::*;
18use crate::{LocInfo, PQLLong, vm::VmInstruction};
19
20#[derive(Debug)]
21pub struct StatementsRunner<'src> {
22 pub(crate) src: &'src str,
23 pub n_trials: usize,
24 pub n_threads: usize,
25 #[debug(skip)]
26 pub stream_out: Box<dyn io::Write>,
27 #[debug(skip)]
28 pub stream_err: Box<dyn io::Write>,
29}
30
31type RangeProc = Box<dyn FnOnce() -> Result<PQLRange, PQLError> + Send>;
32
33pub fn init_vm(stmt: &ast::Stmt, n_trials: usize) -> Result<Vm, PQLError> {
36 let fc = &stmt.from;
37 let sels = &stmt.selectors;
38
39 let game = init_game(fc)?;
40 let board_range_proc = init_board(fc);
41 let (player_names, mut ranges_procs) = init_players(fc, game);
42
43 ranges_procs.push(board_range_proc);
44 let mut ranges = parallel_exec(ranges_procs)?
45 .into_iter()
46 .collect::<Result<Vec<_>, _>>()?;
47 let board_range = ranges.pop().unwrap();
48
49 let (instructions, store) = instruction::init(sels, game, &player_names)?;
50
51 let player_ranges = ranges;
52
53 Ok(Vm {
54 board_range,
55 player_ranges,
56 instructions,
57 store,
58 stack: VmStack::default(),
59 buffer: VmBuffer::new(player_names.len(), game),
60 rng: Rng::default(),
61 n_trials,
62 n_failed: 0,
63 })
64}
65
66impl<'src> StatementsRunner<'src> {
67 pub const fn new(
68 src: &'src str,
69 n_trials: usize,
70 n_threads: usize,
71 stream_out: Box<dyn io::Write>,
72 stream_err: Box<dyn io::Write>,
73 ) -> Self {
74 Self {
75 src,
76 n_trials,
77 n_threads,
78 stream_out,
79 stream_err,
80 }
81 }
82
83 #[allow(clippy::missing_panics_doc)]
84 pub fn run(&mut self) {
85 match parse(self.src) {
86 Ok(stmts) => {
87 for stmt in &stmts {
88 match self.run_stmt(stmt) {
89 Ok(()) => (),
90 Err(e) => self.report_error(&e),
91 }
92 }
93
94 self.stream_out
95 .flush()
96 .expect("Failed to write to output stream");
97 }
98 Err(e) => {
99 self.report_error(&PQLError::from(e));
100 }
101 }
102 }
103
104 fn run_stmt(&mut self, stmt: &ast::Stmt) -> Result<(), PQLError> {
105 let n = self.n_threads;
106 let n_trials = self.n_trials / n;
107
108 let vm = init_vm(stmt, n_trials)?;
109
110 let mut vms: Vec<Vm> = Vec::with_capacity(n);
111 vms.push(vm);
112
113 while vms.len() != n {
114 vms.push(vms[0].clone());
115 }
116
117 let procs = vms
118 .into_iter()
119 .map(|vm| Box::new(move || vm.try_run()) as _)
120 .collect::<Vec<_>>();
121
122 match (parallel_exec(procs)?)
123 .into_iter()
124 .collect::<Result<Vec<Vm>, PQLError>>()
125 {
126 Ok(vms) => {
127 self.aggregate_outputs(&vms, &stmt.selectors, self.n_trials);
128 }
129
130 Err(e) => self.report_error(&e),
131 }
132
133 Ok(())
134 }
135
136 fn report_error(&mut self, e: &PQLError) {
137 let loc: Option<LocInfo> = e.into();
138
139 writeln!(self.stream_err, "Error:")
140 .and_then(|()| writeln!(self.stream_err, "{e}"))
141 .and_then(|()| {
142 if let Some((a, b)) = loc {
143 writeln!(self.stream_err, "{}", &self.src[a..b])
144 } else {
145 Ok(())
146 }
147 })
148 .expect("Failed to write to error stream");
149 }
150
151 #[allow(clippy::cast_precision_loss)]
152 fn aggregate_outputs(
153 &mut self,
154 vms: &[Vm],
155 selectors: &[ast::Selector],
156 n_trials: usize,
157 ) {
158 fn next_write(iter: &mut Iter<'_, VmInstruction>) -> VmStoreVarIdx {
159 loop {
160 match iter.next() {
161 Some(VmInstruction::Write(idx)) => return *idx,
162 Some(_) => (),
163 None => todo!(),
164 }
165 }
166 }
167
168 let vec_ins = vms[0].instructions.clone();
169 let _game: PQLGame = (&vms[0].buffer).into();
170 let mut iter = vec_ins.iter();
171
172 for (n, selector) in selectors.iter().enumerate() {
173 let idx = next_write(&mut iter);
174
175 let name = selector.alias.as_ref().map_or_else(
176 || format!("{:?}{}", selector.kind, n + 1),
177 |id| id.inner.into(),
178 );
179
180 let res = match selector.kind {
181 ast::SelectorKind::Avg => {
182 let v: VmStackValue = vms
183 .iter()
184 .map(|vm| {
185 *vm.store
186 .downcast_get::<&VmStackValue>(idx)
187 .unwrap()
188 })
189 .reduce(|m, e| m.try_add(e).unwrap())
190 .unwrap();
191
192 let c: PQLLong = vms
193 .iter()
194 .map(|vm| {
195 *vm.store.downcast_get::<&PQLLong>(idx + 1).unwrap()
196 })
197 .sum::<PQLLong>();
198
199 let n: VmStackValueNum = v.try_into().unwrap();
200 writeln!(
201 self.stream_out,
202 "{name} = {}",
203 n.cast_dbl() / c as f64
204 )
205 }
206
207 ast::SelectorKind::Count => {
208 let c: PQLLong = vms
209 .iter()
210 .map(|vm| {
211 *vm.store.downcast_get::<&PQLLong>(idx).unwrap()
212 })
213 .sum::<PQLLong>();
214
215 writeln!(
216 self.stream_out,
217 "{name} = {}%({})",
218 100.0 * c as f64 / n_trials as f64,
219 c
220 )
221 }
222
223 ast::SelectorKind::Max => {
224 let v: VmStackValue = vms
225 .iter()
226 .map(|vm| {
227 *vm.store
228 .downcast_get::<&VmStackValue>(idx)
229 .unwrap()
230 })
231 .reduce(|m, e| if m >= e { m } else { e })
232 .unwrap();
233
234 writeln!(self.stream_out, "{name} = {v}")
235 }
236
237 ast::SelectorKind::Min => {
238 let v: VmStackValue = vms
239 .iter()
240 .map(|vm| {
241 *vm.store
242 .downcast_get::<&VmStackValue>(idx)
243 .unwrap()
244 })
245 .reduce(|m, e| if m <= e { m } else { e })
246 .unwrap();
247
248 writeln!(self.stream_out, "{name} = {v}")
249 }
250 };
251
252 res.expect("Failed to write to output stream");
253 }
254 }
255}