formalang 0.0.3-beta

FormaLang compiler frontend: lexer, parser, semantic analyzer, and IR lowering.
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
//! Phase 2 / 2b / 2c: rewrite IR references after specialisation.
//!
//! - [`rewrite_module`] turns every `ResolvedType::Generic` into the
//!   concrete `Struct/Enum/Trait` id of its cloned specialisation, plus
//!   matching rewrites of `IrTraitRef` slots (constraints and impl trait
//!   refs).
//! - [`specialise_impls`] clones each generic-targeting impl block once
//!   per specialised target with the body's `TypeParam` slots
//!   substituted, returning the `(orig_idx, spec_target) → new_idx` map.
//! - [`rewrite_dispatch_impl_ids`] retargets `DispatchKind::Static` calls
//!   onto the cloned impls.
//! - [`devirtualise_concrete_receivers`] resolves any
//!   `DispatchKind::Virtual` whose receiver is now a concrete type to
//!   `Static` against the specialised impl.

use std::collections::HashMap;

use crate::ir::{GenericBase, IrExpr, IrGenericParam, IrImpl, IrModule, IrTraitRef, ResolvedType};

use super::expr_walk::iter_expr_children_mut;
use super::specialise::{substitute_expr_types, substitute_type, Instantiation};
use super::walkers::{walk_function_types_mut, walk_module_types_mut};

/// Map from `(original impl index, specialised target)` to the new
/// impl index in `module.impls` after Phase 2b.
pub(super) type ImplRemap = HashMap<(usize, GenericBase), usize>;

/// Phase 2: rewrite every `ResolvedType::Generic` to its specialised
/// concrete base.
pub(super) fn rewrite_module(module: &mut IrModule, mapping: &HashMap<Instantiation, GenericBase>) {
    {
        let rewrite = |ty: &mut ResolvedType| rewrite_type(ty, mapping);
        walk_module_types_mut(module, rewrite);
    }
    // Phase E: rewrite generic-trait references on IrTraitRef slots
    // that don't live inside ResolvedType. After this, every
    // constraint and impl-trait-ref with non-empty args points at
    // its specialised trait id with the args slot cleared (the
    // specialised trait isn't generic any more).
    rewrite_trait_refs(module, mapping);
}

fn rewrite_trait_ref(tr: &mut IrTraitRef, mapping: &HashMap<Instantiation, GenericBase>) {
    if tr.args.is_empty() {
        return;
    }
    // Rewrite nested generic args first so the lookup key matches
    // the post-rewrite shape stored in `mapping`.
    for a in &mut tr.args {
        rewrite_type(a, mapping);
    }
    let key = (GenericBase::Trait(tr.trait_id), tr.args.clone());
    if let Some(GenericBase::Trait(new_id)) = mapping.get(&key).copied() {
        tr.trait_id = new_id;
        tr.args.clear();
    }
}

fn rewrite_trait_refs(module: &mut IrModule, mapping: &HashMap<Instantiation, GenericBase>) {
    let rewrite_params = |params: &mut [IrGenericParam],
                          mapping: &HashMap<Instantiation, GenericBase>| {
        for p in params {
            for c in &mut p.constraints {
                rewrite_trait_ref(c, mapping);
            }
        }
    };
    for s in &mut module.structs {
        for tr in &mut s.traits {
            rewrite_trait_ref(tr, mapping);
        }
        rewrite_params(&mut s.generic_params, mapping);
    }
    for e in &mut module.enums {
        rewrite_params(&mut e.generic_params, mapping);
    }
    for t in &mut module.traits {
        rewrite_params(&mut t.generic_params, mapping);
    }
    for imp in &mut module.impls {
        rewrite_params(&mut imp.generic_params, mapping);
        if let Some(tr) = &mut imp.trait_ref {
            rewrite_trait_ref(tr, mapping);
        }
    }
    for f in &mut module.functions {
        rewrite_params(&mut f.generic_params, mapping);
    }
}

