1pub mod rpc;
2pub mod util;
3
4use async_trait::async_trait;
5use futures::{
6 channel::oneshot, future::FusedFuture, lock::Mutex as AsyncMutex, sink::Sink, Future,
7 FutureExt, SinkExt,
8};
9use handler::Handler;
10use lsp_types::{
11 notification::{self, Notification},
12 request as req,
13 request::Request,
14 NumberOrString,
15};
16use serde::{de::DeserializeOwned, Serialize};
17use std::{
18 collections::HashMap,
19 io, mem,
20 pin::Pin,
21 sync::{
22 atomic::{AtomicBool, Ordering},
23 Arc, Mutex,
24 },
25 task::{Poll, Waker},
26};
27use tracing::Instrument;
28
29mod handler;
30
31#[cfg(any(feature = "tokio-stdio", feature = "tokio-tcp"))]
32pub mod listen;
33
34#[derive(Debug, Clone, Default)]
35struct Cancellation {
36 cancelled: Arc<AtomicBool>,
37 waker: Arc<Mutex<Option<Waker>>>,
38}
39
40impl Cancellation {
41 pub fn token(&self) -> CancelToken {
42 CancelToken {
43 cancelled: self.cancelled.clone(),
44 waker_set: Arc::new(AtomicBool::new(false)),
45 waker: self.waker.clone(),
46 }
47 }
48
49 pub fn cancel(&mut self) {
50 self.cancelled.store(true, Ordering::SeqCst);
51
52 if let Some(w) = (*self.waker.lock().unwrap()).take() {
53 w.wake();
54 }
55 }
56}
57
58#[derive(Debug, Clone)]
59pub struct CancelToken {
60 cancelled: Arc<AtomicBool>,
61 waker_set: Arc<AtomicBool>,
62 waker: Arc<Mutex<Option<Waker>>>,
63}
64
65impl CancelToken {
66 pub fn is_cancelled(&self) -> bool {
67 self.cancelled.load(Ordering::SeqCst)
68 }
69
70 pub fn as_err(&mut self) -> CancelTokenErr {
71 CancelTokenErr(self)
72 }
73}
74
75impl Future for CancelToken {
76 type Output = ();
77
78 #[allow(unused_mut)] fn poll(
80 mut self: std::pin::Pin<&mut Self>,
81 cx: &mut std::task::Context<'_>,
82 ) -> Poll<Self::Output> {
83 if self.cancelled.load(Ordering::SeqCst) {
84 Poll::Ready(())
85 } else {
86 if !self.waker_set.load(Ordering::SeqCst) {
87 *self.waker.lock().unwrap() = Some(cx.waker().clone());
88 }
89
90 Poll::Pending
91 }
92 }
93}
94
95impl FusedFuture for CancelToken {
96 fn is_terminated(&self) -> bool {
97 false
98 }
99}
100
101pub struct CancelTokenErr<'t>(&'t mut CancelToken);
102
103impl Future for CancelTokenErr<'_> {
104 type Output = Result<(), rpc::Error>;
105 fn poll(
106 mut self: std::pin::Pin<&mut Self>,
107 cx: &mut std::task::Context<'_>,
108 ) -> Poll<Self::Output> {
109 match self.0.poll_unpin(cx) {
110 Poll::Ready(_) => Poll::Ready(Err(rpc::Error::request_cancelled())),
111 Poll::Pending => Poll::Pending,
112 }
113 }
114}
115
116impl FusedFuture for CancelTokenErr<'_> {
117 fn is_terminated(&self) -> bool {
118 false
119 }
120}
121
122#[async_trait(?Send)]
123pub trait ResponseWriter: Sized {
124 async fn write_response<R: Serialize>(
125 mut self,
126 response: &rpc::Response<R>,
127 ) -> Result<(), io::Error>;
128}
129
130#[async_trait(?Send)]
131pub trait RequestWriter {
132 async fn write_request<
133 R: Request<Params = P>,
134 P: Serialize + DeserializeOwned + core::fmt::Debug,
135 >(
136 &mut self,
137 params: Option<R::Params>,
138 ) -> Result<rpc::Response<R::Result>, io::Error>;
139
140 async fn write_notification<
141 N: Notification<Params = P>,
142 P: Serialize + DeserializeOwned + core::fmt::Debug,
143 >(
144 &mut self,
145 params: Option<N::Params>,
146 ) -> Result<(), io::Error>;
147
148 async fn cancel(&mut self) -> Result<(), io::Error>;
149}
150
151trait NewTrait: Future<Output = ()> {}
152impl<T> NewTrait for T where T: Future<Output = ()> {}
153
154type DeferredTasks = Arc<AsyncMutex<Vec<Pin<Box<dyn NewTrait>>>>>;
155
156#[derive(Clone)]
157pub struct Context<W: Clone> {
158 inner: Arc<AsyncMutex<Inner<W>>>,
159 cancel_token: CancelToken,
160 last_req_id: Option<rpc::RequestId>, rw: Arc<AsyncMutex<Box<dyn MessageWriter>>>,
162 world: W,
163 deferred: DeferredTasks,
164}
165
166impl<W: Clone> std::ops::Deref for Context<W> {
167 type Target = W;
168
169 fn deref(&self) -> &Self::Target {
170 &self.world
171 }
172}
173
174impl<W: Clone> Context<W> {
175 pub async fn is_initialized(&self) -> bool {
176 self.inner.lock().await.initialized
177 }
178
179 pub async fn is_shutting_down(&self) -> bool {
180 self.inner.lock().await.shutting_down
181 }
182
183 pub fn world(&self) -> &W {
184 &self.world
185 }
186
187 pub fn cancel_token(&mut self) -> &mut CancelToken {
188 &mut self.cancel_token
189 }
190
191 pub async fn defer<F: Future<Output = ()> + 'static>(&self, fut: F) {
197 self.deferred.lock().await.push(Box::pin(fut));
198 }
199}
200
201#[async_trait(?Send)]
202impl<W: Clone> RequestWriter for Context<W> {
203 #[tracing::instrument(level = tracing::Level::TRACE, skip(self))]
204 async fn write_request<
205 R: Request<Params = P>,
206 P: Serialize + DeserializeOwned + core::fmt::Debug,
207 >(
208 &mut self,
209 params: Option<R::Params>,
210 ) -> Result<rpc::Response<R::Result>, io::Error> {
211 let mut inner = self.inner.lock().await;
212 let req_id = inner.next_request_id;
213 inner.next_request_id += 1;
214
215 let mut rw = self.rw.lock().await;
216
217 let id = NumberOrString::Number(req_id);
218
219 let request = rpc::Request::new()
220 .with_id(id.clone().into())
221 .with_method(R::METHOD)
222 .with_params(params);
223
224 let span = tracing::debug_span!("sending request", ?request);
225
226 rw.send(request.into_message()).instrument(span).await?;
227
228 self.last_req_id = Some(id.clone());
229
230 let (send, recv) = oneshot::channel();
231 inner.requests.insert(id, send);
232
233 drop(inner);
234
235 let res = recv.await.unwrap();
236
237 tracing::trace!(response = ?res, "received response");
238
239 self.last_req_id = None;
240
241 Ok(res.into_params())
242 }
243
244 #[tracing::instrument(level = tracing::Level::TRACE, skip(self))]
245 async fn write_notification<
246 N: Notification<Params = P>,
247 P: Serialize + DeserializeOwned + core::fmt::Debug,
248 >(
249 &mut self,
250 params: Option<N::Params>,
251 ) -> Result<(), io::Error> {
252 let mut rw = self.rw.lock().await;
253 rw.send(
254 rpc::Request::new()
255 .with_method(N::METHOD)
256 .with_params(params)
257 .into_message(),
258 )
259 .await
260 }
261
262 async fn cancel(&mut self) -> Result<(), io::Error> {
263 if let Some(id) = Option::take(&mut self.last_req_id) {
264 self.write_notification::<notification::Cancel, _>(Some(lsp_types::CancelParams { id }))
265 .await
266 } else {
267 Ok(())
268 }
269 }
270}
271
272pub trait MessageWriter: Sink<rpc::Message, Error = io::Error> + Unpin {}
273impl<T: Sink<rpc::Message, Error = io::Error> + Unpin> MessageWriter for T {}
274
275struct Inner<W: Clone> {
276 next_request_id: i32,
277 initialized: bool,
278 shutting_down: bool,
279 handlers: HashMap<String, Box<dyn Handler<W>>>,
280 tasks: HashMap<rpc::RequestId, Cancellation>,
281 requests: HashMap<rpc::RequestId, oneshot::Sender<rpc::Response<serde_json::Value>>>,
282}
283
284impl<W: Clone> Inner<W> {
285 fn task_done(&mut self, id: &rpc::RequestId) {
286 if let Some(mut t) = self.tasks.remove(id) {
287 t.cancel();
288 tracing::trace!(?id, "task completed");
289 }
290 }
291}
292
293pub struct Server<W: Clone> {
294 inner: Arc<AsyncMutex<Inner<W>>>,
295}
296
297impl<W: Clone> Server<W> {
298 #[allow(clippy::new_ret_no_self)]
299 pub fn new() -> ServerBuilder<W> {
300 ServerBuilder {
301 inner: Inner {
302 next_request_id: 0,
303 initialized: false,
304 shutting_down: false,
305 handlers: HashMap::new(),
306 tasks: HashMap::new(),
307 requests: HashMap::new(),
308 },
309 }
310 }
311
312 pub fn handle_message(
313 &self,
314 world: W,
315 message: rpc::Message,
316 writer: impl MessageWriter + Clone + 'static,
317 ) -> impl Future<Output = Result<(), io::Error>> {
318 let inner = self.inner.clone();
319
320 async move {
321 if message.is_response() {
322 Server::handle_response(inner, message.into_response()).await;
323 Ok(())
324 } else {
325 Server::handle_request(inner, world, message.into_request(), writer).await
326 }
327 }
328 }
329
330 pub async fn is_shutting_down(&self) -> bool {
331 self.inner.lock().await.shutting_down
332 }
333
334 #[tracing::instrument(level = tracing::Level::TRACE)]
335 async fn handle_response(
336 inner: Arc<AsyncMutex<Inner<W>>>,
337 response: rpc::Response<serde_json::Value>,
338 ) {
339 if let Some(sender) = inner.lock().await.requests.remove(&response.id) {
340 sender.send(response).ok();
341 } else {
342 tracing::error!(?response, "unexpected response")
343 }
344 }
345
346 #[tracing::instrument(level = tracing::Level::TRACE, skip(data, writer))]
347 async fn handle_request(
348 inner: Arc<AsyncMutex<Inner<W>>>,
349 data: W,
350 request: rpc::Request<serde_json::Value>,
351 mut writer: impl MessageWriter + Clone + 'static,
352 ) -> Result<(), io::Error> {
353 if &request.jsonrpc != "2.0" {
354 tracing::error!("JSON-RPC message version is not 2.0");
355 return writer
356 .send(
357 rpc::Response::error(
358 rpc::Error::invalid_request()
359 .with_data("only JSON-RPC version 2.0 is accepted"),
360 )
361 .into_message(),
362 )
363 .await;
364 }
365
366 if request.id.is_some() {
367 tracing::debug!(
368 id = ?request.id.as_ref().unwrap(),
369 method = %request.method,
370 "request received"
371 );
372 let mut s = inner.lock().await;
373
374 if s.shutting_down {
375 tracing::warn!(
376 id = ?request.id.as_ref().unwrap(),
377 method = %request.method,
378 "received request while shutting down"
379 );
380
381 writer
382 .send(
383 rpc::Response::error(
384 rpc::Error::invalid_request().with_data("server is shutting down"),
385 )
386 .into_message(),
387 )
388 .await?;
389 return Ok(());
390 }
391
392 if request.method == req::Shutdown::METHOD {
393 tracing::info!(
394 id = ?request.id.as_ref().unwrap(),
395 method = %request.method,
396 "received shutdown request"
397 );
398
399 s.shutting_down = true;
400 }
401
402 let is_initialize = request.method == req::Initialize::METHOD;
403
404 if !s.initialized && !is_initialize {
405 tracing::error!(
406 id = ?request.id.as_ref().unwrap(),
407 method = %request.method,
408 "server not yet initialized"
409 );
410
411 writer
412 .send(rpc::Response::error(rpc::Error::server_not_initialized()).into_message())
413 .await?;
414 return Ok(());
415 }
416
417 if s.handlers.contains_key(&request.method) {
418 let mut handler = s.handlers.get_mut(&request.method).unwrap().clone();
419
420 let id = request.id.clone().unwrap();
421
422 drop(s);
424
425 let ctx = Server::create_context(
426 inner.clone(),
427 Arc::new(AsyncMutex::new(Box::new(writer.clone()))),
428 data,
429 &request,
430 )
431 .await;
432
433 let handler_span = tracing::trace_span!(
434 "request handler",
435 method = %request.method,
436 );
437
438 let method = request.method.clone();
439
440 handler
441 .handle(ctx.clone(), request, Some(&mut writer))
442 .instrument(handler_span)
443 .await;
444
445 let deferred = mem::take(&mut (*ctx.deferred.lock().await));
446
447 for d in deferred {
448 let deferred_span = tracing::trace_span!(
449 "deferred task",
450 %method,
451 );
452
453 d.instrument(deferred_span).await
454 }
455
456 let mut s = inner.lock().await;
457
458 s.task_done(&id);
459 if is_initialize {
460 s.initialized = true;
461 }
462 drop(s);
463
464 Ok(())
465 } else if request.method == req::Shutdown::METHOD {
466 writer
468 .send(
469 rpc::Response::success(())
470 .with_request_id(request.id.unwrap())
471 .into_message(),
472 )
473 .await
474 } else {
475 tracing::error!(
476 method = %request.method,
477 "no request handler registered"
478 );
479
480 writer
481 .send(
482 rpc::Response::error(rpc::Error::method_not_found())
483 .with_request_id(request.id.unwrap())
484 .into_message(),
485 )
486 .await
487 }
488 } else {
489 tracing::debug!(
490 method = %request.method,
491 "notification received"
492 );
493
494 if request.method == lsp_types::notification::Cancel::METHOD {
495 if let Some(p) = request.params {
496 if let Ok(c) = serde_json::from_value::<lsp_types::CancelParams>(p) {
497 inner.lock().await.task_done(&c.id);
498 }
499 }
500 return Ok(());
501 }
502
503 let mut s = inner.lock().await;
504
505 if s.handlers.contains_key(&request.method) {
506 let mut handler = s.handlers.get_mut(&request.method).unwrap().clone();
507 drop(s);
508
509 let ctx = Server::create_context(
510 inner,
511 Arc::new(AsyncMutex::new(Box::new(writer))),
512 data,
513 &request,
514 )
515 .await;
516
517 let handler_span = tracing::trace_span!(
518 "notification handler",
519 method = %request.method,
520 );
521
522 let method = request.method.clone();
523
524 handler
525 .handle(ctx.clone(), request, None)
526 .instrument(handler_span)
527 .await;
528
529 let deferred = mem::take(&mut (*ctx.deferred.lock().await));
530
531 for d in deferred {
532 let deferred_span = tracing::trace_span!(
533 "deferred task",
534 %method,
535 );
536
537 d.instrument(deferred_span).await
538 }
539 } else {
540 tracing::warn!(
541 method = %request.method,
542 "no notification handler registered"
543 );
544 }
545
546 Ok(())
547 }
548 }
549
550 async fn create_context<D>(
551 inner: Arc<AsyncMutex<Inner<W>>>,
552 rw: Arc<AsyncMutex<Box<dyn MessageWriter>>>,
553 world: W,
554 req: &rpc::Request<D>,
555 ) -> Context<W> {
556 let cancel = Cancellation::default();
557 let cancel_token = cancel.token();
558
559 if let Some(id) = &req.id {
560 inner.lock().await.tasks.insert(id.clone(), cancel);
561 }
562
563 Context {
564 cancel_token,
565 world,
566 inner,
567 last_req_id: None,
568 rw,
569 deferred: Default::default(),
570 }
571 }
572}
573
574pub struct ServerBuilder<W: Clone + 'static> {
575 inner: Inner<W>,
576}
577
578impl<W: Clone + 'static> ServerBuilder<W> {
579 pub fn on_notification<N, F>(mut self, handler: fn(Context<W>, Params<N::Params>) -> F) -> Self
580 where
581 N: Notification + 'static,
582 F: Future<Output = ()> + 'static,
583 {
584 self.inner.handlers.insert(
585 N::METHOD.into(),
586 Box::new(handler::NotificationHandler::<N, _, _>::new(handler)),
587 );
588 tracing::info!(method = N::METHOD, "registered notification handler");
589 self
590 }
591
592 pub fn on_request<R, F>(mut self, handler: fn(Context<W>, Params<R::Params>) -> F) -> Self
593 where
594 R: Request + 'static,
595 F: Future<Output = Result<R::Result, rpc::Error>> + 'static,
596 {
597 self.inner.handlers.insert(
598 R::METHOD.into(),
599 Box::new(handler::RequestHandler::<R, _, _>::new(handler)),
600 );
601 tracing::info!(method = R::METHOD, "registered request handler");
602 self
603 }
604
605 pub fn build(self) -> Server<W> {
606 Server {
607 inner: Arc::new(AsyncMutex::new(self.inner)),
608 }
609 }
610}
611
612pub struct Params<P>(Option<P>);
613
614impl<P> Params<P> {
615 pub fn optional(self) -> Option<P> {
616 self.0
617 }
618
619 pub fn required(self) -> Result<P, rpc::Error> {
620 match self.0 {
621 None => Err(rpc::Error::invalid_params().with_data("params are required")),
622 Some(p) => Ok(p),
623 }
624 }
625}
626
627impl<P> From<Option<P>> for Params<P> {
628 fn from(p: Option<P>) -> Self {
629 Self(p)
630 }
631}