1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
use std::io;
use std::ops::ControlFlow;
use std::task::{Context, Poll};

use lsp_types::request::{self, Request};
use tower_layer::Layer;
use tower_service::Service;

use crate::{
    AnyEvent, AnyNotification, AnyRequest, ClientSocket, Error, JsonValue, LspService,
    ResponseError, Result,
};

struct ClientProcessExited;

pub struct ClientProcessMonitor<S> {
    service: S,
    client: ClientSocket,
}

impl<S> ClientProcessMonitor<S> {
    #[must_use]
    pub fn new(service: S, client: ClientSocket) -> Self {
        Self { service, client }
    }
}

impl<S: LspService> Service<AnyRequest> for ClientProcessMonitor<S> {
    type Response = JsonValue;
    type Error = ResponseError;
    type Future = S::Future;

    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        self.service.poll_ready(cx)
    }

    fn call(&mut self, req: AnyRequest) -> Self::Future {
        if let Some(pid) = (|| -> Option<i32> {
            (req.method == request::Initialize::METHOD)
                .then_some(&req.params)?
                .as_object()?
                .get("processId")?
                .as_i64()?
                .try_into()
                .ok()
        })() {
            let client = self.client.clone();
            tokio::spawn(async move {
                if let Ok(()) = wait_for_pid(pid).await {
                    // Ignore channel close.
                    let _: Result<_, _> = client.emit(ClientProcessExited);
                }
            });
        }

        self.service.call(req)
    }
}

impl<S: LspService> LspService for ClientProcessMonitor<S> {
    fn notify(&mut self, notif: AnyNotification) -> ControlFlow<Result<()>> {
        self.service.notify(notif)
    }

    fn emit(&mut self, event: AnyEvent) -> ControlFlow<Result<()>> {
        match event.downcast::<ClientProcessExited>() {
            Ok(ClientProcessExited) => {
                ControlFlow::Break(Err(Error::Protocol("Client process exited".into())))
            }
            Err(event) => self.service.emit(event),
        }
    }
}

async fn wait_for_pid(pid: i32) -> io::Result<()> {
    use rustix::io::Errno;
    use rustix::process::{pidfd_open, Pid, PidfdFlags};
    use tokio::io::unix::{AsyncFd, AsyncFdReadyGuard};

    let pid = pid
        .try_into()
        .ok()
        .and_then(|pid| unsafe { Pid::from_raw(pid) })
        .ok_or_else(|| io::Error::new(io::ErrorKind::Other, format!("Invalid PID {pid}")))?;
    let pidfd = match pidfd_open(pid, PidfdFlags::NONBLOCK) {
        Ok(pidfd) => pidfd,
        // Already exited.
        Err(Errno::SRCH) => return Ok(()),
        Err(err) => return Err(err.into()),
    };

    let pidfd = AsyncFd::new(pidfd)?;
    let _guard: AsyncFdReadyGuard<'_, _> = pidfd.readable().await?;
    Ok(())
}

#[must_use]
pub struct ClientProcessMonitorLayer {
    client: ClientSocket,
}

impl ClientProcessMonitorLayer {
    pub fn new(client: ClientSocket) -> Self {
        Self { client }
    }
}

impl<S> Layer<S> for ClientProcessMonitorLayer {
    type Service = ClientProcessMonitor<S>;

    fn layer(&self, inner: S) -> Self::Service {
        ClientProcessMonitor::new(inner, self.client.clone())
    }
}