Skip to main content

difflore_cli/
hook_forward.rs

1use std::sync::Arc;
2use std::time::Duration;
3
4use interprocess::local_socket::tokio::prelude::*;
5use interprocess::local_socket::{GenericFilePath, ListenerOptions, ToFsName};
6use serde::{Deserialize, Serialize};
7use tokio::time::timeout;
8
9pub const ENV: &str = difflore_core::env::DIFFLORE_HOOK_FORWARD;
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum Mode {
13    Auto,
14    Always,
15    Never,
16}
17
18impl Mode {
19    fn from_env() -> Self {
20        match difflore_core::env::var(ENV)
21            .unwrap_or_else(|| "auto".to_owned())
22            .to_ascii_lowercase()
23            .as_str()
24        {
25            "always" => Self::Always,
26            "never" | "off" | "0" | "false" => Self::Never,
27            _ => Self::Auto,
28        }
29    }
30}
31
32impl std::fmt::Display for Mode {
33    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34        match self {
35            Self::Auto => write!(f, "auto"),
36            Self::Always => write!(f, "always"),
37            Self::Never => write!(f, "never"),
38        }
39    }
40}
41
42pub enum Attempt {
43    Used(String),
44    Unavailable { mode: Mode, error: String },
45    Disabled,
46}
47
48#[derive(Clone)]
49pub struct State {
50    pub db: difflore_core::SqlitePool,
51    pub index_pool: difflore_core::SqlitePool,
52}
53
54#[derive(Debug, Serialize, Deserialize)]
55struct Request {
56    client: String,
57    raw: String,
58}
59
60#[derive(Debug, Serialize, Deserialize)]
61struct Response {
62    ok: bool,
63    #[serde(default, skip_serializing_if = "Option::is_none")]
64    output: Option<String>,
65    #[serde(default, skip_serializing_if = "Option::is_none")]
66    error: Option<String>,
67}
68
69pub async fn try_forward(client: &str, raw: &str) -> Attempt {
70    let mode = Mode::from_env();
71    if mode == Mode::Never {
72        return Attempt::Disabled;
73    }
74    let req = Request {
75        client: client.to_owned(),
76        raw: raw.to_owned(),
77    };
78    let fut = roundtrip(&req);
79    match timeout(Duration::from_secs(5), fut).await {
80        Ok(Ok(output)) => Attempt::Used(output),
81        Ok(Err(error)) => Attempt::Unavailable {
82            mode,
83            error: error.to_string(),
84        },
85        Err(_) => Attempt::Unavailable {
86            mode,
87            error: "timed out connecting to hook forwarder".to_owned(),
88        },
89    }
90}
91
92async fn roundtrip(req: &Request) -> anyhow::Result<String> {
93    let line = serde_json::to_string(req)?;
94    let response_line = ipc_roundtrip(&(line + "\n")).await?;
95    let response: Response = serde_json::from_str(response_line.trim())?;
96    if response.ok {
97        Ok(response
98            .output
99            .unwrap_or_else(|| "{\"continue\":true}".to_owned()))
100    } else {
101        Err(anyhow::anyhow!(
102            "{}",
103            response
104                .error
105                .unwrap_or_else(|| "hook forwarder returned an unknown error".to_owned())
106        ))
107    }
108}
109
110pub async fn run_server() -> anyhow::Result<()> {
111    let db = difflore_core::db::init_db()
112        .await
113        .map_err(anyhow::Error::msg)?;
114    let index_pool = difflore_core::context::index_db::get_pool_for_cwd().await?;
115    let state = Arc::new(State { db, index_pool });
116    run_ipc_server(state).await
117}
118
119async fn handle_request(state: &State, line: &str) -> Response {
120    let trace = difflore_core::env::trace_hook();
121    let started = std::time::Instant::now();
122    let req: Request = match serde_json::from_str(line.trim()) {
123        Ok(req) => req,
124        Err(e) => {
125            return Response {
126                ok: false,
127                output: None,
128                error: Some(format!("invalid forward request: {e}")),
129            };
130        }
131    };
132    let adapter = crate::hooks::get_platform_adapter(&req.client);
133    let response = match crate::hook_runtime::hook_output_for_raw(
134        &req.client,
135        &*adapter,
136        &req.raw,
137        false,
138        true,
139        Some(state),
140    )
141    .await
142    {
143        Ok(output) => {
144            if trace {
145                eprintln!(
146                    "[difflore.forward.trace] hook_output={}ms",
147                    started.elapsed().as_millis()
148                );
149            }
150            Response {
151                ok: true,
152                output: Some(output),
153                error: None,
154            }
155        }
156        Err(e) => {
157            if trace {
158                eprintln!(
159                    "[difflore.forward.trace] hook_error={}ms",
160                    started.elapsed().as_millis()
161                );
162            }
163            Response {
164                ok: false,
165                output: None,
166                error: Some(e.to_string()),
167            }
168        }
169    };
170    if trace {
171        eprintln!(
172            "[difflore.forward.trace] response_ready={}ms",
173            started.elapsed().as_millis()
174        );
175    }
176    response
177}
178
179/// Cross-platform local-socket endpoint. `interprocess` interprets
180/// the same path as a Unix-domain socket on Unix and as a named-pipe-
181/// equivalent in the local namespace on Windows.
182fn endpoint() -> anyhow::Result<std::path::PathBuf> {
183    Ok(difflore_core::paths::data_home()
184        .map_err(anyhow::Error::msg)?
185        .join("hook-forward.sock"))
186}
187
188async fn ipc_roundtrip(request_line: &str) -> anyhow::Result<String> {
189    use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
190
191    let path = endpoint()?;
192    let name = path.to_fs_name::<GenericFilePath>()?;
193    let stream = LocalSocketStream::connect(name).await?;
194    let (reader, mut writer) = stream.split();
195    writer.write_all(request_line.as_bytes()).await?;
196    writer.flush().await?;
197    let mut reader = BufReader::new(reader);
198    let mut response = String::new();
199    reader.read_line(&mut response).await?;
200    if response.trim().is_empty() {
201        anyhow::bail!("hook forwarder returned an empty response");
202    }
203    Ok(response)
204}
205
206async fn run_ipc_server(state: Arc<State>) -> anyhow::Result<()> {
207    use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
208
209    let socket = endpoint()?;
210    if let Some(parent) = socket.parent() {
211        std::fs::create_dir_all(parent)?;
212    }
213    // On Unix the listener takes a real filesystem path; remove any
214    // stale socket file from a prior run. On Windows the named-pipe
215    // namespace doesn't have a file to remove, but the call is a noop.
216    let _ = std::fs::remove_file(&socket);
217    let name = socket.to_fs_name::<GenericFilePath>()?;
218    let listener = ListenerOptions::new().name(name).create_tokio()?;
219    loop {
220        let stream = listener.accept().await?;
221        let state = Arc::<State>::clone(&state);
222        tokio::spawn(async move {
223            let trace = difflore_core::env::trace_hook();
224            let started = std::time::Instant::now();
225            let (reader, mut writer) = stream.split();
226            let mut reader = BufReader::new(reader);
227            let mut line = String::new();
228            if reader.read_line(&mut line).await.is_err() {
229                return;
230            }
231            if trace {
232                eprintln!(
233                    "[difflore.forward.trace] request_read={}ms",
234                    started.elapsed().as_millis()
235                );
236            }
237            let response = handle_request(&state, &line).await;
238            let response_line = match serde_json::to_string(&response) {
239                Ok(s) => s + "\n",
240                Err(_) => "{\"ok\":false,\"error\":\"serialize response failed\"}\n".to_owned(),
241            };
242            let _ = writer.write_all(response_line.as_bytes()).await;
243            let _ = writer.flush().await;
244            if trace {
245                eprintln!(
246                    "[difflore.forward.trace] response_written={}ms",
247                    started.elapsed().as_millis()
248                );
249            }
250        });
251    }
252}