Skip to main content

mdvault_core/captures/
hooks.rs

1//! Capture lifecycle hooks execution.
2//!
3//! This module provides support for running before_insert and after_insert hooks
4//! defined in Lua capture specifications.
5
6use std::collections::HashMap;
7
8use crate::scripting::{LuaEngine, ScriptingError};
9
10use super::types::CaptureSpec;
11
12/// Result of running a before_insert hook.
13#[derive(Debug)]
14pub struct BeforeInsertResult {
15    /// Modified content to insert (or original if unchanged)
16    pub content: String,
17}
18
19/// Run the before_insert hook if defined.
20///
21/// The hook receives:
22/// - content: The rendered content template
23/// - vars: Table of all variables
24/// - target: Table with file, section, position
25///
26/// Returns: Modified content string
27pub fn run_before_insert_hook(
28    spec: &CaptureSpec,
29    content: &str,
30    vars: &HashMap<String, String>,
31) -> Result<BeforeInsertResult, ScriptingError> {
32    // If no hook defined, return content unchanged
33    if !spec.has_before_insert {
34        return Ok(BeforeInsertResult { content: content.to_string() });
35    }
36
37    // Get Lua source
38    let lua_source = spec
39        .lua_source
40        .as_ref()
41        .ok_or_else(|| ScriptingError::Other("Capture has no Lua source".to_string()))?;
42
43    // Create Lua engine
44    let engine = LuaEngine::sandboxed()?;
45    let lua = engine.lua();
46
47    // Execute the capture definition to get the table
48    let capture_table: mlua::Table =
49        lua.load(lua_source).eval().map_err(ScriptingError::Lua)?;
50
51    // Get the before_insert function
52    let hook_fn: mlua::Function =
53        capture_table.get("before_insert").map_err(ScriptingError::Lua)?;
54
55    // Build vars table
56    let vars_table = lua.create_table().map_err(ScriptingError::Lua)?;
57    for (k, v) in vars {
58        vars_table.set(k.as_str(), v.as_str()).map_err(ScriptingError::Lua)?;
59    }
60
61    // Build target table
62    let target_table = lua.create_table().map_err(ScriptingError::Lua)?;
63    target_table.set("file", spec.target.file.as_str()).map_err(ScriptingError::Lua)?;
64    if let Some(section) = &spec.target.section {
65        target_table.set("section", section.as_str()).map_err(ScriptingError::Lua)?;
66    }
67    let position_str = match spec.target.position {
68        super::types::CapturePosition::Begin => "begin",
69        super::types::CapturePosition::End => "end",
70    };
71    target_table.set("position", position_str).map_err(ScriptingError::Lua)?;
72
73    // Call the hook: before_insert(content, vars, target)
74    let result: mlua::Value =
75        hook_fn.call((content, vars_table, target_table)).map_err(ScriptingError::Lua)?;
76
77    // Extract result - should be a string (modified content)
78    let modified_content = match result {
79        mlua::Value::String(s) => s.to_str().map_err(ScriptingError::Lua)?.to_string(),
80        mlua::Value::Nil => content.to_string(), // Hook returned nil, use original
81        _ => {
82            return Err(ScriptingError::Other(
83                "before_insert hook must return a string or nil".to_string(),
84            ));
85        }
86    };
87
88    Ok(BeforeInsertResult { content: modified_content })
89}
90
91/// Result of running an after_insert hook.
92#[derive(Debug)]
93pub struct AfterInsertResult {
94    /// Whether the hook ran successfully
95    pub success: bool,
96}
97
98/// Run the after_insert hook if defined.
99///
100/// The hook receives:
101/// - content: The content that was inserted
102/// - vars: Table of all variables
103/// - target: Table with file, section, position
104/// - result: Table with target_file path and success status
105///
106/// Returns: Nothing (hook is for side effects only)
107pub fn run_after_insert_hook(
108    spec: &CaptureSpec,
109    content: &str,
110    vars: &HashMap<String, String>,
111    target_file: &std::path::Path,
112    section_matched: Option<(&str, u8)>,
113) -> Result<AfterInsertResult, ScriptingError> {
114    // If no hook defined, return success
115    if !spec.has_after_insert {
116        return Ok(AfterInsertResult { success: true });
117    }
118
119    // Get Lua source
120    let lua_source = spec
121        .lua_source
122        .as_ref()
123        .ok_or_else(|| ScriptingError::Other("Capture has no Lua source".to_string()))?;
124
125    // Create Lua engine
126    let engine = LuaEngine::sandboxed()?;
127    let lua = engine.lua();
128
129    // Execute the capture definition to get the table
130    let capture_table: mlua::Table =
131        lua.load(lua_source).eval().map_err(ScriptingError::Lua)?;
132
133    // Get the after_insert function
134    let hook_fn: mlua::Function =
135        capture_table.get("after_insert").map_err(ScriptingError::Lua)?;
136
137    // Build vars table
138    let vars_table = lua.create_table().map_err(ScriptingError::Lua)?;
139    for (k, v) in vars {
140        vars_table.set(k.as_str(), v.as_str()).map_err(ScriptingError::Lua)?;
141    }
142
143    // Build target table
144    let target_table = lua.create_table().map_err(ScriptingError::Lua)?;
145    target_table.set("file", spec.target.file.as_str()).map_err(ScriptingError::Lua)?;
146    if let Some(section) = &spec.target.section {
147        target_table.set("section", section.as_str()).map_err(ScriptingError::Lua)?;
148    }
149    let position_str = match spec.target.position {
150        super::types::CapturePosition::Begin => "begin",
151        super::types::CapturePosition::End => "end",
152    };
153    target_table.set("position", position_str).map_err(ScriptingError::Lua)?;
154
155    // Build result table
156    let result_table = lua.create_table().map_err(ScriptingError::Lua)?;
157    result_table
158        .set("target_file", target_file.to_string_lossy().as_ref())
159        .map_err(ScriptingError::Lua)?;
160    result_table.set("success", true).map_err(ScriptingError::Lua)?;
161    if let Some((section_title, level)) = section_matched {
162        result_table.set("section_title", section_title).map_err(ScriptingError::Lua)?;
163        result_table.set("section_level", level).map_err(ScriptingError::Lua)?;
164    }
165
166    // Call the hook: after_insert(content, vars, target, result)
167    let _: mlua::Value = hook_fn
168        .call((content, vars_table, target_table, result_table))
169        .map_err(ScriptingError::Lua)?;
170
171    Ok(AfterInsertResult { success: true })
172}
173
174#[cfg(test)]
175mod tests {
176    use super::*;
177    use crate::captures::lua_loader::load_capture_from_lua;
178    use std::fs;
179    use tempfile::TempDir;
180
181    fn write_lua_capture(
182        dir: &std::path::Path,
183        name: &str,
184        content: &str,
185    ) -> std::path::PathBuf {
186        let path = dir.join(format!("{}.lua", name));
187        fs::write(&path, content).unwrap();
188        path
189    }
190
191    #[test]
192    fn test_before_insert_hook_modifies_content() {
193        let temp = TempDir::new().unwrap();
194        let path = write_lua_capture(
195            temp.path(),
196            "test",
197            r#"
198return {
199    name = "test",
200    target = { file = "test.md", section = "Test" },
201    content = "- {{text}}",
202    before_insert = function(content, vars, target)
203        return "[HOOK] " .. content
204    end,
205}
206"#,
207        );
208
209        let spec = load_capture_from_lua(&path).unwrap();
210        assert!(spec.has_before_insert);
211
212        let vars: HashMap<String, String> = [("text".into(), "hello".into())].into();
213        let result = run_before_insert_hook(&spec, "- hello", &vars).unwrap();
214
215        assert_eq!(result.content, "[HOOK] - hello");
216    }
217
218    #[test]
219    fn test_before_insert_hook_returns_nil() {
220        let temp = TempDir::new().unwrap();
221        let path = write_lua_capture(
222            temp.path(),
223            "test",
224            r#"
225return {
226    name = "test",
227    target = { file = "test.md", section = "Test" },
228    content = "- {{text}}",
229    before_insert = function(content, vars, target)
230        return nil -- Let original content through
231    end,
232}
233"#,
234        );
235
236        let spec = load_capture_from_lua(&path).unwrap();
237        let vars: HashMap<String, String> = [("text".into(), "hello".into())].into();
238        let result = run_before_insert_hook(&spec, "- hello", &vars).unwrap();
239
240        assert_eq!(result.content, "- hello");
241    }
242
243    #[test]
244    fn test_no_hook_passes_through() {
245        let temp = TempDir::new().unwrap();
246        let path = write_lua_capture(
247            temp.path(),
248            "test",
249            r#"
250return {
251    name = "test",
252    target = { file = "test.md", section = "Test" },
253    content = "- {{text}}",
254}
255"#,
256        );
257
258        let spec = load_capture_from_lua(&path).unwrap();
259        assert!(!spec.has_before_insert);
260
261        let vars: HashMap<String, String> = [("text".into(), "hello".into())].into();
262        let result = run_before_insert_hook(&spec, "- hello", &vars).unwrap();
263
264        assert_eq!(result.content, "- hello");
265    }
266
267    #[test]
268    fn test_after_insert_hook_runs() {
269        let temp = TempDir::new().unwrap();
270        let path = write_lua_capture(
271            temp.path(),
272            "test",
273            r#"
274return {
275    name = "test",
276    target = { file = "test.md", section = "Test" },
277    content = "- {{text}}",
278    after_insert = function(content, vars, target, result)
279        -- Side effect only, return value ignored
280        print("Inserted: " .. content)
281    end,
282}
283"#,
284        );
285
286        let spec = load_capture_from_lua(&path).unwrap();
287        assert!(spec.has_after_insert);
288
289        let vars: HashMap<String, String> = [("text".into(), "hello".into())].into();
290        let result = run_after_insert_hook(
291            &spec,
292            "- hello",
293            &vars,
294            std::path::Path::new("/tmp/test.md"),
295            Some(("Test", 2)),
296        )
297        .unwrap();
298
299        assert!(result.success);
300    }
301}