mimium_lang/
mir.rs

1// Mid-level intermediate representation that is more like imperative form than hir.
2use crate::{
3    compiler::IoChannelInfo,
4    interner::{Symbol, TypeNodeId},
5    types::{PType, RecordTypeField, Type, TypeSize},
6};
7use std::{cell::OnceCell, path::PathBuf, sync::Arc};
8// Import StateTreeSkeleton for function state information
9use state_tree::tree::{SizedType, StateTreeSkeleton};
10
11pub mod print;
12
13// #[derive(Debug, Clone, PartialEq)]
14// pub struct Global(VPtr);
15
16#[derive(Debug, Clone, PartialEq, Eq, Hash)]
17pub struct Argument(pub Symbol, pub TypeNodeId);
18
19pub type VReg = u64;
20#[derive(Debug, PartialEq, Eq, Hash, Clone)]
21pub enum Value {
22    Global(VPtr),
23    // Argument(usize, Arc<Argument>), //index,
24    Argument(usize),
25    /// holds SSA index(position in infinite registers)
26    Register(VReg),
27    State(VPtr),
28    // idx of the function in the program
29    Function(usize),
30    /// native function (Rust function item or closure)
31    ExtFunction(Symbol, TypeNodeId),
32    /// internal state
33    None,
34}
35
36pub type VPtr = Arc<Value>;
37
38#[derive(Debug, Clone, PartialEq)]
39pub enum Instruction {
40    Uinteger(u64),
41    Integer(i64),
42    //constant float
43    Float(f64),
44    String(Symbol),
45    // allocate memory from stack depending on the size
46    Alloc(TypeNodeId),
47    // load value to register from the pointer type
48    Load(VPtr, TypeNodeId),
49    // Store value to stack(destination,source, type)
50    Store(VPtr, VPtr, TypeNodeId),
51    // Instruction for computing destination address like LLVM's GetElementPtr.
52    // This instruction does no actual computation on runtime.
53    GetElement {
54        value: VPtr,
55        ty: TypeNodeId, // type of the composite value like tuple or struct
56        tuple_offset: u64,
57    },
58    // call function, arguments, type of return value
59    Call(VPtr, Vec<(VPtr, TypeNodeId)>, TypeNodeId),
60    CallCls(VPtr, Vec<(VPtr, TypeNodeId)>, TypeNodeId),
61    GetGlobal(VPtr, TypeNodeId),
62    SetGlobal(VPtr, VPtr, TypeNodeId),
63    // make closure with upindexes
64    Closure(VPtr),
65    //closes upvalues of specific closure. Always inserted right before Return instruction.
66    CloseUpValues(VPtr, TypeNodeId),
67    //label to funcproto  and localvar offset?
68    GetUpValue(u64, TypeNodeId),
69    SetUpValue(u64, VPtr, TypeNodeId),
70    //internal state: feed and delay
71    PushStateOffset(u64),
72    PopStateOffset(u64),
73    //load internal state to register(destination)
74    GetState(TypeNodeId),
75
76    //condition,  basic block index for then statement, else statement, and merge block
77    JmpIf(VPtr, u64, u64, u64),
78    // basic block index (for return statement)
79    Jmp(i16),
80    //merge
81    Phi(VPtr, VPtr),
82
83    Return(VPtr, TypeNodeId),
84    //value to update state
85    ReturnFeed(VPtr, TypeNodeId),
86
87    Delay(u64, VPtr, VPtr),
88    Mem(VPtr),
89
90    // Primitive Operations
91    AddF(VPtr, VPtr),
92    SubF(VPtr, VPtr),
93    MulF(VPtr, VPtr),
94    DivF(VPtr, VPtr),
95    ModF(VPtr, VPtr),
96    NegF(VPtr),
97    AbsF(VPtr),
98    SinF(VPtr),
99    CosF(VPtr),
100    PowF(VPtr, VPtr),
101    LogF(VPtr),
102    SqrtF(VPtr),
103
104    // Primitive Operations for int
105    AddI(VPtr, VPtr),
106    SubI(VPtr, VPtr),
107    MulI(VPtr, VPtr),
108    DivI(VPtr, VPtr),
109    ModI(VPtr, VPtr),
110    NegI(VPtr),
111    AbsI(VPtr),
112
113    PowI(VPtr),
114    LogI(VPtr, VPtr),
115    // primitive Operations for bool
116    Not(VPtr),
117    Eq(VPtr, VPtr),
118    Ne(VPtr, VPtr),
119    Gt(VPtr, VPtr),
120    Ge(VPtr, VPtr),
121    Lt(VPtr, VPtr),
122    Le(VPtr, VPtr),
123    And(VPtr, VPtr),
124    Or(VPtr, VPtr),
125
126    CastFtoI(VPtr),
127    CastItoF(VPtr),
128    CastItoB(VPtr),
129
130    /// Array Literal. Array in mimium is basically variable-length and immutable.
131    /// So VM may allocate array in speacial heap area, or treat it as just as like static data in the program.
132    /// MIR does not distinguish how the vm treats the array.
133    Array(Vec<VPtr>, TypeNodeId), // array literal, values and type of the array
134    GetArrayElem(VPtr, VPtr, TypeNodeId), // get array element, at specific index , type of the element
135    SetArrayElem(VPtr, VPtr, VPtr, TypeNodeId), // set array element at index
136    Error,
137}
138
139#[derive(Debug, Default, Clone, PartialEq)]
140pub struct Block(pub Vec<(VPtr, Instruction)>);
141
142impl Block {
143    pub fn get_total_stateskeleton(
144        &self,
145        functions: &[Function],
146    ) -> Vec<StateTreeSkeleton<StateType>> {
147        let mut children = vec![];
148        for (_v, instr) in &self.0 {
149            match instr {
150                Instruction::Delay(len, _, _) => {
151                    children.push(Box::new(StateTreeSkeleton::Delay { len: *len }))
152                }
153                Instruction::Mem(_) => {
154                    children.push(Box::new(StateTreeSkeleton::Mem(StateType(1))))
155                }
156                Instruction::ReturnFeed(_, ty) => {
157                    children.push(Box::new(StateTreeSkeleton::Feed(StateType::from(*ty))))
158                }
159                Instruction::Call(v, _, _) => {
160                    if let Value::Function(idx) = **v {
161                        let func = &functions[idx];
162                        if func.is_stateful() {
163                            children.push(Box::new(func.state_skeleton.clone()))
164                        }
165                    }
166                }
167                _ => {}
168            }
169        }
170        children.into_iter().map(|e| *e).collect()
171    }
172}
173
174#[derive(Debug, Clone, PartialEq)]
175pub enum UpIndex {
176    Local(usize),   // index of local variables in upper functions
177    Upvalue(usize), // index of upvalues in upper functions
178}
179
180#[derive(Clone, Copy, Debug, PartialEq)]
181pub struct OpenUpValue {
182    pub pos: usize,
183    pub size: TypeSize,
184    pub is_closure: bool,
185}
186
187#[derive(Debug, Clone, PartialEq)]
188pub struct StateType(pub u64);
189impl state_tree::tree::SizedType for StateType {
190    fn word_size(&self) -> u64 {
191        self.0
192    }
193}
194impl From<TypeNodeId> for StateType {
195    fn from(t: TypeNodeId) -> Self {
196        match t.to_type() {
197            Type::Primitive(PType::Unit) => StateType(0),
198            Type::Primitive(PType::Numeric) | Type::Function { .. } => StateType(1),
199            Type::Record(fields) => StateType(
200                fields
201                    .iter()
202                    .map(|RecordTypeField { ty, .. }| ty.word_size())
203                    .sum(),
204            ),
205            Type::Tuple(elems) => StateType(elems.iter().map(|ty| ty.word_size()).sum()),
206            Type::Array(_elem_ty) => StateType(1),
207            _ => todo!(),
208        }
209    }
210}
211
212#[derive(Debug, Clone, PartialEq)]
213pub struct Function {
214    pub index: usize,
215    pub label: Symbol,
216    pub args: Vec<Argument>,
217    // pub argtypes: Vec<TypeNodeId>,
218    pub return_type: OnceCell<TypeNodeId>, // TODO: None is the state when the type is not inferred yet.
219    pub upindexes: Vec<Arc<Value>>,
220    pub upperfn_i: Option<usize>,
221    pub body: Vec<Block>,
222    /// StateTree skeleton information for this function's state layout
223    pub state_skeleton: StateTreeSkeleton<StateType>,
224}
225
226impl Function {
227    pub fn new(
228        index: usize,
229        name: Symbol,
230        args: &[Argument],
231        // argtypes: &[TypeNodeId],
232        state_skeleton: Vec<StateTreeSkeleton<StateType>>,
233        upperfn_i: Option<usize>,
234    ) -> Self {
235        let state_boxed = state_skeleton.into_iter().map(Box::new).collect();
236        Self {
237            index,
238            label: name,
239            args: args.to_vec(),
240            // argtypes: argtypes.to_vec(),
241            return_type: OnceCell::new(),
242            upindexes: vec![],
243            upperfn_i,
244            body: vec![Block::default()],
245            state_skeleton: StateTreeSkeleton::FnCall(state_boxed),
246        }
247    }
248    pub fn add_new_basicblock(&mut self) -> usize {
249        self.body.push(Block(vec![]));
250        self.body.len() - 1
251    }
252    pub fn get_argtypes(&self) -> Vec<TypeNodeId> {
253        self.args.iter().map(|a| a.1).collect()
254    }
255    pub fn get_or_insert_upvalue(&mut self, v: &Arc<Value>) -> usize {
256        self.upindexes
257            .iter()
258            .position(|vt| v == vt)
259            .unwrap_or_else(|| {
260                self.upindexes.push(v.clone());
261                self.upindexes.len() - 1
262            })
263    }
264    pub fn push_state_skeleton(&mut self, skeleton: StateTreeSkeleton<StateType>) {
265        if let StateTreeSkeleton::FnCall(children) = &mut self.state_skeleton {
266            children.push(Box::new(skeleton))
267        } else {
268            panic!("State skeleton for function must be FnCall type");
269        }
270    }
271    pub fn is_stateful(&self) -> bool {
272        if let StateTreeSkeleton::FnCall(children) = &self.state_skeleton {
273            !children.is_empty()
274        } else {
275            panic!("State skeleton for function must be FnCall type");
276        }
277    }
278}
279
280#[derive(Debug, Clone, Default)]
281pub struct Mir {
282    pub functions: Vec<Function>,
283    pub file_path: Option<PathBuf>,
284}
285
286impl Mir {
287    pub fn new(file_path: Option<PathBuf>) -> Self {
288        Self {
289            file_path,
290            ..Default::default()
291        }
292    }
293    pub fn get_dsp_iochannels(&self) -> Option<IoChannelInfo> {
294        self.functions
295            .iter()
296            .find(|f| f.label.as_str() == "dsp")
297            .and_then(|f| {
298                // log::info!("input_type:{:?} output_type:{:?}", f.get_argtypes().as_slice(), f.return_type.get().as_ref());
299                let input = match f.get_argtypes().as_slice() {
300                    [] => Some(0),
301                    [t] => t.to_type().get_iochannel_count(),
302                    _ => None,
303                };
304                let output = f
305                    .return_type
306                    .get()
307                    .and_then(|t| t.to_type().get_iochannel_count());
308                input.and_then(|input| output.map(|output| IoChannelInfo { input, output }))
309            })
310    }
311
312    /// Get the StateTreeSkeleton for a specific function by name
313    pub fn get_function_state_skeleton(
314        &self,
315        function_name: &str,
316    ) -> Option<&StateTreeSkeleton<StateType>> {
317        self.functions.iter().find_map(|f| {
318            if f.label.as_str() == function_name {
319                Some(&f.state_skeleton)
320            } else {
321                None
322            }
323        })
324    }
325
326    /// Get the StateTreeSkeleton for the dsp function (commonly used for audio processing)
327    pub fn get_dsp_state_skeleton(&self) -> Option<&StateTreeSkeleton<StateType>> {
328        self.get_function_state_skeleton("dsp")
329    }
330}