1use mlua::prelude::*;
17use serde_json::Value as JsonValue;
18
19use crate::policy::PathOp;
20use crate::util::{check_path, classify, with_config, TableKind};
21
22pub fn module(lua: &Lua) -> LuaResult<LuaTable> {
23 let t = lua.create_table()?;
24
25 t.set(
26 "decode",
27 lua.create_function(|lua, s: String| {
28 let max_depth = with_config(lua, |c| c.max_json_depth)?;
29 let value: JsonValue = serde_json::from_str(&s)
30 .map_err(|e| LuaError::external(format!("json.decode: {e}")))?;
31 json_to_lua(lua, &value, max_depth)
32 })?,
33 )?;
34
35 t.set(
36 "encode",
37 lua.create_function(|lua, value: LuaValue| {
38 let max_depth = with_config(lua, |c| c.max_json_depth)?;
39 let json = lua_to_json(&value, max_depth)?;
40 serde_json::to_string(&json)
41 .map_err(|e| LuaError::external(format!("json.encode: {e}")))
42 })?,
43 )?;
44
45 t.set(
46 "encode_pretty",
47 lua.create_function(|lua, value: LuaValue| {
48 let max_depth = with_config(lua, |c| c.max_json_depth)?;
49 let json = lua_to_json(&value, max_depth)?;
50 serde_json::to_string_pretty(&json)
51 .map_err(|e| LuaError::external(format!("json.encode_pretty: {e}")))
52 })?,
53 )?;
54
55 t.set(
56 "read_file",
57 lua.create_function(|lua, path: String| {
58 let access = check_path(lua, &path, PathOp::Read)?;
59 let max_depth = with_config(lua, |c| c.max_json_depth)?;
60 let content = access.read_to_string().map_err(LuaError::external)?;
61 let value: JsonValue = serde_json::from_str(&content)
62 .map_err(|e| LuaError::external(format!("json.read_file: {e}")))?;
63 json_to_lua(lua, &value, max_depth)
64 })?,
65 )?;
66
67 t.set(
68 "write_file",
69 lua.create_function(|lua, (path, value): (String, LuaValue)| {
70 let access = check_path(lua, &path, PathOp::Write)?;
71 let max_depth = with_config(lua, |c| c.max_json_depth)?;
72 let json = lua_to_json(&value, max_depth)?;
73 let content = serde_json::to_string_pretty(&json)
74 .map_err(|e| LuaError::external(format!("json.write_file: {e}")))?;
75 access
76 .write(content.as_bytes())
77 .map_err(LuaError::external)?;
78 Ok(true)
79 })?,
80 )?;
81
82 Ok(t)
83}
84
85pub(crate) fn json_to_lua(lua: &Lua, value: &JsonValue, max_depth: usize) -> LuaResult<LuaValue> {
88 json_to_lua_inner(lua, value, 0, max_depth)
89}
90
91fn json_to_lua_inner(
92 lua: &Lua,
93 value: &JsonValue,
94 depth: usize,
95 max_depth: usize,
96) -> LuaResult<LuaValue> {
97 if depth > max_depth {
98 return Err(LuaError::external(format!(
99 "JSON nesting too deep (limit: {max_depth})"
100 )));
101 }
102 match value {
103 JsonValue::Null => Ok(LuaValue::Nil),
104 JsonValue::Bool(b) => Ok(LuaValue::Boolean(*b)),
105 JsonValue::Number(n) => {
106 if let Some(i) = n.as_i64() {
107 Ok(LuaValue::Integer(i))
108 } else if let Some(f) = n.as_f64() {
109 Ok(LuaValue::Number(f))
110 } else {
111 Err(LuaError::external(format!(
112 "JSON number {n} is not representable as i64 or f64"
113 )))
114 }
115 }
116 JsonValue::String(s) => lua.create_string(s).map(LuaValue::String),
117 JsonValue::Array(arr) => {
118 let table = lua.create_table()?;
119 for (i, v) in arr.iter().enumerate() {
120 table.set(i + 1, json_to_lua_inner(lua, v, depth + 1, max_depth)?)?;
121 }
122 Ok(LuaValue::Table(table))
123 }
124 JsonValue::Object(map) => {
125 let table = lua.create_table()?;
126 for (k, v) in map {
127 table.set(k.as_str(), json_to_lua_inner(lua, v, depth + 1, max_depth)?)?;
128 }
129 Ok(LuaValue::Table(table))
130 }
131 }
132}
133
134pub(crate) fn lua_to_json(value: &LuaValue, max_depth: usize) -> LuaResult<JsonValue> {
137 lua_to_json_inner(value, 0, max_depth)
138}
139
140fn lua_to_json_inner(value: &LuaValue, depth: usize, max_depth: usize) -> LuaResult<JsonValue> {
141 if depth > max_depth {
142 return Err(LuaError::external(format!(
143 "Lua table nesting too deep for JSON (limit: {max_depth})"
144 )));
145 }
146 match value {
147 LuaValue::Nil => Ok(JsonValue::Null),
148 LuaValue::LightUserData(u) if u.0.is_null() => Ok(JsonValue::Null),
153 LuaValue::Boolean(b) => Ok(JsonValue::Bool(*b)),
154 LuaValue::Integer(i) => Ok(JsonValue::Number((*i).into())),
155 LuaValue::Number(n) => serde_json::Number::from_f64(*n)
156 .map(JsonValue::Number)
157 .ok_or_else(|| LuaError::external(format!("cannot convert {n} to JSON number"))),
158 LuaValue::String(s) => Ok(JsonValue::String(s.to_str()?.to_string())),
159 LuaValue::Table(t) => lua_table_to_json(t, depth, max_depth),
160 _ => Err(LuaError::external("unsupported type for JSON conversion")),
161 }
162}
163
164fn lua_table_to_json(table: &LuaTable, depth: usize, max_depth: usize) -> LuaResult<JsonValue> {
165 match classify(table)? {
166 TableKind::Array(len) => {
167 let mut arr = Vec::with_capacity(len);
168 for i in 1..=len {
169 let v: LuaValue = table.raw_get(i)?;
170 arr.push(lua_to_json_inner(&v, depth + 1, max_depth)?);
171 }
172 Ok(JsonValue::Array(arr))
173 }
174 TableKind::Map(pairs) => {
175 let mut map = serde_json::Map::new();
176 for (k, v) in pairs {
177 let key = match k {
178 LuaValue::String(s) => s.to_str()?.to_string(),
179 LuaValue::Integer(i) => i.to_string(),
180 LuaValue::Number(n) => n.to_string(),
181 other => {
182 return Err(LuaError::external(format!(
183 "unsupported table key type for JSON: {}",
184 other.type_name()
185 )));
186 }
187 };
188 map.insert(key, lua_to_json_inner(&v, depth + 1, max_depth)?);
189 }
190 Ok(JsonValue::Object(map))
191 }
192 }
193}
194
195#[cfg(test)]
196mod tests {
197 use super::*;
198
199 use crate::util::test_eval as eval;
200
201 #[test]
202 fn decode_object() {
203 let s: String = eval(
204 r#"
205 local t = std.json.decode('{"a":1,"b":"hello"}')
206 return tostring(t.a) .. "," .. t.b
207 "#,
208 );
209 assert_eq!(s, "1,hello");
210 }
211
212 #[test]
213 fn decode_array() {
214 let n: i64 = eval(
215 r#"
216 local arr = std.json.decode('[10,20,30]')
217 return #arr
218 "#,
219 );
220 assert_eq!(n, 3);
221 }
222
223 #[test]
224 fn decode_invalid_returns_error() {
225 let lua = Lua::new();
226 crate::register_all(&lua, "std").unwrap();
227 let result: mlua::Result<LuaValue> = lua.load(r#"return std.json.decode("{bad")"#).eval();
228 assert!(result.is_err());
229 }
230
231 #[test]
232 fn encode_roundtrip() {
233 let s: String = eval(
234 r#"
235 local original = '{"name":"test","values":[1,2,3]}'
236 local t = std.json.decode(original)
237 local encoded = std.json.encode(t)
238 local t2 = std.json.decode(encoded)
239 return t2.name .. "," .. tostring(#t2.values)
240 "#,
241 );
242 assert_eq!(s, "test,3");
243 }
244
245 #[test]
246 fn encode_empty_table_as_object() {
247 let s: String = eval(
248 r#"
249 return std.json.encode({})
250 "#,
251 );
252 assert_eq!(s, "{}");
253 }
254
255 #[test]
256 fn encode_nested_structure() {
257 let s: String = eval(
258 r#"
259 return std.json.encode({items = {1, 2}, meta = {ok = true}})
260 "#,
261 );
262 let v: serde_json::Value = serde_json::from_str(&s).unwrap();
263 assert!(v["items"].is_array());
264 assert_eq!(v["meta"]["ok"], true);
265 }
266
267 #[test]
268 fn encode_pretty_has_newlines() {
269 let s: String = eval(
270 r#"
271 return std.json.encode_pretty({a = 1})
272 "#,
273 );
274 assert!(s.contains('\n'));
275 }
276
277 #[test]
278 fn decode_null_becomes_nil() {
279 let b: bool = eval(
280 r#"
281 local t = std.json.decode('{"x":null}')
282 return t.x == nil
283 "#,
284 );
285 assert!(b);
286 }
287
288 #[test]
289 fn decode_boolean() {
290 let s: String = eval(
291 r#"
292 local t = std.json.decode('{"flag":true}')
293 return type(t.flag)
294 "#,
295 );
296 assert_eq!(s, "boolean");
297 }
298
299 #[test]
300 fn max_depth_enforced_on_decode() {
301 let lua = Lua::new();
302 let config = crate::config::Config::builder()
303 .max_json_depth(2)
304 .build()
305 .unwrap();
306 crate::register_all_with(&lua, "std", config).unwrap();
307
308 let result: mlua::Result<LuaValue> = lua
310 .load(r#"return std.json.decode('{"a":{"b":{"c":1}}}')"#)
311 .eval();
312 assert!(result.is_err());
313 }
314
315 #[test]
316 fn max_depth_enforced_on_encode() {
317 let lua = Lua::new();
318 let config = crate::config::Config::builder()
319 .max_json_depth(2)
320 .build()
321 .unwrap();
322 crate::register_all_with(&lua, "std", config).unwrap();
323
324 let result: mlua::Result<LuaValue> = lua
326 .load(r#"return std.json.encode({a = {b = {c = 1}}})"#)
327 .eval();
328 assert!(result.is_err());
329 }
330
331 #[test]
332 fn encode_rejects_boolean_key() {
333 let lua = Lua::new();
334 crate::register_all(&lua, "std").unwrap();
335
336 let result: mlua::Result<LuaValue> = lua
337 .load(
338 r#"
339 local t = {}
340 t[true] = "val"
341 return std.json.encode(t)
342 "#,
343 )
344 .eval();
345 assert!(result.is_err());
346 let err_msg = result.unwrap_err().to_string();
347 assert!(err_msg.contains("unsupported table key type"));
348 }
349
350 #[test]
351 fn encode_accepts_mlua_null_sentinel() {
352 let lua = Lua::new();
356 crate::register_all(&lua, "std").unwrap();
357
358 lua.globals().set("_null", LuaValue::NULL).unwrap();
359
360 let s: String = lua
361 .load(r#"return std.json.encode({ x = _null, y = 1 })"#)
362 .eval()
363 .unwrap();
364
365 let v: serde_json::Value = serde_json::from_str(&s).unwrap();
366 assert!(v["x"].is_null());
367 assert_eq!(v["y"], 1);
368 }
369
370 #[test]
371 fn encode_rejects_non_null_light_userdata() {
372 let lua = Lua::new();
376 crate::register_all(&lua, "std").unwrap();
377
378 let mut dummy = 42u8;
379 let ud = LuaValue::LightUserData(mlua::LightUserData(
380 &mut dummy as *mut _ as *mut std::ffi::c_void,
381 ));
382 lua.globals().set("_ud", ud).unwrap();
383
384 let result: mlua::Result<String> =
385 lua.load(r#"return std.json.encode(_ud)"#).eval();
386 assert!(result.is_err());
387 assert!(result
388 .unwrap_err()
389 .to_string()
390 .contains("unsupported type for JSON conversion"));
391 }
392
393 #[test]
394 fn read_file_and_write_file_roundtrip() {
395 let dir = std::env::temp_dir().join("mlua_bat_test_json_file");
396 std::fs::create_dir_all(&dir).unwrap();
397 let path = dir.join("data.json");
398 let path_str = path.to_string_lossy();
399
400 let lua = Lua::new();
401 crate::register_all(&lua, "std").unwrap();
402
403 let s: String = lua
404 .load(&format!(
405 r#"
406 std.json.write_file("{path_str}", {{name = "test", ok = true}})
407 local t = std.json.read_file("{path_str}")
408 return t.name
409 "#
410 ))
411 .eval()
412 .unwrap();
413 assert_eq!(s, "test");
414 let _ = std::fs::remove_dir_all(&dir);
415 }
416}