Skip to main content

mlua_batteries/
json.rs

1//! JSON encode/decode module.
2//!
3//! ```lua
4//! local json = std.json
5//! local t = json.decode('{"a":1}')
6//! local s = json.encode(t)
7//! local s2 = json.encode_pretty(t)
8//! ```
9//!
10//! # Empty tables
11//!
12//! An empty Lua table `{}` is encoded as a JSON **object** `{}`,
13//! not an array `[]`. This matches the `classify` heuristic: a table
14//! with `raw_len() == 0` is treated as a map.
15
16use mlua::prelude::*;
17use serde_json::Value as JsonValue;
18
19use crate::policy::PathOp;
20use crate::util::{check_path, classify, with_config, TableKind};
21
22pub fn module(lua: &Lua) -> LuaResult<LuaTable> {
23    let t = lua.create_table()?;
24
25    t.set(
26        "decode",
27        lua.create_function(|lua, s: String| {
28            let max_depth = with_config(lua, |c| c.max_json_depth)?;
29            let value: JsonValue = serde_json::from_str(&s)
30                .map_err(|e| LuaError::external(format!("json.decode: {e}")))?;
31            json_to_lua(lua, &value, max_depth)
32        })?,
33    )?;
34
35    t.set(
36        "encode",
37        lua.create_function(|lua, value: LuaValue| {
38            let max_depth = with_config(lua, |c| c.max_json_depth)?;
39            let json = lua_to_json(&value, max_depth)?;
40            serde_json::to_string(&json)
41                .map_err(|e| LuaError::external(format!("json.encode: {e}")))
42        })?,
43    )?;
44
45    t.set(
46        "encode_pretty",
47        lua.create_function(|lua, value: LuaValue| {
48            let max_depth = with_config(lua, |c| c.max_json_depth)?;
49            let json = lua_to_json(&value, max_depth)?;
50            serde_json::to_string_pretty(&json)
51                .map_err(|e| LuaError::external(format!("json.encode_pretty: {e}")))
52        })?,
53    )?;
54
55    t.set(
56        "read_file",
57        lua.create_function(|lua, path: String| {
58            let access = check_path(lua, &path, PathOp::Read)?;
59            let max_depth = with_config(lua, |c| c.max_json_depth)?;
60            let content = access.read_to_string().map_err(LuaError::external)?;
61            let value: JsonValue = serde_json::from_str(&content)
62                .map_err(|e| LuaError::external(format!("json.read_file: {e}")))?;
63            json_to_lua(lua, &value, max_depth)
64        })?,
65    )?;
66
67    t.set(
68        "write_file",
69        lua.create_function(|lua, (path, value): (String, LuaValue)| {
70            let access = check_path(lua, &path, PathOp::Write)?;
71            let max_depth = with_config(lua, |c| c.max_json_depth)?;
72            let json = lua_to_json(&value, max_depth)?;
73            let content = serde_json::to_string_pretty(&json)
74                .map_err(|e| LuaError::external(format!("json.write_file: {e}")))?;
75            access
76                .write(content.as_bytes())
77                .map_err(LuaError::external)?;
78            Ok(true)
79        })?,
80    )?;
81
82    Ok(t)
83}
84
85// ─── Conversion: JSON → Lua ────────────────────────────
86
87pub(crate) fn json_to_lua(lua: &Lua, value: &JsonValue, max_depth: usize) -> LuaResult<LuaValue> {
88    json_to_lua_inner(lua, value, 0, max_depth)
89}
90
91fn json_to_lua_inner(
92    lua: &Lua,
93    value: &JsonValue,
94    depth: usize,
95    max_depth: usize,
96) -> LuaResult<LuaValue> {
97    if depth > max_depth {
98        return Err(LuaError::external(format!(
99            "JSON nesting too deep (limit: {max_depth})"
100        )));
101    }
102    match value {
103        JsonValue::Null => Ok(LuaValue::Nil),
104        JsonValue::Bool(b) => Ok(LuaValue::Boolean(*b)),
105        JsonValue::Number(n) => {
106            if let Some(i) = n.as_i64() {
107                Ok(LuaValue::Integer(i))
108            } else if let Some(f) = n.as_f64() {
109                Ok(LuaValue::Number(f))
110            } else {
111                Err(LuaError::external(format!(
112                    "JSON number {n} is not representable as i64 or f64"
113                )))
114            }
115        }
116        JsonValue::String(s) => lua.create_string(s).map(LuaValue::String),
117        JsonValue::Array(arr) => {
118            let table = lua.create_table()?;
119            for (i, v) in arr.iter().enumerate() {
120                table.set(i + 1, json_to_lua_inner(lua, v, depth + 1, max_depth)?)?;
121            }
122            Ok(LuaValue::Table(table))
123        }
124        JsonValue::Object(map) => {
125            let table = lua.create_table()?;
126            for (k, v) in map {
127                table.set(k.as_str(), json_to_lua_inner(lua, v, depth + 1, max_depth)?)?;
128            }
129            Ok(LuaValue::Table(table))
130        }
131    }
132}
133
134// ─── Conversion: Lua → JSON ────────────────────────────
135
136pub(crate) fn lua_to_json(value: &LuaValue, max_depth: usize) -> LuaResult<JsonValue> {
137    lua_to_json_inner(value, 0, max_depth)
138}
139
140fn lua_to_json_inner(value: &LuaValue, depth: usize, max_depth: usize) -> LuaResult<JsonValue> {
141    if depth > max_depth {
142        return Err(LuaError::external(format!(
143            "Lua table nesting too deep for JSON (limit: {max_depth})"
144        )));
145    }
146    match value {
147        LuaValue::Nil => Ok(JsonValue::Null),
148        LuaValue::Boolean(b) => Ok(JsonValue::Bool(*b)),
149        LuaValue::Integer(i) => Ok(JsonValue::Number((*i).into())),
150        LuaValue::Number(n) => serde_json::Number::from_f64(*n)
151            .map(JsonValue::Number)
152            .ok_or_else(|| LuaError::external(format!("cannot convert {n} to JSON number"))),
153        LuaValue::String(s) => Ok(JsonValue::String(s.to_str()?.to_string())),
154        LuaValue::Table(t) => lua_table_to_json(t, depth, max_depth),
155        _ => Err(LuaError::external("unsupported type for JSON conversion")),
156    }
157}
158
159fn lua_table_to_json(table: &LuaTable, depth: usize, max_depth: usize) -> LuaResult<JsonValue> {
160    match classify(table)? {
161        TableKind::Array(len) => {
162            let mut arr = Vec::with_capacity(len);
163            for i in 1..=len {
164                let v: LuaValue = table.raw_get(i)?;
165                arr.push(lua_to_json_inner(&v, depth + 1, max_depth)?);
166            }
167            Ok(JsonValue::Array(arr))
168        }
169        TableKind::Map(pairs) => {
170            let mut map = serde_json::Map::new();
171            for (k, v) in pairs {
172                let key = match k {
173                    LuaValue::String(s) => s.to_str()?.to_string(),
174                    LuaValue::Integer(i) => i.to_string(),
175                    LuaValue::Number(n) => n.to_string(),
176                    other => {
177                        return Err(LuaError::external(format!(
178                            "unsupported table key type for JSON: {}",
179                            other.type_name()
180                        )));
181                    }
182                };
183                map.insert(key, lua_to_json_inner(&v, depth + 1, max_depth)?);
184            }
185            Ok(JsonValue::Object(map))
186        }
187    }
188}
189
190#[cfg(test)]
191mod tests {
192    use super::*;
193
194    use crate::util::test_eval as eval;
195
196    #[test]
197    fn decode_object() {
198        let s: String = eval(
199            r#"
200            local t = std.json.decode('{"a":1,"b":"hello"}')
201            return tostring(t.a) .. "," .. t.b
202        "#,
203        );
204        assert_eq!(s, "1,hello");
205    }
206
207    #[test]
208    fn decode_array() {
209        let n: i64 = eval(
210            r#"
211            local arr = std.json.decode('[10,20,30]')
212            return #arr
213        "#,
214        );
215        assert_eq!(n, 3);
216    }
217
218    #[test]
219    fn decode_invalid_returns_error() {
220        let lua = Lua::new();
221        crate::register_all(&lua, "std").unwrap();
222        let result: mlua::Result<LuaValue> = lua.load(r#"return std.json.decode("{bad")"#).eval();
223        assert!(result.is_err());
224    }
225
226    #[test]
227    fn encode_roundtrip() {
228        let s: String = eval(
229            r#"
230            local original = '{"name":"test","values":[1,2,3]}'
231            local t = std.json.decode(original)
232            local encoded = std.json.encode(t)
233            local t2 = std.json.decode(encoded)
234            return t2.name .. "," .. tostring(#t2.values)
235        "#,
236        );
237        assert_eq!(s, "test,3");
238    }
239
240    #[test]
241    fn encode_empty_table_as_object() {
242        let s: String = eval(
243            r#"
244            return std.json.encode({})
245        "#,
246        );
247        assert_eq!(s, "{}");
248    }
249
250    #[test]
251    fn encode_nested_structure() {
252        let s: String = eval(
253            r#"
254            return std.json.encode({items = {1, 2}, meta = {ok = true}})
255        "#,
256        );
257        let v: serde_json::Value = serde_json::from_str(&s).unwrap();
258        assert!(v["items"].is_array());
259        assert_eq!(v["meta"]["ok"], true);
260    }
261
262    #[test]
263    fn encode_pretty_has_newlines() {
264        let s: String = eval(
265            r#"
266            return std.json.encode_pretty({a = 1})
267        "#,
268        );
269        assert!(s.contains('\n'));
270    }
271
272    #[test]
273    fn decode_null_becomes_nil() {
274        let b: bool = eval(
275            r#"
276            local t = std.json.decode('{"x":null}')
277            return t.x == nil
278        "#,
279        );
280        assert!(b);
281    }
282
283    #[test]
284    fn decode_boolean() {
285        let s: String = eval(
286            r#"
287            local t = std.json.decode('{"flag":true}')
288            return type(t.flag)
289        "#,
290        );
291        assert_eq!(s, "boolean");
292    }
293
294    #[test]
295    fn max_depth_enforced_on_decode() {
296        let lua = Lua::new();
297        let config = crate::config::Config::builder()
298            .max_json_depth(2)
299            .build()
300            .unwrap();
301        crate::register_all_with(&lua, "std", config).unwrap();
302
303        // depth 3: {"a":{"b":{"c":1}}}
304        let result: mlua::Result<LuaValue> = lua
305            .load(r#"return std.json.decode('{"a":{"b":{"c":1}}}')"#)
306            .eval();
307        assert!(result.is_err());
308    }
309
310    #[test]
311    fn max_depth_enforced_on_encode() {
312        let lua = Lua::new();
313        let config = crate::config::Config::builder()
314            .max_json_depth(2)
315            .build()
316            .unwrap();
317        crate::register_all_with(&lua, "std", config).unwrap();
318
319        // depth 3 nested table
320        let result: mlua::Result<LuaValue> = lua
321            .load(r#"return std.json.encode({a = {b = {c = 1}}})"#)
322            .eval();
323        assert!(result.is_err());
324    }
325
326    #[test]
327    fn encode_rejects_boolean_key() {
328        let lua = Lua::new();
329        crate::register_all(&lua, "std").unwrap();
330
331        let result: mlua::Result<LuaValue> = lua
332            .load(
333                r#"
334                local t = {}
335                t[true] = "val"
336                return std.json.encode(t)
337            "#,
338            )
339            .eval();
340        assert!(result.is_err());
341        let err_msg = result.unwrap_err().to_string();
342        assert!(err_msg.contains("unsupported table key type"));
343    }
344
345    #[test]
346    fn read_file_and_write_file_roundtrip() {
347        let dir = std::env::temp_dir().join("mlua_bat_test_json_file");
348        std::fs::create_dir_all(&dir).unwrap();
349        let path = dir.join("data.json");
350        let path_str = path.to_string_lossy();
351
352        let lua = Lua::new();
353        crate::register_all(&lua, "std").unwrap();
354
355        let s: String = lua
356            .load(&format!(
357                r#"
358                std.json.write_file("{path_str}", {{name = "test", ok = true}})
359                local t = std.json.read_file("{path_str}")
360                return t.name
361            "#
362            ))
363            .eval()
364            .unwrap();
365        assert_eq!(s, "test");
366        let _ = std::fs::remove_dir_all(&dir);
367    }
368}