1use crate::error::LuaError;
31use mlua::{Function, Lua, Table};
32use orcs_runtime::sandbox::SandboxPolicy;
33use std::path::{Path, PathBuf};
34use std::sync::Arc;
35
36#[derive(Debug, Clone)]
42pub struct LuaEnv {
43 sandbox: Arc<dyn SandboxPolicy>,
45 search_paths: Vec<PathBuf>,
47}
48
49impl LuaEnv {
50 #[must_use]
52 pub fn new(sandbox: Arc<dyn SandboxPolicy>) -> Self {
53 Self {
54 sandbox,
55 search_paths: Vec::new(),
56 }
57 }
58
59 #[must_use]
64 pub fn with_search_path(mut self, path: impl AsRef<Path>) -> Self {
65 self.search_paths.push(path.as_ref().to_path_buf());
66 self
67 }
68
69 #[must_use]
71 pub fn with_search_paths(mut self, paths: impl IntoIterator<Item = impl AsRef<Path>>) -> Self {
72 for p in paths {
73 self.search_paths.push(p.as_ref().to_path_buf());
74 }
75 self
76 }
77
78 #[must_use]
80 pub fn search_paths(&self) -> &[PathBuf] {
81 &self.search_paths
82 }
83
84 pub fn create_lua(&self) -> Result<Lua, LuaError> {
96 let lua = Lua::new();
97
98 let saved_require: Option<Function> = lua.globals().get("require").ok();
100 let saved_package: Option<Table> = lua.globals().get("package").ok();
101
102 crate::orcs_helpers::register_base_orcs_functions(&lua, Arc::clone(&self.sandbox))?;
105
106 self.setup_require(&lua, saved_require, saved_package)?;
108
109 Ok(lua)
110 }
111
112 fn setup_require(
122 &self,
123 lua: &Lua,
124 saved_require: Option<Function>,
125 saved_package: Option<Table>,
126 ) -> Result<(), LuaError> {
127 let Some(_original_require) = saved_require else {
129 tracing::warn!("Lua VM missing require function, skipping require setup");
130 return Ok(());
131 };
132 let Some(package) = saved_package else {
133 tracing::warn!("Lua VM missing package table, skipping require setup");
134 return Ok(());
135 };
136
137 package
139 .set("path", "")
140 .map_err(|e| LuaError::InvalidScript(format!("set package.path: {e}")))?;
141 package
142 .set("cpath", "")
143 .map_err(|e| LuaError::InvalidScript(format!("set package.cpath: {e}")))?;
144
145 let loaded: Table = match package.get("loaded") {
147 Ok(t) => t,
148 Err(_) => lua
149 .create_table()
150 .map_err(|e| LuaError::InvalidScript(format!("create package.loaded: {e}")))?,
151 };
152 package
153 .set("loaded", loaded)
154 .map_err(|e| LuaError::InvalidScript(format!("set package.loaded: {e}")))?;
155
156 let empty_searchers = lua
158 .create_table()
159 .map_err(|e| LuaError::InvalidScript(format!("create searchers: {e}")))?;
160 package
161 .set("searchers", empty_searchers)
162 .map_err(|e| LuaError::InvalidScript(format!("set package.searchers: {e}")))?;
163
164 lua.globals()
166 .set("package", package)
167 .map_err(|e| LuaError::InvalidScript(format!("restore package: {e}")))?;
168
169 let search_paths = self.search_paths.clone();
171
172 let custom_require = lua.create_function(move |lua, name: String| {
173 let package: Table = lua
175 .globals()
176 .get("package")
177 .map_err(|e| mlua::Error::RuntimeError(format!("package table missing: {e}")))?;
178 let loaded: Table = package
179 .get("loaded")
180 .map_err(|e| mlua::Error::RuntimeError(format!("package.loaded missing: {e}")))?;
181
182 if let Ok(cached) = loaded.get::<mlua::Value>(name.as_str()) {
183 if cached != mlua::Value::Nil {
184 return Ok(cached);
185 }
186 }
187
188 let module_rel = name.replace('.', "/");
190 for base in &search_paths {
191 let file_path = base.join(format!("{module_rel}.lua"));
193 if let Some(source) = try_read_within_base(&file_path, base) {
194 let result = eval_module(lua, &source, &name, &file_path)?;
195 loaded.set(name.as_str(), result.clone())?;
196 return Ok(result);
197 }
198
199 let init_path = base.join(&module_rel).join("init.lua");
201 if let Some(source) = try_read_within_base(&init_path, base) {
202 let result = eval_module(lua, &source, &name, &init_path)?;
203 loaded.set(name.as_str(), result.clone())?;
204 return Ok(result);
205 }
206 }
207
208 let searched: Vec<_> = search_paths
210 .iter()
211 .flat_map(|base| {
212 [
213 format!("{}/{module_rel}.lua", base.display()),
214 format!("{}/{module_rel}/init.lua", base.display()),
215 ]
216 })
217 .collect();
218
219 Err(mlua::Error::RuntimeError(format!(
220 "module '{}' not found (searched: {})",
221 name,
222 searched.join(", ")
223 )))
224 })?;
225
226 lua.globals()
227 .set("require", custom_require)
228 .map_err(|e| LuaError::InvalidScript(format!("set require: {e}")))?;
229
230 Ok(())
231 }
232}
233
234fn try_read_within_base(path: &Path, base: &Path) -> Option<String> {
240 let canonical = path.canonicalize().ok()?;
241 let base_canonical = base.canonicalize().ok()?;
242 if !canonical.starts_with(&base_canonical) {
243 return None;
244 }
245 if canonical.is_file() {
246 std::fs::read_to_string(&canonical).ok()
247 } else {
248 None
249 }
250}
251
252fn eval_module(
254 lua: &Lua,
255 source: &str,
256 name: &str,
257 path: &Path,
258) -> Result<mlua::Value, mlua::Error> {
259 lua.load(source)
260 .set_name(format!("{name} ({path})", path = path.display()))
261 .eval()
262 .map_err(|e| {
263 mlua::Error::RuntimeError(format!(
264 "error loading module '{name}' from {}: {e}",
265 path.display()
266 ))
267 })
268}
269
270#[cfg(test)]
271mod tests {
272 use super::*;
273 use orcs_runtime::sandbox::ProjectSandbox;
274
275 fn test_policy() -> Arc<dyn SandboxPolicy> {
276 Arc::new(ProjectSandbox::new(".").expect("test sandbox"))
277 }
278
279 #[test]
282 fn create_lua_returns_working_vm() {
283 let env = LuaEnv::new(test_policy());
284 let lua = env.create_lua().expect("create_lua should succeed");
285
286 let result: String = lua
288 .load(r#"return type(orcs.log)"#)
289 .eval()
290 .expect("orcs.log should exist");
291 assert_eq!(result, "function");
292 }
293
294 #[test]
295 fn create_lua_has_require() {
296 let env = LuaEnv::new(test_policy());
297 let lua = env.create_lua().expect("create_lua");
298
299 let result: String = lua
300 .load(r#"return type(require)"#)
301 .eval()
302 .expect("require should exist");
303 assert_eq!(result, "function");
304 }
305
306 #[test]
307 fn create_lua_has_package() {
308 let env = LuaEnv::new(test_policy());
309 let lua = env.create_lua().expect("create_lua");
310
311 let result: String = lua
312 .load(r#"return type(package)"#)
313 .eval()
314 .expect("package should exist");
315 assert_eq!(result, "table");
316 }
317
318 #[test]
321 fn sandbox_io_removed() {
322 let env = LuaEnv::new(test_policy());
323 let lua = env.create_lua().expect("create_lua");
324
325 let result: String = lua.load(r#"return type(io)"#).eval().expect("eval");
326 assert_eq!(result, "nil");
327 }
328
329 #[test]
330 fn sandbox_loadfile_removed() {
331 let env = LuaEnv::new(test_policy());
332 let lua = env.create_lua().expect("create_lua");
333
334 let result: String = lua.load(r#"return type(loadfile)"#).eval().expect("eval");
335 assert_eq!(result, "nil");
336 }
337
338 #[test]
339 fn sandbox_debug_removed() {
340 let env = LuaEnv::new(test_policy());
341 let lua = env.create_lua().expect("create_lua");
342
343 let result: String = lua.load(r#"return type(debug)"#).eval().expect("eval");
344 assert_eq!(result, "nil");
345 }
346
347 #[test]
348 fn sandbox_cpath_empty() {
349 let env = LuaEnv::new(test_policy());
350 let lua = env.create_lua().expect("create_lua");
351
352 let result: String = lua.load(r#"return package.cpath"#).eval().expect("eval");
353 assert_eq!(result, "");
354 }
355
356 #[test]
357 fn sandbox_path_empty() {
358 let env = LuaEnv::new(test_policy());
359 let lua = env.create_lua().expect("create_lua");
360
361 let result: String = lua.load(r#"return package.path"#).eval().expect("eval");
362 assert_eq!(result, "");
363 }
364
365 #[test]
366 fn sandbox_os_execute_removed() {
367 let env = LuaEnv::new(test_policy());
368 let lua = env.create_lua().expect("create_lua");
369
370 let result: String = lua
371 .load(r#"return type(os.execute)"#)
372 .eval()
373 .expect("eval os.execute type");
374 assert_eq!(result, "nil", "os.execute must be removed");
375 }
376
377 #[test]
378 fn sandbox_os_remove_removed() {
379 let env = LuaEnv::new(test_policy());
380 let lua = env.create_lua().expect("create_lua");
381
382 let result: String = lua
383 .load(r#"return type(os.remove)"#)
384 .eval()
385 .expect("eval os.remove type");
386 assert_eq!(result, "nil", "os.remove must be removed");
387 }
388
389 #[test]
390 fn sandbox_dofile_removed() {
391 let env = LuaEnv::new(test_policy());
392 let lua = env.create_lua().expect("create_lua");
393
394 let result: String = lua
395 .load(r#"return type(dofile)"#)
396 .eval()
397 .expect("eval dofile type");
398 assert_eq!(result, "nil", "dofile must be removed");
399 }
400
401 #[test]
402 fn sandbox_load_removed() {
403 let env = LuaEnv::new(test_policy());
404 let lua = env.create_lua().expect("create_lua");
405
406 let result: String = lua
407 .load(r#"return type(load)"#)
408 .eval()
409 .expect("eval load type");
410 assert_eq!(result, "nil", "load must be removed");
411 }
412
413 #[test]
414 fn sandbox_os_safe_functions_preserved() {
415 let env = LuaEnv::new(test_policy());
416 let lua = env.create_lua().expect("create_lua");
417
418 let result: String = lua
419 .load(r#"return type(os.time)"#)
420 .eval()
421 .expect("eval os.time type");
422 assert_eq!(result, "function", "os.time must be preserved");
423
424 let result: String = lua
425 .load(r#"return type(os.clock)"#)
426 .eval()
427 .expect("eval os.clock type");
428 assert_eq!(result, "function", "os.clock must be preserved");
429 }
430
431 #[test]
434 fn require_nonexistent_errors() {
435 let env = LuaEnv::new(test_policy());
436 let lua = env.create_lua().expect("create_lua");
437
438 let result = lua.load(r#"require("nonexistent_module_xyz")"#).exec();
439 assert!(result.is_err());
440 let err_str = result.unwrap_err().to_string();
441 assert!(
442 err_str.contains("not found"),
443 "error should say 'not found', got: {err_str}"
444 );
445 }
446
447 #[test]
450 fn require_filesystem_module() {
451 let dir = tempfile::tempdir().expect("create temp dir");
452 let lib_dir = dir.path().join("lib");
453 std::fs::create_dir_all(&lib_dir).expect("create lib dir");
454 std::fs::write(
455 lib_dir.join("helper.lua"),
456 r#"
457 local M = {}
458 function M.greet() return "hello from helper" end
459 return M
460 "#,
461 )
462 .expect("write helper.lua");
463
464 let sandbox = Arc::new(ProjectSandbox::new(dir.path()).expect("sandbox for tempdir"));
465 let env = LuaEnv::new(sandbox).with_search_path(dir.path());
466
467 let lua = env.create_lua().expect("create_lua");
468
469 let result: String = lua
470 .load(
471 r#"
472 local helper = require("lib.helper")
473 return helper.greet()
474 "#,
475 )
476 .eval()
477 .expect("require filesystem module");
478 assert_eq!(result, "hello from helper");
479 }
480
481 #[test]
482 fn require_filesystem_init_lua() {
483 let dir = tempfile::tempdir().expect("create temp dir");
484 let mod_dir = dir.path().join("mymod");
485 std::fs::create_dir_all(&mod_dir).expect("create mymod dir");
486 std::fs::write(mod_dir.join("init.lua"), r#"return { name = "mymod" }"#)
487 .expect("write init.lua");
488
489 let sandbox = Arc::new(ProjectSandbox::new(dir.path()).expect("sandbox"));
490 let env = LuaEnv::new(sandbox).with_search_path(dir.path());
491
492 let lua = env.create_lua().expect("create_lua");
493
494 let result: String = lua
495 .load(
496 r#"
497 local m = require("mymod")
498 return m.name
499 "#,
500 )
501 .eval()
502 .expect("require init.lua");
503 assert_eq!(result, "mymod");
504 }
505
506 #[test]
507 fn require_filesystem_is_cached() {
508 let dir = tempfile::tempdir().expect("create temp dir");
509 std::fs::write(
510 dir.path().join("counter.lua"),
511 r#"
512 _counter = (_counter or 0) + 1
513 return { count = _counter }
514 "#,
515 )
516 .expect("write counter.lua");
517
518 let sandbox = Arc::new(ProjectSandbox::new(dir.path()).expect("sandbox"));
519 let env = LuaEnv::new(sandbox).with_search_path(dir.path());
520
521 let lua = env.create_lua().expect("create_lua");
522
523 let result: i64 = lua
524 .load(
525 r#"
526 local a = require("counter")
527 local b = require("counter")
528 -- If cached, count should be 1 (loaded once)
529 -- If not cached, count would be 2
530 return a.count
531 "#,
532 )
533 .eval()
534 .expect("require should cache");
535 assert_eq!(result, 1, "module should be loaded only once");
536 }
537
538 #[test]
541 fn search_path_priority_first_wins() {
542 let dir1 = tempfile::tempdir().expect("create temp dir 1");
543 let dir2 = tempfile::tempdir().expect("create temp dir 2");
544
545 std::fs::write(
546 dir1.path().join("shared.lua"),
547 r#"return { source = "dir1" }"#,
548 )
549 .expect("write shared.lua to dir1");
550 std::fs::write(
551 dir2.path().join("shared.lua"),
552 r#"return { source = "dir2" }"#,
553 )
554 .expect("write shared.lua to dir2");
555
556 let sandbox = Arc::new(ProjectSandbox::new(dir1.path()).expect("sandbox"));
558 let env = LuaEnv::new(sandbox)
559 .with_search_path(dir1.path())
560 .with_search_path(dir2.path());
561
562 let lua = env.create_lua().expect("create_lua");
563
564 let result: String = lua
565 .load(
566 r#"
567 local m = require("shared")
568 return m.source
569 "#,
570 )
571 .eval()
572 .expect("require shared");
573 assert_eq!(result, "dir1", "first search path should win");
574 }
575
576 #[test]
577 fn require_filesystem_module_loads() {
578 let dir = tempfile::tempdir().expect("create temp dir");
579 std::fs::write(
580 dir.path().join("my_module.lua"),
581 r#"return { source = "filesystem" }"#,
582 )
583 .expect("write module");
584
585 let sandbox = Arc::new(ProjectSandbox::new(dir.path()).expect("sandbox"));
586 let env = LuaEnv::new(sandbox).with_search_path(dir.path());
587
588 let lua = env.create_lua().expect("create_lua");
589
590 let result: String = lua
591 .load(
592 r#"
593 local m = require("my_module")
594 return m.source
595 "#,
596 )
597 .eval()
598 .expect("require should load from filesystem");
599 assert_eq!(result, "filesystem");
600 }
601
602 #[test]
605 fn require_error_lists_searched_paths() {
606 let dir = tempfile::tempdir().expect("create temp dir");
607 let sandbox = Arc::new(ProjectSandbox::new(dir.path()).expect("sandbox"));
608 let env = LuaEnv::new(sandbox).with_search_path(dir.path());
609
610 let lua = env.create_lua().expect("create_lua");
611
612 let result = lua.load(r#"require("missing")"#).exec();
613 let err_str = result.unwrap_err().to_string();
614 assert!(err_str.contains("missing"), "should contain module name");
615 assert!(
616 err_str.contains("not found"),
617 "should say not found: {err_str}"
618 );
619 }
620
621 #[test]
624 fn with_search_paths_batch() {
625 let env = LuaEnv::new(test_policy()).with_search_paths(["/a", "/b", "/c"]);
626 assert_eq!(env.search_paths().len(), 3);
627 }
628
629 #[test]
630 fn search_paths_empty_by_default() {
631 let env = LuaEnv::new(test_policy());
632 assert!(env.search_paths().is_empty());
633 }
634
635 #[test]
638 fn orcs_log_works() {
639 let env = LuaEnv::new(test_policy());
640 let lua = env.create_lua().expect("create_lua");
641
642 lua.load(r#"orcs.log("info", "hello from LuaEnv")"#)
643 .exec()
644 .expect("orcs.log should work");
645 }
646
647 #[test]
648 fn orcs_pwd_works() {
649 let env = LuaEnv::new(test_policy());
650 let lua = env.create_lua().expect("create_lua");
651
652 let result: String = lua
653 .load(r#"return orcs.pwd"#)
654 .eval()
655 .expect("orcs.pwd should work");
656 assert!(!result.is_empty(), "pwd should not be empty");
657 }
658
659 #[test]
660 fn orcs_json_parse_works() {
661 let env = LuaEnv::new(test_policy());
662 let lua = env.create_lua().expect("create_lua");
663
664 let result: i64 = lua
665 .load(r#"return orcs.json_parse('{"x":42}').x"#)
666 .eval()
667 .expect("json_parse should work");
668 assert_eq!(result, 42);
669 }
670}