Skip to main content

difflore_cli/hooks/
mod.rs

1//! Platform hook adapter layer.
2//!
3//! Each supported AI coding client (Claude Code, Cursor, Zed, …) expects
4//! lifecycle hooks to speak its own JSON dialect on stdin/stdout. This
5//! module defines the `PlatformAdapter` trait every client implementation
6//! conforms to, plus the `get_platform_adapter` dispatcher that the CLI
7//! uses to look up the right adapter by name in the `difflore-hook` shim
8//! time.
9//!
10//! The CLI's job is thin: read stdin, hand it to the adapter, get a
11//! normalised `HookEvent`, run `DiffLore` logic, hand the `HookResult`
12//! back to the adapter to get platform-specific JSON, write to stdout.
13//! Any per-platform quirk lives *inside* the adapter — the CLI stays
14//! platform-agnostic.
15
16pub mod claude_code;
17pub mod cursor;
18pub mod gemini_cli;
19// Since-last-session recap banner used by the `SessionStart` dispatch arm.
20pub mod session_banner;
21pub(crate) mod synth;
22pub mod types;
23pub mod windsurf;
24
25/// Static, generic half of an adapter — owns the raw payload type,
26/// the human-readable label used in parse-error messages, and the
27/// canonical-event mapping. `PlatformAdapter` (above) stays
28/// object-safe for the `Box<dyn PlatformAdapter>` dispatch site;
29/// this trait carries the type-level pieces (associated types,
30/// associated consts) the dispatcher doesn't need.
31///
32/// Adapters implement BOTH traits: `PlatformAdapter` for runtime
33/// dispatch, `PayloadAdapter` for the parse pipeline. The two are
34/// glued together by `PlatformAdapter::parse_stdin` delegating to
35/// `Self::parse_stdin_default`.
36pub(crate) trait PayloadAdapter {
37    /// Strongly-typed view of the IDE's stdin envelope. Each adapter
38    /// keeps its own per-IDE struct (the wire shapes diverge enough
39    /// that a one-size struct would be a constant grab-bag of
40    /// `Option<Value>`s).
41    type Raw: serde::de::DeserializeOwned;
42
43    /// Used in the "invalid <label> hook JSON" parse-error message.
44    /// Keeps each adapter's wording self-explanatory in logs.
45    const PARSE_LABEL: &'static str;
46
47    /// Map a parsed `Raw` into the canonical `HookEvent`. Adapters
48    /// are responsible for: validating the discriminator field,
49    /// dispatching by event name, and pulling per-event payload
50    /// fields out of the (often loosely-typed) `Raw`.
51    fn into_canonical(raw: Self::Raw) -> Result<types::HookEvent, String>;
52
53    /// Default `parse_stdin` body: trim, deserialize into `Raw`,
54    /// hand off to `into_canonical`. Adapters' `PlatformAdapter::
55    /// parse_stdin` impls delegate here so the boilerplate lives in
56    /// exactly one place.
57    fn parse_stdin_default(raw: &str) -> Result<types::HookEvent, String> {
58        let payload: Self::Raw = serde_json::from_str(raw.trim())
59            .map_err(|e| format!("invalid {} hook JSON: {e}", Self::PARSE_LABEL))?;
60        Self::into_canonical(payload)
61    }
62}
63
64/// Contract every platform adapter implements. The trait is object-safe
65/// on purpose so `get_platform_adapter` can return a `Box<dyn
66/// PlatformAdapter>` and the dispatch site doesn't need to know the
67/// concrete type at compile time. That makes adding a new client (say
68/// Cursor) a pure module-level addition — no changes to the CLI dispatch
69/// loop beyond the `get_platform_adapter` match arm.
70pub trait PlatformAdapter: Send + Sync {
71    /// Stable identifier used in logs + telemetry. Must match the string
72    /// `get_platform_adapter` dispatches on so `adapter.name() ==
73    /// requested_name` round-trips.
74    fn name(&self) -> &'static str;
75
76    /// Parse a single hook invocation's stdin payload into our canonical
77    /// `HookEvent`. Adapters SHOULD be permissive about unknown fields
78    /// (clients evolve faster than we can ship adapter updates) and
79    /// strict only about the tiny subset they actually need.
80    ///
81    /// On unrecognised / unsupported events, return `Err` with a human-
82    /// readable reason. The CLI logs the error and no-ops — hooks must
83    /// never block the user workflow, even on malformed input.
84    fn parse_stdin(&self, raw: &str) -> Result<types::HookEvent, String>;
85
86    /// Format a `HookResult` as the exact JSON the client expects on
87    /// stdout. Returns a complete, newline-free string; the caller
88    /// prints it + a trailing newline. Formatting is infallible because
89    /// `HookResult` is a fixed shape we control.
90    fn format_output(&self, result: types::HookResult) -> String;
91
92    /// Bucket an error produced by the hook's core work so the CLI can
93    /// pick an exit code (see `main.rs` `Hook::Run`). Default walks the
94    /// `anyhow` error chain for transport-ish hints (io kinds, reqwest
95    /// connect/timeout, HTTP 5xx) vs client-ish hints (HTTP 4xx, serde
96    /// parse failures). Adapters can override when their transport layer
97    /// carries richer context than the default sniffer can see.
98    fn classify_error(&self, err: &anyhow::Error) -> types::ErrorClass {
99        default_classify_error(err)
100    }
101}
102
103/// Default error classifier shared by every adapter. Kept as a free
104/// function (not an inherent method) so unit tests don't need to
105/// construct a concrete adapter to exercise it.
106pub fn default_classify_error(err: &anyhow::Error) -> types::ErrorClass {
107    use types::ErrorClass;
108
109    for cause in err.chain() {
110        // reqwest: connection refused / timeout / DNS resolution failure
111        // all surface through these two predicates. HTTP status is
112        // checked separately — `is_connect` and `is_timeout` return
113        // false on a 5xx response body.
114        if let Some(re) = cause.downcast_ref::<reqwest::Error>() {
115            if re.is_timeout() || re.is_connect() {
116                return ErrorClass::Transport;
117            }
118            if let Some(status) = re.status() {
119                if status.is_server_error() {
120                    return ErrorClass::Transport;
121                }
122                // Retryable 4xx subset that semantically belongs in
123                // Transport, not Client — the assistant session must
124                // not block on them. 429 (Too Many Requests) and 408
125                // (Request Timeout) are infrastructure-level signals
126                // ("wait + retry"), same as a 5xx or DNS failure.
127                // See memory `project_error_path_actionable_playbook.md`
128                // — `format_cloud_err` already routes 429 through the
129                // transport-style hint; classifier should match.
130                if status.as_u16() == 429 || status.as_u16() == 408 {
131                    return ErrorClass::Transport;
132                }
133                if status.is_client_error() {
134                    return ErrorClass::Client;
135                }
136            }
137        }
138
139        // std::io: connection refused by the listener, socket half
140        // closed mid-request, kernel timeout. All transport-class.
141        if let Some(io) = cause.downcast_ref::<std::io::Error>() {
142            use std::io::ErrorKind::{ConnectionRefused, ConnectionReset, NotConnected, TimedOut};
143            if matches!(
144                io.kind(),
145                ConnectionRefused | TimedOut | ConnectionReset | NotConnected
146            ) {
147                return ErrorClass::Transport;
148            }
149        }
150
151        // serde: a parse failure means the other side sent us something
152        // malformed. Not our transport's fault — surface so the parser
153        // (ours or theirs) gets fixed.
154        if cause.downcast_ref::<serde_json::Error>().is_some() {
155            return ErrorClass::Client;
156        }
157    }
158
159    ErrorClass::Fatal
160}
161
162#[cfg(test)]
163mod classifier_tests {
164    use super::*;
165    use types::ErrorClass;
166
167    #[test]
168    fn io_kinds_map_to_expected_class() {
169        use std::io::ErrorKind;
170        let cases: &[(ErrorKind, ErrorClass)] = &[
171            (ErrorKind::ConnectionRefused, ErrorClass::Transport),
172            (ErrorKind::TimedOut, ErrorClass::Transport),
173            (ErrorKind::ConnectionReset, ErrorClass::Transport),
174            (ErrorKind::NotConnected, ErrorClass::Transport),
175            // PermissionDenied is NOT in the transport allow-list — must
176            // fall through to Fatal so real bugs aren't silently retried.
177            (ErrorKind::PermissionDenied, ErrorClass::Fatal),
178        ];
179        for (kind, want) in cases {
180            let err: anyhow::Error = std::io::Error::new(*kind, "x").into();
181            assert_eq!(default_classify_error(&err), *want, "for {kind:?}");
182        }
183    }
184
185    #[test]
186    fn serde_parse_error_is_client_and_plain_anyhow_is_fatal() {
187        // Malformed JSON → Client (caller's bug, surface it).
188        let parse_err = serde_json::from_str::<serde_json::Value>("{not json").unwrap_err();
189        let err: anyhow::Error = parse_err.into();
190        assert_eq!(default_classify_error(&err), ErrorClass::Client);
191
192        // No downcast hit → conservative Fatal default.
193        let err = anyhow::anyhow!("something exploded");
194        assert_eq!(default_classify_error(&err), ErrorClass::Fatal);
195    }
196
197    #[test]
198    fn wrapped_io_transport_still_classifies_through_context() {
199        // Production callers almost always add `.context("...")` before
200        // the error escapes. Chain-walking must still spot the io kind.
201        let root: anyhow::Error =
202            std::io::Error::new(std::io::ErrorKind::ConnectionRefused, "down").into();
203        let wrapped = root
204            .context("fetch relevant rules")
205            .context("hook dispatch");
206        assert_eq!(default_classify_error(&wrapped), ErrorClass::Transport);
207    }
208}
209
210/// Dispatch by client name. Unknown names fall through to the
211/// Claude Code adapter as the pragmatic default — almost every
212/// `DiffLore` user today is on Claude Code, and a wrong-but-compatible
213/// parse fails loudly (via `parse_stdin`) while a panic would kill the
214/// user's whole assistant session.
215///
216/// Accepted aliases: `"claude-code"` / `"claude_code"` / `"claude"` all
217/// map to the Claude Code adapter so env-var typos don't silently
218/// reach a `Cursor`/`Zed` codepath that doesn't yet exist.
219pub fn get_platform_adapter(client_name: &str) -> Box<dyn PlatformAdapter> {
220    // Match case-insensitively + ignoring separator style ("gemini-cli"
221    // vs "gemini_cli") so env-var typos and different casing conventions
222    // in hook configs across tools all route to the right adapter.
223    let normalized = client_name.to_ascii_lowercase();
224    match normalized.as_str() {
225        "cursor" => Box::new(cursor::CursorAdapter),
226        "gemini-cli" | "gemini_cli" | "gemini" => Box::new(gemini_cli::GeminiCliAdapter),
227        "windsurf" => Box::new(windsurf::WindsurfAdapter),
228        // "claude-code"/"claude_code"/"claude" plus any unknown name
229        // deliberately route to Claude Code: see module docs.
230        _ => Box::new(claude_code::ClaudeCodeAdapter),
231    }
232}
233
234#[cfg(test)]
235mod tests {
236    use super::*;
237
238    #[test]
239    fn dispatch_routes_aliases_and_unknown_falls_back_to_claude_code() {
240        // Aliases (separator + casing variants) must all route correctly.
241        // Unknown names fall back to Claude Code rather than panic — see
242        // `get_platform_adapter` doc.
243        let cases: &[(&str, &str)] = &[
244            ("claude-code", "claude-code"),
245            ("claude_code", "claude-code"),
246            ("claude", "claude-code"),
247            ("cursor", "cursor"),
248            ("Cursor", "cursor"),
249            ("gemini-cli", "gemini-cli"),
250            ("gemini_cli", "gemini-cli"),
251            ("gemini", "gemini-cli"),
252            ("Gemini-CLI", "gemini-cli"),
253            ("windsurf", "windsurf"),
254            ("Windsurf", "windsurf"),
255            ("definitely-not-a-real-client", "claude-code"),
256        ];
257        for (input, want) in cases {
258            assert_eq!(
259                get_platform_adapter(input).name(),
260                *want,
261                "alias {input} misrouted"
262            );
263        }
264    }
265}