Skip to main content

mlua_batteries/
json.rs

1//! JSON encode/decode module.
2//!
3//! ```lua
4//! local json = std.json
5//! local t = json.decode('{"a":1}')
6//! local s = json.encode(t)
7//! local s2 = json.encode_pretty(t)
8//! ```
9//!
10//! # Empty tables
11//!
12//! An empty Lua table `{}` is encoded as a JSON **object** `{}`,
13//! not an array `[]`. This matches the `classify` heuristic: a table
14//! with `raw_len() == 0` is treated as a map.
15
16use mlua::prelude::*;
17use serde_json::Value as JsonValue;
18
19use crate::policy::PathOp;
20use crate::util::{check_path, classify, with_config, TableKind};
21
22pub fn module(lua: &Lua) -> LuaResult<LuaTable> {
23    let t = lua.create_table()?;
24
25    t.set(
26        "decode",
27        lua.create_function(|lua, s: String| {
28            let max_depth = with_config(lua, |c| c.max_json_depth)?;
29            let value: JsonValue = serde_json::from_str(&s)
30                .map_err(|e| LuaError::external(format!("json.decode: {e}")))?;
31            json_to_lua(lua, &value, max_depth)
32        })?,
33    )?;
34
35    t.set(
36        "encode",
37        lua.create_function(|lua, value: LuaValue| {
38            let max_depth = with_config(lua, |c| c.max_json_depth)?;
39            let json = lua_to_json(&value, max_depth)?;
40            serde_json::to_string(&json)
41                .map_err(|e| LuaError::external(format!("json.encode: {e}")))
42        })?,
43    )?;
44
45    t.set(
46        "encode_pretty",
47        lua.create_function(|lua, value: LuaValue| {
48            let max_depth = with_config(lua, |c| c.max_json_depth)?;
49            let json = lua_to_json(&value, max_depth)?;
50            serde_json::to_string_pretty(&json)
51                .map_err(|e| LuaError::external(format!("json.encode_pretty: {e}")))
52        })?,
53    )?;
54
55    t.set(
56        "read_file",
57        lua.create_function(|lua, path: String| {
58            let access = check_path(lua, &path, PathOp::Read)?;
59            let max_depth = with_config(lua, |c| c.max_json_depth)?;
60            let content = access.read_to_string().map_err(LuaError::external)?;
61            let value: JsonValue = serde_json::from_str(&content)
62                .map_err(|e| LuaError::external(format!("json.read_file: {e}")))?;
63            json_to_lua(lua, &value, max_depth)
64        })?,
65    )?;
66
67    t.set(
68        "write_file",
69        lua.create_function(|lua, (path, value): (String, LuaValue)| {
70            let access = check_path(lua, &path, PathOp::Write)?;
71            let max_depth = with_config(lua, |c| c.max_json_depth)?;
72            let json = lua_to_json(&value, max_depth)?;
73            let content = serde_json::to_string_pretty(&json)
74                .map_err(|e| LuaError::external(format!("json.write_file: {e}")))?;
75            access
76                .write(content.as_bytes())
77                .map_err(LuaError::external)?;
78            Ok(true)
79        })?,
80    )?;
81
82    Ok(t)
83}
84
85// ─── Conversion: JSON → Lua ────────────────────────────
86
87pub(crate) fn json_to_lua(lua: &Lua, value: &JsonValue, max_depth: usize) -> LuaResult<LuaValue> {
88    json_to_lua_inner(lua, value, 0, max_depth)
89}
90
91fn json_to_lua_inner(
92    lua: &Lua,
93    value: &JsonValue,
94    depth: usize,
95    max_depth: usize,
96) -> LuaResult<LuaValue> {
97    if depth > max_depth {
98        return Err(LuaError::external(format!(
99            "JSON nesting too deep (limit: {max_depth})"
100        )));
101    }
102    match value {
103        JsonValue::Null => Ok(LuaValue::Nil),
104        JsonValue::Bool(b) => Ok(LuaValue::Boolean(*b)),
105        JsonValue::Number(n) => {
106            if let Some(i) = n.as_i64() {
107                Ok(LuaValue::Integer(i))
108            } else if let Some(f) = n.as_f64() {
109                Ok(LuaValue::Number(f))
110            } else {
111                Err(LuaError::external(format!(
112                    "JSON number {n} is not representable as i64 or f64"
113                )))
114            }
115        }
116        JsonValue::String(s) => lua.create_string(s).map(LuaValue::String),
117        JsonValue::Array(arr) => {
118            let table = lua.create_table()?;
119            for (i, v) in arr.iter().enumerate() {
120                table.set(i + 1, json_to_lua_inner(lua, v, depth + 1, max_depth)?)?;
121            }
122            Ok(LuaValue::Table(table))
123        }
124        JsonValue::Object(map) => {
125            let table = lua.create_table()?;
126            for (k, v) in map {
127                table.set(k.as_str(), json_to_lua_inner(lua, v, depth + 1, max_depth)?)?;
128            }
129            Ok(LuaValue::Table(table))
130        }
131    }
132}
133
134// ─── Conversion: Lua → JSON ────────────────────────────
135
136pub(crate) fn lua_to_json(value: &LuaValue, max_depth: usize) -> LuaResult<JsonValue> {
137    lua_to_json_inner(value, 0, max_depth)
138}
139
140fn lua_to_json_inner(value: &LuaValue, depth: usize, max_depth: usize) -> LuaResult<JsonValue> {
141    if depth > max_depth {
142        return Err(LuaError::external(format!(
143            "Lua table nesting too deep for JSON (limit: {max_depth})"
144        )));
145    }
146    match value {
147        LuaValue::Nil => Ok(JsonValue::Null),
148        // `mlua::serde::LuaSerdeExt` maps JSON `null` to `Value::NULL`, which is
149        // `LightUserData(ptr::null_mut())`.  Recognize that sentinel so values
150        // produced by mlua's serde bridge can round-trip through `json.encode`.
151        // Non-null `LightUserData` (app-specific) continues to error below.
152        LuaValue::LightUserData(u) if u.0.is_null() => Ok(JsonValue::Null),
153        LuaValue::Boolean(b) => Ok(JsonValue::Bool(*b)),
154        LuaValue::Integer(i) => Ok(JsonValue::Number((*i).into())),
155        LuaValue::Number(n) => serde_json::Number::from_f64(*n)
156            .map(JsonValue::Number)
157            .ok_or_else(|| LuaError::external(format!("cannot convert {n} to JSON number"))),
158        LuaValue::String(s) => Ok(JsonValue::String(s.to_str()?.to_string())),
159        LuaValue::Table(t) => lua_table_to_json(t, depth, max_depth),
160        _ => Err(LuaError::external("unsupported type for JSON conversion")),
161    }
162}
163
164fn lua_table_to_json(table: &LuaTable, depth: usize, max_depth: usize) -> LuaResult<JsonValue> {
165    match classify(table)? {
166        TableKind::Array(len) => {
167            let mut arr = Vec::with_capacity(len);
168            for i in 1..=len {
169                let v: LuaValue = table.raw_get(i)?;
170                arr.push(lua_to_json_inner(&v, depth + 1, max_depth)?);
171            }
172            Ok(JsonValue::Array(arr))
173        }
174        TableKind::Map(pairs) => {
175            let mut map = serde_json::Map::new();
176            for (k, v) in pairs {
177                let key = match k {
178                    LuaValue::String(s) => s.to_str()?.to_string(),
179                    LuaValue::Integer(i) => i.to_string(),
180                    LuaValue::Number(n) => n.to_string(),
181                    other => {
182                        return Err(LuaError::external(format!(
183                            "unsupported table key type for JSON: {}",
184                            other.type_name()
185                        )));
186                    }
187                };
188                map.insert(key, lua_to_json_inner(&v, depth + 1, max_depth)?);
189            }
190            Ok(JsonValue::Object(map))
191        }
192    }
193}
194
195#[cfg(test)]
196mod tests {
197    use super::*;
198
199    use crate::util::test_eval as eval;
200
201    #[test]
202    fn decode_object() {
203        let s: String = eval(
204            r#"
205            local t = std.json.decode('{"a":1,"b":"hello"}')
206            return tostring(t.a) .. "," .. t.b
207        "#,
208        );
209        assert_eq!(s, "1,hello");
210    }
211
212    #[test]
213    fn decode_array() {
214        let n: i64 = eval(
215            r#"
216            local arr = std.json.decode('[10,20,30]')
217            return #arr
218        "#,
219        );
220        assert_eq!(n, 3);
221    }
222
223    #[test]
224    fn decode_invalid_returns_error() {
225        let lua = Lua::new();
226        crate::register_all(&lua, "std").unwrap();
227        let result: mlua::Result<LuaValue> = lua.load(r#"return std.json.decode("{bad")"#).eval();
228        assert!(result.is_err());
229    }
230
231    #[test]
232    fn encode_roundtrip() {
233        let s: String = eval(
234            r#"
235            local original = '{"name":"test","values":[1,2,3]}'
236            local t = std.json.decode(original)
237            local encoded = std.json.encode(t)
238            local t2 = std.json.decode(encoded)
239            return t2.name .. "," .. tostring(#t2.values)
240        "#,
241        );
242        assert_eq!(s, "test,3");
243    }
244
245    #[test]
246    fn encode_empty_table_as_object() {
247        let s: String = eval(
248            r#"
249            return std.json.encode({})
250        "#,
251        );
252        assert_eq!(s, "{}");
253    }
254
255    #[test]
256    fn encode_nested_structure() {
257        let s: String = eval(
258            r#"
259            return std.json.encode({items = {1, 2}, meta = {ok = true}})
260        "#,
261        );
262        let v: serde_json::Value = serde_json::from_str(&s).unwrap();
263        assert!(v["items"].is_array());
264        assert_eq!(v["meta"]["ok"], true);
265    }
266
267    #[test]
268    fn encode_pretty_has_newlines() {
269        let s: String = eval(
270            r#"
271            return std.json.encode_pretty({a = 1})
272        "#,
273        );
274        assert!(s.contains('\n'));
275    }
276
277    #[test]
278    fn decode_null_becomes_nil() {
279        let b: bool = eval(
280            r#"
281            local t = std.json.decode('{"x":null}')
282            return t.x == nil
283        "#,
284        );
285        assert!(b);
286    }
287
288    #[test]
289    fn decode_boolean() {
290        let s: String = eval(
291            r#"
292            local t = std.json.decode('{"flag":true}')
293            return type(t.flag)
294        "#,
295        );
296        assert_eq!(s, "boolean");
297    }
298
299    #[test]
300    fn max_depth_enforced_on_decode() {
301        let lua = Lua::new();
302        let config = crate::config::Config::builder()
303            .max_json_depth(2)
304            .build()
305            .unwrap();
306        crate::register_all_with(&lua, "std", config).unwrap();
307
308        // depth 3: {"a":{"b":{"c":1}}}
309        let result: mlua::Result<LuaValue> = lua
310            .load(r#"return std.json.decode('{"a":{"b":{"c":1}}}')"#)
311            .eval();
312        assert!(result.is_err());
313    }
314
315    #[test]
316    fn max_depth_enforced_on_encode() {
317        let lua = Lua::new();
318        let config = crate::config::Config::builder()
319            .max_json_depth(2)
320            .build()
321            .unwrap();
322        crate::register_all_with(&lua, "std", config).unwrap();
323
324        // depth 3 nested table
325        let result: mlua::Result<LuaValue> = lua
326            .load(r#"return std.json.encode({a = {b = {c = 1}}})"#)
327            .eval();
328        assert!(result.is_err());
329    }
330
331    #[test]
332    fn encode_rejects_boolean_key() {
333        let lua = Lua::new();
334        crate::register_all(&lua, "std").unwrap();
335
336        let result: mlua::Result<LuaValue> = lua
337            .load(
338                r#"
339                local t = {}
340                t[true] = "val"
341                return std.json.encode(t)
342            "#,
343            )
344            .eval();
345        assert!(result.is_err());
346        let err_msg = result.unwrap_err().to_string();
347        assert!(err_msg.contains("unsupported table key type"));
348    }
349
350    #[test]
351    fn encode_accepts_mlua_null_sentinel() {
352        // `mlua::Value::NULL` is the sentinel produced by `LuaSerdeExt::to_value`
353        // for JSON `null`.  Encode must map it back to JSON `null` so that values
354        // going through mlua's serde bridge can be re-encoded.
355        let lua = Lua::new();
356        crate::register_all(&lua, "std").unwrap();
357
358        lua.globals().set("_null", LuaValue::NULL).unwrap();
359
360        let s: String = lua
361            .load(r#"return std.json.encode({ x = _null, y = 1 })"#)
362            .eval()
363            .unwrap();
364
365        let v: serde_json::Value = serde_json::from_str(&s).unwrap();
366        assert!(v["x"].is_null());
367        assert_eq!(v["y"], 1);
368    }
369
370    #[test]
371    fn encode_rejects_non_null_light_userdata() {
372        // Guardrail: only the canonical NULL sentinel is accepted.  App-specific
373        // LightUserData pointers must still error — we don't want to silently
374        // serialize arbitrary pointers as `null`.
375        let lua = Lua::new();
376        crate::register_all(&lua, "std").unwrap();
377
378        let mut dummy = 42u8;
379        let ud = LuaValue::LightUserData(mlua::LightUserData(
380            &mut dummy as *mut _ as *mut std::ffi::c_void,
381        ));
382        lua.globals().set("_ud", ud).unwrap();
383
384        let result: mlua::Result<String> =
385            lua.load(r#"return std.json.encode(_ud)"#).eval();
386        assert!(result.is_err());
387        assert!(result
388            .unwrap_err()
389            .to_string()
390            .contains("unsupported type for JSON conversion"));
391    }
392
393    #[test]
394    fn read_file_and_write_file_roundtrip() {
395        let dir = std::env::temp_dir().join("mlua_bat_test_json_file");
396        std::fs::create_dir_all(&dir).unwrap();
397        let path = dir.join("data.json");
398        let path_str = path.to_string_lossy();
399
400        let lua = Lua::new();
401        crate::register_all(&lua, "std").unwrap();
402
403        let s: String = lua
404            .load(&format!(
405                r#"
406                std.json.write_file("{path_str}", {{name = "test", ok = true}})
407                local t = std.json.read_file("{path_str}")
408                return t.name
409            "#
410            ))
411            .eval()
412            .unwrap();
413        assert_eq!(s, "test");
414        let _ = std::fs::remove_dir_all(&dir);
415    }
416}