use crate::command::chat::constants::{HOOK_LLM_MAX_TOKENS, HOOK_LLM_TEMPERATURE};
use crate::command::chat::infra::hook::definition::*;
use crate::command::chat::infra::hook::types::*;
use crate::command::chat::storage::ModelProvider;
use crate::util::log::write_info_log;
use std::io::Write;
use std::path::PathBuf;
use std::process::Command;
use std::sync::{Arc, Mutex};
pub(crate) fn execute_hook_with_provider(
kind: &HookKind,
context: &HookContext,
provider: &Option<Arc<Mutex<ModelProvider>>>,
) -> Result<HookResult, String> {
match kind {
HookKind::Shell(shell) => execute_shell_hook(shell, context),
HookKind::Llm(llm) => execute_llm_hook(llm, context, provider),
HookKind::Builtin(builtin) => match (builtin.handler)(context) {
Some(result) => Ok(result),
None => Ok(HookResult::default()),
},
}
}
const LLM_HOOK_FORMAT_INSTRUCTION: &str = r#"
---
You are a hook function. You MUST respond with ONLY a valid JSON object matching this schema (no markdown, no explanation outside JSON):
{
"user_input": "string (optional, replace user message)",
"assistant_output": "string (optional, replace assistant output)",
"messages": [{"role":"user","content":"..."}] (optional, replace message list),
"system_prompt": "string (optional, replace system prompt)",
"tool_arguments": "string (optional, replace tool arguments JSON)",
"tool_result": "string (optional, replace tool result)",
"tool_error": "string (optional, replace tool error)",
"inject_messages": [{"role":"user","content":"..."}] (optional, append messages),
"action": "stop" or "skip" (optional, stop=abort pipeline, skip=skip current step),
"retry_feedback": "string (optional, feedback to retry with)",
"additional_context": "string (optional, append to system_prompt)",
"system_message": "string (optional, show toast to user)"
}
Return {} if no modification needed."#;
pub(crate) fn render_prompt_template(template: &str, context: &HookContext) -> String {
let mut result = template.to_string();
result = result.replace("{{event}}", context.event.as_str());
result = result.replace("{{cwd}}", &context.cwd);
result = result.replace(
"{{user_input}}",
context.user_input.as_deref().unwrap_or(""),
);
result = result.replace(
"{{assistant_output}}",
context.assistant_output.as_deref().unwrap_or(""),
);
result = result.replace("{{tool_name}}", context.tool_name.as_deref().unwrap_or(""));
result = result.replace(
"{{tool_arguments}}",
context.tool_arguments.as_deref().unwrap_or(""),
);
result = result.replace(
"{{tool_result}}",
context.tool_result.as_deref().unwrap_or(""),
);
result = result.replace("{{model}}", context.model.as_deref().unwrap_or(""));
result
}
pub(crate) fn extract_json_from_llm_output(text: &str) -> Option<&str> {
let start = text.find('{')?;
let end = text.rfind('}')?;
if end > start {
Some(&text[start..=end])
} else {
None
}
}
pub(crate) fn execute_llm_hook(
hook: &LlmHook,
context: &HookContext,
provider_opt: &Option<Arc<Mutex<ModelProvider>>>,
) -> Result<HookResult, String> {
let provider_arc = provider_opt
.as_ref()
.ok_or("LLM hook 无法执行:未注入 provider")?;
let provider = provider_arc
.lock()
.map_err(|e| format!("获取 provider 锁失败: {}", e))?
.clone();
let provider = if let Some(ref model) = hook.model {
let mut p = provider;
p.model = model.clone();
p
} else {
provider
};
let rendered = render_prompt_template(&hook.prompt, context);
let full_prompt = format!("{}{}", rendered, LLM_HOOK_FORMAT_INSTRUCTION);
let system_msg = "You are a hook function. Respond ONLY with the JSON object as instructed.";
let user_msg = full_prompt.as_str();
let url = format!(
"{}/chat/completions",
provider.api_base.trim_end_matches('/')
);
let request_body = serde_json::json!({
"model": provider.model,
"messages": [
{"role": "system", "content": system_msg},
{"role": "user", "content": user_msg}
],
"temperature": HOOK_LLM_TEMPERATURE,
"max_tokens": HOOK_LLM_MAX_TOKENS,
});
let request_str = serde_json::to_string(&request_body)
.map_err(|e| format!("序列化 LLM hook 请求失败: {}", e))?;
let timeout_secs = hook.timeout;
let rt =
tokio::runtime::Runtime::new().map_err(|e| format!("创建 tokio runtime 失败: {}", e))?;
rt.block_on(async {
let client = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(timeout_secs))
.build()
.map_err(|e| format!("创建 HTTP client 失败: {}", e))?;
let resp = client
.post(&url)
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", provider.api_key))
.body(request_str)
.send()
.await
.map_err(|e| format!("LLM hook 请求失败: {}", e))?;
let status = resp.status();
let body = resp
.text()
.await
.map_err(|e| format!("读取 LLM hook 响应失败: {}", e))?;
if !status.is_success() {
return Err(format!(
"LLM hook API 错误: HTTP {} (body: {})",
status,
&body[..body.len().min(500)]
));
}
let parsed: serde_json::Value = serde_json::from_str(&body)
.map_err(|e| format!("解析 LLM hook 响应 JSON 失败: {}", e))?;
let content = parsed["choices"][0]["message"]["content"]
.as_str()
.unwrap_or("")
.trim();
if content.is_empty() || content == "{}" {
return Ok(HookResult::default());
}
let json_str = match extract_json_from_llm_output(content) {
Some(s) => s,
None => {
return Err(format!(
"LLM hook 输出中未找到 JSON (输出: {})",
&content[..content.len().min(500)]
));
}
};
let hook_result: HookResult = serde_json::from_str(json_str).map_err(|e| {
format!(
"解析 LLM hook JSON 失败: {} (提取的 JSON: {})",
e,
&json_str[..json_str.len().min(500)]
)
})?;
write_info_log(
"execute_llm_hook",
&format!(
"LLM hook 完成 (prompt_len={}, model={}), action={:?}",
hook.prompt.len(),
provider.model,
hook_result.action
),
);
Ok(hook_result)
})
}
pub(crate) fn execute_shell_hook(
hook: &ShellHook,
context: &HookContext,
) -> Result<HookResult, String> {
let context_json =
serde_json::to_string(context).map_err(|e| format!("序列化 context 失败: {}", e))?;
let user_cwd = std::env::current_dir().unwrap_or_else(|_| PathBuf::from("."));
let hook_dir_str = hook
.dir_path
.as_ref()
.map(|p| p.display().to_string())
.unwrap_or_default();
let mut cmd = Command::new("sh");
cmd.arg("-c")
.arg(&hook.command)
.current_dir(&user_cwd)
.env("JCLI_HOOK_EVENT", context.event.as_str())
.env("JCLI_CWD", user_cwd.display().to_string())
.env("JCLI_HOOK_DIR", &hook_dir_str);
if let Some(ref hook_dir) = hook.dir_path {
let existing_path = std::env::var("PATH").unwrap_or_default();
let new_path = if existing_path.is_empty() {
hook_dir.display().to_string()
} else {
format!("{}:{}", hook_dir.display(), existing_path)
};
cmd.env("PATH", new_path);
}
let mut child = cmd
.stdin(std::process::Stdio::piped())
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::piped())
.spawn()
.map_err(|e| format!("启动 hook 进程失败: {}", e))?;
let pid = child.id();
if let Some(mut stdin) = child.stdin.take() {
let _ = stdin.write_all(context_json.as_bytes());
}
let (tx, rx) = std::sync::mpsc::channel();
std::thread::spawn(move || {
let _ = tx.send(child.wait_with_output());
});
let timeout = std::time::Duration::from_secs(hook.timeout);
match rx.recv_timeout(timeout) {
Ok(Ok(output)) => {
let stderr_str = String::from_utf8_lossy(&output.stderr).trim().to_string();
if !stderr_str.is_empty() {
write_info_log(
"execute_shell_hook",
&format!("Hook stderr ({}): {}", hook.command, stderr_str),
);
}
if !output.status.success() {
let mut err = format!("Hook 退出码: {:?}", output.status.code());
if !stderr_str.is_empty() {
err.push_str(&format!(", stderr: {}", stderr_str));
}
return Err(err);
}
let stdout = String::from_utf8_lossy(&output.stdout);
let stdout = stdout.trim();
if stdout.is_empty() || stdout == "{}" {
return Ok(HookResult::default());
}
let result: HookResult = serde_json::from_str(stdout)
.map_err(|e| format!("解析 hook 输出 JSON 失败: {} (输出: {})", e, stdout))?;
write_info_log(
"execute_shell_hook",
&format!(
"Hook 完成 (cmd: {}), action={:?}",
hook.command, result.action
),
);
Ok(result)
}
Ok(Err(e)) => Err(format!("等待 hook 进程失败: {}", e)),
Err(_) => {
let _ = nix::sys::signal::kill(
nix::unistd::Pid::from_raw(pid as i32),
nix::sys::signal::Signal::SIGKILL,
);
Err(format!("Hook 超时 ({}s): {}", hook.timeout, hook.command))
}
}
}