1use std::{
2 collections::VecDeque,
3 fmt,
4 io::{self, Write},
5 marker::PhantomData,
6 marker::Unpin,
7 pin::Pin,
8 str,
9 task::{self, Poll},
10};
11
12use anyhow::anyhow;
13
14use combine::{
15 error::ParseError,
16 from_str,
17 parser::{
18 combinator::{any_send_partial_state, AnySendPartialState},
19 range::{range, take, take_while1},
20 },
21 skip_many,
22 stream::{easy, PartialStream, RangeStream},
23 Parser,
24};
25
26use bytes::{buf::Buf, BufMut, BytesMut};
27
28use tokio_util::codec::{Decoder, Encoder};
29
30use futures::{channel::mpsc, prelude::*, Sink, Stream};
31
32use jsonrpc_core::{Error, ErrorCode, Params, RpcMethodSimple, RpcNotificationSimple, Value};
33
34use lsp_types::{notification, LogMessageParams, MessageType};
35
36use serde;
37use serde_json::{self, from_value, to_string, to_value};
38
39use crate::BoxFuture;
40
41#[derive(Debug, PartialEq)]
42pub struct ServerError<E> {
43 pub message: String,
44 pub data: Option<E>,
45}
46
47impl<E, D> From<E> for ServerError<D>
48where
49 E: fmt::Display,
50{
51 fn from(err: E) -> ServerError<D> {
52 ServerError {
53 message: err.to_string(),
54 data: None,
55 }
56 }
57}
58
59pub trait LanguageServerCommand<P>: Send + Sync + 'static
60where
61 Self::Future: Send + 'static,
62{
63 type Future: Future<Output = Result<Self::Output, ServerError<Self::Error>>> + Send + 'static;
64 type Output: serde::Serialize;
65 type Error: serde::Serialize;
66 fn execute(&self, param: P) -> Self::Future;
67
68 fn invalid_params(&self) -> Option<Self::Error> {
69 None
70 }
71}
72
73impl<'de, F, R, P, O, E> LanguageServerCommand<P> for F
74where
75 F: Fn(P) -> R + Send + Sync + 'static,
76 R: Future<Output = Result<O, ServerError<E>>> + Send + 'static,
77 P: serde::Deserialize<'de>,
78 O: serde::Serialize,
79 E: serde::Serialize,
80{
81 type Future = F::Output;
82 type Output = O;
83 type Error = E;
84
85 fn execute(&self, param: P) -> Self::Future {
86 self(param)
87 }
88}
89
90pub trait LanguageServerNotification<P>: Send + Sync + 'static {
91 fn execute(&self, param: P);
92}
93
94impl<'de, F, P> LanguageServerNotification<P> for F
95where
96 F: Fn(P) + Send + Sync + 'static,
97 P: serde::Deserialize<'de> + 'static,
98{
99 fn execute(&self, param: P) {
100 self(param)
101 }
102}
103pub struct ServerCommand<T, P>(pub T, PhantomData<fn(P)>);
104
105impl<T, P> ServerCommand<T, P> {
106 pub fn method(command: T) -> ServerCommand<T, P>
107 where
108 T: LanguageServerCommand<P>,
109 P: for<'de> serde::Deserialize<'de> + 'static,
110 {
111 ServerCommand(command, PhantomData)
112 }
113
114 pub fn notification(command: T) -> ServerCommand<T, P>
115 where
116 T: LanguageServerNotification<P>,
117 P: for<'de> serde::Deserialize<'de> + 'static,
118 {
119 ServerCommand(command, PhantomData)
120 }
121}
122
123impl<P, T> RpcMethodSimple for ServerCommand<T, P>
124where
125 T: LanguageServerCommand<P>,
126 P: for<'de> serde::Deserialize<'de> + 'static,
127{
128 type Out = BoxFuture<Value, Error>;
129 fn call(&self, param: Params) -> Self::Out {
130 let value = match param {
131 Params::Map(map) => Value::Object(map),
132 Params::Array(arr) => Value::Array(arr),
133 Params::None => Value::Null,
134 };
135 let err = match from_value(value) {
136 Ok(value) => {
137 return self
138 .0
139 .execute(value)
140 .map(|result| match result {
141 Ok(value) => {
142 Ok(to_value(&value).expect("result data could not be serialized"))
143 }
144 Err(error) => Err(Error {
145 code: ErrorCode::InternalError,
146 message: error.message,
147 data: error
148 .data
149 .as_ref()
150 .map(|v| to_value(v).expect("error data could not be serialized")),
151 }),
152 })
153 .boxed()
154 }
155 Err(err) => err,
156 };
157 let data = self.0.invalid_params();
158 futures::future::err(Error {
159 code: ErrorCode::InvalidParams,
160 message: format!("Invalid params: {}", err),
161 data: data
162 .as_ref()
163 .map(|v| to_value(v).expect("error data could not be serialized")),
164 })
165 .boxed()
166 }
167}
168
169impl<T, P> RpcNotificationSimple for ServerCommand<T, P>
170where
171 T: LanguageServerNotification<P>,
172 P: for<'de> serde::Deserialize<'de> + 'static,
173{
174 fn execute(&self, param: Params) {
175 match param {
176 Params::Map(map) => match from_value(Value::Object(map)) {
177 Ok(value) => {
178 self.0.execute(value);
179 }
180 Err(err) => error!("{}", err), },
182 _ => (), }
184 }
185}
186
187pub(crate) async fn log_message(sender: mpsc::Sender<String>, message: String) {
188 debug!("{}", message);
189 send_response(
190 sender,
191 notification!("window/logMessage"),
192 LogMessageParams {
193 typ: MessageType::Log,
194 message,
195 },
196 )
197 .await
198}
199
200macro_rules! log_message {
201 ($sender: expr, $($ts: tt)+) => { async {
202 if log_enabled!(::log::Level::Debug) {
203 let msg = format!( $($ts)+ );
204 crate::rpc::log_message($sender, msg).await
205 }
206 } }
207}
208
209pub async fn send_response<T>(mut sender: mpsc::Sender<String>, _: Option<T>, value: T::Params)
210where
211 T: notification::Notification,
212 T::Params: serde::Serialize,
213{
214 let r = format!(
215 r#"{{"jsonrpc": "2.0", "method": "{}", "params": {} }}"#,
216 T::METHOD,
217 serde_json::to_value(value).unwrap()
218 );
219 let _ = sender.send(r).await;
220}
221
222pub fn write_message<W, T>(output: W, value: &T) -> io::Result<()>
223where
224 W: Write,
225 T: serde::Serialize,
226{
227 let response = to_string(&value).unwrap();
228 write_message_str(output, &response)
229}
230
231pub fn write_message_str<W>(mut output: W, response: &str) -> io::Result<()>
232where
233 W: Write,
234{
235 debug!("Respond: {}", response);
236 write!(
237 output,
238 "Content-Length: {}\r\n\r\n{}",
239 response.len(),
240 response
241 )?;
242 output.flush()?;
243 Ok(())
244}
245
246pub struct LanguageServerDecoder {
247 state: AnySendPartialState,
248}
249
250impl LanguageServerDecoder {
251 pub fn new() -> LanguageServerDecoder {
252 LanguageServerDecoder {
253 state: Default::default(),
254 }
255 }
256}
257
258fn decode_parser<'a, I>(
266) -> impl Parser<I, Output = Vec<u8>, PartialState = AnySendPartialState> + 'a
267where
268 I: RangeStream<Token = u8, Range = &'a [u8]> + 'a,
269 I::Error: ParseError<I::Token, I::Range, I::Position>,
271{
272 let content_length =
273 range(&b"Content-Length: "[..]).with(from_str(take_while1(|b: u8| b.is_ascii_digit())));
274
275 any_send_partial_state(
276 (
277 skip_many(range(&b"\r\n"[..])),
278 content_length,
279 range(&b"\r\n\r\n"[..]).map(|_| ()),
280 )
281 .then_partial(|&mut (_, message_length, _)| {
282 take(message_length).map(|bytes: &[u8]| bytes.to_owned())
283 }),
284 )
285}
286
287impl Decoder for LanguageServerDecoder {
288 type Item = String;
289 type Error = anyhow::Error;
290
291 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
292 let (opt, removed_len) = combine::stream::decode(
293 decode_parser(),
294 &mut easy::Stream(PartialStream(&src[..])),
295 &mut self.state,
296 )
297 .map_err(|err| {
298 let err = err
299 .map_range(|r| {
300 str::from_utf8(r)
301 .ok()
302 .map_or_else(|| format!("{:?}", r), |s| s.to_string())
303 })
304 .map_position(|p| p.translate_position(&src[..]));
305 anyhow!("{}\nIn input: `{}`", err, str::from_utf8(src).unwrap())
306 })?;
307
308 src.advance(removed_len);
309
310 match opt {
311 None => Ok(None),
312
313 Some(output) => {
314 let value = String::from_utf8(output)?;
315 Ok(Some(value))
316 }
317 }
318 }
319}
320
321#[derive(Debug)]
322pub struct LanguageServerEncoder;
323
324impl Encoder<String> for LanguageServerEncoder {
325 type Error = anyhow::Error;
326 fn encode(&mut self, item: String, dst: &mut BytesMut) -> Result<(), Self::Error> {
327 dst.reserve(item.len() + 60); write_message_str(dst.writer(), &item)?;
329 Ok(())
330 }
331}
332
333pub struct Entry<K, V, W> {
334 pub key: K,
335 pub value: V,
336 pub version: W,
337}
338
339pub struct UniqueSink<K, V, W> {
341 sender: mpsc::UnboundedSender<Entry<K, V, W>>,
342}
343
344impl<K, V, W> Clone for UniqueSink<K, V, W> {
345 fn clone(&self) -> Self {
346 UniqueSink {
347 sender: self.sender.clone(),
348 }
349 }
350}
351
352pub struct UniqueStream<K, V, W> {
353 queue: VecDeque<Entry<K, V, W>>,
354 receiver: mpsc::UnboundedReceiver<Entry<K, V, W>>,
355 exhausted: bool,
356}
357
358pub fn unique_queue<K, V, W>() -> (UniqueSink<K, V, W>, UniqueStream<K, V, W>)
359where
360 K: PartialEq,
361 W: Ord,
362{
363 let (sender, receiver) = mpsc::unbounded();
364 (
365 UniqueSink { sender },
366 UniqueStream {
367 queue: VecDeque::new(),
368 receiver,
369 exhausted: false,
370 },
371 )
372}
373
374impl<K, V, W> Stream for UniqueStream<K, V, W>
375where
376 K: PartialEq,
377 W: Ord,
378 Self: Unpin,
379{
380 type Item = Entry<K, V, W>;
381
382 fn poll_next(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Option<Self::Item>> {
383 while !self.exhausted {
384 match self.receiver.poll_next_unpin(cx) {
385 Poll::Ready(Some(item)) => {
386 if let Some(entry) = self.queue.iter_mut().find(|entry| entry.key == item.key) {
387 if entry.version < item.version {
388 *entry = item;
389 }
390 continue;
391 }
392 self.queue.push_back(item);
393 }
394 Poll::Ready(None) => {
395 self.exhausted = true;
396 }
397 Poll::Pending => break,
398 }
399 }
400 match self.queue.pop_front() {
401 Some(item) => Poll::Ready(Some(item)),
402 None => {
403 if self.exhausted {
404 Poll::Ready(None)
405 } else {
406 Poll::Pending
407 }
408 }
409 }
410 }
411}
412
413impl<K, V, W> Sink<Entry<K, V, W>> for UniqueSink<K, V, W> {
414 type Error = mpsc::SendError;
415
416 fn poll_ready(
417 mut self: Pin<&mut Self>,
418 cx: &mut task::Context<'_>,
419 ) -> Poll<Result<(), Self::Error>> {
420 Pin::new(&mut self.sender).poll_ready(cx)
421 }
422
423 fn start_send(mut self: Pin<&mut Self>, item: Entry<K, V, W>) -> Result<(), Self::Error> {
424 Pin::new(&mut self.sender).start_send(item)
425 }
426
427 fn poll_flush(
428 mut self: Pin<&mut Self>,
429 cx: &mut task::Context<'_>,
430 ) -> Poll<Result<(), Self::Error>> {
431 Pin::new(&mut self.sender).poll_flush(cx)
432 }
433
434 fn poll_close(
435 mut self: Pin<&mut Self>,
436 cx: &mut task::Context<'_>,
437 ) -> Poll<Result<(), Self::Error>> {
438 Pin::new(&mut self.sender).poll_close(cx)
439 }
440}