mdvault_core/scripting/
hook_runner.rs1use super::engine::LuaEngine;
6use super::hooks::{HookError, NoteContext};
7use super::types::SandboxConfig;
8use super::vault_context::VaultContext;
9use crate::types::definition::TypeDefinition;
10use crate::types::validation::yaml_to_lua_table;
11
12#[derive(Debug)]
14pub struct UpdateHookResult {
15 pub modified: bool,
17 pub frontmatter: Option<serde_yaml::Value>,
19 pub content: Option<String>,
21}
22
23pub fn run_on_create_hook(
48 typedef: &TypeDefinition,
49 note_ctx: &NoteContext,
50 vault_ctx: VaultContext,
51) -> Result<(), HookError> {
52 if !typedef.has_on_create_hook {
54 return Ok(());
55 }
56
57 let engine = LuaEngine::with_vault_context(SandboxConfig::restricted(), vault_ctx)
59 .map_err(|e| HookError::LuaError(e.to_string()))?;
60
61 let lua = engine.lua();
62
63 let typedef_table: mlua::Table =
65 lua.load(&typedef.lua_source).eval().map_err(|e| {
66 HookError::LuaError(format!("failed to load type definition: {}", e))
67 })?;
68
69 let note_table =
71 lua.create_table().map_err(|e| HookError::LuaError(e.to_string()))?;
72
73 note_table
74 .set("path", note_ctx.path.to_string_lossy().to_string())
75 .map_err(|e| HookError::LuaError(e.to_string()))?;
76
77 note_table
78 .set("type", note_ctx.note_type.clone())
79 .map_err(|e| HookError::LuaError(e.to_string()))?;
80
81 note_table
82 .set("content", note_ctx.content.clone())
83 .map_err(|e| HookError::LuaError(e.to_string()))?;
84
85 let fm_table = yaml_to_lua_table(lua, ¬e_ctx.frontmatter)
87 .map_err(|e| HookError::LuaError(e.to_string()))?;
88
89 note_table
90 .set("frontmatter", fm_table)
91 .map_err(|e| HookError::LuaError(e.to_string()))?;
92
93 let on_create_fn: mlua::Function = typedef_table.get("on_create").map_err(|e| {
95 HookError::LuaError(format!("on_create function not found: {}", e))
96 })?;
97
98 on_create_fn
102 .call::<()>(note_table)
103 .map_err(|e| HookError::Execution(format!("on_create hook failed: {}", e)))?;
104
105 Ok(())
106}
107
108pub fn run_on_update_hook(
138 typedef: &TypeDefinition,
139 note_ctx: &NoteContext,
140 vault_ctx: VaultContext,
141) -> Result<UpdateHookResult, HookError> {
142 if !typedef.has_on_update_hook {
144 return Ok(UpdateHookResult {
145 modified: false,
146 frontmatter: None,
147 content: None,
148 });
149 }
150
151 let engine = LuaEngine::with_vault_context(SandboxConfig::restricted(), vault_ctx)
153 .map_err(|e| HookError::LuaError(e.to_string()))?;
154
155 let lua = engine.lua();
156
157 let typedef_table: mlua::Table =
159 lua.load(&typedef.lua_source).eval().map_err(|e| {
160 HookError::LuaError(format!("failed to load type definition: {}", e))
161 })?;
162
163 let note_table =
165 lua.create_table().map_err(|e| HookError::LuaError(e.to_string()))?;
166
167 note_table
168 .set("path", note_ctx.path.to_string_lossy().to_string())
169 .map_err(|e| HookError::LuaError(e.to_string()))?;
170
171 note_table
172 .set("type", note_ctx.note_type.clone())
173 .map_err(|e| HookError::LuaError(e.to_string()))?;
174
175 note_table
176 .set("content", note_ctx.content.clone())
177 .map_err(|e| HookError::LuaError(e.to_string()))?;
178
179 let fm_table = yaml_to_lua_table(lua, ¬e_ctx.frontmatter)
181 .map_err(|e| HookError::LuaError(e.to_string()))?;
182
183 note_table
184 .set("frontmatter", fm_table)
185 .map_err(|e| HookError::LuaError(e.to_string()))?;
186
187 let on_update_fn: mlua::Function = typedef_table.get("on_update").map_err(|e| {
189 HookError::LuaError(format!("on_update function not found: {}", e))
190 })?;
191
192 let result: mlua::Value = on_update_fn
194 .call(note_table)
195 .map_err(|e| HookError::Execution(format!("on_update hook failed: {}", e)))?;
196
197 match result {
199 mlua::Value::Table(returned_note) => {
200 let frontmatter: Option<serde_yaml::Value> =
202 if let Ok(fm_table) = returned_note.get::<mlua::Table>("frontmatter") {
203 Some(lua_table_to_yaml(&fm_table)?)
204 } else {
205 None
206 };
207
208 let content: Option<String> = returned_note.get("content").ok();
209
210 let modified = frontmatter.is_some() || content.is_some();
211 Ok(UpdateHookResult { modified, frontmatter, content })
212 }
213 mlua::Value::Nil => {
214 Ok(UpdateHookResult { modified: false, frontmatter: None, content: None })
216 }
217 _ => {
218 Ok(UpdateHookResult { modified: false, frontmatter: None, content: None })
220 }
221 }
222}
223
224fn lua_table_to_yaml(table: &mlua::Table) -> Result<serde_yaml::Value, HookError> {
226 let mut map = serde_yaml::Mapping::new();
227
228 for pair in table.pairs::<mlua::Value, mlua::Value>() {
229 let (key, value) = pair.map_err(|e| HookError::LuaError(e.to_string()))?;
230
231 let yaml_key = match key {
232 mlua::Value::String(s) => {
233 let str_val =
234 s.to_str().map_err(|e| HookError::LuaError(e.to_string()))?;
235 serde_yaml::Value::String(str_val.to_string())
236 }
237 mlua::Value::Integer(i) => serde_yaml::Value::Number(i.into()),
238 _ => continue, };
240
241 let yaml_value = lua_value_to_yaml(value)?;
242 map.insert(yaml_key, yaml_value);
243 }
244
245 Ok(serde_yaml::Value::Mapping(map))
246}
247
248fn lua_value_to_yaml(value: mlua::Value) -> Result<serde_yaml::Value, HookError> {
250 match value {
251 mlua::Value::Nil => Ok(serde_yaml::Value::Null),
252 mlua::Value::Boolean(b) => Ok(serde_yaml::Value::Bool(b)),
253 mlua::Value::Integer(i) => Ok(serde_yaml::Value::Number(i.into())),
254 mlua::Value::Number(n) => {
255 Ok(serde_yaml::Value::Number(serde_yaml::Number::from(n)))
256 }
257 mlua::Value::String(s) => {
258 let str_val = s.to_str().map_err(|e| HookError::LuaError(e.to_string()))?;
259 Ok(serde_yaml::Value::String(str_val.to_string()))
260 }
261 mlua::Value::Table(t) => {
262 if is_lua_array(&t) {
264 let mut seq = Vec::new();
265 for pair in t.pairs::<i64, mlua::Value>() {
266 let (_, v) = pair.map_err(|e| HookError::LuaError(e.to_string()))?;
267 seq.push(lua_value_to_yaml(v)?);
268 }
269 Ok(serde_yaml::Value::Sequence(seq))
270 } else {
271 lua_table_to_yaml(&t)
272 }
273 }
274 _ => Ok(serde_yaml::Value::Null),
275 }
276}
277
278fn is_lua_array(table: &mlua::Table) -> bool {
280 let len = table.raw_len();
281 if len == 0 {
282 table.pairs::<mlua::Value, mlua::Value>().next().is_none()
284 } else {
285 for i in 1..=len {
287 if table.raw_get::<mlua::Value>(i).is_err() {
288 return false;
289 }
290 }
291 true
292 }
293}
294
295#[cfg(test)]
296mod tests {
297 use super::*;
298 use std::collections::HashMap;
299 use std::path::PathBuf;
300
301 fn make_typedef_with_hook(lua_source: &str) -> TypeDefinition {
302 TypeDefinition {
303 name: "test".to_string(),
304 description: None,
305 source_path: PathBuf::new(),
306 schema: HashMap::new(),
307 has_validate_fn: false,
308 has_on_create_hook: true,
309 has_on_update_hook: false,
310 is_builtin_override: false,
311 lua_source: lua_source.to_string(),
312 }
313 }
314
315 fn make_note_ctx() -> NoteContext {
316 NoteContext {
317 path: PathBuf::from("test.md"),
318 note_type: "test".to_string(),
319 frontmatter: serde_yaml::Value::Mapping(serde_yaml::Mapping::new()),
320 content: "# Test".to_string(),
321 }
322 }
323
324 #[test]
325 fn test_skip_if_no_hook() {
326 let typedef = TypeDefinition {
327 name: "test".to_string(),
328 description: None,
329 source_path: PathBuf::new(),
330 schema: HashMap::new(),
331 has_validate_fn: false,
332 has_on_create_hook: false, has_on_update_hook: false,
334 is_builtin_override: false,
335 lua_source: String::new(),
336 };
337
338 let _note_ctx = make_note_ctx();
342
343 assert!(!typedef.has_on_create_hook);
348 }
349
350 #[test]
351 fn test_hook_receives_note_context() {
352 let lua_source = r#"
355 return {
356 on_create = function(note)
357 -- Just verify we can access note fields
358 local _ = note.path
359 local _ = note.type
360 local _ = note.content
361 local _ = note.frontmatter
362 return note
363 end
364 }
365 "#;
366
367 let _typedef = make_typedef_with_hook(lua_source);
368 let _note_ctx = make_note_ctx();
369
370 let engine = LuaEngine::sandboxed().unwrap();
372 let lua = engine.lua();
373
374 let typedef_table: mlua::Table = lua.load(lua_source).eval().unwrap();
376
377 let note_table = lua.create_table().unwrap();
379 note_table.set("path", "test.md").unwrap();
380 note_table.set("type", "test").unwrap();
381 note_table.set("content", "# Test").unwrap();
382 let fm = lua.create_table().unwrap();
383 note_table.set("frontmatter", fm).unwrap();
384
385 let on_create: mlua::Function = typedef_table.get("on_create").unwrap();
387 let result = on_create.call::<mlua::Value>(note_table);
388
389 assert!(result.is_ok());
390 }
391}