Skip to main content

mdvault_core/captures/
hooks.rs

1//! Capture lifecycle hooks execution.
2//!
3//! This module provides support for running before_insert and after_insert hooks
4//! defined in Lua capture specifications.
5
6use std::collections::HashMap;
7
8use crate::scripting::{LuaEngine, ScriptingError};
9
10use super::types::CaptureSpec;
11
12/// Result of running a before_insert hook.
13#[derive(Debug)]
14pub struct BeforeInsertResult {
15    /// Modified content to insert (or original if unchanged)
16    pub content: String,
17}
18
19/// Run the before_insert hook if defined.
20///
21/// The hook receives:
22/// - content: The rendered content template
23/// - vars: Table of all variables
24/// - target: Table with file, section, position
25///
26/// Returns: Modified content string
27pub fn run_before_insert_hook(
28    spec: &CaptureSpec,
29    content: &str,
30    vars: &HashMap<String, String>,
31) -> Result<BeforeInsertResult, ScriptingError> {
32    // If no hook defined, return content unchanged
33    if !spec.has_before_insert {
34        return Ok(BeforeInsertResult { content: content.to_string() });
35    }
36
37    // Get Lua source
38    let lua_source = spec
39        .lua_source
40        .as_ref()
41        .ok_or_else(|| ScriptingError::Other("Capture has no Lua source".to_string()))?;
42
43    // Create Lua engine
44    let engine = LuaEngine::sandboxed()?;
45    let lua = engine.lua();
46
47    // Execute the capture definition to get the table
48    let capture_table: mlua::Table =
49        lua.load(lua_source).eval().map_err(ScriptingError::Lua)?;
50
51    // Get the before_insert function
52    let hook_fn: mlua::Function = capture_table
53        .get("before_insert")
54        .map_err(ScriptingError::Lua)?;
55
56    // Build vars table
57    let vars_table = lua.create_table().map_err(ScriptingError::Lua)?;
58    for (k, v) in vars {
59        vars_table.set(k.as_str(), v.as_str()).map_err(ScriptingError::Lua)?;
60    }
61
62    // Build target table
63    let target_table = lua.create_table().map_err(ScriptingError::Lua)?;
64    target_table
65        .set("file", spec.target.file.as_str())
66        .map_err(ScriptingError::Lua)?;
67    if let Some(section) = &spec.target.section {
68        target_table.set("section", section.as_str()).map_err(ScriptingError::Lua)?;
69    }
70    let position_str = match spec.target.position {
71        super::types::CapturePosition::Begin => "begin",
72        super::types::CapturePosition::End => "end",
73    };
74    target_table.set("position", position_str).map_err(ScriptingError::Lua)?;
75
76    // Call the hook: before_insert(content, vars, target)
77    let result: mlua::Value = hook_fn
78        .call((content, vars_table, target_table))
79        .map_err(ScriptingError::Lua)?;
80
81    // Extract result - should be a string (modified content)
82    let modified_content = match result {
83        mlua::Value::String(s) => s.to_str().map_err(ScriptingError::Lua)?.to_string(),
84        mlua::Value::Nil => content.to_string(), // Hook returned nil, use original
85        _ => {
86            return Err(ScriptingError::Other(
87                "before_insert hook must return a string or nil".to_string(),
88            ));
89        }
90    };
91
92    Ok(BeforeInsertResult { content: modified_content })
93}
94
95/// Result of running an after_insert hook.
96#[derive(Debug)]
97pub struct AfterInsertResult {
98    /// Whether the hook ran successfully
99    pub success: bool,
100}
101
102/// Run the after_insert hook if defined.
103///
104/// The hook receives:
105/// - content: The content that was inserted
106/// - vars: Table of all variables
107/// - target: Table with file, section, position
108/// - result: Table with target_file path and success status
109///
110/// Returns: Nothing (hook is for side effects only)
111pub fn run_after_insert_hook(
112    spec: &CaptureSpec,
113    content: &str,
114    vars: &HashMap<String, String>,
115    target_file: &std::path::Path,
116    section_matched: Option<(&str, u8)>,
117) -> Result<AfterInsertResult, ScriptingError> {
118    // If no hook defined, return success
119    if !spec.has_after_insert {
120        return Ok(AfterInsertResult { success: true });
121    }
122
123    // Get Lua source
124    let lua_source = spec
125        .lua_source
126        .as_ref()
127        .ok_or_else(|| ScriptingError::Other("Capture has no Lua source".to_string()))?;
128
129    // Create Lua engine
130    let engine = LuaEngine::sandboxed()?;
131    let lua = engine.lua();
132
133    // Execute the capture definition to get the table
134    let capture_table: mlua::Table =
135        lua.load(lua_source).eval().map_err(ScriptingError::Lua)?;
136
137    // Get the after_insert function
138    let hook_fn: mlua::Function = capture_table
139        .get("after_insert")
140        .map_err(ScriptingError::Lua)?;
141
142    // Build vars table
143    let vars_table = lua.create_table().map_err(ScriptingError::Lua)?;
144    for (k, v) in vars {
145        vars_table.set(k.as_str(), v.as_str()).map_err(ScriptingError::Lua)?;
146    }
147
148    // Build target table
149    let target_table = lua.create_table().map_err(ScriptingError::Lua)?;
150    target_table
151        .set("file", spec.target.file.as_str())
152        .map_err(ScriptingError::Lua)?;
153    if let Some(section) = &spec.target.section {
154        target_table.set("section", section.as_str()).map_err(ScriptingError::Lua)?;
155    }
156    let position_str = match spec.target.position {
157        super::types::CapturePosition::Begin => "begin",
158        super::types::CapturePosition::End => "end",
159    };
160    target_table.set("position", position_str).map_err(ScriptingError::Lua)?;
161
162    // Build result table
163    let result_table = lua.create_table().map_err(ScriptingError::Lua)?;
164    result_table
165        .set("target_file", target_file.to_string_lossy().as_ref())
166        .map_err(ScriptingError::Lua)?;
167    result_table.set("success", true).map_err(ScriptingError::Lua)?;
168    if let Some((section_title, level)) = section_matched {
169        result_table.set("section_title", section_title).map_err(ScriptingError::Lua)?;
170        result_table.set("section_level", level).map_err(ScriptingError::Lua)?;
171    }
172
173    // Call the hook: after_insert(content, vars, target, result)
174    let _: mlua::Value = hook_fn
175        .call((content, vars_table, target_table, result_table))
176        .map_err(ScriptingError::Lua)?;
177
178    Ok(AfterInsertResult { success: true })
179}
180
181#[cfg(test)]
182mod tests {
183    use super::*;
184    use crate::captures::lua_loader::load_capture_from_lua;
185    use std::fs;
186    use tempfile::TempDir;
187
188    fn write_lua_capture(dir: &std::path::Path, name: &str, content: &str) -> std::path::PathBuf {
189        let path = dir.join(format!("{}.lua", name));
190        fs::write(&path, content).unwrap();
191        path
192    }
193
194    #[test]
195    fn test_before_insert_hook_modifies_content() {
196        let temp = TempDir::new().unwrap();
197        let path = write_lua_capture(
198            temp.path(),
199            "test",
200            r#"
201return {
202    name = "test",
203    target = { file = "test.md", section = "Test" },
204    content = "- {{text}}",
205    before_insert = function(content, vars, target)
206        return "[HOOK] " .. content
207    end,
208}
209"#,
210        );
211
212        let spec = load_capture_from_lua(&path).unwrap();
213        assert!(spec.has_before_insert);
214
215        let vars: HashMap<String, String> = [("text".into(), "hello".into())].into();
216        let result = run_before_insert_hook(&spec, "- hello", &vars).unwrap();
217
218        assert_eq!(result.content, "[HOOK] - hello");
219    }
220
221    #[test]
222    fn test_before_insert_hook_returns_nil() {
223        let temp = TempDir::new().unwrap();
224        let path = write_lua_capture(
225            temp.path(),
226            "test",
227            r#"
228return {
229    name = "test",
230    target = { file = "test.md", section = "Test" },
231    content = "- {{text}}",
232    before_insert = function(content, vars, target)
233        return nil -- Let original content through
234    end,
235}
236"#,
237        );
238
239        let spec = load_capture_from_lua(&path).unwrap();
240        let vars: HashMap<String, String> = [("text".into(), "hello".into())].into();
241        let result = run_before_insert_hook(&spec, "- hello", &vars).unwrap();
242
243        assert_eq!(result.content, "- hello");
244    }
245
246    #[test]
247    fn test_no_hook_passes_through() {
248        let temp = TempDir::new().unwrap();
249        let path = write_lua_capture(
250            temp.path(),
251            "test",
252            r#"
253return {
254    name = "test",
255    target = { file = "test.md", section = "Test" },
256    content = "- {{text}}",
257}
258"#,
259        );
260
261        let spec = load_capture_from_lua(&path).unwrap();
262        assert!(!spec.has_before_insert);
263
264        let vars: HashMap<String, String> = [("text".into(), "hello".into())].into();
265        let result = run_before_insert_hook(&spec, "- hello", &vars).unwrap();
266
267        assert_eq!(result.content, "- hello");
268    }
269
270    #[test]
271    fn test_after_insert_hook_runs() {
272        let temp = TempDir::new().unwrap();
273        let path = write_lua_capture(
274            temp.path(),
275            "test",
276            r#"
277return {
278    name = "test",
279    target = { file = "test.md", section = "Test" },
280    content = "- {{text}}",
281    after_insert = function(content, vars, target, result)
282        -- Side effect only, return value ignored
283        print("Inserted: " .. content)
284    end,
285}
286"#,
287        );
288
289        let spec = load_capture_from_lua(&path).unwrap();
290        assert!(spec.has_after_insert);
291
292        let vars: HashMap<String, String> = [("text".into(), "hello".into())].into();
293        let result = run_after_insert_hook(
294            &spec,
295            "- hello",
296            &vars,
297            std::path::Path::new("/tmp/test.md"),
298            Some(("Test", 2)),
299        )
300        .unwrap();
301
302        assert!(result.success);
303    }
304}