use super::engine::LuaEngine;
use super::hooks::{HookError, NoteContext};
use super::types::SandboxConfig;
use super::vault_context::VaultContext;
use crate::types::definition::TypeDefinition;
use crate::types::validation::yaml_to_lua_table;
use tracing::debug;
#[derive(Debug)]
pub struct HookResult {
pub modified: bool,
pub frontmatter: Option<serde_yaml::Value>,
pub content: Option<String>,
pub variables: Option<serde_yaml::Value>,
}
pub type UpdateHookResult = HookResult;
pub fn run_on_create_hook(
typedef: &TypeDefinition,
note_ctx: &NoteContext,
vault_ctx: VaultContext,
) -> Result<HookResult, HookError> {
if !typedef.has_on_create_hook {
return Ok(HookResult {
modified: false,
frontmatter: None,
content: None,
variables: None,
});
}
let engine = LuaEngine::with_vault_context(SandboxConfig::restricted(), vault_ctx)
.map_err(|e| HookError::LuaError(e.to_string()))?;
let lua = engine.lua();
let typedef_table: mlua::Table =
lua.load(&typedef.lua_source).eval().map_err(|e| {
HookError::LuaError(format!("failed to load type definition: {}", e))
})?;
let note_table =
lua.create_table().map_err(|e| HookError::LuaError(e.to_string()))?;
note_table
.set("path", note_ctx.path.to_string_lossy().to_string())
.map_err(|e| HookError::LuaError(e.to_string()))?;
note_table
.set("type", note_ctx.note_type.clone())
.map_err(|e| HookError::LuaError(e.to_string()))?;
note_table
.set("content", note_ctx.content.clone())
.map_err(|e| HookError::LuaError(e.to_string()))?;
let fm_table = yaml_to_lua_table(lua, ¬e_ctx.frontmatter)
.map_err(|e| HookError::LuaError(e.to_string()))?;
note_table
.set("frontmatter", fm_table)
.map_err(|e| HookError::LuaError(e.to_string()))?;
debug!("Hook input variables: {:?}", note_ctx.variables);
let vars_table = yaml_to_lua_table(lua, ¬e_ctx.variables)
.map_err(|e| HookError::LuaError(e.to_string()))?;
note_table
.set("variables", vars_table)
.map_err(|e| HookError::LuaError(e.to_string()))?;
let on_create_fn: mlua::Function = typedef_table.get("on_create").map_err(|e| {
HookError::LuaError(format!("on_create function not found: {}", e))
})?;
let result: mlua::Value = on_create_fn
.call(note_table)
.map_err(|e| HookError::Execution(format!("on_create hook failed: {}", e)))?;
match result {
mlua::Value::Table(returned_note) => {
let frontmatter: Option<serde_yaml::Value> =
if let Ok(fm_table) = returned_note.get::<mlua::Table>("frontmatter") {
Some(lua_table_to_yaml(&fm_table)?)
} else {
None
};
let content: Option<String> = returned_note.get("content").ok();
let content = match content {
Some(ref c) if c != ¬e_ctx.content => content,
_ => None,
};
let variables: Option<serde_yaml::Value> =
if let Ok(vars_table) = returned_note.get::<mlua::Table>("variables") {
let v = Some(lua_table_to_yaml(&vars_table)?);
debug!("Hook output variables: {:?}", v);
v
} else {
None
};
let modified =
frontmatter.is_some() || content.is_some() || variables.is_some();
Ok(HookResult { modified, frontmatter, content, variables })
}
mlua::Value::Nil => {
Ok(HookResult {
modified: false,
frontmatter: None,
content: None,
variables: None,
})
}
_ => {
Ok(HookResult {
modified: false,
frontmatter: None,
content: None,
variables: None,
})
}
}
}
pub fn run_on_update_hook(
typedef: &TypeDefinition,
note_ctx: &NoteContext,
vault_ctx: VaultContext,
) -> Result<UpdateHookResult, HookError> {
if !typedef.has_on_update_hook {
return Ok(UpdateHookResult {
modified: false,
frontmatter: None,
content: None,
variables: None,
});
}
let engine = LuaEngine::with_vault_context(SandboxConfig::restricted(), vault_ctx)
.map_err(|e| HookError::LuaError(e.to_string()))?;
let lua = engine.lua();
let typedef_table: mlua::Table =
lua.load(&typedef.lua_source).eval().map_err(|e| {
HookError::LuaError(format!("failed to load type definition: {}", e))
})?;
let note_table =
lua.create_table().map_err(|e| HookError::LuaError(e.to_string()))?;
note_table
.set("path", note_ctx.path.to_string_lossy().to_string())
.map_err(|e| HookError::LuaError(e.to_string()))?;
note_table
.set("type", note_ctx.note_type.clone())
.map_err(|e| HookError::LuaError(e.to_string()))?;
note_table
.set("content", note_ctx.content.clone())
.map_err(|e| HookError::LuaError(e.to_string()))?;
let fm_table = yaml_to_lua_table(lua, ¬e_ctx.frontmatter)
.map_err(|e| HookError::LuaError(e.to_string()))?;
note_table
.set("frontmatter", fm_table)
.map_err(|e| HookError::LuaError(e.to_string()))?;
let on_update_fn: mlua::Function = typedef_table.get("on_update").map_err(|e| {
HookError::LuaError(format!("on_update function not found: {}", e))
})?;
let result: mlua::Value = on_update_fn
.call(note_table)
.map_err(|e| HookError::Execution(format!("on_update hook failed: {}", e)))?;
match result {
mlua::Value::Table(returned_note) => {
let frontmatter: Option<serde_yaml::Value> =
if let Ok(fm_table) = returned_note.get::<mlua::Table>("frontmatter") {
Some(lua_table_to_yaml(&fm_table)?)
} else {
None
};
let content: Option<String> = returned_note.get("content").ok();
let modified = frontmatter.is_some() || content.is_some();
Ok(UpdateHookResult { modified, frontmatter, content, variables: None })
}
mlua::Value::Nil => {
Ok(UpdateHookResult {
modified: false,
frontmatter: None,
content: None,
variables: None,
})
}
_ => {
Ok(UpdateHookResult {
modified: false,
frontmatter: None,
content: None,
variables: None,
})
}
}
}
fn lua_table_to_yaml(table: &mlua::Table) -> Result<serde_yaml::Value, HookError> {
let mut map = serde_yaml::Mapping::new();
for pair in table.pairs::<mlua::Value, mlua::Value>() {
let (key, value) = pair.map_err(|e| HookError::LuaError(e.to_string()))?;
let yaml_key = match key {
mlua::Value::String(s) => {
let str_val =
s.to_str().map_err(|e| HookError::LuaError(e.to_string()))?;
serde_yaml::Value::String(str_val.to_string())
}
mlua::Value::Integer(i) => serde_yaml::Value::Number(i.into()),
_ => continue, };
let yaml_value = lua_value_to_yaml(value)?;
map.insert(yaml_key, yaml_value);
}
Ok(serde_yaml::Value::Mapping(map))
}
fn lua_value_to_yaml(value: mlua::Value) -> Result<serde_yaml::Value, HookError> {
match value {
mlua::Value::Nil => Ok(serde_yaml::Value::Null),
mlua::Value::Boolean(b) => Ok(serde_yaml::Value::Bool(b)),
mlua::Value::Integer(i) => Ok(serde_yaml::Value::Number(i.into())),
mlua::Value::Number(n) => {
Ok(serde_yaml::Value::Number(serde_yaml::Number::from(n)))
}
mlua::Value::String(s) => {
let str_val = s.to_str().map_err(|e| HookError::LuaError(e.to_string()))?;
Ok(serde_yaml::Value::String(str_val.to_string()))
}
mlua::Value::Table(t) => {
if is_lua_array(&t) {
let mut seq = Vec::new();
for pair in t.pairs::<i64, mlua::Value>() {
let (_, v) = pair.map_err(|e| HookError::LuaError(e.to_string()))?;
seq.push(lua_value_to_yaml(v)?);
}
Ok(serde_yaml::Value::Sequence(seq))
} else {
lua_table_to_yaml(&t)
}
}
_ => Ok(serde_yaml::Value::Null),
}
}
fn is_lua_array(table: &mlua::Table) -> bool {
let len = table.raw_len();
if len == 0 {
table.pairs::<mlua::Value, mlua::Value>().next().is_none()
} else {
for i in 1..=len {
if table.raw_get::<mlua::Value>(i).is_err() {
return false;
}
}
true
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
use std::path::PathBuf;
fn make_typedef_with_hook(lua_source: &str) -> TypeDefinition {
TypeDefinition {
name: "test".to_string(),
description: None,
source_path: PathBuf::new(),
schema: HashMap::new(),
output: None,
frontmatter_order: None,
variables: crate::vars::VarsMap::new(),
has_validate_fn: false,
has_on_create_hook: true,
has_on_update_hook: false,
is_builtin_override: false,
lua_source: lua_source.to_string(),
}
}
fn make_note_ctx() -> NoteContext {
NoteContext {
path: PathBuf::from("test.md"),
note_type: "test".to_string(),
frontmatter: serde_yaml::Value::Mapping(serde_yaml::Mapping::new()),
content: "# Test".to_string(),
variables: serde_yaml::Value::Null,
}
}
#[test]
fn test_skip_if_no_hook() {
let typedef = TypeDefinition {
name: "test".to_string(),
description: None,
source_path: PathBuf::new(),
schema: HashMap::new(),
output: None,
frontmatter_order: None,
variables: crate::vars::VarsMap::new(),
has_validate_fn: false,
has_on_create_hook: false, has_on_update_hook: false,
is_builtin_override: false,
lua_source: String::new(),
};
let _note_ctx = make_note_ctx();
assert!(!typedef.has_on_create_hook);
}
#[test]
fn test_hook_receives_note_context() {
let lua_source = r#"
return {
on_create = function(note)
-- Just verify we can access note fields
local _ = note.path
local _ = note.type
local _ = note.content
local _ = note.frontmatter
return note
end
}
"#;
let _typedef = make_typedef_with_hook(lua_source);
let _note_ctx = make_note_ctx();
let engine = LuaEngine::sandboxed().unwrap();
let lua = engine.lua();
let typedef_table: mlua::Table = lua.load(lua_source).eval().unwrap();
let note_table = lua.create_table().unwrap();
note_table.set("path", "test.md").unwrap();
note_table.set("type", "test").unwrap();
note_table.set("content", "# Test").unwrap();
let fm = lua.create_table().unwrap();
note_table.set("frontmatter", fm).unwrap();
let on_create: mlua::Function = typedef_table.get("on_create").unwrap();
let result = on_create.call::<mlua::Value>(note_table);
assert!(result.is_ok());
}
}