jsonrpc_codec/
httpcodec.rs

1use bytes::{BufMut, Bytes, BytesMut};
2use tokio_codec::{Decoder, Encoder};
3
4use super::request::Request;
5use super::response::{Error, Response};
6
7#[derive(Debug, Clone)]
8pub enum HTTP {
9    Request(Request),
10    Response(Response),
11    Error(Error),
12    NeedMore(usize, usize, Bytes),
13}
14
15impl HTTP {
16    fn parse(src: &mut BytesMut, mut had_body: BytesMut) -> Option<HTTP> {
17        if had_body.len() == 0 {
18            let (version, amt, length) = {
19                let mut req_parsed_headers = [httparse::EMPTY_HEADER; 16];
20                let mut res_parsed_headers = [httparse::EMPTY_HEADER; 16];
21                let mut req = httparse::Request::new(&mut req_parsed_headers);
22                let mut res = httparse::Response::new(&mut res_parsed_headers);
23                let req_status = req.parse(&src);
24                let res_status = res.parse(&src);
25
26                if req_status.is_err() && res_status.is_err() {
27                    println!("failed to parse http");
28                    return Some(HTTP::Error(Error::ParseError(None)));
29                }
30
31                let (status, version, length) = if req_status.is_err() {
32                    let content_length_headers: Vec<httparse::Header> = res
33                        .headers
34                        .iter()
35                        .filter(|header| header.name == "Content-Length")
36                        .cloned()
37                        .collect();
38
39                    if content_length_headers.len() != 1 {
40                        return Some(HTTP::Error(Error::ParseError(None)));
41                    }
42
43                    let length_bytes = content_length_headers.first().unwrap().value;
44                    let mut length_string = String::new();
45
46                    for b in length_bytes {
47                        length_string.push(*b as char);
48                    }
49
50                    let length = length_string.parse::<usize>();
51                    if length.is_err() {
52                        return Some(HTTP::Error(Error::ParseError(None)));
53                    };
54
55                    (res_status.unwrap(), res.version.unwrap(), length.unwrap())
56                } else {
57                    let content_length_headers: Vec<httparse::Header> = req
58                        .headers
59                        .iter()
60                        .filter(|header| header.name == "Content-Length")
61                        .cloned()
62                        .collect();
63
64                    if content_length_headers.len() != 1 {
65                        return Some(HTTP::Error(Error::ParseError(None)));
66                    }
67
68                    let length_bytes = content_length_headers.first().unwrap().value;
69                    let mut length_string = String::new();
70
71                    for b in length_bytes {
72                        length_string.push(*b as char);
73                    }
74
75                    let length = length_string.parse::<usize>();
76                    if length.is_err() {
77                        return Some(HTTP::Error(Error::ParseError(None)));
78                    };
79
80                    (req_status.unwrap(), req.version.unwrap(), length.unwrap())
81                };
82
83                let amt = match status {
84                    httparse::Status::Complete(amt) => amt,
85                    httparse::Status::Partial => return Some(HTTP::Error(Error::ParseError(None))),
86                };
87
88                (version, amt, length)
89            };
90            if version != 1 {
91                println!("only HTTP/1.1 accepted");
92                return Some(HTTP::Error(Error::ParseError(None)));
93            }
94
95            had_body = src.split_off(amt);
96
97            if had_body.len() < length {
98                return Some(HTTP::NeedMore(amt, length, had_body.freeze()));
99            }
100        }
101
102        let json = had_body.freeze();
103
104        let request_result = Request::parse_from_json_bytes(json.clone());
105        if request_result.is_err() {
106            let response_result = Response::parse_from_json_bytes(json);
107            if response_result.is_err() {
108                Some(HTTP::Error(request_result.err().unwrap()))
109            } else {
110                Some(HTTP::Response(response_result.unwrap()))
111            }
112        } else {
113            Some(HTTP::Request(request_result.unwrap()))
114        }
115    }
116
117    fn deparse(&self) -> Bytes {
118        match self {
119            HTTP::Request(meta) => meta.deparse(),
120            HTTP::Response(meta) => meta.deparse(),
121            HTTP::Error(meta) => meta.deparse(),
122            _ => Bytes::new(),
123        }
124    }
125}
126
127// cache, body length, header_length, is_receiving
128#[derive(Default)]
129pub struct HTTPCodec(pub BytesMut, pub usize, pub usize, pub bool);
130
131impl HTTPCodec {
132    pub fn new() -> Self {
133        HTTPCodec(BytesMut::new(), 0, 0, true)
134    }
135}
136
137impl Decoder for HTTPCodec {
138    type Item = HTTP;
139    type Error = std::io::Error;
140
141    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
142        if !self.3 {
143            return Ok(None);
144        }
145        if self.2 > 0 {
146            let bytes = src.split_off(self.2);
147            self.0.reserve(bytes.len());
148            self.0.put(bytes);
149
150            if self.0.len() < self.1 {
151                return Ok(None);
152            }
153        }
154
155        let http = HTTP::parse(src, self.0.clone());
156        match http {
157            Some(HTTP::NeedMore(amt, length, bytes)) => {
158                self.0.reserve(bytes.len());
159                self.0.put(bytes);
160                self.1 = length; // body leangth
161                self.2 = amt; // header length
162                Ok(None)
163            }
164            Some(h) => {
165                self.3 = false;
166                self.0.clear();
167                self.0.reserve(0);
168                Ok(Some(h))
169            }
170            None => Ok(None),
171        }
172    }
173}
174
175impl Encoder for HTTPCodec {
176    type Item = HTTP;
177    type Error = std::io::Error;
178
179    fn encode(&mut self, msg: HTTP, dst: &mut BytesMut) -> Result<(), Self::Error> {
180        let bytes = msg.deparse();
181        dst.reserve(bytes.len());
182        dst.put(bytes);
183        Ok(())
184    }
185}