1use super::engine::LuaEngine;
6use super::hooks::{HookError, NoteContext};
7use super::types::SandboxConfig;
8use super::vault_context::VaultContext;
9use crate::types::definition::TypeDefinition;
10use crate::types::validation::yaml_to_lua_table;
11use tracing::debug;
12
13#[derive(Debug)]
15pub struct HookResult {
16 pub modified: bool,
18 pub frontmatter: Option<serde_yaml::Value>,
20 pub content: Option<String>,
22 pub variables: Option<serde_yaml::Value>,
24}
25
26pub type UpdateHookResult = HookResult;
28
29pub fn run_on_create_hook(
58 typedef: &TypeDefinition,
59 note_ctx: &NoteContext,
60 vault_ctx: VaultContext,
61) -> Result<HookResult, HookError> {
62 if !typedef.has_on_create_hook {
64 return Ok(HookResult {
65 modified: false,
66 frontmatter: None,
67 content: None,
68 variables: None,
69 });
70 }
71
72 let engine = LuaEngine::with_vault_context(SandboxConfig::restricted(), vault_ctx)
74 .map_err(|e| HookError::LuaError(e.to_string()))?;
75
76 let lua = engine.lua();
77
78 let typedef_table: mlua::Table =
80 lua.load(&typedef.lua_source).eval().map_err(|e| {
81 HookError::LuaError(format!("failed to load type definition: {}", e))
82 })?;
83
84 let note_table =
86 lua.create_table().map_err(|e| HookError::LuaError(e.to_string()))?;
87
88 note_table
89 .set("path", note_ctx.path.to_string_lossy().to_string())
90 .map_err(|e| HookError::LuaError(e.to_string()))?;
91
92 note_table
93 .set("type", note_ctx.note_type.clone())
94 .map_err(|e| HookError::LuaError(e.to_string()))?;
95
96 note_table
97 .set("content", note_ctx.content.clone())
98 .map_err(|e| HookError::LuaError(e.to_string()))?;
99
100 let fm_table = yaml_to_lua_table(lua, ¬e_ctx.frontmatter)
102 .map_err(|e| HookError::LuaError(e.to_string()))?;
103
104 note_table
105 .set("frontmatter", fm_table)
106 .map_err(|e| HookError::LuaError(e.to_string()))?;
107
108 debug!("Hook input variables: {:?}", note_ctx.variables);
110 let vars_table = yaml_to_lua_table(lua, ¬e_ctx.variables)
111 .map_err(|e| HookError::LuaError(e.to_string()))?;
112
113 note_table
114 .set("variables", vars_table)
115 .map_err(|e| HookError::LuaError(e.to_string()))?;
116
117 let on_create_fn: mlua::Function = typedef_table.get("on_create").map_err(|e| {
119 HookError::LuaError(format!("on_create function not found: {}", e))
120 })?;
121
122 let result: mlua::Value = on_create_fn
124 .call(note_table)
125 .map_err(|e| HookError::Execution(format!("on_create hook failed: {}", e)))?;
126
127 match result {
129 mlua::Value::Table(returned_note) => {
130 let frontmatter: Option<serde_yaml::Value> =
132 if let Ok(fm_table) = returned_note.get::<mlua::Table>("frontmatter") {
133 Some(lua_table_to_yaml(&fm_table)?)
134 } else {
135 None
136 };
137
138 let content: Option<String> = returned_note.get("content").ok();
139 let content = match content {
140 Some(ref c) if c != ¬e_ctx.content => content,
141 _ => None,
142 };
143
144 let variables: Option<serde_yaml::Value> =
145 if let Ok(vars_table) = returned_note.get::<mlua::Table>("variables") {
146 let v = Some(lua_table_to_yaml(&vars_table)?);
147 debug!("Hook output variables: {:?}", v);
148 v
149 } else {
150 None
151 };
152
153 let modified =
154 frontmatter.is_some() || content.is_some() || variables.is_some();
155 Ok(HookResult { modified, frontmatter, content, variables })
156 }
157 mlua::Value::Nil => {
158 Ok(HookResult {
160 modified: false,
161 frontmatter: None,
162 content: None,
163 variables: None,
164 })
165 }
166 _ => {
167 Ok(HookResult {
169 modified: false,
170 frontmatter: None,
171 content: None,
172 variables: None,
173 })
174 }
175 }
176}
177
178pub fn run_on_update_hook(
208 typedef: &TypeDefinition,
209 note_ctx: &NoteContext,
210 vault_ctx: VaultContext,
211) -> Result<UpdateHookResult, HookError> {
212 if !typedef.has_on_update_hook {
214 return Ok(UpdateHookResult {
215 modified: false,
216 frontmatter: None,
217 content: None,
218 variables: None,
219 });
220 }
221
222 let engine = LuaEngine::with_vault_context(SandboxConfig::restricted(), vault_ctx)
224 .map_err(|e| HookError::LuaError(e.to_string()))?;
225
226 let lua = engine.lua();
227
228 let typedef_table: mlua::Table =
230 lua.load(&typedef.lua_source).eval().map_err(|e| {
231 HookError::LuaError(format!("failed to load type definition: {}", e))
232 })?;
233
234 let note_table =
236 lua.create_table().map_err(|e| HookError::LuaError(e.to_string()))?;
237
238 note_table
239 .set("path", note_ctx.path.to_string_lossy().to_string())
240 .map_err(|e| HookError::LuaError(e.to_string()))?;
241
242 note_table
243 .set("type", note_ctx.note_type.clone())
244 .map_err(|e| HookError::LuaError(e.to_string()))?;
245
246 note_table
247 .set("content", note_ctx.content.clone())
248 .map_err(|e| HookError::LuaError(e.to_string()))?;
249
250 let fm_table = yaml_to_lua_table(lua, ¬e_ctx.frontmatter)
252 .map_err(|e| HookError::LuaError(e.to_string()))?;
253
254 note_table
255 .set("frontmatter", fm_table)
256 .map_err(|e| HookError::LuaError(e.to_string()))?;
257
258 let on_update_fn: mlua::Function = typedef_table.get("on_update").map_err(|e| {
260 HookError::LuaError(format!("on_update function not found: {}", e))
261 })?;
262
263 let result: mlua::Value = on_update_fn
265 .call(note_table)
266 .map_err(|e| HookError::Execution(format!("on_update hook failed: {}", e)))?;
267
268 match result {
270 mlua::Value::Table(returned_note) => {
271 let frontmatter: Option<serde_yaml::Value> =
273 if let Ok(fm_table) = returned_note.get::<mlua::Table>("frontmatter") {
274 Some(lua_table_to_yaml(&fm_table)?)
275 } else {
276 None
277 };
278
279 let content: Option<String> = returned_note.get("content").ok();
280
281 let modified = frontmatter.is_some() || content.is_some();
282 Ok(UpdateHookResult { modified, frontmatter, content, variables: None })
283 }
284 mlua::Value::Nil => {
285 Ok(UpdateHookResult {
287 modified: false,
288 frontmatter: None,
289 content: None,
290 variables: None,
291 })
292 }
293 _ => {
294 Ok(UpdateHookResult {
296 modified: false,
297 frontmatter: None,
298 content: None,
299 variables: None,
300 })
301 }
302 }
303}
304
305fn lua_table_to_yaml(table: &mlua::Table) -> Result<serde_yaml::Value, HookError> {
307 let mut map = serde_yaml::Mapping::new();
308
309 for pair in table.pairs::<mlua::Value, mlua::Value>() {
310 let (key, value) = pair.map_err(|e| HookError::LuaError(e.to_string()))?;
311
312 let yaml_key = match key {
313 mlua::Value::String(s) => {
314 let str_val =
315 s.to_str().map_err(|e| HookError::LuaError(e.to_string()))?;
316 serde_yaml::Value::String(str_val.to_string())
317 }
318 mlua::Value::Integer(i) => serde_yaml::Value::Number(i.into()),
319 _ => continue, };
321
322 let yaml_value = lua_value_to_yaml(value)?;
323 map.insert(yaml_key, yaml_value);
324 }
325
326 Ok(serde_yaml::Value::Mapping(map))
327}
328
329fn lua_value_to_yaml(value: mlua::Value) -> Result<serde_yaml::Value, HookError> {
331 match value {
332 mlua::Value::Nil => Ok(serde_yaml::Value::Null),
333 mlua::Value::Boolean(b) => Ok(serde_yaml::Value::Bool(b)),
334 mlua::Value::Integer(i) => Ok(serde_yaml::Value::Number(i.into())),
335 mlua::Value::Number(n) => {
336 Ok(serde_yaml::Value::Number(serde_yaml::Number::from(n)))
337 }
338 mlua::Value::String(s) => {
339 let str_val = s.to_str().map_err(|e| HookError::LuaError(e.to_string()))?;
340 Ok(serde_yaml::Value::String(str_val.to_string()))
341 }
342 mlua::Value::Table(t) => {
343 if is_lua_array(&t) {
345 let mut seq = Vec::new();
346 for pair in t.pairs::<i64, mlua::Value>() {
347 let (_, v) = pair.map_err(|e| HookError::LuaError(e.to_string()))?;
348 seq.push(lua_value_to_yaml(v)?);
349 }
350 Ok(serde_yaml::Value::Sequence(seq))
351 } else {
352 lua_table_to_yaml(&t)
353 }
354 }
355 _ => Ok(serde_yaml::Value::Null),
356 }
357}
358
359fn is_lua_array(table: &mlua::Table) -> bool {
361 let len = table.raw_len();
362 if len == 0 {
363 table.pairs::<mlua::Value, mlua::Value>().next().is_none()
365 } else {
366 for i in 1..=len {
368 if table.raw_get::<mlua::Value>(i).is_err() {
369 return false;
370 }
371 }
372 true
373 }
374}
375
376#[cfg(test)]
377mod tests {
378 use super::*;
379 use std::collections::HashMap;
380 use std::path::PathBuf;
381
382 fn make_typedef_with_hook(lua_source: &str) -> TypeDefinition {
383 TypeDefinition {
384 name: "test".to_string(),
385 description: None,
386 source_path: PathBuf::new(),
387 schema: HashMap::new(),
388 output: None,
389 frontmatter_order: None,
390 variables: crate::vars::VarsMap::new(),
391 has_validate_fn: false,
392 has_on_create_hook: true,
393 has_on_update_hook: false,
394 is_builtin_override: false,
395 lua_source: lua_source.to_string(),
396 }
397 }
398
399 fn make_note_ctx() -> NoteContext {
400 NoteContext {
401 path: PathBuf::from("test.md"),
402 note_type: "test".to_string(),
403 frontmatter: serde_yaml::Value::Mapping(serde_yaml::Mapping::new()),
404 content: "# Test".to_string(),
405 variables: serde_yaml::Value::Null,
406 }
407 }
408
409 #[test]
410 fn test_skip_if_no_hook() {
411 let typedef = TypeDefinition {
412 name: "test".to_string(),
413 description: None,
414 source_path: PathBuf::new(),
415 schema: HashMap::new(),
416 output: None,
417 frontmatter_order: None,
418 variables: crate::vars::VarsMap::new(),
419 has_validate_fn: false,
420 has_on_create_hook: false, has_on_update_hook: false,
422 is_builtin_override: false,
423 lua_source: String::new(),
424 };
425
426 let _note_ctx = make_note_ctx();
430
431 assert!(!typedef.has_on_create_hook);
436 }
437
438 #[test]
439 fn test_hook_receives_note_context() {
440 let lua_source = r#"
443 return {
444 on_create = function(note)
445 -- Just verify we can access note fields
446 local _ = note.path
447 local _ = note.type
448 local _ = note.content
449 local _ = note.frontmatter
450 return note
451 end
452 }
453 "#;
454
455 let _typedef = make_typedef_with_hook(lua_source);
456 let _note_ctx = make_note_ctx();
457
458 let engine = LuaEngine::sandboxed().unwrap();
460 let lua = engine.lua();
461
462 let typedef_table: mlua::Table = lua.load(lua_source).eval().unwrap();
464
465 let note_table = lua.create_table().unwrap();
467 note_table.set("path", "test.md").unwrap();
468 note_table.set("type", "test").unwrap();
469 note_table.set("content", "# Test").unwrap();
470 let fm = lua.create_table().unwrap();
471 note_table.set("frontmatter", fm).unwrap();
472
473 let on_create: mlua::Function = typedef_table.get("on_create").unwrap();
475 let result = on_create.call::<mlua::Value>(note_table);
476
477 assert!(result.is_ok());
478 }
479}