1use std::mem;
2
3use atoi::atoi;
4use bytes::{Bytes, BytesMut};
5use http::{header, request, Request, Response};
6use httparse;
7use tokio_codec::{Decoder, Encoder};
8
9use error;
10
11#[derive(Debug)]
12pub struct HttpCodec {
13 state: State,
14 max_cl: usize,
15}
16
17impl HttpCodec {
18 pub fn new(max_cl: usize) -> HttpCodec {
19 HttpCodec {
20 state: State::ParsingRequest,
21 max_cl,
22 }
23 }
24}
25
26#[derive(Debug)]
27enum State {
28 ParsingRequest,
29 ReadingBody(request::Parts, usize),
30}
31
32impl Encoder for HttpCodec {
33 type Item = Response<Bytes>;
34 type Error = error::Error;
35
36 fn encode(&mut self, item: Response<Bytes>, dst: &mut BytesMut) -> error::Result<()> {
37 dst.extend_from_slice(b"HTTP/1.1 ");
38 dst.extend_from_slice(format!("{}", item.status()).as_bytes());
39 dst.extend_from_slice(b"\r\n");
40
41 for (k, v) in item.headers() {
42 dst.extend_from_slice(k.as_str().as_bytes());
43 dst.extend_from_slice(b": ");
44 dst.extend_from_slice(v.as_bytes());
45 dst.extend_from_slice(b"\r\n");
46 }
47
48 dst.extend_from_slice(b"\r\n");
49 dst.extend_from_slice(item.body());
50
51 Ok(())
52 }
53}
54
55impl Decoder for HttpCodec {
56 type Item = Request<Bytes>;
57 type Error = error::Error;
58
59 fn decode(&mut self, src: &mut BytesMut) -> error::Result<Option<Request<Bytes>>> {
60 use self::State::*;
61
62 if src.len() == 0 {
63 return Ok(None);
64 }
65
66 loop {
67 match mem::replace(&mut self.state, ParsingRequest) {
68 ParsingRequest => {
69 let amt = {
70 let mut headers = [httparse::EMPTY_HEADER; 16];
71 let mut request = httparse::Request::new(&mut headers);
72 let amt = match request.parse(src)? {
73 httparse::Status::Complete(amt) => amt,
74 httparse::Status::Partial => return Ok(None),
75 };
76 match request.version.unwrap() {
77 1 => (),
78 version => return Err(error::Error::VersionError(version)),
79 }
80 let mut builder = Request::builder();
81 builder.method(request.method.unwrap());
82 builder.uri(request.path.unwrap());
83 for header in request.headers.iter() {
84 builder.header(header.name, header.value);
85 }
86 let r = builder.body(()).unwrap();
87 let cl = match r.headers().get(header::CONTENT_LENGTH) {
88 Some(cl) => match atoi(cl.as_bytes()) {
89 Some(cl) => cl,
90 None => return Err(error::Error::ContentLengthError),
91 },
92 None => return Err(error::Error::ContentLengthError),
93 };
94 if cl > self.max_cl {
95 return Err(error::Error::ContentLengthError);
96 }
97 let (parts, _) = r.into_parts();
98 self.state = ReadingBody(parts, cl);
99 amt
100 };
101 src.advance(amt);
102 }
103 ReadingBody(parts, cl) => {
104 if src.len() < cl {
105 self.state = ReadingBody(parts, cl);
106 return Ok(None);
107 }
108 let body = src.split_to(cl).freeze();
109 return Ok(Some(Request::from_parts(parts, body)));
110 }
111 }
112 }
113 }
114}
115
116#[cfg(test)]
117mod tests {
118 extern crate fake_stream;
119 extern crate futures;
120
121 use self::fake_stream::FakeStream;
122 use self::futures::{Async, Future, Sink, Stream};
123 use super::*;
124 use http::{status, Method};
125 use std::io::{Read, Write};
126 use std::str::from_utf8;
127
128 #[test]
129 fn test_decode() {
130 let mut fake = FakeStream::new();
131 let req = b"\
132 POST /cgi-bin/process.cgi HTTP/1.1\r\n\
133 connection: Keep-Alive\r\n\
134 content-length: 9\r\n\r\n\
135 something";
136 let wl = fake.write(req).unwrap();
137
138 assert_eq!(req.len(), wl);
139
140 let mut framed = HttpCodec::new(10).framed(fake);
141
142 let request = match framed.poll().unwrap() {
143 Async::Ready(Some(request)) => request,
144 _ => panic!("no request"),
145 };
146
147 assert_eq!(request.uri().path(), "/cgi-bin/process.cgi");
148 assert_eq!(request.method(), Method::POST);
149 assert_eq!(request.headers().len(), 2);
150 assert_eq!(request.body(), &Bytes::from_static(b"something"));
151 }
152
153 #[test]
154 fn test_encode() {
155 let fake = FakeStream::new();
156 let expected = "\
157 HTTP/1.1 200 OK\r\n\
158 content-length: 9\r\n\
159 content-type: text/html\r\n\r\n\
160 something";
161
162 let mut buf = vec![0; expected.len()];
163
164 let res = Response::builder()
165 .status(status::StatusCode::OK)
166 .header("content-length", "9")
167 .header("content-type", "text/html")
168 .body(Bytes::from_static(b"something"))
169 .unwrap();
170
171 let framed = HttpCodec::new(10).framed(fake);
172
173 let framed = framed.send(res).wait().unwrap();
174
175 let mut fake = framed.into_inner();
176
177 let rl = fake.read(&mut buf).unwrap();
178
179 assert_eq!(rl, expected.len());
180 assert_eq!(from_utf8(&buf).unwrap(), expected);
181 }
182}