1use super::engine::LuaEngine;
6use super::hooks::{HookError, NoteContext};
7use super::types::SandboxConfig;
8use super::vault_context::VaultContext;
9use crate::types::definition::TypeDefinition;
10use crate::types::validation::yaml_to_lua_table;
11
12#[derive(Debug)]
14pub struct HookResult {
15 pub modified: bool,
17 pub frontmatter: Option<serde_yaml::Value>,
19 pub content: Option<String>,
21}
22
23pub type UpdateHookResult = HookResult;
25
26pub fn run_on_create_hook(
55 typedef: &TypeDefinition,
56 note_ctx: &NoteContext,
57 vault_ctx: VaultContext,
58) -> Result<HookResult, HookError> {
59 if !typedef.has_on_create_hook {
61 return Ok(HookResult { modified: false, frontmatter: None, content: None });
62 }
63
64 let engine = LuaEngine::with_vault_context(SandboxConfig::restricted(), vault_ctx)
66 .map_err(|e| HookError::LuaError(e.to_string()))?;
67
68 let lua = engine.lua();
69
70 let typedef_table: mlua::Table =
72 lua.load(&typedef.lua_source).eval().map_err(|e| {
73 HookError::LuaError(format!("failed to load type definition: {}", e))
74 })?;
75
76 let note_table =
78 lua.create_table().map_err(|e| HookError::LuaError(e.to_string()))?;
79
80 note_table
81 .set("path", note_ctx.path.to_string_lossy().to_string())
82 .map_err(|e| HookError::LuaError(e.to_string()))?;
83
84 note_table
85 .set("type", note_ctx.note_type.clone())
86 .map_err(|e| HookError::LuaError(e.to_string()))?;
87
88 note_table
89 .set("content", note_ctx.content.clone())
90 .map_err(|e| HookError::LuaError(e.to_string()))?;
91
92 let fm_table = yaml_to_lua_table(lua, ¬e_ctx.frontmatter)
94 .map_err(|e| HookError::LuaError(e.to_string()))?;
95
96 note_table
97 .set("frontmatter", fm_table)
98 .map_err(|e| HookError::LuaError(e.to_string()))?;
99
100 let on_create_fn: mlua::Function = typedef_table.get("on_create").map_err(|e| {
102 HookError::LuaError(format!("on_create function not found: {}", e))
103 })?;
104
105 let result: mlua::Value = on_create_fn
107 .call(note_table)
108 .map_err(|e| HookError::Execution(format!("on_create hook failed: {}", e)))?;
109
110 match result {
112 mlua::Value::Table(returned_note) => {
113 let frontmatter: Option<serde_yaml::Value> =
115 if let Ok(fm_table) = returned_note.get::<mlua::Table>("frontmatter") {
116 Some(lua_table_to_yaml(&fm_table)?)
117 } else {
118 None
119 };
120
121 let content: Option<String> = returned_note.get("content").ok();
122
123 let modified = frontmatter.is_some() || content.is_some();
124 Ok(HookResult { modified, frontmatter, content })
125 }
126 mlua::Value::Nil => {
127 Ok(HookResult { modified: false, frontmatter: None, content: None })
129 }
130 _ => {
131 Ok(HookResult { modified: false, frontmatter: None, content: None })
133 }
134 }
135}
136
137pub fn run_on_update_hook(
167 typedef: &TypeDefinition,
168 note_ctx: &NoteContext,
169 vault_ctx: VaultContext,
170) -> Result<UpdateHookResult, HookError> {
171 if !typedef.has_on_update_hook {
173 return Ok(UpdateHookResult {
174 modified: false,
175 frontmatter: None,
176 content: None,
177 });
178 }
179
180 let engine = LuaEngine::with_vault_context(SandboxConfig::restricted(), vault_ctx)
182 .map_err(|e| HookError::LuaError(e.to_string()))?;
183
184 let lua = engine.lua();
185
186 let typedef_table: mlua::Table =
188 lua.load(&typedef.lua_source).eval().map_err(|e| {
189 HookError::LuaError(format!("failed to load type definition: {}", e))
190 })?;
191
192 let note_table =
194 lua.create_table().map_err(|e| HookError::LuaError(e.to_string()))?;
195
196 note_table
197 .set("path", note_ctx.path.to_string_lossy().to_string())
198 .map_err(|e| HookError::LuaError(e.to_string()))?;
199
200 note_table
201 .set("type", note_ctx.note_type.clone())
202 .map_err(|e| HookError::LuaError(e.to_string()))?;
203
204 note_table
205 .set("content", note_ctx.content.clone())
206 .map_err(|e| HookError::LuaError(e.to_string()))?;
207
208 let fm_table = yaml_to_lua_table(lua, ¬e_ctx.frontmatter)
210 .map_err(|e| HookError::LuaError(e.to_string()))?;
211
212 note_table
213 .set("frontmatter", fm_table)
214 .map_err(|e| HookError::LuaError(e.to_string()))?;
215
216 let on_update_fn: mlua::Function = typedef_table.get("on_update").map_err(|e| {
218 HookError::LuaError(format!("on_update function not found: {}", e))
219 })?;
220
221 let result: mlua::Value = on_update_fn
223 .call(note_table)
224 .map_err(|e| HookError::Execution(format!("on_update hook failed: {}", e)))?;
225
226 match result {
228 mlua::Value::Table(returned_note) => {
229 let frontmatter: Option<serde_yaml::Value> =
231 if let Ok(fm_table) = returned_note.get::<mlua::Table>("frontmatter") {
232 Some(lua_table_to_yaml(&fm_table)?)
233 } else {
234 None
235 };
236
237 let content: Option<String> = returned_note.get("content").ok();
238
239 let modified = frontmatter.is_some() || content.is_some();
240 Ok(UpdateHookResult { modified, frontmatter, content })
241 }
242 mlua::Value::Nil => {
243 Ok(UpdateHookResult { modified: false, frontmatter: None, content: None })
245 }
246 _ => {
247 Ok(UpdateHookResult { modified: false, frontmatter: None, content: None })
249 }
250 }
251}
252
253fn lua_table_to_yaml(table: &mlua::Table) -> Result<serde_yaml::Value, HookError> {
255 let mut map = serde_yaml::Mapping::new();
256
257 for pair in table.pairs::<mlua::Value, mlua::Value>() {
258 let (key, value) = pair.map_err(|e| HookError::LuaError(e.to_string()))?;
259
260 let yaml_key = match key {
261 mlua::Value::String(s) => {
262 let str_val =
263 s.to_str().map_err(|e| HookError::LuaError(e.to_string()))?;
264 serde_yaml::Value::String(str_val.to_string())
265 }
266 mlua::Value::Integer(i) => serde_yaml::Value::Number(i.into()),
267 _ => continue, };
269
270 let yaml_value = lua_value_to_yaml(value)?;
271 map.insert(yaml_key, yaml_value);
272 }
273
274 Ok(serde_yaml::Value::Mapping(map))
275}
276
277fn lua_value_to_yaml(value: mlua::Value) -> Result<serde_yaml::Value, HookError> {
279 match value {
280 mlua::Value::Nil => Ok(serde_yaml::Value::Null),
281 mlua::Value::Boolean(b) => Ok(serde_yaml::Value::Bool(b)),
282 mlua::Value::Integer(i) => Ok(serde_yaml::Value::Number(i.into())),
283 mlua::Value::Number(n) => {
284 Ok(serde_yaml::Value::Number(serde_yaml::Number::from(n)))
285 }
286 mlua::Value::String(s) => {
287 let str_val = s.to_str().map_err(|e| HookError::LuaError(e.to_string()))?;
288 Ok(serde_yaml::Value::String(str_val.to_string()))
289 }
290 mlua::Value::Table(t) => {
291 if is_lua_array(&t) {
293 let mut seq = Vec::new();
294 for pair in t.pairs::<i64, mlua::Value>() {
295 let (_, v) = pair.map_err(|e| HookError::LuaError(e.to_string()))?;
296 seq.push(lua_value_to_yaml(v)?);
297 }
298 Ok(serde_yaml::Value::Sequence(seq))
299 } else {
300 lua_table_to_yaml(&t)
301 }
302 }
303 _ => Ok(serde_yaml::Value::Null),
304 }
305}
306
307fn is_lua_array(table: &mlua::Table) -> bool {
309 let len = table.raw_len();
310 if len == 0 {
311 table.pairs::<mlua::Value, mlua::Value>().next().is_none()
313 } else {
314 for i in 1..=len {
316 if table.raw_get::<mlua::Value>(i).is_err() {
317 return false;
318 }
319 }
320 true
321 }
322}
323
324#[cfg(test)]
325mod tests {
326 use super::*;
327 use std::collections::HashMap;
328 use std::path::PathBuf;
329
330 fn make_typedef_with_hook(lua_source: &str) -> TypeDefinition {
331 TypeDefinition {
332 name: "test".to_string(),
333 description: None,
334 source_path: PathBuf::new(),
335 schema: HashMap::new(),
336 has_validate_fn: false,
337 has_on_create_hook: true,
338 has_on_update_hook: false,
339 is_builtin_override: false,
340 lua_source: lua_source.to_string(),
341 }
342 }
343
344 fn make_note_ctx() -> NoteContext {
345 NoteContext {
346 path: PathBuf::from("test.md"),
347 note_type: "test".to_string(),
348 frontmatter: serde_yaml::Value::Mapping(serde_yaml::Mapping::new()),
349 content: "# Test".to_string(),
350 }
351 }
352
353 #[test]
354 fn test_skip_if_no_hook() {
355 let typedef = TypeDefinition {
356 name: "test".to_string(),
357 description: None,
358 source_path: PathBuf::new(),
359 schema: HashMap::new(),
360 has_validate_fn: false,
361 has_on_create_hook: false, has_on_update_hook: false,
363 is_builtin_override: false,
364 lua_source: String::new(),
365 };
366
367 let _note_ctx = make_note_ctx();
371
372 assert!(!typedef.has_on_create_hook);
377 }
378
379 #[test]
380 fn test_hook_receives_note_context() {
381 let lua_source = r#"
384 return {
385 on_create = function(note)
386 -- Just verify we can access note fields
387 local _ = note.path
388 local _ = note.type
389 local _ = note.content
390 local _ = note.frontmatter
391 return note
392 end
393 }
394 "#;
395
396 let _typedef = make_typedef_with_hook(lua_source);
397 let _note_ctx = make_note_ctx();
398
399 let engine = LuaEngine::sandboxed().unwrap();
401 let lua = engine.lua();
402
403 let typedef_table: mlua::Table = lua.load(lua_source).eval().unwrap();
405
406 let note_table = lua.create_table().unwrap();
408 note_table.set("path", "test.md").unwrap();
409 note_table.set("type", "test").unwrap();
410 note_table.set("content", "# Test").unwrap();
411 let fm = lua.create_table().unwrap();
412 note_table.set("frontmatter", fm).unwrap();
413
414 let on_create: mlua::Function = typedef_table.get("on_create").unwrap();
416 let result = on_create.call::<mlua::Value>(note_table);
417
418 assert!(result.is_ok());
419 }
420}