synfx_dsp_jit/
jit.rs

1// Copyright (c) 2022 Weird Constructor <weirdconstructor@gmail.com>
2// This file is a part of synfx-dsp-jit. Released under GPL-3.0-or-later.
3// See README.md and COPYING for details.
4
5use crate::ast::{ASTBinOp, ASTBufOp, ASTFun, ASTNode, ASTLenOp};
6use crate::context::{
7    DSPFunction, DSPNodeContext, DSPNodeSigBit, DSPNodeType, DSPNodeTypeLibrary,
8    AUX_VAR_IDX_ISRATE, AUX_VAR_IDX_RESET, AUX_VAR_IDX_SRATE,
9};
10use cranelift::prelude::types::{F64, F32, I32, I64};
11use cranelift::prelude::InstBuilder;
12use cranelift::prelude::*;
13use cranelift_codegen::ir::immediates::Offset32;
14use cranelift_codegen::ir::UserFuncName;
15use cranelift_codegen::settings::{self, Configurable};
16use cranelift_jit::{JITBuilder, JITModule};
17use cranelift_module::default_libcall_names;
18use cranelift_module::{FuncId, Linkage, Module};
19use std::cell::RefCell;
20use std::collections::HashMap;
21use std::rc::Rc;
22use std::sync::Arc;
23
24/// The Just In Time compiler, that translates a [crate::ASTNode] tree into
25/// machine code in form of a [DSPFunction] structure you can use to execute it.
26///
27/// See also [JIT::compile] for an example.
28pub struct JIT {
29    /// The function builder context, which is reused across multiple
30    /// FunctionBuilder instances.
31    builder_context: FunctionBuilderContext,
32
33    /// The main Cranelift context, which holds the state for codegen. Cranelift
34    /// separates this from `Module` to allow for parallel compilation, with a
35    /// context per thread, though this isn't in the simple demo here.
36    ctx: codegen::Context,
37
38    /// The module, with the jit backend, which manages the JIT'd
39    /// functions.
40    module: Option<JITModule>,
41
42    /// The available DSP node types that an be called by the code.
43    dsp_lib: Rc<RefCell<DSPNodeTypeLibrary>>,
44
45    /// The current [DSPNodeContext] we compile a [DSPFunction] for
46    dsp_ctx: Rc<RefCell<DSPNodeContext>>,
47}
48
49impl JIT {
50    /// Create a new JIT compiler instance.
51    ///
52    /// Because every newly compile function gets it's own fresh module,
53    /// you need to recreate a [JIT] instance for every time you compile
54    /// a function.
55    ///
56    ///```
57    /// use synfx_dsp_jit::*;
58    /// let lib = get_standard_library();
59    /// let ctx = DSPNodeContext::new_ref();
60    ///
61    /// let jit = JIT::new(lib.clone(), ctx.clone());
62    /// // ...
63    /// ctx.borrow_mut().free();
64    ///```
65    pub fn new(
66        dsp_lib: Rc<RefCell<DSPNodeTypeLibrary>>,
67        dsp_ctx: Rc<RefCell<DSPNodeContext>>,
68    ) -> Self {
69        let mut flag_builder = settings::builder();
70        flag_builder
71            .set("use_colocated_libcalls", "false")
72            .expect("Setting 'use_colocated_libcalls' works");
73        // FIXME set back to true once the x64 backend supports it.
74        flag_builder.set("is_pic", "false").expect("Setting 'is_pic' works");
75        let isa_builder = cranelift_native::builder().unwrap_or_else(|msg| {
76            panic!("host machine is not supported: {}", msg);
77        });
78        let isa = isa_builder
79            .finish(settings::Flags::new(flag_builder))
80            .expect("ISA Builder finish works");
81        let mut builder = JITBuilder::with_isa(isa, default_libcall_names());
82
83        dsp_lib
84            .borrow()
85            .for_each(|typ| -> Result<(), JITCompileError> {
86                builder.symbol(typ.name(), typ.function_ptr());
87                Ok(())
88            })
89            .expect("symbol adding works");
90
91        let module = JITModule::new(builder);
92        Self {
93            builder_context: FunctionBuilderContext::new(),
94            ctx: module.make_context(),
95            module: Some(module),
96            dsp_lib,
97            dsp_ctx,
98        }
99    }
100
101    /// Compiles a [crate::ASTFun] / [crate::ASTNode] tree into a [DSPFunction].
102    ///
103    /// There are some checks done by the compiler, see the possible errors in [JITCompileError].
104    /// Otherwise the usage is pretty straight forward, here is another example:
105    ///```
106    /// use synfx_dsp_jit::*;
107    /// let lib = get_standard_library();
108    /// let ctx = DSPNodeContext::new_ref();
109    ///
110    /// let jit = JIT::new(lib.clone(), ctx.clone());
111    /// let mut fun = jit.compile(ASTFun::new(Box::new(ASTNode::Lit(0.424242))))
112    ///     .expect("Compiles fine");
113    ///
114    /// // ...
115    /// fun.init(44100.0, None);
116    /// // ...
117    /// let (mut sig1, mut sig2) = (0.0, 0.0);
118    /// let ret = fun.exec(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, &mut sig1, &mut sig2);
119    /// // ...
120    ///
121    /// // Compile a different function now...
122    /// let jit = JIT::new(lib.clone(), ctx.clone());
123    /// let mut new_fun = jit.compile(ASTFun::new(Box::new(ASTNode::Lit(0.33333))))
124    ///     .expect("Compiles fine");
125    ///
126    /// // Make sure to preserve any (possible) state...
127    /// new_fun.init(44100.0, Some(&fun));
128    /// // ...
129    /// let (mut sig1, mut sig2) = (0.0, 0.0);
130    /// let ret = new_fun.exec(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, &mut sig1, &mut sig2);
131    /// // ...
132    ///
133    /// ctx.borrow_mut().free();
134    ///```
135    pub fn compile(mut self, prog: ASTFun) -> Result<Box<DSPFunction>, JITCompileError> {
136        let module = self.module.as_mut().expect("Module still loaded");
137        let ptr_type = module.target_config().pointer_type();
138
139        for param_idx in 0..prog.param_count() {
140            if prog.param_is_ref(param_idx) {
141                self.ctx.func.signature.params.push(AbiParam::new(ptr_type));
142            } else {
143                self.ctx.func.signature.params.push(AbiParam::new(F64));
144            };
145        }
146
147        self.ctx.func.signature.returns.push(AbiParam::new(F64));
148
149        let id = module
150            .declare_function("dsp", Linkage::Export, &self.ctx.func.signature)
151            .map_err(|e| JITCompileError::DeclareTopFunError(e.to_string()))?;
152
153        self.ctx.func.name = UserFuncName::user(0, id.as_u32());
154
155        // Then, translate the AST nodes into Cranelift IR.
156        self.translate(prog)?;
157
158        let mut module = self.module.take().expect("Module still loaded");
159        module.define_function(id, &mut self.ctx).map_err(|e| {
160            match e {
161                cranelift_module::ModuleError::Compilation(e) => {
162                    JITCompileError::DefineTopFunError(cranelift_codegen::print_errors::pretty_error(
163                        &self.ctx.func,
164                        e,
165                    ))
166                },
167                _ => {
168                    JITCompileError::DefineTopFunError(format!("{:?}", e))
169                }
170            }
171        })?;
172
173        module.clear_context(&mut self.ctx);
174        match module.finalize_definitions() {
175            Ok(()) => (),
176            Err(e) => {
177                return Err(JITCompileError::ModuleError(
178                    format!("{}", e)));
179            }
180        }
181
182        let code = module.get_finalized_function(id);
183
184        let dsp_fun = self
185            .dsp_ctx
186            .borrow_mut()
187            .finalize_dsp_function(code, module)
188            .expect("DSPFunction present in DSPNodeContext.");
189
190        Ok(dsp_fun)
191    }
192
193    fn translate(&mut self, fun: ASTFun) -> Result<(), JITCompileError> {
194        let builder = FunctionBuilder::new(&mut self.ctx.func, &mut self.builder_context);
195
196        let module = self.module.as_mut().expect("Module still loaded");
197        let dsp_lib = self.dsp_lib.clone();
198        let dsp_lib = dsp_lib.borrow();
199        let dsp_ctx = self.dsp_ctx.clone();
200        let mut dsp_ctx = dsp_ctx.borrow_mut();
201
202        let debug = dsp_ctx.debug_enabled();
203
204        let debug_str = {
205            let mut trans = DSPFunctionTranslator::new(&mut *dsp_ctx, &*dsp_lib, builder, module);
206            trans.register_functions()?;
207            trans.translate(fun, debug)?
208        };
209
210        if let Some(debug_str) = debug_str {
211            dsp_ctx.cranelift_ir_dump = debug_str;
212        }
213
214        Ok(())
215    }
216
217    //    pub fn translate_ast_node(&mut self, builder: FunctionBuilder<'a>,
218}
219
220fn constant_lookup(name: &str) -> Option<f64> {
221    match name {
222        "PI" => Some(std::f64::consts::PI),
223        "TAU" => Some(std::f64::consts::TAU),
224        "E" => Some(std::f64::consts::E),
225        "1PI" => Some(std::f64::consts::FRAC_1_PI),
226        "2PI" => Some(std::f64::consts::FRAC_2_PI),
227        "PI2" => Some(std::f64::consts::FRAC_PI_2),
228        "PI3" => Some(std::f64::consts::FRAC_PI_3),
229        "PI4" => Some(std::f64::consts::FRAC_PI_4),
230        "PI6" => Some(std::f64::consts::FRAC_PI_6),
231        "PI8" => Some(std::f64::consts::FRAC_PI_8),
232        "1SQRT2" => Some(std::f64::consts::FRAC_1_SQRT_2),
233        "2SQRT_PI" => Some(std::f64::consts::FRAC_2_SQRT_PI),
234        "LN2" => Some(std::f64::consts::LN_2),
235        "LN10" => Some(std::f64::consts::LN_10),
236        _ => None,
237    }
238}
239
240pub(crate) struct DSPFunctionTranslator<'a, 'b, 'c> {
241    dsp_ctx: &'c mut DSPNodeContext,
242    dsp_lib: &'b DSPNodeTypeLibrary,
243    builder: Option<FunctionBuilder<'a>>,
244    variables: HashMap<String, Variable>,
245    var_index: usize,
246    module: &'a mut JITModule,
247    dsp_node_functions: HashMap<String, (Arc<dyn DSPNodeType>, FuncId)>,
248    ptr_w: u32,
249}
250
251/// Error enum for JIT compilation errors.
252#[derive(Debug, Clone)]
253pub enum JITCompileError {
254    BadDefinedParams,
255    UnknownFunction(String),
256    UndefinedVariable(String),
257    UnknownTable(usize),
258    InvalidReturnValueAccess(String),
259    DeclareTopFunError(String),
260    DefineTopFunError(String),
261    UndefinedDSPNode(String),
262    UnknownBuffer(usize),
263    NoValueBufferWrite(usize),
264    NotEnoughArgsInCall(String, u64),
265    ModuleError(String),
266    NodeStateError(String, u64),
267}
268
269macro_rules! b {
270    ($self: ident) => {
271        $self.builder.as_mut().expect("FunctionBuilder not finalized")
272    }
273}
274
275impl<'a, 'b, 'c> DSPFunctionTranslator<'a, 'b, 'c> {
276    pub fn new(
277        dsp_ctx: &'c mut DSPNodeContext,
278        dsp_lib: &'b DSPNodeTypeLibrary,
279        builder: FunctionBuilder<'a>,
280        module: &'a mut JITModule,
281    ) -> Self {
282        dsp_ctx.init_dsp_function();
283
284        let builder = Some(builder);
285
286        Self {
287            dsp_ctx,
288            dsp_lib,
289            var_index: 0,
290            variables: HashMap::new(),
291            builder,
292            module,
293            dsp_node_functions: HashMap::new(),
294            ptr_w: 8,
295        }
296    }
297
298    pub fn register_functions(&mut self) -> Result<(), JITCompileError> {
299        let ptr_type = self.module.target_config().pointer_type();
300
301        let mut dsp_node_functions = HashMap::new();
302        self.dsp_lib.for_each(|typ| {
303            let mut sig = self.module.make_signature();
304            let mut i = 0;
305            while let Some(bit) = typ.signature(i) {
306                match bit {
307                    DSPNodeSigBit::Value => {
308                        sig.params.push(AbiParam::new(F64));
309                    }
310                    DSPNodeSigBit::DSPStatePtr
311                    | DSPNodeSigBit::NodeStatePtr
312                    | DSPNodeSigBit::MultReturnPtr => {
313                        sig.params.push(AbiParam::new(ptr_type));
314                    }
315                }
316                i += 1;
317            }
318
319            if typ.has_return_value() {
320                sig.returns.push(AbiParam::new(F64));
321            }
322
323            let func_id = self
324                .module
325                .declare_function(typ.name(), cranelift_module::Linkage::Import, &sig)
326                .map_err(|e| JITCompileError::DeclareTopFunError(e.to_string()))?;
327
328            dsp_node_functions.insert(typ.name().to_string(), (typ.clone(), func_id));
329
330            Ok(())
331        })?;
332
333        self.dsp_node_functions = dsp_node_functions;
334
335        Ok(())
336    }
337
338    /// Declare a single variable declaration.
339    fn declare_variable(&mut self, typ: types::Type, name: &str) -> Variable {
340        let var = Variable::new(self.var_index);
341        //d// println!("DECLARE {} = {}", name, self.var_index);
342
343        if !self.variables.contains_key(name) {
344            self.variables.insert(name.into(), var);
345            b!(self).declare_var(var, typ);
346            self.var_index += 1;
347        }
348
349        var
350    }
351
352    fn translate(&mut self, fun: ASTFun, debug: bool) -> Result<Option<String>, JITCompileError> {
353        let ptr_type = self.module.target_config().pointer_type();
354        self.ptr_w = ptr_type.bytes();
355
356        let entry_block = b!(self).create_block();
357        b!(self).append_block_params_for_function_params(entry_block);
358        b!(self).switch_to_block(entry_block);
359        b!(self).seal_block(entry_block);
360
361        self.variables.clear();
362
363        // declare and define parameters:
364        for param_idx in 0..fun.param_count() {
365            let val = b!(self).block_params(entry_block)[param_idx];
366
367            match fun.param_name(param_idx) {
368                Some(param_name) => {
369                    let var = if fun.param_is_ref(param_idx) {
370                        self.declare_variable(ptr_type, param_name)
371                    } else {
372                        self.declare_variable(F64, param_name)
373                    };
374
375                    b!(self).def_var(var, val);
376                }
377                None => {
378                    return Err(JITCompileError::BadDefinedParams);
379                }
380            }
381        }
382
383        // declare and define local variables:
384        for local_name in fun.local_variables().iter() {
385            let zero = b!(self).ins().f64const(0.0);
386            let var = self.declare_variable(F64, local_name);
387            b!(self).def_var(var, zero);
388        }
389
390        let v = self.compile(fun.ast_ref())?;
391
392        b!(self).ins().return_(&[v]);
393
394        let result = if debug {
395            Some(format!("{}", b!(self).func.display()))
396        } else {
397            None
398        };
399
400        self.builder.take().expect("builder not finalized yet").finalize();
401
402        Ok(result)
403    }
404
405    fn ins_b_to_f64(&mut self, v: Value) -> Value {
406//        let bint = self.b!(self).ins().bint(I32, v);
407        b!(self).ins().fcvt_from_uint(F64, v)
408    }
409
410    fn compile(&mut self, ast: &ASTNode) -> Result<Value, JITCompileError> {
411        match ast {
412            ASTNode::Lit(v) => Ok(b!(self).ins().f64const(*v)),
413            ASTNode::Var(name) => {
414                if let Some(c) = constant_lookup(name) {
415                    Ok(b!(self).ins().f64const(c))
416                } else if name.starts_with('&') {
417                    let variable = self
418                        .variables
419                        .get(name)
420                        .ok_or_else(|| JITCompileError::UndefinedVariable(name.to_string()))?;
421                    let ptr = b!(self).use_var(*variable);
422                    Ok(b!(self).ins().load(F64, MemFlags::new(), ptr, 0))
423                } else if name.starts_with('$') {
424                    let aux_vars = self
425                        .variables
426                        .get("&aux")
427                        .ok_or_else(|| JITCompileError::UndefinedVariable("&aux".to_string()))?;
428
429                    let pvs = b!(self).use_var(*aux_vars);
430                    let offs = match &name[..] {
431                        "$srate" => AUX_VAR_IDX_SRATE,
432                        "$israte" => AUX_VAR_IDX_ISRATE,
433                        "$reset" => AUX_VAR_IDX_RESET,
434                        _ => return Err(JITCompileError::UndefinedVariable(name.to_string())),
435                    };
436                    let aux_value = b!(self).ins().load(
437                        F64,
438                        MemFlags::new(),
439                        pvs,
440                        Offset32::new(offs as i32 * F64.bytes() as i32),
441                    );
442                    Ok(aux_value)
443                } else if name.starts_with('*') {
444                    let pv_index = self
445                        .dsp_ctx
446                        .get_persistent_variable_index(name)
447                        .map_err(|_| JITCompileError::UndefinedVariable(name.to_string()))?;
448
449                    let persistent_vars = self
450                        .variables
451                        .get("&pv")
452                        .ok_or_else(|| JITCompileError::UndefinedVariable("&pv".to_string()))?;
453                    let pvs = b!(self).use_var(*persistent_vars);
454                    let pers_value = b!(self).ins().load(
455                        F64,
456                        MemFlags::new(),
457                        pvs,
458                        Offset32::new(pv_index as i32 * F64.bytes() as i32),
459                    );
460                    Ok(pers_value)
461                } else if name.starts_with('%') {
462                    if name.len() > 2 {
463                        return Err(JITCompileError::InvalidReturnValueAccess(name.to_string()));
464                    }
465
466                    let offs: i32 = match name.chars().nth(1) {
467                        Some('1') => 0,
468                        Some('2') => 1,
469                        Some('3') => 2,
470                        Some('4') => 3,
471                        Some('5') => 4,
472                        _ => {
473                            return Err(JITCompileError::InvalidReturnValueAccess(
474                                name.to_string(),
475                            ));
476                        }
477                    };
478
479                    let return_vals = self
480                        .variables
481                        .get("&rv")
482                        .ok_or_else(|| JITCompileError::UndefinedVariable("&rv".to_string()))?;
483                    let rvs = b!(self).use_var(*return_vals);
484                    let ret_value = b!(self).ins().load(
485                        F64,
486                        MemFlags::new(),
487                        rvs,
488                        Offset32::new(offs * F64.bytes() as i32),
489                    );
490                    Ok(ret_value)
491                } else {
492                    let variable = self
493                        .variables
494                        .get(name)
495                        .ok_or_else(|| JITCompileError::UndefinedVariable(name.to_string()))?;
496                    Ok(b!(self).use_var(*variable))
497                }
498            }
499            ASTNode::Assign(name, ast) => {
500                let value = self.compile(ast)?;
501
502                if name.starts_with('&') {
503                    let variable = self
504                        .variables
505                        .get(name)
506                        .ok_or_else(|| JITCompileError::UndefinedVariable(name.to_string()))?;
507                    let ptr = b!(self).use_var(*variable);
508                    b!(self).ins().store(MemFlags::new(), value, ptr, 0);
509                } else if name.starts_with('*') {
510                    let pv_index = self
511                        .dsp_ctx
512                        .get_persistent_variable_index(name)
513                        .map_err(|_| JITCompileError::UndefinedVariable(name.to_string()))?;
514
515                    let persistent_vars = self
516                        .variables
517                        .get("&pv")
518                        .ok_or_else(|| JITCompileError::UndefinedVariable("&pv".to_string()))?;
519                    let pvs = b!(self).use_var(*persistent_vars);
520                    b!(self).ins().store(
521                        MemFlags::new(),
522                        value,
523                        pvs,
524                        Offset32::new(pv_index as i32 * F64.bytes() as i32),
525                    );
526                } else {
527                    let variable = self
528                        .variables
529                        .get(name)
530                        .ok_or_else(|| JITCompileError::UndefinedVariable(name.to_string()))?;
531                    b!(self).def_var(*variable, value);
532                }
533
534                Ok(value)
535            }
536            ASTNode::BinOp(op, a, b) => {
537                let value_a = self.compile(a)?;
538                let value_b = self.compile(b)?;
539                let value = match op {
540                    ASTBinOp::Add => b!(self).ins().fadd(value_a, value_b),
541                    ASTBinOp::Sub => b!(self).ins().fsub(value_a, value_b),
542                    ASTBinOp::Mul => b!(self).ins().fmul(value_a, value_b),
543                    ASTBinOp::Div => b!(self).ins().fdiv(value_a, value_b),
544                    ASTBinOp::Eq => {
545                        let cmp_res = b!(self).ins().fcmp(FloatCC::Equal, value_a, value_b);
546                        self.ins_b_to_f64(cmp_res)
547                    }
548                    ASTBinOp::Ne => {
549                        let cmp_res = b!(self).ins().fcmp(FloatCC::Equal, value_a, value_b);
550                        let bnot = b!(self).ins().bnot(cmp_res);
551//                        let bint = b!(self).ins().bint(I32, bnot);
552                        b!(self).ins().fcvt_from_uint(F64, bnot)
553                    }
554                    ASTBinOp::Ge => {
555                        let cmp_res =
556                            b!(self).ins().fcmp(FloatCC::GreaterThanOrEqual, value_a, value_b);
557                        self.ins_b_to_f64(cmp_res)
558                    }
559                    ASTBinOp::Le => {
560                        let cmp_res =
561                            b!(self).ins().fcmp(FloatCC::LessThanOrEqual, value_a, value_b);
562                        self.ins_b_to_f64(cmp_res)
563                    }
564                    ASTBinOp::Gt => {
565                        let cmp_res =
566                            b!(self).ins().fcmp(FloatCC::GreaterThan, value_a, value_b);
567                        self.ins_b_to_f64(cmp_res)
568                    }
569                    ASTBinOp::Lt => {
570                        let cmp_res = b!(self).ins().fcmp(FloatCC::LessThan, value_a, value_b);
571                        self.ins_b_to_f64(cmp_res)
572                    }
573                };
574
575                Ok(value)
576            }
577            ASTNode::BufDeclare { buf_idx, len } => {
578                if *buf_idx >= self.dsp_ctx.config.buffer_count {
579                    return Err(JITCompileError::UnknownBuffer(*buf_idx));
580                }
581
582                self.dsp_ctx.buffer_declare[*buf_idx] = *len;
583
584                Ok(b!(self).ins().f64const(0.0))
585            },
586            ASTNode::Len(op) => {
587                let (buf_idx, buf_lens) = match op {
588                    ASTLenOp::Buffer(buf_idx) => {
589                        if *buf_idx >= self.dsp_ctx.config.buffer_count {
590                            return Err(JITCompileError::UnknownBuffer(*buf_idx));
591                        }
592
593                        let buf_lens = self.variables.get("&buf_lens").ok_or_else(|| {
594                            JITCompileError::UndefinedVariable("&buf_lens".to_string())
595                        })?;
596
597                        (*buf_idx, buf_lens)
598                    },
599                    ASTLenOp::Table(tbl_idx) => {
600                        let tbl_lens = self.variables.get("&table_lens").ok_or_else(|| {
601                            JITCompileError::UndefinedVariable("&table_lens".to_string())
602                        })?;
603
604                        if *tbl_idx >= self.dsp_ctx.config.tables.len() {
605                            return Err(JITCompileError::UnknownTable(*tbl_idx));
606                        }
607
608                        (*tbl_idx, tbl_lens)
609                    },
610                };
611
612                let lenptr = b!(self).use_var(*buf_lens);
613                let len = b!(self).ins().load(
614                    I64,
615                    MemFlags::new(),
616                    lenptr,
617                    Offset32::new(buf_idx as i32 * self.ptr_w as i32),
618                );
619
620                Ok(b!(self).ins().fcvt_from_uint(F64, len))
621            },
622            ASTNode::BufOp { op, idx, val } => {
623                let idx = self.compile(idx)?;
624
625                let ptr_type = self.module.target_config().pointer_type();
626
627                let (buf_var, buf_idx, buf_lens) = match op {
628                    ASTBufOp::Write(buf_idx)
629                    | ASTBufOp::Read(buf_idx)
630                    | ASTBufOp::ReadLin(buf_idx) => {
631                        let buf_var = self.variables.get("&bufs").ok_or_else(|| {
632                            JITCompileError::UndefinedVariable("&bufs".to_string())
633                        })?;
634
635                        let buf_lens = self.variables.get("&buf_lens").ok_or_else(|| {
636                            JITCompileError::UndefinedVariable("&buf_lens".to_string())
637                        })?;
638
639                        if *buf_idx >= self.dsp_ctx.config.buffer_count {
640                            return Err(JITCompileError::UnknownBuffer(*buf_idx));
641                        }
642
643                        (buf_var, buf_idx, buf_lens)
644                    }
645                    ASTBufOp::TableRead(tbl_idx) | ASTBufOp::TableReadLin(tbl_idx) => {
646                        let buf_var = self.variables.get("&tables").ok_or_else(|| {
647                            JITCompileError::UndefinedVariable("&tables".to_string())
648                        })?;
649
650                        let tbl_lens = self.variables.get("&table_lens").ok_or_else(|| {
651                            JITCompileError::UndefinedVariable("&table_lens".to_string())
652                        })?;
653
654                        if *tbl_idx >= self.dsp_ctx.config.tables.len() {
655                            return Err(JITCompileError::UnknownTable(*tbl_idx));
656                        }
657
658                        (buf_var, tbl_idx, tbl_lens)
659                    }
660                };
661
662                let bptr = b!(self).use_var(*buf_var);
663                let buffer = b!(self).ins().load(
664                    ptr_type,
665                    MemFlags::new(),
666                    bptr,
667                    Offset32::new(*buf_idx as i32 * self.ptr_w as i32),
668                );
669
670                let lenptr = b!(self).use_var(*buf_lens);
671                let len = b!(self).ins().load(
672                    I64,
673                    MemFlags::new(),
674                    lenptr,
675                    Offset32::new(*buf_idx as i32 * self.ptr_w as i32),
676                );
677
678                let orig_idx = idx;
679                let idx = b!(self).ins().floor(idx);
680                let orig_fint_idx = idx;
681                let idx = b!(self).ins().fcvt_to_uint(I64, idx);
682                let orig_int_idx = idx;
683
684                let data_width =
685                    match op {
686                        ASTBufOp::TableReadLin { .. } | ASTBufOp::TableRead { .. } => F32.bytes() as i64,
687                        _ => F64.bytes() as i64
688                    };
689
690                let idx = b!(self).ins().urem(idx, len);
691                let idx = b!(self).ins().imul_imm(idx, data_width);
692                let ptr = b!(self).ins().iadd(buffer, idx);
693
694                match op {
695                    ASTBufOp::Write { .. } => {
696                        let val = val
697                            .as_ref()
698                            .ok_or_else(|| JITCompileError::NoValueBufferWrite(*buf_idx))?;
699                        let val = self.compile(val)?;
700
701                        b!(self).ins().store(MemFlags::new(), val, ptr, 0);
702                        Ok(b!(self).ins().f64const(0.0))
703                    }
704                    ASTBufOp::Read { .. } => {
705                        Ok(b!(self).ins().load(F64, MemFlags::new(), ptr, 0))
706                    }
707                    ASTBufOp::TableRead { .. } => {
708                        let sample = b!(self).ins().load(F32, MemFlags::new(), ptr, 0);
709                        Ok(b!(self).ins().fpromote(F64, sample))
710                    }
711                    ASTBufOp::ReadLin { .. } | ASTBufOp::TableReadLin { .. } => {
712                        let fract = b!(self).ins().fsub(orig_idx, orig_fint_idx);
713                        let idx = b!(self).ins().iadd_imm(orig_int_idx, 1 as i64);
714                        let idx = b!(self).ins().urem(idx, len);
715                        let idx = b!(self).ins().imul_imm(idx, data_width);
716                        let ptr2 = b!(self).ins().iadd(buffer, idx);
717
718                        let (a, b) =
719                            if data_width == (I32.bytes() as i64) {
720                                let a = b!(self).ins().load(F32, MemFlags::new(), ptr, 0);
721                                let b = b!(self).ins().load(F32, MemFlags::new(), ptr2, 0);
722                                let a = b!(self).ins().fpromote(F64, a);
723                                let b = b!(self).ins().fpromote(F64, b);
724                                (a, b)
725                            } else {
726                                let a = b!(self).ins().load(F64, MemFlags::new(), ptr, 0);
727                                let b = b!(self).ins().load(F64, MemFlags::new(), ptr2, 0);
728                                (a, b)
729                            };
730                        let one = b!(self).ins().f64const(1.0);
731                        let fract_1 = b!(self).ins().fsub(one, fract);
732                        let a = b!(self).ins().fmul(a, fract_1);
733                        let b = b!(self).ins().fmul(b, fract);
734                        Ok(b!(self).ins().fadd(a, b))
735                    }
736                }
737            }
738            ASTNode::Call(name, dsp_node_uid, args) => {
739                let func = self
740                    .dsp_node_functions
741                    .get(name)
742                    .ok_or_else(|| JITCompileError::UndefinedDSPNode(name.to_string()))?
743                    .clone();
744                let node_type = func.0;
745                let func_id = func.1;
746
747                let ptr_type = self.module.target_config().pointer_type();
748
749                let mut dsp_node_fun_params = vec![];
750                let mut i = 0;
751                let mut arg_idx = 0;
752                while let Some(bit) = node_type.signature(i) {
753                    match bit {
754                        DSPNodeSigBit::Value => {
755                            if arg_idx >= args.len() {
756                                return Err(JITCompileError::NotEnoughArgsInCall(
757                                    name.to_string(),
758                                    *dsp_node_uid,
759                                ));
760                            }
761                            dsp_node_fun_params.push(self.compile(&args[arg_idx])?);
762                            arg_idx += 1;
763                        }
764                        DSPNodeSigBit::DSPStatePtr => {
765                            let state_var = self.variables.get("&state").ok_or_else(|| {
766                                JITCompileError::UndefinedVariable("&state".to_string())
767                            })?;
768                            dsp_node_fun_params.push(b!(self).use_var(*state_var));
769                        }
770                        DSPNodeSigBit::NodeStatePtr => {
771                            let node_state_index = match self
772                                .dsp_ctx
773                                .add_dsp_node_instance(node_type.clone(), *dsp_node_uid)
774                            {
775                                Err(e) => {
776                                    return Err(JITCompileError::NodeStateError(e, *dsp_node_uid));
777                                }
778                                Ok(idx) => idx,
779                            };
780
781                            let fstate_var = self.variables.get("&fstate").ok_or_else(|| {
782                                JITCompileError::UndefinedVariable("&fstate".to_string())
783                            })?;
784                            let fptr = b!(self).use_var(*fstate_var);
785                            let func_state = b!(self).ins().load(
786                                ptr_type,
787                                MemFlags::new(),
788                                fptr,
789                                Offset32::new(node_state_index as i32 * self.ptr_w as i32),
790                            );
791                            dsp_node_fun_params.push(func_state);
792                        }
793                        DSPNodeSigBit::MultReturnPtr => {
794                            let ret_var = self.variables.get("&rv").ok_or_else(|| {
795                                JITCompileError::UndefinedVariable("&rv".to_string())
796                            })?;
797                            dsp_node_fun_params.push(b!(self).use_var(*ret_var));
798                        }
799                    }
800
801                    i += 1;
802                }
803
804                let local_callee = self.module.declare_func_in_func(func_id, b!(self).func);
805                let call = b!(self).ins().call(local_callee, &dsp_node_fun_params);
806                if node_type.has_return_value() {
807                    Ok(b!(self).inst_results(call)[0])
808                } else {
809                    Ok(b!(self).ins().f64const(0.0))
810                }
811            }
812            ASTNode::If(cond, then, els) => {
813                let condition_value = if let ASTNode::BinOp(op, a, b) = cond.as_ref() {
814                    let val = match op {
815                        ASTBinOp::Eq => {
816                            let a = self.compile(a)?;
817                            let b = self.compile(b)?;
818                            b!(self).ins().fcmp(FloatCC::Equal, a, b)
819                        }
820                        ASTBinOp::Ne => {
821                            let a = self.compile(a)?;
822                            let b = self.compile(b)?;
823                            let eq = b!(self).ins().fcmp(FloatCC::Equal, a, b);
824                            b!(self).ins().bnot(eq)
825                        }
826                        ASTBinOp::Gt => {
827                            let a = self.compile(a)?;
828                            let b = self.compile(b)?;
829                            b!(self).ins().fcmp(FloatCC::GreaterThan, a, b)
830                        }
831                        ASTBinOp::Lt => {
832                            let a = self.compile(a)?;
833                            let b = self.compile(b)?;
834                            b!(self).ins().fcmp(FloatCC::LessThan, a, b)
835                        }
836                        ASTBinOp::Ge => {
837                            let a = self.compile(a)?;
838                            let b = self.compile(b)?;
839                            b!(self).ins().fcmp(FloatCC::GreaterThanOrEqual, a, b)
840                        }
841                        ASTBinOp::Le => {
842                            let a = self.compile(a)?;
843                            let b = self.compile(b)?;
844                            b!(self).ins().fcmp(FloatCC::LessThanOrEqual, a, b)
845                        }
846                        _ => self.compile(cond)?,
847                    };
848
849                    val
850                } else {
851                    let res = self.compile(cond)?;
852                    let cmpv = b!(self).ins().f64const(0.5);
853                    b!(self).ins().fcmp(FloatCC::GreaterThanOrEqual, res, cmpv)
854                };
855
856                let then_block = b!(self).create_block();
857                let else_block = b!(self).create_block();
858                let merge_block = b!(self).create_block();
859
860                // If-else constructs in the toy language have a return value.
861                // In traditional SSA form, this would produce a PHI between
862                // the then and else bodies. Cranelift uses block parameters,
863                // so set up a parameter in the merge block, and we'll pass
864                // the return values to it from the branches.
865                b!(self).append_block_param(merge_block, F64);
866
867                // Test the if condition and conditionally branch.
868                b!(self).ins().brif(condition_value, then_block, &[], else_block, &[]);
869
870                b!(self).switch_to_block(then_block);
871                b!(self).seal_block(then_block);
872                let then_return = self.compile(then)?;
873
874                // Jump to the merge block, passing it the block return value.
875                b!(self).ins().jump(merge_block, &[then_return]);
876
877                b!(self).switch_to_block(else_block);
878                b!(self).seal_block(else_block);
879                let else_return = if let Some(els) = els {
880                    self.compile(els)?
881                } else {
882                    b!(self).ins().f64const(0.0)
883                };
884
885                // Jump to the merge block, passing it the block return value.
886                b!(self).ins().jump(merge_block, &[else_return]);
887
888                // Switch to the merge block for subsequent statements.
889                b!(self).switch_to_block(merge_block);
890
891                // We've now seen all the predecessors of the merge block.
892                b!(self).seal_block(merge_block);
893
894                // Read the value of the if-else by reading the merge block
895                // parameter.
896                let phi = b!(self).block_params(merge_block)[0];
897
898                Ok(phi)
899            }
900            ASTNode::Stmts(stmts) => {
901                let mut value = None;
902                for ast in stmts {
903                    value = Some(self.compile(ast)?);
904                }
905                if let Some(value) = value {
906                    Ok(value)
907                } else {
908                    Ok(b!(self).ins().f64const(0.0))
909                }
910            }
911        }
912    }
913}
914
915/// Returns a [DSPFunction] that does nothing. This can be helpful for initializing
916/// structures you want to send to the DSP thread.
917pub fn get_nop_function(
918    lib: Rc<RefCell<DSPNodeTypeLibrary>>,
919    dsp_ctx: Rc<RefCell<DSPNodeContext>>,
920) -> Box<DSPFunction> {
921    let jit = JIT::new(lib, dsp_ctx);
922    jit.compile(ASTFun::new(Box::new(ASTNode::Lit(0.0)))).expect("No compile error")
923}