Skip to main content

symjit/
runnable.rs

1use anyhow::{anyhow, Result};
2use std::collections::HashSet;
3use std::io::{Read, Write};
4
5use crate::amd::{AmdComplexGenerator, AmdSSEGenerator, AmdScalarGenerator, AmdVectorGenerator};
6use crate::applet::Applet;
7use crate::arm::{ArmComplexGenerator, ArmGenerator, ArmSimdGenerator};
8use crate::complexify::Complexifier;
9use crate::config::Config;
10use crate::generator::Generator;
11use crate::machine::MachineCode;
12use crate::matrix::{combine_matrixes, Matrix};
13use crate::mir::{CompiledMir, Mir};
14use crate::model::Program;
15use crate::riscv64::RiscV;
16use crate::symbol::Loc;
17use crate::utils::*;
18
19use rayon::prelude::*;
20
21#[derive(Debug, PartialEq, Copy, Clone)]
22pub enum CompilerType {
23    /// generates bytecode (interpreter).
24    ByteCode,
25    /// generates code for the detected CPU (default)
26    Native,
27    /// generates x86-64 (AMD64) code.
28    Amd,
29    /// generates AVX code for x86-64 architecture.
30    AmdAVX,
31    /// generates SSE2 code for x86-64 architecture.
32    AmdSSE,
33    /// generates aarch64 (ARM64) code.
34    Arm,
35    /// generates riscv64 (RISC V) code.
36    RiscV,
37    /// debug mode, generates both bytecode and native codes
38    /// and compares the outputs.
39    Debug,
40}
41
42#[repr(C)] // to ensure binary compatibility with Applet
43pub struct Application {
44    // Applet compatibility
45    // Important! The order of these fields is critical and should be
46    // the same as the order of Applet fields.
47    pub compiled: Option<MachineCode<f64>>,
48    pub compiled_simd: Option<MachineCode<f64>>,
49    pub use_simd: bool,
50    pub use_threads: bool,
51    pub count_states: usize,
52    pub count_params: usize,
53    pub count_obs: usize,
54    pub count_diffs: usize,
55    pub config: Config,
56    // Non-Applet fields
57    pub prog: Program,
58    pub compiled_fast: Option<MachineCode<f64>>,
59    pub bytecode: CompiledMir,
60    pub params: Vec<f64>,
61    pub can_fast: bool,
62    pub first_state: usize,
63    pub first_param: usize,
64    pub first_obs: usize,
65    pub first_diff: usize,
66    pub reals: HashSet<Loc>,
67    pub original: Option<Mir>,
68}
69
70impl Application {
71    pub fn new(mut prog: Program, reals: HashSet<Loc>) -> Result<Application> {
72        /*
73         * Stop-gap measure. A better solution would be to add `times_real`,
74         * `divide_real`, and `load_param_real` to generators.
75         */
76        if !reals.is_empty() {
77            prog.builder.config.set_fast_complex(false);
78        }
79
80        let mut mir = Mir::new(prog.config().clone());
81        prog.builder.compile_mir(&mut mir)?;
82        prog.builder.optimize_mir(&mut mir)?;
83        Self::with_mir(prog, reals, mir)
84    }
85
86    pub fn with_mir(mut prog: Program, reals: HashSet<Loc>, mut mir: Mir) -> Result<Application> {
87        let first_state = 0;
88        let first_param = 0;
89        let first_obs = first_state + prog.count_states;
90        let first_diff = first_obs + prog.count_obs;
91
92        let count_states = prog.count_states;
93        let count_params = prog.count_params;
94        let count_obs = prog.count_obs;
95        let count_diffs = prog.count_diffs;
96
97        let params = vec![0.0; count_params + 1];
98
99        let config = prog.config().clone();
100        let mut original: Option<Mir> = None;
101        let compiled: Option<MachineCode<f64>>;
102
103        if config.is_complex() {
104            original = Some(mir.clone());
105            let complexified = Complexifier::new(&reals, config.clone()).complexify(&mir)?;
106
107            if config.fast_complex() {
108                /*
109                crate::allocator::GreedyAllocator::new(config.clone(), config.available_registers() as usize - 4)
110                    .optimize(&mut mir)?;
111                */
112                compiled = Self::compile_ty(&config, &mir, &mut prog)?;
113            } else {
114                compiled = Self::compile_ty(&config, &complexified, &mut prog)?;
115            }
116
117            mir = complexified;
118        } else {
119            compiled = Self::compile_ty(&config, &mir, &mut prog)?;
120        }
121
122        let use_simd = config.use_simd() && prog.count_loops == 0;
123        let use_threads = config.use_threads();
124
125        let can_fast = config.may_fast()
126            && count_states <= 8
127            && count_params == 0
128            && count_obs == 1
129            && count_diffs == 0;
130
131        // bytecode takes the ownership of mir
132        let bytecode = Self::compile_bytecode(mir, &mut prog)?;
133
134        Ok(Application {
135            prog,
136            compiled,
137            compiled_simd: None,
138            compiled_fast: None,
139            bytecode,
140            params,
141            use_simd,
142            use_threads,
143            can_fast,
144            first_state,
145            first_param,
146            first_obs,
147            first_diff,
148            count_states,
149            count_params,
150            count_obs,
151            count_diffs,
152            config,
153            reals,
154            original,
155        })
156    }
157
158    fn compile_ty(
159        config: &Config,
160        mir: &Mir,
161        prog: &mut Program,
162    ) -> Result<Option<MachineCode<f64>>> {
163        let compiled = match config.compiler_type() {
164            CompilerType::AmdAVX => Some(Self::compile_avx(mir, prog)?),
165            CompilerType::AmdSSE => Some(Self::compile_sse(mir, prog)?),
166            CompilerType::Arm => Some(Self::compile_arm(mir, prog)?),
167            CompilerType::RiscV => Some(Self::compile_riscv(mir, prog)?),
168            CompilerType::ByteCode => None,
169            CompilerType::Debug => {
170                println!("`ty = debug` is deprecated");
171                None
172            }
173            _ => return Err(anyhow!("unrecognized `ty`")),
174        };
175
176        Ok(compiled)
177    }
178
179    pub fn seal(self) -> Result<Applet> {
180        Applet::new(self)
181    }
182
183    pub fn as_applet(&self) -> &Applet {
184        unsafe { std::mem::transmute(self) }
185    }
186
187    /********************* compile_* functions *************************/
188
189    fn compile<G: Generator>(
190        mir: &Mir,
191        prog: &mut Program,
192        mut generator: G,
193        size: usize,
194        arch: &str,
195        lanes: usize,
196    ) -> Result<MachineCode<f64>> {
197        let mem: Vec<f64> = vec![0.0; size];
198        prog.builder.compile_from_mir(
199            mir,
200            &mut generator,
201            prog.count_states,
202            prog.count_obs,
203            prog.count_params,
204        )?;
205
206        Ok(MachineCode::new(
207            arch,
208            generator.bytes(),
209            mem,
210            false,
211            lanes,
212            prog.config().huge(),
213        ))
214    }
215
216    fn compile_fast<G: Generator>(
217        mir: &Mir,
218        prog: &mut Program,
219        mut generator: G,
220        idx_ret: u32,
221        arch: &str,
222    ) -> Result<MachineCode<f64>> {
223        let mem: Vec<f64> = Vec::new();
224        prog.builder.compile_fast_from_mir(
225            mir,
226            &mut generator,
227            prog.count_states,
228            prog.count_obs,
229            idx_ret as i32,
230        )?;
231
232        Ok(MachineCode::new(
233            arch,
234            generator.bytes(),
235            mem,
236            true,
237            1,
238            prog.config().huge(),
239        ))
240    }
241
242    fn compile_bytecode(mir: Mir, prog: &mut Program) -> Result<CompiledMir> {
243        let mem: Vec<f64> = vec![0.0; prog.mem_size()];
244        let stack: Vec<f64> = vec![0.0; prog.builder.stack_size()];
245
246        Ok(CompiledMir::new(mir, mem, stack))
247    }
248
249    fn compile_sse(mir: &Mir, prog: &mut Program) -> Result<MachineCode<f64>> {
250        Self::compile::<AmdSSEGenerator>(
251            mir,
252            prog,
253            AmdSSEGenerator::new(prog.config().clone()),
254            prog.mem_size(),
255            "x86_64",
256            1,
257        )
258    }
259
260    fn compile_avx(mir: &Mir, prog: &mut Program) -> Result<MachineCode<f64>> {
261        if prog.config().is_complex() && prog.config().fast_complex() {
262            Self::compile::<AmdComplexGenerator>(
263                mir,
264                prog,
265                AmdComplexGenerator::new(prog.config().clone()),
266                prog.mem_size(),
267                "x86_64",
268                1,
269            )
270        } else {
271            Self::compile::<AmdScalarGenerator>(
272                mir,
273                prog,
274                AmdScalarGenerator::new(prog.config().clone()),
275                prog.mem_size(),
276                "x86_64",
277                1,
278            )
279        }
280    }
281
282    fn compile_avx_simd(mir: &Mir, prog: &mut Program) -> Result<MachineCode<f64>> {
283        Self::compile::<AmdVectorGenerator>(
284            mir,
285            prog,
286            AmdVectorGenerator::new(prog.config().clone()),
287            prog.mem_size() * 4,
288            "x86_64",
289            4,
290        )
291    }
292
293    fn compile_arm(mir: &Mir, prog: &mut Program) -> Result<MachineCode<f64>> {
294        if prog.config().is_complex() && prog.config().fast_complex() {
295            Self::compile::<ArmComplexGenerator>(
296                mir,
297                prog,
298                ArmComplexGenerator::new(prog.config().clone()),
299                prog.mem_size(),
300                "aarch64",
301                1,
302            )
303        } else {
304            Self::compile::<ArmGenerator>(
305                mir,
306                prog,
307                ArmGenerator::new(prog.config().clone()),
308                prog.mem_size(),
309                "aarch64",
310                1,
311            )
312        }
313    }
314
315    fn compile_arm_simd(mir: &Mir, prog: &mut Program) -> Result<MachineCode<f64>> {
316        Self::compile::<ArmSimdGenerator>(
317            mir,
318            prog,
319            ArmSimdGenerator::new(prog.config().clone()),
320            prog.mem_size() * 2,
321            "aarch64",
322            2,
323        )
324    }
325
326    fn compile_riscv(mir: &Mir, prog: &mut Program) -> Result<MachineCode<f64>> {
327        Self::compile::<RiscV>(
328            mir,
329            prog,
330            RiscV::new(prog.config().clone()),
331            prog.mem_size(),
332            "riscv64",
333            1,
334        )
335    }
336
337    fn compile_amd_fast(mir: &Mir, prog: &mut Program, idx_ret: u32) -> Result<MachineCode<f64>> {
338        if prog.config().has_avx() {
339            Self::compile_fast(
340                mir,
341                prog,
342                AmdScalarGenerator::new(prog.config().clone()),
343                idx_ret,
344                "x86_64",
345            )
346        } else {
347            Self::compile_fast(
348                mir,
349                prog,
350                AmdSSEGenerator::new(prog.config().clone()),
351                idx_ret,
352                "x86_64",
353            )
354        }
355    }
356
357    fn compile_arm_fast(mir: &Mir, prog: &mut Program, idx_ret: u32) -> Result<MachineCode<f64>> {
358        Self::compile_fast(
359            mir,
360            prog,
361            ArmGenerator::new(prog.config().clone()),
362            idx_ret,
363            "aarch64",
364        )
365    }
366
367    fn compile_riscv_fast(mir: &Mir, prog: &mut Program, idx_ret: u32) -> Result<MachineCode<f64>> {
368        Self::compile_fast(
369            mir,
370            prog,
371            RiscV::new(prog.config().clone()),
372            idx_ret,
373            "riscv64",
374        )
375    }
376
377    /**********************************************************/
378
379    #[inline]
380    pub fn exec(&mut self) {
381        if let Some(compiled) = &mut self.compiled {
382            compiled.exec(&self.params[..])
383        } else {
384            self.bytecode.exec(&self.params[..]);
385        }
386    }
387
388    pub fn exec_callable(&mut self, xx: &[f64]) -> f64 {
389        if let Some(compiled) = &mut self.compiled {
390            let mem = compiled.mem_mut();
391            mem[self.first_state..self.first_state + self.count_states].copy_from_slice(xx);
392            compiled.exec(&self.params[..]);
393            compiled.mem()[self.first_obs]
394        } else {
395            let mem = self.bytecode.mem_mut();
396            mem[self.first_state..self.first_state + self.count_states].copy_from_slice(xx);
397            self.bytecode.exec(&self.params[..]);
398            self.bytecode.mem()[self.first_obs]
399        }
400    }
401
402    pub fn prepare_simd(&mut self) {
403        // SIMD compilation is lazy!
404        if self.compiled_simd.is_none() && self.use_simd {
405            if self.config.has_avx() {
406                self.compiled_simd =
407                    Self::compile_avx_simd(&self.bytecode.mir, &mut self.prog).ok();
408            } else if self.config.is_arm64() {
409                self.compiled_simd =
410                    Self::compile_arm_simd(&self.bytecode.mir, &mut self.prog).ok();
411            }
412        };
413    }
414
415    fn prepare_fast(&mut self) {
416        // fast func compilation is lazy!
417        if self.compiled_simd.is_none() && self.can_fast {
418            if self.config.is_amd64() {
419                self.compiled_fast = Self::compile_amd_fast(
420                    &self.bytecode.mir,
421                    &mut self.prog,
422                    self.first_obs as u32,
423                )
424                .ok();
425            } else if self.config.is_arm64() {
426                self.compiled_fast = Self::compile_arm_fast(
427                    &self.bytecode.mir,
428                    &mut self.prog,
429                    self.first_obs as u32,
430                )
431                .ok();
432            } else if self.config.is_riscv64() {
433                self.compiled_fast = Self::compile_riscv_fast(
434                    &self.bytecode.mir,
435                    &mut self.prog,
436                    self.first_obs as u32,
437                )
438                .ok();
439            }
440        };
441    }
442
443    pub fn get_fast(&mut self) -> Option<CompiledFunc<f64>> {
444        self.prepare_fast();
445        self.compiled_fast.as_ref().map(|c| c.func())
446    }
447
448    pub fn exec_vectorized(&mut self, states: &mut Matrix, obs: &mut Matrix) {
449        if let Some(compiled) = &self.compiled {
450            if !compiled.support_indirect() {
451                self.exec_vectorized_simple(states, obs);
452                return;
453            }
454
455            self.prepare_simd();
456
457            if let Some(simd) = &self.compiled_simd {
458                self.exec_vectorized_simd(states, obs, self.use_threads, simd.count_lanes());
459            } else {
460                self.exec_vectorized_scalar(states, obs, self.use_threads);
461            }
462        }
463    }
464
465    pub fn exec_vectorized_simple(&mut self, states: &Matrix, obs: &mut Matrix) {
466        assert!(states.ncols == obs.ncols);
467        let n = states.ncols;
468        let params = &self.params[..];
469
470        if let Some(compiled) = &mut self.compiled {
471            for t in 0..n {
472                {
473                    let mem = compiled.mem_mut();
474                    for i in 0..self.count_states {
475                        mem[self.first_state + i] = states.get(i, t);
476                    }
477                }
478
479                compiled.exec(params);
480
481                {
482                    let mem = compiled.mem_mut();
483                    for i in 0..self.count_obs {
484                        obs.set(i, t, mem[self.first_obs + i]);
485                    }
486                }
487            }
488        } else {
489            for t in 0..n {
490                {
491                    let mem = self.bytecode.mem_mut();
492                    for i in 0..self.count_states {
493                        mem[self.first_state + i] = states.get(i, t);
494                    }
495                }
496
497                self.bytecode.exec(params);
498
499                {
500                    let mem = self.bytecode.mem_mut();
501                    for i in 0..self.count_obs {
502                        obs.set(i, t, mem[self.first_obs + i]);
503                    }
504                }
505            }
506        }
507    }
508
509    fn exec_single(t: usize, v: &Matrix, params: &[f64], f: CompiledFunc<f64>) {
510        let p = v.p.as_ptr();
511        f(std::ptr::null(), p, t, params.as_ptr());
512    }
513
514    pub fn exec_vectorized_scalar(&mut self, states: &mut Matrix, obs: &mut Matrix, threads: bool) {
515        if let Some(compiled) = &mut self.compiled {
516            assert!(states.ncols == obs.ncols);
517            let n = states.ncols;
518            let f = compiled.func();
519            let params = &self.params[..];
520            let v = combine_matrixes(states, obs);
521
522            if threads {
523                (0..n)
524                    .into_par_iter()
525                    .for_each(|t| Self::exec_single(t, &v, params, f));
526            } else {
527                (0..n)
528                    //.into_iter()
529                    .for_each(|t| Self::exec_single(t, &v, params, f));
530            }
531        }
532    }
533
534    pub fn exec_vectorized_simd(
535        &mut self,
536        states: &mut Matrix,
537        obs: &mut Matrix,
538        threads: bool,
539        l: usize,
540    ) {
541        if let Some(compiled) = &mut self.compiled {
542            assert!(states.ncols == obs.ncols);
543            let n = states.ncols;
544            let params = &self.params[..];
545            let n0 = l * (n / l);
546            let v = combine_matrixes(states, obs);
547
548            if let Some(g) = &mut self.compiled_simd {
549                let f = g.func();
550                if threads {
551                    (0..n / l)
552                        .into_par_iter()
553                        .for_each(|t| Self::exec_single(t, &v, params, f));
554                } else {
555                    (0..n / l).for_each(|t| Self::exec_single(t, &v, params, f));
556                }
557            }
558
559            let f = compiled.func();
560
561            if threads {
562                (n0..n)
563                    .into_par_iter()
564                    .for_each(|t| Self::exec_single(t, &v, params, f));
565            } else {
566                (n0..n).for_each(|t| Self::exec_single(t, &v, params, f));
567            }
568        }
569    }
570
571    pub fn dump(&mut self, name: &str, what: &str) -> bool {
572        match what {
573            "scalar" => {
574                if let Some(f) = &self.compiled {
575                    f.dump(name);
576                    true
577                } else {
578                    false
579                }
580            }
581            "simd" => {
582                self.prepare_simd();
583
584                if let Some(f) = &self.compiled_simd {
585                    f.dump(name);
586                    true
587                } else {
588                    false
589                }
590            }
591            "fast" => {
592                self.prepare_fast();
593
594                if let Some(f) = &self.compiled_fast {
595                    f.dump(name);
596                    true
597                } else {
598                    false
599                }
600            }
601            "bytecode" => {
602                self.bytecode.dump(name);
603                true
604            }
605            "stats" => {
606                let size = if let Some(f) = &self.compiled {
607                    f.as_machine().unwrap().size
608                } else {
609                    0
610                };
611                self.bytecode.mir.print_stats(name, size);
612                true
613            }
614            _ => false,
615        }
616    }
617
618    pub fn dumps(&self) -> Vec<u8> {
619        if let Some(f) = &self.compiled {
620            f.dumps()
621        } else {
622            Vec::new()
623        }
624    }
625
626    /************************** save/load ******************************/
627
628    const MAGIC: usize = 0x40568795410d08e9;
629}
630
631fn save_reals(stream: &mut impl Write, reals: &HashSet<Loc>) -> Result<()> {
632    let num_elems = reals.len();
633    stream.write_all(&num_elems.to_le_bytes())?;
634
635    for r in reals.iter() {
636        let b = match r {
637            Loc::Mem(idx) => 0x100000000 | (*idx as usize),
638            Loc::Stack(idx) => 0x200000000 | (*idx as usize),
639            Loc::Param(idx) => 0x300000000 | (*idx as usize),
640        };
641        stream.write_all(&b.to_le_bytes())?;
642    }
643
644    Ok(())
645}
646
647fn load_reals(stream: &mut impl Read) -> Result<HashSet<Loc>> {
648    let mut bytes: [u8; 8] = [0; 8];
649
650    stream.read_exact(&mut bytes)?;
651    let num_elems = usize::from_le_bytes(bytes);
652
653    let mut reals: HashSet<Loc> = HashSet::new();
654
655    for _ in 0..num_elems {
656        stream.read_exact(&mut bytes)?;
657        let b = usize::from_le_bytes(bytes);
658
659        let r = match b >> 32 {
660            1 => Loc::Mem((b & 0xffffffff) as u32),
661            2 => Loc::Stack((b & 0xffffffff) as u32),
662            3 => Loc::Param((b & 0xffffffff) as u32),
663            _ => return Err(anyhow!("invalid loc")),
664        };
665        reals.insert(r);
666    }
667
668    Ok(reals)
669}
670
671impl Storage for Application {
672    fn save(&self, stream: &mut impl Write) -> Result<()> {
673        stream.write_all(&Self::MAGIC.to_le_bytes())?;
674
675        let version: usize = 3;
676        stream.write_all(&version.to_le_bytes())?;
677
678        self.prog.save(stream)?;
679
680        let mut mask: usize = 0;
681
682        if self.compiled.is_some() && self.compiled.as_ref().unwrap().as_machine().is_some() {
683            mask |= 1;
684        };
685
686        if self.compiled_fast.is_some()
687            && self.compiled_fast.as_ref().unwrap().as_machine().is_some()
688        {
689            mask |= 2;
690        }
691
692        if self.compiled_simd.is_some()
693            && self.compiled_simd.as_ref().unwrap().as_machine().is_some()
694        {
695            mask |= 4;
696        }
697
698        stream.write_all(&mask.to_le_bytes())?;
699
700        match &self.original {
701            Some(mir) => mir.save(stream)?,
702            None => self.bytecode.mir.save(stream)?,
703        }
704
705        save_reals(stream, &self.reals)?;
706
707        Ok(())
708    }
709
710    fn load(stream: &mut impl Read, config: &Config) -> Result<Self> {
711        let mut bytes: [u8; 8] = [0; 8];
712
713        stream.read_exact(&mut bytes)?;
714
715        if usize::from_le_bytes(bytes) != Self::MAGIC {
716            return Err(anyhow!("invalid magic number (Application)"));
717        }
718
719        stream.read_exact(&mut bytes)?;
720
721        if usize::from_le_bytes(bytes) != 3 {
722            return Err(anyhow!("invalid sjb version"));
723        }
724
725        let prog = Program::load(stream, config)?;
726
727        stream.read_exact(&mut bytes)?;
728        let mask = usize::from_le_bytes(bytes);
729
730        let mir = Mir::load(stream, prog.config())?;
731
732        let reals = load_reals(stream)?;
733
734        let mut app = Application::with_mir(prog, reals, mir)?;
735
736        if mask & 2 != 0 {
737            app.prepare_fast();
738        }
739
740        if mask & 4 != 0 {
741            app.prepare_simd();
742        }
743
744        Ok(app)
745    }
746}