1use std::fmt::Write;
2use std::io;
3
4use bytes::BytesMut;
5use tokio_util::codec::{Decoder, Encoder};
6
7use crate::Error;
8
9type Body = serde_json::Value;
10
11#[derive(Default)]
12pub struct LspCodec {
13 encoder: LspEncoder,
14 decoder: LspDecoder,
15}
16
17impl Encoder<Body> for LspCodec {
18 type Error = <LspEncoder as Encoder<Body>>::Error;
19
20 fn encode(&mut self, item: Body, dst: &mut BytesMut) -> Result<(), Self::Error> {
21 Encoder::encode(&mut self.encoder, item, dst)
22 }
23}
24
25impl Decoder for LspCodec {
26 type Item = Body;
27 type Error = <LspEncoder as Encoder<Body>>::Error;
28
29 fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
30 Decoder::decode(&mut self.decoder, buf)
31 }
32}
33
34#[derive(Default)]
35pub struct LspDecoder {
36 state: State,
37}
38
39#[derive(Default)]
40pub struct LspEncoder;
41
42enum State {
43 ReadingHeader {
44 header: HeaderBuilder,
45 cursor: usize,
46 },
47 ReadingBody(Header),
48 Parsed(Body),
49}
50
51impl Default for State {
52 fn default() -> State {
53 State::ReadingHeader {
54 header: HeaderBuilder::default(),
55 cursor: 0,
56 }
57 }
58}
59
60#[derive(Debug, PartialEq)]
61pub enum HeaderError {
62 DuplicateHeaderField,
63 MissingContentLength,
64 UnsupportedCharset,
65 HeaderFieldParseError(String),
66 WrongEntryField(String),
67}
68
69#[derive(Debug, Default, PartialEq)]
70pub struct Header {
71 content_length: ContentLength,
72 content_type: Option<ContentType>,
73}
74
75#[derive(Default)]
76struct HeaderBuilder {
77 content_length: Option<ContentLength>,
78 content_type: Option<ContentType>,
79}
80
81impl HeaderBuilder {
82 fn try_field(&mut self, field: HeaderField) -> Result<&mut Self, HeaderError> {
83 match field {
84 HeaderField::ContentLength(len) => {
85 if self.content_length.is_some() {
86 Err(HeaderError::DuplicateHeaderField)
87 } else {
88 self.content_length = Some(len);
89 Ok(self)
90 }
91 }
92 HeaderField::ContentType(typ) => {
93 if self.content_type.is_some() {
94 Err(HeaderError::DuplicateHeaderField)
95 } else {
96 self.content_type = Some(typ);
97 Ok(self)
98 }
99 }
100 }
101 }
102
103 fn try_build(self) -> Result<Header, HeaderError> {
104 if let Some(len) = self.content_length {
105 Ok(Header {
106 content_length: len,
107 content_type: self.content_type,
108 })
109 } else {
110 Err(HeaderError::MissingContentLength)
111 }
112 }
113}
114
115#[derive(Debug, Default, PartialEq)]
116struct ContentLength(usize);
117#[derive(Debug, PartialEq)]
118struct ContentType(String);
119
120impl Default for ContentType {
121 fn default() -> ContentType {
122 ContentType(String::from("application/vscode-jsonrpc; charset=utf-8"))
123 }
124}
125
126enum HeaderField {
127 ContentLength(ContentLength),
128 ContentType(ContentType),
129}
130
131impl std::str::FromStr for HeaderField {
132 type Err = HeaderError;
133 fn from_str(s: &str) -> Result<Self, Self::Err> {
134 ContentLength::from_str(s)
135 .map(HeaderField::ContentLength)
136 .or_else(|_| ContentType::from_str(s).map(HeaderField::ContentType))
137 }
138}
139
140impl std::str::FromStr for ContentLength {
141 type Err = HeaderError;
142 fn from_str(s: &str) -> Result<Self, Self::Err> {
143 if s.starts_with("Content-Length: ") {
144 let len = s["Content-Length: ".len()..]
145 .trim_end()
146 .parse()
147 .map_err(|_| HeaderError::HeaderFieldParseError(s.to_owned()))?;
148 Ok(ContentLength(len))
149 } else {
150 Err(HeaderError::HeaderFieldParseError(s.to_owned()))
151 }
152 }
153}
154
155impl std::str::FromStr for ContentType {
156 type Err = HeaderError;
157 fn from_str(s: &str) -> Result<Self, Self::Err> {
158 if s.starts_with("Content-Type: ") {
159 let typ = &s["Content-Type: ".len()..];
160
161 match typ.find("charset=").map(|i| &typ[i + "charset=".len()..]) {
162 Some(charset)
163 if charset.starts_with("utf8")
164 || charset.starts_with("utf-8")
165 || charset.starts_with("UTF-8") => {}
166 _ => Err(HeaderError::UnsupportedCharset)?,
168 }
169
170 Ok(ContentType(typ.to_owned()))
171 } else {
172 Err(HeaderError::HeaderFieldParseError(s.to_owned()))
173 }
174 }
175}
176
177enum UpdateState {
178 NotEnough,
179 Ready,
180 Parsed,
181}
182
183impl State {
184 fn try_update(&mut self, buf: &mut BytesMut) -> Result<UpdateState, Error> {
185 match self {
186 State::ReadingHeader { header, cursor } => {
187 if let Some(index) = buf[*cursor..].windows(2).position(|w| w == [b'\r', b'\n']) {
188 let index = *cursor + index;
189
190 let line = buf.split_to(index + 2); *cursor = 0;
192
193 let line = &line[..line.len() - 2];
194 let line = std::str::from_utf8(&line).expect("invalid utf8 data");
195
196 if line.is_empty() {
197 let header = std::mem::replace(header, HeaderBuilder::default())
198 .try_build()
199 .map_err(|_| HeaderError::MissingContentLength)?;
200 *self = State::ReadingBody(header);
201 } else {
202 let field = line
203 .parse()
204 .map_err(|_| HeaderError::WrongEntryField(line.to_owned()))?;
205 header
206 .try_field(field)
207 .map_err(|_| HeaderError::DuplicateHeaderField)?;
208 }
209
210 Ok(UpdateState::Ready)
211 } else {
212 *cursor = buf.len();
213
214 Ok(UpdateState::NotEnough)
215 }
216 }
217 State::ReadingBody(header) => {
218 if buf.len() >= header.content_length.0 {
219 let buf = buf.split_to(header.content_length.0);
220
221 let s = std::str::from_utf8(&buf).expect("invalid utf8 data");
222 let body = serde_json::from_str(s).map_err(Error::Serde)?;
223
224 *self = State::Parsed(body);
225
226 Ok(UpdateState::Parsed)
227 } else {
228 Ok(UpdateState::NotEnough)
229 }
230 }
231 State::Parsed(..) => Ok(UpdateState::Parsed),
232 }
233 }
234}
235
236impl Decoder for LspDecoder {
237 type Item = Body;
238 type Error = Error;
239
240 fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
241 loop {
242 match self.state.try_update(buf)? {
243 UpdateState::Ready => continue,
244 UpdateState::NotEnough => break Ok(None),
245 UpdateState::Parsed => {
246 break match std::mem::replace(&mut self.state, State::default()) {
247 State::Parsed(body) => Ok(Some(body)),
248 _ => unreachable!(),
249 };
250 }
251 };
252 }
253 }
254}
255
256impl Encoder<Body> for LspEncoder {
257 type Error = Error;
258
259 fn encode(&mut self, item: Body, dst: &mut BytesMut) -> Result<(), Error> {
260 let body = serde_json::to_string(&item).map_err(Error::Serde)?;
261 let body_len: usize = body.chars().map(char::len_utf8).sum();
262
263 let header = format!("Content-Length: {}\r\n\r\n", body_len);
264 let header_len: usize = header.chars().map(char::len_utf8).sum();
265
266 dst.reserve(header_len + body_len);
267 Ok(write!(dst, "{}{}", header, body)
268 .map_err(|_| io::Error::new(io::ErrorKind::Other, "Formatting into buffer failed"))?)
269 }
270}
271
272#[cfg(test)]
273mod tests {
274 use super::*;
275
276 #[test]
277 fn content_type() {
278 let ContentType(typ) = "Content-Type: application/vscode-jsonrpc; charset=utf8"
280 .parse()
281 .unwrap();
282 assert_eq!(typ, "application/vscode-jsonrpc; charset=utf8");
283
284 let ContentType(typ) = "Content-Type: application/vscode-jsonrpc; charset=utf-8"
285 .parse()
286 .unwrap();
287 assert_eq!(typ, "application/vscode-jsonrpc; charset=utf-8");
288
289 let ContentType(typ) = "Content-Type: application/vscode-jsonrpc; charset=UTF-8"
290 .parse()
291 .unwrap();
292 assert_eq!(typ, "application/vscode-jsonrpc; charset=UTF-8");
293
294 let res = "Content-Type: application/vscode-jsonrpc; charset=utf-16".parse::<ContentType>();
295 assert!(res.is_err());
296
297 let res = "Content-Type: application/vscode-jsonrpc; charset=latin1".parse::<ContentType>();
298 assert!(res.is_err());
299 }
300}