http_codec/
server.rs

1use std::mem;
2
3use atoi::atoi;
4use bytes::{Bytes, BytesMut};
5use http::{header, request, Request, Response};
6use httparse;
7use tokio_codec::{Decoder, Encoder};
8
9use error;
10
11#[derive(Debug)]
12pub struct HttpCodec {
13    state: State,
14    max_cl: usize,
15}
16
17impl HttpCodec {
18    pub fn new(max_cl: usize) -> HttpCodec {
19        HttpCodec {
20            state: State::ParsingRequest,
21            max_cl,
22        }
23    }
24}
25
26#[derive(Debug)]
27enum State {
28    ParsingRequest,
29    ReadingBody(request::Parts, usize),
30}
31
32impl Encoder for HttpCodec {
33    type Item = Response<Bytes>;
34    type Error = error::Error;
35
36    fn encode(&mut self, item: Response<Bytes>, dst: &mut BytesMut) -> error::Result<()> {
37        dst.extend_from_slice(b"HTTP/1.1 ");
38        dst.extend_from_slice(format!("{}", item.status()).as_bytes());
39        dst.extend_from_slice(b"\r\n");
40
41        for (k, v) in item.headers() {
42            dst.extend_from_slice(k.as_str().as_bytes());
43            dst.extend_from_slice(b": ");
44            dst.extend_from_slice(v.as_bytes());
45            dst.extend_from_slice(b"\r\n");
46        }
47
48        dst.extend_from_slice(b"\r\n");
49        dst.extend_from_slice(item.body());
50
51        Ok(())
52    }
53}
54
55impl Decoder for HttpCodec {
56    type Item = Request<Bytes>;
57    type Error = error::Error;
58
59    fn decode(&mut self, src: &mut BytesMut) -> error::Result<Option<Request<Bytes>>> {
60        use self::State::*;
61
62        if src.len() == 0 {
63            return Ok(None);
64        }
65
66        loop {
67            match mem::replace(&mut self.state, ParsingRequest) {
68                ParsingRequest => {
69                    let amt = {
70                        let mut headers = [httparse::EMPTY_HEADER; 16];
71                        let mut request = httparse::Request::new(&mut headers);
72                        let amt = match request.parse(src)? {
73                            httparse::Status::Complete(amt) => amt,
74                            httparse::Status::Partial => return Ok(None),
75                        };
76                        match request.version.unwrap() {
77                            1 => (),
78                            version => return Err(error::Error::VersionError(version)),
79                        }
80                        let mut builder = Request::builder();
81                        builder.method(request.method.unwrap());
82                        builder.uri(request.path.unwrap());
83                        for header in request.headers.iter() {
84                            builder.header(header.name, header.value);
85                        }
86                        let r = builder.body(()).unwrap();
87                        let cl = match r.headers().get(header::CONTENT_LENGTH) {
88                            Some(cl) => match atoi(cl.as_bytes()) {
89                                Some(cl) => cl,
90                                None => return Err(error::Error::ContentLengthError),
91                            },
92                            None => return Err(error::Error::ContentLengthError),
93                        };
94                        if cl > self.max_cl {
95                            return Err(error::Error::ContentLengthError);
96                        }
97                        let (parts, _) = r.into_parts();
98                        self.state = ReadingBody(parts, cl);
99                        amt
100                    };
101                    src.advance(amt);
102                }
103                ReadingBody(parts, cl) => {
104                    if src.len() < cl {
105                        self.state = ReadingBody(parts, cl);
106                        return Ok(None);
107                    }
108                    let body = src.split_to(cl).freeze();
109                    return Ok(Some(Request::from_parts(parts, body)));
110                }
111            }
112        }
113    }
114}
115
116#[cfg(test)]
117mod tests {
118    extern crate fake_stream;
119    extern crate futures;
120
121    use self::fake_stream::FakeStream;
122    use self::futures::{Async, Future, Sink, Stream};
123    use super::*;
124    use http::{status, Method};
125    use std::io::{Read, Write};
126    use std::str::from_utf8;
127
128    #[test]
129    fn test_decode() {
130        let mut fake = FakeStream::new();
131        let req = b"\
132                    POST /cgi-bin/process.cgi HTTP/1.1\r\n\
133                    connection: Keep-Alive\r\n\
134                    content-length: 9\r\n\r\n\
135                    something";
136        let wl = fake.write(req).unwrap();
137
138        assert_eq!(req.len(), wl);
139
140        let mut framed = HttpCodec::new(10).framed(fake);
141
142        let request = match framed.poll().unwrap() {
143            Async::Ready(Some(request)) => request,
144            _ => panic!("no request"),
145        };
146
147        assert_eq!(request.uri().path(), "/cgi-bin/process.cgi");
148        assert_eq!(request.method(), Method::POST);
149        assert_eq!(request.headers().len(), 2);
150        assert_eq!(request.body(), &Bytes::from_static(b"something"));
151    }
152
153    #[test]
154    fn test_encode() {
155        let fake = FakeStream::new();
156        let expected = "\
157                        HTTP/1.1 200 OK\r\n\
158                        content-length: 9\r\n\
159                        content-type: text/html\r\n\r\n\
160                        something";
161
162        let mut buf = vec![0; expected.len()];
163
164        let res = Response::builder()
165            .status(status::StatusCode::OK)
166            .header("content-length", "9")
167            .header("content-type", "text/html")
168            .body(Bytes::from_static(b"something"))
169            .unwrap();
170
171        let framed = HttpCodec::new(10).framed(fake);
172
173        let framed = framed.send(res).wait().unwrap();
174
175        let mut fake = framed.into_inner();
176
177        let rl = fake.read(&mut buf).unwrap();
178
179        assert_eq!(rl, expected.len());
180        assert_eq!(from_utf8(&buf).unwrap(), expected);
181    }
182}