mdvault_core/scripting/
hook_runner.rs

1//! Hook execution for lifecycle events.
2//!
3//! This module provides functions to run lifecycle hooks defined in type definitions.
4
5use super::engine::LuaEngine;
6use super::hooks::{HookError, NoteContext};
7use super::types::SandboxConfig;
8use super::vault_context::VaultContext;
9use crate::types::definition::TypeDefinition;
10use crate::types::validation::yaml_to_lua_table;
11
12/// Result of running a hook that may modify the note.
13#[derive(Debug)]
14pub struct HookResult {
15    /// Whether the hook made changes to the note.
16    pub modified: bool,
17    /// The updated frontmatter (if modified).
18    pub frontmatter: Option<serde_yaml::Value>,
19    /// The updated content (if modified).
20    pub content: Option<String>,
21}
22
23/// Alias for backwards compatibility.
24pub type UpdateHookResult = HookResult;
25
26/// Run the `on_create` hook for a type definition.
27///
28/// This function is called after a note is created to allow the type definition
29/// to perform additional operations like logging to daily notes, updating indexes,
30/// or modifying the note's frontmatter.
31///
32/// # Arguments
33///
34/// * `typedef` - The type definition containing the hook
35/// * `note_ctx` - Context about the created note
36/// * `vault_ctx` - Vault context with access to repositories
37///
38/// # Returns
39///
40/// * `Ok(HookResult)` with any modifications from the hook
41/// * `Err(HookError)` on failure
42///
43/// # Example
44///
45/// ```ignore
46/// use mdvault_core::scripting::{run_on_create_hook, NoteContext, VaultContext};
47///
48/// let note_ctx = NoteContext::new(path, "task".into(), frontmatter, content);
49/// let result = run_on_create_hook(&typedef, &note_ctx, vault_ctx)?;
50/// if result.modified {
51///     // Write back the updated content
52/// }
53/// ```
54pub fn run_on_create_hook(
55    typedef: &TypeDefinition,
56    note_ctx: &NoteContext,
57    vault_ctx: VaultContext,
58) -> Result<HookResult, HookError> {
59    // Skip if no hook defined
60    if !typedef.has_on_create_hook {
61        return Ok(HookResult { modified: false, frontmatter: None, content: None });
62    }
63
64    // Create engine with vault context
65    let engine = LuaEngine::with_vault_context(SandboxConfig::restricted(), vault_ctx)
66        .map_err(|e| HookError::LuaError(e.to_string()))?;
67
68    let lua = engine.lua();
69
70    // Load and evaluate the type definition to get the table
71    let typedef_table: mlua::Table =
72        lua.load(&typedef.lua_source).eval().map_err(|e| {
73            HookError::LuaError(format!("failed to load type definition: {}", e))
74        })?;
75
76    // Build note table for the hook
77    let note_table =
78        lua.create_table().map_err(|e| HookError::LuaError(e.to_string()))?;
79
80    note_table
81        .set("path", note_ctx.path.to_string_lossy().to_string())
82        .map_err(|e| HookError::LuaError(e.to_string()))?;
83
84    note_table
85        .set("type", note_ctx.note_type.clone())
86        .map_err(|e| HookError::LuaError(e.to_string()))?;
87
88    note_table
89        .set("content", note_ctx.content.clone())
90        .map_err(|e| HookError::LuaError(e.to_string()))?;
91
92    // Convert frontmatter to Lua table
93    let fm_table = yaml_to_lua_table(lua, &note_ctx.frontmatter)
94        .map_err(|e| HookError::LuaError(e.to_string()))?;
95
96    note_table
97        .set("frontmatter", fm_table)
98        .map_err(|e| HookError::LuaError(e.to_string()))?;
99
100    // Get on_create function
101    let on_create_fn: mlua::Function = typedef_table.get("on_create").map_err(|e| {
102        HookError::LuaError(format!("on_create function not found: {}", e))
103    })?;
104
105    // Call the hook - it may return a modified note table
106    let result: mlua::Value = on_create_fn
107        .call(note_table)
108        .map_err(|e| HookError::Execution(format!("on_create hook failed: {}", e)))?;
109
110    // Check if hook returned a modified note
111    match result {
112        mlua::Value::Table(returned_note) => {
113            // Extract frontmatter and content if present
114            let frontmatter: Option<serde_yaml::Value> =
115                if let Ok(fm_table) = returned_note.get::<mlua::Table>("frontmatter") {
116                    Some(lua_table_to_yaml(&fm_table)?)
117                } else {
118                    None
119                };
120
121            let content: Option<String> = returned_note.get("content").ok();
122
123            let modified = frontmatter.is_some() || content.is_some();
124            Ok(HookResult { modified, frontmatter, content })
125        }
126        mlua::Value::Nil => {
127            // Hook returned nil, no modifications
128            Ok(HookResult { modified: false, frontmatter: None, content: None })
129        }
130        _ => {
131            // Unexpected return type
132            Ok(HookResult { modified: false, frontmatter: None, content: None })
133        }
134    }
135}
136
137/// Run the `on_update` hook for a type definition.
138///
139/// This function is called after a note is modified (via capture operations) to allow
140/// the type definition to perform additional operations like updating timestamps.
141///
142/// Unlike `on_create`, this hook can return a modified note which will be written back.
143///
144/// # Arguments
145///
146/// * `typedef` - The type definition containing the hook
147/// * `note_ctx` - Context about the updated note
148/// * `vault_ctx` - Vault context with access to repositories
149///
150/// # Returns
151///
152/// * `Ok(UpdateHookResult)` with any modifications from the hook
153/// * `Err(HookError)` on failure
154///
155/// # Example
156///
157/// ```ignore
158/// use mdvault_core::scripting::{run_on_update_hook, NoteContext, VaultContext};
159///
160/// let note_ctx = NoteContext::new(path, "task".into(), frontmatter, content);
161/// let result = run_on_update_hook(&typedef, &note_ctx, vault_ctx)?;
162/// if result.modified {
163///     // Write back the updated content
164/// }
165/// ```
166pub fn run_on_update_hook(
167    typedef: &TypeDefinition,
168    note_ctx: &NoteContext,
169    vault_ctx: VaultContext,
170) -> Result<UpdateHookResult, HookError> {
171    // Skip if no hook defined
172    if !typedef.has_on_update_hook {
173        return Ok(UpdateHookResult {
174            modified: false,
175            frontmatter: None,
176            content: None,
177        });
178    }
179
180    // Create engine with vault context
181    let engine = LuaEngine::with_vault_context(SandboxConfig::restricted(), vault_ctx)
182        .map_err(|e| HookError::LuaError(e.to_string()))?;
183
184    let lua = engine.lua();
185
186    // Load and evaluate the type definition to get the table
187    let typedef_table: mlua::Table =
188        lua.load(&typedef.lua_source).eval().map_err(|e| {
189            HookError::LuaError(format!("failed to load type definition: {}", e))
190        })?;
191
192    // Build note table for the hook
193    let note_table =
194        lua.create_table().map_err(|e| HookError::LuaError(e.to_string()))?;
195
196    note_table
197        .set("path", note_ctx.path.to_string_lossy().to_string())
198        .map_err(|e| HookError::LuaError(e.to_string()))?;
199
200    note_table
201        .set("type", note_ctx.note_type.clone())
202        .map_err(|e| HookError::LuaError(e.to_string()))?;
203
204    note_table
205        .set("content", note_ctx.content.clone())
206        .map_err(|e| HookError::LuaError(e.to_string()))?;
207
208    // Convert frontmatter to Lua table
209    let fm_table = yaml_to_lua_table(lua, &note_ctx.frontmatter)
210        .map_err(|e| HookError::LuaError(e.to_string()))?;
211
212    note_table
213        .set("frontmatter", fm_table)
214        .map_err(|e| HookError::LuaError(e.to_string()))?;
215
216    // Get on_update function
217    let on_update_fn: mlua::Function = typedef_table.get("on_update").map_err(|e| {
218        HookError::LuaError(format!("on_update function not found: {}", e))
219    })?;
220
221    // Call the hook - it may return a modified note table
222    let result: mlua::Value = on_update_fn
223        .call(note_table)
224        .map_err(|e| HookError::Execution(format!("on_update hook failed: {}", e)))?;
225
226    // Check if hook returned a modified note
227    match result {
228        mlua::Value::Table(returned_note) => {
229            // Extract frontmatter and content if present
230            let frontmatter: Option<serde_yaml::Value> =
231                if let Ok(fm_table) = returned_note.get::<mlua::Table>("frontmatter") {
232                    Some(lua_table_to_yaml(&fm_table)?)
233                } else {
234                    None
235                };
236
237            let content: Option<String> = returned_note.get("content").ok();
238
239            let modified = frontmatter.is_some() || content.is_some();
240            Ok(UpdateHookResult { modified, frontmatter, content })
241        }
242        mlua::Value::Nil => {
243            // Hook returned nil, no modifications
244            Ok(UpdateHookResult { modified: false, frontmatter: None, content: None })
245        }
246        _ => {
247            // Unexpected return type
248            Ok(UpdateHookResult { modified: false, frontmatter: None, content: None })
249        }
250    }
251}
252
253/// Convert a Lua table to serde_yaml::Value.
254fn lua_table_to_yaml(table: &mlua::Table) -> Result<serde_yaml::Value, HookError> {
255    let mut map = serde_yaml::Mapping::new();
256
257    for pair in table.pairs::<mlua::Value, mlua::Value>() {
258        let (key, value) = pair.map_err(|e| HookError::LuaError(e.to_string()))?;
259
260        let yaml_key = match key {
261            mlua::Value::String(s) => {
262                let str_val =
263                    s.to_str().map_err(|e| HookError::LuaError(e.to_string()))?;
264                serde_yaml::Value::String(str_val.to_string())
265            }
266            mlua::Value::Integer(i) => serde_yaml::Value::Number(i.into()),
267            _ => continue, // Skip non-string/integer keys
268        };
269
270        let yaml_value = lua_value_to_yaml(value)?;
271        map.insert(yaml_key, yaml_value);
272    }
273
274    Ok(serde_yaml::Value::Mapping(map))
275}
276
277/// Convert a single Lua value to serde_yaml::Value.
278fn lua_value_to_yaml(value: mlua::Value) -> Result<serde_yaml::Value, HookError> {
279    match value {
280        mlua::Value::Nil => Ok(serde_yaml::Value::Null),
281        mlua::Value::Boolean(b) => Ok(serde_yaml::Value::Bool(b)),
282        mlua::Value::Integer(i) => Ok(serde_yaml::Value::Number(i.into())),
283        mlua::Value::Number(n) => {
284            Ok(serde_yaml::Value::Number(serde_yaml::Number::from(n)))
285        }
286        mlua::Value::String(s) => {
287            let str_val = s.to_str().map_err(|e| HookError::LuaError(e.to_string()))?;
288            Ok(serde_yaml::Value::String(str_val.to_string()))
289        }
290        mlua::Value::Table(t) => {
291            // Check if it's an array or a map
292            if is_lua_array(&t) {
293                let mut seq = Vec::new();
294                for pair in t.pairs::<i64, mlua::Value>() {
295                    let (_, v) = pair.map_err(|e| HookError::LuaError(e.to_string()))?;
296                    seq.push(lua_value_to_yaml(v)?);
297                }
298                Ok(serde_yaml::Value::Sequence(seq))
299            } else {
300                lua_table_to_yaml(&t)
301            }
302        }
303        _ => Ok(serde_yaml::Value::Null),
304    }
305}
306
307/// Check if a Lua table is an array (sequential integer keys starting from 1).
308fn is_lua_array(table: &mlua::Table) -> bool {
309    let len = table.raw_len();
310    if len == 0 {
311        // Could be empty table, check for any keys
312        table.pairs::<mlua::Value, mlua::Value>().next().is_none()
313    } else {
314        // Check if keys are 1..=len
315        for i in 1..=len {
316            if table.raw_get::<mlua::Value>(i).is_err() {
317                return false;
318            }
319        }
320        true
321    }
322}
323
324#[cfg(test)]
325mod tests {
326    use super::*;
327    use std::collections::HashMap;
328    use std::path::PathBuf;
329
330    fn make_typedef_with_hook(lua_source: &str) -> TypeDefinition {
331        TypeDefinition {
332            name: "test".to_string(),
333            description: None,
334            source_path: PathBuf::new(),
335            schema: HashMap::new(),
336            has_validate_fn: false,
337            has_on_create_hook: true,
338            has_on_update_hook: false,
339            is_builtin_override: false,
340            lua_source: lua_source.to_string(),
341        }
342    }
343
344    fn make_note_ctx() -> NoteContext {
345        NoteContext {
346            path: PathBuf::from("test.md"),
347            note_type: "test".to_string(),
348            frontmatter: serde_yaml::Value::Mapping(serde_yaml::Mapping::new()),
349            content: "# Test".to_string(),
350        }
351    }
352
353    #[test]
354    fn test_skip_if_no_hook() {
355        let typedef = TypeDefinition {
356            name: "test".to_string(),
357            description: None,
358            source_path: PathBuf::new(),
359            schema: HashMap::new(),
360            has_validate_fn: false,
361            has_on_create_hook: false, // No hook
362            has_on_update_hook: false,
363            is_builtin_override: false,
364            lua_source: String::new(),
365        };
366
367        // Create a minimal vault context - this won't be used since there's no hook
368        // We can't easily create a VaultContext in tests without real repositories,
369        // but since has_on_create_hook is false, it will return early
370        let _note_ctx = make_note_ctx();
371
372        // This test verifies that when has_on_create_hook is false,
373        // the function returns Ok(()) without trying to access vault_ctx
374        // However, we need a VaultContext to call the function...
375        // For now, just test the hook detection logic works.
376        assert!(!typedef.has_on_create_hook);
377    }
378
379    #[test]
380    fn test_hook_receives_note_context() {
381        // This test verifies the Lua hook structure works
382        // We create a hook that just returns true without vault operations
383        let lua_source = r#"
384            return {
385                on_create = function(note)
386                    -- Just verify we can access note fields
387                    local _ = note.path
388                    local _ = note.type
389                    local _ = note.content
390                    local _ = note.frontmatter
391                    return note
392                end
393            }
394        "#;
395
396        let _typedef = make_typedef_with_hook(lua_source);
397        let _note_ctx = make_note_ctx();
398
399        // Create a sandboxed engine to test the Lua code directly
400        let engine = LuaEngine::sandboxed().unwrap();
401        let lua = engine.lua();
402
403        // Load the typedef
404        let typedef_table: mlua::Table = lua.load(lua_source).eval().unwrap();
405
406        // Build note table
407        let note_table = lua.create_table().unwrap();
408        note_table.set("path", "test.md").unwrap();
409        note_table.set("type", "test").unwrap();
410        note_table.set("content", "# Test").unwrap();
411        let fm = lua.create_table().unwrap();
412        note_table.set("frontmatter", fm).unwrap();
413
414        // Call on_create
415        let on_create: mlua::Function = typedef_table.get("on_create").unwrap();
416        let result = on_create.call::<mlua::Value>(note_table);
417
418        assert!(result.is_ok());
419    }
420}