Skip to main content

mlua_batteries/
validate.rs

1//! Table validation module.
2//!
3//! Validates Lua table structure against a schema definition.
4//! Schemas are plain Lua tables — no external schema language required.
5//!
6//! # Schema formats
7//!
8//! **Shorthand** — type name only (field is optional):
9//!
10//! ```lua
11//! local schema = {name = "string", age = "number"}
12//! ```
13//!
14//! **Full** — table with constraints:
15//!
16//! ```lua
17//! local schema = {
18//!     name   = {type = "string", required = true, min_len = 1},
19//!     age    = {type = "number", min = 0, max = 150},
20//!     status = {type = "string", one_of = {"active", "inactive"}},
21//!     tags   = {type = "table"},
22//! }
23//! ```
24//!
25//! # Supported constraints
26//!
27//! | Key | Applies to | Description |
28//! |-----|-----------|-------------|
29//! | `type` | all | Expected Lua type name |
30//! | `required` | all | Field must be non-nil (default: false) |
31//! | `min` | number/integer | Minimum value (inclusive) |
32//! | `max` | number/integer | Maximum value (inclusive) |
33//! | `min_len` | string | Minimum string length |
34//! | `max_len` | string | Maximum string length |
35//! | `one_of` | string/number/integer/boolean | Allowed values list |
36//!
37//! # Usage
38//!
39//! ```lua
40//! local ok, errors = std.validate.check(data, schema)
41//! if not ok then
42//!     for _, msg in ipairs(errors) do print(msg) end
43//! end
44//! ```
45
46use mlua::prelude::*;
47
48pub fn module(lua: &Lua) -> LuaResult<LuaTable> {
49    let t = lua.create_table()?;
50
51    t.set(
52        "check",
53        lua.create_function(|lua, (data, schema): (LuaTable, LuaTable)| {
54            let mut errors: Vec<String> = Vec::new();
55            validate_table(&data, &schema, &mut errors)?;
56            if errors.is_empty() {
57                Ok((true, LuaValue::Nil))
58            } else {
59                let err_table = lua.create_table()?;
60                for (i, e) in errors.iter().enumerate() {
61                    err_table.set(i + 1, e.as_str())?;
62                }
63                Ok((false, LuaValue::Table(err_table)))
64            }
65        })?,
66    )?;
67
68    Ok(t)
69}
70
71struct FieldSpec {
72    type_name: Option<String>,
73    required: bool,
74    min: Option<f64>,
75    max: Option<f64>,
76    min_len: Option<usize>,
77    max_len: Option<usize>,
78    one_of: Option<Vec<LuaValue>>,
79}
80
81fn parse_field_spec(value: &LuaValue) -> LuaResult<FieldSpec> {
82    match value {
83        LuaValue::String(s) => Ok(FieldSpec {
84            type_name: Some(s.to_str()?.to_string()),
85            required: false,
86            min: None,
87            max: None,
88            min_len: None,
89            max_len: None,
90            one_of: None,
91        }),
92        LuaValue::Table(t) => {
93            let type_name: Option<String> = t.get("type")?;
94            let required: Option<bool> = t.get("required")?;
95            let min: Option<f64> = t.get("min")?;
96            let max: Option<f64> = t.get("max")?;
97            let min_len: Option<usize> = t.get("min_len")?;
98            let max_len: Option<usize> = t.get("max_len")?;
99            let one_of_table: Option<LuaTable> = t.get("one_of")?;
100            let one_of = match one_of_table {
101                Some(tbl) => {
102                    let mut vals = Vec::new();
103                    for v in tbl.sequence_values::<LuaValue>() {
104                        vals.push(v?);
105                    }
106                    Some(vals)
107                }
108                None => None,
109            };
110            Ok(FieldSpec {
111                type_name,
112                required: required.unwrap_or(false),
113                min,
114                max,
115                min_len,
116                max_len,
117                one_of,
118            })
119        }
120        other => Err(LuaError::external(format!(
121            "validate: schema field must be a string or table, got {}",
122            other.type_name()
123        ))),
124    }
125}
126
127fn validate_table(data: &LuaTable, schema: &LuaTable, errors: &mut Vec<String>) -> LuaResult<()> {
128    for pair in schema.pairs::<LuaValue, LuaValue>() {
129        let (key, spec_value) = pair?;
130        let key_str = format_key(&key);
131        let spec = parse_field_spec(&spec_value)?;
132        let value: LuaValue = data.get(key)?;
133        validate_field(&key_str, &value, &spec, errors);
134    }
135    Ok(())
136}
137
138fn validate_field(key: &str, value: &LuaValue, spec: &FieldSpec, errors: &mut Vec<String>) {
139    // nil check
140    if matches!(value, LuaValue::Nil) {
141        if spec.required {
142            errors.push(format!("{key}: required"));
143        }
144        return;
145    }
146
147    // type check
148    if let Some(ref expected) = spec.type_name {
149        if !matches_type(value, expected) {
150            errors.push(format!(
151                "{key}: expected {expected}, got {}",
152                lua_type_name(value)
153            ));
154            return; // skip further checks if type is wrong
155        }
156    }
157
158    // numeric range
159    if let Some(n) = as_number(value) {
160        if let Some(min) = spec.min {
161            if n < min {
162                errors.push(format!("{key}: must be >= {min}, got {n}"));
163            }
164        }
165        if let Some(max) = spec.max {
166            if n > max {
167                errors.push(format!("{key}: must be <= {max}, got {n}"));
168            }
169        }
170    }
171
172    // string length
173    if let LuaValue::String(s) = value {
174        let len = s.as_bytes().len();
175        if let Some(min_len) = spec.min_len {
176            if len < min_len {
177                errors.push(format!("{key}: length must be >= {min_len}, got {len}"));
178            }
179        }
180        if let Some(max_len) = spec.max_len {
181            if len > max_len {
182                errors.push(format!("{key}: length must be <= {max_len}, got {len}"));
183            }
184        }
185    }
186
187    // one_of
188    if let Some(ref allowed) = spec.one_of {
189        if !allowed.iter().any(|a| values_equal(a, value)) {
190            let allowed_str = allowed
191                .iter()
192                .map(format_display)
193                .collect::<Vec<_>>()
194                .join(", ");
195            errors.push(format!(
196                "{key}: must be one of [{allowed_str}], got {}",
197                format_display(value)
198            ));
199        }
200    }
201}
202
203fn matches_type(value: &LuaValue, expected: &str) -> bool {
204    match expected {
205        "string" => matches!(value, LuaValue::String(_)),
206        "number" => matches!(value, LuaValue::Number(_) | LuaValue::Integer(_)),
207        "integer" => matches!(value, LuaValue::Integer(_)),
208        "boolean" => matches!(value, LuaValue::Boolean(_)),
209        "table" => matches!(value, LuaValue::Table(_)),
210        "function" => matches!(value, LuaValue::Function(_)),
211        "any" => true,
212        _ => false,
213    }
214}
215
216fn lua_type_name(value: &LuaValue) -> &'static str {
217    match value {
218        LuaValue::Nil => "nil",
219        LuaValue::Boolean(_) => "boolean",
220        LuaValue::Integer(_) => "integer",
221        LuaValue::Number(_) => "number",
222        LuaValue::String(_) => "string",
223        LuaValue::Table(_) => "table",
224        LuaValue::Function(_) => "function",
225        _ => "userdata",
226    }
227}
228
229fn as_number(value: &LuaValue) -> Option<f64> {
230    match value {
231        LuaValue::Number(n) => Some(*n),
232        LuaValue::Integer(i) => Some(*i as f64),
233        _ => None,
234    }
235}
236
237fn values_equal(a: &LuaValue, b: &LuaValue) -> bool {
238    match (a, b) {
239        (LuaValue::String(a), LuaValue::String(b)) => a.as_bytes() == b.as_bytes(),
240        (LuaValue::Integer(a), LuaValue::Integer(b)) => a == b,
241        (LuaValue::Number(a), LuaValue::Number(b)) => a == b,
242        (LuaValue::Integer(a), LuaValue::Number(b)) => (*a as f64) == *b,
243        (LuaValue::Number(a), LuaValue::Integer(b)) => *a == (*b as f64),
244        (LuaValue::Boolean(a), LuaValue::Boolean(b)) => a == b,
245        (LuaValue::Nil, LuaValue::Nil) => true,
246        _ => false,
247    }
248}
249
250fn format_key(value: &LuaValue) -> String {
251    match value {
252        LuaValue::String(s) => s.to_string_lossy().to_string(),
253        LuaValue::Integer(i) => i.to_string(),
254        other => format!("<{}>", other.type_name()),
255    }
256}
257
258fn format_display(value: &LuaValue) -> String {
259    match value {
260        LuaValue::Nil => "nil".to_string(),
261        LuaValue::Boolean(b) => b.to_string(),
262        LuaValue::Integer(i) => i.to_string(),
263        LuaValue::Number(n) => n.to_string(),
264        LuaValue::String(s) => format!("\"{}\"", s.to_string_lossy()),
265        other => format!("<{}>", other.type_name()),
266    }
267}
268
269#[cfg(test)]
270mod tests {
271    use crate::util::test_eval as eval;
272
273    // ─── shorthand (type-only) ────────────────────────────
274
275    #[test]
276    fn shorthand_valid() {
277        let ok: bool = eval(
278            r#"
279            local ok, _ = std.validate.check(
280                {name = "John", age = 30},
281                {name = "string", age = "number"}
282            )
283            return ok
284        "#,
285        );
286        assert!(ok);
287    }
288
289    #[test]
290    fn shorthand_type_mismatch() {
291        let s: String = eval(
292            r#"
293            local ok, errs = std.validate.check(
294                {name = 42},
295                {name = "string"}
296            )
297            return errs[1]
298        "#,
299        );
300        assert!(s.contains("expected string, got integer"), "got: {s}");
301    }
302
303    #[test]
304    fn shorthand_missing_optional_is_ok() {
305        let ok: bool = eval(
306            r#"
307            local ok, _ = std.validate.check(
308                {},
309                {name = "string"}
310            )
311            return ok
312        "#,
313        );
314        assert!(ok);
315    }
316
317    // ─── required ─────────────────────────────────────────
318
319    #[test]
320    fn required_missing_field() {
321        let s: String = eval(
322            r#"
323            local ok, errs = std.validate.check(
324                {},
325                {name = {type = "string", required = true}}
326            )
327            return errs[1]
328        "#,
329        );
330        assert!(s.contains("required"), "got: {s}");
331    }
332
333    #[test]
334    fn required_present_field() {
335        let ok: bool = eval(
336            r#"
337            local ok, _ = std.validate.check(
338                {name = "John"},
339                {name = {type = "string", required = true}}
340            )
341            return ok
342        "#,
343        );
344        assert!(ok);
345    }
346
347    // ─── numeric range ────────────────────────────────────
348
349    #[test]
350    fn min_violated() {
351        let s: String = eval(
352            r#"
353            local ok, errs = std.validate.check(
354                {age = -1},
355                {age = {type = "number", min = 0}}
356            )
357            return errs[1]
358        "#,
359        );
360        assert!(s.contains(">= 0"), "got: {s}");
361    }
362
363    #[test]
364    fn max_violated() {
365        let s: String = eval(
366            r#"
367            local ok, errs = std.validate.check(
368                {age = 200},
369                {age = {type = "number", max = 150}}
370            )
371            return errs[1]
372        "#,
373        );
374        assert!(s.contains("<= 150"), "got: {s}");
375    }
376
377    #[test]
378    fn range_valid() {
379        let ok: bool = eval(
380            r#"
381            local ok, _ = std.validate.check(
382                {age = 30},
383                {age = {type = "number", min = 0, max = 150}}
384            )
385            return ok
386        "#,
387        );
388        assert!(ok);
389    }
390
391    // ─── string length ────────────────────────────────────
392
393    #[test]
394    fn min_len_violated() {
395        let s: String = eval(
396            r#"
397            local ok, errs = std.validate.check(
398                {name = ""},
399                {name = {type = "string", min_len = 1}}
400            )
401            return errs[1]
402        "#,
403        );
404        assert!(s.contains("length must be >= 1"), "got: {s}");
405    }
406
407    #[test]
408    fn max_len_violated() {
409        let s: String = eval(
410            r#"
411            local ok, errs = std.validate.check(
412                {code = "ABCDEF"},
413                {code = {type = "string", max_len = 3}}
414            )
415            return errs[1]
416        "#,
417        );
418        assert!(s.contains("length must be <= 3"), "got: {s}");
419    }
420
421    // ─── one_of ───────────────────────────────────────────
422
423    #[test]
424    fn one_of_valid() {
425        let ok: bool = eval(
426            r#"
427            local ok, _ = std.validate.check(
428                {status = "active"},
429                {status = {type = "string", one_of = {"active", "inactive"}}}
430            )
431            return ok
432        "#,
433        );
434        assert!(ok);
435    }
436
437    #[test]
438    fn one_of_violated() {
439        let s: String = eval(
440            r#"
441            local ok, errs = std.validate.check(
442                {status = "unknown"},
443                {status = {type = "string", one_of = {"active", "inactive"}}}
444            )
445            return errs[1]
446        "#,
447        );
448        assert!(s.contains("must be one of"), "got: {s}");
449        assert!(s.contains("\"active\""), "got: {s}");
450    }
451
452    #[test]
453    fn one_of_numeric() {
454        let ok: bool = eval(
455            r#"
456            local ok, _ = std.validate.check(
457                {level = 2},
458                {level = {type = "number", one_of = {1, 2, 3}}}
459            )
460            return ok
461        "#,
462        );
463        assert!(ok);
464    }
465
466    // ─── integer type ─────────────────────────────────────
467
468    #[test]
469    fn integer_accepts_integer() {
470        let ok: bool = eval(
471            r#"
472            local ok, _ = std.validate.check(
473                {count = 42},
474                {count = "integer"}
475            )
476            return ok
477        "#,
478        );
479        assert!(ok);
480    }
481
482    #[test]
483    fn integer_rejects_float() {
484        let s: String = eval(
485            r#"
486            local ok, errs = std.validate.check(
487                {count = 3.14},
488                {count = "integer"}
489            )
490            return errs[1]
491        "#,
492        );
493        assert!(s.contains("expected integer, got number"), "got: {s}");
494    }
495
496    // ─── any type ─────────────────────────────────────────
497
498    #[test]
499    fn any_accepts_anything() {
500        let ok: bool = eval(
501            r#"
502            local ok, _ = std.validate.check(
503                {data = "text", count = 42, flag = true},
504                {data = "any", count = "any", flag = "any"}
505            )
506            return ok
507        "#,
508        );
509        assert!(ok);
510    }
511
512    // ─── multiple errors ──────────────────────────────────
513
514    #[test]
515    fn multiple_errors_collected() {
516        let n: i64 = eval(
517            r#"
518            local ok, errs = std.validate.check(
519                {name = 42, age = "old"},
520                {name = "string", age = "number"}
521            )
522            return #errs
523        "#,
524        );
525        assert_eq!(n, 2);
526    }
527
528    // ─── table type ───────────────────────────────────────
529
530    #[test]
531    fn table_type_valid() {
532        let ok: bool = eval(
533            r#"
534            local ok, _ = std.validate.check(
535                {tags = {"a", "b"}},
536                {tags = "table"}
537            )
538            return ok
539        "#,
540        );
541        assert!(ok);
542    }
543
544    #[test]
545    fn table_type_rejects_string() {
546        let s: String = eval(
547            r#"
548            local ok, errs = std.validate.check(
549                {tags = "not a table"},
550                {tags = "table"}
551            )
552            return errs[1]
553        "#,
554        );
555        assert!(s.contains("expected table, got string"), "got: {s}");
556    }
557
558    // ─── boolean type ─────────────────────────────────────
559
560    #[test]
561    fn boolean_valid() {
562        let ok: bool = eval(
563            r#"
564            local ok, _ = std.validate.check(
565                {active = true},
566                {active = "boolean"}
567            )
568            return ok
569        "#,
570        );
571        assert!(ok);
572    }
573
574    // ─── edge cases ───────────────────────────────────────
575
576    #[test]
577    fn empty_schema_always_passes() {
578        let ok: bool = eval(
579            r#"
580            local ok, _ = std.validate.check({anything = "here"}, {})
581            return ok
582        "#,
583        );
584        assert!(ok);
585    }
586
587    #[test]
588    fn schema_with_invalid_spec_returns_error() {
589        let lua = mlua::Lua::new();
590        crate::register_all(&lua, "std").unwrap();
591        let result: mlua::Result<mlua::Value> = lua
592            .load(r#"return std.validate.check({x = 1}, {x = 42})"#)
593            .eval();
594        assert!(result.is_err());
595    }
596
597    #[test]
598    fn type_mismatch_skips_range_checks() {
599        let n: i64 = eval(
600            r#"
601            local ok, errs = std.validate.check(
602                {age = "not a number"},
603                {age = {type = "number", min = 0, max = 150}}
604            )
605            return #errs
606        "#,
607        );
608        // Only the type error, not min/max errors
609        assert_eq!(n, 1);
610    }
611}