1use serde::{Deserialize, Serialize};
2use serde_json::Value;
3use std::collections::HashMap;
4use std::future::Future;
5use std::pin::Pin;
6use std::sync::Arc;
7use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader};
8use tokio::net::UnixStream;
9use tokio::sync::{oneshot, Mutex, Notify, RwLock};
10use tracing::{debug, error, info};
11
12const MAX_MESSAGE_SIZE: usize = 50 * 1024 * 1024;
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub(crate) struct IPCMessage {
16 pub(crate) id: String,
17 pub(crate) r#type: String,
18 pub(crate) payload: Value,
19}
20
21pub type IncomingRequestHandler = Arc<
22 dyn Fn(String, Value) -> Pin<Box<dyn Future<Output = Result<Value, String>> + Send>>
23 + Send
24 + Sync,
25>;
26
27async fn read_message<R: AsyncReadExt + Unpin>(reader: &mut R) -> Result<IPCMessage, String> {
28 let mut len_buf = [0u8; 4];
29 reader
30 .read_exact(&mut len_buf)
31 .await
32 .map_err(|e| format!("Failed to read message length: {}", e))?;
33 let len = u32::from_be_bytes(len_buf) as usize;
34 if len == 0 {
35 return Err("Zero-length message".to_string());
36 }
37 if len > MAX_MESSAGE_SIZE {
38 return Err(format!("Message too large: {} bytes", len));
39 }
40 let mut body = vec![0u8; len];
41 reader
42 .read_exact(&mut body)
43 .await
44 .map_err(|e| format!("Failed to read message body: {}", e))?;
45 let msg: IPCMessage =
46 serde_json::from_slice(&body).map_err(|e| format!("Failed to parse IPC message: {}", e))?;
47 Ok(msg)
48}
49
50async fn write_message<W: AsyncWriteExt + Unpin>(
51 writer: &mut W, msg: &IPCMessage,
52) -> Result<(), String> {
53 let body =
54 serde_json::to_vec(msg).map_err(|e| format!("Failed to serialize IPC message: {}", e))?;
55 let len = body.len() as u32;
56 writer
57 .write_all(&len.to_be_bytes())
58 .await
59 .map_err(|e| format!("Failed to write message length: {}", e))?;
60 writer.write_all(&body).await.map_err(|e| format!("Failed to write message body: {}", e))?;
61 writer.flush().await.map_err(|e| format!("Failed to flush message: {}", e))?;
62 Ok(())
63}
64
65type PendingRequestMap = Arc<Mutex<HashMap<String, oneshot::Sender<Result<Value, String>>>>>;
66
67pub struct ExtensionIpc {
68 write: Arc<Mutex<tokio::net::unix::OwnedWriteHalf>>,
69 pending_requests: PendingRequestMap,
70 message_id: Arc<Mutex<u64>>,
71 alive: Arc<std::sync::atomic::AtomicBool>,
72}
73
74impl ExtensionIpc {
75 pub fn new(stream: UnixStream) -> Self {
76 let (read_half, write_half) = stream.into_split();
77 let write = Arc::new(Mutex::new(write_half));
78 let pending_requests: PendingRequestMap =
79 Arc::new(Mutex::new(HashMap::new()));
80 let message_id = Arc::new(Mutex::new(0u64));
81 let alive = Arc::new(std::sync::atomic::AtomicBool::new(true));
82
83 let pending_clone = Arc::clone(&pending_requests);
84 let alive_clone = Arc::clone(&alive);
85
86 tokio::spawn(async move {
87 let mut reader = BufReader::new(read_half);
88 loop {
89 match read_message(&mut reader).await {
90 Ok(msg) => {
91 let mut pending = pending_clone.lock().await;
92 if let Some(sender) = pending.remove(&msg.id) {
93 if msg.r#type.ends_with("-error") {
94 let error_msg = msg
95 .payload
96 .get("error")
97 .and_then(|e| e.as_str())
98 .unwrap_or("Unknown error")
99 .to_string();
100 let _ = sender.send(Err(error_msg));
101 } else {
102 let _ = sender.send(Ok(msg.payload));
103 }
104 }
105 }
106 Err(e) => {
107 if e.contains("Failed to read message length") {
108 break;
109 }
110 error!(error = %e, "Read error");
111 break;
112 }
113 }
114 }
115 alive_clone.store(false, std::sync::atomic::Ordering::Relaxed);
124 let mut pending = pending_clone.lock().await;
125 pending.clear();
126 });
127
128 Self { write, pending_requests, message_id, alive }
129 }
130
131 pub fn is_alive(&self) -> bool {
132 self.alive.load(std::sync::atomic::Ordering::Relaxed)
133 }
134
135 pub async fn request(&self, msg_type: &str, payload: Value) -> Result<Value, String> {
136 if !self.alive.load(std::sync::atomic::Ordering::Relaxed) {
140 return Err("IPC connection is not alive".to_string());
141 }
142
143 let id = {
144 let mut message_id = self.message_id.lock().await;
145 *message_id += 1;
146 format!("req_{}", *message_id)
147 };
148
149 let (tx, rx) = oneshot::channel();
150 {
151 let mut pending = self.pending_requests.lock().await;
152 pending.insert(id.clone(), tx);
153 }
154
155 let message = IPCMessage { id: id.clone(), r#type: msg_type.to_string(), payload };
156
157 {
158 let mut writer = self.write.lock().await;
159 write_message(&mut *writer, &message).await?;
160 }
161
162 match tokio::time::timeout(std::time::Duration::from_secs(30), rx).await {
173 Ok(Ok(result)) => result,
174 Ok(Err(_)) => {
175 let mut pending = self.pending_requests.lock().await;
179 pending.remove(&id);
180 Err("IPC connection closed while awaiting response".to_string())
181 }
182 Err(_) => {
183 let mut pending = self.pending_requests.lock().await;
189 pending.remove(&id);
190 Err(format!("IPC request '{}' timed out after 30s", msg_type))
191 }
192 }
193 }
194
195 pub async fn send(&self, msg_type: &str, payload: Value) -> Result<(), String> {
196 let id = {
197 let mut message_id = self.message_id.lock().await;
198 *message_id += 1;
199 format!("msg_{}", *message_id)
200 };
201
202 let message = IPCMessage { id, r#type: msg_type.to_string(), payload };
203
204 {
205 let mut writer = self.write.lock().await;
206 write_message(&mut *writer, &message).await?;
207 }
208
209 Ok(())
210 }
211}
212
213pub struct IncomingIpc {
214 listener: Arc<tokio::net::UnixListener>,
215 handler: Option<IncomingRequestHandler>,
216 running: Arc<Mutex<bool>>,
217 shutdown: Arc<Notify>,
218}
219
220impl IncomingIpc {
221 pub fn new(socket_path: &str) -> Result<Self, String> {
222 let path = socket_path
223 .strip_prefix("ipc://")
224 .ok_or_else(|| format!("Invalid IPC URL: {}", socket_path))?;
225
226 if std::path::Path::new(path).exists() {
227 let _ = std::fs::remove_file(path);
228 }
229
230 let listener = tokio::net::UnixListener::bind(path)
231 .map_err(|e| format!("Failed to bind IPC socket at {}: {}", path, e))?;
232
233 info!(path = path, "Listening on incoming IPC socket");
234
235 Ok(Self {
236 listener: Arc::new(listener),
237 handler: None,
238 running: Arc::new(Mutex::new(false)),
239 shutdown: Arc::new(Notify::new()),
240 })
241 }
242
243 pub fn set_handler(&mut self, handler: IncomingRequestHandler) {
244 self.handler = Some(handler);
245 }
246
247 pub async fn start(&self) -> Result<(), String> {
248 if self.handler.is_none() {
249 return Err("No handler set for incoming requests".to_string());
250 }
251
252 let mut running = self.running.lock().await;
253 if *running {
254 return Ok(());
255 }
256 *running = true;
257 drop(running);
258
259 let handler = self
260 .handler
261 .clone()
262 .ok_or_else(|| "No handler set for incoming requests".to_string())?;
263 let running_flag = Arc::clone(&self.running);
264 let listener = Arc::clone(&self.listener);
265 let shutdown = Arc::clone(&self.shutdown);
266
267 tokio::spawn(async move {
268 loop {
269 tokio::select! {
270 accept_result = listener.accept() => {
271 match accept_result {
272 Ok((stream, _)) => {
273 let handler = handler.clone();
274 let running_flag = Arc::clone(&running_flag);
275 tokio::spawn(async move {
276 if let Err(e) =
277 handle_incoming_connection(stream, handler, running_flag).await
278 {
279 error!(error = %e, "Connection handler error");
280 }
281 });
282 }
283 Err(e) => {
284 error!(error = %e, "Accept error");
285 tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
286 }
287 }
288 }
289 _ = shutdown.notified() => {
290 break;
291 }
292 }
293 }
294
295 let mut running = running_flag.lock().await;
296 *running = false;
297 info!("Stopped listening on incoming IPC socket");
298 });
299
300 Ok(())
301 }
302
303 pub async fn stop(&self) {
304 let mut running = self.running.lock().await;
305 *running = false;
306 drop(running);
307 self.shutdown.notify_waiters();
308 }
309}
310
311async fn handle_incoming_connection(
312 stream: tokio::net::UnixStream, handler: IncomingRequestHandler, running_flag: Arc<Mutex<bool>>,
313) -> Result<(), String> {
314 let (read_half, write_half) = stream.into_split();
315 let mut reader = BufReader::new(read_half);
316 let write = Arc::new(Mutex::new(write_half));
317
318 loop {
319 {
320 let running = running_flag.lock().await;
321 if !*running {
322 break;
323 }
324 }
325
326 let msg = match read_message(&mut reader).await {
327 Ok(msg) => msg,
328 Err(e) => {
329 if e.contains("Failed to read message length") {
330 break;
331 }
332 error!(error = %e, "Read error, closing connection");
333 break;
334 }
335 };
336
337 let response_payload = match handler(msg.r#type.clone(), msg.payload.clone()).await {
338 Ok(payload) => payload,
339 Err(e) => {
340 let error_response = IPCMessage {
341 id: msg.id.clone(),
342 r#type: format!("{}-error", msg.r#type),
343 payload: serde_json::json!({ "error": e }),
344 };
345
346 let mut writer = write.lock().await;
347 let _ = write_message(&mut *writer, &error_response).await;
348 continue;
349 }
350 };
351
352 let response = IPCMessage {
353 id: msg.id,
354 r#type: format!("{}-response", msg.r#type),
355 payload: response_payload,
356 };
357
358 let mut writer = write.lock().await;
359 let _ = write_message(&mut *writer, &response).await;
360 }
361
362 Ok(())
363}
364
365pub struct IpcManager {
366 outgoing: Arc<RwLock<HashMap<String, Arc<ExtensionIpc>>>>,
367 incoming: Arc<RwLock<HashMap<String, Arc<IncomingIpc>>>>,
368}
369
370impl Default for IpcManager {
371 fn default() -> Self {
372 Self::new()
373 }
374}
375
376impl IpcManager {
377 pub fn new() -> Self {
378 Self {
379 outgoing: Arc::new(RwLock::new(HashMap::new())),
380 incoming: Arc::new(RwLock::new(HashMap::new())),
381 }
382 }
383
384 pub async fn connect_outgoing(&self, id: &str, socket_path: &str) -> Result<(), String> {
385 debug!(id = id, socket_path = socket_path, "Connecting outgoing");
386
387 let path = socket_path
388 .strip_prefix("ipc://")
389 .ok_or_else(|| format!("Invalid IPC URL: {}", socket_path))?;
390
391 let max_attempts = 10;
392 for attempt in 1..=max_attempts {
393 match UnixStream::connect(path).await {
394 Ok(stream) => {
395 let ipc = Arc::new(ExtensionIpc::new(stream));
396 let mut conns = self.outgoing.write().await;
397 conns.insert(id.to_string(), ipc);
398 info!(id = id, "Outgoing connection established");
399 return Ok(());
400 }
401 Err(_) if attempt < max_attempts => {
402 let delay = tokio::time::Duration::from_millis(attempt as u64 * 200);
403 tokio::time::sleep(delay).await;
404 continue;
405 }
406 Err(e) => {
407 return Err(format!(
408 "Failed to connect after {} attempts: {}",
409 max_attempts, e
410 ));
411 }
412 }
413 }
414
415 Err("Failed to connect".to_string())
416 }
417
418 pub async fn setup_incoming(
419 &self, id: &str, socket_path: &str, handler: IncomingRequestHandler,
420 ) -> Result<(), String> {
421 debug!(id = id, socket_path = socket_path, "Setting up incoming");
422
423 let mut ipc = IncomingIpc::new(socket_path)?;
424 ipc.set_handler(handler);
425
426 let ipc_arc = Arc::new(ipc);
427 ipc_arc.start().await?;
428
429 let mut conns = self.incoming.write().await;
430 conns.insert(id.to_string(), ipc_arc);
431
432 info!(id = id, "Incoming connection ready");
433 Ok(())
434 }
435
436 pub async fn disconnect(&self, id: &str) {
437 let mut outgoing = self.outgoing.write().await;
438 let mut incoming = self.incoming.write().await;
439
440 outgoing.remove(id);
441 if let Some(incoming) = incoming.remove(id) {
442 incoming.stop().await;
443 }
444 }
445
446 pub async fn is_connected(&self, id: &str) -> bool {
447 let conns = self.outgoing.read().await;
448 conns.get(id).is_some_and(|ipc| ipc.is_alive())
449 }
450
451 pub async fn reconnect_outgoing(&self, id: &str, socket_path: &str) -> Result<(), String> {
452 {
453 let mut conns = self.outgoing.write().await;
454 if let Some(old) = conns.remove(id) {
455 drop(old);
456 }
457 }
458 self.connect_outgoing(id, socket_path).await
459 }
460
461 pub async fn request(&self, id: &str, msg_type: &str, payload: Value) -> Result<Value, String> {
462 let ipc = self
463 .get_outgoing(id)
464 .await
465 .ok_or_else(|| format!("Extension host '{}' not connected", id))?;
466 ipc.request(msg_type, payload).await
467 }
468
469 pub async fn get_outgoing(&self, id: &str) -> Option<Arc<ExtensionIpc>> {
470 let conns = self.outgoing.read().await;
471 conns.get(id).cloned()
472 }
473
474 pub async fn connect(&self, id: &str, socket_path: &str) -> Result<(), String> {
475 self.connect_outgoing(id, socket_path).await
476 }
477
478 pub async fn get(&self, id: &str) -> Option<Arc<ExtensionIpc>> {
479 self.get_outgoing(id).await
480 }
481}