lsp_codec/
proto.rs

1use std::fmt::Write;
2use std::io;
3
4use bytes::BytesMut;
5use tokio_util::codec::{Decoder, Encoder};
6
7use crate::Error;
8
9type Body = serde_json::Value;
10
11#[derive(Default)]
12pub struct LspCodec {
13    encoder: LspEncoder,
14    decoder: LspDecoder,
15}
16
17impl Encoder<Body> for LspCodec {
18    type Error = <LspEncoder as Encoder<Body>>::Error;
19
20    fn encode(&mut self, item: Body, dst: &mut BytesMut) -> Result<(), Self::Error> {
21        Encoder::encode(&mut self.encoder, item, dst)
22    }
23}
24
25impl Decoder for LspCodec {
26    type Item = Body;
27    type Error = <LspEncoder as Encoder<Body>>::Error;
28
29    fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
30        Decoder::decode(&mut self.decoder, buf)
31    }
32}
33
34#[derive(Default)]
35pub struct LspDecoder {
36    state: State,
37}
38
39#[derive(Default)]
40pub struct LspEncoder;
41
42enum State {
43    ReadingHeader {
44        header: HeaderBuilder,
45        cursor: usize,
46    },
47    ReadingBody(Header),
48    Parsed(Body),
49}
50
51impl Default for State {
52    fn default() -> State {
53        State::ReadingHeader {
54            header: HeaderBuilder::default(),
55            cursor: 0,
56        }
57    }
58}
59
60#[derive(Debug, PartialEq)]
61pub enum HeaderError {
62    DuplicateHeaderField,
63    MissingContentLength,
64    UnsupportedCharset,
65    HeaderFieldParseError(String),
66    WrongEntryField(String),
67}
68
69#[derive(Debug, Default, PartialEq)]
70pub struct Header {
71    content_length: ContentLength,
72    content_type: Option<ContentType>,
73}
74
75#[derive(Default)]
76struct HeaderBuilder {
77    content_length: Option<ContentLength>,
78    content_type: Option<ContentType>,
79}
80
81impl HeaderBuilder {
82    fn try_field(&mut self, field: HeaderField) -> Result<&mut Self, HeaderError> {
83        match field {
84            HeaderField::ContentLength(len) => {
85                if self.content_length.is_some() {
86                    Err(HeaderError::DuplicateHeaderField)
87                } else {
88                    self.content_length = Some(len);
89                    Ok(self)
90                }
91            }
92            HeaderField::ContentType(typ) => {
93                if self.content_type.is_some() {
94                    Err(HeaderError::DuplicateHeaderField)
95                } else {
96                    self.content_type = Some(typ);
97                    Ok(self)
98                }
99            }
100        }
101    }
102
103    fn try_build(self) -> Result<Header, HeaderError> {
104        if let Some(len) = self.content_length {
105            Ok(Header {
106                content_length: len,
107                content_type: self.content_type,
108            })
109        } else {
110            Err(HeaderError::MissingContentLength)
111        }
112    }
113}
114
115#[derive(Debug, Default, PartialEq)]
116struct ContentLength(usize);
117#[derive(Debug, PartialEq)]
118struct ContentType(String);
119
120impl Default for ContentType {
121    fn default() -> ContentType {
122        ContentType(String::from("application/vscode-jsonrpc; charset=utf-8"))
123    }
124}
125
126enum HeaderField {
127    ContentLength(ContentLength),
128    ContentType(ContentType),
129}
130
131impl std::str::FromStr for HeaderField {
132    type Err = HeaderError;
133    fn from_str(s: &str) -> Result<Self, Self::Err> {
134        ContentLength::from_str(s)
135            .map(HeaderField::ContentLength)
136            .or_else(|_| ContentType::from_str(s).map(HeaderField::ContentType))
137    }
138}
139
140impl std::str::FromStr for ContentLength {
141    type Err = HeaderError;
142    fn from_str(s: &str) -> Result<Self, Self::Err> {
143        if s.starts_with("Content-Length: ") {
144            let len = s["Content-Length: ".len()..]
145                .trim_end()
146                .parse()
147                .map_err(|_| HeaderError::HeaderFieldParseError(s.to_owned()))?;
148            Ok(ContentLength(len))
149        } else {
150            Err(HeaderError::HeaderFieldParseError(s.to_owned()))
151        }
152    }
153}
154
155impl std::str::FromStr for ContentType {
156    type Err = HeaderError;
157    fn from_str(s: &str) -> Result<Self, Self::Err> {
158        if s.starts_with("Content-Type: ") {
159            let typ = &s["Content-Type: ".len()..];
160
161            match typ.find("charset=").map(|i| &typ[i + "charset=".len()..]) {
162                Some(charset)
163                    if charset.starts_with("utf8")
164                        || charset.starts_with("utf-8")
165                        || charset.starts_with("UTF-8") => {}
166                // https://github.com/Microsoft/language-server-protocol/issues/600
167                _ => Err(HeaderError::UnsupportedCharset)?,
168            }
169
170            Ok(ContentType(typ.to_owned()))
171        } else {
172            Err(HeaderError::HeaderFieldParseError(s.to_owned()))
173        }
174    }
175}
176
177enum UpdateState {
178    NotEnough,
179    Ready,
180    Parsed,
181}
182
183impl State {
184    fn try_update(&mut self, buf: &mut BytesMut) -> Result<UpdateState, Error> {
185        match self {
186            State::ReadingHeader { header, cursor } => {
187                if let Some(index) = buf[*cursor..].windows(2).position(|w| w == [b'\r', b'\n']) {
188                    let index = *cursor + index;
189
190                    let line = buf.split_to(index + 2); // consume \r *and* trailing \n
191                    *cursor = 0;
192
193                    let line = &line[..line.len() - 2];
194                    let line = std::str::from_utf8(&line).expect("invalid utf8 data");
195
196                    if line.is_empty() {
197                        let header = std::mem::replace(header, HeaderBuilder::default())
198                            .try_build()
199                            .map_err(|_| HeaderError::MissingContentLength)?;
200                        *self = State::ReadingBody(header);
201                    } else {
202                        let field = line
203                            .parse()
204                            .map_err(|_| HeaderError::WrongEntryField(line.to_owned()))?;
205                        header
206                            .try_field(field)
207                            .map_err(|_| HeaderError::DuplicateHeaderField)?;
208                    }
209
210                    Ok(UpdateState::Ready)
211                } else {
212                    *cursor = buf.len();
213
214                    Ok(UpdateState::NotEnough)
215                }
216            }
217            State::ReadingBody(header) => {
218                if buf.len() >= header.content_length.0 {
219                    let buf = buf.split_to(header.content_length.0);
220
221                    let s = std::str::from_utf8(&buf).expect("invalid utf8 data");
222                    let body = serde_json::from_str(s).map_err(Error::Serde)?;
223
224                    *self = State::Parsed(body);
225
226                    Ok(UpdateState::Parsed)
227                } else {
228                    Ok(UpdateState::NotEnough)
229                }
230            }
231            State::Parsed(..) => Ok(UpdateState::Parsed),
232        }
233    }
234}
235
236impl Decoder for LspDecoder {
237    type Item = Body;
238    type Error = Error;
239
240    fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
241        loop {
242            match self.state.try_update(buf)? {
243                UpdateState::Ready => continue,
244                UpdateState::NotEnough => break Ok(None),
245                UpdateState::Parsed => {
246                    break match std::mem::replace(&mut self.state, State::default()) {
247                        State::Parsed(body) => Ok(Some(body)),
248                        _ => unreachable!(),
249                    };
250                }
251            };
252        }
253    }
254}
255
256impl Encoder<Body> for LspEncoder {
257    type Error = Error;
258
259    fn encode(&mut self, item: Body, dst: &mut BytesMut) -> Result<(), Error> {
260        let body = serde_json::to_string(&item).map_err(Error::Serde)?;
261        let body_len: usize = body.chars().map(char::len_utf8).sum();
262
263        let header = format!("Content-Length: {}\r\n\r\n", body_len);
264        let header_len: usize = header.chars().map(char::len_utf8).sum();
265
266        dst.reserve(header_len + body_len);
267        Ok(write!(dst, "{}{}", header, body)
268            .map_err(|_| io::Error::new(io::ErrorKind::Other, "Formatting into buffer failed"))?)
269    }
270}
271
272#[cfg(test)]
273mod tests {
274    use super::*;
275
276    #[test]
277    fn content_type() {
278        // Backwards compatibility, see https://github.com/Microsoft/language-server-protocol/pull/199.
279        let ContentType(typ) = "Content-Type: application/vscode-jsonrpc; charset=utf8"
280            .parse()
281            .unwrap();
282        assert_eq!(typ, "application/vscode-jsonrpc; charset=utf8");
283
284        let ContentType(typ) = "Content-Type: application/vscode-jsonrpc; charset=utf-8"
285            .parse()
286            .unwrap();
287        assert_eq!(typ, "application/vscode-jsonrpc; charset=utf-8");
288
289        let ContentType(typ) = "Content-Type: application/vscode-jsonrpc; charset=UTF-8"
290            .parse()
291            .unwrap();
292        assert_eq!(typ, "application/vscode-jsonrpc; charset=UTF-8");
293
294        let res = "Content-Type: application/vscode-jsonrpc; charset=utf-16".parse::<ContentType>();
295        assert!(res.is_err());
296
297        let res = "Content-Type: application/vscode-jsonrpc; charset=latin1".parse::<ContentType>();
298        assert!(res.is_err());
299    }
300}