difflore_cli/
hook_forward.rs1use std::sync::Arc;
2use std::time::Duration;
3
4use interprocess::local_socket::tokio::prelude::*;
5use interprocess::local_socket::{GenericFilePath, ListenerOptions, ToFsName};
6use serde::{Deserialize, Serialize};
7use tokio::time::timeout;
8
9pub const ENV: &str = difflore_core::env::DIFFLORE_HOOK_FORWARD;
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum Mode {
13 Auto,
14 Always,
15 Never,
16}
17
18impl Mode {
19 fn from_env() -> Self {
20 match difflore_core::env::var(ENV)
21 .unwrap_or_else(|| "auto".to_owned())
22 .to_ascii_lowercase()
23 .as_str()
24 {
25 "always" => Self::Always,
26 "never" | "off" | "0" | "false" => Self::Never,
27 _ => Self::Auto,
28 }
29 }
30}
31
32impl std::fmt::Display for Mode {
33 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34 match self {
35 Self::Auto => write!(f, "auto"),
36 Self::Always => write!(f, "always"),
37 Self::Never => write!(f, "never"),
38 }
39 }
40}
41
42pub enum Attempt {
43 Used(String),
44 Unavailable { mode: Mode, error: String },
45 Disabled,
46}
47
48#[derive(Clone)]
49pub struct State {
50 pub db: difflore_core::SqlitePool,
51 pub index_pool: difflore_core::SqlitePool,
52}
53
54#[derive(Debug, Serialize, Deserialize)]
55struct Request {
56 client: String,
57 raw: String,
58}
59
60#[derive(Debug, Serialize, Deserialize)]
61struct Response {
62 ok: bool,
63 #[serde(default, skip_serializing_if = "Option::is_none")]
64 output: Option<String>,
65 #[serde(default, skip_serializing_if = "Option::is_none")]
66 error: Option<String>,
67}
68
69pub async fn try_forward(client: &str, raw: &str) -> Attempt {
70 let mode = Mode::from_env();
71 if mode == Mode::Never {
72 return Attempt::Disabled;
73 }
74 let req = Request {
75 client: client.to_owned(),
76 raw: raw.to_owned(),
77 };
78 let fut = roundtrip(&req);
79 match timeout(Duration::from_secs(5), fut).await {
80 Ok(Ok(output)) => Attempt::Used(output),
81 Ok(Err(error)) => Attempt::Unavailable {
82 mode,
83 error: error.to_string(),
84 },
85 Err(_) => Attempt::Unavailable {
86 mode,
87 error: "timed out connecting to hook forwarder".to_owned(),
88 },
89 }
90}
91
92async fn roundtrip(req: &Request) -> anyhow::Result<String> {
93 let line = serde_json::to_string(req)?;
94 let response_line = ipc_roundtrip(&(line + "\n")).await?;
95 let response: Response = serde_json::from_str(response_line.trim())?;
96 if response.ok {
97 Ok(response
98 .output
99 .unwrap_or_else(|| "{\"continue\":true}".to_owned()))
100 } else {
101 Err(anyhow::anyhow!(
102 "{}",
103 response
104 .error
105 .unwrap_or_else(|| "hook forwarder returned an unknown error".to_owned())
106 ))
107 }
108}
109
110pub async fn run_server() -> anyhow::Result<()> {
111 let db = difflore_core::db::init_db()
112 .await
113 .map_err(anyhow::Error::msg)?;
114 let index_pool = difflore_core::context::index_db::get_pool_for_cwd().await?;
115 let state = Arc::new(State { db, index_pool });
116 run_ipc_server(state).await
117}
118
119async fn handle_request(state: &State, line: &str) -> Response {
120 let trace = difflore_core::env::trace_hook();
121 let started = std::time::Instant::now();
122 let req: Request = match serde_json::from_str(line.trim()) {
123 Ok(req) => req,
124 Err(e) => {
125 return Response {
126 ok: false,
127 output: None,
128 error: Some(format!("invalid forward request: {e}")),
129 };
130 }
131 };
132 let adapter = crate::hooks::get_platform_adapter(&req.client);
133 let response = match crate::hook_runtime::hook_output_for_raw(
134 &req.client,
135 &*adapter,
136 &req.raw,
137 false,
138 true,
139 Some(state),
140 )
141 .await
142 {
143 Ok(output) => {
144 if trace {
145 eprintln!(
146 "[difflore.forward.trace] hook_output={}ms",
147 started.elapsed().as_millis()
148 );
149 }
150 Response {
151 ok: true,
152 output: Some(output),
153 error: None,
154 }
155 }
156 Err(e) => {
157 if trace {
158 eprintln!(
159 "[difflore.forward.trace] hook_error={}ms",
160 started.elapsed().as_millis()
161 );
162 }
163 Response {
164 ok: false,
165 output: None,
166 error: Some(e.to_string()),
167 }
168 }
169 };
170 if trace {
171 eprintln!(
172 "[difflore.forward.trace] response_ready={}ms",
173 started.elapsed().as_millis()
174 );
175 }
176 response
177}
178
179fn endpoint() -> anyhow::Result<std::path::PathBuf> {
183 Ok(difflore_core::paths::data_home()
184 .map_err(anyhow::Error::msg)?
185 .join("hook-forward.sock"))
186}
187
188async fn ipc_roundtrip(request_line: &str) -> anyhow::Result<String> {
189 use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
190
191 let path = endpoint()?;
192 let name = path.to_fs_name::<GenericFilePath>()?;
193 let stream = LocalSocketStream::connect(name).await?;
194 let (reader, mut writer) = stream.split();
195 writer.write_all(request_line.as_bytes()).await?;
196 writer.flush().await?;
197 let mut reader = BufReader::new(reader);
198 let mut response = String::new();
199 reader.read_line(&mut response).await?;
200 if response.trim().is_empty() {
201 anyhow::bail!("hook forwarder returned an empty response");
202 }
203 Ok(response)
204}
205
206async fn run_ipc_server(state: Arc<State>) -> anyhow::Result<()> {
207 use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
208
209 let socket = endpoint()?;
210 if let Some(parent) = socket.parent() {
211 std::fs::create_dir_all(parent)?;
212 }
213 let _ = std::fs::remove_file(&socket);
217 let name = socket.to_fs_name::<GenericFilePath>()?;
218 let listener = ListenerOptions::new().name(name).create_tokio()?;
219 loop {
220 let stream = listener.accept().await?;
221 let state = Arc::<State>::clone(&state);
222 tokio::spawn(async move {
223 let trace = difflore_core::env::trace_hook();
224 let started = std::time::Instant::now();
225 let (reader, mut writer) = stream.split();
226 let mut reader = BufReader::new(reader);
227 let mut line = String::new();
228 if reader.read_line(&mut line).await.is_err() {
229 return;
230 }
231 if trace {
232 eprintln!(
233 "[difflore.forward.trace] request_read={}ms",
234 started.elapsed().as_millis()
235 );
236 }
237 let response = handle_request(&state, &line).await;
238 let response_line = match serde_json::to_string(&response) {
239 Ok(s) => s + "\n",
240 Err(_) => "{\"ok\":false,\"error\":\"serialize response failed\"}\n".to_owned(),
241 };
242 let _ = writer.write_all(response_line.as_bytes()).await;
243 let _ = writer.flush().await;
244 if trace {
245 eprintln!(
246 "[difflore.forward.trace] response_written={}ms",
247 started.elapsed().as_millis()
248 );
249 }
250 });
251 }
252}