Skip to main content

mist_codegen/
class_decl.rs

1use std::collections::HashMap;
2
3use mist_parser::ast::*;
4
5use crate::{Context, GenRust, GetRust, RustCodegen};
6
7pub struct ClassProcessedData {
8    visibility: Visibility,
9    name: Identifier,
10    generics: GenericsDecl,
11    inherits: Option<ExprPath>,
12    self_path: ExprPath,
13    self_ty: TypeExpr,
14    fields: Vec<Spanned<FieldDeclStmt>>,
15    constructor: Spanned<ClassConstructor>,
16    items: Vec<ClassItem>,
17    methods: Vec<Spanned<FunctionDecl>>,
18    v_table: Vec<Identifier>,
19    override_v_table: HashMap<Override, Spanned<Vec<Identifier>>>,
20}
21
22impl ClassProcessedData {
23    pub fn analyze(
24        visibility: &Visibility,
25        name: &Identifier,
26        generics: &GenericsDecl,
27        inherits: &Option<ExprPath>,
28        fields: &Vec<Spanned<FieldDeclStmt>>,
29        constructor: &Spanned<ClassConstructor>,
30        items: &Vec<ClassItem>,
31    ) -> Self {
32        let self_path = ExprPath(vec![ExprPathSegment {
33            ident: name.clone(),
34            generics: generics.clone().into(),
35        }]);
36
37        let self_ty = get_type_from_path(&self_path);
38
39        let methods = items
40            .iter()
41            .filter_map(|item| match item {
42                ClassItem::ImplDecl(_) => None,
43                ClassItem::Method(method) => Some(method.clone()),
44            })
45            .collect::<Vec<Spanned<FunctionDecl>>>();
46
47        let mut v_table = Vec::new();
48        let mut override_v_table = std::collections::HashMap::new();
49
50        for method in &methods {
51            if matches!(method.item.visibility, Visibility::Public) {
52                match &method.item.is_override {
53                    None => {
54                        v_table.push(method.item.name.clone());
55                    }
56                    Some(override_spec) => {
57                        override_v_table
58                            .entry(override_spec.clone())
59                            .or_insert_with(|| Spanned {
60                                line: method.line,
61                                column: method.column,
62                                item: Vec::new(),
63                            })
64                            .item
65                            .push(method.item.name.clone());
66                    }
67                }
68            }
69        }
70
71        ClassProcessedData {
72            visibility: visibility.clone(),
73            name: name.clone(),
74            generics: generics.clone(),
75            inherits: inherits.clone(),
76            self_path,
77            self_ty,
78            fields: fields.clone(),
79            constructor: constructor.clone(),
80            items: items.clone(),
81            methods,
82            v_table,
83            override_v_table,
84        }
85    }
86
87    // ── Stage 2: Code Emission ───────────────────────────────────────
88
89    pub fn emit(&self, ctx: &mut Context, cg: &mut RustCodegen) {
90        ctx.expr_super = self.inherits.clone();
91
92        self.emit_struct_decl(cg);
93        self.emit_impl_block(ctx, cg);
94        self.emit_impl_decls(ctx, cg);
95        self.emit_deref_impls(cg);
96
97        ctx.expr_super = None;
98    }
99
100    fn emit_struct_decl(&self, cg: &mut RustCodegen) {
101        cg.addln(&format!(
102            "{}struct {}{} {{",
103            self.visibility.get_rust(),
104            self.name.get_rust(),
105            self.generics.get_rust()
106        ));
107        cg.indent += 1;
108
109        cg.add_indentedln(
110            "pub _m_oop: (&'static [*const std::ffi::c_void], *mut std::ffi::c_void),",
111        );
112
113        if let Some(ref inherits) = self.inherits {
114            cg.add_indented("pub _super: Box<");
115            cg.add(&get_type_from_path(inherits).get_rust());
116            cg.addln(">,");
117        }
118
119        for field in &self.fields {
120            cg.add_indentedln(&field.get_comment());
121            cg.add_indentedln(&field.item.decl.get_rust());
122        }
123
124        cg.indent -= 1;
125        cg.addln("}\n");
126    }
127
128    fn emit_impl_block(&self, ctx: &mut Context, cg: &mut RustCodegen) {
129        cg.addln(&format!(
130            "impl{} {} {{",
131            self.generics.get_rust(),
132            self.self_ty.get_rust()
133        ));
134        cg.indent += 1;
135
136        self.emit_v_table(cg);
137
138        self.emit_super_v_table(cg);
139        self.emit_super_v_tests(cg);
140
141        self.emit_constructor(ctx, cg);
142        self.emit_methods(ctx, cg);
143
144        cg.indent -= 1;
145        cg.addln("}\n");
146    }
147
148    fn emit_v_table(&self, cg: &mut RustCodegen) {
149        for (i, method_name) in self.v_table.iter().enumerate() {
150            cg.add_indentedln(&format!(
151                "pub const __FN_{}: usize = {i};",
152                method_name.0.to_uppercase()
153            ));
154        }
155
156        cg.add_indentedln(&format!(
157            "pub const __V_TABLE: [*const std::ffi::c_void; {}] = [",
158            self.v_table.len()
159        ));
160        cg.indent += 1;
161
162        for method_name in &self.v_table {
163            cg.add_indented("Self::__m_");
164            cg.add(&method_name.get_rust());
165            cg.add(" as *const std::ffi::c_void");
166            cg.addln(",");
167        }
168
169        cg.indent -= 1;
170        cg.add_indentedln("];");
171    }
172
173    fn emit_super_v_table(&self, cg: &mut RustCodegen) {
174        if self.override_v_table.is_empty() {
175            return;
176        }
177
178        cg.add_indentedln(&format!(
179            "pub const __SUPER_V_TABLES: [&'static [*const std::ffi::c_void]; {}] = [",
180            self.override_v_table.len()
181        ));
182        cg.indent += 1;
183
184        for (override_tier, overriden_method_idents) in &self.override_v_table {
185            let target_path = match &override_tier.0 {
186                Some(path) => path.clone(),
187                None => {
188                    if let Some(parent_path) = &self.inherits {
189                        parent_path.clone()
190                    } else {
191                        continue; // Safeguard if AST has a dangling override without inheritance
192                    }
193                }
194            };
195
196            let target_rust_path = target_path.get_rust();
197
198            // 2. Emit the block for this specific index table
199            cg.add_indentedln("&{");
200            cg.indent += 1;
201
202            // Initialize this sub-table with the target parent class's base vtable
203            cg.add_indentedln(&format!("let mut table = {}::__V_TABLE;", target_rust_path));
204
205            // Patch the slots for every method registered under this specific override tier
206            for method_ident in &overriden_method_idents.item {
207                cg.add_indentedln(&format!(
208                    "table[{}::__FN_{}] = {}::__m_{} as *const std::ffi::c_void;",
209                    target_rust_path,
210                    method_ident.0.to_uppercase(),
211                    self.self_path.get_rust(),
212                    method_ident.get_rust()
213                ));
214            }
215
216            cg.add_indentedln("table");
217            cg.indent -= 1;
218            cg.add_indentedln("},");
219        }
220
221        cg.indent -= 1;
222        cg.add_indentedln("];");
223    }
224
225    fn emit_super_v_tests(&self, cg: &mut RustCodegen) {
226        cg.add_indentedln(&format!("const fn __test_vt() {{"));
227        cg.indent += 1;
228
229        for (override_tier, _) in self.override_v_table.iter() {
230            let target_path = match &override_tier.0 {
231                Some(path) => path.clone(),
232
233                None => {
234                    if let Some(parent_path) = &self.inherits {
235                        parent_path.clone()
236                    } else {
237                        continue; // Safeguard if AST has a dangling override without inheritance
238                    }
239                }
240            };
241
242            let target_rust_path = target_path.get_rust();
243
244            for method in &self.methods {
245                if method.item.is_override.as_ref() == Some(override_tier) {
246                    let mut params = method
247                        .item
248                        .params
249                        .0
250                        .clone()
251                        .into_iter()
252                        .filter_map(|v| v.type_)
253                        .collect::<Vec<_>>();
254
255                    if params.is_empty() {
256                        continue;
257                    }
258
259                    if let TypeExpr::Ref { mutable, .. } = params.remove(0) {
260                        cg.add_indentedln(&method.get_comment());
261                        cg.add_indented(&format!("{}::__m_", target_rust_path));
262                        cg.add(&method.item.name.get_rust());
263                        cg.add(" as ");
264
265                        params.insert(
266                            0,
267                            TypeExpr::Ref {
268                                lifetime: None,
269                                mutable,
270                                ty: Box::new(get_type_from_path(&target_path)),
271                            },
272                        );
273
274                        cg.add(
275                            &TypeExpr::StaticFn(
276                                params,
277                                method.item.return_type.clone().map(Box::new),
278                            )
279                            .get_rust(),
280                        );
281                        cg.addln(";");
282                    }
283                }
284            }
285        }
286
287        cg.indent -= 1;
288        cg.add_indentedln("}");
289    }
290
291    fn emit_constructor(&self, ctx: &mut Context, cg: &mut RustCodegen) {
292        let constructor_comment = self.constructor.get_comment();
293
294        cg.add_indentedln("#[allow(invalid_value)]");
295        cg.add_indentedln(&constructor_comment);
296
297        cg.add_indented(&format!(
298            "{}fn new{}(",
299            self.constructor.item.visibility.get_rust(),
300            self.constructor.item.generics.get_rust()
301        ));
302
303        let params = self
304            .constructor
305            .item
306            .params
307            .0
308            .clone()
309            .into_iter()
310            .enumerate()
311            .map(|(idx, mut v)| {
312                v.name = construct_pattern(&v.name, idx);
313                (idx, v)
314            })
315            .collect::<Vec<_>>();
316
317        for (i, param) in &params {
318            if *i > 0 {
319                cg.add(", ");
320            }
321            param.gen_rust(ctx, cg);
322        }
323
324        cg.addln(") -> Box<Self> {");
325        cg.indent += 1;
326
327        cg.add_indentedln("let mut this = Box::new(unsafe { std::mem::MaybeUninit::<Self>::zeroed().assume_init() });");
328        cg.add_indentedln("let this_ptr = &mut *this as *mut Self as *mut std::ffi::c_void;");
329        cg.add_indentedln("this._m_oop = (&Self::__V_TABLE, this_ptr);");
330
331        // Inline field declarations and initializers
332        for field in &self.fields {
333            let comment = field.get_comment();
334
335            if let Some(init) = &field.item.init {
336                cg.add_indentedln(&comment);
337                cg.add_indentedln(&format!("this.{} = ", field.item.decl.name.get_rust()));
338                init.gen_rust(ctx, cg);
339            }
340        }
341
342        cg.add_indented("this.constructor(");
343        for (i, param) in &params {
344            if *i > 0 {
345                cg.add(", ");
346            }
347            ctx.expr_ensure_semicolon = false;
348            param.name.gen_rust(ctx, cg);
349        }
350        cg.addln(");");
351
352        if self.inherits.is_some() && !self.override_v_table.is_empty() {
353            for (idx, (override_tier, v)) in self.override_v_table.iter().enumerate() {
354                match &override_tier.0 {
355                    // Direct base class layout updates
356                    None => {
357                        cg.add_indentedln(&format!(
358                            "this._super._m_oop.0 = Self::__SUPER_V_TABLES[{}];",
359                            idx
360                        ));
361                    }
362                    // Deep ancestor trait table updates
363                    Some(path) => {
364                        cg.add_indentedln(&v.get_comment());
365                        cg.add_indentedln(&format!(
366                            "(|v: &mut {}| {{v._m_oop.0 = Self::__SUPER_V_TABLES[{}];}})(&mut this);",
367                            path.get_rust(),
368                            idx
369                        ));
370                    }
371                }
372            }
373        }
374
375        cg.add_indentedln(&constructor_comment);
376
377        cg.add_indentedln("this");
378        cg.indent -= 1;
379        cg.add_indentedln("}\n");
380
381        // Generate matching inner initialization body block
382        let mut constructor_params = vec![VarDecl {
383            name: Pattern::Path(false, Path(vec![Identifier(String::from("self"))])),
384            type_: Some(TypeExpr::Ref {
385                lifetime: None,
386                mutable: true,
387                ty: Box::new(TypeExpr::Path(
388                    Path(vec![Identifier(String::from("Self"))]),
389                    None,
390                )),
391            }),
392        }];
393
394        constructor_params.append(&mut self.constructor.item.params.0.clone());
395
396        Spanned {
397            line: self.constructor.line,
398            column: self.constructor.column,
399            item: FunctionDecl {
400                visibility: self.constructor.item.visibility.clone(),
401                is_override: None,
402                name: Identifier(String::from("constructor")),
403                generics: self.constructor.item.generics.clone(),
404                params: ParamList(constructor_params),
405                return_type: Some(TypeExpr::Tuple(Vec::new())),
406                body: Some(self.constructor.item.body.clone()),
407            },
408        }
409        .gen_rust(ctx, cg);
410    }
411
412    fn emit_methods(&self, ctx: &mut Context, cg: &mut RustCodegen) {
413        for method in &self.methods {
414            match method.item.visibility {
415                Visibility::Public => {
416                    if method.item.is_override.is_none() {
417                        gen_method_point(&method.item, ctx, cg);
418                    }
419
420                    let mut prefixed = method.clone();
421                    prefixed.item.name.0.insert_str(0, "__m_");
422                    prefixed.gen_rust(ctx, cg);
423                }
424                _ => {
425                    method.gen_rust(ctx, cg);
426                }
427            }
428        }
429    }
430
431    fn emit_impl_decls(&self, ctx: &mut Context, cg: &mut RustCodegen) {
432        for item in &self.items {
433            if let ClassItem::ImplDecl(impl_) = item {
434                let mut impl_ = impl_.clone();
435                impl_.item.trait_ = Some(impl_.item.target);
436                impl_.item.target = TypeExpr::Path(Path(vec![self.name.clone()]), None);
437                impl_.gen_rust(ctx, cg);
438            }
439        }
440    }
441
442    fn emit_deref_impls(&self, cg: &mut RustCodegen) {
443        if let Some(ref inherits) = self.inherits {
444            let generics_str = self.generics.get_rust();
445            let generics_expr_str = Generics::from(self.generics.clone()).get_rust();
446
447            cg.add(&format!(
448                "impl{} std::ops::Deref for {}{}",
449                generics_str,
450                self.name.get_rust(),
451                generics_expr_str
452            ));
453
454            cg.addln(" {");
455            cg.indent += 1;
456
457            cg.add_indented("type Target = ");
458            cg.add(&inherits.get_rust());
459            cg.addln(";");
460
461            cg.add_indentedln("fn deref(&self) -> &Self::Target { &self._super }");
462
463            cg.indent -= 1;
464            cg.addln("}");
465
466            cg.add(&format!(
467                "impl{} std::ops::DerefMut for {}{}",
468                generics_str,
469                self.name.get_rust(),
470                generics_expr_str
471            ));
472
473            cg.addln(" {");
474            cg.indent += 1;
475
476            cg.add_indentedln("fn deref_mut(&mut self) -> &mut Self::Target { &mut self._super }");
477
478            cg.indent -= 1;
479            cg.addln("}");
480        }
481    }
482}
483
484pub fn class_decl(
485    ctx: &mut Context,
486    cg: &mut RustCodegen,
487    visibility: &Visibility,
488    name: &Identifier,
489    generics: &GenericsDecl,
490    inherits: &Option<ExprPath>,
491    fields: &Vec<Spanned<FieldDeclStmt>>,
492    constructor: &Spanned<ClassConstructor>,
493    items: &Vec<ClassItem>,
494) {
495    let data = ClassProcessedData::analyze(
496        visibility,
497        name,
498        generics,
499        inherits,
500        fields,
501        constructor,
502        items,
503    );
504    data.emit(ctx, cg);
505}
506
507fn construct_pattern(pat: &Pattern, idx: usize) -> Pattern {
508    match pat {
509        Pattern::Literal(v) => Pattern::Literal(v.clone()),
510        Pattern::Path(is_mut, v) => Pattern::Path(*is_mut, v.clone().into()),
511        _ => Pattern::Path(false, Path(vec![Identifier(format!("_{idx}"))])),
512    }
513}
514
515pub fn gen_method_point(method: &FunctionDecl, ctx: &mut Context, cg: &mut RustCodegen) {
516    cg.add(&format!(
517        "{}fn {}{}(",
518        method.visibility.get_rust(),
519        method.name.get_rust(),
520        method.generics.get_rust(),
521    ));
522
523    let params = method
524        .params
525        .0
526        .clone()
527        .into_iter()
528        .enumerate()
529        .map(|(idx, mut v)| {
530            v.name = construct_pattern(&v.name, idx);
531            (idx, v)
532        })
533        .collect::<Vec<_>>();
534
535    for (i, param) in &params {
536        if *i > 0 {
537            cg.add(", ");
538        }
539        param.gen_rust(ctx, cg);
540    }
541
542    cg.add(") ");
543    if let Some(return_type) = &method.return_type {
544        cg.add("-> ");
545        cg.add(&return_type.get_rust());
546    }
547
548    cg.addln("{");
549    cg.indent += 1;
550
551    cg.add_indentedln("unsafe {");
552    cg.indent += 1;
553
554    cg.add_indentedln(&format!(
555        "let func_ptr = self._m_oop.0[Self::__FN_{}];",
556        method.name.0.to_uppercase()
557    ));
558
559    cg.add_indented("let func: ");
560
561    let mut param_types: Vec<TypeExpr> = method
562        .params
563        .clone()
564        .0
565        .into_iter()
566        .filter_map(|v| v.type_)
567        .collect();
568
569    if !param_types.is_empty() {
570        param_types.remove(0);
571    }
572
573    param_types.insert(
574        0,
575        TypeExpr::UnsafePtr {
576            mutable: true,
577            ty: Box::new(TypeExpr::Path(
578                Path(vec![
579                    Identifier(String::from("std")),
580                    Identifier(String::from("ffi")),
581                    Identifier(String::from("c_void")),
582                ]),
583                None,
584            )),
585        },
586    );
587
588    cg.add(&TypeExpr::StaticFn(param_types, method.return_type.clone().map(Box::new)).get_rust());
589    cg.addln(" = std::mem::transmute(func_ptr);");
590
591    cg.add_indented("func(self._m_oop.1");
592
593    for (i, param) in &params {
594        if *i == 0 {
595            continue; // self._m_oop.1 already fulfills it
596        }
597        cg.add(", ");
598        ctx.expr_ensure_semicolon = false;
599        param.name.gen_rust(ctx, cg);
600    }
601
602    cg.addln(")");
603
604    cg.indent -= 1;
605    cg.add_indentedln("}");
606
607    cg.indent -= 1;
608    cg.add_indentedln("}");
609}
610
611pub fn get_type_from_path(path: &ExprPath) -> TypeExpr {
612    TypeExpr::Path(
613        Path(path.0.iter().map(|v| v.ident.clone()).collect::<Vec<_>>()),
614        path.0.last().unwrap().generics.clone(),
615    )
616}