Skip to main content

lean_ctx/core/plugins/
executor.rs

1use std::process::Stdio;
2use std::time::Duration;
3
4use serde::{Deserialize, Serialize};
5
6use super::registry::Plugin;
7
8#[derive(Debug, Clone, Serialize)]
9#[serde(tag = "hook", rename_all = "snake_case")]
10pub enum HookPoint {
11    OnSessionStart,
12    OnSessionEnd,
13    PreRead {
14        path: String,
15    },
16    PostCompress {
17        path: String,
18        original_tokens: usize,
19        compressed_tokens: usize,
20    },
21    OnKnowledgeUpdate {
22        fact_id: String,
23    },
24}
25
26impl HookPoint {
27    pub fn hook_name(&self) -> &'static str {
28        match self {
29            Self::OnSessionStart => "on_session_start",
30            Self::OnSessionEnd => "on_session_end",
31            Self::PreRead { .. } => "pre_read",
32            Self::PostCompress { .. } => "post_compress",
33            Self::OnKnowledgeUpdate { .. } => "on_knowledge_update",
34        }
35    }
36
37    pub fn all_hook_names() -> &'static [&'static str] {
38        &[
39            "on_session_start",
40            "on_session_end",
41            "pre_read",
42            "post_compress",
43            "on_knowledge_update",
44        ]
45    }
46}
47
48#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct HookResult {
50    pub plugin_name: String,
51    pub success: bool,
52    #[serde(skip_serializing_if = "Option::is_none")]
53    pub output: Option<String>,
54    #[serde(skip_serializing_if = "Option::is_none")]
55    pub error: Option<String>,
56    pub duration_ms: u64,
57}
58
59pub fn execute_hook_sync(plugin: &Plugin, hook: &HookPoint) -> HookResult {
60    let hook_name = hook.hook_name();
61    let plugin_name = plugin.manifest.plugin.name.clone();
62
63    let Some(entry) = plugin.manifest.hooks.get(hook_name) else {
64        return HookResult {
65            plugin_name,
66            success: true,
67            output: None,
68            error: None,
69            duration_ms: 0,
70        };
71    };
72
73    let timeout = Duration::from_millis(entry.timeout_ms);
74    let start = std::time::Instant::now();
75
76    let hook_json = match serde_json::to_string(hook) {
77        Ok(j) => j,
78        Err(e) => {
79            return HookResult {
80                plugin_name,
81                success: false,
82                output: None,
83                error: Some(format!("failed to serialize hook data: {e}")),
84                duration_ms: start.elapsed().as_millis() as u64,
85            };
86        }
87    };
88
89    let parts: Vec<&str> = entry.command.split_whitespace().collect();
90    if parts.is_empty() {
91        return HookResult {
92            plugin_name,
93            success: false,
94            output: None,
95            error: Some("empty command".to_string()),
96            duration_ms: start.elapsed().as_millis() as u64,
97        };
98    }
99
100    let mut cmd = std::process::Command::new(parts[0]);
101    if parts.len() > 1 {
102        cmd.args(&parts[1..]);
103    }
104    cmd.stdin(Stdio::piped())
105        .stdout(Stdio::piped())
106        .stderr(Stdio::piped())
107        .env("LEAN_CTX_HOOK", hook_name)
108        .env("LEAN_CTX_PLUGIN_DIR", &plugin.path);
109
110    let mut child = match cmd.spawn() {
111        Ok(c) => c,
112        Err(e) => {
113            return HookResult {
114                plugin_name,
115                success: false,
116                output: None,
117                error: Some(format!("failed to spawn: {e}")),
118                duration_ms: start.elapsed().as_millis() as u64,
119            };
120        }
121    };
122
123    if let Some(ref mut stdin) = child.stdin.take() {
124        use std::io::Write;
125        let _ = stdin.write_all(hook_json.as_bytes());
126    }
127
128    let result = wait_with_timeout(&mut child, timeout);
129    let duration_ms = start.elapsed().as_millis() as u64;
130
131    match result {
132        Ok(output) => {
133            let stdout = String::from_utf8_lossy(&output.stdout).to_string();
134            let stderr = String::from_utf8_lossy(&output.stderr).to_string();
135            let success = output.status.success();
136            HookResult {
137                plugin_name,
138                success,
139                output: if stdout.is_empty() {
140                    None
141                } else {
142                    Some(stdout)
143                },
144                error: if stderr.is_empty() && success {
145                    None
146                } else if !stderr.is_empty() {
147                    Some(stderr)
148                } else {
149                    Some(format!("exit code: {}", output.status))
150                },
151                duration_ms,
152            }
153        }
154        Err(e) => HookResult {
155            plugin_name,
156            success: false,
157            output: None,
158            error: Some(e),
159            duration_ms,
160        },
161    }
162}
163
164fn wait_with_timeout(
165    child: &mut std::process::Child,
166    timeout: Duration,
167) -> Result<std::process::Output, String> {
168    let deadline = std::time::Instant::now() + timeout;
169    loop {
170        match child.try_wait() {
171            Ok(Some(status)) => {
172                let stdout = child
173                    .stdout
174                    .take()
175                    .map(|mut s| {
176                        use std::io::Read;
177                        let mut buf = Vec::new();
178                        let _ = s.read_to_end(&mut buf);
179                        buf
180                    })
181                    .unwrap_or_default();
182                let stderr = child
183                    .stderr
184                    .take()
185                    .map(|mut s| {
186                        use std::io::Read;
187                        let mut buf = Vec::new();
188                        let _ = s.read_to_end(&mut buf);
189                        buf
190                    })
191                    .unwrap_or_default();
192                return Ok(std::process::Output {
193                    status,
194                    stdout,
195                    stderr,
196                });
197            }
198            Ok(None) => {
199                if std::time::Instant::now() >= deadline {
200                    let _ = child.kill();
201                    return Err(format!("timeout after {}ms", timeout.as_millis()));
202                }
203                std::thread::sleep(Duration::from_millis(10));
204            }
205            Err(e) => return Err(format!("wait error: {e}")),
206        }
207    }
208}
209
210pub fn execute_hooks_for_point(plugins: &[&Plugin], hook: &HookPoint) -> Vec<HookResult> {
211    let hook_name = hook.hook_name();
212    plugins
213        .iter()
214        .filter(|p| p.enabled && p.manifest.hooks.contains_key(hook_name))
215        .map(|p| execute_hook_sync(p, hook))
216        .collect()
217}
218
219#[cfg(test)]
220mod tests {
221    use super::*;
222
223    #[test]
224    fn hook_point_names() {
225        assert_eq!(HookPoint::OnSessionStart.hook_name(), "on_session_start");
226        assert_eq!(HookPoint::OnSessionEnd.hook_name(), "on_session_end");
227        assert_eq!(
228            HookPoint::PreRead { path: "x".into() }.hook_name(),
229            "pre_read"
230        );
231        assert_eq!(
232            HookPoint::PostCompress {
233                path: "x".into(),
234                original_tokens: 100,
235                compressed_tokens: 50,
236            }
237            .hook_name(),
238            "post_compress"
239        );
240        assert_eq!(
241            HookPoint::OnKnowledgeUpdate {
242                fact_id: "f1".into()
243            }
244            .hook_name(),
245            "on_knowledge_update"
246        );
247    }
248
249    #[test]
250    fn all_hook_names_complete() {
251        let names = HookPoint::all_hook_names();
252        assert_eq!(names.len(), 5);
253        assert!(names.contains(&"on_session_start"));
254        assert!(names.contains(&"pre_read"));
255        assert!(names.contains(&"post_compress"));
256    }
257
258    #[test]
259    fn hook_point_serializes_to_json() {
260        let hook = HookPoint::PostCompress {
261            path: "/tmp/file.rs".into(),
262            original_tokens: 1000,
263            compressed_tokens: 200,
264        };
265        let json = serde_json::to_string(&hook).unwrap();
266        assert!(json.contains("post_compress"));
267        assert!(json.contains("1000"));
268        assert!(json.contains("200"));
269    }
270
271    #[test]
272    fn execute_missing_hook_is_noop() {
273        let manifest = crate::core::plugins::manifest::PluginManifest::from_str(
274            r#"
275[plugin]
276name = "no-hooks"
277version = "1.0.0"
278"#,
279            &std::path::PathBuf::from("test.toml"),
280        )
281        .unwrap();
282
283        let plugin = Plugin {
284            manifest,
285            enabled: true,
286            path: std::path::PathBuf::from("/tmp/no-hooks"),
287        };
288
289        let result = execute_hook_sync(&plugin, &HookPoint::OnSessionStart);
290        assert!(result.success);
291        assert_eq!(result.duration_ms, 0);
292    }
293
294    #[test]
295    fn execute_nonexistent_binary_fails() {
296        let manifest = crate::core::plugins::manifest::PluginManifest::from_str(
297            r#"
298[plugin]
299name = "bad-binary"
300version = "1.0.0"
301
302[hooks.on_session_start]
303command = "__nonexistent_lean_ctx_test_binary__ start"
304timeout_ms = 1000
305"#,
306            &std::path::PathBuf::from("test.toml"),
307        )
308        .unwrap();
309
310        let plugin = Plugin {
311            manifest,
312            enabled: true,
313            path: std::path::PathBuf::from("/tmp/bad-binary"),
314        };
315
316        let result = execute_hook_sync(&plugin, &HookPoint::OnSessionStart);
317        assert!(!result.success);
318        assert!(result.error.unwrap().contains("failed to spawn"));
319    }
320
321    #[cfg(unix)]
322    #[test]
323    fn execute_echo_plugin_succeeds() {
324        let manifest = crate::core::plugins::manifest::PluginManifest::from_str(
325            r#"
326[plugin]
327name = "echo-plugin"
328version = "1.0.0"
329
330[hooks.on_session_start]
331command = "echo hello"
332timeout_ms = 2000
333"#,
334            &std::path::PathBuf::from("test.toml"),
335        )
336        .unwrap();
337
338        let plugin = Plugin {
339            manifest,
340            enabled: true,
341            path: std::path::PathBuf::from("/tmp/echo-plugin"),
342        };
343
344        let result = execute_hook_sync(&plugin, &HookPoint::OnSessionStart);
345        assert!(result.success);
346        assert!(result.output.unwrap().contains("hello"));
347    }
348}