Skip to main content

harn_vm/
typecheck.rs

1//! Runtime type & arity validation, shared between user-defined function
2//! calls and registry-known builtin calls.
3//!
4//! Every call-site validation in the VM funnels through three entry points:
5//!
6//! - [`assert_value_matches_type`] — given a [`VmValue`] and a
7//!   [`TypeExpr`], decide whether the value satisfies the type. The
8//!   single source of runtime truth for `int`/`string`/`list<T>`/...
9//!   compatibility, mirroring the static [`TypeChecker::types_compatible`]
10//!   semantics on values rather than type expressions.
11//! - [`validate_user_call`] — arity check + per-arg declared-type
12//!   assertion for compiled user-defined functions
13//!   ([`crate::chunk::CompiledFunction`]).
14//! - [`validate_builtin_call`] — arity check + per-arg type assertion
15//!   for builtins, driven by the parser's
16//!   [`harn_parser::builtin_signatures`] registry. The runtime never
17//!   re-implements per-builtin validation; the registry is the contract.
18//!
19//! All three return [`crate::value::VmError`] variants
20//! ([`VmError::ArityMismatch`], [`VmError::ArgTypeMismatch`]) on failure
21//! so error UX is uniform. Callers may pass an optional
22//! [`harn_lexer::Span`] when they have a source location for the call
23//! site (e.g. derived from the chunk's PC→span table); when omitted the
24//! error renders without a positional suffix.
25
26use harn_lexer::Span;
27use harn_parser::builtin_signatures::{self, BuiltinSignature};
28use harn_parser::typechecker::format_type;
29use harn_parser::TypeExpr;
30
31use crate::chunk::{CompiledFunction, ParamSlot};
32use crate::value::{ArgTypeMismatchError, ArityExpect, ArityMismatchError, VmError, VmValue};
33
34/// Validate that `value` satisfies `expected`. Returns `Ok(())` when the
35/// value is acceptable, otherwise an [`VmError::ArgTypeMismatch`] tagged
36/// with `callee` / `param` / `span` for the caller's diagnostic.
37///
38/// The dispatch table mirrors the static checker's `types_compatible`
39/// rules:
40/// - `Named("any")` and the special generic-parameter sentinel skip
41///   validation (any value passes).
42/// - `Named("number")` accepts `int` or `float`.
43/// - `Optional<T>` / `T | nil` accepts the inner type or `Nil`.
44/// - `list<T>`, `dict<K, V>`, `iter<T>`, `Generator<T>`, `Stream<T>`
45///   check the container; element-level validation is per element when
46///   the value is a literal `VmValue::List` / `VmValue::Dict` whose
47///   contents are cheap to walk. For lazy iterators / streams we skip
48///   element validation (they may be infinite or expensive).
49/// - `Shape{...}` validates field presence and per-field types against
50///   `VmValue::Dict` and `VmValue::StructInstance`.
51/// - `Union(...)` accepts any matching alternative.
52/// - `Intersection(...)` accepts only when *every* alternative matches.
53/// - Literal types (`LitInt`, `LitString`) require value equality with
54///   the literal.
55/// - `Never` always rejects.
56pub fn assert_value_matches_type(
57    value: &VmValue,
58    expected: &TypeExpr,
59    callee: &str,
60    param: &str,
61    span: Option<Span>,
62) -> Result<(), VmError> {
63    assert_value_matches_type_with_generics(value, expected, callee, param, span, &[], &[])
64}
65
66fn assert_value_matches_type_with_generics(
67    value: &VmValue,
68    expected: &TypeExpr,
69    callee: &str,
70    param: &str,
71    span: Option<Span>,
72    type_params: &[String],
73    nominal_type_names: &[String],
74) -> Result<(), VmError> {
75    if matches_type_with_generics(value, expected, type_params, nominal_type_names) {
76        Ok(())
77    } else {
78        Err(VmError::ArgTypeMismatch(Box::new(ArgTypeMismatchError {
79            callee: callee.to_string(),
80            param: param.to_string(),
81            expected: format_type(expected),
82            got: value.type_name(),
83            span,
84        })))
85    }
86}
87
88fn user_param_for_arg(func: &CompiledFunction, index: usize) -> Option<&ParamSlot> {
89    if func.has_rest_param && index >= func.params.len().saturating_sub(1) {
90        func.params.last()
91    } else {
92        func.params.get(index)
93    }
94}
95
96fn builtin_param_for_arg(
97    sig: &BuiltinSignature,
98    index: usize,
99) -> Option<&harn_parser::builtin_signatures::Param> {
100    if sig.has_rest && index >= sig.params.len().saturating_sub(1) {
101        sig.params.last()
102    } else {
103        sig.params.get(index)
104    }
105}
106
107/// Recursive predicate driving [`assert_value_matches_type`]. Kept
108/// internal so the public API only exposes `Result`-returning forms.
109#[cfg(test)]
110fn matches_type(value: &VmValue, expected: &TypeExpr) -> bool {
111    matches_type_with_generics(value, expected, &[], &[])
112}
113
114fn matches_type_with_generics(
115    value: &VmValue,
116    expected: &TypeExpr,
117    type_params: &[String],
118    nominal_type_names: &[String],
119) -> bool {
120    match expected {
121        TypeExpr::Named(name) => match name.as_str() {
122            _ if type_params.iter().any(|param| param == name) => true,
123            "any" | "unknown" => true,
124            "int" => matches!(value, VmValue::Int(_)),
125            "float" => matches!(value, VmValue::Float(_) | VmValue::Int(_)),
126            "number" => matches!(value, VmValue::Int(_) | VmValue::Float(_)),
127            "string" => matches!(value, VmValue::String(_)),
128            "bool" => matches!(value, VmValue::Bool(_)),
129            "nil" => matches!(value, VmValue::Nil),
130            "list" => matches!(value, VmValue::List(_)),
131            "dict" => matches!(value, VmValue::Dict(_)),
132            "bytes" => matches!(value, VmValue::Bytes(_)),
133            "duration" => matches!(value, VmValue::Duration(_)),
134            "set" => matches!(value, VmValue::Set(_)),
135            "range" => matches!(value, VmValue::Range(_)),
136            "iter" => matches!(value, VmValue::Iter(_)),
137            "generator" | "Generator" => matches!(value, VmValue::Generator(_)),
138            "stream" | "Stream" => matches!(value, VmValue::Stream(_)),
139            "channel" => matches!(value, VmValue::Channel(_)),
140            "task_handle" => matches!(value, VmValue::TaskHandle(_)),
141            "atomic" => matches!(value, VmValue::Atomic(_)),
142            "rng" => matches!(value, VmValue::Rng(_)),
143            "sync_permit" => matches!(value, VmValue::SyncPermit(_)),
144            "mcp_client" => matches!(value, VmValue::McpClient(_)),
145            "pair" => matches!(value, VmValue::Pair(_)),
146            "enum" => matches!(value, VmValue::EnumVariant { .. }),
147            "struct" => matches!(value, VmValue::StructInstance { .. }),
148            "closure" => matches!(
149                value,
150                VmValue::Closure(_) | VmValue::BuiltinRef(_) | VmValue::BuiltinRefId { .. }
151            ),
152            _ => {
153                if !nominal_type_names.iter().any(|ty| ty == name) {
154                    true
155                } else {
156                    value
157                        .struct_name()
158                        .is_some_and(|struct_name| struct_name == name)
159                        || matches!(value, VmValue::EnumVariant { enum_name, .. } if enum_name.as_ref() == name)
160                }
161            }
162        },
163        TypeExpr::Union(members) => members
164            .iter()
165            .any(|m| matches_type_with_generics(value, m, type_params, nominal_type_names)),
166        TypeExpr::Intersection(members) => members
167            .iter()
168            .all(|m| matches_type_with_generics(value, m, type_params, nominal_type_names)),
169        TypeExpr::List(inner) => match value {
170            VmValue::List(items) => items
171                .iter()
172                .all(|v| matches_type_with_generics(v, inner, type_params, nominal_type_names)),
173            _ => false,
174        },
175        TypeExpr::DictType(_, vt) => match value {
176            VmValue::Dict(map) => map
177                .values()
178                .all(|v| matches_type_with_generics(v, vt, type_params, nominal_type_names)),
179            _ => false,
180        },
181        TypeExpr::Iter(_) | TypeExpr::Generator(_) | TypeExpr::Stream(_) => match value {
182            // Lazy / async sequences: only check the container shape;
183            // element-level validation would force evaluation.
184            VmValue::List(_) | VmValue::Generator(_) | VmValue::Stream(_) => true,
185            _ => false,
186        },
187        TypeExpr::Shape(fields) => match value {
188            VmValue::Dict(map) => fields.iter().all(|f| match map.get(&f.name) {
189                Some(v) => {
190                    matches_type_with_generics(v, &f.type_expr, type_params, nominal_type_names)
191                }
192                None => f.optional,
193            }),
194            VmValue::StructInstance { .. } => {
195                fields.iter().all(|f| match value.struct_field(&f.name) {
196                    Some(v) => {
197                        matches_type_with_generics(v, &f.type_expr, type_params, nominal_type_names)
198                    }
199                    None => f.optional,
200                })
201            }
202            _ => false,
203        },
204        TypeExpr::Applied { name, args } => match (name.as_str(), args.as_slice()) {
205            ("list", [inner]) => matches_type_with_generics(
206                value,
207                &TypeExpr::List(Box::new(inner.clone())),
208                type_params,
209                nominal_type_names,
210            ),
211            ("dict", [k, v]) => matches_type_with_generics(
212                value,
213                &TypeExpr::DictType(Box::new(k.clone()), Box::new(v.clone())),
214                type_params,
215                nominal_type_names,
216            ),
217            ("Option", [inner]) => {
218                matches!(value, VmValue::Nil)
219                    || matches_type_with_generics(value, inner, type_params, nominal_type_names)
220            }
221            // Result<T, E>, custom user-applied generics, Schema<T>, etc.
222            // fall through to permissive — runtime can't determine the
223            // active variant without more semantic knowledge.
224            _ => true,
225        },
226        TypeExpr::FnType { .. } => matches!(
227            value,
228            VmValue::Closure(_) | VmValue::BuiltinRef(_) | VmValue::BuiltinRefId { .. }
229        ),
230        TypeExpr::Never => false,
231        TypeExpr::LitString(s) => matches!(value, VmValue::String(rs) if rs.as_ref() == s),
232        TypeExpr::LitInt(i) => matches!(value, VmValue::Int(rv) if rv == i),
233    }
234}
235
236/// Validate a user-defined function call: arity (respecting defaults +
237/// rest), then per-parameter declared-type assertion for parameters
238/// that carry a [`TypeExpr`] in their [`crate::chunk::ParamSlot`].
239pub fn validate_user_call(
240    func: &CompiledFunction,
241    args: &[VmValue],
242    span: Option<Span>,
243) -> Result<(), VmError> {
244    let total = func.params.len();
245    let required = func.required_param_count();
246    let got = args.len();
247
248    let arity_ok = if func.has_rest_param {
249        // Rest absorbs everything >= (total - 1).
250        got >= total.saturating_sub(1)
251    } else {
252        got >= required && got <= total
253    };
254
255    if !arity_ok {
256        let expected = arity_expect_for(func);
257        return Err(VmError::ArityMismatch(Box::new(ArityMismatchError {
258            callee: func.name.clone(),
259            expected,
260            got,
261            span,
262        })));
263    }
264
265    for (i, value) in args.iter().enumerate() {
266        let Some(slot) = user_param_for_arg(func, i) else {
267            continue;
268        };
269        let Some(expected) = &slot.type_expr else {
270            continue;
271        };
272        if matches!(expected, TypeExpr::Named(name) if func.declares_type_param(name)) {
273            continue;
274        }
275        if let Some(schema) = crate::compiler::Compiler::type_expr_to_schema_value(expected) {
276            crate::schema::schema_assert_param(value, &slot.name, &schema)?;
277            continue;
278        }
279        assert_value_matches_type_with_generics(
280            value,
281            expected,
282            &func.name,
283            &slot.name,
284            span,
285            &func.type_params,
286            &func.nominal_type_names,
287        )?;
288    }
289
290    Ok(())
291}
292
293/// Validate a builtin call against the parser's signature registry.
294/// Returns `Ok(())` when the builtin is unknown to the registry — the
295/// alignment guarantee enforced at registration time means unknown
296/// names are necessarily internal/special-purpose builtins
297/// (e.g. compiler-synthesized `__*`) that don't need runtime
298/// validation.
299pub fn validate_builtin_call(
300    name: &str,
301    args: &[VmValue],
302    span: Option<Span>,
303) -> Result<(), VmError> {
304    let Some(sig) = builtin_signatures::lookup(name) else {
305        return Ok(());
306    };
307    validate_against_signature(name, sig, args, span)
308}
309
310/// Shared implementation for [`validate_builtin_call`] (and any future
311/// callers that already have a signature in hand). Public so test
312/// harnesses can drive it directly with synthetic signatures.
313pub fn validate_against_signature(
314    name: &str,
315    sig: &BuiltinSignature,
316    args: &[VmValue],
317    span: Option<Span>,
318) -> Result<(), VmError> {
319    let total = sig.params.len();
320    let required = sig.required_params();
321    let got = args.len();
322
323    let arity_ok = if sig.has_rest {
324        got >= total.saturating_sub(1)
325    } else {
326        got >= required && got <= total
327    };
328
329    if !arity_ok {
330        let expected = if sig.has_rest {
331            ArityExpect::AtLeast(total.saturating_sub(1))
332        } else if required == total {
333            ArityExpect::Exact(total)
334        } else {
335            ArityExpect::Range {
336                min: required,
337                max: total,
338            }
339        };
340        return Err(VmError::ArityMismatch(Box::new(ArityMismatchError {
341            callee: name.to_string(),
342            expected,
343            got,
344            span,
345        })));
346    }
347
348    for (i, value) in args.iter().enumerate() {
349        let Some(param) = builtin_param_for_arg(sig, i) else {
350            continue;
351        };
352        if param.optional && matches!(value, VmValue::Nil) {
353            continue;
354        }
355        // Generic type parameters inside builtin signatures are not
356        // resolvable at the value level — the static checker handles
357        // them. Skip type-param positions at runtime to avoid bogus
358        // mismatches.
359        let expected = param.ty.to_type_expr();
360        if matches!(&expected, TypeExpr::Named(n) if sig.is_type_param(n)) {
361            continue;
362        }
363        // `any` is always satisfied; format_type would render "any"
364        // and the runtime predicate accepts everything anyway.
365        if param.ty.is_any() {
366            continue;
367        }
368        if matches!(param.ty, harn_parser::builtin_signatures::Ty::SchemaOf(_)) {
369            continue;
370        }
371        assert_value_matches_type(value, &expected, name, param.name, span)?;
372    }
373
374    Ok(())
375}
376
377/// Compute the [`ArityExpect`] to embed in an [`VmError::ArityMismatch`]
378/// for a user-defined function. Respects defaults and rest-param flags
379/// so the message reads naturally.
380fn arity_expect_for(func: &CompiledFunction) -> ArityExpect {
381    let total = func.params.len();
382    let required = func.required_param_count();
383    if func.has_rest_param {
384        ArityExpect::AtLeast(total.saturating_sub(1))
385    } else if required == total {
386        ArityExpect::Exact(total)
387    } else {
388        ArityExpect::Range {
389            min: required,
390            max: total,
391        }
392    }
393}
394
395#[cfg(test)]
396mod tests {
397    use super::*;
398    use std::rc::Rc;
399
400    fn vm_int(n: i64) -> VmValue {
401        VmValue::Int(n)
402    }
403
404    fn vm_string(s: &str) -> VmValue {
405        VmValue::String(Rc::from(s))
406    }
407
408    fn ty_int() -> TypeExpr {
409        TypeExpr::Named("int".into())
410    }
411
412    fn ty_string() -> TypeExpr {
413        TypeExpr::Named("string".into())
414    }
415
416    #[test]
417    fn matches_primitive_types() {
418        assert!(matches_type(&vm_int(42), &ty_int()));
419        assert!(!matches_type(&vm_int(42), &ty_string()));
420        assert!(matches_type(&vm_string("x"), &ty_string()));
421        assert!(matches_type(
422            &VmValue::Bool(true),
423            &TypeExpr::Named("bool".into())
424        ));
425        assert!(matches_type(&VmValue::Nil, &TypeExpr::Named("nil".into())));
426    }
427
428    #[test]
429    fn float_accepts_int_promotion() {
430        // Mirrors the static rule: `int` is assignable to `float`.
431        assert!(matches_type(&vm_int(3), &TypeExpr::Named("float".into())));
432        assert!(matches_type(
433            &VmValue::Float(3.0),
434            &TypeExpr::Named("float".into())
435        ));
436    }
437
438    #[test]
439    fn union_accepts_any_member() {
440        let union = TypeExpr::Union(vec![ty_int(), ty_string()]);
441        assert!(matches_type(&vm_int(1), &union));
442        assert!(matches_type(&vm_string("y"), &union));
443        assert!(!matches_type(&VmValue::Bool(true), &union));
444    }
445
446    #[test]
447    fn optional_accepts_nil() {
448        let opt = TypeExpr::Union(vec![ty_string(), TypeExpr::Named("nil".into())]);
449        assert!(matches_type(&VmValue::Nil, &opt));
450        assert!(matches_type(&vm_string("x"), &opt));
451        assert!(!matches_type(&vm_int(1), &opt));
452    }
453
454    #[test]
455    fn list_validates_elements() {
456        let list_int = TypeExpr::List(Box::new(ty_int()));
457        let good = VmValue::List(Rc::new(vec![vm_int(1), vm_int(2)]));
458        let bad = VmValue::List(Rc::new(vec![vm_int(1), vm_string("x")]));
459        assert!(matches_type(&good, &list_int));
460        assert!(!matches_type(&bad, &list_int));
461    }
462
463    #[test]
464    fn shape_validates_required_fields() {
465        let shape = TypeExpr::Shape(vec![harn_parser::ShapeField {
466            name: "x".into(),
467            type_expr: ty_int(),
468            optional: false,
469        }]);
470        let mut good = std::collections::BTreeMap::new();
471        good.insert("x".to_string(), vm_int(7));
472        assert!(matches_type(&VmValue::Dict(Rc::new(good)), &shape));
473        assert!(!matches_type(
474            &VmValue::Dict(Rc::new(std::collections::BTreeMap::new())),
475            &shape
476        ));
477    }
478
479    #[test]
480    fn named_type_matches_user_struct_name() {
481        let custom = TypeExpr::Named("MyStruct".into());
482        assert!(!matches_type_with_generics(
483            &vm_int(1),
484            &custom,
485            &[],
486            &["MyStruct".to_string()]
487        ));
488        assert!(matches_type_with_generics(
489            &VmValue::struct_instance("MyStruct", Default::default()),
490            &custom,
491            &[],
492            &["MyStruct".to_string()]
493        ));
494    }
495
496    #[test]
497    fn lit_int_requires_value_equality() {
498        assert!(matches_type(&vm_int(42), &TypeExpr::LitInt(42)));
499        assert!(!matches_type(&vm_int(7), &TypeExpr::LitInt(42)));
500    }
501
502    #[test]
503    fn assert_value_returns_arg_type_mismatch_on_fail() {
504        let err =
505            assert_value_matches_type(&vm_string("abc"), &ty_int(), "myFn", "n", None).unwrap_err();
506        match err {
507            VmError::ArgTypeMismatch(err) => {
508                assert_eq!(err.callee, "myFn");
509                assert_eq!(err.param, "n");
510                assert_eq!(err.expected, "int");
511                assert_eq!(err.got, "string");
512                assert!(err.span.is_none());
513            }
514            other => panic!("expected ArgTypeMismatch, got {other:?}"),
515        }
516    }
517}