Skip to main content

mimium_lang/
compiler.rs

1pub mod bytecodegen;
2pub(crate) mod intrinsics;
3pub mod mirgen;
4pub mod parser;
5pub mod translate_staging;
6pub mod typing;
7
8#[cfg(not(target_arch = "wasm32"))]
9pub mod wasmgen;
10
11use serde::{Deserialize, Serialize};
12
13use crate::plugin::{ExtFunTypeInfo, MacroFunction};
14use thiserror::Error;
15
16/// Stage information for multi-stage programming.
17/// Moved from plugin.rs to be shared across compiler modules.
18#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
19pub enum EvalStage {
20    /// Persistent stage - accessible from all stages (like builtins)
21    Persistent,
22    /// Specific stage number
23    Stage(u8),
24}
25
26impl std::fmt::Display for EvalStage {
27    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28        match self {
29            EvalStage::Persistent => write!(f, "persistent"),
30            EvalStage::Stage(n) => write!(f, "{n}"),
31        }
32    }
33}
34
35impl EvalStage {
36    pub fn is_available_in_macro(&self) -> bool {
37        matches!(self, EvalStage::Persistent | EvalStage::Stage(0))
38    }
39    pub fn is_available_in_vm(&self) -> bool {
40        matches!(self, EvalStage::Persistent | EvalStage::Stage(1))
41    }
42
43    /// Format the stage for error messages
44    pub fn format_for_error(&self) -> String {
45        match self {
46            EvalStage::Persistent => "persistent".to_string(),
47            EvalStage::Stage(n) => n.to_string(),
48        }
49    }
50
51    /// Increment the current stage for bracket expressions
52    pub fn increment(self) -> EvalStage {
53        match self {
54            EvalStage::Persistent => EvalStage::Persistent, // Persistent stays persistent
55            EvalStage::Stage(n) => EvalStage::Stage(n + 1),
56        }
57    }
58
59    /// Decrement the current stage for escape expressions
60    pub fn decrement(self) -> EvalStage {
61        match self {
62            EvalStage::Persistent => EvalStage::Persistent, // Persistent stays persistent
63            EvalStage::Stage(n) => EvalStage::Stage(n.saturating_sub(1)),
64        }
65    }
66}
67
68#[derive(Debug, Clone, Error)]
69pub enum ErrorKind {
70    #[error("Type Mismatch, expected {0}, but the actual was {1}.")]
71    TypeMismatch(Type, Type),
72    #[error("Circular loop of type definition")]
73    CircularType,
74    #[error("Tuple index out of range, number of elements are {0} but accessed with {1}.")]
75    IndexOutOfRange(u16, u16),
76    #[error("Index access for non tuple-type {0}.")]
77    IndexForNonTuple(Type),
78    #[error("Variable Not Found.")]
79    VariableNotFound(String),
80    #[error("Feed can take only non-funtion type.")]
81    NonPrimitiveInFeed,
82    #[error("Application to non-function type value.")]
83    NotApplicable,
84    #[error("Array index out of bounds.")]
85    IndexOutOfBounds,
86    #[error("Type error in expression.")]
87    TypeError,
88    #[error("Unknown error.")]
89    Unknown,
90}
91
92#[derive(Debug, Clone, Error)]
93#[error("{0}")]
94pub struct CompileError(pub ErrorKind, pub Span);
95
96impl ReportableError for CompileError {
97    fn get_labels(&self) -> Vec<(crate::utils::metadata::Location, String)> {
98        todo!()
99    }
100}
101
102use std::path::PathBuf;
103
104use mirgen::recursecheck;
105
106use crate::{
107    interner::{ExprNodeId, Symbol, TypeNodeId},
108    mir::Mir,
109    runtime::vm,
110    types::Type,
111    utils::{error::ReportableError, metadata::Span},
112};
113pub fn emit_ast(
114    src: &str,
115    path: Option<PathBuf>,
116) -> Result<ExprNodeId, Vec<Box<dyn ReportableError>>> {
117    let (ast, _module_info, errs) = parser::parse_to_expr(src, path.clone());
118    if errs.is_empty() {
119        let ast = parser::add_global_context(ast, path.clone().unwrap_or_default());
120        let (ast, _errs) =
121            mirgen::convert_pronoun::convert_pronoun(ast, path.clone().unwrap_or_default());
122        Ok(recursecheck::convert_recurse(
123            ast,
124            path.clone().unwrap_or_default(),
125        ))
126    } else {
127        Err(errs)
128    }
129}
130
131#[derive(Clone, Copy, Debug, Default)]
132pub struct Config {
133    pub self_eval_mode: bytecodegen::SelfEvalMode,
134}
135
136pub struct Context {
137    ext_fns: Vec<ExtFunTypeInfo>,
138    macros: Vec<Box<dyn MacroFunction>>,
139    file_path: Option<PathBuf>,
140    config: Config,
141}
142// Compiler implements Send because MacroFunction may modify system plugins but it will never conflict with VM execution
143unsafe impl Send for Context {}
144
145#[derive(Debug, Clone, Copy, Default, PartialEq)]
146pub struct IoChannelInfo {
147    pub input: u32,
148    pub output: u32,
149}
150
151impl Context {
152    pub fn new(
153        ext_fns: impl IntoIterator<Item = ExtFunTypeInfo>,
154        macros: impl IntoIterator<Item = Box<dyn MacroFunction>>,
155        file_path: Option<PathBuf>,
156        config: Config,
157    ) -> Self {
158        Self {
159            ext_fns: ext_fns.into_iter().collect(),
160            macros: macros.into_iter().collect(),
161            file_path,
162            config,
163        }
164    }
165    pub fn get_ext_typeinfos(&self) -> Vec<(Symbol, TypeNodeId)> {
166        self.ext_fns
167            .clone()
168            .into_iter()
169            .map(|ExtFunTypeInfo { name, ty, .. }| (name, ty))
170            .chain(self.macros.iter().map(|m| (m.get_name(), m.get_type())))
171            .collect()
172    }
173    pub fn emit_mir(&self, src: &str) -> Result<Mir, Vec<Box<dyn ReportableError>>> {
174        let path = self.file_path.clone();
175        let (ast, module_info, mut parse_errs) = parser::parse_to_expr(src, path);
176        // let ast = parser::add_global_context(ast, self.file_path.unwrap_or_default());
177        let mir = mirgen::compile_with_module_info(
178            ast,
179            self.get_ext_typeinfos().as_slice(),
180            &self.macros,
181            self.file_path.clone(),
182            module_info,
183        );
184        if parse_errs.is_empty() {
185            mir
186        } else {
187            let _ = mir.map_err(|mut e| {
188                parse_errs.append(&mut e);
189            });
190            Err(parse_errs)
191        }
192    }
193    pub fn emit_bytecode(&self, src: &str) -> Result<vm::Program, Vec<Box<dyn ReportableError>>> {
194        let mir = self.emit_mir(src)?;
195        let config = bytecodegen::Config {
196            self_eval_mode: self.config.self_eval_mode,
197        };
198        Ok(bytecodegen::gen_bytecode(mir, config))
199    }
200
201    /// Compile source code to a WASM module.
202    ///
203    /// Returns the WASM binary bytes together with the DSP function's
204    /// [`StateTreeSkeleton`] (used for state-preserving hot-swap) and the
205    /// I/O channel configuration.
206    #[cfg(not(target_arch = "wasm32"))]
207    pub fn emit_wasm(&self, src: &str) -> Result<WasmOutput, Vec<Box<dyn ReportableError>>> {
208        let mir = self.emit_mir(src)?;
209        let io_channels = mir.get_dsp_iochannels();
210        let dsp_state_skeleton = mir.get_dsp_state_skeleton().cloned();
211        let ext_fns: Vec<crate::plugin::ExtFunTypeInfo> = self.ext_fns.clone();
212        let mut generator = wasmgen::WasmGenerator::new(std::sync::Arc::new(mir), &ext_fns);
213        let bytes = generator.generate().map_err(|e| {
214            vec![Box::new(crate::utils::error::SimpleError {
215                message: e,
216                span: crate::utils::metadata::Location::default(),
217            }) as Box<dyn ReportableError>]
218        })?;
219        Ok(WasmOutput {
220            bytes,
221            dsp_state_skeleton,
222            io_channels,
223            ext_fns,
224        })
225    }
226}
227
228/// Output of WASM compilation via [`Context::emit_wasm`].
229#[cfg(not(target_arch = "wasm32"))]
230pub struct WasmOutput {
231    /// The compiled WASM module binary.
232    pub bytes: Vec<u8>,
233    /// State tree skeleton of the DSP function (for state migration).
234    pub dsp_state_skeleton: Option<state_tree::tree::StateTreeSkeleton<crate::mir::StateType>>,
235    /// I/O channel info extracted from the DSP function signature.
236    pub io_channels: Option<IoChannelInfo>,
237    /// External function type infos required to instantiate the generated module.
238    pub ext_fns: Vec<crate::plugin::ExtFunTypeInfo>,
239}
240
241// pub fn interpret_top(
242//     content: String,
243//     global_ctx: &mut ast_interpreter::Context,
244// ) -> Result<ast_interpreter::Value, Vec<Box<dyn ReportableError>>> {
245//     let ast = emit_ast(&content, None)?;
246//     ast_interpreter::eval_ast(ast, global_ctx).map_err(|e| {
247//         let eb: Box<dyn ReportableError> = Box::new(e);
248//         vec![eb]
249//     })
250// }
251
252#[cfg(test)]
253mod test {
254    use crate::{function, interner::ToSymbol, numeric};
255
256    use super::*;
257    fn get_source() -> &'static str {
258        //type annotation input:float is not necessary ideally
259        // but we have to for now for because of subtyping issue
260        r#"
261fn counter(){
262    self + 1
263}
264fn dsp(input:float){
265    let res = input + counter()
266    (0,res)
267}
268"#
269    }
270
271    fn get_tuple_source() -> &'static str {
272        r#"
273fn dsp(input:(float,float)){
274    input
275}
276"#
277    }
278
279    fn test_context() -> Context {
280        let addfn = ExtFunTypeInfo::new(
281            "add".to_symbol(),
282            function!(vec![numeric!(), numeric!()], numeric!()),
283            EvalStage::Persistent,
284        );
285        let extfns = [addfn];
286        Context::new(extfns, [], None, Config::default())
287    }
288    #[test]
289    fn mir_channelcount() {
290        let src = &get_source();
291        let ctx = test_context();
292        let mir = ctx.emit_mir(src).unwrap();
293        log::trace!("Mir: {mir}");
294        let iochannels = mir.get_dsp_iochannels().unwrap();
295        assert_eq!(iochannels.input, 1);
296        assert_eq!(iochannels.output, 2);
297    }
298
299    #[test]
300    fn mir_tuple_channelcount() {
301        let ctx = test_context();
302        let mir = ctx.emit_mir(get_tuple_source()).unwrap();
303        let iochannels = mir.get_dsp_iochannels().unwrap();
304        assert_eq!(iochannels.input, 2);
305        assert_eq!(iochannels.output, 2);
306    }
307
308    #[test]
309    fn bytecode_channelcount() {
310        let src = &get_source();
311        let ctx = test_context();
312        let prog = ctx.emit_bytecode(src).unwrap();
313        let iochannels = prog.iochannels.unwrap();
314        assert_eq!(iochannels.input, 1);
315        assert_eq!(iochannels.output, 2);
316    }
317
318    #[test]
319    fn bytecode_tuple_channelcount() {
320        let ctx = test_context();
321        let prog = ctx.emit_bytecode(get_tuple_source()).unwrap();
322        let iochannels = prog.iochannels.unwrap();
323        assert_eq!(iochannels.input, 2);
324        assert_eq!(iochannels.output, 2);
325    }
326}