Skip to main content

difflore_cli/hooks/
gemini_cli.rs

1//! Gemini CLI hook adapter.
2//!
3//! Gemini CLI supports 11 lifecycle hooks; `DiffLore` maps 4 to canonical
4//! `HookEvent` variants and ignores the remaining 7 as not actionable:
5//!
6//!   | Gemini event   | Canonical event              |
7//!   |----------------|------------------------------|
8//!   | `SessionStart`   | `SessionStart { cwd }`       |
9//!   | `BeforeAgent`    | `SessionStart { cwd }` *    |
10//!   | `AfterAgent`     | `Stop`                       |
11//!   | `AfterTool`      | `PostToolUse { … }`          |
12//!   | `SessionEnd`     | `SessionEnd`                 |
13//!   | `BeforeTool`     | no-op (pre-execution noise)  |
14//!   | `PreCompress`    | no-op                        |
15//!   | Notification   | no-op                        |
16//!
17//! \* `BeforeAgent` is treated as a session-start so `DiffLore`'s
18//! per-session warmup also fires when users resume sessions; this
19//! costs nothing extra — `ensure_ready` is cached per process.
20//!
21//! Example stdin (verified against Gemini CLI's published hook schema):
22//!
23//! ```json
24//! {
25//!   "session_id": "...",
26//!   "cwd": "/path/to/repo",
27//!   "hook_event_name": "AfterTool",
28//!   "tool_name": "WriteFile",
29//!   "tool_input":  { "path": "src/foo.py", "content": "..." },
30//!   "tool_response": { "success": true, "output": "…" },
31//!   "transcript_path": "/abs/path/to/transcript.jsonl"
32//! }
33//! ```
34//!
35//! **ANSI escape stripping**: Gemini CLI is known to leak raw ANSI
36//! color sequences through tool output into system messages. Our
37//! `format_output` path strips them out of `systemMessage` before
38//! shipping so the user doesn't see `\x1b[31m` garbage in their UI.
39
40use serde::{Deserialize, Serialize};
41use serde_json::{Value, json};
42
43use super::synth;
44use super::types::{HookEvent, HookResult};
45use super::{PayloadAdapter, PlatformAdapter};
46
47/// Zero-sized marker — no adapter state.
48pub struct GeminiCliAdapter;
49
50/// Typed view of Gemini CLI's hook stdin. Every field is optional:
51/// Gemini ships different subsets per event, and we reject only when
52/// `hook_event_name` itself is absent (structurally invalid).
53#[derive(Debug, Clone, Deserialize, Serialize, Default)]
54#[serde(rename_all = "snake_case")]
55pub(crate) struct GeminiHookPayload {
56    #[serde(default)]
57    hook_event_name: Option<String>,
58    #[serde(default)]
59    session_id: Option<String>,
60    #[serde(default)]
61    cwd: Option<String>,
62    #[serde(default)]
63    transcript_path: Option<String>,
64    #[serde(default)]
65    tool_name: Option<String>,
66    #[serde(default)]
67    tool_input: Option<Value>,
68    #[serde(default)]
69    tool_response: Option<Value>,
70    #[serde(default)]
71    prompt: Option<String>,
72    /// `AfterAgent` carries the full assistant text here.
73    #[serde(default)]
74    prompt_response: Option<String>,
75}
76
77impl GeminiHookPayload {
78    fn into_canonical(self) -> Result<HookEvent, String> {
79        let event_name = self
80            .hook_event_name
81            .as_deref()
82            .ok_or_else(|| "missing hook_event_name".to_owned())?;
83        match event_name {
84            // SessionStart and BeforeAgent both signal "new session/turn
85            // starting" — collapse both into SessionStart so warmup logic
86            // in the CLI runs once either way.
87            "SessionStart" | "BeforeAgent" => Ok(HookEvent::SessionStart {
88                cwd: self.cwd.unwrap_or_default(),
89                session_id: None,
90            }),
91            "AfterAgent" => Ok(HookEvent::Stop {
92                session_id: None,
93                transcript_path: None,
94                cwd: None,
95            }),
96            "AfterTool" => Ok(after_tool_event(self)),
97            "SessionEnd" => Ok(HookEvent::SessionEnd {
98                session_id: None,
99                transcript_path: None,
100                cwd: None,
101            }),
102            // BeforeTool / PreCompress / Notification: Gemini fires
103            // these for every tool call / every compaction / every
104            // permission prompt — way too chatty for rule retrieval,
105            // which wants "after the code actually changed" signals.
106            "BeforeTool" | "PreCompress" | "Notification" => Err(format!(
107                "Gemini CLI event {event_name} is intentionally ignored"
108            )),
109            other => Err(format!("unsupported Gemini CLI hook event: {other}")),
110        }
111    }
112}
113
114/// Build a `PostToolUse` from Gemini's `AfterTool` payload.
115///
116/// The `tool_input` shape is tool-specific (Gemini's built-ins include
117/// `WriteFile`, `Edit`, `ReadFile`, `ShellCommand`, …). We probe the
118/// common path keys so typical file-mutation tools flow through with a
119/// `file_path` set.
120///
121/// Tool-name normalisation: the dispatch layer only acts on the
122/// canonical Claude Code names (`Edit`/`Write`/`MultiEdit`); Gemini's
123/// `WriteFile` would otherwise get noop'd and Gemini users would
124/// silently miss rule injection on every file write. Map it here so
125/// downstream callers don't need a per-platform allowlist.
126fn after_tool_event(p: GeminiHookPayload) -> HookEvent {
127    let raw_tool_name = p.tool_name.clone().unwrap_or_default();
128    let tool_name = match raw_tool_name.as_str() {
129        "WriteFile" => "Write".to_owned(),
130        _ => raw_tool_name,
131    };
132    let file_path = p
133        .tool_input
134        .as_ref()
135        .and_then(|v| {
136            v.get("file_path")
137                .or_else(|| v.get("path"))
138                .or_else(|| v.get("file"))
139        })
140        .and_then(|v| v.as_str())
141        .map(String::from);
142    let diff = synthesise_diff(p.tool_input.as_ref(), p.tool_response.as_ref());
143    let (old_text, new_text) = synth::extract_edit_strings(p.tool_input.as_ref());
144    HookEvent::PostToolUse {
145        tool_name,
146        file_path,
147        diff,
148        session_id: p.session_id,
149        new_text,
150        old_text,
151    }
152}
153
154/// Best-effort diff synthesis for Gemini tool payloads.
155///
156/// Handles three common shapes:
157///   - `{ "old_string", "new_string" }` (Edit)
158///   - `{ "content" }` (`WriteFile`)
159///   - `{ "command", "output" }` (`ShellCommand`) — surfaces both so
160///     rule retrievers keying off "npm install" or "curl" still match.
161///
162/// Any tool output that carries ANSI escape sequences (common with
163/// `ShellCommand`) is stripped before being folded into the diff;
164/// downstream text matching shouldn't have to care about terminal
165/// colour codes.
166fn synthesise_diff(tool_input: Option<&Value>, tool_response: Option<&Value>) -> Option<String> {
167    let input = tool_input?;
168    if let (Some(old), Some(new)) = (
169        input.get("old_string").and_then(|v| v.as_str()),
170        input.get("new_string").and_then(|v| v.as_str()),
171    ) {
172        return Some(synth::diff_old_new(old, new));
173    }
174    if let Some(content) = input.get("content").and_then(|v| v.as_str()) {
175        return Some(synth::diff_content(content));
176    }
177    if let Some(cmd) = input.get("command").and_then(|v| v.as_str()) {
178        // Gemini stuffs shell output into `tool_response.output` — we
179        // sanitise ANSI before it lands in the retriever's text index.
180        let cleaned = tool_response
181            .and_then(|v| v.get("output"))
182            .and_then(|v| v.as_str())
183            .map(strip_ansi);
184        return synth::diff_shell(Some(cmd), cleaned.as_deref());
185    }
186    None
187}
188
189/// Strip ANSI escape sequences (CSI / OSC / ESC-prefixed control codes)
190/// from `s`. Implemented as a tiny state machine rather than via the
191/// `regex` crate so the CLI doesn't pick up an extra dependency just
192/// for this one call site.
193///
194/// Covers the cases claude-mem's regex targets:
195///   - CSI sequences `ESC [ … final-byte` where final-byte is
196///     `@ … ~` (0x40..=0x7E).
197///   - 8-bit CSI prefix (0x9B) with the same body.
198///   - Simple two-byte ESC sequences (e.g. `ESC 7`, `ESC M`).
199///
200/// Returns `s` unchanged on the fast path when neither ESC (0x1B) nor
201/// 8-bit CSI (0x9B) appears at all.
202pub(crate) fn strip_ansi(s: &str) -> String {
203    if !s.contains('\x1b') && !s.contains('\u{009b}') {
204        return s.to_owned();
205    }
206    let bytes = s.as_bytes();
207    let mut out = Vec::with_capacity(bytes.len());
208    let mut i = 0;
209    while i < bytes.len() {
210        let b = bytes[i];
211        // 8-bit CSI (U+009B, encoded in UTF-8 as 0xC2 0x9B).
212        if b == 0xC2 && i + 1 < bytes.len() && bytes[i + 1] == 0x9B {
213            i += 2;
214            i = skip_csi_body(bytes, i);
215            continue;
216        }
217        // ESC-prefixed sequence.
218        if b == 0x1B {
219            if i + 1 < bytes.len() && bytes[i + 1] == b'[' {
220                // CSI: ESC [ … final-byte
221                i += 2;
222                i = skip_csi_body(bytes, i);
223                continue;
224            }
225            // Two-byte ESC sequences (e.g. ESC M, ESC 7) — skip the
226            // ESC plus one more byte when present.
227            if i + 1 < bytes.len() {
228                i += 2;
229            } else {
230                i += 1;
231            }
232            continue;
233        }
234        out.push(b);
235        i += 1;
236    }
237    // Safety: input was valid UTF-8, and every byte we drop is a full
238    // ASCII escape sequence — we never split a multi-byte char.
239    String::from_utf8(out).unwrap_or_else(|_| s.to_owned())
240}
241
242/// Skip the body of a CSI sequence starting at `i`, returning the
243/// index just past the final byte. The body consists of any mix of
244/// parameter bytes (0x30..=0x3F) and intermediate bytes (0x20..=0x2F),
245/// terminated by a single final byte (0x40..=0x7E).
246fn skip_csi_body(bytes: &[u8], mut i: usize) -> usize {
247    while i < bytes.len() {
248        let c = bytes[i];
249        i += 1;
250        if (0x40..=0x7E).contains(&c) {
251            return i;
252        }
253    }
254    i
255}
256
257impl PayloadAdapter for GeminiCliAdapter {
258    type Raw = GeminiHookPayload;
259    const PARSE_LABEL: &'static str = "Gemini CLI";
260
261    fn into_canonical(raw: Self::Raw) -> Result<HookEvent, String> {
262        raw.into_canonical()
263    }
264}
265
266impl PlatformAdapter for GeminiCliAdapter {
267    fn name(&self) -> &'static str {
268        "gemini-cli"
269    }
270
271    fn parse_stdin(&self, raw: &str) -> Result<HookEvent, String> {
272        Self::parse_stdin_default(raw)
273    }
274
275    fn format_output(&self, result: HookResult) -> String {
276        // Gemini CLI's documented output shape:
277        //   { continue, suppressOutput, systemMessage, hookSpecificOutput }
278        // `continue` is always included to prevent accidental agent
279        // termination on future Gemini builds that treat absence as
280        // "stop".
281        let mut obj = json!({
282            "continue": result.continue_,
283            "suppressOutput": false,
284        });
285        if let Some(msg) = result.system_message {
286            obj["systemMessage"] = Value::String(strip_ansi(&msg));
287        }
288        if let Some(ctx) = result.additional_context {
289            // Gemini CLI pipes `hookSpecificOutput.additionalContext`
290            // back into the conversation transcript for SessionStart
291            // and AfterTool events — the natural place for advisory
292            // rule injection.
293            obj["hookSpecificOutput"] = json!({
294                "additionalContext": ctx,
295            });
296        }
297        crate::commands::util::json_compact_or(&obj, "{\"continue\":true}")
298    }
299}
300
301#[cfg(test)]
302mod tests {
303    use super::*;
304
305    #[test]
306    fn parse_session_start_reads_cwd() {
307        let adapter = GeminiCliAdapter;
308        let raw = r#"{"hook_event_name":"SessionStart","cwd":"/tmp/x"}"#;
309        assert_eq!(
310            adapter.parse_stdin(raw).unwrap(),
311            HookEvent::SessionStart {
312                cwd: "/tmp/x".into(),
313                session_id: None,
314            }
315        );
316    }
317
318    #[test]
319    fn parse_before_agent_maps_to_session_start() {
320        // Deliberate collapse: BeforeAgent triggers session warmup too.
321        let adapter = GeminiCliAdapter;
322        let raw = r#"{"hook_event_name":"BeforeAgent","cwd":"/home/me/p"}"#;
323        assert_eq!(
324            adapter.parse_stdin(raw).unwrap(),
325            HookEvent::SessionStart {
326                cwd: "/home/me/p".into(),
327                session_id: None,
328            }
329        );
330    }
331
332    #[test]
333    fn parse_after_agent_maps_to_stop() {
334        let adapter = GeminiCliAdapter;
335        assert_eq!(
336            adapter
337                .parse_stdin(r#"{"hook_event_name":"AfterAgent"}"#)
338                .unwrap(),
339            HookEvent::Stop {
340                session_id: None,
341                transcript_path: None,
342                cwd: None
343            }
344        );
345    }
346
347    #[test]
348    fn parse_after_tool_extracts_file_path_and_diff() {
349        let adapter = GeminiCliAdapter;
350        let raw = r#"{
351            "hook_event_name": "AfterTool",
352            "tool_name": "Edit",
353            "tool_input": {
354                "file_path": "src/foo.py",
355                "old_string": "a=1",
356                "new_string": "a=2"
357            }
358        }"#;
359        if let HookEvent::PostToolUse {
360            tool_name,
361            file_path,
362            diff,
363            ..
364        } = adapter.parse_stdin(raw).unwrap()
365        {
366            assert_eq!(tool_name, "Edit");
367            assert_eq!(file_path.as_deref(), Some("src/foo.py"));
368            let d = diff.unwrap();
369            assert!(d.contains("-a=1") && d.contains("+a=2"));
370        } else {
371            panic!("expected PostToolUse");
372        }
373    }
374
375    #[test]
376    fn parse_after_tool_normalises_writefile_to_write() {
377        // Regression: Gemini's `WriteFile` tool was passed through
378        // unmapped, so the hook dispatcher's `Edit|Write|MultiEdit`
379        // allowlist noop'd every Gemini file write — silently skipping
380        // rule injection on the canonical Gemini editing surface.
381        let adapter = GeminiCliAdapter;
382        let raw = r#"{
383            "hook_event_name": "AfterTool",
384            "tool_name": "WriteFile",
385            "tool_input": {
386                "file_path": "src/new.py",
387                "content": "print('hi')"
388            }
389        }"#;
390        if let HookEvent::PostToolUse { tool_name, .. } = adapter.parse_stdin(raw).unwrap() {
391            assert_eq!(tool_name, "Write");
392        } else {
393            panic!("expected PostToolUse");
394        }
395    }
396
397    #[test]
398    fn parse_after_tool_shell_strips_ansi_from_output() {
399        // Regression: Gemini's ShellCommand often leaks ANSI colour
400        // codes — we must strip them before they reach rule retrieval.
401        // The JSON input uses `` unicode escapes (legal JSON) so
402        // serde accepts the payload; strip_ansi then cleans them up.
403        let adapter = GeminiCliAdapter;
404        // Build the JSON via serde_json::json! so the ESC byte (0x1B)
405        // lives only in the produced string; the source file stays
406        // ASCII-clean and survives Edit / Write tools that strip C0
407        // control bytes from string literals.
408        let output = format!("{esc}[31mred{esc}[0m plain", esc = '\u{001b}');
409        let payload = json!({
410            "hook_event_name": "AfterTool",
411            "tool_name": "ShellCommand",
412            "tool_input": { "command": "ls" },
413            "tool_response": { "output": output },
414        });
415        let raw = serde_json::to_string(&payload).unwrap();
416        if let HookEvent::PostToolUse { diff, .. } = adapter.parse_stdin(&raw).unwrap() {
417            let d = diff.unwrap();
418            assert!(d.contains("$ ls"));
419            assert!(d.contains("+red plain"), "got: {d:?}");
420            assert!(!d.contains('\x1b'), "ANSI escape leaked into diff: {d:?}");
421        } else {
422            panic!("expected PostToolUse");
423        }
424    }
425
426    #[test]
427    fn parse_ignored_events_error_loudly_so_cli_noops() {
428        // BeforeTool / PreCompress / Notification are deliberately
429        // not modelled. They must error (so the CLI no-ops) rather
430        // than silently returning Stop or anything else actionable.
431        let adapter = GeminiCliAdapter;
432        for ev in ["BeforeTool", "PreCompress", "Notification"] {
433            let raw = format!(r#"{{"hook_event_name":"{ev}"}}"#);
434            let err = adapter.parse_stdin(&raw).unwrap_err();
435            assert!(err.contains("ignored"), "for {ev}: {err}");
436        }
437    }
438
439    #[test]
440    fn format_output_includes_continue_and_suppress_output() {
441        let adapter = GeminiCliAdapter;
442        let out = adapter.format_output(HookResult::noop());
443        let v: Value = serde_json::from_str(&out).unwrap();
444        assert_eq!(v["continue"], true);
445        assert_eq!(v["suppressOutput"], false);
446    }
447
448    #[test]
449    fn format_output_nests_additional_context_under_hook_specific_output() {
450        let adapter = GeminiCliAdapter;
451        let out = adapter.format_output(HookResult::with_context("R1"));
452        let v: Value = serde_json::from_str(&out).unwrap();
453        assert_eq!(v["hookSpecificOutput"]["additionalContext"], "R1");
454    }
455
456    #[test]
457    fn format_output_strips_ansi_from_system_message() {
458        let adapter = GeminiCliAdapter;
459        let mut r = HookResult::noop();
460        r.system_message = Some("\u{001b}[31mred\u{001b}[0m OK".into());
461        let out = adapter.format_output(r);
462        let v: Value = serde_json::from_str(&out).unwrap();
463        let msg = v["systemMessage"].as_str().unwrap();
464        assert!(!msg.contains('\x1b'), "ANSI leaked: {msg:?}");
465        assert!(msg.contains("red OK"), "content lost: {msg:?}");
466    }
467}