1use std::collections::HashMap;
10use std::future::Future;
11use std::num::NonZeroUsize;
12use std::ops::ControlFlow;
13use std::pin::Pin;
14use std::sync::{Arc, Weak};
15use std::task::{Context, Poll};
16use std::thread::available_parallelism;
17
18use futures::stream::{AbortHandle, Abortable};
19use futures::task::AtomicWaker;
20use lsp_types::notification::{self, Notification};
21use pin_project_lite::pin_project;
22use tower_layer::Layer;
23use tower_service::Service;
24
25use crate::{
26 AnyEvent, AnyNotification, AnyRequest, ErrorCode, LspService, RequestId, ResponseError, Result,
27};
28
29pub struct Concurrency<S> {
33 service: S,
34 max_concurrency: NonZeroUsize,
35 semaphore: Arc<AtomicWaker>,
37 ongoing: HashMap<RequestId, AbortHandle>,
38}
39
40define_getters!(impl[S] Concurrency<S>, service: S);
41
42impl<S: LspService> Service<AnyRequest> for Concurrency<S>
43where
44 S::Error: From<ResponseError>,
45{
46 type Response = S::Response;
47 type Error = S::Error;
48 type Future = ResponseFuture<S::Future>;
49
50 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
51 if Arc::weak_count(&self.semaphore) >= self.max_concurrency.get() {
52 self.semaphore.register(cx.waker());
54 if Arc::weak_count(&self.semaphore) >= self.max_concurrency.get() {
56 return Poll::Pending;
57 }
58 }
59
60 Poll::Ready(Ok(()))
62 }
63
64 fn call(&mut self, req: AnyRequest) -> Self::Future {
65 let guard = SemaphoreGuard(Arc::downgrade(&self.semaphore));
66 debug_assert!(
67 Arc::weak_count(&self.semaphore) <= self.max_concurrency.get(),
68 "`poll_ready` is not called before `call`",
69 );
70
71 let (handle, registration) = AbortHandle::new_pair();
72
73 if self.ongoing.len() >= self.max_concurrency.get() * 2 {
77 self.ongoing.retain(|_, handle| !handle.is_aborted());
78 }
79 self.ongoing.insert(req.id.clone(), handle.clone());
80
81 let fut = self.service.call(req);
82 let fut = Abortable::new(fut, registration);
83 ResponseFuture {
84 fut,
85 _abort_on_drop: AbortOnDrop(handle),
86 _guard: guard,
87 }
88 }
89}
90
91struct SemaphoreGuard(Weak<AtomicWaker>);
92
93impl Drop for SemaphoreGuard {
94 fn drop(&mut self) {
95 if let Some(sema) = self.0.upgrade() {
96 if let Some(waker) = sema.take() {
97 drop(sema);
99 waker.wake();
101 }
102 }
103 }
104}
105
106struct AbortOnDrop(AbortHandle);
110
111impl Drop for AbortOnDrop {
112 fn drop(&mut self) {
113 self.0.abort();
114 }
115}
116
117pin_project! {
118 pub struct ResponseFuture<Fut> {
120 #[pin]
121 fut: Abortable<Fut>,
122 _abort_on_drop: AbortOnDrop,
125 _guard: SemaphoreGuard,
126 }
127}
128
129impl<Fut, Response, Error> Future for ResponseFuture<Fut>
130where
131 Fut: Future<Output = Result<Response, Error>>,
132 Error: From<ResponseError>,
133{
134 type Output = Fut::Output;
135
136 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
137 match self.project().fut.poll(cx) {
138 Poll::Pending => Poll::Pending,
139 Poll::Ready(Ok(inner_ret)) => Poll::Ready(inner_ret),
140 Poll::Ready(Err(_aborted)) => Poll::Ready(Err(ResponseError {
141 code: ErrorCode::REQUEST_CANCELLED,
142 message: "Client cancelled the request".into(),
143 data: None,
144 }
145 .into())),
146 }
147 }
148}
149
150impl<S: LspService> LspService for Concurrency<S>
151where
152 S::Error: From<ResponseError>,
153{
154 fn notify(&mut self, notif: AnyNotification) -> ControlFlow<Result<()>> {
155 if notif.method == notification::Cancel::METHOD {
156 if let Ok(params) = serde_json::from_value::<lsp_types::CancelParams>(notif.params) {
157 self.ongoing.remove(¶ms.id);
158 }
159 return ControlFlow::Continue(());
160 }
161 self.service.notify(notif)
162 }
163
164 fn emit(&mut self, event: AnyEvent) -> ControlFlow<Result<()>> {
165 self.service.emit(event)
166 }
167}
168
169#[derive(Clone, Debug)]
176#[must_use]
177pub struct ConcurrencyBuilder {
178 max_concurrency: NonZeroUsize,
179}
180
181impl Default for ConcurrencyBuilder {
182 fn default() -> Self {
183 Self::new(available_parallelism().unwrap_or(NonZeroUsize::new(1).unwrap()))
184 }
185}
186
187impl ConcurrencyBuilder {
188 pub fn new(max_concurrency: NonZeroUsize) -> Self {
190 Self { max_concurrency }
191 }
192}
193
194pub type ConcurrencyLayer = ConcurrencyBuilder;
196
197impl<S> Layer<S> for ConcurrencyBuilder {
198 type Service = Concurrency<S>;
199
200 fn layer(&self, inner: S) -> Self::Service {
201 Concurrency {
202 service: inner,
203 max_concurrency: self.max_concurrency,
204 semaphore: Arc::new(AtomicWaker::new()),
205 ongoing: HashMap::with_capacity(
207 self.max_concurrency
208 .get()
209 .checked_mul(2)
210 .expect("max_concurrency overflow"),
211 ),
212 }
213 }
214}