1#![allow(clippy::wildcard_imports)]
2
3use std::{io, slice::Iter, sync::LazyLock};
4
5use derive_more::Debug;
6
7mod init_board;
8mod init_game;
9mod init_players;
10
11use concurrency::parallel_exec;
12use init_board::init_board;
13use init_game::init_game;
14use init_players::init_players;
15use pql_parser::{ast, parse};
16use rustc_hash::*;
17use vm::{Vm, VmStackValue, VmStackValueNum, VmStoreVarIdx, *};
18
19use super::*;
20use crate::{LocInfo, PQLLong, vm::VmInstruction};
21
22#[derive(Debug)]
23pub struct StatementsRunner<'src> {
24 pub(crate) src: &'src str,
25 pub n_trials: usize,
26 pub n_threads: usize,
27 #[debug(skip)]
28 pub stream_out: Box<dyn io::Write>,
29 #[debug(skip)]
30 pub stream_err: Box<dyn io::Write>,
31}
32
33type RangeProc = Box<dyn FnOnce() -> Result<PQLRange, PQLError> + Send>;
34
35pub fn init_vm(stmt: &ast::Stmt, n_trials: usize) -> Result<Vm, PQLError> {
38 let fc = &stmt.from;
39 let sels = &stmt.selectors;
40
41 let game = init_game(fc)?;
42 let board_range_proc = init_board(fc);
43 let (player_names, mut ranges_procs) = init_players(fc, game);
44
45 ranges_procs.push(board_range_proc);
46 let mut ranges = parallel_exec(ranges_procs)?
47 .into_iter()
48 .collect::<Result<Vec<_>, _>>()?;
49 let board_range = ranges.pop().unwrap();
50
51 let (instructions, store) = instruction::init(sels, game, &player_names)?;
52
53 let player_ranges = ranges;
54
55 Ok(Vm {
56 board_range,
57 player_ranges,
58 instructions,
59 store,
60 stack: VmStack::default(),
61 buffer: VmBuffer::new(player_names.len(), game),
62 rng: Rng::default(),
63 n_trials,
64 n_failed: 0,
65 })
66}
67
68impl<'src> StatementsRunner<'src> {
69 pub const fn new(
70 src: &'src str,
71 n_trials: usize,
72 n_threads: usize,
73 stream_out: Box<dyn io::Write>,
74 stream_err: Box<dyn io::Write>,
75 ) -> Self {
76 Self {
77 src,
78 n_trials,
79 n_threads,
80 stream_out,
81 stream_err,
82 }
83 }
84
85 #[allow(clippy::missing_panics_doc)]
86 pub fn run(&mut self) {
87 match parse(self.src) {
88 Ok(stmts) => {
89 for stmt in &stmts {
90 match self.run_stmt(stmt) {
91 Ok(()) => (),
92 Err(e) => self.report_error(&e),
93 }
94 }
95
96 self.stream_out
97 .flush()
98 .expect("Failed to write to output stream");
99 }
100 Err(e) => {
101 self.report_error(&PQLError::from(e));
102 }
103 }
104 }
105
106 fn run_stmt(&mut self, stmt: &ast::Stmt) -> Result<(), PQLError> {
107 let n = self.n_threads;
108 let n_trials = self.n_trials / n;
109
110 let vm = init_vm(stmt, n_trials)?;
111
112 let mut vms: Vec<Vm> = Vec::with_capacity(n);
113 vms.push(vm);
114
115 while vms.len() != n {
116 vms.push(vms[0].clone());
117 }
118
119 let procs = vms
120 .into_iter()
121 .map(|vm| Box::new(move || vm.try_run()) as _)
122 .collect::<Vec<_>>();
123
124 match (parallel_exec(procs)?)
125 .into_iter()
126 .collect::<Result<Vec<Vm>, PQLError>>()
127 {
128 Ok(vms) => {
129 self.aggregate_outputs(&vms, &stmt.selectors, self.n_trials);
130 }
131
132 Err(e) => self.report_error(&e),
133 }
134
135 Ok(())
136 }
137
138 fn report_error(&mut self, e: &PQLError) {
139 let loc: Option<LocInfo> = e.into();
140
141 writeln!(self.stream_err, "Error:")
142 .and_then(|()| writeln!(self.stream_err, "{e}"))
143 .and_then(|()| {
144 if let Some((a, b)) = loc {
145 writeln!(self.stream_err, "{}", &self.src[a..b])
146 } else {
147 Ok(())
148 }
149 })
150 .expect("Failed to write to error stream");
151 }
152
153 #[allow(clippy::cast_precision_loss)]
154 fn aggregate_outputs(
155 &mut self,
156 vms: &[Vm],
157 selectors: &[ast::Selector],
158 n_trials: usize,
159 ) {
160 fn next_write(iter: &mut Iter<'_, VmInstruction>) -> VmStoreVarIdx {
161 loop {
162 match iter.next() {
163 Some(VmInstruction::Write(idx)) => return *idx,
164 Some(_) => (),
165 None => todo!(),
166 }
167 }
168 }
169
170 let vec_ins = vms[0].instructions.clone();
171 let _game: PQLGame = (&vms[0].buffer).into();
172 let mut iter = vec_ins.iter();
173
174 for (n, selector) in selectors.iter().enumerate() {
175 let idx = next_write(&mut iter);
176
177 let name = selector.alias.as_ref().map_or_else(
178 || format!("{:?}{}", selector.kind, n + 1),
179 |id| id.inner.into(),
180 );
181
182 let res = match selector.kind {
183 ast::SelectorKind::Avg => {
184 let v: VmStackValue = vms
185 .iter()
186 .map(|vm| {
187 *vm.store
188 .downcast_get::<&VmStackValue>(idx)
189 .unwrap()
190 })
191 .reduce(|m, e| m.try_add(e).unwrap())
192 .unwrap();
193
194 let c: PQLLong = vms
195 .iter()
196 .map(|vm| {
197 *vm.store.downcast_get::<&PQLLong>(idx + 1).unwrap()
198 })
199 .sum::<PQLLong>();
200
201 let n: VmStackValueNum = v.try_into().unwrap();
202 writeln!(
203 self.stream_out,
204 "{name} = {}",
205 n.cast_dbl() / c as f64
206 )
207 }
208
209 ast::SelectorKind::Count => {
210 let c: PQLLong = vms
211 .iter()
212 .map(|vm| {
213 *vm.store.downcast_get::<&PQLLong>(idx).unwrap()
214 })
215 .sum::<PQLLong>();
216
217 writeln!(
218 self.stream_out,
219 "{name} = {}%({})",
220 100.0 * c as f64 / n_trials as f64,
221 c
222 )
223 }
224
225 ast::SelectorKind::Max => {
226 let v: VmStackValue = vms
227 .iter()
228 .map(|vm| {
229 *vm.store
230 .downcast_get::<&VmStackValue>(idx)
231 .unwrap()
232 })
233 .reduce(|m, e| if m >= e { m } else { e })
234 .unwrap();
235
236 writeln!(self.stream_out, "{name} = {v}")
237 }
238
239 ast::SelectorKind::Min => {
240 let v: VmStackValue = vms
241 .iter()
242 .map(|vm| {
243 *vm.store
244 .downcast_get::<&VmStackValue>(idx)
245 .unwrap()
246 })
247 .reduce(|m, e| if m <= e { m } else { e })
248 .unwrap();
249
250 writeln!(self.stream_out, "{name} = {v}")
251 }
252 };
253
254 res.expect("Failed to write to output stream");
255 }
256 }
257}