1use std::collections::HashMap;
7
8use crate::scripting::{LuaEngine, ScriptingError};
9
10use super::types::CaptureSpec;
11
12#[derive(Debug)]
14pub struct BeforeInsertResult {
15 pub content: String,
17}
18
19pub fn run_before_insert_hook(
28 spec: &CaptureSpec,
29 content: &str,
30 vars: &HashMap<String, String>,
31) -> Result<BeforeInsertResult, ScriptingError> {
32 if !spec.has_before_insert {
34 return Ok(BeforeInsertResult { content: content.to_string() });
35 }
36
37 let lua_source = spec
39 .lua_source
40 .as_ref()
41 .ok_or_else(|| ScriptingError::Other("Capture has no Lua source".to_string()))?;
42
43 let engine = LuaEngine::sandboxed()?;
45 let lua = engine.lua();
46
47 let capture_table: mlua::Table =
49 lua.load(lua_source).eval().map_err(ScriptingError::Lua)?;
50
51 let hook_fn: mlua::Function = capture_table
53 .get("before_insert")
54 .map_err(ScriptingError::Lua)?;
55
56 let vars_table = lua.create_table().map_err(ScriptingError::Lua)?;
58 for (k, v) in vars {
59 vars_table.set(k.as_str(), v.as_str()).map_err(ScriptingError::Lua)?;
60 }
61
62 let target_table = lua.create_table().map_err(ScriptingError::Lua)?;
64 target_table
65 .set("file", spec.target.file.as_str())
66 .map_err(ScriptingError::Lua)?;
67 if let Some(section) = &spec.target.section {
68 target_table.set("section", section.as_str()).map_err(ScriptingError::Lua)?;
69 }
70 let position_str = match spec.target.position {
71 super::types::CapturePosition::Begin => "begin",
72 super::types::CapturePosition::End => "end",
73 };
74 target_table.set("position", position_str).map_err(ScriptingError::Lua)?;
75
76 let result: mlua::Value = hook_fn
78 .call((content, vars_table, target_table))
79 .map_err(ScriptingError::Lua)?;
80
81 let modified_content = match result {
83 mlua::Value::String(s) => s.to_str().map_err(ScriptingError::Lua)?.to_string(),
84 mlua::Value::Nil => content.to_string(), _ => {
86 return Err(ScriptingError::Other(
87 "before_insert hook must return a string or nil".to_string(),
88 ));
89 }
90 };
91
92 Ok(BeforeInsertResult { content: modified_content })
93}
94
95#[derive(Debug)]
97pub struct AfterInsertResult {
98 pub success: bool,
100}
101
102pub fn run_after_insert_hook(
112 spec: &CaptureSpec,
113 content: &str,
114 vars: &HashMap<String, String>,
115 target_file: &std::path::Path,
116 section_matched: Option<(&str, u8)>,
117) -> Result<AfterInsertResult, ScriptingError> {
118 if !spec.has_after_insert {
120 return Ok(AfterInsertResult { success: true });
121 }
122
123 let lua_source = spec
125 .lua_source
126 .as_ref()
127 .ok_or_else(|| ScriptingError::Other("Capture has no Lua source".to_string()))?;
128
129 let engine = LuaEngine::sandboxed()?;
131 let lua = engine.lua();
132
133 let capture_table: mlua::Table =
135 lua.load(lua_source).eval().map_err(ScriptingError::Lua)?;
136
137 let hook_fn: mlua::Function = capture_table
139 .get("after_insert")
140 .map_err(ScriptingError::Lua)?;
141
142 let vars_table = lua.create_table().map_err(ScriptingError::Lua)?;
144 for (k, v) in vars {
145 vars_table.set(k.as_str(), v.as_str()).map_err(ScriptingError::Lua)?;
146 }
147
148 let target_table = lua.create_table().map_err(ScriptingError::Lua)?;
150 target_table
151 .set("file", spec.target.file.as_str())
152 .map_err(ScriptingError::Lua)?;
153 if let Some(section) = &spec.target.section {
154 target_table.set("section", section.as_str()).map_err(ScriptingError::Lua)?;
155 }
156 let position_str = match spec.target.position {
157 super::types::CapturePosition::Begin => "begin",
158 super::types::CapturePosition::End => "end",
159 };
160 target_table.set("position", position_str).map_err(ScriptingError::Lua)?;
161
162 let result_table = lua.create_table().map_err(ScriptingError::Lua)?;
164 result_table
165 .set("target_file", target_file.to_string_lossy().as_ref())
166 .map_err(ScriptingError::Lua)?;
167 result_table.set("success", true).map_err(ScriptingError::Lua)?;
168 if let Some((section_title, level)) = section_matched {
169 result_table.set("section_title", section_title).map_err(ScriptingError::Lua)?;
170 result_table.set("section_level", level).map_err(ScriptingError::Lua)?;
171 }
172
173 let _: mlua::Value = hook_fn
175 .call((content, vars_table, target_table, result_table))
176 .map_err(ScriptingError::Lua)?;
177
178 Ok(AfterInsertResult { success: true })
179}
180
181#[cfg(test)]
182mod tests {
183 use super::*;
184 use crate::captures::lua_loader::load_capture_from_lua;
185 use std::fs;
186 use tempfile::TempDir;
187
188 fn write_lua_capture(dir: &std::path::Path, name: &str, content: &str) -> std::path::PathBuf {
189 let path = dir.join(format!("{}.lua", name));
190 fs::write(&path, content).unwrap();
191 path
192 }
193
194 #[test]
195 fn test_before_insert_hook_modifies_content() {
196 let temp = TempDir::new().unwrap();
197 let path = write_lua_capture(
198 temp.path(),
199 "test",
200 r#"
201return {
202 name = "test",
203 target = { file = "test.md", section = "Test" },
204 content = "- {{text}}",
205 before_insert = function(content, vars, target)
206 return "[HOOK] " .. content
207 end,
208}
209"#,
210 );
211
212 let spec = load_capture_from_lua(&path).unwrap();
213 assert!(spec.has_before_insert);
214
215 let vars: HashMap<String, String> = [("text".into(), "hello".into())].into();
216 let result = run_before_insert_hook(&spec, "- hello", &vars).unwrap();
217
218 assert_eq!(result.content, "[HOOK] - hello");
219 }
220
221 #[test]
222 fn test_before_insert_hook_returns_nil() {
223 let temp = TempDir::new().unwrap();
224 let path = write_lua_capture(
225 temp.path(),
226 "test",
227 r#"
228return {
229 name = "test",
230 target = { file = "test.md", section = "Test" },
231 content = "- {{text}}",
232 before_insert = function(content, vars, target)
233 return nil -- Let original content through
234 end,
235}
236"#,
237 );
238
239 let spec = load_capture_from_lua(&path).unwrap();
240 let vars: HashMap<String, String> = [("text".into(), "hello".into())].into();
241 let result = run_before_insert_hook(&spec, "- hello", &vars).unwrap();
242
243 assert_eq!(result.content, "- hello");
244 }
245
246 #[test]
247 fn test_no_hook_passes_through() {
248 let temp = TempDir::new().unwrap();
249 let path = write_lua_capture(
250 temp.path(),
251 "test",
252 r#"
253return {
254 name = "test",
255 target = { file = "test.md", section = "Test" },
256 content = "- {{text}}",
257}
258"#,
259 );
260
261 let spec = load_capture_from_lua(&path).unwrap();
262 assert!(!spec.has_before_insert);
263
264 let vars: HashMap<String, String> = [("text".into(), "hello".into())].into();
265 let result = run_before_insert_hook(&spec, "- hello", &vars).unwrap();
266
267 assert_eq!(result.content, "- hello");
268 }
269
270 #[test]
271 fn test_after_insert_hook_runs() {
272 let temp = TempDir::new().unwrap();
273 let path = write_lua_capture(
274 temp.path(),
275 "test",
276 r#"
277return {
278 name = "test",
279 target = { file = "test.md", section = "Test" },
280 content = "- {{text}}",
281 after_insert = function(content, vars, target, result)
282 -- Side effect only, return value ignored
283 print("Inserted: " .. content)
284 end,
285}
286"#,
287 );
288
289 let spec = load_capture_from_lua(&path).unwrap();
290 assert!(spec.has_after_insert);
291
292 let vars: HashMap<String, String> = [("text".into(), "hello".into())].into();
293 let result = run_after_insert_hook(
294 &spec,
295 "- hello",
296 &vars,
297 std::path::Path::new("/tmp/test.md"),
298 Some(("Test", 2)),
299 )
300 .unwrap();
301
302 assert!(result.success);
303 }
304}