1use harn_lexer::Span;
27use harn_parser::builtin_signatures::{self, BuiltinSignature};
28use harn_parser::typechecker::format_type;
29use harn_parser::TypeExpr;
30
31use crate::chunk::{CompiledFunction, ParamSlot};
32use crate::value::{ArgTypeMismatchError, ArityExpect, ArityMismatchError, VmError, VmValue};
33
34pub fn assert_value_matches_type(
57 value: &VmValue,
58 expected: &TypeExpr,
59 callee: &str,
60 param: &str,
61 span: Option<Span>,
62) -> Result<(), VmError> {
63 assert_value_matches_type_with_generics(value, expected, callee, param, span, &[], &[])
64}
65
66fn assert_value_matches_type_with_generics(
67 value: &VmValue,
68 expected: &TypeExpr,
69 callee: &str,
70 param: &str,
71 span: Option<Span>,
72 type_params: &[String],
73 nominal_type_names: &[String],
74) -> Result<(), VmError> {
75 if matches_type_with_generics(value, expected, type_params, nominal_type_names) {
76 Ok(())
77 } else {
78 Err(VmError::ArgTypeMismatch(Box::new(ArgTypeMismatchError {
79 callee: callee.to_string(),
80 param: param.to_string(),
81 expected: format_type(expected),
82 got: value.type_name(),
83 span,
84 })))
85 }
86}
87
88fn user_param_for_arg(func: &CompiledFunction, index: usize) -> Option<&ParamSlot> {
89 if func.has_rest_param && index >= func.params.len().saturating_sub(1) {
90 func.params.last()
91 } else {
92 func.params.get(index)
93 }
94}
95
96fn builtin_param_for_arg(
97 sig: &BuiltinSignature,
98 index: usize,
99) -> Option<&harn_parser::builtin_signatures::Param> {
100 if sig.has_rest && index >= sig.params.len().saturating_sub(1) {
101 sig.params.last()
102 } else {
103 sig.params.get(index)
104 }
105}
106
107#[cfg(test)]
110fn matches_type(value: &VmValue, expected: &TypeExpr) -> bool {
111 matches_type_with_generics(value, expected, &[], &[])
112}
113
114fn matches_type_with_generics(
115 value: &VmValue,
116 expected: &TypeExpr,
117 type_params: &[String],
118 nominal_type_names: &[String],
119) -> bool {
120 match expected {
121 TypeExpr::Named(name) => match name.as_str() {
122 _ if type_params.iter().any(|param| param == name) => true,
123 "any" | "unknown" => true,
124 "int" => matches!(value, VmValue::Int(_)),
125 "float" => matches!(value, VmValue::Float(_) | VmValue::Int(_)),
126 "number" => matches!(value, VmValue::Int(_) | VmValue::Float(_)),
127 "string" => matches!(value, VmValue::String(_)),
128 "bool" => matches!(value, VmValue::Bool(_)),
129 "nil" => matches!(value, VmValue::Nil),
130 "list" => matches!(value, VmValue::List(_)),
131 "dict" => matches!(value, VmValue::Dict(_)),
132 "bytes" => matches!(value, VmValue::Bytes(_)),
133 "duration" => matches!(value, VmValue::Duration(_)),
134 "set" => matches!(value, VmValue::Set(_)),
135 "range" => matches!(value, VmValue::Range(_)),
136 "iter" => matches!(value, VmValue::Iter(_)),
137 "generator" | "Generator" => matches!(value, VmValue::Generator(_)),
138 "stream" | "Stream" => matches!(value, VmValue::Stream(_)),
139 "channel" => matches!(value, VmValue::Channel(_)),
140 "task_handle" => matches!(value, VmValue::TaskHandle(_)),
141 "atomic" => matches!(value, VmValue::Atomic(_)),
142 "rng" => matches!(value, VmValue::Rng(_)),
143 "sync_permit" => matches!(value, VmValue::SyncPermit(_)),
144 "mcp_client" => matches!(value, VmValue::McpClient(_)),
145 "pair" => matches!(value, VmValue::Pair(_)),
146 "enum" => matches!(value, VmValue::EnumVariant { .. }),
147 "struct" => matches!(value, VmValue::StructInstance { .. }),
148 "closure" => matches!(
149 value,
150 VmValue::Closure(_) | VmValue::BuiltinRef(_) | VmValue::BuiltinRefId { .. }
151 ),
152 _ => {
153 if !nominal_type_names.iter().any(|ty| ty == name) {
154 true
155 } else {
156 value
157 .struct_name()
158 .is_some_and(|struct_name| struct_name == name)
159 || matches!(value, VmValue::EnumVariant { enum_name, .. } if enum_name.as_ref() == name)
160 }
161 }
162 },
163 TypeExpr::Union(members) => members
164 .iter()
165 .any(|m| matches_type_with_generics(value, m, type_params, nominal_type_names)),
166 TypeExpr::Intersection(members) => members
167 .iter()
168 .all(|m| matches_type_with_generics(value, m, type_params, nominal_type_names)),
169 TypeExpr::List(inner) => match value {
170 VmValue::List(items) => items
171 .iter()
172 .all(|v| matches_type_with_generics(v, inner, type_params, nominal_type_names)),
173 _ => false,
174 },
175 TypeExpr::DictType(_, vt) => match value {
176 VmValue::Dict(map) => map
177 .values()
178 .all(|v| matches_type_with_generics(v, vt, type_params, nominal_type_names)),
179 _ => false,
180 },
181 TypeExpr::Iter(_) | TypeExpr::Generator(_) | TypeExpr::Stream(_) => match value {
182 VmValue::List(_) | VmValue::Generator(_) | VmValue::Stream(_) => true,
185 _ => false,
186 },
187 TypeExpr::Shape(fields) => match value {
188 VmValue::Dict(map) => fields.iter().all(|f| match map.get(&f.name) {
189 Some(VmValue::Nil) if f.optional => true,
193 Some(v) => {
194 matches_type_with_generics(v, &f.type_expr, type_params, nominal_type_names)
195 }
196 None => f.optional,
197 }),
198 VmValue::StructInstance { .. } => {
199 fields.iter().all(|f| match value.struct_field(&f.name) {
200 Some(VmValue::Nil) if f.optional => true,
201 Some(v) => {
202 matches_type_with_generics(v, &f.type_expr, type_params, nominal_type_names)
203 }
204 None => f.optional,
205 })
206 }
207 _ => false,
208 },
209 TypeExpr::Applied { name, args } => match (name.as_str(), args.as_slice()) {
210 ("list", [inner]) => matches_type_with_generics(
211 value,
212 &TypeExpr::List(Box::new(inner.clone())),
213 type_params,
214 nominal_type_names,
215 ),
216 ("dict", [k, v]) => matches_type_with_generics(
217 value,
218 &TypeExpr::DictType(Box::new(k.clone()), Box::new(v.clone())),
219 type_params,
220 nominal_type_names,
221 ),
222 ("Option", [inner]) => {
223 matches!(value, VmValue::Nil)
224 || matches_type_with_generics(value, inner, type_params, nominal_type_names)
225 }
226 _ => true,
230 },
231 TypeExpr::FnType { .. } => matches!(
232 value,
233 VmValue::Closure(_) | VmValue::BuiltinRef(_) | VmValue::BuiltinRefId { .. }
234 ),
235 TypeExpr::Never => false,
236 TypeExpr::LitString(s) => matches!(value, VmValue::String(rs) if rs.as_ref() == s),
237 TypeExpr::LitInt(i) => matches!(value, VmValue::Int(rv) if rv == i),
238 }
239}
240
241pub fn validate_user_call(
245 func: &CompiledFunction,
246 args: &[VmValue],
247 span: Option<Span>,
248) -> Result<(), VmError> {
249 let total = func.params.len();
250 let required = func.required_param_count();
251 let got = args.len();
252
253 let arity_ok = if func.has_rest_param {
254 got >= total.saturating_sub(1)
256 } else {
257 got >= required && got <= total
258 };
259
260 if !arity_ok {
261 let expected = arity_expect_for(func);
262 return Err(VmError::ArityMismatch(Box::new(ArityMismatchError {
263 callee: func.name.clone(),
264 expected,
265 got,
266 span,
267 })));
268 }
269
270 for (i, value) in args.iter().enumerate() {
271 let Some(slot) = user_param_for_arg(func, i) else {
272 continue;
273 };
274 let Some(expected) = &slot.type_expr else {
275 continue;
276 };
277 if matches!(expected, TypeExpr::Named(name) if func.declares_type_param(name)) {
278 continue;
279 }
280 if let Some(schema) = crate::compiler::Compiler::type_expr_to_schema_value(expected) {
281 crate::schema::schema_assert_param(value, &slot.name, &schema)?;
282 continue;
283 }
284 assert_value_matches_type_with_generics(
285 value,
286 expected,
287 &func.name,
288 &slot.name,
289 span,
290 &func.type_params,
291 &func.nominal_type_names,
292 )?;
293 }
294
295 Ok(())
296}
297
298pub fn validate_builtin_call(
305 name: &str,
306 args: &[VmValue],
307 span: Option<Span>,
308) -> Result<(), VmError> {
309 let Some(sig) = builtin_signatures::lookup(name) else {
310 return Ok(());
311 };
312 validate_against_signature(name, sig, args, span)
313}
314
315pub fn validate_against_signature(
319 name: &str,
320 sig: &BuiltinSignature,
321 args: &[VmValue],
322 span: Option<Span>,
323) -> Result<(), VmError> {
324 let total = sig.params.len();
325 let required = sig.required_params();
326 let got = args.len();
327
328 let arity_ok = if sig.has_rest {
329 got >= total.saturating_sub(1)
330 } else {
331 got >= required && got <= total
332 };
333
334 if !arity_ok {
335 let expected = if sig.has_rest {
336 ArityExpect::AtLeast(total.saturating_sub(1))
337 } else if required == total {
338 ArityExpect::Exact(total)
339 } else {
340 ArityExpect::Range {
341 min: required,
342 max: total,
343 }
344 };
345 return Err(VmError::ArityMismatch(Box::new(ArityMismatchError {
346 callee: name.to_string(),
347 expected,
348 got,
349 span,
350 })));
351 }
352
353 for (i, value) in args.iter().enumerate() {
354 let Some(param) = builtin_param_for_arg(sig, i) else {
355 continue;
356 };
357 if param.optional && matches!(value, VmValue::Nil) {
358 continue;
359 }
360 let expected = param.ty.to_type_expr();
365 if matches!(&expected, TypeExpr::Named(n) if sig.is_type_param(n)) {
366 continue;
367 }
368 if param.ty.is_any() {
371 continue;
372 }
373 if matches!(param.ty, harn_parser::builtin_signatures::Ty::SchemaOf(_)) {
374 continue;
375 }
376 assert_value_matches_type(value, &expected, name, param.name, span)?;
377 }
378
379 Ok(())
380}
381
382fn arity_expect_for(func: &CompiledFunction) -> ArityExpect {
386 let total = func.params.len();
387 let required = func.required_param_count();
388 if func.has_rest_param {
389 ArityExpect::AtLeast(total.saturating_sub(1))
390 } else if required == total {
391 ArityExpect::Exact(total)
392 } else {
393 ArityExpect::Range {
394 min: required,
395 max: total,
396 }
397 }
398}
399
400#[cfg(test)]
401mod tests {
402 use super::*;
403 use std::rc::Rc;
404
405 fn vm_int(n: i64) -> VmValue {
406 VmValue::Int(n)
407 }
408
409 fn vm_string(s: &str) -> VmValue {
410 VmValue::String(Rc::from(s))
411 }
412
413 fn ty_int() -> TypeExpr {
414 TypeExpr::Named("int".into())
415 }
416
417 fn ty_string() -> TypeExpr {
418 TypeExpr::Named("string".into())
419 }
420
421 #[test]
422 fn matches_primitive_types() {
423 assert!(matches_type(&vm_int(42), &ty_int()));
424 assert!(!matches_type(&vm_int(42), &ty_string()));
425 assert!(matches_type(&vm_string("x"), &ty_string()));
426 assert!(matches_type(
427 &VmValue::Bool(true),
428 &TypeExpr::Named("bool".into())
429 ));
430 assert!(matches_type(&VmValue::Nil, &TypeExpr::Named("nil".into())));
431 }
432
433 #[test]
434 fn float_accepts_int_promotion() {
435 assert!(matches_type(&vm_int(3), &TypeExpr::Named("float".into())));
437 assert!(matches_type(
438 &VmValue::Float(3.0),
439 &TypeExpr::Named("float".into())
440 ));
441 }
442
443 #[test]
444 fn union_accepts_any_member() {
445 let union = TypeExpr::Union(vec![ty_int(), ty_string()]);
446 assert!(matches_type(&vm_int(1), &union));
447 assert!(matches_type(&vm_string("y"), &union));
448 assert!(!matches_type(&VmValue::Bool(true), &union));
449 }
450
451 #[test]
452 fn optional_accepts_nil() {
453 let opt = TypeExpr::Union(vec![ty_string(), TypeExpr::Named("nil".into())]);
454 assert!(matches_type(&VmValue::Nil, &opt));
455 assert!(matches_type(&vm_string("x"), &opt));
456 assert!(!matches_type(&vm_int(1), &opt));
457 }
458
459 #[test]
460 fn list_validates_elements() {
461 let list_int = TypeExpr::List(Box::new(ty_int()));
462 let good = VmValue::List(Rc::new(vec![vm_int(1), vm_int(2)]));
463 let bad = VmValue::List(Rc::new(vec![vm_int(1), vm_string("x")]));
464 assert!(matches_type(&good, &list_int));
465 assert!(!matches_type(&bad, &list_int));
466 }
467
468 #[test]
469 fn shape_validates_required_fields() {
470 let shape = TypeExpr::Shape(vec![harn_parser::ShapeField {
471 name: "x".into(),
472 type_expr: ty_int(),
473 optional: false,
474 }]);
475 let mut good = std::collections::BTreeMap::new();
476 good.insert("x".to_string(), vm_int(7));
477 assert!(matches_type(&VmValue::Dict(Rc::new(good)), &shape));
478 assert!(!matches_type(
479 &VmValue::Dict(Rc::new(std::collections::BTreeMap::new())),
480 &shape
481 ));
482 }
483
484 #[test]
485 fn named_type_matches_user_struct_name() {
486 let custom = TypeExpr::Named("MyStruct".into());
487 assert!(!matches_type_with_generics(
488 &vm_int(1),
489 &custom,
490 &[],
491 &["MyStruct".to_string()]
492 ));
493 assert!(matches_type_with_generics(
494 &VmValue::struct_instance("MyStruct", Default::default()),
495 &custom,
496 &[],
497 &["MyStruct".to_string()]
498 ));
499 }
500
501 #[test]
502 fn lit_int_requires_value_equality() {
503 assert!(matches_type(&vm_int(42), &TypeExpr::LitInt(42)));
504 assert!(!matches_type(&vm_int(7), &TypeExpr::LitInt(42)));
505 }
506
507 #[test]
508 fn assert_value_returns_arg_type_mismatch_on_fail() {
509 let err =
510 assert_value_matches_type(&vm_string("abc"), &ty_int(), "myFn", "n", None).unwrap_err();
511 match err {
512 VmError::ArgTypeMismatch(err) => {
513 assert_eq!(err.callee, "myFn");
514 assert_eq!(err.param, "n");
515 assert_eq!(err.expected, "int");
516 assert_eq!(err.got, "string");
517 assert!(err.span.is_none());
518 }
519 other => panic!("expected ArgTypeMismatch, got {other:?}"),
520 }
521 }
522}