Skip to main content

mdvault_core/scripting/
hook_runner.rs

1//! Hook execution for lifecycle events.
2//!
3//! This module provides functions to run lifecycle hooks defined in type definitions.
4
5use super::engine::LuaEngine;
6use super::hooks::{HookError, NoteContext};
7use super::types::SandboxConfig;
8use super::vault_context::VaultContext;
9use crate::types::definition::TypeDefinition;
10use crate::types::validation::yaml_to_lua_table;
11use tracing::debug;
12
13/// Result of running a hook that may modify the note.
14#[derive(Debug)]
15pub struct HookResult {
16    /// Whether the hook made changes to the note.
17    pub modified: bool,
18    /// The updated frontmatter (if modified).
19    pub frontmatter: Option<serde_yaml::Value>,
20    /// The updated content (if modified).
21    pub content: Option<String>,
22    /// The updated template variables (if modified).
23    pub variables: Option<serde_yaml::Value>,
24}
25
26/// Alias for backwards compatibility.
27pub type UpdateHookResult = HookResult;
28
29/// Run the `on_create` hook for a type definition.
30///
31/// This function is called after a note is created to allow the type definition
32/// to perform additional operations like logging to daily notes, updating indexes,
33/// or modifying the note's frontmatter.
34///
35/// # Arguments
36///
37/// * `typedef` - The type definition containing the hook
38/// * `note_ctx` - Context about the created note
39/// * `vault_ctx` - Vault context with access to repositories
40///
41/// # Returns
42///
43/// * `Ok(HookResult)` with any modifications from the hook
44/// * `Err(HookError)` on failure
45///
46/// # Example
47///
48/// ```ignore
49/// use mdvault_core::scripting::{run_on_create_hook, NoteContext, VaultContext};
50///
51/// let note_ctx = NoteContext::new(path, "task".into(), frontmatter, content, variables);
52/// let result = run_on_create_hook(&typedef, &note_ctx, vault_ctx)?;
53/// if result.modified {
54///     // Write back the updated content
55/// }
56/// ```
57pub fn run_on_create_hook(
58    typedef: &TypeDefinition,
59    note_ctx: &NoteContext,
60    vault_ctx: VaultContext,
61) -> Result<HookResult, HookError> {
62    // Skip if no hook defined
63    if !typedef.has_on_create_hook {
64        return Ok(HookResult {
65            modified: false,
66            frontmatter: None,
67            content: None,
68            variables: None,
69        });
70    }
71
72    // Create engine with vault context
73    let engine = LuaEngine::with_vault_context(SandboxConfig::restricted(), vault_ctx)
74        .map_err(|e| HookError::LuaError(e.to_string()))?;
75
76    let lua = engine.lua();
77
78    // Load and evaluate the type definition to get the table
79    let typedef_table: mlua::Table =
80        lua.load(&typedef.lua_source).eval().map_err(|e| {
81            HookError::LuaError(format!("failed to load type definition: {}", e))
82        })?;
83
84    // Build note table for the hook
85    let note_table =
86        lua.create_table().map_err(|e| HookError::LuaError(e.to_string()))?;
87
88    note_table
89        .set("path", note_ctx.path.to_string_lossy().to_string())
90        .map_err(|e| HookError::LuaError(e.to_string()))?;
91
92    note_table
93        .set("type", note_ctx.note_type.clone())
94        .map_err(|e| HookError::LuaError(e.to_string()))?;
95
96    note_table
97        .set("content", note_ctx.content.clone())
98        .map_err(|e| HookError::LuaError(e.to_string()))?;
99
100    // Convert frontmatter to Lua table
101    let fm_table = yaml_to_lua_table(lua, &note_ctx.frontmatter)
102        .map_err(|e| HookError::LuaError(e.to_string()))?;
103
104    note_table
105        .set("frontmatter", fm_table)
106        .map_err(|e| HookError::LuaError(e.to_string()))?;
107
108    // Convert variables to Lua table
109    debug!("Hook input variables: {:?}", note_ctx.variables);
110    let vars_table = yaml_to_lua_table(lua, &note_ctx.variables)
111        .map_err(|e| HookError::LuaError(e.to_string()))?;
112
113    note_table
114        .set("variables", vars_table)
115        .map_err(|e| HookError::LuaError(e.to_string()))?;
116
117    // Get on_create function
118    let on_create_fn: mlua::Function = typedef_table.get("on_create").map_err(|e| {
119        HookError::LuaError(format!("on_create function not found: {}", e))
120    })?;
121
122    // Call the hook - it may return a modified note table
123    let result: mlua::Value = on_create_fn
124        .call(note_table)
125        .map_err(|e| HookError::Execution(format!("on_create hook failed: {}", e)))?;
126
127    // Check if hook returned a modified note
128    match result {
129        mlua::Value::Table(returned_note) => {
130            // Extract frontmatter, content, and variables if present
131            let frontmatter: Option<serde_yaml::Value> =
132                if let Ok(fm_table) = returned_note.get::<mlua::Table>("frontmatter") {
133                    Some(lua_table_to_yaml(&fm_table)?)
134                } else {
135                    None
136                };
137
138            let content: Option<String> = returned_note.get("content").ok();
139            let content = match content {
140                Some(ref c) if c != &note_ctx.content => content,
141                _ => None,
142            };
143
144            let variables: Option<serde_yaml::Value> =
145                if let Ok(vars_table) = returned_note.get::<mlua::Table>("variables") {
146                    let v = Some(lua_table_to_yaml(&vars_table)?);
147                    debug!("Hook output variables: {:?}", v);
148                    v
149                } else {
150                    None
151                };
152
153            let modified =
154                frontmatter.is_some() || content.is_some() || variables.is_some();
155            Ok(HookResult { modified, frontmatter, content, variables })
156        }
157        mlua::Value::Nil => {
158            // Hook returned nil, no modifications
159            Ok(HookResult {
160                modified: false,
161                frontmatter: None,
162                content: None,
163                variables: None,
164            })
165        }
166        _ => {
167            // Unexpected return type
168            Ok(HookResult {
169                modified: false,
170                frontmatter: None,
171                content: None,
172                variables: None,
173            })
174        }
175    }
176}
177
178/// Run the `on_update` hook for a type definition.
179///
180/// This function is called after a note is modified (via capture operations) to allow
181/// the type definition to perform additional operations like updating timestamps.
182///
183/// Unlike `on_create`, this hook can return a modified note which will be written back.
184///
185/// # Arguments
186///
187/// * `typedef` - The type definition containing the hook
188/// * `note_ctx` - Context about the updated note
189/// * `vault_ctx` - Vault context with access to repositories
190///
191/// # Returns
192///
193/// * `Ok(UpdateHookResult)` with any modifications from the hook
194/// * `Err(HookError)` on failure
195///
196/// # Example
197///
198/// ```ignore
199/// use mdvault_core::scripting::{run_on_update_hook, NoteContext, VaultContext};
200///
201/// let note_ctx = NoteContext::new(path, "task".into(), frontmatter, content);
202/// let result = run_on_update_hook(&typedef, &note_ctx, vault_ctx)?;
203/// if result.modified {
204///     // Write back the updated content
205/// }
206/// ```
207pub fn run_on_update_hook(
208    typedef: &TypeDefinition,
209    note_ctx: &NoteContext,
210    vault_ctx: VaultContext,
211) -> Result<UpdateHookResult, HookError> {
212    // Skip if no hook defined
213    if !typedef.has_on_update_hook {
214        return Ok(UpdateHookResult {
215            modified: false,
216            frontmatter: None,
217            content: None,
218            variables: None,
219        });
220    }
221
222    // Create engine with vault context
223    let engine = LuaEngine::with_vault_context(SandboxConfig::restricted(), vault_ctx)
224        .map_err(|e| HookError::LuaError(e.to_string()))?;
225
226    let lua = engine.lua();
227
228    // Load and evaluate the type definition to get the table
229    let typedef_table: mlua::Table =
230        lua.load(&typedef.lua_source).eval().map_err(|e| {
231            HookError::LuaError(format!("failed to load type definition: {}", e))
232        })?;
233
234    // Build note table for the hook
235    let note_table =
236        lua.create_table().map_err(|e| HookError::LuaError(e.to_string()))?;
237
238    note_table
239        .set("path", note_ctx.path.to_string_lossy().to_string())
240        .map_err(|e| HookError::LuaError(e.to_string()))?;
241
242    note_table
243        .set("type", note_ctx.note_type.clone())
244        .map_err(|e| HookError::LuaError(e.to_string()))?;
245
246    note_table
247        .set("content", note_ctx.content.clone())
248        .map_err(|e| HookError::LuaError(e.to_string()))?;
249
250    // Convert frontmatter to Lua table
251    let fm_table = yaml_to_lua_table(lua, &note_ctx.frontmatter)
252        .map_err(|e| HookError::LuaError(e.to_string()))?;
253
254    note_table
255        .set("frontmatter", fm_table)
256        .map_err(|e| HookError::LuaError(e.to_string()))?;
257
258    // Get on_update function
259    let on_update_fn: mlua::Function = typedef_table.get("on_update").map_err(|e| {
260        HookError::LuaError(format!("on_update function not found: {}", e))
261    })?;
262
263    // Call the hook - it may return a modified note table
264    let result: mlua::Value = on_update_fn
265        .call(note_table)
266        .map_err(|e| HookError::Execution(format!("on_update hook failed: {}", e)))?;
267
268    // Check if hook returned a modified note
269    match result {
270        mlua::Value::Table(returned_note) => {
271            // Extract frontmatter and content if present
272            let frontmatter: Option<serde_yaml::Value> =
273                if let Ok(fm_table) = returned_note.get::<mlua::Table>("frontmatter") {
274                    Some(lua_table_to_yaml(&fm_table)?)
275                } else {
276                    None
277                };
278
279            let content: Option<String> = returned_note.get("content").ok();
280
281            let modified = frontmatter.is_some() || content.is_some();
282            Ok(UpdateHookResult { modified, frontmatter, content, variables: None })
283        }
284        mlua::Value::Nil => {
285            // Hook returned nil, no modifications
286            Ok(UpdateHookResult {
287                modified: false,
288                frontmatter: None,
289                content: None,
290                variables: None,
291            })
292        }
293        _ => {
294            // Unexpected return type
295            Ok(UpdateHookResult {
296                modified: false,
297                frontmatter: None,
298                content: None,
299                variables: None,
300            })
301        }
302    }
303}
304
305/// Convert a Lua table to serde_yaml::Value.
306fn lua_table_to_yaml(table: &mlua::Table) -> Result<serde_yaml::Value, HookError> {
307    let mut map = serde_yaml::Mapping::new();
308
309    for pair in table.pairs::<mlua::Value, mlua::Value>() {
310        let (key, value) = pair.map_err(|e| HookError::LuaError(e.to_string()))?;
311
312        let yaml_key = match key {
313            mlua::Value::String(s) => {
314                let str_val =
315                    s.to_str().map_err(|e| HookError::LuaError(e.to_string()))?;
316                serde_yaml::Value::String(str_val.to_string())
317            }
318            mlua::Value::Integer(i) => serde_yaml::Value::Number(i.into()),
319            _ => continue, // Skip non-string/integer keys
320        };
321
322        let yaml_value = lua_value_to_yaml(value)?;
323        map.insert(yaml_key, yaml_value);
324    }
325
326    Ok(serde_yaml::Value::Mapping(map))
327}
328
329/// Convert a single Lua value to serde_yaml::Value.
330fn lua_value_to_yaml(value: mlua::Value) -> Result<serde_yaml::Value, HookError> {
331    match value {
332        mlua::Value::Nil => Ok(serde_yaml::Value::Null),
333        mlua::Value::Boolean(b) => Ok(serde_yaml::Value::Bool(b)),
334        mlua::Value::Integer(i) => Ok(serde_yaml::Value::Number(i.into())),
335        mlua::Value::Number(n) => {
336            Ok(serde_yaml::Value::Number(serde_yaml::Number::from(n)))
337        }
338        mlua::Value::String(s) => {
339            let str_val = s.to_str().map_err(|e| HookError::LuaError(e.to_string()))?;
340            Ok(serde_yaml::Value::String(str_val.to_string()))
341        }
342        mlua::Value::Table(t) => {
343            // Check if it's an array or a map
344            if is_lua_array(&t) {
345                let mut seq = Vec::new();
346                for pair in t.pairs::<i64, mlua::Value>() {
347                    let (_, v) = pair.map_err(|e| HookError::LuaError(e.to_string()))?;
348                    seq.push(lua_value_to_yaml(v)?);
349                }
350                Ok(serde_yaml::Value::Sequence(seq))
351            } else {
352                lua_table_to_yaml(&t)
353            }
354        }
355        _ => Ok(serde_yaml::Value::Null),
356    }
357}
358
359/// Check if a Lua table is an array (sequential integer keys starting from 1).
360fn is_lua_array(table: &mlua::Table) -> bool {
361    let len = table.raw_len();
362    if len == 0 {
363        // Could be empty table, check for any keys
364        table.pairs::<mlua::Value, mlua::Value>().next().is_none()
365    } else {
366        // Check if keys are 1..=len
367        for i in 1..=len {
368            if table.raw_get::<mlua::Value>(i).is_err() {
369                return false;
370            }
371        }
372        true
373    }
374}
375
376#[cfg(test)]
377mod tests {
378    use super::*;
379    use std::collections::HashMap;
380    use std::path::PathBuf;
381
382    fn make_typedef_with_hook(lua_source: &str) -> TypeDefinition {
383        TypeDefinition {
384            name: "test".to_string(),
385            description: None,
386            source_path: PathBuf::new(),
387            schema: HashMap::new(),
388            output: None,
389            frontmatter_order: None,
390            variables: crate::vars::VarsMap::new(),
391            has_validate_fn: false,
392            has_on_create_hook: true,
393            has_on_update_hook: false,
394            is_builtin_override: false,
395            lua_source: lua_source.to_string(),
396        }
397    }
398
399    fn make_note_ctx() -> NoteContext {
400        NoteContext {
401            path: PathBuf::from("test.md"),
402            note_type: "test".to_string(),
403            frontmatter: serde_yaml::Value::Mapping(serde_yaml::Mapping::new()),
404            content: "# Test".to_string(),
405            variables: serde_yaml::Value::Null,
406        }
407    }
408
409    #[test]
410    fn test_skip_if_no_hook() {
411        let typedef = TypeDefinition {
412            name: "test".to_string(),
413            description: None,
414            source_path: PathBuf::new(),
415            schema: HashMap::new(),
416            output: None,
417            frontmatter_order: None,
418            variables: crate::vars::VarsMap::new(),
419            has_validate_fn: false,
420            has_on_create_hook: false, // No hook
421            has_on_update_hook: false,
422            is_builtin_override: false,
423            lua_source: String::new(),
424        };
425
426        // Create a minimal vault context - this won't be used since there's no hook
427        // We can't easily create a VaultContext in tests without real repositories,
428        // but since has_on_create_hook is false, it will return early
429        let _note_ctx = make_note_ctx();
430
431        // This test verifies that when has_on_create_hook is false,
432        // the function returns Ok(()) without trying to access vault_ctx
433        // However, we need a VaultContext to call the function...
434        // For now, just test the hook detection logic works.
435        assert!(!typedef.has_on_create_hook);
436    }
437
438    #[test]
439    fn test_hook_receives_note_context() {
440        // This test verifies the Lua hook structure works
441        // We create a hook that just returns true without vault operations
442        let lua_source = r#"
443            return {
444                on_create = function(note)
445                    -- Just verify we can access note fields
446                    local _ = note.path
447                    local _ = note.type
448                    local _ = note.content
449                    local _ = note.frontmatter
450                    return note
451                end
452            }
453        "#;
454
455        let _typedef = make_typedef_with_hook(lua_source);
456        let _note_ctx = make_note_ctx();
457
458        // Create a sandboxed engine to test the Lua code directly
459        let engine = LuaEngine::sandboxed().unwrap();
460        let lua = engine.lua();
461
462        // Load the typedef
463        let typedef_table: mlua::Table = lua.load(lua_source).eval().unwrap();
464
465        // Build note table
466        let note_table = lua.create_table().unwrap();
467        note_table.set("path", "test.md").unwrap();
468        note_table.set("type", "test").unwrap();
469        note_table.set("content", "# Test").unwrap();
470        let fm = lua.create_table().unwrap();
471        note_table.set("frontmatter", fm).unwrap();
472
473        // Call on_create
474        let on_create: mlua::Function = typedef_table.get("on_create").unwrap();
475        let result = on_create.call::<mlua::Value>(note_table);
476
477        assert!(result.is_ok());
478    }
479}