async_lsp/
concurrency.rs

1//! Incoming request multiplexing limits and cancellation.
2//!
3//! *Applies to both Language Servers and Language Clients.*
4//!
5//! Note that the [`crate::MainLoop`] can poll multiple ongoing requests
6//! out-of-box, while this middleware is to provides these additional features:
7//! 1. Limit concurrent incoming requests to at most `max_concurrency`.
8//! 2. Cancellation of incoming requests via client notification `$/cancelRequest`.
9use std::collections::HashMap;
10use std::future::Future;
11use std::num::NonZeroUsize;
12use std::ops::ControlFlow;
13use std::pin::Pin;
14use std::sync::{Arc, Weak};
15use std::task::{Context, Poll};
16use std::thread::available_parallelism;
17
18use futures::stream::{AbortHandle, Abortable};
19use futures::task::AtomicWaker;
20use lsp_types::notification::{self, Notification};
21use pin_project_lite::pin_project;
22use tower_layer::Layer;
23use tower_service::Service;
24
25use crate::{
26    AnyEvent, AnyNotification, AnyRequest, ErrorCode, LspService, RequestId, ResponseError, Result,
27};
28
29/// The middleware for incoming request multiplexing limits and cancellation.
30///
31/// See [module level documentations](self) for details.
32pub struct Concurrency<S> {
33    service: S,
34    max_concurrency: NonZeroUsize,
35    /// A specialized single-acquire-multiple-release semaphore, using `Arc::weak_count` as tokens.
36    semaphore: Arc<AtomicWaker>,
37    ongoing: HashMap<RequestId, AbortHandle>,
38}
39
40define_getters!(impl[S] Concurrency<S>, service: S);
41
42impl<S: LspService> Service<AnyRequest> for Concurrency<S>
43where
44    S::Error: From<ResponseError>,
45{
46    type Response = S::Response;
47    type Error = S::Error;
48    type Future = ResponseFuture<S::Future>;
49
50    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
51        if Arc::weak_count(&self.semaphore) >= self.max_concurrency.get() {
52            // Implicit `Acquire`.
53            self.semaphore.register(cx.waker());
54            // No guards dropped between the check and register?
55            if Arc::weak_count(&self.semaphore) >= self.max_concurrency.get() {
56                return Poll::Pending;
57            }
58        }
59
60        // Here we have `weak_count < max_concurrency`. The service is ready for new calls.
61        Poll::Ready(Ok(()))
62    }
63
64    fn call(&mut self, req: AnyRequest) -> Self::Future {
65        let guard = SemaphoreGuard(Arc::downgrade(&self.semaphore));
66        debug_assert!(
67            Arc::weak_count(&self.semaphore) <= self.max_concurrency.get(),
68            "`poll_ready` is not called before `call`",
69        );
70
71        let (handle, registration) = AbortHandle::new_pair();
72
73        // Regularly purge completed or dead tasks. See also `AbortOnDrop` below.
74        // This costs 2*N time to remove at least N tasks, results in amortized O(1) time cost
75        // for each spawned task.
76        if self.ongoing.len() >= self.max_concurrency.get() * 2 {
77            self.ongoing.retain(|_, handle| !handle.is_aborted());
78        }
79        self.ongoing.insert(req.id.clone(), handle.clone());
80
81        let fut = self.service.call(req);
82        let fut = Abortable::new(fut, registration);
83        ResponseFuture {
84            fut,
85            _abort_on_drop: AbortOnDrop(handle),
86            _guard: guard,
87        }
88    }
89}
90
91struct SemaphoreGuard(Weak<AtomicWaker>);
92
93impl Drop for SemaphoreGuard {
94    fn drop(&mut self) {
95        if let Some(sema) = self.0.upgrade() {
96            if let Some(waker) = sema.take() {
97                // Return the token first.
98                drop(sema);
99                // Wake up `poll_ready`. Implicit "Release".
100                waker.wake();
101            }
102        }
103    }
104}
105
106/// By default, the `AbortHandle` only transfers information from it to `Abortable<_>`, not in
107/// reverse. But we want to set the flag on drop (either success or failure), so that the `ongoing`
108/// map can be purged regularly without bloating indefinitely.
109struct AbortOnDrop(AbortHandle);
110
111impl Drop for AbortOnDrop {
112    fn drop(&mut self) {
113        self.0.abort();
114    }
115}
116
117pin_project! {
118    /// The [`Future`] type used by the [`Concurrency`] middleware.
119    pub struct ResponseFuture<Fut> {
120        #[pin]
121        fut: Abortable<Fut>,
122        // NB. Comes before `SemaphoreGuard`. So that when the guard wake up the caller, it is able
123        // to purge the current future from `ongoing` map immediately.
124        _abort_on_drop: AbortOnDrop,
125        _guard: SemaphoreGuard,
126    }
127}
128
129impl<Fut, Response, Error> Future for ResponseFuture<Fut>
130where
131    Fut: Future<Output = Result<Response, Error>>,
132    Error: From<ResponseError>,
133{
134    type Output = Fut::Output;
135
136    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
137        match self.project().fut.poll(cx) {
138            Poll::Pending => Poll::Pending,
139            Poll::Ready(Ok(inner_ret)) => Poll::Ready(inner_ret),
140            Poll::Ready(Err(_aborted)) => Poll::Ready(Err(ResponseError {
141                code: ErrorCode::REQUEST_CANCELLED,
142                message: "Client cancelled the request".into(),
143                data: None,
144            }
145            .into())),
146        }
147    }
148}
149
150impl<S: LspService> LspService for Concurrency<S>
151where
152    S::Error: From<ResponseError>,
153{
154    fn notify(&mut self, notif: AnyNotification) -> ControlFlow<Result<()>> {
155        if notif.method == notification::Cancel::METHOD {
156            if let Ok(params) = serde_json::from_value::<lsp_types::CancelParams>(notif.params) {
157                self.ongoing.remove(&params.id);
158            }
159            return ControlFlow::Continue(());
160        }
161        self.service.notify(notif)
162    }
163
164    fn emit(&mut self, event: AnyEvent) -> ControlFlow<Result<()>> {
165        self.service.emit(event)
166    }
167}
168
169/// The builder of [`Concurrency`] middleware.
170///
171/// It's [`Default`] configuration has `max_concurrency` of the result of
172/// [`std::thread::available_parallelism`], fallback to `1` if it fails.
173///
174/// See [module level documentations](self) for details.
175#[derive(Clone, Debug)]
176#[must_use]
177pub struct ConcurrencyBuilder {
178    max_concurrency: NonZeroUsize,
179}
180
181impl Default for ConcurrencyBuilder {
182    fn default() -> Self {
183        Self::new(available_parallelism().unwrap_or(NonZeroUsize::new(1).unwrap()))
184    }
185}
186
187impl ConcurrencyBuilder {
188    /// Create the middleware with concurrency limit `max_concurrency`.
189    pub fn new(max_concurrency: NonZeroUsize) -> Self {
190        Self { max_concurrency }
191    }
192}
193
194/// A type alias of [`ConcurrencyBuilder`] conforming to the naming convention of [`tower_layer`].
195pub type ConcurrencyLayer = ConcurrencyBuilder;
196
197impl<S> Layer<S> for ConcurrencyBuilder {
198    type Service = Concurrency<S>;
199
200    fn layer(&self, inner: S) -> Self::Service {
201        Concurrency {
202            service: inner,
203            max_concurrency: self.max_concurrency,
204            semaphore: Arc::new(AtomicWaker::new()),
205            // See `Concurrency::call` for why the factor 2.
206            ongoing: HashMap::with_capacity(
207                self.max_concurrency
208                    .get()
209                    .checked_mul(2)
210                    .expect("max_concurrency overflow"),
211            ),
212        }
213    }
214}