1use mlua::prelude::*;
47
48pub fn module(lua: &Lua) -> LuaResult<LuaTable> {
49 let t = lua.create_table()?;
50
51 t.set(
52 "check",
53 lua.create_function(|lua, (data, schema): (LuaTable, LuaTable)| {
54 let mut errors: Vec<String> = Vec::new();
55 validate_table(&data, &schema, &mut errors)?;
56 if errors.is_empty() {
57 Ok((true, LuaValue::Nil))
58 } else {
59 let err_table = lua.create_table()?;
60 for (i, e) in errors.iter().enumerate() {
61 err_table.set(i + 1, e.as_str())?;
62 }
63 Ok((false, LuaValue::Table(err_table)))
64 }
65 })?,
66 )?;
67
68 Ok(t)
69}
70
71struct FieldSpec {
72 type_name: Option<String>,
73 required: bool,
74 min: Option<f64>,
75 max: Option<f64>,
76 min_len: Option<usize>,
77 max_len: Option<usize>,
78 one_of: Option<Vec<LuaValue>>,
79}
80
81fn parse_field_spec(value: &LuaValue) -> LuaResult<FieldSpec> {
82 match value {
83 LuaValue::String(s) => Ok(FieldSpec {
84 type_name: Some(s.to_str()?.to_string()),
85 required: false,
86 min: None,
87 max: None,
88 min_len: None,
89 max_len: None,
90 one_of: None,
91 }),
92 LuaValue::Table(t) => {
93 let type_name: Option<String> = t.get("type")?;
94 let required: Option<bool> = t.get("required")?;
95 let min: Option<f64> = t.get("min")?;
96 let max: Option<f64> = t.get("max")?;
97 let min_len: Option<usize> = t.get("min_len")?;
98 let max_len: Option<usize> = t.get("max_len")?;
99 let one_of_table: Option<LuaTable> = t.get("one_of")?;
100 let one_of = match one_of_table {
101 Some(tbl) => {
102 let mut vals = Vec::new();
103 for v in tbl.sequence_values::<LuaValue>() {
104 vals.push(v?);
105 }
106 Some(vals)
107 }
108 None => None,
109 };
110 Ok(FieldSpec {
111 type_name,
112 required: required.unwrap_or(false),
113 min,
114 max,
115 min_len,
116 max_len,
117 one_of,
118 })
119 }
120 other => Err(LuaError::external(format!(
121 "validate: schema field must be a string or table, got {}",
122 other.type_name()
123 ))),
124 }
125}
126
127fn validate_table(data: &LuaTable, schema: &LuaTable, errors: &mut Vec<String>) -> LuaResult<()> {
128 for pair in schema.pairs::<LuaValue, LuaValue>() {
129 let (key, spec_value) = pair?;
130 let key_str = format_key(&key);
131 let spec = parse_field_spec(&spec_value)?;
132 let value: LuaValue = data.get(key)?;
133 validate_field(&key_str, &value, &spec, errors);
134 }
135 Ok(())
136}
137
138fn validate_field(key: &str, value: &LuaValue, spec: &FieldSpec, errors: &mut Vec<String>) {
139 if matches!(value, LuaValue::Nil) {
141 if spec.required {
142 errors.push(format!("{key}: required"));
143 }
144 return;
145 }
146
147 if let Some(ref expected) = spec.type_name {
149 if !matches_type(value, expected) {
150 errors.push(format!(
151 "{key}: expected {expected}, got {}",
152 lua_type_name(value)
153 ));
154 return; }
156 }
157
158 if let Some(n) = as_number(value) {
160 if let Some(min) = spec.min {
161 if n < min {
162 errors.push(format!("{key}: must be >= {min}, got {n}"));
163 }
164 }
165 if let Some(max) = spec.max {
166 if n > max {
167 errors.push(format!("{key}: must be <= {max}, got {n}"));
168 }
169 }
170 }
171
172 if let LuaValue::String(s) = value {
174 let len = s.as_bytes().len();
175 if let Some(min_len) = spec.min_len {
176 if len < min_len {
177 errors.push(format!("{key}: length must be >= {min_len}, got {len}"));
178 }
179 }
180 if let Some(max_len) = spec.max_len {
181 if len > max_len {
182 errors.push(format!("{key}: length must be <= {max_len}, got {len}"));
183 }
184 }
185 }
186
187 if let Some(ref allowed) = spec.one_of {
189 if !allowed.iter().any(|a| values_equal(a, value)) {
190 let allowed_str = allowed
191 .iter()
192 .map(format_display)
193 .collect::<Vec<_>>()
194 .join(", ");
195 errors.push(format!(
196 "{key}: must be one of [{allowed_str}], got {}",
197 format_display(value)
198 ));
199 }
200 }
201}
202
203fn matches_type(value: &LuaValue, expected: &str) -> bool {
204 match expected {
205 "string" => matches!(value, LuaValue::String(_)),
206 "number" => matches!(value, LuaValue::Number(_) | LuaValue::Integer(_)),
207 "integer" => matches!(value, LuaValue::Integer(_)),
208 "boolean" => matches!(value, LuaValue::Boolean(_)),
209 "table" => matches!(value, LuaValue::Table(_)),
210 "function" => matches!(value, LuaValue::Function(_)),
211 "any" => true,
212 _ => false,
213 }
214}
215
216fn lua_type_name(value: &LuaValue) -> &'static str {
217 match value {
218 LuaValue::Nil => "nil",
219 LuaValue::Boolean(_) => "boolean",
220 LuaValue::Integer(_) => "integer",
221 LuaValue::Number(_) => "number",
222 LuaValue::String(_) => "string",
223 LuaValue::Table(_) => "table",
224 LuaValue::Function(_) => "function",
225 _ => "userdata",
226 }
227}
228
229fn as_number(value: &LuaValue) -> Option<f64> {
230 match value {
231 LuaValue::Number(n) => Some(*n),
232 LuaValue::Integer(i) => Some(*i as f64),
233 _ => None,
234 }
235}
236
237fn values_equal(a: &LuaValue, b: &LuaValue) -> bool {
238 match (a, b) {
239 (LuaValue::String(a), LuaValue::String(b)) => a.as_bytes() == b.as_bytes(),
240 (LuaValue::Integer(a), LuaValue::Integer(b)) => a == b,
241 (LuaValue::Number(a), LuaValue::Number(b)) => a == b,
242 (LuaValue::Integer(a), LuaValue::Number(b)) => (*a as f64) == *b,
243 (LuaValue::Number(a), LuaValue::Integer(b)) => *a == (*b as f64),
244 (LuaValue::Boolean(a), LuaValue::Boolean(b)) => a == b,
245 (LuaValue::Nil, LuaValue::Nil) => true,
246 _ => false,
247 }
248}
249
250fn format_key(value: &LuaValue) -> String {
251 match value {
252 LuaValue::String(s) => s.to_string_lossy().to_string(),
253 LuaValue::Integer(i) => i.to_string(),
254 other => format!("<{}>", other.type_name()),
255 }
256}
257
258fn format_display(value: &LuaValue) -> String {
259 match value {
260 LuaValue::Nil => "nil".to_string(),
261 LuaValue::Boolean(b) => b.to_string(),
262 LuaValue::Integer(i) => i.to_string(),
263 LuaValue::Number(n) => n.to_string(),
264 LuaValue::String(s) => format!("\"{}\"", s.to_string_lossy()),
265 other => format!("<{}>", other.type_name()),
266 }
267}
268
269#[cfg(test)]
270mod tests {
271 use crate::util::test_eval as eval;
272
273 #[test]
276 fn shorthand_valid() {
277 let ok: bool = eval(
278 r#"
279 local ok, _ = std.validate.check(
280 {name = "John", age = 30},
281 {name = "string", age = "number"}
282 )
283 return ok
284 "#,
285 );
286 assert!(ok);
287 }
288
289 #[test]
290 fn shorthand_type_mismatch() {
291 let s: String = eval(
292 r#"
293 local ok, errs = std.validate.check(
294 {name = 42},
295 {name = "string"}
296 )
297 return errs[1]
298 "#,
299 );
300 assert!(s.contains("expected string, got integer"), "got: {s}");
301 }
302
303 #[test]
304 fn shorthand_missing_optional_is_ok() {
305 let ok: bool = eval(
306 r#"
307 local ok, _ = std.validate.check(
308 {},
309 {name = "string"}
310 )
311 return ok
312 "#,
313 );
314 assert!(ok);
315 }
316
317 #[test]
320 fn required_missing_field() {
321 let s: String = eval(
322 r#"
323 local ok, errs = std.validate.check(
324 {},
325 {name = {type = "string", required = true}}
326 )
327 return errs[1]
328 "#,
329 );
330 assert!(s.contains("required"), "got: {s}");
331 }
332
333 #[test]
334 fn required_present_field() {
335 let ok: bool = eval(
336 r#"
337 local ok, _ = std.validate.check(
338 {name = "John"},
339 {name = {type = "string", required = true}}
340 )
341 return ok
342 "#,
343 );
344 assert!(ok);
345 }
346
347 #[test]
350 fn min_violated() {
351 let s: String = eval(
352 r#"
353 local ok, errs = std.validate.check(
354 {age = -1},
355 {age = {type = "number", min = 0}}
356 )
357 return errs[1]
358 "#,
359 );
360 assert!(s.contains(">= 0"), "got: {s}");
361 }
362
363 #[test]
364 fn max_violated() {
365 let s: String = eval(
366 r#"
367 local ok, errs = std.validate.check(
368 {age = 200},
369 {age = {type = "number", max = 150}}
370 )
371 return errs[1]
372 "#,
373 );
374 assert!(s.contains("<= 150"), "got: {s}");
375 }
376
377 #[test]
378 fn range_valid() {
379 let ok: bool = eval(
380 r#"
381 local ok, _ = std.validate.check(
382 {age = 30},
383 {age = {type = "number", min = 0, max = 150}}
384 )
385 return ok
386 "#,
387 );
388 assert!(ok);
389 }
390
391 #[test]
394 fn min_len_violated() {
395 let s: String = eval(
396 r#"
397 local ok, errs = std.validate.check(
398 {name = ""},
399 {name = {type = "string", min_len = 1}}
400 )
401 return errs[1]
402 "#,
403 );
404 assert!(s.contains("length must be >= 1"), "got: {s}");
405 }
406
407 #[test]
408 fn max_len_violated() {
409 let s: String = eval(
410 r#"
411 local ok, errs = std.validate.check(
412 {code = "ABCDEF"},
413 {code = {type = "string", max_len = 3}}
414 )
415 return errs[1]
416 "#,
417 );
418 assert!(s.contains("length must be <= 3"), "got: {s}");
419 }
420
421 #[test]
424 fn one_of_valid() {
425 let ok: bool = eval(
426 r#"
427 local ok, _ = std.validate.check(
428 {status = "active"},
429 {status = {type = "string", one_of = {"active", "inactive"}}}
430 )
431 return ok
432 "#,
433 );
434 assert!(ok);
435 }
436
437 #[test]
438 fn one_of_violated() {
439 let s: String = eval(
440 r#"
441 local ok, errs = std.validate.check(
442 {status = "unknown"},
443 {status = {type = "string", one_of = {"active", "inactive"}}}
444 )
445 return errs[1]
446 "#,
447 );
448 assert!(s.contains("must be one of"), "got: {s}");
449 assert!(s.contains("\"active\""), "got: {s}");
450 }
451
452 #[test]
453 fn one_of_numeric() {
454 let ok: bool = eval(
455 r#"
456 local ok, _ = std.validate.check(
457 {level = 2},
458 {level = {type = "number", one_of = {1, 2, 3}}}
459 )
460 return ok
461 "#,
462 );
463 assert!(ok);
464 }
465
466 #[test]
469 fn integer_accepts_integer() {
470 let ok: bool = eval(
471 r#"
472 local ok, _ = std.validate.check(
473 {count = 42},
474 {count = "integer"}
475 )
476 return ok
477 "#,
478 );
479 assert!(ok);
480 }
481
482 #[test]
483 fn integer_rejects_float() {
484 let s: String = eval(
485 r#"
486 local ok, errs = std.validate.check(
487 {count = 3.14},
488 {count = "integer"}
489 )
490 return errs[1]
491 "#,
492 );
493 assert!(s.contains("expected integer, got number"), "got: {s}");
494 }
495
496 #[test]
499 fn any_accepts_anything() {
500 let ok: bool = eval(
501 r#"
502 local ok, _ = std.validate.check(
503 {data = "text", count = 42, flag = true},
504 {data = "any", count = "any", flag = "any"}
505 )
506 return ok
507 "#,
508 );
509 assert!(ok);
510 }
511
512 #[test]
515 fn multiple_errors_collected() {
516 let n: i64 = eval(
517 r#"
518 local ok, errs = std.validate.check(
519 {name = 42, age = "old"},
520 {name = "string", age = "number"}
521 )
522 return #errs
523 "#,
524 );
525 assert_eq!(n, 2);
526 }
527
528 #[test]
531 fn table_type_valid() {
532 let ok: bool = eval(
533 r#"
534 local ok, _ = std.validate.check(
535 {tags = {"a", "b"}},
536 {tags = "table"}
537 )
538 return ok
539 "#,
540 );
541 assert!(ok);
542 }
543
544 #[test]
545 fn table_type_rejects_string() {
546 let s: String = eval(
547 r#"
548 local ok, errs = std.validate.check(
549 {tags = "not a table"},
550 {tags = "table"}
551 )
552 return errs[1]
553 "#,
554 );
555 assert!(s.contains("expected table, got string"), "got: {s}");
556 }
557
558 #[test]
561 fn boolean_valid() {
562 let ok: bool = eval(
563 r#"
564 local ok, _ = std.validate.check(
565 {active = true},
566 {active = "boolean"}
567 )
568 return ok
569 "#,
570 );
571 assert!(ok);
572 }
573
574 #[test]
577 fn empty_schema_always_passes() {
578 let ok: bool = eval(
579 r#"
580 local ok, _ = std.validate.check({anything = "here"}, {})
581 return ok
582 "#,
583 );
584 assert!(ok);
585 }
586
587 #[test]
588 fn schema_with_invalid_spec_returns_error() {
589 let lua = mlua::Lua::new();
590 crate::register_all(&lua, "std").unwrap();
591 let result: mlua::Result<mlua::Value> = lua
592 .load(r#"return std.validate.check({x = 1}, {x = 42})"#)
593 .eval();
594 assert!(result.is_err());
595 }
596
597 #[test]
598 fn type_mismatch_skips_range_checks() {
599 let n: i64 = eval(
600 r#"
601 local ok, errs = std.validate.check(
602 {age = "not a number"},
603 {age = {type = "number", min = 0, max = 150}}
604 )
605 return #errs
606 "#,
607 );
608 assert_eq!(n, 1);
610 }
611}