1use harn_lexer::Span;
27use harn_parser::builtin_signatures::{self, BuiltinSignature};
28use harn_parser::typechecker::format_type;
29use harn_parser::TypeExpr;
30
31use crate::chunk::{CompiledFunction, ParamSlot};
32use crate::value::{ArgTypeMismatchError, ArityExpect, ArityMismatchError, VmError, VmValue};
33
34pub fn assert_value_matches_type(
57 value: &VmValue,
58 expected: &TypeExpr,
59 callee: &str,
60 param: &str,
61 span: Option<Span>,
62) -> Result<(), VmError> {
63 assert_value_matches_type_with_generics(value, expected, callee, param, span, &[], &[])
64}
65
66fn assert_value_matches_type_with_generics(
67 value: &VmValue,
68 expected: &TypeExpr,
69 callee: &str,
70 param: &str,
71 span: Option<Span>,
72 type_params: &[String],
73 nominal_type_names: &[String],
74) -> Result<(), VmError> {
75 if matches_type_with_generics(value, expected, type_params, nominal_type_names) {
76 Ok(())
77 } else {
78 Err(VmError::ArgTypeMismatch(Box::new(ArgTypeMismatchError {
79 callee: callee.to_string(),
80 param: param.to_string(),
81 expected: format_type(expected),
82 got: value.type_name(),
83 span,
84 })))
85 }
86}
87
88fn user_param_for_arg(func: &CompiledFunction, index: usize) -> Option<&ParamSlot> {
89 if func.has_rest_param && index >= func.params.len().saturating_sub(1) {
90 func.params.last()
91 } else {
92 func.params.get(index)
93 }
94}
95
96fn builtin_param_for_arg(
97 sig: &BuiltinSignature,
98 index: usize,
99) -> Option<&harn_parser::builtin_signatures::Param> {
100 if sig.has_rest && index >= sig.params.len().saturating_sub(1) {
101 sig.params.last()
102 } else {
103 sig.params.get(index)
104 }
105}
106
107#[cfg(test)]
110fn matches_type(value: &VmValue, expected: &TypeExpr) -> bool {
111 matches_type_with_generics(value, expected, &[], &[])
112}
113
114fn matches_type_with_generics(
115 value: &VmValue,
116 expected: &TypeExpr,
117 type_params: &[String],
118 nominal_type_names: &[String],
119) -> bool {
120 match expected {
121 TypeExpr::Named(name) => match name.as_str() {
122 _ if type_params.iter().any(|param| param == name) => true,
123 "any" | "unknown" => true,
124 "int" => matches!(value, VmValue::Int(_)),
125 "float" => matches!(value, VmValue::Float(_) | VmValue::Int(_)),
126 "number" => matches!(value, VmValue::Int(_) | VmValue::Float(_)),
127 "string" => matches!(value, VmValue::String(_)),
128 "bool" => matches!(value, VmValue::Bool(_)),
129 "nil" => matches!(value, VmValue::Nil),
130 "list" => matches!(value, VmValue::List(_)),
131 "dict" => matches!(value, VmValue::Dict(_)),
132 "bytes" => matches!(value, VmValue::Bytes(_)),
133 "duration" => matches!(value, VmValue::Duration(_)),
134 "set" => matches!(value, VmValue::Set(_)),
135 "range" => matches!(value, VmValue::Range(_)),
136 "iter" => matches!(value, VmValue::Iter(_)),
137 "generator" | "Generator" => matches!(value, VmValue::Generator(_)),
138 "stream" | "Stream" => matches!(value, VmValue::Stream(_)),
139 "channel" => matches!(value, VmValue::Channel(_)),
140 "task_handle" => matches!(value, VmValue::TaskHandle(_)),
141 "atomic" => matches!(value, VmValue::Atomic(_)),
142 "rng" => matches!(value, VmValue::Rng(_)),
143 "sync_permit" => matches!(value, VmValue::SyncPermit(_)),
144 "mcp_client" => matches!(value, VmValue::McpClient(_)),
145 "pair" => matches!(value, VmValue::Pair(_)),
146 "enum" => matches!(value, VmValue::EnumVariant { .. }),
147 "struct" => matches!(value, VmValue::StructInstance { .. }),
148 "closure" => matches!(
149 value,
150 VmValue::Closure(_) | VmValue::BuiltinRef(_) | VmValue::BuiltinRefId { .. }
151 ),
152 _ => {
153 if !nominal_type_names.iter().any(|ty| ty == name) {
154 true
155 } else {
156 value
157 .struct_name()
158 .is_some_and(|struct_name| struct_name == name)
159 || matches!(value, VmValue::EnumVariant { enum_name, .. } if enum_name.as_ref() == name)
160 }
161 }
162 },
163 TypeExpr::Union(members) => members
164 .iter()
165 .any(|m| matches_type_with_generics(value, m, type_params, nominal_type_names)),
166 TypeExpr::Intersection(members) => members
167 .iter()
168 .all(|m| matches_type_with_generics(value, m, type_params, nominal_type_names)),
169 TypeExpr::List(inner) => match value {
170 VmValue::List(items) => items
171 .iter()
172 .all(|v| matches_type_with_generics(v, inner, type_params, nominal_type_names)),
173 _ => false,
174 },
175 TypeExpr::DictType(_, vt) => match value {
176 VmValue::Dict(map) => map
177 .values()
178 .all(|v| matches_type_with_generics(v, vt, type_params, nominal_type_names)),
179 _ => false,
180 },
181 TypeExpr::Iter(_) | TypeExpr::Generator(_) | TypeExpr::Stream(_) => match value {
182 VmValue::List(_) | VmValue::Generator(_) | VmValue::Stream(_) => true,
185 _ => false,
186 },
187 TypeExpr::Shape(fields) => match value {
188 VmValue::Dict(map) => fields.iter().all(|f| match map.get(&f.name) {
189 Some(v) => {
190 matches_type_with_generics(v, &f.type_expr, type_params, nominal_type_names)
191 }
192 None => f.optional,
193 }),
194 VmValue::StructInstance { .. } => {
195 fields.iter().all(|f| match value.struct_field(&f.name) {
196 Some(v) => {
197 matches_type_with_generics(v, &f.type_expr, type_params, nominal_type_names)
198 }
199 None => f.optional,
200 })
201 }
202 _ => false,
203 },
204 TypeExpr::Applied { name, args } => match (name.as_str(), args.as_slice()) {
205 ("list", [inner]) => matches_type_with_generics(
206 value,
207 &TypeExpr::List(Box::new(inner.clone())),
208 type_params,
209 nominal_type_names,
210 ),
211 ("dict", [k, v]) => matches_type_with_generics(
212 value,
213 &TypeExpr::DictType(Box::new(k.clone()), Box::new(v.clone())),
214 type_params,
215 nominal_type_names,
216 ),
217 ("Option", [inner]) => {
218 matches!(value, VmValue::Nil)
219 || matches_type_with_generics(value, inner, type_params, nominal_type_names)
220 }
221 _ => true,
225 },
226 TypeExpr::FnType { .. } => matches!(
227 value,
228 VmValue::Closure(_) | VmValue::BuiltinRef(_) | VmValue::BuiltinRefId { .. }
229 ),
230 TypeExpr::Never => false,
231 TypeExpr::LitString(s) => matches!(value, VmValue::String(rs) if rs.as_ref() == s),
232 TypeExpr::LitInt(i) => matches!(value, VmValue::Int(rv) if rv == i),
233 }
234}
235
236pub fn validate_user_call(
240 func: &CompiledFunction,
241 args: &[VmValue],
242 span: Option<Span>,
243) -> Result<(), VmError> {
244 let total = func.params.len();
245 let required = func.required_param_count();
246 let got = args.len();
247
248 let arity_ok = if func.has_rest_param {
249 got >= total.saturating_sub(1)
251 } else {
252 got >= required && got <= total
253 };
254
255 if !arity_ok {
256 let expected = arity_expect_for(func);
257 return Err(VmError::ArityMismatch(Box::new(ArityMismatchError {
258 callee: func.name.clone(),
259 expected,
260 got,
261 span,
262 })));
263 }
264
265 for (i, value) in args.iter().enumerate() {
266 let Some(slot) = user_param_for_arg(func, i) else {
267 continue;
268 };
269 let Some(expected) = &slot.type_expr else {
270 continue;
271 };
272 if matches!(expected, TypeExpr::Named(name) if func.declares_type_param(name)) {
273 continue;
274 }
275 if let Some(schema) = crate::compiler::Compiler::type_expr_to_schema_value(expected) {
276 crate::schema::schema_assert_param(value, &slot.name, &schema)?;
277 continue;
278 }
279 assert_value_matches_type_with_generics(
280 value,
281 expected,
282 &func.name,
283 &slot.name,
284 span,
285 &func.type_params,
286 &func.nominal_type_names,
287 )?;
288 }
289
290 Ok(())
291}
292
293pub fn validate_builtin_call(
300 name: &str,
301 args: &[VmValue],
302 span: Option<Span>,
303) -> Result<(), VmError> {
304 let Some(sig) = builtin_signatures::lookup(name) else {
305 return Ok(());
306 };
307 validate_against_signature(name, sig, args, span)
308}
309
310pub fn validate_against_signature(
314 name: &str,
315 sig: &BuiltinSignature,
316 args: &[VmValue],
317 span: Option<Span>,
318) -> Result<(), VmError> {
319 let total = sig.params.len();
320 let required = sig.required_params();
321 let got = args.len();
322
323 let arity_ok = if sig.has_rest {
324 got >= total.saturating_sub(1)
325 } else {
326 got >= required && got <= total
327 };
328
329 if !arity_ok {
330 let expected = if sig.has_rest {
331 ArityExpect::AtLeast(total.saturating_sub(1))
332 } else if required == total {
333 ArityExpect::Exact(total)
334 } else {
335 ArityExpect::Range {
336 min: required,
337 max: total,
338 }
339 };
340 return Err(VmError::ArityMismatch(Box::new(ArityMismatchError {
341 callee: name.to_string(),
342 expected,
343 got,
344 span,
345 })));
346 }
347
348 for (i, value) in args.iter().enumerate() {
349 let Some(param) = builtin_param_for_arg(sig, i) else {
350 continue;
351 };
352 if param.optional && matches!(value, VmValue::Nil) {
353 continue;
354 }
355 let expected = param.ty.to_type_expr();
360 if matches!(&expected, TypeExpr::Named(n) if sig.is_type_param(n)) {
361 continue;
362 }
363 if param.ty.is_any() {
366 continue;
367 }
368 if matches!(param.ty, harn_parser::builtin_signatures::Ty::SchemaOf(_)) {
369 continue;
370 }
371 assert_value_matches_type(value, &expected, name, param.name, span)?;
372 }
373
374 Ok(())
375}
376
377fn arity_expect_for(func: &CompiledFunction) -> ArityExpect {
381 let total = func.params.len();
382 let required = func.required_param_count();
383 if func.has_rest_param {
384 ArityExpect::AtLeast(total.saturating_sub(1))
385 } else if required == total {
386 ArityExpect::Exact(total)
387 } else {
388 ArityExpect::Range {
389 min: required,
390 max: total,
391 }
392 }
393}
394
395#[cfg(test)]
396mod tests {
397 use super::*;
398 use std::rc::Rc;
399
400 fn vm_int(n: i64) -> VmValue {
401 VmValue::Int(n)
402 }
403
404 fn vm_string(s: &str) -> VmValue {
405 VmValue::String(Rc::from(s))
406 }
407
408 fn ty_int() -> TypeExpr {
409 TypeExpr::Named("int".into())
410 }
411
412 fn ty_string() -> TypeExpr {
413 TypeExpr::Named("string".into())
414 }
415
416 #[test]
417 fn matches_primitive_types() {
418 assert!(matches_type(&vm_int(42), &ty_int()));
419 assert!(!matches_type(&vm_int(42), &ty_string()));
420 assert!(matches_type(&vm_string("x"), &ty_string()));
421 assert!(matches_type(
422 &VmValue::Bool(true),
423 &TypeExpr::Named("bool".into())
424 ));
425 assert!(matches_type(&VmValue::Nil, &TypeExpr::Named("nil".into())));
426 }
427
428 #[test]
429 fn float_accepts_int_promotion() {
430 assert!(matches_type(&vm_int(3), &TypeExpr::Named("float".into())));
432 assert!(matches_type(
433 &VmValue::Float(3.0),
434 &TypeExpr::Named("float".into())
435 ));
436 }
437
438 #[test]
439 fn union_accepts_any_member() {
440 let union = TypeExpr::Union(vec![ty_int(), ty_string()]);
441 assert!(matches_type(&vm_int(1), &union));
442 assert!(matches_type(&vm_string("y"), &union));
443 assert!(!matches_type(&VmValue::Bool(true), &union));
444 }
445
446 #[test]
447 fn optional_accepts_nil() {
448 let opt = TypeExpr::Union(vec![ty_string(), TypeExpr::Named("nil".into())]);
449 assert!(matches_type(&VmValue::Nil, &opt));
450 assert!(matches_type(&vm_string("x"), &opt));
451 assert!(!matches_type(&vm_int(1), &opt));
452 }
453
454 #[test]
455 fn list_validates_elements() {
456 let list_int = TypeExpr::List(Box::new(ty_int()));
457 let good = VmValue::List(Rc::new(vec![vm_int(1), vm_int(2)]));
458 let bad = VmValue::List(Rc::new(vec![vm_int(1), vm_string("x")]));
459 assert!(matches_type(&good, &list_int));
460 assert!(!matches_type(&bad, &list_int));
461 }
462
463 #[test]
464 fn shape_validates_required_fields() {
465 let shape = TypeExpr::Shape(vec![harn_parser::ShapeField {
466 name: "x".into(),
467 type_expr: ty_int(),
468 optional: false,
469 }]);
470 let mut good = std::collections::BTreeMap::new();
471 good.insert("x".to_string(), vm_int(7));
472 assert!(matches_type(&VmValue::Dict(Rc::new(good)), &shape));
473 assert!(!matches_type(
474 &VmValue::Dict(Rc::new(std::collections::BTreeMap::new())),
475 &shape
476 ));
477 }
478
479 #[test]
480 fn named_type_matches_user_struct_name() {
481 let custom = TypeExpr::Named("MyStruct".into());
482 assert!(!matches_type_with_generics(
483 &vm_int(1),
484 &custom,
485 &[],
486 &["MyStruct".to_string()]
487 ));
488 assert!(matches_type_with_generics(
489 &VmValue::struct_instance("MyStruct", Default::default()),
490 &custom,
491 &[],
492 &["MyStruct".to_string()]
493 ));
494 }
495
496 #[test]
497 fn lit_int_requires_value_equality() {
498 assert!(matches_type(&vm_int(42), &TypeExpr::LitInt(42)));
499 assert!(!matches_type(&vm_int(7), &TypeExpr::LitInt(42)));
500 }
501
502 #[test]
503 fn assert_value_returns_arg_type_mismatch_on_fail() {
504 let err =
505 assert_value_matches_type(&vm_string("abc"), &ty_int(), "myFn", "n", None).unwrap_err();
506 match err {
507 VmError::ArgTypeMismatch(err) => {
508 assert_eq!(err.callee, "myFn");
509 assert_eq!(err.param, "n");
510 assert_eq!(err.expected, "int");
511 assert_eq!(err.got, "string");
512 assert!(err.span.is_none());
513 }
514 other => panic!("expected ArgTypeMismatch, got {other:?}"),
515 }
516 }
517}