1use crate::Error;
7use anyhow::anyhow;
8use bytes::{BufMut, BytesMut};
9use serde_json::value::Value;
10use std::str::FromStr;
11use std::{io, str};
12use tokio_util::codec::{Decoder, Encoder};
13
14pub use crate::jsonrpc::JsonRpc;
15use crate::{model::Request, notifications::Notification};
16
17#[derive(Default)]
20pub struct MultiLineCodec {
21 search_pos: usize,
22}
23
24fn utf8(buf: &[u8]) -> Result<&str, io::Error> {
25 str::from_utf8(buf)
26 .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "Unable to decode input as UTF8"))
27}
28
29impl Decoder for MultiLineCodec {
30 type Item = String;
31 type Error = Error;
32 fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Error> {
33 let bytes = &buf[..];
34 let mut i = self.search_pos;
35
36 while i + 1 < bytes.len() {
37 if bytes[i] == b'\n' && bytes[i + 1] == b'\n' {
38 let line = buf.split_to(i + 2);
39 let line = &line[..line.len() - 2];
40
41 self.search_pos = 0;
42
43 return Ok(Some(utf8(line)?.to_owned()));
44 }
45 i += 1;
46 }
47
48 self.search_pos = bytes.len().saturating_sub(1);
49
50 Ok(None)
51 }
52}
53
54impl<T> Encoder<T> for MultiLineCodec
55where
56 T: AsRef<str>,
57{
58 type Error = Error;
59 fn encode(&mut self, line: T, buf: &mut BytesMut) -> Result<(), Self::Error> {
60 let line = line.as_ref();
61 buf.reserve(line.len() + 2);
62 buf.put(line.as_bytes());
63 buf.put_u8(b'\n');
64 buf.put_u8(b'\n');
65 Ok(())
66 }
67}
68
69#[derive(Default)]
70pub struct JsonCodec {
71 inner: MultiLineCodec,
74}
75
76impl<T> Encoder<T> for JsonCodec
77where
78 T: Into<Value>,
79{
80 type Error = Error;
81 fn encode(&mut self, msg: T, buf: &mut BytesMut) -> Result<(), Self::Error> {
82 let s = msg.into().to_string();
83 self.inner.encode(s, buf)
84 }
85}
86
87impl Decoder for JsonCodec {
88 type Item = Value;
89 type Error = Error;
90
91 fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Error> {
92 match self.inner.decode(buf) {
93 Ok(None) => Ok(None),
94 Err(e) => Err(e),
95 Ok(Some(s)) => {
96 if let Ok(v) = Value::from_str(&s) {
97 Ok(Some(v))
98 } else {
99 Err(anyhow!("failed to parse JSON"))
100 }
101 }
102 }
103 }
104}
105
106#[allow(dead_code)]
110#[derive(Default)]
111pub(crate) struct JsonRpcCodec {
112 inner: JsonCodec,
113}
114
115impl Decoder for JsonRpcCodec {
116 type Item = JsonRpc<Notification, Request>;
117 type Error = Error;
118
119 fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Error> {
120 match self.inner.decode(buf) {
121 Ok(None) => Ok(None),
122 Err(e) => Err(e),
123 Ok(Some(s)) => {
124 let req: Self::Item = serde_json::from_value(s)?;
125 Ok(Some(req))
126 }
127 }
128 }
129}
130
131#[cfg(test)]
132mod test {
133 use super::{JsonCodec, MultiLineCodec};
134 use bytes::{BufMut, BytesMut};
135 use serde_json::json;
136 use tokio_util::codec::{Decoder, Encoder};
137
138 #[test]
139 fn test_ml_decoder() {
140 struct Test(String, Option<String>, String);
141 let tests = vec![
142 Test("".to_string(), None, "".to_string()),
143 Test(
144 "{\"hello\":\"world\"}\n\nremainder".to_string(),
145 Some("{\"hello\":\"world\"}".to_string()),
146 "remainder".to_string(),
147 ),
148 Test(
149 "{\"hello\":\"world\"}\n\n{}\n\nremainder".to_string(),
150 Some("{\"hello\":\"world\"}".to_string()),
151 "{}\n\nremainder".to_string(),
152 ),
153 ];
154
155 for t in tests.iter() {
156 let mut buf = BytesMut::new();
157 buf.put_slice(t.0.as_bytes());
158
159 let mut codec = MultiLineCodec::default();
160 let mut remainder = BytesMut::new();
161 remainder.put_slice(t.2.as_bytes());
162
163 assert_eq!(codec.decode(&mut buf).unwrap(), t.1);
164 assert_eq!(buf, remainder);
165 }
166 }
167
168 #[test]
169 fn test_ml_encoder() {
170 let tests = vec!["test"];
171
172 for t in tests.iter() {
173 let mut buf = BytesMut::new();
174 let mut codec = MultiLineCodec::default();
175 let mut expected = BytesMut::new();
176 expected.put_slice(t.as_bytes());
177 expected.put_u8(b'\n');
178 expected.put_u8(b'\n');
179 codec.encode(t, &mut buf).unwrap();
180 assert_eq!(buf, expected);
181 }
182 }
183
184 #[test]
185 fn test_json_codec() {
186 let tests = vec![json!({"hello": "world"})];
187
188 for t in tests.iter() {
189 let mut codec = JsonCodec::default();
190 let mut buf = BytesMut::new();
191 codec.encode(t.clone(), &mut buf).unwrap();
192 let decoded = codec.decode(&mut buf).unwrap().unwrap();
193 assert_eq!(&decoded, t);
194 }
195 }
196}