Skip to main content

dscode_extension_host/
ipc.rs

1use serde::{Deserialize, Serialize};
2use serde_json::Value;
3use std::collections::HashMap;
4use std::future::Future;
5use std::pin::Pin;
6use std::sync::Arc;
7use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader};
8use tokio::net::UnixStream;
9use tokio::sync::{oneshot, Mutex, Notify, RwLock};
10use tracing::{debug, error, info};
11
12const MAX_MESSAGE_SIZE: usize = 50 * 1024 * 1024;
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub(crate) struct IPCMessage {
16    pub(crate) id: String,
17    pub(crate) r#type: String,
18    pub(crate) payload: Value,
19}
20
21pub type IncomingRequestHandler = Arc<
22    dyn Fn(String, Value) -> Pin<Box<dyn Future<Output = Result<Value, String>> + Send>>
23        + Send
24        + Sync,
25>;
26
27async fn read_message<R: AsyncReadExt + Unpin>(reader: &mut R) -> Result<IPCMessage, String> {
28    let mut len_buf = [0u8; 4];
29    reader
30        .read_exact(&mut len_buf)
31        .await
32        .map_err(|e| format!("Failed to read message length: {}", e))?;
33    let len = u32::from_be_bytes(len_buf) as usize;
34    if len == 0 {
35        return Err("Zero-length message".to_string());
36    }
37    if len > MAX_MESSAGE_SIZE {
38        return Err(format!("Message too large: {} bytes", len));
39    }
40    let mut body = vec![0u8; len];
41    reader
42        .read_exact(&mut body)
43        .await
44        .map_err(|e| format!("Failed to read message body: {}", e))?;
45    let msg: IPCMessage =
46        serde_json::from_slice(&body).map_err(|e| format!("Failed to parse IPC message: {}", e))?;
47    Ok(msg)
48}
49
50async fn write_message<W: AsyncWriteExt + Unpin>(
51    writer: &mut W, msg: &IPCMessage,
52) -> Result<(), String> {
53    let body =
54        serde_json::to_vec(msg).map_err(|e| format!("Failed to serialize IPC message: {}", e))?;
55    let len = body.len() as u32;
56    writer
57        .write_all(&len.to_be_bytes())
58        .await
59        .map_err(|e| format!("Failed to write message length: {}", e))?;
60    writer.write_all(&body).await.map_err(|e| format!("Failed to write message body: {}", e))?;
61    writer.flush().await.map_err(|e| format!("Failed to flush message: {}", e))?;
62    Ok(())
63}
64
65type PendingRequestMap = Arc<Mutex<HashMap<String, oneshot::Sender<Result<Value, String>>>>>;
66
67pub struct ExtensionIpc {
68    write: Arc<Mutex<tokio::net::unix::OwnedWriteHalf>>,
69    pending_requests: PendingRequestMap,
70    message_id: Arc<Mutex<u64>>,
71    alive: Arc<std::sync::atomic::AtomicBool>,
72}
73
74impl ExtensionIpc {
75    pub fn new(stream: UnixStream) -> Self {
76        let (read_half, write_half) = stream.into_split();
77        let write = Arc::new(Mutex::new(write_half));
78        let pending_requests: PendingRequestMap =
79            Arc::new(Mutex::new(HashMap::new()));
80        let message_id = Arc::new(Mutex::new(0u64));
81        let alive = Arc::new(std::sync::atomic::AtomicBool::new(true));
82
83        let pending_clone = Arc::clone(&pending_requests);
84        let alive_clone = Arc::clone(&alive);
85
86        tokio::spawn(async move {
87            let mut reader = BufReader::new(read_half);
88            loop {
89                match read_message(&mut reader).await {
90                    Ok(msg) => {
91                        let mut pending = pending_clone.lock().await;
92                        if let Some(sender) = pending.remove(&msg.id) {
93                            if msg.r#type.ends_with("-error") {
94                                let error_msg = msg
95                                    .payload
96                                    .get("error")
97                                    .and_then(|e| e.as_str())
98                                    .unwrap_or("Unknown error")
99                                    .to_string();
100                                let _ = sender.send(Err(error_msg));
101                            } else {
102                                let _ = sender.send(Ok(msg.payload));
103                            }
104                        }
105                    }
106                    Err(e) => {
107                        if e.contains("Failed to read message length") {
108                            break;
109                        }
110                        error!(error = %e, "Read error");
111                        break;
112                    }
113                }
114            }
115            // When the read loop breaks (socket EOF or read error):
116            // 1. alive flag set to false (Relaxed ordering is sufficient —
117            //    eventual consistency is fine for a "dead connection" signal)
118            // 2. All pending requests are cleared by dropping their Senders,
119            //    which causes each Receiver to get RecvError
120            // 3. Any future request() calls will fail the alive check
121            // 4. The IpcManager still holds a reference to this ExtensionIpc —
122            //    it must be explicitly removed or replaced on reconnection
123            alive_clone.store(false, std::sync::atomic::Ordering::Relaxed);
124            let mut pending = pending_clone.lock().await;
125            pending.clear();
126        });
127
128        Self { write, pending_requests, message_id, alive }
129    }
130
131    pub fn is_alive(&self) -> bool {
132        self.alive.load(std::sync::atomic::Ordering::Relaxed)
133    }
134
135    pub async fn request(&self, msg_type: &str, payload: Value) -> Result<Value, String> {
136        // Check connection liveness before sending. This is a best-effort check —
137        // the connection could die between this check and the actual write.
138        // The timeout below protects against that race.
139        if !self.alive.load(std::sync::atomic::Ordering::Relaxed) {
140            return Err("IPC connection is not alive".to_string());
141        }
142
143        let id = {
144            let mut message_id = self.message_id.lock().await;
145            *message_id += 1;
146            format!("req_{}", *message_id)
147        };
148
149        let (tx, rx) = oneshot::channel();
150        {
151            let mut pending = self.pending_requests.lock().await;
152            pending.insert(id.clone(), tx);
153        }
154
155        let message = IPCMessage { id: id.clone(), r#type: msg_type.to_string(), payload };
156
157        {
158            let mut writer = self.write.lock().await;
159            write_message(&mut *writer, &message).await?;
160        }
161
162        // Request timeout: 30 seconds.
163        // Three possible outcomes for a pending request:
164        //   1. Response received: resolved normally via oneshot channel
165        //   2. Connection closed: oneshot Sender dropped, Receiver gets RecvError
166        //   3. Timeout: tokio::time::timeout fires, request cleaned up
167        //
168        // Without this timeout, if the extension host becomes unresponsive
169        // (infinite loop, deadlock) but the connection stays alive, the caller
170        // blocks forever. Connection close (outcome 2) only helps when the
171        // process actually crashes or the socket breaks.
172        match tokio::time::timeout(std::time::Duration::from_secs(30), rx).await {
173            Ok(Ok(result)) => result,
174            Ok(Err(_)) => {
175                // Sender dropped — connection closed while request was in flight.
176                // The read loop (spawned task) detected socket EOF and cleared
177                // pending requests by dropping all Senders.
178                let mut pending = self.pending_requests.lock().await;
179                pending.remove(&id);
180                Err("IPC connection closed while awaiting response".to_string())
181            }
182            Err(_) => {
183                // Timeout — extension host did not respond within 30 seconds.
184                // This can happen if the host is in an infinite loop, deadlocked,
185                // or simply overwhelmed. The request is removed from pending to
186                // prevent memory leaks. If a response arrives later (after timeout),
187                // it will be silently dropped (no matching pending entry).
188                let mut pending = self.pending_requests.lock().await;
189                pending.remove(&id);
190                Err(format!("IPC request '{}' timed out after 30s", msg_type))
191            }
192        }
193    }
194
195    pub async fn send(&self, msg_type: &str, payload: Value) -> Result<(), String> {
196        let id = {
197            let mut message_id = self.message_id.lock().await;
198            *message_id += 1;
199            format!("msg_{}", *message_id)
200        };
201
202        let message = IPCMessage { id, r#type: msg_type.to_string(), payload };
203
204        {
205            let mut writer = self.write.lock().await;
206            write_message(&mut *writer, &message).await?;
207        }
208
209        Ok(())
210    }
211}
212
213pub struct IncomingIpc {
214    listener: Arc<tokio::net::UnixListener>,
215    handler: Option<IncomingRequestHandler>,
216    running: Arc<Mutex<bool>>,
217    shutdown: Arc<Notify>,
218}
219
220impl IncomingIpc {
221    pub fn new(socket_path: &str) -> Result<Self, String> {
222        let path = socket_path
223            .strip_prefix("ipc://")
224            .ok_or_else(|| format!("Invalid IPC URL: {}", socket_path))?;
225
226        if std::path::Path::new(path).exists() {
227            let _ = std::fs::remove_file(path);
228        }
229
230        let listener = tokio::net::UnixListener::bind(path)
231            .map_err(|e| format!("Failed to bind IPC socket at {}: {}", path, e))?;
232
233        info!(path = path, "Listening on incoming IPC socket");
234
235        Ok(Self {
236            listener: Arc::new(listener),
237            handler: None,
238            running: Arc::new(Mutex::new(false)),
239            shutdown: Arc::new(Notify::new()),
240        })
241    }
242
243    pub fn set_handler(&mut self, handler: IncomingRequestHandler) {
244        self.handler = Some(handler);
245    }
246
247    pub async fn start(&self) -> Result<(), String> {
248        if self.handler.is_none() {
249            return Err("No handler set for incoming requests".to_string());
250        }
251
252        let mut running = self.running.lock().await;
253        if *running {
254            return Ok(());
255        }
256        *running = true;
257        drop(running);
258
259        let handler = self
260            .handler
261            .clone()
262            .ok_or_else(|| "No handler set for incoming requests".to_string())?;
263        let running_flag = Arc::clone(&self.running);
264        let listener = Arc::clone(&self.listener);
265        let shutdown = Arc::clone(&self.shutdown);
266
267        tokio::spawn(async move {
268            loop {
269                tokio::select! {
270                    accept_result = listener.accept() => {
271                        match accept_result {
272                            Ok((stream, _)) => {
273                                let handler = handler.clone();
274                                let running_flag = Arc::clone(&running_flag);
275                                tokio::spawn(async move {
276                                    if let Err(e) =
277                                        handle_incoming_connection(stream, handler, running_flag).await
278                                    {
279                                        error!(error = %e, "Connection handler error");
280                                    }
281                                });
282                            }
283                            Err(e) => {
284                                error!(error = %e, "Accept error");
285                                tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
286                            }
287                        }
288                    }
289                    _ = shutdown.notified() => {
290                        break;
291                    }
292                }
293            }
294
295            let mut running = running_flag.lock().await;
296            *running = false;
297            info!("Stopped listening on incoming IPC socket");
298        });
299
300        Ok(())
301    }
302
303    pub async fn stop(&self) {
304        let mut running = self.running.lock().await;
305        *running = false;
306        drop(running);
307        self.shutdown.notify_waiters();
308    }
309}
310
311async fn handle_incoming_connection(
312    stream: tokio::net::UnixStream, handler: IncomingRequestHandler, running_flag: Arc<Mutex<bool>>,
313) -> Result<(), String> {
314    let (read_half, write_half) = stream.into_split();
315    let mut reader = BufReader::new(read_half);
316    let write = Arc::new(Mutex::new(write_half));
317
318    loop {
319        {
320            let running = running_flag.lock().await;
321            if !*running {
322                break;
323            }
324        }
325
326        let msg = match read_message(&mut reader).await {
327            Ok(msg) => msg,
328            Err(e) => {
329                if e.contains("Failed to read message length") {
330                    break;
331                }
332                error!(error = %e, "Read error, closing connection");
333                break;
334            }
335        };
336
337        let response_payload = match handler(msg.r#type.clone(), msg.payload.clone()).await {
338            Ok(payload) => payload,
339            Err(e) => {
340                let error_response = IPCMessage {
341                    id: msg.id.clone(),
342                    r#type: format!("{}-error", msg.r#type),
343                    payload: serde_json::json!({ "error": e }),
344                };
345
346                let mut writer = write.lock().await;
347                let _ = write_message(&mut *writer, &error_response).await;
348                continue;
349            }
350        };
351
352        let response = IPCMessage {
353            id: msg.id,
354            r#type: format!("{}-response", msg.r#type),
355            payload: response_payload,
356        };
357
358        let mut writer = write.lock().await;
359        let _ = write_message(&mut *writer, &response).await;
360    }
361
362    Ok(())
363}
364
365pub struct IpcManager {
366    outgoing: Arc<RwLock<HashMap<String, Arc<ExtensionIpc>>>>,
367    incoming: Arc<RwLock<HashMap<String, Arc<IncomingIpc>>>>,
368}
369
370impl Default for IpcManager {
371    fn default() -> Self {
372        Self::new()
373    }
374}
375
376impl IpcManager {
377    pub fn new() -> Self {
378        Self {
379            outgoing: Arc::new(RwLock::new(HashMap::new())),
380            incoming: Arc::new(RwLock::new(HashMap::new())),
381        }
382    }
383
384    pub async fn connect_outgoing(&self, id: &str, socket_path: &str) -> Result<(), String> {
385        debug!(id = id, socket_path = socket_path, "Connecting outgoing");
386
387        let path = socket_path
388            .strip_prefix("ipc://")
389            .ok_or_else(|| format!("Invalid IPC URL: {}", socket_path))?;
390
391        let max_attempts = 10;
392        for attempt in 1..=max_attempts {
393            match UnixStream::connect(path).await {
394                Ok(stream) => {
395                    let ipc = Arc::new(ExtensionIpc::new(stream));
396                    let mut conns = self.outgoing.write().await;
397                    conns.insert(id.to_string(), ipc);
398                    info!(id = id, "Outgoing connection established");
399                    return Ok(());
400                }
401                Err(_) if attempt < max_attempts => {
402                    let delay = tokio::time::Duration::from_millis(attempt as u64 * 200);
403                    tokio::time::sleep(delay).await;
404                    continue;
405                }
406                Err(e) => {
407                    return Err(format!(
408                        "Failed to connect after {} attempts: {}",
409                        max_attempts, e
410                    ));
411                }
412            }
413        }
414
415        Err("Failed to connect".to_string())
416    }
417
418    pub async fn setup_incoming(
419        &self, id: &str, socket_path: &str, handler: IncomingRequestHandler,
420    ) -> Result<(), String> {
421        debug!(id = id, socket_path = socket_path, "Setting up incoming");
422
423        let mut ipc = IncomingIpc::new(socket_path)?;
424        ipc.set_handler(handler);
425
426        let ipc_arc = Arc::new(ipc);
427        ipc_arc.start().await?;
428
429        let mut conns = self.incoming.write().await;
430        conns.insert(id.to_string(), ipc_arc);
431
432        info!(id = id, "Incoming connection ready");
433        Ok(())
434    }
435
436    pub async fn disconnect(&self, id: &str) {
437        let mut outgoing = self.outgoing.write().await;
438        let mut incoming = self.incoming.write().await;
439
440        outgoing.remove(id);
441        if let Some(incoming) = incoming.remove(id) {
442            incoming.stop().await;
443        }
444    }
445
446    pub async fn is_connected(&self, id: &str) -> bool {
447        let conns = self.outgoing.read().await;
448        conns.get(id).is_some_and(|ipc| ipc.is_alive())
449    }
450
451    pub async fn reconnect_outgoing(&self, id: &str, socket_path: &str) -> Result<(), String> {
452        {
453            let mut conns = self.outgoing.write().await;
454            if let Some(old) = conns.remove(id) {
455                drop(old);
456            }
457        }
458        self.connect_outgoing(id, socket_path).await
459    }
460
461    pub async fn request(&self, id: &str, msg_type: &str, payload: Value) -> Result<Value, String> {
462        let ipc = self
463            .get_outgoing(id)
464            .await
465            .ok_or_else(|| format!("Extension host '{}' not connected", id))?;
466        ipc.request(msg_type, payload).await
467    }
468
469    pub async fn get_outgoing(&self, id: &str) -> Option<Arc<ExtensionIpc>> {
470        let conns = self.outgoing.read().await;
471        conns.get(id).cloned()
472    }
473
474    pub async fn connect(&self, id: &str, socket_path: &str) -> Result<(), String> {
475        self.connect_outgoing(id, socket_path).await
476    }
477
478    pub async fn get(&self, id: &str) -> Option<Arc<ExtensionIpc>> {
479        self.get_outgoing(id).await
480    }
481}