1#[cfg(feature = "runtime-agnostic")]
4use async_codec_lite::{FramedRead, FramedWrite};
5#[cfg(feature = "runtime-agnostic")]
6use futures::io::{AsyncRead, AsyncWrite};
7
8#[cfg(feature = "runtime-tokio")]
9use tokio::io::{AsyncRead, AsyncWrite};
10#[cfg(feature = "runtime-tokio")]
11use tokio_util::codec::{FramedRead, FramedWrite};
12
13use futures::channel::mpsc;
14use futures::{future, join, stream, FutureExt, Sink, SinkExt, Stream, StreamExt, TryFutureExt};
15use tower::Service;
16use tracing::error;
17
18use crate::codec::{LanguageServerCodec, ParseError};
19use crate::jsonrpc::{Error, Id, Message, Request, Response};
20use crate::service::{ClientSocket, RequestStream, ResponseSink};
21
22const DEFAULT_MAX_CONCURRENCY: usize = 4;
23const MESSAGE_QUEUE_SIZE: usize = 100;
24
25pub trait Loopback {
29 type RequestStream: Stream<Item = Request>;
31 type ResponseSink: Sink<Response> + Unpin;
33
34 fn split(self) -> (Self::RequestStream, Self::ResponseSink);
38}
39
40impl Loopback for ClientSocket {
41 type RequestStream = RequestStream;
42 type ResponseSink = ResponseSink;
43
44 #[inline]
45 fn split(self) -> (Self::RequestStream, Self::ResponseSink) {
46 self.split()
47 }
48}
49
50#[derive(Debug)]
52pub struct Server<I, O, L = ClientSocket> {
53 stdin: I,
54 stdout: O,
55 loopback: L,
56 max_concurrency: usize,
57}
58
59impl<I, O, L> Server<I, O, L>
60where
61 I: AsyncRead + Unpin,
62 O: AsyncWrite,
63 L: Loopback,
64 <L::ResponseSink as Sink<Response>>::Error: std::error::Error,
65{
66 pub fn new(stdin: I, stdout: O, socket: L) -> Self {
68 Server {
69 stdin,
70 stdout,
71 loopback: socket,
72 max_concurrency: DEFAULT_MAX_CONCURRENCY,
73 }
74 }
75
76 pub fn concurrency_level(mut self, max: usize) -> Self {
97 self.max_concurrency = max;
98 self
99 }
100
101 pub async fn serve<T>(self, mut service: T)
103 where
104 T: Service<Request, Response = Option<Response>> + 'static,
105 T::Error: Into<Box<dyn std::error::Error>>
106 {
107 let (client_requests, mut client_responses) = self.loopback.split();
108 let (client_requests, client_abort) = stream::abortable(client_requests);
109 let (mut responses_tx, responses_rx) = mpsc::channel(0);
110 let (mut server_tasks_tx, server_tasks_rx) = mpsc::channel(MESSAGE_QUEUE_SIZE);
111
112 let mut framed_stdin = FramedRead::new(self.stdin, LanguageServerCodec::default());
113 let framed_stdout = FramedWrite::new(self.stdout, LanguageServerCodec::default());
114
115 let process_server_tasks = server_tasks_rx
116 .buffer_unordered(self.max_concurrency)
117 .filter_map(future::ready)
118 .map(|res| Ok(Message::Response(res)))
119 .forward(responses_tx.clone().sink_map_err(|_| unreachable!()))
120 .map(|_| ());
121
122 let print_output = stream::select(responses_rx, client_requests.map(Message::Request))
123 .map(Ok)
124 .forward(framed_stdout.sink_map_err(|e| error!("failed to encode message: {}", e)))
125 .map(|_| ());
126
127 let read_input = async {
128 while let Some(msg) = framed_stdin.next().await {
129 match msg {
130 Ok(Message::Request(req)) => {
131 if let Err(err) = future::poll_fn(|cx| service.poll_ready(cx)).await {
132 error!("{}", display_sources(err.into().as_ref()));
133 return;
134 }
135
136 let fut = service.call(req).unwrap_or_else(|err| {
137 error!("{}", display_sources(err.into().as_ref()));
138 None
139 });
140
141 server_tasks_tx.send(fut).await.unwrap();
142 }
143 Ok(Message::Response(res)) => {
144 if let Err(err) = client_responses.send(res).await {
145 error!("{}", display_sources(&err));
146 return;
147 }
148 }
149 Err(err) => {
150 error!("failed to decode message: {}", err);
151 let res = Response::from_error(Id::Null, to_jsonrpc_error(err));
152 responses_tx.send(Message::Response(res)).await.unwrap();
153 }
154 }
155 }
156
157 server_tasks_tx.disconnect();
158 responses_tx.disconnect();
159 client_abort.abort();
160 };
161
162 join!(print_output, read_input, process_server_tasks);
163 }
164}
165
166fn display_sources(error: &dyn std::error::Error) -> String {
167 if let Some(source) = error.source() {
168 format!("{}: {}", error, display_sources(source))
169 } else {
170 error.to_string()
171 }
172}
173
174#[cfg(feature = "runtime-tokio")]
175fn to_jsonrpc_error(err: ParseError) -> Error {
176 match err {
177 ParseError::Body(err) if err.is_data() => Error::invalid_request(),
178 _ => Error::parse_error(),
179 }
180}
181
182#[cfg(feature = "runtime-agnostic")]
183fn to_jsonrpc_error(err: impl std::error::Error) -> Error {
184 match err.source().and_then(|e| e.downcast_ref()) {
185 Some(ParseError::Body(err)) if err.is_data() => Error::invalid_request(),
186 _ => Error::parse_error(),
187 }
188}
189
190#[cfg(test)]
191mod tests {
192 use std::task::{Context, Poll};
193
194 #[cfg(feature = "runtime-agnostic")]
195 use futures::io::Cursor;
196 #[cfg(feature = "runtime-tokio")]
197 use std::io::Cursor;
198
199 use futures::future::Ready;
200 use futures::{future, sink, stream};
201
202 use super::*;
203
204 const REQUEST: &str = r#"{"jsonrpc":"2.0","method":"initialize","params":{},"id":1}"#;
205 const RESPONSE: &str = r#"{"jsonrpc":"2.0","result":{"capabilities":{}},"id":1}"#;
206
207 #[derive(Debug)]
208 struct MockService;
209
210 impl Service<Request> for MockService {
211 type Response = Option<Response>;
212 type Error = String;
213 type Future = Ready<Result<Self::Response, Self::Error>>;
214
215 fn poll_ready(&mut self, _: &mut Context) -> Poll<Result<(), Self::Error>> {
216 Poll::Ready(Ok(()))
217 }
218
219 fn call(&mut self, _: Request) -> Self::Future {
220 let response = serde_json::from_str(RESPONSE).unwrap();
221 future::ok(Some(response))
222 }
223 }
224
225 struct MockLoopback(Vec<Request>);
226
227 impl Loopback for MockLoopback {
228 type RequestStream = stream::Iter<std::vec::IntoIter<Request>>;
229 type ResponseSink = sink::Drain<Response>;
230
231 fn split(self) -> (Self::RequestStream, Self::ResponseSink) {
232 (stream::iter(self.0), sink::drain())
233 }
234 }
235
236 fn mock_request() -> Vec<u8> {
237 format!("Content-Length: {}\r\n\r\n{}", REQUEST.len(), REQUEST).into_bytes()
238 }
239
240 fn mock_response() -> Vec<u8> {
241 format!("Content-Length: {}\r\n\r\n{}", RESPONSE.len(), RESPONSE).into_bytes()
242 }
243
244 fn mock_stdio() -> (Cursor<Vec<u8>>, Vec<u8>) {
245 (Cursor::new(mock_request()), Vec::new())
246 }
247
248 #[tokio::test(flavor = "current_thread")]
249 async fn serves_on_stdio() {
250 let (mut stdin, mut stdout) = mock_stdio();
251 Server::new(&mut stdin, &mut stdout, MockLoopback(vec![]))
252 .serve(MockService)
253 .await;
254
255 assert_eq!(stdin.position(), 80);
256 assert_eq!(stdout, mock_response());
257 }
258
259 #[tokio::test(flavor = "current_thread")]
260 async fn interleaves_messages() {
261 let socket = MockLoopback(vec![serde_json::from_str(REQUEST).unwrap()]);
262
263 let (mut stdin, mut stdout) = mock_stdio();
264 Server::new(&mut stdin, &mut stdout, socket)
265 .serve(MockService)
266 .await;
267
268 assert_eq!(stdin.position(), 80);
269 let output: Vec<_> = mock_request().into_iter().chain(mock_response()).collect();
270 assert_eq!(stdout, output);
271 }
272
273 #[tokio::test(flavor = "current_thread")]
274 async fn handles_invalid_json() {
275 let invalid = r#"{"jsonrpc":"2.0","method":"#;
276 let message = format!("Content-Length: {}\r\n\r\n{}", invalid.len(), invalid).into_bytes();
277 let (mut stdin, mut stdout) = (Cursor::new(message), Vec::new());
278
279 Server::new(&mut stdin, &mut stdout, MockLoopback(vec![]))
280 .serve(MockService)
281 .await;
282
283 assert_eq!(stdin.position(), 48);
284 let err = r#"{"jsonrpc":"2.0","error":{"code":-32700,"message":"Parse error"},"id":null}"#;
285 let output = format!("Content-Length: {}\r\n\r\n{}", err.len(), err).into_bytes();
286 assert_eq!(stdout, output);
287 }
288}