1use crate::Error;
7use anyhow::anyhow;
8use bytes::{BufMut, BytesMut};
9use serde_json::value::Value;
10use std::str::FromStr;
11use std::{io, str};
12use tokio_util::codec::{Decoder, Encoder};
13
14use crate::messages::JsonRpc;
15use crate::messages::{Notification, Request};
16
17#[derive(Default)]
20pub struct MultiLineCodec {}
21
22fn find_separator(buf: &mut BytesMut) -> Option<usize> {
25 buf.iter()
26 .zip(buf.iter().skip(1))
27 .position(|b| *b.0 == b'\n' && *b.1 == b'\n')
28}
29
30fn utf8(buf: &[u8]) -> Result<&str, io::Error> {
31 str::from_utf8(buf)
32 .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "Unable to decode input as UTF8"))
33}
34
35impl Decoder for MultiLineCodec {
36 type Item = String;
37 type Error = Error;
38 fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Error> {
39 if let Some(newline_offset) = find_separator(buf) {
40 let line = buf.split_to(newline_offset + 2);
41 let line = &line[..line.len() - 2];
42 let line = utf8(line)?;
43 Ok(Some(line.to_string()))
44 } else {
45 Ok(None)
46 }
47 }
48}
49
50impl<T> Encoder<T> for MultiLineCodec
51where
52 T: AsRef<str>,
53{
54 type Error = Error;
55 fn encode(&mut self, line: T, buf: &mut BytesMut) -> Result<(), Self::Error> {
56 let line = line.as_ref();
57 buf.reserve(line.len() + 2);
58 buf.put(line.as_bytes());
59 buf.put_u8(b'\n');
60 buf.put_u8(b'\n');
61 Ok(())
62 }
63}
64
65#[derive(Default)]
66pub struct JsonCodec {
67 inner: MultiLineCodec,
70}
71
72impl<T> Encoder<T> for JsonCodec
73where
74 T: Into<Value>,
75{
76 type Error = Error;
77 fn encode(&mut self, msg: T, buf: &mut BytesMut) -> Result<(), Self::Error> {
78 let s = msg.into().to_string();
79 self.inner.encode(s, buf)
80 }
81}
82
83impl Decoder for JsonCodec {
84 type Item = Value;
85 type Error = Error;
86
87 fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Error> {
88 match self.inner.decode(buf) {
89 Ok(None) => Ok(None),
90 Err(e) => Err(e),
91 Ok(Some(s)) => {
92 if let Ok(v) = Value::from_str(&s) {
93 Ok(Some(v))
94 } else {
95 Err(anyhow!("failed to parse JSON"))
96 }
97 }
98 }
99 }
100}
101
102#[derive(Default)]
106pub(crate) struct JsonRpcCodec {
107 inner: JsonCodec,
108}
109
110impl Decoder for JsonRpcCodec {
111 type Item = JsonRpc<Notification, Request>;
112 type Error = Error;
113
114 fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Error> {
115 match self.inner.decode(buf) {
116 Ok(None) => Ok(None),
117 Err(e) => Err(e),
118 Ok(Some(s)) => {
119 let req: Self::Item = serde_json::from_value(s)?;
120 Ok(Some(req))
121 }
122 }
123 }
124}
125
126#[cfg(test)]
127mod test {
128 use super::{find_separator, JsonCodec, MultiLineCodec};
129 use bytes::{BufMut, BytesMut};
130 use serde_json::json;
131 use tokio_util::codec::{Decoder, Encoder};
132
133 #[test]
134 fn test_separator() {
135 struct Test(String, Option<usize>);
136 let tests = vec![
137 Test("".to_string(), None),
138 Test("}\n\n".to_string(), Some(1)),
139 Test("\"hello\"},\n\"world\"}\n\n".to_string(), Some(18)),
140 ];
141
142 for t in tests.iter() {
143 let mut buf = BytesMut::new();
144 buf.put_slice(t.0.as_bytes());
145 assert_eq!(find_separator(&mut buf), t.1);
146 }
147 }
148
149 #[test]
150 fn test_ml_decoder() {
151 struct Test(String, Option<String>, String);
152 let tests = vec![
153 Test("".to_string(), None, "".to_string()),
154 Test(
155 "{\"hello\":\"world\"}\n\nremainder".to_string(),
156 Some("{\"hello\":\"world\"}".to_string()),
157 "remainder".to_string(),
158 ),
159 Test(
160 "{\"hello\":\"world\"}\n\n{}\n\nremainder".to_string(),
161 Some("{\"hello\":\"world\"}".to_string()),
162 "{}\n\nremainder".to_string(),
163 ),
164 ];
165
166 for t in tests.iter() {
167 let mut buf = BytesMut::new();
168 buf.put_slice(t.0.as_bytes());
169
170 let mut codec = MultiLineCodec::default();
171 let mut remainder = BytesMut::new();
172 remainder.put_slice(t.2.as_bytes());
173
174 assert_eq!(codec.decode(&mut buf).unwrap(), t.1);
175 assert_eq!(buf, remainder);
176 }
177 }
178
179 #[test]
180 fn test_ml_encoder() {
181 let tests = vec!["test"];
182
183 for t in tests.iter() {
184 let mut buf = BytesMut::new();
185 let mut codec = MultiLineCodec::default();
186 let mut expected = BytesMut::new();
187 expected.put_slice(t.as_bytes());
188 expected.put_u8(b'\n');
189 expected.put_u8(b'\n');
190 codec.encode(t, &mut buf).unwrap();
191 assert_eq!(buf, expected);
192 }
193 }
194
195 #[test]
196 fn test_json_codec() {
197 let tests = vec![json!({"hello": "world"})];
198
199 for t in tests.iter() {
200 let mut codec = JsonCodec::default();
201 let mut buf = BytesMut::new();
202 codec.encode(t.clone(), &mut buf).unwrap();
203 let decoded = codec.decode(&mut buf).unwrap().unwrap();
204 assert_eq!(&decoded, t);
205 }
206 }
207}