1use anyhow::{anyhow, Result};
2use std::collections::HashSet;
3use std::io::{Read, Write};
4
5use crate::amd::{AmdFamily, AmdGenerator};
6use crate::applet::Applet;
7use crate::arm::{ArmGenerator, ArmSimdGenerator};
8use crate::complexify::Complexifier;
9use crate::config::Config;
10use crate::generator::Generator;
11use crate::machine::MachineCode;
12use crate::matrix::{combine_matrixes, Matrix};
13use crate::mir::{CompiledMir, Mir};
14use crate::model::Program;
15use crate::riscv64::RiscV;
16use crate::symbol::Loc;
17use crate::utils::*;
18
19use rayon::prelude::*;
20
21#[derive(Debug, PartialEq, Copy, Clone)]
22pub enum CompilerType {
23 ByteCode,
25 Native,
27 Amd,
29 AmdAVX,
31 AmdSSE,
33 Arm,
35 RiscV,
37 Debug,
40}
41
42#[repr(C)] pub struct Application {
44 pub compiled: Option<MachineCode<f64>>,
48 pub compiled_simd: Option<MachineCode<f64>>,
49 pub use_simd: bool,
50 pub use_threads: bool,
51 pub count_states: usize,
52 pub count_params: usize,
53 pub count_obs: usize,
54 pub count_diffs: usize,
55 pub config: Config,
56 pub prog: Program,
58 pub compiled_fast: Option<MachineCode<f64>>,
59 pub bytecode: CompiledMir,
60 pub params: Vec<f64>,
61 pub can_fast: bool,
62 pub first_state: usize,
63 pub first_param: usize,
64 pub first_obs: usize,
65 pub first_diff: usize,
66}
67
68impl Application {
69 pub fn new(mut prog: Program, reals: HashSet<Loc>) -> Result<Application> {
70 let first_state = 0;
71 let first_param = 0;
72 let first_obs = first_state + prog.count_states;
73 let first_diff = first_obs + prog.count_obs;
74
75 let count_states = prog.count_states;
76 let count_params = prog.count_params;
77 let count_obs = prog.count_obs;
78 let count_diffs = prog.count_diffs;
79
80 let params = vec![0.0; count_params + 1];
81
82 let config = prog.config().clone();
83
84 let mut mir = prog.builder.compile_mir()?;
85
86 if config.is_complex() {
87 mir = Complexifier::new(&reals, config.clone()).complexify(&mir)?;
88 }
89
90 let compiled = match config.compiler_type() {
92 CompilerType::AmdAVX => Some(Self::compile_avx(&mir, &mut prog)?),
93 CompilerType::AmdSSE => Some(Self::compile_sse(&mir, &mut prog)?),
94 CompilerType::Arm => Some(Self::compile_arm(&mir, &mut prog)?),
95 CompilerType::RiscV => Some(Self::compile_riscv(&mir, &mut prog)?),
96 CompilerType::ByteCode => None,
97 CompilerType::Debug => {
98 println!("`ty = debug` is deprecated");
99 None
100 }
101 _ => return Err(anyhow!("unrecognized `ty`")),
102 };
103
104 let use_simd = config.use_simd() && prog.count_loops == 0;
105 let use_threads = config.use_threads() && prog.mem_size() < 128;
106
107 let can_fast = config.may_fast()
108 && count_states <= 8
109 && count_params == 0
110 && count_obs == 1
111 && count_diffs == 0;
112
113 let bytecode = Self::compile_bytecode(mir, &mut prog)?;
115
116 Ok(Application {
117 prog,
118 compiled,
119 compiled_simd: None,
120 compiled_fast: None,
121 bytecode,
122 params,
123 use_simd,
124 use_threads,
125 can_fast,
126 first_state,
127 first_param,
128 first_obs,
129 first_diff,
130 count_states,
131 count_params,
132 count_obs,
133 count_diffs,
134 config,
135 })
136 }
137
138 pub fn seal(self) -> Result<Applet> {
139 Applet::new(self)
140 }
141
142 pub fn as_applet(&self) -> &Applet {
143 unsafe { std::mem::transmute(self) }
144 }
145
146 fn compile<G: Generator>(
149 mir: &Mir,
150 prog: &mut Program,
151 mut generator: G,
152 size: usize,
153 arch: &str,
154 lanes: usize,
155 ) -> Result<MachineCode<f64>> {
156 let mem: Vec<f64> = vec![0.0; size];
157 prog.builder.compile_from_mir(
158 mir,
159 &mut generator,
160 prog.count_states,
161 prog.count_obs,
162 prog.count_params,
163 )?;
164
165 Ok(MachineCode::new(arch, generator.bytes(), mem, false, lanes))
166 }
167
168 fn compile_fast<G: Generator>(
169 mir: &Mir,
170 prog: &mut Program,
171 mut generator: G,
172 idx_ret: u32,
173 arch: &str,
174 ) -> Result<MachineCode<f64>> {
175 let mem: Vec<f64> = Vec::new();
176 prog.builder.compile_fast_from_mir(
177 mir,
178 &mut generator,
179 prog.count_states,
180 prog.count_obs,
181 idx_ret as i32,
182 )?;
183
184 Ok(MachineCode::new(arch, generator.bytes(), mem, true, 1))
185 }
186
187 fn compile_bytecode(mir: Mir, prog: &mut Program) -> Result<CompiledMir> {
188 let mem: Vec<f64> = vec![0.0; prog.mem_size()];
189 let stack: Vec<f64> = vec![0.0; prog.builder.block().sym_table.num_stack];
190
191 Ok(CompiledMir::new(mir, mem, stack))
192 }
193
194 fn compile_sse(mir: &Mir, prog: &mut Program) -> Result<MachineCode<f64>> {
195 Self::compile::<AmdGenerator>(
196 mir,
197 prog,
198 AmdGenerator::new(AmdFamily::SSEScalar, prog.config().clone()),
199 prog.mem_size(),
200 "x86_64",
201 1,
202 )
203 }
204
205 fn compile_avx(mir: &Mir, prog: &mut Program) -> Result<MachineCode<f64>> {
206 Self::compile::<AmdGenerator>(
207 mir,
208 prog,
209 AmdGenerator::new(AmdFamily::AvxScalar, prog.config().clone()),
210 prog.mem_size(),
211 "x86_64",
212 1,
213 )
214 }
215
216 fn compile_avx_simd(mir: &Mir, prog: &mut Program) -> Result<MachineCode<f64>> {
217 Self::compile::<AmdGenerator>(
218 mir,
219 prog,
220 AmdGenerator::new(AmdFamily::AvxVector, prog.config().clone()),
221 prog.mem_size() * 4,
222 "x86_64",
223 4,
224 )
225 }
226
227 fn compile_arm(mir: &Mir, prog: &mut Program) -> Result<MachineCode<f64>> {
228 Self::compile::<ArmGenerator>(
229 mir,
230 prog,
231 ArmGenerator::new(prog.config().clone()),
232 prog.mem_size(),
233 "aarch64",
234 1,
235 )
236 }
237
238 fn compile_arm_simd(mir: &Mir, prog: &mut Program) -> Result<MachineCode<f64>> {
239 Self::compile::<ArmSimdGenerator>(
240 mir,
241 prog,
242 ArmSimdGenerator::new(prog.config().clone()),
243 prog.mem_size() * 2,
244 "aarch64",
245 2,
246 )
247 }
248
249 fn compile_riscv(mir: &Mir, prog: &mut Program) -> Result<MachineCode<f64>> {
250 Self::compile::<RiscV>(
251 mir,
252 prog,
253 RiscV::new(prog.config().clone()),
254 prog.mem_size(),
255 "riscv64",
256 1,
257 )
258 }
259
260 fn compile_amd_fast(mir: &Mir, prog: &mut Program, idx_ret: u32) -> Result<MachineCode<f64>> {
261 if prog.config().has_avx() {
262 Self::compile_fast(
263 mir,
264 prog,
265 AmdGenerator::new(AmdFamily::AvxScalar, prog.config().clone()),
266 idx_ret,
267 "x86_64",
268 )
269 } else {
270 Self::compile_fast(
271 mir,
272 prog,
273 AmdGenerator::new(AmdFamily::SSEScalar, prog.config().clone()),
274 idx_ret,
275 "x86_64",
276 )
277 }
278 }
279
280 fn compile_arm_fast(mir: &Mir, prog: &mut Program, idx_ret: u32) -> Result<MachineCode<f64>> {
281 Self::compile_fast(
282 mir,
283 prog,
284 ArmGenerator::new(prog.config().clone()),
285 idx_ret,
286 "aarch64",
287 )
288 }
289
290 fn compile_riscv_fast(mir: &Mir, prog: &mut Program, idx_ret: u32) -> Result<MachineCode<f64>> {
291 Self::compile_fast(
292 mir,
293 prog,
294 RiscV::new(prog.config().clone()),
295 idx_ret,
296 "riscv64",
297 )
298 }
299
300 #[inline]
303 pub fn exec(&mut self) {
304 if let Some(compiled) = &mut self.compiled {
305 compiled.exec(&self.params[..])
306 } else {
307 self.bytecode.exec(&self.params[..]);
308 }
309 }
310
311 pub fn exec_callable(&mut self, xx: &[f64]) -> f64 {
312 if let Some(compiled) = &mut self.compiled {
313 let mem = compiled.mem_mut();
314 mem[self.first_state..self.first_state + self.count_states].copy_from_slice(xx);
315 compiled.exec(&self.params[..]);
316 compiled.mem()[self.first_obs]
317 } else {
318 let mem = self.bytecode.mem_mut();
319 mem[self.first_state..self.first_state + self.count_states].copy_from_slice(xx);
320 self.bytecode.exec(&self.params[..]);
321 self.bytecode.mem()[self.first_obs]
322 }
323 }
324
325 pub fn prepare_simd(&mut self) {
326 if self.compiled_simd.is_none() && self.use_simd {
328 if self.config.has_avx() {
329 self.compiled_simd =
330 Self::compile_avx_simd(&self.bytecode.mir, &mut self.prog).ok();
331 } else if self.config.is_arm64() {
332 self.compiled_simd =
333 Self::compile_arm_simd(&self.bytecode.mir, &mut self.prog).ok();
334 }
335 };
336 }
337
338 fn prepare_fast(&mut self) {
339 if self.compiled_simd.is_none() && self.can_fast {
341 if self.config.is_amd64() {
342 self.compiled_fast = Self::compile_amd_fast(
343 &self.bytecode.mir,
344 &mut self.prog,
345 self.first_obs as u32,
346 )
347 .ok();
348 } else if self.config.is_arm64() {
349 self.compiled_fast = Self::compile_arm_fast(
350 &self.bytecode.mir,
351 &mut self.prog,
352 self.first_obs as u32,
353 )
354 .ok();
355 } else if self.config.is_riscv64() {
356 self.compiled_fast = Self::compile_riscv_fast(
357 &self.bytecode.mir,
358 &mut self.prog,
359 self.first_obs as u32,
360 )
361 .ok();
362 }
363 };
364 }
365
366 pub fn get_fast(&mut self) -> Option<CompiledFunc<f64>> {
367 self.prepare_fast();
368 self.compiled_fast.as_ref().map(|c| c.func())
369 }
370
371 pub fn exec_vectorized(&mut self, states: &mut Matrix, obs: &mut Matrix) {
372 if let Some(compiled) = &self.compiled {
373 if !compiled.support_indirect() {
374 self.exec_vectorized_simple(states, obs);
375 return;
376 }
377
378 self.prepare_simd();
379
380 if let Some(simd) = &self.compiled_simd {
381 self.exec_vectorized_simd(states, obs, self.use_threads, simd.count_lanes());
382 } else {
383 self.exec_vectorized_scalar(states, obs, self.use_threads);
384 }
385 }
386 }
387
388 pub fn exec_vectorized_simple(&mut self, states: &Matrix, obs: &mut Matrix) {
389 assert!(states.ncols == obs.ncols);
390 let n = states.ncols;
391 let params = &self.params[..];
392
393 if let Some(compiled) = &mut self.compiled {
394 for t in 0..n {
395 {
396 let mem = compiled.mem_mut();
397 for i in 0..self.count_states {
398 mem[self.first_state + i] = states.get(i, t);
399 }
400 }
401
402 compiled.exec(params);
403
404 {
405 let mem = compiled.mem_mut();
406 for i in 0..self.count_obs {
407 obs.set(i, t, mem[self.first_obs + i]);
408 }
409 }
410 }
411 } else {
412 for t in 0..n {
413 {
414 let mem = self.bytecode.mem_mut();
415 for i in 0..self.count_states {
416 mem[self.first_state + i] = states.get(i, t);
417 }
418 }
419
420 self.bytecode.exec(params);
421
422 {
423 let mem = self.bytecode.mem_mut();
424 for i in 0..self.count_obs {
425 obs.set(i, t, mem[self.first_obs + i]);
426 }
427 }
428 }
429 }
430 }
431
432 fn exec_single(t: usize, v: &Matrix, params: &[f64], f: CompiledFunc<f64>) {
433 let p = v.p.as_ptr();
434 f(std::ptr::null(), p, t, params.as_ptr());
435 }
436
437 pub fn exec_vectorized_scalar(&mut self, states: &mut Matrix, obs: &mut Matrix, threads: bool) {
438 if let Some(compiled) = &mut self.compiled {
439 assert!(states.ncols == obs.ncols);
440 let n = states.ncols;
441 let f = compiled.func();
442 let params = &self.params[..];
443 let v = combine_matrixes(states, obs);
444
445 if threads {
446 (0..n)
447 .into_par_iter()
448 .for_each(|t| Self::exec_single(t, &v, params, f));
449 } else {
450 (0..n)
451 .for_each(|t| Self::exec_single(t, &v, params, f));
453 }
454 }
455 }
456
457 pub fn exec_vectorized_simd(
458 &mut self,
459 states: &mut Matrix,
460 obs: &mut Matrix,
461 threads: bool,
462 l: usize,
463 ) {
464 if let Some(compiled) = &mut self.compiled {
465 assert!(states.ncols == obs.ncols);
466 let n = states.ncols;
467 let params = &self.params[..];
468 let n0 = l * (n / l);
469 let v = combine_matrixes(states, obs);
470
471 if let Some(g) = &mut self.compiled_simd {
472 let f = g.func();
473 if threads {
474 (0..n / l)
475 .into_par_iter()
476 .for_each(|t| Self::exec_single(t, &v, params, f));
477 } else {
478 (0..n / l).for_each(|t| Self::exec_single(t, &v, params, f));
479 }
480 }
481
482 let f = compiled.func();
483
484 if threads {
485 (n0..n)
486 .into_par_iter()
487 .for_each(|t| Self::exec_single(t, &v, params, f));
488 } else {
489 (n0..n).for_each(|t| Self::exec_single(t, &v, params, f));
490 }
491 }
492 }
493
494 pub fn dump(&mut self, name: &str, what: &str) -> bool {
495 match what {
496 "scalar" => {
497 if let Some(f) = &self.compiled {
498 f.dump(name);
499 true
500 } else {
501 false
502 }
503 }
504 "simd" => {
505 self.prepare_simd();
506
507 if let Some(f) = &self.compiled_simd {
508 f.dump(name);
509 true
510 } else {
511 false
512 }
513 }
514 "fast" => {
515 self.prepare_fast();
516
517 if let Some(f) = &self.compiled_fast {
518 f.dump(name);
519 true
520 } else {
521 false
522 }
523 }
524 _ => false,
525 }
526 }
527
528 pub fn dumps(&self) -> Vec<u8> {
529 if let Some(f) = &self.compiled {
530 f.dumps()
531 } else {
532 Vec::new()
533 }
534 }
535
536 const MAGIC: usize = 0x40568795410d08e9;
539}
540
541impl Storage for Application {
542 fn save(&self, stream: &mut impl Write) -> Result<()> {
543 stream.write_all(&Self::MAGIC.to_le_bytes())?;
544
545 let version: usize = 1;
546 stream.write_all(&version.to_le_bytes())?;
547
548 self.prog.save(stream)?;
549
550 let mut mask: usize = 0;
551
552 if self.compiled.is_some() && self.compiled.as_ref().unwrap().as_machine().is_some() {
553 mask |= 1;
554 };
555
556 if self.compiled_fast.is_some()
557 && self.compiled_fast.as_ref().unwrap().as_machine().is_some()
558 {
559 mask |= 2;
560 }
561
562 if self.compiled_simd.is_some()
563 && self.compiled_simd.as_ref().unwrap().as_machine().is_some()
564 {
565 mask |= 4;
566 }
567
568 stream.write_all(&mask.to_le_bytes())?;
569
570 if let Some(compiled) = &self.compiled {
571 compiled.as_machine().unwrap().save(stream)?;
572 }
573
574 if let Some(compiled) = &self.compiled_fast {
575 compiled.as_machine().unwrap().save(stream)?;
576 }
577
578 if let Some(compiled) = &self.compiled_simd {
579 compiled.as_machine().unwrap().save(stream)?;
580 }
581
582 Ok(())
583 }
584
585 fn load(stream: &mut impl Read) -> Result<Self> {
586 let mut bytes: [u8; 8] = [0; 8];
587
588 stream.read_exact(&mut bytes)?;
589
590 if usize::from_le_bytes(bytes) != Self::MAGIC {
591 return Err(anyhow!("invalid magic number"));
592 }
593
594 stream.read_exact(&mut bytes)?;
595
596 if usize::from_le_bytes(bytes) != 1 {
597 return Err(anyhow!("invalid sjb version"));
598 }
599
600 let mut prog = Program::load(stream)?;
601
602 stream.read_exact(&mut bytes)?;
603 let mask = usize::from_le_bytes(bytes);
604
605 let compiled: Option<MachineCode<f64>> = if mask & 1 != 0 {
606 Some(MachineCode::load(stream)?)
607 } else {
608 None
609 };
610
611 let compiled_fast: Option<MachineCode<f64>> = if mask & 2 != 0 {
612 Some(MachineCode::load(stream)?)
613 } else {
614 None
615 };
616
617 let compiled_simd: Option<MachineCode<f64>> = if mask & 4 != 0 {
618 Some(MachineCode::load(stream)?)
619 } else {
620 None
621 };
622
623 let first_state = 0;
624 let first_param = 0;
625 let first_obs = first_state + prog.count_states;
626 let first_diff = first_obs + prog.count_obs;
627
628 let count_states = prog.count_states;
629 let count_params = prog.count_params;
630 let count_obs = prog.count_obs;
631 let count_diffs = prog.count_diffs;
632
633 let params = vec![0.0; count_params + 1];
634
635 let config = prog.config().clone();
636 let mir = Mir::new(config.clone());
637
638 let use_simd = config.use_simd() && prog.count_loops == 0;
639 let use_threads = config.use_threads() && prog.mem_size() < 128;
640
641 let can_fast = config.may_fast()
642 && count_states <= 8
643 && count_params == 0
644 && count_obs == 1
645 && count_diffs == 0;
646
647 let bytecode = Self::compile_bytecode(mir, &mut prog)?;
648
649 Ok(Application {
650 prog,
651 compiled,
652 compiled_simd,
653 compiled_fast,
654 bytecode,
655 params,
656 use_simd,
657 use_threads,
658 can_fast,
659 first_state,
660 first_param,
661 first_obs,
662 first_diff,
663 count_states,
664 count_params,
665 count_obs,
666 count_diffs,
667 config,
668 })
669 }
670}