1use std::collections::HashMap;
7
8use crate::scripting::{LuaEngine, ScriptingError};
9
10use super::types::CaptureSpec;
11
12#[derive(Debug)]
14pub struct BeforeInsertResult {
15 pub content: String,
17}
18
19pub fn run_before_insert_hook(
28 spec: &CaptureSpec,
29 content: &str,
30 vars: &HashMap<String, String>,
31) -> Result<BeforeInsertResult, ScriptingError> {
32 if !spec.has_before_insert {
34 return Ok(BeforeInsertResult { content: content.to_string() });
35 }
36
37 let lua_source = spec
39 .lua_source
40 .as_ref()
41 .ok_or_else(|| ScriptingError::Other("Capture has no Lua source".to_string()))?;
42
43 let engine = LuaEngine::sandboxed()?;
45 let lua = engine.lua();
46
47 let capture_table: mlua::Table =
49 lua.load(lua_source).eval().map_err(ScriptingError::Lua)?;
50
51 let hook_fn: mlua::Function =
53 capture_table.get("before_insert").map_err(ScriptingError::Lua)?;
54
55 let vars_table = lua.create_table().map_err(ScriptingError::Lua)?;
57 for (k, v) in vars {
58 vars_table.set(k.as_str(), v.as_str()).map_err(ScriptingError::Lua)?;
59 }
60
61 let target_table = lua.create_table().map_err(ScriptingError::Lua)?;
63 target_table.set("file", spec.target.file.as_str()).map_err(ScriptingError::Lua)?;
64 if let Some(section) = &spec.target.section {
65 target_table.set("section", section.as_str()).map_err(ScriptingError::Lua)?;
66 }
67 let position_str = match spec.target.position {
68 super::types::CapturePosition::Begin => "begin",
69 super::types::CapturePosition::End => "end",
70 };
71 target_table.set("position", position_str).map_err(ScriptingError::Lua)?;
72
73 let result: mlua::Value =
75 hook_fn.call((content, vars_table, target_table)).map_err(ScriptingError::Lua)?;
76
77 let modified_content = match result {
79 mlua::Value::String(s) => s.to_str().map_err(ScriptingError::Lua)?.to_string(),
80 mlua::Value::Nil => content.to_string(), _ => {
82 return Err(ScriptingError::Other(
83 "before_insert hook must return a string or nil".to_string(),
84 ));
85 }
86 };
87
88 Ok(BeforeInsertResult { content: modified_content })
89}
90
91#[derive(Debug)]
93pub struct AfterInsertResult {
94 pub success: bool,
96}
97
98pub fn run_after_insert_hook(
108 spec: &CaptureSpec,
109 content: &str,
110 vars: &HashMap<String, String>,
111 target_file: &std::path::Path,
112 section_matched: Option<(&str, u8)>,
113) -> Result<AfterInsertResult, ScriptingError> {
114 if !spec.has_after_insert {
116 return Ok(AfterInsertResult { success: true });
117 }
118
119 let lua_source = spec
121 .lua_source
122 .as_ref()
123 .ok_or_else(|| ScriptingError::Other("Capture has no Lua source".to_string()))?;
124
125 let engine = LuaEngine::sandboxed()?;
127 let lua = engine.lua();
128
129 let capture_table: mlua::Table =
131 lua.load(lua_source).eval().map_err(ScriptingError::Lua)?;
132
133 let hook_fn: mlua::Function =
135 capture_table.get("after_insert").map_err(ScriptingError::Lua)?;
136
137 let vars_table = lua.create_table().map_err(ScriptingError::Lua)?;
139 for (k, v) in vars {
140 vars_table.set(k.as_str(), v.as_str()).map_err(ScriptingError::Lua)?;
141 }
142
143 let target_table = lua.create_table().map_err(ScriptingError::Lua)?;
145 target_table.set("file", spec.target.file.as_str()).map_err(ScriptingError::Lua)?;
146 if let Some(section) = &spec.target.section {
147 target_table.set("section", section.as_str()).map_err(ScriptingError::Lua)?;
148 }
149 let position_str = match spec.target.position {
150 super::types::CapturePosition::Begin => "begin",
151 super::types::CapturePosition::End => "end",
152 };
153 target_table.set("position", position_str).map_err(ScriptingError::Lua)?;
154
155 let result_table = lua.create_table().map_err(ScriptingError::Lua)?;
157 result_table
158 .set("target_file", target_file.to_string_lossy().as_ref())
159 .map_err(ScriptingError::Lua)?;
160 result_table.set("success", true).map_err(ScriptingError::Lua)?;
161 if let Some((section_title, level)) = section_matched {
162 result_table.set("section_title", section_title).map_err(ScriptingError::Lua)?;
163 result_table.set("section_level", level).map_err(ScriptingError::Lua)?;
164 }
165
166 let _: mlua::Value = hook_fn
168 .call((content, vars_table, target_table, result_table))
169 .map_err(ScriptingError::Lua)?;
170
171 Ok(AfterInsertResult { success: true })
172}
173
174#[cfg(test)]
175mod tests {
176 use super::*;
177 use crate::captures::lua_loader::load_capture_from_lua;
178 use std::fs;
179 use tempfile::TempDir;
180
181 fn write_lua_capture(
182 dir: &std::path::Path,
183 name: &str,
184 content: &str,
185 ) -> std::path::PathBuf {
186 let path = dir.join(format!("{}.lua", name));
187 fs::write(&path, content).unwrap();
188 path
189 }
190
191 #[test]
192 fn test_before_insert_hook_modifies_content() {
193 let temp = TempDir::new().unwrap();
194 let path = write_lua_capture(
195 temp.path(),
196 "test",
197 r#"
198return {
199 name = "test",
200 target = { file = "test.md", section = "Test" },
201 content = "- {{text}}",
202 before_insert = function(content, vars, target)
203 return "[HOOK] " .. content
204 end,
205}
206"#,
207 );
208
209 let spec = load_capture_from_lua(&path).unwrap();
210 assert!(spec.has_before_insert);
211
212 let vars: HashMap<String, String> = [("text".into(), "hello".into())].into();
213 let result = run_before_insert_hook(&spec, "- hello", &vars).unwrap();
214
215 assert_eq!(result.content, "[HOOK] - hello");
216 }
217
218 #[test]
219 fn test_before_insert_hook_returns_nil() {
220 let temp = TempDir::new().unwrap();
221 let path = write_lua_capture(
222 temp.path(),
223 "test",
224 r#"
225return {
226 name = "test",
227 target = { file = "test.md", section = "Test" },
228 content = "- {{text}}",
229 before_insert = function(content, vars, target)
230 return nil -- Let original content through
231 end,
232}
233"#,
234 );
235
236 let spec = load_capture_from_lua(&path).unwrap();
237 let vars: HashMap<String, String> = [("text".into(), "hello".into())].into();
238 let result = run_before_insert_hook(&spec, "- hello", &vars).unwrap();
239
240 assert_eq!(result.content, "- hello");
241 }
242
243 #[test]
244 fn test_no_hook_passes_through() {
245 let temp = TempDir::new().unwrap();
246 let path = write_lua_capture(
247 temp.path(),
248 "test",
249 r#"
250return {
251 name = "test",
252 target = { file = "test.md", section = "Test" },
253 content = "- {{text}}",
254}
255"#,
256 );
257
258 let spec = load_capture_from_lua(&path).unwrap();
259 assert!(!spec.has_before_insert);
260
261 let vars: HashMap<String, String> = [("text".into(), "hello".into())].into();
262 let result = run_before_insert_hook(&spec, "- hello", &vars).unwrap();
263
264 assert_eq!(result.content, "- hello");
265 }
266
267 #[test]
268 fn test_after_insert_hook_runs() {
269 let temp = TempDir::new().unwrap();
270 let path = write_lua_capture(
271 temp.path(),
272 "test",
273 r#"
274return {
275 name = "test",
276 target = { file = "test.md", section = "Test" },
277 content = "- {{text}}",
278 after_insert = function(content, vars, target, result)
279 -- Side effect only, return value ignored
280 print("Inserted: " .. content)
281 end,
282}
283"#,
284 );
285
286 let spec = load_capture_from_lua(&path).unwrap();
287 assert!(spec.has_after_insert);
288
289 let vars: HashMap<String, String> = [("text".into(), "hello".into())].into();
290 let result = run_after_insert_hook(
291 &spec,
292 "- hello",
293 &vars,
294 std::path::Path::new("/tmp/test.md"),
295 Some(("Test", 2)),
296 )
297 .unwrap();
298
299 assert!(result.success);
300 }
301}