mdvault_core/scripting/
hook_runner.rs

1//! Hook execution for lifecycle events.
2//!
3//! This module provides functions to run lifecycle hooks defined in type definitions.
4
5use super::engine::LuaEngine;
6use super::hooks::{HookError, NoteContext};
7use super::types::SandboxConfig;
8use super::vault_context::VaultContext;
9use crate::types::definition::TypeDefinition;
10use crate::types::validation::yaml_to_lua_table;
11
12/// Result of running an `on_update` hook that may modify the note.
13#[derive(Debug)]
14pub struct UpdateHookResult {
15    /// Whether the hook made changes to the note.
16    pub modified: bool,
17    /// The updated frontmatter (if modified).
18    pub frontmatter: Option<serde_yaml::Value>,
19    /// The updated content (if modified).
20    pub content: Option<String>,
21}
22
23/// Run the `on_create` hook for a type definition.
24///
25/// This function is called after a note is created to allow the type definition
26/// to perform additional operations like logging to daily notes or updating indexes.
27///
28/// # Arguments
29///
30/// * `typedef` - The type definition containing the hook
31/// * `note_ctx` - Context about the created note
32/// * `vault_ctx` - Vault context with access to repositories
33///
34/// # Returns
35///
36/// * `Ok(())` if the hook succeeds or doesn't exist
37/// * `Err(HookError)` on failure
38///
39/// # Example
40///
41/// ```ignore
42/// use mdvault_core::scripting::{run_on_create_hook, NoteContext, VaultContext};
43///
44/// let note_ctx = NoteContext::new(path, "task".into(), frontmatter, content);
45/// run_on_create_hook(&typedef, &note_ctx, vault_ctx)?;
46/// ```
47pub fn run_on_create_hook(
48    typedef: &TypeDefinition,
49    note_ctx: &NoteContext,
50    vault_ctx: VaultContext,
51) -> Result<(), HookError> {
52    // Skip if no hook defined
53    if !typedef.has_on_create_hook {
54        return Ok(());
55    }
56
57    // Create engine with vault context
58    let engine = LuaEngine::with_vault_context(SandboxConfig::restricted(), vault_ctx)
59        .map_err(|e| HookError::LuaError(e.to_string()))?;
60
61    let lua = engine.lua();
62
63    // Load and evaluate the type definition to get the table
64    let typedef_table: mlua::Table =
65        lua.load(&typedef.lua_source).eval().map_err(|e| {
66            HookError::LuaError(format!("failed to load type definition: {}", e))
67        })?;
68
69    // Build note table for the hook
70    let note_table =
71        lua.create_table().map_err(|e| HookError::LuaError(e.to_string()))?;
72
73    note_table
74        .set("path", note_ctx.path.to_string_lossy().to_string())
75        .map_err(|e| HookError::LuaError(e.to_string()))?;
76
77    note_table
78        .set("type", note_ctx.note_type.clone())
79        .map_err(|e| HookError::LuaError(e.to_string()))?;
80
81    note_table
82        .set("content", note_ctx.content.clone())
83        .map_err(|e| HookError::LuaError(e.to_string()))?;
84
85    // Convert frontmatter to Lua table
86    let fm_table = yaml_to_lua_table(lua, &note_ctx.frontmatter)
87        .map_err(|e| HookError::LuaError(e.to_string()))?;
88
89    note_table
90        .set("frontmatter", fm_table)
91        .map_err(|e| HookError::LuaError(e.to_string()))?;
92
93    // Get on_create function
94    let on_create_fn: mlua::Function = typedef_table.get("on_create").map_err(|e| {
95        HookError::LuaError(format!("on_create function not found: {}", e))
96    })?;
97
98    // Call the hook
99    // The hook receives the note table and can call mdv.template/capture/macro
100    // We don't currently use the return value, but hooks are expected to return the note
101    on_create_fn
102        .call::<()>(note_table)
103        .map_err(|e| HookError::Execution(format!("on_create hook failed: {}", e)))?;
104
105    Ok(())
106}
107
108/// Run the `on_update` hook for a type definition.
109///
110/// This function is called after a note is modified (via capture operations) to allow
111/// the type definition to perform additional operations like updating timestamps.
112///
113/// Unlike `on_create`, this hook can return a modified note which will be written back.
114///
115/// # Arguments
116///
117/// * `typedef` - The type definition containing the hook
118/// * `note_ctx` - Context about the updated note
119/// * `vault_ctx` - Vault context with access to repositories
120///
121/// # Returns
122///
123/// * `Ok(UpdateHookResult)` with any modifications from the hook
124/// * `Err(HookError)` on failure
125///
126/// # Example
127///
128/// ```ignore
129/// use mdvault_core::scripting::{run_on_update_hook, NoteContext, VaultContext};
130///
131/// let note_ctx = NoteContext::new(path, "task".into(), frontmatter, content);
132/// let result = run_on_update_hook(&typedef, &note_ctx, vault_ctx)?;
133/// if result.modified {
134///     // Write back the updated content
135/// }
136/// ```
137pub fn run_on_update_hook(
138    typedef: &TypeDefinition,
139    note_ctx: &NoteContext,
140    vault_ctx: VaultContext,
141) -> Result<UpdateHookResult, HookError> {
142    // Skip if no hook defined
143    if !typedef.has_on_update_hook {
144        return Ok(UpdateHookResult {
145            modified: false,
146            frontmatter: None,
147            content: None,
148        });
149    }
150
151    // Create engine with vault context
152    let engine = LuaEngine::with_vault_context(SandboxConfig::restricted(), vault_ctx)
153        .map_err(|e| HookError::LuaError(e.to_string()))?;
154
155    let lua = engine.lua();
156
157    // Load and evaluate the type definition to get the table
158    let typedef_table: mlua::Table =
159        lua.load(&typedef.lua_source).eval().map_err(|e| {
160            HookError::LuaError(format!("failed to load type definition: {}", e))
161        })?;
162
163    // Build note table for the hook
164    let note_table =
165        lua.create_table().map_err(|e| HookError::LuaError(e.to_string()))?;
166
167    note_table
168        .set("path", note_ctx.path.to_string_lossy().to_string())
169        .map_err(|e| HookError::LuaError(e.to_string()))?;
170
171    note_table
172        .set("type", note_ctx.note_type.clone())
173        .map_err(|e| HookError::LuaError(e.to_string()))?;
174
175    note_table
176        .set("content", note_ctx.content.clone())
177        .map_err(|e| HookError::LuaError(e.to_string()))?;
178
179    // Convert frontmatter to Lua table
180    let fm_table = yaml_to_lua_table(lua, &note_ctx.frontmatter)
181        .map_err(|e| HookError::LuaError(e.to_string()))?;
182
183    note_table
184        .set("frontmatter", fm_table)
185        .map_err(|e| HookError::LuaError(e.to_string()))?;
186
187    // Get on_update function
188    let on_update_fn: mlua::Function = typedef_table.get("on_update").map_err(|e| {
189        HookError::LuaError(format!("on_update function not found: {}", e))
190    })?;
191
192    // Call the hook - it may return a modified note table
193    let result: mlua::Value = on_update_fn
194        .call(note_table)
195        .map_err(|e| HookError::Execution(format!("on_update hook failed: {}", e)))?;
196
197    // Check if hook returned a modified note
198    match result {
199        mlua::Value::Table(returned_note) => {
200            // Extract frontmatter and content if present
201            let frontmatter: Option<serde_yaml::Value> =
202                if let Ok(fm_table) = returned_note.get::<mlua::Table>("frontmatter") {
203                    Some(lua_table_to_yaml(&fm_table)?)
204                } else {
205                    None
206                };
207
208            let content: Option<String> = returned_note.get("content").ok();
209
210            let modified = frontmatter.is_some() || content.is_some();
211            Ok(UpdateHookResult { modified, frontmatter, content })
212        }
213        mlua::Value::Nil => {
214            // Hook returned nil, no modifications
215            Ok(UpdateHookResult { modified: false, frontmatter: None, content: None })
216        }
217        _ => {
218            // Unexpected return type
219            Ok(UpdateHookResult { modified: false, frontmatter: None, content: None })
220        }
221    }
222}
223
224/// Convert a Lua table to serde_yaml::Value.
225fn lua_table_to_yaml(table: &mlua::Table) -> Result<serde_yaml::Value, HookError> {
226    let mut map = serde_yaml::Mapping::new();
227
228    for pair in table.pairs::<mlua::Value, mlua::Value>() {
229        let (key, value) = pair.map_err(|e| HookError::LuaError(e.to_string()))?;
230
231        let yaml_key = match key {
232            mlua::Value::String(s) => {
233                let str_val =
234                    s.to_str().map_err(|e| HookError::LuaError(e.to_string()))?;
235                serde_yaml::Value::String(str_val.to_string())
236            }
237            mlua::Value::Integer(i) => serde_yaml::Value::Number(i.into()),
238            _ => continue, // Skip non-string/integer keys
239        };
240
241        let yaml_value = lua_value_to_yaml(value)?;
242        map.insert(yaml_key, yaml_value);
243    }
244
245    Ok(serde_yaml::Value::Mapping(map))
246}
247
248/// Convert a single Lua value to serde_yaml::Value.
249fn lua_value_to_yaml(value: mlua::Value) -> Result<serde_yaml::Value, HookError> {
250    match value {
251        mlua::Value::Nil => Ok(serde_yaml::Value::Null),
252        mlua::Value::Boolean(b) => Ok(serde_yaml::Value::Bool(b)),
253        mlua::Value::Integer(i) => Ok(serde_yaml::Value::Number(i.into())),
254        mlua::Value::Number(n) => {
255            Ok(serde_yaml::Value::Number(serde_yaml::Number::from(n)))
256        }
257        mlua::Value::String(s) => {
258            let str_val = s.to_str().map_err(|e| HookError::LuaError(e.to_string()))?;
259            Ok(serde_yaml::Value::String(str_val.to_string()))
260        }
261        mlua::Value::Table(t) => {
262            // Check if it's an array or a map
263            if is_lua_array(&t) {
264                let mut seq = Vec::new();
265                for pair in t.pairs::<i64, mlua::Value>() {
266                    let (_, v) = pair.map_err(|e| HookError::LuaError(e.to_string()))?;
267                    seq.push(lua_value_to_yaml(v)?);
268                }
269                Ok(serde_yaml::Value::Sequence(seq))
270            } else {
271                lua_table_to_yaml(&t)
272            }
273        }
274        _ => Ok(serde_yaml::Value::Null),
275    }
276}
277
278/// Check if a Lua table is an array (sequential integer keys starting from 1).
279fn is_lua_array(table: &mlua::Table) -> bool {
280    let len = table.raw_len();
281    if len == 0 {
282        // Could be empty table, check for any keys
283        table.pairs::<mlua::Value, mlua::Value>().next().is_none()
284    } else {
285        // Check if keys are 1..=len
286        for i in 1..=len {
287            if table.raw_get::<mlua::Value>(i).is_err() {
288                return false;
289            }
290        }
291        true
292    }
293}
294
295#[cfg(test)]
296mod tests {
297    use super::*;
298    use std::collections::HashMap;
299    use std::path::PathBuf;
300
301    fn make_typedef_with_hook(lua_source: &str) -> TypeDefinition {
302        TypeDefinition {
303            name: "test".to_string(),
304            description: None,
305            source_path: PathBuf::new(),
306            schema: HashMap::new(),
307            has_validate_fn: false,
308            has_on_create_hook: true,
309            has_on_update_hook: false,
310            is_builtin_override: false,
311            lua_source: lua_source.to_string(),
312        }
313    }
314
315    fn make_note_ctx() -> NoteContext {
316        NoteContext {
317            path: PathBuf::from("test.md"),
318            note_type: "test".to_string(),
319            frontmatter: serde_yaml::Value::Mapping(serde_yaml::Mapping::new()),
320            content: "# Test".to_string(),
321        }
322    }
323
324    #[test]
325    fn test_skip_if_no_hook() {
326        let typedef = TypeDefinition {
327            name: "test".to_string(),
328            description: None,
329            source_path: PathBuf::new(),
330            schema: HashMap::new(),
331            has_validate_fn: false,
332            has_on_create_hook: false, // No hook
333            has_on_update_hook: false,
334            is_builtin_override: false,
335            lua_source: String::new(),
336        };
337
338        // Create a minimal vault context - this won't be used since there's no hook
339        // We can't easily create a VaultContext in tests without real repositories,
340        // but since has_on_create_hook is false, it will return early
341        let _note_ctx = make_note_ctx();
342
343        // This test verifies that when has_on_create_hook is false,
344        // the function returns Ok(()) without trying to access vault_ctx
345        // However, we need a VaultContext to call the function...
346        // For now, just test the hook detection logic works.
347        assert!(!typedef.has_on_create_hook);
348    }
349
350    #[test]
351    fn test_hook_receives_note_context() {
352        // This test verifies the Lua hook structure works
353        // We create a hook that just returns true without vault operations
354        let lua_source = r#"
355            return {
356                on_create = function(note)
357                    -- Just verify we can access note fields
358                    local _ = note.path
359                    local _ = note.type
360                    local _ = note.content
361                    local _ = note.frontmatter
362                    return note
363                end
364            }
365        "#;
366
367        let _typedef = make_typedef_with_hook(lua_source);
368        let _note_ctx = make_note_ctx();
369
370        // Create a sandboxed engine to test the Lua code directly
371        let engine = LuaEngine::sandboxed().unwrap();
372        let lua = engine.lua();
373
374        // Load the typedef
375        let typedef_table: mlua::Table = lua.load(lua_source).eval().unwrap();
376
377        // Build note table
378        let note_table = lua.create_table().unwrap();
379        note_table.set("path", "test.md").unwrap();
380        note_table.set("type", "test").unwrap();
381        note_table.set("content", "# Test").unwrap();
382        let fm = lua.create_table().unwrap();
383        note_table.set("frontmatter", fm).unwrap();
384
385        // Call on_create
386        let on_create: mlua::Function = typedef_table.get("on_create").unwrap();
387        let result = on_create.call::<mlua::Value>(note_table);
388
389        assert!(result.is_ok());
390    }
391}