1use crate::error::LuaError;
26use mlua::{Function, Lua, LuaSerdeExt, Table};
27use orcs_runtime::sandbox::SandboxPolicy;
28use std::path::Path;
29use std::sync::Arc;
30
31pub(crate) fn tool_read(path: &str, sandbox: &dyn SandboxPolicy) -> Result<(String, u64), String> {
35 let canonical = sandbox.validate_read(path).map_err(|e| e.to_string())?;
36
37 let metadata =
38 std::fs::metadata(&canonical).map_err(|e| format!("cannot read metadata: {path} ({e})"))?;
39
40 if !metadata.is_file() {
41 return Err(format!("not a file: {path}"));
42 }
43
44 let size = metadata.len();
45 let content =
46 std::fs::read_to_string(&canonical).map_err(|e| format!("read failed: {path} ({e})"))?;
47
48 Ok((content, size))
49}
50
51pub(crate) fn tool_write(
57 path: &str,
58 content: &str,
59 sandbox: &dyn SandboxPolicy,
60) -> Result<usize, String> {
61 let target = sandbox.validate_write(path).map_err(|e| e.to_string())?;
62
63 let parent = target
65 .parent()
66 .ok_or_else(|| format!("cannot determine parent directory: {path}"))?;
67 std::fs::create_dir_all(parent).map_err(|e| format!("cannot create parent directory: {e}"))?;
68
69 let bytes = content.len();
70
71 let mut temp = tempfile::NamedTempFile::new_in(parent)
74 .map_err(|e| format!("temp file creation failed: {path} ({e})"))?;
75
76 use std::io::Write;
77 temp.write_all(content.as_bytes())
78 .map_err(|e| format!("write failed: {path} ({e})"))?;
79
80 temp.persist(&target)
81 .map_err(|e| format!("rename failed: {path} ({e})"))?;
82
83 Ok(bytes)
84}
85
86#[derive(Debug)]
88pub(crate) struct GrepMatch {
89 pub(crate) line_number: usize,
90 pub(crate) line: String,
91}
92
93const MAX_GREP_DEPTH: usize = 32;
95
96const MAX_GREP_MATCHES: usize = 10_000;
98
99pub(crate) fn tool_grep(
106 pattern: &str,
107 path: &str,
108 sandbox: &dyn SandboxPolicy,
109) -> Result<Vec<GrepMatch>, String> {
110 let re = regex::Regex::new(pattern).map_err(|e| format!("invalid regex: {pattern} ({e})"))?;
111
112 let canonical = sandbox.validate_read(path).map_err(|e| e.to_string())?;
113 let mut matches = Vec::new();
114
115 let sandbox_root = sandbox.root();
116 if canonical.is_file() {
117 grep_file(&re, &canonical, &mut matches)?;
118 } else if canonical.is_dir() {
119 grep_dir(&re, &canonical, sandbox_root, &mut matches, 0)?;
120 } else {
121 return Err(format!("not a file or directory: {path}"));
122 }
123
124 Ok(matches)
125}
126
127fn grep_file(re: ®ex::Regex, path: &Path, matches: &mut Vec<GrepMatch>) -> Result<(), String> {
128 let content =
129 std::fs::read_to_string(path).map_err(|e| format!("read failed: {:?} ({e})", path))?;
130
131 for (i, line) in content.lines().enumerate() {
132 if matches.len() >= MAX_GREP_MATCHES {
133 break;
134 }
135 if re.is_match(line) {
136 matches.push(GrepMatch {
137 line_number: i + 1,
138 line: line.to_string(),
139 });
140 }
141 }
142
143 Ok(())
144}
145
146fn grep_dir(
155 re: ®ex::Regex,
156 dir: &Path,
157 sandbox_root: &Path,
158 matches: &mut Vec<GrepMatch>,
159 depth: usize,
160) -> Result<(), String> {
161 if depth > MAX_GREP_DEPTH {
162 tracing::debug!("grep: max depth ({MAX_GREP_DEPTH}) reached at {:?}", dir);
163 return Ok(());
164 }
165 if matches.len() >= MAX_GREP_MATCHES {
166 return Ok(());
167 }
168
169 let entries =
170 std::fs::read_dir(dir).map_err(|e| format!("cannot read directory: {:?} ({e})", dir))?;
171
172 for entry in entries.flatten() {
173 if matches.len() >= MAX_GREP_MATCHES {
174 break;
175 }
176
177 let path = entry.path();
178
179 let canonical = match path.canonicalize() {
181 Ok(c) if c.starts_with(sandbox_root) => c,
182 _ => continue, };
184
185 if canonical.is_file() {
186 let is_binary = {
188 use std::io::Read;
189 match std::fs::File::open(&canonical) {
190 Ok(mut file) => {
191 let mut buf = [0u8; 512];
192 match file.read(&mut buf) {
193 Ok(n) => buf[..n].contains(&0),
194 Err(_) => true, }
196 }
197 Err(_) => true, }
199 };
200 if is_binary {
201 continue;
202 }
203 if let Err(e) = grep_file(re, &canonical, matches) {
204 tracing::debug!("grep: skip {:?}: {e}", canonical);
205 }
206 } else if canonical.is_dir() {
207 if let Err(e) = grep_dir(re, &canonical, sandbox_root, matches, depth + 1) {
208 tracing::debug!("grep: skip dir {:?}: {e}", canonical);
209 }
210 }
211 }
212
213 Ok(())
214}
215
216pub(crate) fn tool_glob(
221 pattern: &str,
222 dir: Option<&str>,
223 sandbox: &dyn SandboxPolicy,
224) -> Result<Vec<String>, String> {
225 if pattern.contains("..") {
227 return Err("glob pattern must not contain '..'".to_string());
228 }
229
230 let full_pattern = match dir {
231 Some(d) => {
232 let base = sandbox.validate_read(d).map_err(|e| e.to_string())?;
233 if !base.is_dir() {
234 return Err(format!("not a directory: {d}"));
235 }
236 format!("{}/{pattern}", base.display())
237 }
238 None => {
239 format!("{}/{pattern}", sandbox.root().display())
240 }
241 };
242
243 let paths =
244 glob::glob(&full_pattern).map_err(|e| format!("invalid glob pattern: {pattern} ({e})"))?;
245
246 let sandbox_root = sandbox.root();
247 let mut results = Vec::new();
248 for entry in paths.flatten() {
249 if let Ok(canonical) = entry.canonicalize() {
251 if canonical.starts_with(sandbox_root) {
252 results.push(canonical.display().to_string());
253 }
254 }
255 }
256
257 results.sort();
258 Ok(results)
259}
260
261pub(crate) fn tool_mkdir(path: &str, sandbox: &dyn SandboxPolicy) -> Result<(), String> {
265 let target = sandbox.validate_write(path).map_err(|e| e.to_string())?;
266 std::fs::create_dir_all(&target).map_err(|e| format!("mkdir failed: {path} ({e})"))
267}
268
269pub(crate) fn tool_remove(path: &str, sandbox: &dyn SandboxPolicy) -> Result<(), String> {
274 sandbox.validate_write(path).map_err(|e| e.to_string())?;
276 let canonical = sandbox.validate_read(path).map_err(|e| e.to_string())?;
278
279 if canonical.is_file() {
280 std::fs::remove_file(&canonical).map_err(|e| format!("remove failed: {path} ({e})"))
281 } else if canonical.is_dir() {
282 std::fs::remove_dir_all(&canonical).map_err(|e| format!("remove failed: {path} ({e})"))
283 } else {
284 Err(format!("not found: {path}"))
285 }
286}
287
288pub(crate) fn tool_mv(src: &str, dst: &str, sandbox: &dyn SandboxPolicy) -> Result<(), String> {
293 let src_canonical = sandbox.validate_read(src).map_err(|e| e.to_string())?;
294 let dst_target = sandbox.validate_write(dst).map_err(|e| e.to_string())?;
295
296 if let Some(parent) = dst_target.parent() {
298 std::fs::create_dir_all(parent)
299 .map_err(|e| format!("cannot create parent directory: {e}"))?;
300 }
301
302 std::fs::rename(&src_canonical, &dst_target)
303 .map_err(|e| format!("mv failed: {src} -> {dst} ({e})"))
304}
305
306pub(crate) struct ScanEntry {
312 pub path: String,
313 pub relative: String,
314 pub is_dir: bool,
315 pub size: u64,
316 pub modified: u64,
317}
318
319pub(crate) fn tool_scan_dir(
321 path: &str,
322 recursive: bool,
323 exclude: &[String],
324 include: &[String],
325 max_depth: Option<usize>,
326 sandbox: &dyn SandboxPolicy,
327) -> Result<Vec<ScanEntry>, String> {
328 let base = sandbox.validate_read(path).map_err(|e| e.to_string())?;
329
330 if !base.is_dir() {
331 return Err(format!("not a directory: {path}"));
332 }
333
334 let exclude_set = build_glob_set(exclude)?;
335 let include_set = if include.is_empty() {
336 None
337 } else {
338 Some(build_glob_set(include)?)
339 };
340
341 let mut walker = walkdir::WalkDir::new(&base);
342 if !recursive {
343 walker = walker.max_depth(1);
344 } else if let Some(depth) = max_depth {
345 walker = walker.max_depth(depth);
346 }
347
348 let mut entries = Vec::new();
349 for entry in walker.into_iter().filter_map(|e| e.ok()) {
350 if entry.path() == base {
351 continue;
352 }
353
354 let relative = entry
355 .path()
356 .strip_prefix(&base)
357 .unwrap_or(entry.path())
358 .to_string_lossy()
359 .to_string();
360
361 if exclude_set.is_match(&relative) {
362 continue;
363 }
364
365 let is_dir = entry.file_type().is_dir();
366
367 if !is_dir {
368 if let Some(ref inc) = include_set {
369 if !inc.is_match(&relative) {
370 continue;
371 }
372 }
373 }
374
375 let metadata = entry.metadata().ok();
376 let size = metadata.as_ref().map_or(0, |m| m.len());
377 let modified = metadata
378 .and_then(|m| m.modified().ok())
379 .and_then(|t| t.duration_since(std::time::UNIX_EPOCH).ok())
380 .map_or(0, |d| d.as_secs());
381
382 entries.push(ScanEntry {
383 path: entry.path().to_string_lossy().to_string(),
384 relative,
385 is_dir,
386 size,
387 modified,
388 });
389 }
390
391 Ok(entries)
392}
393
394fn build_glob_set(patterns: &[String]) -> Result<globset::GlobSet, String> {
395 let mut builder = globset::GlobSetBuilder::new();
396 for pattern in patterns {
397 let glob =
398 globset::Glob::new(pattern).map_err(|e| format!("invalid glob '{pattern}': {e}"))?;
399 builder.add(glob);
400 }
401 builder
402 .build()
403 .map_err(|e| format!("glob set build error: {e}"))
404}
405
406pub(crate) struct FrontmatterResult {
410 pub frontmatter: Option<serde_json::Value>,
411 pub body: String,
412 pub format: Option<String>,
413}
414
415pub(crate) fn tool_parse_frontmatter_str(content: &str) -> Result<FrontmatterResult, String> {
419 let trimmed = content.trim_start();
420
421 if let Some(rest) = trimmed.strip_prefix("---") {
422 if let Some(end_idx) = rest.find("\n---") {
424 let yaml_str = &rest[..end_idx];
425 let body_start = end_idx + 4; let body = rest[body_start..].trim_start_matches('\n').to_string();
427
428 let value: serde_json::Value =
429 serde_yaml::from_str(yaml_str).map_err(|e| format!("YAML parse error: {e}"))?;
430
431 Ok(FrontmatterResult {
432 frontmatter: Some(value),
433 body,
434 format: Some("yaml".to_string()),
435 })
436 } else {
437 Ok(FrontmatterResult {
438 frontmatter: None,
439 body: content.to_string(),
440 format: None,
441 })
442 }
443 } else if let Some(rest) = trimmed.strip_prefix("+++") {
444 if let Some(end_idx) = rest.find("\n+++") {
446 let toml_str = &rest[..end_idx];
447 let body_start = end_idx + 4;
448 let body = rest[body_start..].trim_start_matches('\n').to_string();
449
450 let toml_value: toml::Value = toml_str
451 .parse()
452 .map_err(|e| format!("TOML parse error: {e}"))?;
453 let json_value = toml_to_json(toml_value);
454
455 Ok(FrontmatterResult {
456 frontmatter: Some(json_value),
457 body,
458 format: Some("toml".to_string()),
459 })
460 } else {
461 Ok(FrontmatterResult {
462 frontmatter: None,
463 body: content.to_string(),
464 format: None,
465 })
466 }
467 } else {
468 Ok(FrontmatterResult {
469 frontmatter: None,
470 body: content.to_string(),
471 format: None,
472 })
473 }
474}
475
476pub(crate) fn tool_parse_frontmatter(
478 path: &str,
479 sandbox: &dyn SandboxPolicy,
480) -> Result<FrontmatterResult, String> {
481 let canonical = sandbox.validate_read(path).map_err(|e| e.to_string())?;
482 let content =
483 std::fs::read_to_string(&canonical).map_err(|e| format!("read failed: {path} ({e})"))?;
484 tool_parse_frontmatter_str(&content)
485}
486
487pub(crate) fn tool_parse_toml(content: &str) -> Result<serde_json::Value, String> {
491 let toml_value: toml::Value = content
492 .parse()
493 .map_err(|e| format!("TOML parse error: {e}"))?;
494 Ok(toml_to_json(toml_value))
495}
496
497fn toml_to_json(value: toml::Value) -> serde_json::Value {
498 match value {
499 toml::Value::String(s) => serde_json::Value::String(s),
500 toml::Value::Integer(i) => serde_json::json!(i),
501 toml::Value::Float(f) => serde_json::json!(f),
502 toml::Value::Boolean(b) => serde_json::Value::Bool(b),
503 toml::Value::Datetime(d) => serde_json::Value::String(d.to_string()),
504 toml::Value::Array(arr) => {
505 serde_json::Value::Array(arr.into_iter().map(toml_to_json).collect())
506 }
507 toml::Value::Table(map) => {
508 let obj = map.into_iter().map(|(k, v)| (k, toml_to_json(v))).collect();
509 serde_json::Value::Object(obj)
510 }
511 }
512}
513
514pub(crate) struct GlobMatchResult {
518 pub matched: Vec<String>,
519 pub unmatched: Vec<String>,
520}
521
522pub(crate) fn tool_glob_match(
524 patterns: &[String],
525 paths: &[String],
526) -> Result<GlobMatchResult, String> {
527 let glob_set = build_glob_set(patterns)?;
528
529 let mut matched = Vec::new();
530 let mut unmatched = Vec::new();
531
532 for path in paths {
533 if glob_set.is_match(path) {
534 matched.push(path.clone());
535 } else {
536 unmatched.push(path.clone());
537 }
538 }
539
540 Ok(GlobMatchResult { matched, unmatched })
541}
542
543pub(crate) fn tool_load_lua(
553 lua: &Lua,
554 content: &str,
555 source_name: &str,
556) -> Result<mlua::Value, String> {
557 let env = lua
559 .create_table()
560 .map_err(|e| format!("env creation failed: {e}"))?;
561
562 let globals = lua.globals();
563
564 let safe_globals = [
566 "table",
567 "string",
568 "math",
569 "pairs",
570 "ipairs",
571 "next",
572 "type",
573 "tostring",
574 "tonumber",
575 "select",
576 "unpack",
577 "error",
578 "pcall",
579 "xpcall",
580 "rawget",
581 "rawset",
582 "rawequal",
583 "rawlen",
584 "setmetatable",
585 "getmetatable",
586 ];
587
588 for name in &safe_globals {
589 if let Ok(val) = globals.get::<mlua::Value>(*name) {
590 env.set(*name, val)
591 .map_err(|e| format!("env.{name}: {e}"))?;
592 }
593 }
594
595 let src = source_name.to_string();
597 let print_fn = lua
598 .create_function(move |_, args: mlua::MultiValue| {
599 let parts: Vec<String> = args.iter().map(|v| format!("{v:?}")).collect();
600 tracing::info!(source = %src, "[lua-sandbox] {}", parts.join("\t"));
601 Ok(())
602 })
603 .map_err(|e| format!("print fn: {e}"))?;
604 env.set("print", print_fn)
605 .map_err(|e| format!("env.print: {e}"))?;
606
607 let chunk = lua.load(content).set_name(source_name);
609
610 chunk
611 .set_environment(env)
612 .eval::<mlua::Value>()
613 .map_err(|e| format!("{source_name}: {e}"))
614}
615
616pub fn register_tool_functions(lua: &Lua, sandbox: Arc<dyn SandboxPolicy>) -> Result<(), LuaError> {
627 let orcs_table: Table = lua.globals().get("orcs")?;
628
629 lua.set_app_data(Arc::clone(&sandbox));
632
633 {
643 use crate::tool_registry::{dispatch_rust_tool, ensure_registry};
644
645 let registry = ensure_registry(lua)?;
646
647 fn get_tool(
650 registry: &crate::tool_registry::IntentRegistry,
651 name: &str,
652 ) -> Option<std::sync::Arc<dyn orcs_component::RustTool>> {
653 match registry.get_tool(name) {
654 Some(t) => Some(std::sync::Arc::clone(t)),
655 None => {
656 debug_assert!(false, "builtin tool '{name}' missing from IntentRegistry");
657 tracing::error!(
658 "builtin tool '{name}' missing from IntentRegistry — orcs.{name}() unavailable"
659 );
660 None
661 }
662 }
663 }
664
665 fn register_deny_stub(
669 lua: &Lua,
670 orcs_table: &Table,
671 name: &str,
672 ) -> Result<(), mlua::Error> {
673 let tool_name = name.to_string();
674 let f = lua.create_function(move |lua, _args: mlua::MultiValue| {
675 let result = lua.create_table()?;
676 result.set("ok", false)?;
677 result.set("error", format!("tool unavailable: {tool_name}"))?;
678 Ok(result)
679 })?;
680 orcs_table.set(name, f)?;
681 Ok(())
682 }
683
684 macro_rules! register_wrapper {
687 ($tool:expr, $name:literal, |$arg:ident: String|) => {
689 if let Some(t) = $tool {
690 let f = lua.create_function(move |lua, $arg: String| {
691 let args = lua.create_table()?;
692 args.set(stringify!($arg), $arg)?;
693 dispatch_rust_tool(lua, &*t, &args)
694 })?;
695 orcs_table.set($name, f)?;
696 } else {
697 register_deny_stub(lua, &orcs_table, $name)?;
698 }
699 };
700 ($tool:expr, $name:literal, |$a1:ident: String, $a2:ident: String|) => {
702 if let Some(t) = $tool {
703 let f = lua.create_function(move |lua, ($a1, $a2): (String, String)| {
704 let args = lua.create_table()?;
705 args.set(stringify!($a1), $a1)?;
706 args.set(stringify!($a2), $a2)?;
707 dispatch_rust_tool(lua, &*t, &args)
708 })?;
709 orcs_table.set($name, f)?;
710 } else {
711 register_deny_stub(lua, &orcs_table, $name)?;
712 }
713 };
714 ($tool:expr, $name:literal, |$a1:ident: String, $a2:ident: Option<String>|) => {
716 if let Some(t) = $tool {
717 let f =
718 lua.create_function(move |lua, ($a1, opt): (String, Option<String>)| {
719 let args = lua.create_table()?;
720 args.set(stringify!($a1), $a1)?;
721 if let Some(v) = opt {
722 args.set(stringify!($a2), v)?;
723 }
724 dispatch_rust_tool(lua, &*t, &args)
725 })?;
726 orcs_table.set($name, f)?;
727 } else {
728 register_deny_stub(lua, &orcs_table, $name)?;
729 }
730 };
731 }
732
733 let read_tool = get_tool(®istry, "read");
734 let write_tool = get_tool(®istry, "write");
735 let grep_tool = get_tool(®istry, "grep");
736 let glob_tool = get_tool(®istry, "glob");
737 let mkdir_tool = get_tool(®istry, "mkdir");
738 let remove_tool = get_tool(®istry, "remove");
739 let mv_tool = get_tool(®istry, "mv");
740 drop(registry);
741
742 register_wrapper!(read_tool, "read", |path: String|);
743 register_wrapper!(write_tool, "write", |path: String, content: String|);
744 register_wrapper!(grep_tool, "grep", |pattern: String, path: String|);
745 register_wrapper!(glob_tool, "glob", |pattern: String, dir: Option<String>|);
746 register_wrapper!(mkdir_tool, "mkdir", |path: String|);
747 register_wrapper!(remove_tool, "remove", |path: String|);
748 register_wrapper!(mv_tool, "mv", |src: String, dst: String|);
749 }
750
751 let sb = Arc::clone(&sandbox);
753 let scan_dir_fn = lua.create_function(move |lua, config: Table| {
754 let path: String = config.get("path")?;
755 let recursive: bool = config.get("recursive").unwrap_or(true);
756 let max_depth: Option<usize> = config.get("max_depth").ok();
757
758 let exclude: Vec<String> = config
759 .get::<Table>("exclude")
760 .map(|t| {
761 t.sequence_values::<String>()
762 .filter_map(|v| v.ok())
763 .collect()
764 })
765 .unwrap_or_default();
766
767 let include: Vec<String> = config
768 .get::<Table>("include")
769 .map(|t| {
770 t.sequence_values::<String>()
771 .filter_map(|v| v.ok())
772 .collect()
773 })
774 .unwrap_or_default();
775
776 match tool_scan_dir(&path, recursive, &exclude, &include, max_depth, sb.as_ref()) {
777 Ok(entries) => {
778 let result = lua.create_table()?;
779 for (i, entry) in entries.iter().enumerate() {
780 let t = lua.create_table()?;
781 t.set("path", entry.path.as_str())?;
782 t.set("relative", entry.relative.as_str())?;
783 t.set("is_dir", entry.is_dir)?;
784 t.set("size", entry.size)?;
785 t.set("modified", entry.modified)?;
786 result.set(i + 1, t)?;
787 }
788 Ok(result)
789 }
790 Err(e) => Err(mlua::Error::RuntimeError(e)),
791 }
792 })?;
793 orcs_table.set("scan_dir", scan_dir_fn)?;
794
795 let sb = Arc::clone(&sandbox);
797 let parse_fm_fn =
798 lua.create_function(move |lua, path: String| {
799 match tool_parse_frontmatter(&path, sb.as_ref()) {
800 Ok(result) => frontmatter_result_to_lua(lua, result),
801 Err(e) => {
802 let t = lua.create_table()?;
803 t.set("ok", false)?;
804 t.set("error", e)?;
805 Ok(t)
806 }
807 }
808 })?;
809 orcs_table.set("parse_frontmatter", parse_fm_fn)?;
810
811 let parse_fm_str_fn = lua.create_function(move |lua, content: String| {
813 match tool_parse_frontmatter_str(&content) {
814 Ok(result) => frontmatter_result_to_lua(lua, result),
815 Err(e) => {
816 let t = lua.create_table()?;
817 t.set("ok", false)?;
818 t.set("error", e)?;
819 Ok(t)
820 }
821 }
822 })?;
823 orcs_table.set("parse_frontmatter_str", parse_fm_str_fn)?;
824
825 let parse_toml_fn =
827 lua.create_function(
828 move |lua, content: String| match tool_parse_toml(&content) {
829 Ok(value) => lua.to_value(&value).map_err(|e| {
830 mlua::Error::RuntimeError(format!("TOML to Lua conversion failed: {e}"))
831 }),
832 Err(e) => Err(mlua::Error::RuntimeError(e)),
833 },
834 )?;
835 orcs_table.set("parse_toml", parse_toml_fn)?;
836
837 let glob_match_fn =
839 lua.create_function(move |lua, (patterns_tbl, paths_tbl): (Table, Table)| {
840 let patterns: Vec<String> = patterns_tbl
841 .sequence_values::<String>()
842 .filter_map(|v| v.ok())
843 .collect();
844 let paths: Vec<String> = paths_tbl
845 .sequence_values::<String>()
846 .filter_map(|v| v.ok())
847 .collect();
848
849 match tool_glob_match(&patterns, &paths) {
850 Ok(result) => {
851 let t = lua.create_table()?;
852
853 let matched = lua.create_table()?;
854 for (i, m) in result.matched.iter().enumerate() {
855 matched.set(i + 1, m.as_str())?;
856 }
857 t.set("matched", matched)?;
858
859 let unmatched = lua.create_table()?;
860 for (i, u) in result.unmatched.iter().enumerate() {
861 unmatched.set(i + 1, u.as_str())?;
862 }
863 t.set("unmatched", unmatched)?;
864
865 Ok(t)
866 }
867 Err(e) => Err(mlua::Error::RuntimeError(e)),
868 }
869 })?;
870 orcs_table.set("glob_match", glob_match_fn)?;
871
872 let load_lua_fn = lua.create_function(
874 move |lua, (content, source_name): (String, Option<String>)| {
875 let name = source_name.as_deref().unwrap_or("(eval)");
876 tool_load_lua(lua, &content, name).map_err(mlua::Error::RuntimeError)
877 },
878 )?;
879 orcs_table.set("load_lua", load_lua_fn)?;
880
881 tracing::debug!(
882 "Registered orcs tool functions: read, write, grep, glob, mkdir, remove, mv, scan_dir, parse_frontmatter, parse_toml, glob_match, load_lua (sandbox_root={})",
883 sandbox.root().display()
884 );
885 Ok(())
886}
887
888fn frontmatter_result_to_lua(lua: &Lua, result: FrontmatterResult) -> Result<Table, mlua::Error> {
890 let t = lua.create_table()?;
891 match result.frontmatter {
892 Some(fm) => {
893 let lua_fm = lua.to_value(&fm)?;
894 t.set("frontmatter", lua_fm)?;
895 }
896 None => t.set("frontmatter", mlua::Value::Nil)?,
897 }
898 t.set("body", result.body)?;
899 match result.format {
900 Some(f) => t.set("format", f)?,
901 None => t.set("format", mlua::Value::Nil)?,
902 }
903 Ok(t)
904}
905
906pub(crate) struct ToolHookContext {
914 pub(crate) registry: orcs_hook::SharedHookRegistry,
915 pub(crate) component_id: orcs_types::ComponentId,
916}
917
918const HOOKABLE_TOOLS: &[&str] = &[
920 "read",
921 "write",
922 "grep",
923 "glob",
924 "mkdir",
925 "remove",
926 "mv",
927 "scan_dir",
928 "parse_frontmatter",
929];
930
931pub(crate) fn wrap_tools_with_hooks(lua: &Lua) -> Result<(), LuaError> {
959 let orcs_table: Table = lua.globals().get("orcs")?;
960
961 let dispatch_fn = lua.create_function(
964 |lua, (phase, tool_name, args_val): (String, String, mlua::Value)| {
965 let (registry, component_id) = {
967 let ctx = lua.app_data_ref::<ToolHookContext>();
968 let Some(ctx) = ctx else {
969 return Ok(mlua::Value::Nil);
970 };
971 (
972 std::sync::Arc::clone(&ctx.registry),
973 ctx.component_id.clone(),
974 )
975 };
976
977 let point = match phase.as_str() {
978 "pre" => orcs_hook::HookPoint::ToolPreExecute,
979 "post" => orcs_hook::HookPoint::ToolPostExecute,
980 _ => return Ok(mlua::Value::Nil),
981 };
982
983 let args_json: serde_json::Value =
984 lua.from_value(args_val).unwrap_or(serde_json::Value::Null);
985
986 let payload = serde_json::json!({
987 "tool": tool_name,
988 "args": args_json,
989 });
990
991 let original_payload = if phase == "post" {
994 Some(payload.clone())
995 } else {
996 None
997 };
998
999 let hook_ctx = orcs_hook::HookContext::new(
1000 point,
1001 component_id.clone(),
1002 orcs_types::ChannelId::new(),
1003 orcs_types::Principal::System,
1004 0,
1005 payload,
1006 );
1007
1008 let action = {
1009 let guard = registry.read().unwrap_or_else(|poisoned| {
1010 tracing::warn!("hook registry lock poisoned, using inner value");
1011 poisoned.into_inner()
1012 });
1013 guard.dispatch(point, &component_id, None, hook_ctx)
1014 };
1015
1016 match action {
1017 orcs_hook::HookAction::Abort { reason } => {
1018 let result = lua.create_table()?;
1019 result.set("ok", false)?;
1020 result.set("error", format!("blocked by hook: {reason}"))?;
1021 Ok(mlua::Value::Table(result))
1022 }
1023 orcs_hook::HookAction::Skip(val) | orcs_hook::HookAction::Replace(val) => {
1024 lua.to_value(&val)
1025 }
1026 orcs_hook::HookAction::Continue(ctx) => {
1027 if let Some(original) = original_payload {
1030 if ctx.payload != original {
1031 lua.to_value(&ctx.payload)
1032 } else {
1033 Ok(mlua::Value::Nil)
1034 }
1035 } else {
1036 Ok(mlua::Value::Nil)
1037 }
1038 }
1039 }
1040 },
1041 )?;
1042 orcs_table.set("_dispatch_tool_hook", dispatch_fn)?;
1043
1044 for &name in HOOKABLE_TOOLS {
1046 if orcs_table.get::<Function>(name).is_err() {
1047 continue;
1048 }
1049
1050 let wrap_code = format!(
1051 r#"
1052 do
1053 local _orig = orcs.{name}
1054 orcs.{name} = function(...)
1055 local pre = orcs._dispatch_tool_hook("pre", "{name}", {{...}})
1056 if pre ~= nil then return pre end
1057 local result = _orig(...)
1058 local post = orcs._dispatch_tool_hook("post", "{name}", result)
1059 if post ~= nil then return post end
1060 return result
1061 end
1062 end
1063 "#,
1064 );
1065
1066 lua.load(&wrap_code)
1067 .exec()
1068 .map_err(|e| LuaError::InvalidScript(format!("failed to wrap tool '{name}': {e}")))?;
1069 }
1070
1071 tracing::debug!("Wrapped tools with hook dispatch: {:?}", HOOKABLE_TOOLS);
1072 Ok(())
1073}
1074
1075#[cfg(test)]
1076mod tests {
1077 use super::*;
1078 use orcs_runtime::sandbox::ProjectSandbox;
1079 use orcs_runtime::WorkDir;
1080 use std::fs;
1081 use std::path::PathBuf;
1082
1083 fn test_sandbox() -> (WorkDir, PathBuf, Arc<dyn SandboxPolicy>) {
1086 let wd = WorkDir::temporary().expect("should create temporary work directory");
1087 let root = wd
1088 .path()
1089 .canonicalize()
1090 .expect("should canonicalize work dir path");
1091 let sandbox = ProjectSandbox::new(&root).expect("should create project sandbox");
1092 (wd, root, Arc::new(sandbox))
1093 }
1094
1095 #[test]
1098 fn read_existing_file() {
1099 let (_td, root, sandbox) = test_sandbox();
1100 let file = root.join("test.txt");
1101 fs::write(&file, "hello world").expect("should write test file");
1102
1103 let (content, size) = tool_read(
1104 file.to_str().expect("path should be valid UTF-8"),
1105 sandbox.as_ref(),
1106 )
1107 .expect("should read existing file");
1108 assert_eq!(content, "hello world");
1109 assert_eq!(size, 11);
1110 }
1111
1112 #[test]
1113 fn read_nonexistent_file() {
1114 let (_td, _root, sandbox) = test_sandbox();
1115 let result = tool_read("nonexistent.txt", sandbox.as_ref());
1116 assert!(result.is_err());
1117 }
1118
1119 #[test]
1120 fn read_directory_fails() {
1121 let (_td, root, sandbox) = test_sandbox();
1122 let sub = root.join("subdir");
1123 fs::create_dir_all(&sub).expect("should create subdirectory");
1124
1125 let result = tool_read(
1126 sub.to_str().expect("path should be valid UTF-8"),
1127 sandbox.as_ref(),
1128 );
1129 assert!(result.is_err());
1130 assert!(result
1131 .expect_err("should fail for directory")
1132 .contains("not a file"));
1133 }
1134
1135 #[test]
1136 fn read_outside_root_rejected() {
1137 let (_td, _root, sandbox) = test_sandbox();
1138 let result = tool_read("/etc/hosts", sandbox.as_ref());
1139 assert!(result.is_err());
1140 assert!(result
1141 .expect_err("should deny access outside root")
1142 .contains("access denied"));
1143 }
1144
1145 #[test]
1148 fn write_new_file() {
1149 let (_td, root, sandbox) = test_sandbox();
1150 let file = root.join("new.txt");
1151
1152 let bytes = tool_write(
1153 file.to_str().expect("path should be valid UTF-8"),
1154 "new content",
1155 sandbox.as_ref(),
1156 )
1157 .expect("should write new file");
1158 assert_eq!(bytes, 11);
1159 assert_eq!(
1160 fs::read_to_string(&file).expect("should read written file"),
1161 "new content"
1162 );
1163 }
1164
1165 #[test]
1166 fn write_overwrites_existing() {
1167 let (_td, root, sandbox) = test_sandbox();
1168 let file = root.join("existing.txt");
1169 fs::write(&file, "old").expect("should write initial file");
1170
1171 tool_write(
1172 file.to_str().expect("path should be valid UTF-8"),
1173 "new",
1174 sandbox.as_ref(),
1175 )
1176 .expect("should overwrite existing file");
1177 assert_eq!(
1178 fs::read_to_string(&file).expect("should read overwritten file"),
1179 "new"
1180 );
1181 }
1182
1183 #[test]
1184 fn write_creates_parent_dirs() {
1185 let (_td, root, sandbox) = test_sandbox();
1186 let file = root.join("sub/dir/file.txt");
1187
1188 tool_write(
1189 file.to_str().expect("path should be valid UTF-8"),
1190 "nested",
1191 sandbox.as_ref(),
1192 )
1193 .expect("should write file with parent dir creation");
1194 assert_eq!(
1195 fs::read_to_string(&file).expect("should read nested file"),
1196 "nested"
1197 );
1198 }
1199
1200 #[test]
1201 fn write_atomic_no_temp_leftover() {
1202 let (_td, root, sandbox) = test_sandbox();
1203 let file = root.join("atomic.txt");
1204
1205 tool_write(
1206 file.to_str().expect("path should be valid UTF-8"),
1207 "content",
1208 sandbox.as_ref(),
1209 )
1210 .expect("should write file atomically");
1211
1212 let temp = file.with_extension("tmp.orcs");
1214 assert!(!temp.exists());
1215 }
1216
1217 #[test]
1218 fn write_outside_root_rejected() {
1219 let (_td, _root, sandbox) = test_sandbox();
1220 let result = tool_write("/etc/evil.txt", "bad", sandbox.as_ref());
1221 assert!(result.is_err());
1222 assert!(result
1223 .expect_err("should deny write outside root")
1224 .contains("access denied"));
1225 }
1226
1227 #[test]
1230 fn grep_finds_matches() {
1231 let (_td, root, sandbox) = test_sandbox();
1232 let file = root.join("search.txt");
1233 fs::write(&file, "line one\nline two\nthird line").expect("should write search file");
1234
1235 let matches = tool_grep(
1236 "line",
1237 file.to_str().expect("path should be valid UTF-8"),
1238 sandbox.as_ref(),
1239 )
1240 .expect("should find grep matches");
1241 assert_eq!(matches.len(), 3);
1242 assert_eq!(matches[0].line_number, 1);
1243 assert_eq!(matches[0].line, "line one");
1244 }
1245
1246 #[test]
1247 fn grep_regex_pattern() {
1248 let (_td, root, sandbox) = test_sandbox();
1249 let file = root.join("regex.txt");
1250 fs::write(&file, "foo123\nbar456\nfoo789").expect("should write regex test file");
1251
1252 let matches = tool_grep(
1253 r"foo\d+",
1254 file.to_str().expect("path should be valid UTF-8"),
1255 sandbox.as_ref(),
1256 )
1257 .expect("should find regex matches");
1258 assert_eq!(matches.len(), 2);
1259 }
1260
1261 #[test]
1262 fn grep_no_matches() {
1263 let (_td, root, sandbox) = test_sandbox();
1264 let file = root.join("empty.txt");
1265 fs::write(&file, "nothing here").expect("should write test file");
1266
1267 let matches = tool_grep(
1268 "nonexistent",
1269 file.to_str().expect("path should be valid UTF-8"),
1270 sandbox.as_ref(),
1271 )
1272 .expect("should return empty matches without error");
1273 assert!(matches.is_empty());
1274 }
1275
1276 #[test]
1277 fn grep_invalid_regex() {
1278 let (_td, root, sandbox) = test_sandbox();
1279 let file = root.join("test.txt");
1280 fs::write(&file, "content").expect("should write test file");
1281
1282 let result = tool_grep(
1283 "[invalid",
1284 file.to_str().expect("path should be valid UTF-8"),
1285 sandbox.as_ref(),
1286 );
1287 assert!(result.is_err());
1288 assert!(result
1289 .expect_err("should fail for invalid regex")
1290 .contains("invalid regex"));
1291 }
1292
1293 #[test]
1294 fn grep_directory_recursive() {
1295 let (_td, root, sandbox) = test_sandbox();
1296 let sub = root.join("sub");
1297 fs::create_dir_all(&sub).expect("should create subdirectory");
1298
1299 fs::write(root.join("a.txt"), "target line\nother").expect("should write a.txt");
1300 fs::write(sub.join("b.txt"), "no match\ntarget here").expect("should write b.txt");
1301
1302 let matches = tool_grep(
1303 "target",
1304 root.to_str().expect("path should be valid UTF-8"),
1305 sandbox.as_ref(),
1306 )
1307 .expect("should find recursive grep matches");
1308 assert_eq!(matches.len(), 2);
1309 }
1310
1311 #[test]
1312 fn grep_outside_root_rejected() {
1313 let (_td, _root, sandbox) = test_sandbox();
1314 let result = tool_grep("pattern", "/etc", sandbox.as_ref());
1315 assert!(result.is_err());
1316 assert!(result
1317 .expect_err("should deny grep outside root")
1318 .contains("access denied"));
1319 }
1320
1321 #[test]
1324 fn glob_finds_files() {
1325 let (_td, root, sandbox) = test_sandbox();
1326 fs::write(root.join("a.txt"), "").expect("should write a.txt");
1327 fs::write(root.join("b.txt"), "").expect("should write b.txt");
1328 fs::write(root.join("c.rs"), "").expect("should write c.rs");
1329
1330 let files = tool_glob(
1331 "*.txt",
1332 Some(root.to_str().expect("path should be valid UTF-8")),
1333 sandbox.as_ref(),
1334 )
1335 .expect("should find txt files via glob");
1336 assert_eq!(files.len(), 2);
1337 }
1338
1339 #[test]
1340 fn glob_recursive() {
1341 let (_td, root, sandbox) = test_sandbox();
1342 let sub = root.join("sub");
1343 fs::create_dir_all(&sub).expect("should create subdirectory");
1344 fs::write(root.join("top.rs"), "").expect("should write top.rs");
1345 fs::write(sub.join("nested.rs"), "").expect("should write nested.rs");
1346
1347 let files = tool_glob(
1348 "**/*.rs",
1349 Some(root.to_str().expect("path should be valid UTF-8")),
1350 sandbox.as_ref(),
1351 )
1352 .expect("should find rs files recursively");
1353 assert_eq!(files.len(), 2);
1354 }
1355
1356 #[test]
1357 fn glob_no_matches() {
1358 let (_td, root, sandbox) = test_sandbox();
1359 let files = tool_glob(
1360 "*.xyz",
1361 Some(root.to_str().expect("path should be valid UTF-8")),
1362 sandbox.as_ref(),
1363 )
1364 .expect("should return empty matches for no-match glob");
1365 assert!(files.is_empty());
1366 }
1367
1368 #[test]
1369 fn glob_invalid_pattern() {
1370 let (_td, root, sandbox) = test_sandbox();
1371 let result = tool_glob(
1372 "[invalid",
1373 Some(root.to_str().expect("path should be valid UTF-8")),
1374 sandbox.as_ref(),
1375 );
1376 assert!(result.is_err());
1377 }
1378
1379 #[test]
1380 fn glob_outside_root_rejected() {
1381 let (_td, _root, sandbox) = test_sandbox();
1382 let result = tool_glob("*", Some("/etc"), sandbox.as_ref());
1383 assert!(result.is_err());
1384 assert!(result
1385 .expect_err("should deny glob outside root")
1386 .contains("access denied"));
1387 }
1388
1389 #[test]
1390 fn glob_rejects_dotdot_in_pattern() {
1391 let (_td, _root, sandbox) = test_sandbox();
1392 let result = tool_glob("../../**/*", None, sandbox.as_ref());
1393 assert!(result.is_err());
1394 assert!(
1395 result
1396 .expect_err("should reject dotdot pattern")
1397 .contains("'..'"),
1398 "expected dotdot rejection"
1399 );
1400 }
1401
1402 #[test]
1405 fn grep_respects_depth_limit() {
1406 let (_td, root, sandbox) = test_sandbox();
1407
1408 let mut deep = root.clone();
1410 for i in 0..35 {
1411 deep = deep.join(format!("d{i}"));
1412 }
1413 fs::create_dir_all(&deep).expect("should create deep directory structure");
1414 fs::write(deep.join("deep.txt"), "needle").expect("should write deep file");
1415
1416 fs::write(root.join("shallow.txt"), "needle").expect("should write shallow file");
1418
1419 let matches = tool_grep(
1420 "needle",
1421 root.to_str().expect("path should be valid UTF-8"),
1422 sandbox.as_ref(),
1423 )
1424 .expect("should grep respecting depth limit");
1425 assert_eq!(matches.len(), 1);
1427 }
1428
1429 #[test]
1432 fn register_tools_in_lua() {
1433 let (_td, _root, sandbox) = test_sandbox();
1434 let lua = Lua::new();
1435 let orcs = lua.create_table().expect("should create orcs table");
1436 lua.globals()
1437 .set("orcs", orcs)
1438 .expect("should set orcs global");
1439
1440 register_tool_functions(&lua, sandbox).expect("should register tool functions");
1441
1442 let orcs: Table = lua
1443 .globals()
1444 .get("orcs")
1445 .expect("should get orcs table back");
1446 assert!(orcs.get::<mlua::Function>("read").is_ok());
1447 assert!(orcs.get::<mlua::Function>("write").is_ok());
1448 assert!(orcs.get::<mlua::Function>("grep").is_ok());
1449 assert!(orcs.get::<mlua::Function>("glob").is_ok());
1450 }
1451
1452 #[test]
1453 fn lua_read_file() {
1454 let (_td, root, sandbox) = test_sandbox();
1455 let file = root.join("lua_read.txt");
1456 fs::write(&file, "lua content").expect("should write lua read test file");
1457
1458 let lua = Lua::new();
1459 let orcs = lua.create_table().expect("should create orcs table");
1460 lua.globals()
1461 .set("orcs", orcs)
1462 .expect("should set orcs global");
1463 register_tool_functions(&lua, sandbox).expect("should register tool functions");
1464
1465 let code = format!(
1466 r#"return orcs.read("{}")"#,
1467 file.display().to_string().replace('\\', "\\\\")
1468 );
1469 let result: Table = lua.load(&code).eval().expect("should eval lua read");
1470 assert!(result.get::<bool>("ok").expect("should have ok field"));
1471 assert_eq!(
1472 result
1473 .get::<String>("content")
1474 .expect("should have content field"),
1475 "lua content"
1476 );
1477 assert_eq!(
1478 result.get::<u64>("size").expect("should have size field"),
1479 11
1480 );
1481 }
1482
1483 #[test]
1484 fn lua_write_file() {
1485 let (_td, root, sandbox) = test_sandbox();
1486 let file = root.join("lua_write.txt");
1487
1488 let lua = Lua::new();
1489 let orcs = lua.create_table().expect("should create orcs table");
1490 lua.globals()
1491 .set("orcs", orcs)
1492 .expect("should set orcs global");
1493 register_tool_functions(&lua, sandbox).expect("should register tool functions");
1494
1495 let code = format!(
1496 r#"return orcs.write("{}", "written from lua")"#,
1497 file.display().to_string().replace('\\', "\\\\")
1498 );
1499 let result: Table = lua.load(&code).eval().expect("should eval lua write");
1500 assert!(result.get::<bool>("ok").expect("should have ok field"));
1501 assert_eq!(
1502 fs::read_to_string(&file).expect("should read lua-written file"),
1503 "written from lua"
1504 );
1505 }
1506
1507 #[test]
1508 fn lua_grep_file() {
1509 let (_td, root, sandbox) = test_sandbox();
1510 let file = root.join("lua_grep.txt");
1511 fs::write(&file, "alpha\nbeta\nalpha_two").expect("should write grep test file");
1512
1513 let lua = Lua::new();
1514 let orcs = lua.create_table().expect("should create orcs table");
1515 lua.globals()
1516 .set("orcs", orcs)
1517 .expect("should set orcs global");
1518 register_tool_functions(&lua, sandbox).expect("should register tool functions");
1519
1520 let code = format!(
1521 r#"return orcs.grep("alpha", "{}")"#,
1522 file.display().to_string().replace('\\', "\\\\")
1523 );
1524 let result: Table = lua.load(&code).eval().expect("should eval lua grep");
1525 assert!(result.get::<bool>("ok").expect("should have ok field"));
1526 assert_eq!(
1527 result
1528 .get::<usize>("count")
1529 .expect("should have count field"),
1530 2
1531 );
1532 }
1533
1534 #[test]
1535 fn lua_glob_files() {
1536 let (_td, root, sandbox) = test_sandbox();
1537 fs::write(root.join("a.lua"), "").expect("should write a.lua");
1538 fs::write(root.join("b.lua"), "").expect("should write b.lua");
1539
1540 let lua = Lua::new();
1541 let orcs = lua.create_table().expect("should create orcs table");
1542 lua.globals()
1543 .set("orcs", orcs)
1544 .expect("should set orcs global");
1545 register_tool_functions(&lua, sandbox).expect("should register tool functions");
1546
1547 let code = format!(
1548 r#"return orcs.glob("*.lua", "{}")"#,
1549 root.display().to_string().replace('\\', "\\\\")
1550 );
1551 let result: Table = lua.load(&code).eval().expect("should eval lua glob");
1552 assert!(result.get::<bool>("ok").expect("should have ok field"));
1553 assert_eq!(
1554 result
1555 .get::<usize>("count")
1556 .expect("should have count field"),
1557 2
1558 );
1559 }
1560
1561 #[test]
1562 fn lua_read_nonexistent_returns_error() {
1563 let (_td, _root, sandbox) = test_sandbox();
1564 let lua = Lua::new();
1565 let orcs = lua.create_table().expect("should create orcs table");
1566 lua.globals()
1567 .set("orcs", orcs)
1568 .expect("should set orcs global");
1569 register_tool_functions(&lua, sandbox).expect("should register tool functions");
1570
1571 let result: Table = lua
1572 .load(r#"return orcs.read("nonexistent_file_xyz.txt")"#)
1573 .eval()
1574 .expect("should eval lua read for nonexistent file");
1575 assert!(!result.get::<bool>("ok").expect("should have ok field"));
1576 assert!(result.get::<String>("error").is_ok());
1577 }
1578
1579 #[test]
1580 fn lua_read_outside_sandbox_returns_error() {
1581 let (_td, _root, sandbox) = test_sandbox();
1582 let lua = Lua::new();
1583 let orcs = lua.create_table().expect("should create orcs table");
1584 lua.globals()
1585 .set("orcs", orcs)
1586 .expect("should set orcs global");
1587 register_tool_functions(&lua, sandbox).expect("should register tool functions");
1588
1589 let result: Table = lua
1590 .load(r#"return orcs.read("/etc/hosts")"#)
1591 .eval()
1592 .expect("should eval lua read for outside sandbox");
1593 assert!(!result.get::<bool>("ok").expect("should have ok field"));
1594 let error = result
1595 .get::<String>("error")
1596 .expect("should have error field");
1597 assert!(
1598 error.contains("access denied"),
1599 "expected 'access denied', got: {error}"
1600 );
1601 }
1602
1603 #[cfg(unix)]
1606 mod symlink_tests {
1607 use super::*;
1608 use std::os::unix::fs::symlink;
1609
1610 #[test]
1611 fn glob_skips_symlink_outside_sandbox() {
1612 let (_td, root, sandbox) = test_sandbox();
1613 let outside = WorkDir::temporary().expect("should create outside temp work dir");
1614 let outside_canon = outside
1615 .path()
1616 .canonicalize()
1617 .expect("should canonicalize outside path");
1618 fs::write(outside_canon.join("leaked.txt"), "secret")
1619 .expect("should write leaked file");
1620 symlink(&outside_canon, root.join("escape")).expect("should create escape symlink");
1621 fs::write(root.join("ok.txt"), "safe").expect("should write ok file");
1622
1623 let files =
1624 tool_glob("**/*.txt", None, sandbox.as_ref()).expect("should glob without error");
1625 for f in &files {
1626 assert!(!f.contains("leaked"), "leaked file found: {f}");
1627 }
1628 assert_eq!(files.len(), 1, "only ok.txt should be found");
1629 }
1630
1631 #[test]
1632 fn grep_dir_skips_symlink_outside_sandbox() {
1633 let (_td, root, sandbox) = test_sandbox();
1634 let outside = WorkDir::temporary().expect("should create outside temp work dir");
1635 let outside_canon = outside
1636 .path()
1637 .canonicalize()
1638 .expect("should canonicalize outside path");
1639 fs::write(outside_canon.join("secret.txt"), "password123")
1640 .expect("should write secret file");
1641 symlink(&outside_canon, root.join("escape")).expect("should create escape symlink");
1642 fs::write(root.join("ok.txt"), "password123").expect("should write ok file");
1643
1644 let matches = tool_grep(
1645 "password",
1646 root.to_str().expect("path should be valid UTF-8"),
1647 sandbox.as_ref(),
1648 )
1649 .expect("should grep without error");
1650 assert_eq!(matches.len(), 1, "symlinked outside file should be skipped");
1652 }
1653
1654 #[test]
1655 fn write_via_symlink_escape_rejected() {
1656 let (_td, root, sandbox) = test_sandbox();
1657 let outside = WorkDir::temporary().expect("should create outside temp work dir");
1658 let outside_canon = outside
1659 .path()
1660 .canonicalize()
1661 .expect("should canonicalize outside path");
1662 symlink(&outside_canon, root.join("escape")).expect("should create escape symlink");
1663
1664 let result = tool_write(
1665 root.join("escape/evil.txt")
1666 .to_str()
1667 .expect("path should be valid UTF-8"),
1668 "evil",
1669 sandbox.as_ref(),
1670 );
1671 assert!(
1672 result.is_err(),
1673 "write via symlink escape should be rejected"
1674 );
1675 }
1676
1677 #[test]
1678 fn read_via_symlink_escape_rejected() {
1679 let (_td, root, sandbox) = test_sandbox();
1680 let outside = WorkDir::temporary().expect("should create outside temp work dir");
1681 let outside_canon = outside
1682 .path()
1683 .canonicalize()
1684 .expect("should canonicalize outside path");
1685 fs::write(outside_canon.join("secret.txt"), "secret")
1686 .expect("should write secret file");
1687 symlink(&outside_canon, root.join("escape")).expect("should create escape symlink");
1688
1689 let result = tool_read(
1690 root.join("escape/secret.txt")
1691 .to_str()
1692 .expect("path should be valid UTF-8"),
1693 sandbox.as_ref(),
1694 );
1695 assert!(
1696 result.is_err(),
1697 "read via symlink escape should be rejected"
1698 );
1699 }
1700 }
1701
1702 mod tool_hook_tests {
1705 use super::*;
1706 use orcs_hook::{HookPoint, HookRegistry};
1707 use orcs_types::ComponentId;
1708
1709 fn setup_lua_with_hooks() -> (Lua, orcs_hook::SharedHookRegistry, WorkDir) {
1710 let wd = WorkDir::temporary().expect("should create work dir for hooks");
1711 let root = wd
1712 .path()
1713 .canonicalize()
1714 .expect("should canonicalize hook test root");
1715 let sandbox: Arc<dyn SandboxPolicy> =
1716 Arc::new(ProjectSandbox::new(&root).expect("should create hook sandbox"));
1717
1718 let lua = Lua::new();
1719 let orcs = lua.create_table().expect("should create orcs table");
1720 lua.globals()
1721 .set("orcs", orcs)
1722 .expect("should set orcs global");
1723 register_tool_functions(&lua, sandbox).expect("should register tool functions");
1724
1725 let registry = std::sync::Arc::new(std::sync::RwLock::new(HookRegistry::new()));
1726 let comp_id = ComponentId::builtin("test");
1727
1728 lua.set_app_data(ToolHookContext {
1729 registry: std::sync::Arc::clone(®istry),
1730 component_id: comp_id,
1731 });
1732
1733 wrap_tools_with_hooks(&lua).expect("should wrap tools with hooks");
1734
1735 (lua, registry, wd)
1736 }
1737
1738 #[test]
1739 fn dispatch_function_registered() {
1740 let (lua, _registry, _td) = setup_lua_with_hooks();
1741 let orcs: Table = lua.globals().get("orcs").expect("should get orcs table");
1742 assert!(orcs.get::<Function>("_dispatch_tool_hook").is_ok());
1743 }
1744
1745 #[test]
1746 fn tools_work_normally_without_hooks() {
1747 let (lua, _registry, td) = setup_lua_with_hooks();
1748 let root = td.path().canonicalize().expect("should canonicalize root");
1749 fs::write(root.join("test.txt"), "hello").expect("should write test file");
1750
1751 let code = format!(
1752 r#"return orcs.read("{}")"#,
1753 root.join("test.txt")
1754 .display()
1755 .to_string()
1756 .replace('\\', "\\\\")
1757 );
1758 let result: Table = lua
1759 .load(&code)
1760 .eval()
1761 .expect("should eval read without hooks");
1762 assert!(result.get::<bool>("ok").expect("should have ok field"));
1763 assert_eq!(
1764 result
1765 .get::<String>("content")
1766 .expect("should have content field"),
1767 "hello"
1768 );
1769 }
1770
1771 #[test]
1772 fn pre_hook_abort_blocks_read() {
1773 let (lua, registry, td) = setup_lua_with_hooks();
1774 let root = td.path().canonicalize().expect("should canonicalize root");
1775 fs::write(root.join("secret.txt"), "top secret").expect("should write secret file");
1776
1777 {
1778 let mut guard = registry.write().expect("should acquire write lock");
1779 guard.register(Box::new(orcs_hook::testing::MockHook::aborter(
1780 "block-read",
1781 "*::*",
1782 HookPoint::ToolPreExecute,
1783 "access denied by policy",
1784 )));
1785 }
1786
1787 let code = format!(
1788 r#"return orcs.read("{}")"#,
1789 root.join("secret.txt")
1790 .display()
1791 .to_string()
1792 .replace('\\', "\\\\")
1793 );
1794 let result: Table = lua
1795 .load(&code)
1796 .eval()
1797 .expect("should eval read with abort hook");
1798 assert!(!result.get::<bool>("ok").expect("should have ok field"));
1799 let error = result
1800 .get::<String>("error")
1801 .expect("should have error field");
1802 assert!(
1803 error.contains("blocked by hook"),
1804 "expected 'blocked by hook', got: {error}"
1805 );
1806 assert!(error.contains("access denied by policy"));
1807 }
1808
1809 #[test]
1810 fn pre_hook_skip_returns_custom_value() {
1811 let (lua, registry, td) = setup_lua_with_hooks();
1812 let root = td.path().canonicalize().expect("should canonicalize root");
1813 fs::write(root.join("real.txt"), "real content").expect("should write real file");
1814
1815 {
1816 let mut guard = registry.write().expect("should acquire write lock");
1817 guard.register(Box::new(orcs_hook::testing::MockHook::skipper(
1818 "skip-read",
1819 "*::*",
1820 HookPoint::ToolPreExecute,
1821 serde_json::json!({"ok": true, "content": "cached", "size": 6}),
1822 )));
1823 }
1824
1825 let code = format!(
1826 r#"return orcs.read("{}")"#,
1827 root.join("real.txt")
1828 .display()
1829 .to_string()
1830 .replace('\\', "\\\\")
1831 );
1832 let result: Table = lua
1833 .load(&code)
1834 .eval()
1835 .expect("should eval read with skip hook");
1836 assert!(result.get::<bool>("ok").expect("should have ok field"));
1837 assert_eq!(
1838 result
1839 .get::<String>("content")
1840 .expect("should have content field"),
1841 "cached"
1842 );
1843 }
1844
1845 #[test]
1846 fn pre_hook_continue_allows_tool() {
1847 let (lua, registry, td) = setup_lua_with_hooks();
1848 let root = td.path().canonicalize().expect("should canonicalize root");
1849 fs::write(root.join("allowed.txt"), "allowed content")
1850 .expect("should write allowed file");
1851
1852 {
1853 let mut guard = registry.write().expect("should acquire write lock");
1854 guard.register(Box::new(orcs_hook::testing::MockHook::pass_through(
1855 "pass-read",
1856 "*::*",
1857 HookPoint::ToolPreExecute,
1858 )));
1859 }
1860
1861 let code = format!(
1862 r#"return orcs.read("{}")"#,
1863 root.join("allowed.txt")
1864 .display()
1865 .to_string()
1866 .replace('\\', "\\\\")
1867 );
1868 let result: Table = lua
1869 .load(&code)
1870 .eval()
1871 .expect("should eval read with continue hook");
1872 assert!(result.get::<bool>("ok").expect("should have ok field"));
1873 assert_eq!(
1874 result
1875 .get::<String>("content")
1876 .expect("should have content field"),
1877 "allowed content"
1878 );
1879 }
1880
1881 #[test]
1882 fn post_hook_replace_changes_result() {
1883 let (lua, registry, td) = setup_lua_with_hooks();
1884 let root = td.path().canonicalize().expect("should canonicalize root");
1885 fs::write(root.join("original.txt"), "original").expect("should write original file");
1886
1887 {
1888 let mut guard = registry.write().expect("should acquire write lock");
1889 guard.register(Box::new(orcs_hook::testing::MockHook::replacer(
1890 "replace-result",
1891 "*::*",
1892 HookPoint::ToolPostExecute,
1893 serde_json::json!({"ok": true, "content": "replaced", "size": 8}),
1894 )));
1895 }
1896
1897 let code = format!(
1898 r#"return orcs.read("{}")"#,
1899 root.join("original.txt")
1900 .display()
1901 .to_string()
1902 .replace('\\', "\\\\")
1903 );
1904 let result: Table = lua
1905 .load(&code)
1906 .eval()
1907 .expect("should eval read with replace hook");
1908 assert!(result.get::<bool>("ok").expect("should have ok field"));
1909 assert_eq!(
1910 result
1911 .get::<String>("content")
1912 .expect("should have content field"),
1913 "replaced"
1914 );
1915 }
1916
1917 #[test]
1918 fn post_hook_continue_preserves_result() {
1919 let (lua, registry, td) = setup_lua_with_hooks();
1920 let root = td.path().canonicalize().expect("should canonicalize root");
1921 fs::write(root.join("keep.txt"), "keep this").expect("should write keep file");
1922
1923 {
1924 let mut guard = registry.write().expect("should acquire write lock");
1925 guard.register(Box::new(orcs_hook::testing::MockHook::pass_through(
1926 "observe-only",
1927 "*::*",
1928 HookPoint::ToolPostExecute,
1929 )));
1930 }
1931
1932 let code = format!(
1933 r#"return orcs.read("{}")"#,
1934 root.join("keep.txt")
1935 .display()
1936 .to_string()
1937 .replace('\\', "\\\\")
1938 );
1939 let result: Table = lua
1940 .load(&code)
1941 .eval()
1942 .expect("should eval read with observe hook");
1943 assert!(result.get::<bool>("ok").expect("should have ok field"));
1944 assert_eq!(
1945 result
1946 .get::<String>("content")
1947 .expect("should have content field"),
1948 "keep this"
1949 );
1950 }
1951
1952 #[test]
1953 fn pre_hook_abort_blocks_write() {
1954 let (lua, registry, td) = setup_lua_with_hooks();
1955 let root = td.path().canonicalize().expect("should canonicalize root");
1956
1957 {
1958 let mut guard = registry.write().expect("should acquire write lock");
1959 guard.register(Box::new(orcs_hook::testing::MockHook::aborter(
1960 "block-write",
1961 "*::*",
1962 HookPoint::ToolPreExecute,
1963 "writes disabled",
1964 )));
1965 }
1966
1967 let code = format!(
1968 r#"return orcs.write("{}", "evil")"#,
1969 root.join("blocked.txt")
1970 .display()
1971 .to_string()
1972 .replace('\\', "\\\\")
1973 );
1974 let result: Table = lua
1975 .load(&code)
1976 .eval()
1977 .expect("should eval write with abort hook");
1978 assert!(!result.get::<bool>("ok").expect("should have ok field"));
1979 let error = result
1980 .get::<String>("error")
1981 .expect("should have error field");
1982 assert!(error.contains("writes disabled"));
1983
1984 assert!(!root.join("blocked.txt").exists());
1986 }
1987
1988 #[test]
1989 fn hooks_receive_tool_name_in_payload() {
1990 let (lua, registry, td) = setup_lua_with_hooks();
1991 let root = td.path().canonicalize().expect("should canonicalize root");
1992 fs::write(root.join("check.txt"), "data").expect("should write check file");
1993
1994 {
1996 let mut guard = registry.write().expect("should acquire write lock");
1997 guard.register(Box::new(orcs_hook::testing::MockHook::modifier(
1998 "check-tool",
1999 "*::*",
2000 HookPoint::ToolPreExecute,
2001 |ctx| {
2002 assert!(ctx.payload.get("tool").is_some());
2004 assert!(ctx.payload.get("args").is_some());
2005 },
2006 )));
2007 }
2008
2009 let code = format!(
2010 r#"return orcs.read("{}")"#,
2011 root.join("check.txt")
2012 .display()
2013 .to_string()
2014 .replace('\\', "\\\\")
2015 );
2016 let result: Table = lua
2017 .load(&code)
2018 .eval()
2019 .expect("should eval read with modifier hook");
2020 assert!(result.get::<bool>("ok").expect("should have ok field"));
2021 }
2022
2023 #[test]
2024 fn no_context_tools_work_normally() {
2025 let wd = WorkDir::temporary().expect("should create work dir");
2027 let root = wd.path().canonicalize().expect("should canonicalize root");
2028 let sandbox: Arc<dyn SandboxPolicy> =
2029 Arc::new(ProjectSandbox::new(&root).expect("should create sandbox"));
2030
2031 let lua = Lua::new();
2032 let orcs = lua.create_table().expect("should create orcs table");
2033 lua.globals()
2034 .set("orcs", orcs)
2035 .expect("should set orcs global");
2036 register_tool_functions(&lua, sandbox).expect("should register tool functions");
2037
2038 wrap_tools_with_hooks(&lua).expect("should wrap tools with hooks");
2040
2041 fs::write(root.join("nocontext.txt"), "works").expect("should write nocontext file");
2042
2043 let code = format!(
2044 r#"return orcs.read("{}")"#,
2045 root.join("nocontext.txt")
2046 .display()
2047 .to_string()
2048 .replace('\\', "\\\\")
2049 );
2050 let result: Table = lua
2051 .load(&code)
2052 .eval()
2053 .expect("should eval read without hook context");
2054 assert!(result.get::<bool>("ok").expect("should have ok field"));
2055 assert_eq!(
2056 result
2057 .get::<String>("content")
2058 .expect("should have content field"),
2059 "works"
2060 );
2061 }
2062
2063 #[test]
2064 fn pre_hook_abort_blocks_glob() {
2065 let (lua, registry, td) = setup_lua_with_hooks();
2066 let root = td.path().canonicalize().expect("should canonicalize root");
2067 fs::write(root.join("a.txt"), "").expect("should write test file");
2068
2069 {
2070 let mut guard = registry.write().expect("should acquire write lock");
2071 guard.register(Box::new(orcs_hook::testing::MockHook::aborter(
2072 "block-glob",
2073 "*::*",
2074 HookPoint::ToolPreExecute,
2075 "glob not allowed",
2076 )));
2077 }
2078
2079 let code = format!(
2080 r#"return orcs.glob("*.txt", "{}")"#,
2081 root.display().to_string().replace('\\', "\\\\")
2082 );
2083 let result: Table = lua
2084 .load(&code)
2085 .eval()
2086 .expect("should eval glob with abort hook");
2087 assert!(!result.get::<bool>("ok").expect("should have ok field"));
2088 let error = result
2089 .get::<String>("error")
2090 .expect("should have error field");
2091 assert!(error.contains("glob not allowed"));
2092 }
2093 }
2094
2095 }