pub(super) fn rewrite_type(ty: &mut ResolvedType, mapping: &HashMap<Instantiation, GenericBase>) {
    // Recurse first so nested generics inside args are resolved before we
    // try to look up the outer key (the mapping keys hold fully-rewritten
    // inner types, so we must rewrite inner before outer lookup).
    match ty {
        ResolvedType::Array(inner) | ResolvedType::Range(inner) | ResolvedType::Optional(inner) => {
            rewrite_type(inner, mapping);
        }
        ResolvedType::Tuple(fields) => {
            for (_, t) in fields {
                rewrite_type(t, mapping);
            }
        }
        ResolvedType::Dictionary { key_ty, value_ty } => {
            rewrite_type(key_ty, mapping);
            rewrite_type(value_ty, mapping);
        }
        ResolvedType::Closure {
            param_tys,
            return_ty,
        } => {
            for (_, t) in param_tys {
                rewrite_type(t, mapping);
            }
            rewrite_type(return_ty, mapping);
        }
        ResolvedType::Generic { base, args } => {
            for a in args.iter_mut() {
                rewrite_type(a, mapping);
            }
            if let Some(&spec) = mapping.get(&(*base, args.clone())) {
                *ty = match spec {
                    GenericBase::Struct(id) => ResolvedType::Struct(id),
                    GenericBase::Enum(id) => ResolvedType::Enum(id),
                    GenericBase::Trait(id) => ResolvedType::Trait(id),
                };
            }
        }
        ResolvedType::External { type_args, .. } => {
            for a in type_args {
                rewrite_type(a, mapping);
            }
        }
        ResolvedType::Primitive(_)
        | ResolvedType::Struct(_)
        | ResolvedType::Trait(_)
        | ResolvedType::Enum(_)
        | ResolvedType::TypeParam(_)
        | ResolvedType::Error => {}
    }
}

// Phase 2b: specialise impl blocks targeting generic structs/enums

/// For each impl block whose target is a generic struct/enum, append one
/// cloned impl per specialisation of that target (with `TypeParam`s
/// substituted for the concrete type args of that specialisation). The
/// originals are retained in `module.impls` for now; they are dropped in
/// Phase 3 by `drop_specialised_generic_impls`.
///
/// Dispatch sites (`DispatchKind::Static { impl_id }`) still reference the
/// original generic-impl slot after this runs. Backends that iterate
/// `module.impls` to locate methods on a specialised type will find them
/// correctly here; Phase 2c (`rewrite_dispatch_impl_ids`) uses the
/// returned [`ImplRemap`] to retarget `DispatchKind::Static { impl_id }`
/// sites onto the cloned impl for each specialisation.
pub(super) fn specialise_impls(
    module: &mut IrModule,
    mapping: &HashMap<Instantiation, GenericBase>,
) -> ImplRemap {
    // Group specialisations by original generic base.
    type Spec = (Vec<ResolvedType>, GenericBase);
    let mut by_base: HashMap<GenericBase, Vec<Spec>> = HashMap::new();
    for ((orig_base, args), spec_base) in mapping {
        by_base
            .entry(*orig_base)
            .or_default()
            .push((args.clone(), *spec_base));
    }
    let mut new_impls: Vec<IrImpl> = Vec::new();
    let mut impl_remap: ImplRemap = HashMap::new();

    for (orig_idx, imp) in module.impls.iter().enumerate() {
        let base = match imp.target {
            crate::ir::ImplTarget::Struct(id) => GenericBase::Struct(id),
            crate::ir::ImplTarget::Enum(id) => GenericBase::Enum(id),
            // Primitive impls aren't generic; nothing to specialise.
            crate::ir::ImplTarget::Primitive(_) => continue,
        };
        let Some(specs) = by_base.get(&base) else {
            continue;
        };
        let generic_param_names: Vec<String> = match base {
            GenericBase::Struct(sid) => module
                .get_struct(sid)
                .map(|s| s.generic_params.iter().map(|p| p.name.clone()).collect())
                .unwrap_or_default(),
            GenericBase::Enum(eid) => module
                .get_enum(eid)
                .map(|e| e.generic_params.iter().map(|p| p.name.clone()).collect())
                .unwrap_or_default(),
            // An impl never targets a trait base directly — `imp.target`
            // is `ImplTarget::Struct(_)` or `ImplTarget::Enum(_)`. This
            // arm is unreachable but kept for match exhaustiveness.
            GenericBase::Trait(_) => Vec::new(),
        };
        if generic_param_names.is_empty() {
            continue;
        }
        for (args, spec_base) in specs {
            if generic_param_names.len() != args.len() {
                continue;
            }
            let subs: HashMap<String, ResolvedType> = generic_param_names
                .iter()
                .cloned()
                .zip(args.iter().cloned())
                .collect();
            let mut clone = imp.clone();
            clone.target = match spec_base {
                GenericBase::Struct(id) => crate::ir::ImplTarget::Struct(*id),
                GenericBase::Enum(id) => crate::ir::ImplTarget::Enum(*id),
                // Impl targets are struct/enum only — see above.
                GenericBase::Trait(_) => continue,
            };
            for func in &mut clone.functions {
                for param in &mut func.params {
                    if let Some(ty) = &mut param.ty {
                        substitute_type(ty, &subs);
                    }
                    if let Some(default) = &mut param.default {
                        substitute_expr_types(default, &subs);
                    }
                }
                if let Some(ret_ty) = &mut func.return_type {
                    substitute_type(ret_ty, &subs);
                }
                if let Some(body) = &mut func.body {
                    substitute_expr_types(body, &subs);
                }
            }
            walk_impl_types_mut(&mut clone, &mut |ty| rewrite_type(ty, mapping));
            // Record the (orig_idx, spec_target) → new_idx mapping so
            // dispatch-site rewriting can find the right clone.
            let new_idx = module.impls.len().saturating_add(new_impls.len());
            impl_remap.insert((orig_idx, *spec_base), new_idx);
            new_impls.push(clone);
        }
    }

    module.impls.extend(new_impls);
    impl_remap
}

