1use std::process::Stdio;
2use std::time::Duration;
3
4use serde::{Deserialize, Serialize};
5
6use super::registry::Plugin;
7
8#[derive(Debug, Clone, Serialize)]
9#[serde(tag = "hook", rename_all = "snake_case")]
10pub enum HookPoint {
11 OnSessionStart,
12 OnSessionEnd,
13 PreRead {
14 path: String,
15 },
16 PostCompress {
17 path: String,
18 original_tokens: usize,
19 compressed_tokens: usize,
20 },
21 OnKnowledgeUpdate {
22 fact_id: String,
23 },
24}
25
26impl HookPoint {
27 pub fn hook_name(&self) -> &'static str {
28 match self {
29 Self::OnSessionStart => "on_session_start",
30 Self::OnSessionEnd => "on_session_end",
31 Self::PreRead { .. } => "pre_read",
32 Self::PostCompress { .. } => "post_compress",
33 Self::OnKnowledgeUpdate { .. } => "on_knowledge_update",
34 }
35 }
36
37 pub fn all_hook_names() -> &'static [&'static str] {
38 &[
39 "on_session_start",
40 "on_session_end",
41 "pre_read",
42 "post_compress",
43 "on_knowledge_update",
44 ]
45 }
46}
47
48#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct HookResult {
50 pub plugin_name: String,
51 pub success: bool,
52 #[serde(skip_serializing_if = "Option::is_none")]
53 pub output: Option<String>,
54 #[serde(skip_serializing_if = "Option::is_none")]
55 pub error: Option<String>,
56 pub duration_ms: u64,
57}
58
59pub fn execute_hook_sync(plugin: &Plugin, hook: &HookPoint) -> HookResult {
60 let hook_name = hook.hook_name();
61 let plugin_name = plugin.manifest.plugin.name.clone();
62
63 let Some(entry) = plugin.manifest.hooks.get(hook_name) else {
64 return HookResult {
65 plugin_name,
66 success: true,
67 output: None,
68 error: None,
69 duration_ms: 0,
70 };
71 };
72
73 let timeout = Duration::from_millis(entry.timeout_ms);
74 let start = std::time::Instant::now();
75
76 let hook_json = match serde_json::to_string(hook) {
77 Ok(j) => j,
78 Err(e) => {
79 return HookResult {
80 plugin_name,
81 success: false,
82 output: None,
83 error: Some(format!("failed to serialize hook data: {e}")),
84 duration_ms: start.elapsed().as_millis() as u64,
85 };
86 }
87 };
88
89 let parts: Vec<&str> = entry.command.split_whitespace().collect();
90 if parts.is_empty() {
91 return HookResult {
92 plugin_name,
93 success: false,
94 output: None,
95 error: Some("empty command".to_string()),
96 duration_ms: start.elapsed().as_millis() as u64,
97 };
98 }
99
100 let mut cmd = std::process::Command::new(parts[0]);
101 if parts.len() > 1 {
102 cmd.args(&parts[1..]);
103 }
104 cmd.stdin(Stdio::piped())
105 .stdout(Stdio::piped())
106 .stderr(Stdio::piped())
107 .env("LEAN_CTX_HOOK", hook_name)
108 .env("LEAN_CTX_PLUGIN_DIR", &plugin.path);
109
110 let mut child = match cmd.spawn() {
111 Ok(c) => c,
112 Err(e) => {
113 return HookResult {
114 plugin_name,
115 success: false,
116 output: None,
117 error: Some(format!("failed to spawn: {e}")),
118 duration_ms: start.elapsed().as_millis() as u64,
119 };
120 }
121 };
122
123 if let Some(ref mut stdin) = child.stdin.take() {
124 use std::io::Write;
125 let _ = stdin.write_all(hook_json.as_bytes());
126 }
127
128 let result = wait_with_timeout(&mut child, timeout);
129 let duration_ms = start.elapsed().as_millis() as u64;
130
131 match result {
132 Ok(output) => {
133 let stdout = String::from_utf8_lossy(&output.stdout).to_string();
134 let stderr = String::from_utf8_lossy(&output.stderr).to_string();
135 let success = output.status.success();
136 HookResult {
137 plugin_name,
138 success,
139 output: if stdout.is_empty() {
140 None
141 } else {
142 Some(stdout)
143 },
144 error: if stderr.is_empty() && success {
145 None
146 } else if !stderr.is_empty() {
147 Some(stderr)
148 } else {
149 Some(format!("exit code: {}", output.status))
150 },
151 duration_ms,
152 }
153 }
154 Err(e) => HookResult {
155 plugin_name,
156 success: false,
157 output: None,
158 error: Some(e),
159 duration_ms,
160 },
161 }
162}
163
164fn wait_with_timeout(
165 child: &mut std::process::Child,
166 timeout: Duration,
167) -> Result<std::process::Output, String> {
168 let deadline = std::time::Instant::now() + timeout;
169 loop {
170 match child.try_wait() {
171 Ok(Some(status)) => {
172 let stdout = child
173 .stdout
174 .take()
175 .map(|mut s| {
176 use std::io::Read;
177 let mut buf = Vec::new();
178 let _ = s.read_to_end(&mut buf);
179 buf
180 })
181 .unwrap_or_default();
182 let stderr = child
183 .stderr
184 .take()
185 .map(|mut s| {
186 use std::io::Read;
187 let mut buf = Vec::new();
188 let _ = s.read_to_end(&mut buf);
189 buf
190 })
191 .unwrap_or_default();
192 return Ok(std::process::Output {
193 status,
194 stdout,
195 stderr,
196 });
197 }
198 Ok(None) => {
199 if std::time::Instant::now() >= deadline {
200 let _ = child.kill();
201 return Err(format!("timeout after {}ms", timeout.as_millis()));
202 }
203 std::thread::sleep(Duration::from_millis(10));
204 }
205 Err(e) => return Err(format!("wait error: {e}")),
206 }
207 }
208}
209
210pub fn execute_hooks_for_point(plugins: &[&Plugin], hook: &HookPoint) -> Vec<HookResult> {
211 let hook_name = hook.hook_name();
212 plugins
213 .iter()
214 .filter(|p| p.enabled && p.manifest.hooks.contains_key(hook_name))
215 .map(|p| execute_hook_sync(p, hook))
216 .collect()
217}
218
219#[cfg(test)]
220mod tests {
221 use super::*;
222
223 #[test]
224 fn hook_point_names() {
225 assert_eq!(HookPoint::OnSessionStart.hook_name(), "on_session_start");
226 assert_eq!(HookPoint::OnSessionEnd.hook_name(), "on_session_end");
227 assert_eq!(
228 HookPoint::PreRead { path: "x".into() }.hook_name(),
229 "pre_read"
230 );
231 assert_eq!(
232 HookPoint::PostCompress {
233 path: "x".into(),
234 original_tokens: 100,
235 compressed_tokens: 50,
236 }
237 .hook_name(),
238 "post_compress"
239 );
240 assert_eq!(
241 HookPoint::OnKnowledgeUpdate {
242 fact_id: "f1".into()
243 }
244 .hook_name(),
245 "on_knowledge_update"
246 );
247 }
248
249 #[test]
250 fn all_hook_names_complete() {
251 let names = HookPoint::all_hook_names();
252 assert_eq!(names.len(), 5);
253 assert!(names.contains(&"on_session_start"));
254 assert!(names.contains(&"pre_read"));
255 assert!(names.contains(&"post_compress"));
256 }
257
258 #[test]
259 fn hook_point_serializes_to_json() {
260 let hook = HookPoint::PostCompress {
261 path: "/tmp/file.rs".into(),
262 original_tokens: 1000,
263 compressed_tokens: 200,
264 };
265 let json = serde_json::to_string(&hook).unwrap();
266 assert!(json.contains("post_compress"));
267 assert!(json.contains("1000"));
268 assert!(json.contains("200"));
269 }
270
271 #[test]
272 fn execute_missing_hook_is_noop() {
273 let manifest = crate::core::plugins::manifest::PluginManifest::from_str(
274 r#"
275[plugin]
276name = "no-hooks"
277version = "1.0.0"
278"#,
279 &std::path::PathBuf::from("test.toml"),
280 )
281 .unwrap();
282
283 let plugin = Plugin {
284 manifest,
285 enabled: true,
286 path: std::path::PathBuf::from("/tmp/no-hooks"),
287 };
288
289 let result = execute_hook_sync(&plugin, &HookPoint::OnSessionStart);
290 assert!(result.success);
291 assert_eq!(result.duration_ms, 0);
292 }
293
294 #[test]
295 fn execute_nonexistent_binary_fails() {
296 let manifest = crate::core::plugins::manifest::PluginManifest::from_str(
297 r#"
298[plugin]
299name = "bad-binary"
300version = "1.0.0"
301
302[hooks.on_session_start]
303command = "__nonexistent_lean_ctx_test_binary__ start"
304timeout_ms = 1000
305"#,
306 &std::path::PathBuf::from("test.toml"),
307 )
308 .unwrap();
309
310 let plugin = Plugin {
311 manifest,
312 enabled: true,
313 path: std::path::PathBuf::from("/tmp/bad-binary"),
314 };
315
316 let result = execute_hook_sync(&plugin, &HookPoint::OnSessionStart);
317 assert!(!result.success);
318 assert!(result.error.unwrap().contains("failed to spawn"));
319 }
320
321 #[cfg(unix)]
322 #[test]
323 fn execute_echo_plugin_succeeds() {
324 let manifest = crate::core::plugins::manifest::PluginManifest::from_str(
325 r#"
326[plugin]
327name = "echo-plugin"
328version = "1.0.0"
329
330[hooks.on_session_start]
331command = "echo hello"
332timeout_ms = 2000
333"#,
334 &std::path::PathBuf::from("test.toml"),
335 )
336 .unwrap();
337
338 let plugin = Plugin {
339 manifest,
340 enabled: true,
341 path: std::path::PathBuf::from("/tmp/echo-plugin"),
342 };
343
344 let result = execute_hook_sync(&plugin, &HookPoint::OnSessionStart);
345 assert!(result.success);
346 assert!(result.output.unwrap().contains("hello"));
347 }
348}