/// `ImplRemap`-aware type-to-base extraction. Returns the
/// `GenericBase` of a concrete struct/enum receiver type (post Phase 2
/// rewrite). Returns `None` for non-nominal types.
pub(super) fn receiver_to_base(ty: &ResolvedType) -> Option<GenericBase> {
    match ty {
        ResolvedType::Struct(id) => Some(GenericBase::Struct(*id)),
        ResolvedType::Enum(id) => Some(GenericBase::Enum(*id)),
        ResolvedType::Optional(inner) => receiver_to_base(inner),
        ResolvedType::Primitive(_)
        | ResolvedType::Trait(_)
        | ResolvedType::Array(_)
        | ResolvedType::Range(_)
        | ResolvedType::Tuple(_)
        | ResolvedType::Generic { .. }
        | ResolvedType::TypeParam(_)
        | ResolvedType::External { .. }
        | ResolvedType::Dictionary { .. }
        | ResolvedType::Closure { .. }
        | ResolvedType::Error => None,
    }
}

/// Rewrite `DispatchKind::Static { impl_id }` at every method-call
/// site so the id points at the per-specialisation clone created in
/// Phase 2b. Walks every expression in the module.
fn dispatch_rewrite_expr(expr: &mut IrExpr, impl_remap: &ImplRemap) {
    use crate::ir::{DispatchKind, ImplId};
    // Recurse first so nested method calls are rewritten too.
    for child in iter_expr_children_mut(expr) {
        dispatch_rewrite_expr(child, impl_remap);
    }
    if let IrExpr::MethodCall {
        receiver,
        dispatch: DispatchKind::Static { impl_id },
        ..
    } = expr
    {
        let old_idx = impl_id.0 as usize;
        if let Some(target_base) = receiver_to_base(receiver.ty()) {
            if let Some(&new_idx) = impl_remap.get(&(old_idx, target_base)) {
                *impl_id = ImplId(u32::try_from(new_idx).unwrap_or(u32::MAX));
            }
        }
    }
}

pub(super) fn rewrite_dispatch_impl_ids(module: &mut IrModule, impl_remap: &ImplRemap) {
    if impl_remap.is_empty() {
        return;
    }
    // Walk every expression in the module.
    for func in &mut module.functions {
        if let Some(body) = &mut func.body {
            dispatch_rewrite_expr(body, impl_remap);
        }
        for param in &mut func.params {
            if let Some(default) = &mut param.default {
                dispatch_rewrite_expr(default, impl_remap);
            }
        }
    }
    for imp in &mut module.impls {
        for func in &mut imp.functions {
            if let Some(body) = &mut func.body {
                dispatch_rewrite_expr(body, impl_remap);
            }
            for param in &mut func.params {
                if let Some(default) = &mut param.default {
                    dispatch_rewrite_expr(default, impl_remap);
                }
            }
        }
    }
    for s in &mut module.structs {
        for field in &mut s.fields {
            if let Some(default) = &mut field.default {
                dispatch_rewrite_expr(default, impl_remap);
            }
        }
    }
    for e in &mut module.enums {
        for variant in &mut e.variants {
            for field in &mut variant.fields {
                if let Some(default) = &mut field.default {
                    dispatch_rewrite_expr(default, impl_remap);
                }
            }
        }
    }
    for l in &mut module.lets {
        dispatch_rewrite_expr(&mut l.value, impl_remap);
    }
}

/// Phase 2e devirtualisation: walk every method call and rewrite
/// `DispatchKind::Virtual` to `Static` when the receiver type is now
/// concrete (Struct/Enum). Reads `module.impls` to find the impl
/// providing the requested trait method on the receiver type.
///
/// Calls whose receiver is still a `TypeParam` (uninstantiated generic
/// function bodies) stay `Virtual` and are tolerated downstream —
/// those bodies are dropped during compaction or never reached by a
/// backend's specialisation root set.
pub(super) fn devirtualise_concrete_receivers(module: &mut IrModule) {
    // Clone the impls table so we can read it while mutating function
    // bodies. impls don't change shape during devirt; we only consult
    // them for `(target, trait_id, method_name)` lookup.
    let impls_snapshot = module.impls.clone();
    for func in &mut module.functions {
        if let Some(body) = &mut func.body {
            devirtualise_expr(body, &impls_snapshot);
        }
        for param in &mut func.params {
            if let Some(default) = &mut param.default {
                devirtualise_expr(default, &impls_snapshot);
            }
        }
    }
    for imp in &mut module.impls {
        for func in &mut imp.functions {
            if let Some(body) = &mut func.body {
                devirtualise_expr(body, &impls_snapshot);
            }
            for param in &mut func.params {
                if let Some(default) = &mut param.default {
                    devirtualise_expr(default, &impls_snapshot);
                }
            }
        }
    }
    for s in &mut module.structs {
        for field in &mut s.fields {
            if let Some(default) = &mut field.default {
                devirtualise_expr(default, &impls_snapshot);
            }
        }
    }
    for e in &mut module.enums {
        for variant in &mut e.variants {
            for field in &mut variant.fields {
                if let Some(default) = &mut field.default {
                    devirtualise_expr(default, &impls_snapshot);
                }
            }
        }
    }
    for l in &mut module.lets {
        devirtualise_expr(&mut l.value, &impls_snapshot);
    }
}

fn devirtualise_expr(expr: &mut IrExpr, impls: &[IrImpl]) {
    use crate::ir::{DispatchKind, ImplId};
    for child in iter_expr_children_mut(expr) {
        devirtualise_expr(child, impls);
    }
    let IrExpr::MethodCall {
        receiver,
        method,
        dispatch,
        ..
    } = expr
    else {
        return;
    };
    let DispatchKind::Virtual {
        trait_id: virt_trait_id,
        ..
    } = dispatch
    else {
        return;
    };
    let Some(target_base) = receiver_to_base(receiver.ty()) else {
        return;
    };
    let virt_trait_id = *virt_trait_id;
    let method_name_owned = method.clone();
    if let Some(impl_idx) = impls.iter().position(|imp| match imp.target {
        crate::ir::ImplTarget::Struct(id) => {
            target_base == GenericBase::Struct(id)
                && imp.trait_id() == Some(virt_trait_id)
                && imp.functions.iter().any(|f| f.name == method_name_owned)
        }
        crate::ir::ImplTarget::Enum(id) => {
            target_base == GenericBase::Enum(id)
                && imp.trait_id() == Some(virt_trait_id)
                && imp.functions.iter().any(|f| f.name == method_name_owned)
        }
        // Primitive impls don't carry a GenericBase target — no
        // virtual-to-static devirtualisation applies.
        crate::ir::ImplTarget::Primitive(_) => false,
    }) {
        let new_impl_id = ImplId(u32::try_from(impl_idx).unwrap_or(u32::MAX));
        *dispatch = DispatchKind::Static {
            impl_id: new_impl_id,
        };
    }
}

pub(super) fn walk_impl_types_mut(imp: &mut IrImpl, visit: &mut impl FnMut(&mut ResolvedType)) {
    for f in &mut imp.functions {
        walk_function_types_mut(f, visit);
    }
}