cogo_http/server/
mod.rs

1//! HTTP Server
2//!
3//! # Server
4//!
5//! A `Server` is created to listen on port, parse HTTP requests, and hand
6//! them off to a `Handler`. By default, the Server will listen across multiple
7//! threads, but that can be configured to a single thread if preferred.
8//!
9//! # Handling requests
10//!
11//! You must pass a `Handler` to the Server that will handle requests. There is
12//! a default implementation for `fn`s and closures, allowing you pass one of
13//! those easily.
14//!
15//!
16//! ```no_run
17//! use cogo_http::server::{Server, Request, Response};
18//!
19//! fn hello(req: Request, res: Response) {
20//!     // handle things here
21//! }
22//!
23//! Server::http("0.0.0.0:0").unwrap().handle(hello).unwrap();
24//! ```
25//!
26//! As with any trait, you can also define a struct and implement `Handler`
27//! directly on your own type, and pass that to the `Server` instead.
28//!
29//! ```no_run
30//! use std::sync::Mutex;
31//! use std::sync::mpsc::{channel, Sender};
32//! use cogo_http::server::{Handler, Server, Request, Response};
33//!
34//! struct SenderHandler {
35//!     sender: Mutex<Sender<&'static str>>
36//! }
37//!
38//! impl Handler for SenderHandler {
39//!     fn handle(&self, req: Request, res: Response) {
40//!         self.sender.lock().unwrap().send("start").unwrap();
41//!     }
42//! }
43//!
44//!
45//! let (tx, rx) = channel();
46//! Server::http("0.0.0.0:0").unwrap().handle(SenderHandler {
47//!     sender: Mutex::new(tx)
48//! }).unwrap();
49//! ```
50//!
51//! Since the `Server` will be listening on multiple threads, the `Handler`
52//! must implement `Sync`: any mutable state must be synchronized.
53//!
54//! ```no_run
55//! use std::sync::atomic::{AtomicUsize, Ordering};
56//! use cogo_http::server::{Server, Request, Response};
57//!
58//! let counter = AtomicUsize::new(0);
59//! Server::http("0.0.0.0:0").unwrap().handle(move |req: Request, res: Response| {
60//!     counter.fetch_add(1, Ordering::Relaxed);
61//! }).unwrap();
62//! ```
63//!
64//! # The `Request` and `Response` pair
65//!
66//! A `Handler` receives a pair of arguments, a `Request` and a `Response`. The
67//! `Request` includes access to the `method`, `uri`, and `headers` of the
68//! incoming HTTP request. It also implements `std::io::Read`, in order to
69//! read any body, such as with `POST` or `PUT` messages.
70//!
71//! Likewise, the `Response` includes ways to set the `status` and `headers`,
72//! and implements `std::io::Write` to allow writing the response body.
73//!
74//! ```no_run
75//! use std::io;
76//! use cogo_http::server::{Server, Request, Response};
77//! use cogo_http::status::StatusCode;
78//!
79//! Server::http("0.0.0.0:0").unwrap().handle(|mut req: Request, mut res: Response| {
80//!     match req.method {
81//!         cogo_http::Post => {
82//!             io::copy(&mut req, &mut res.start().unwrap()).unwrap();
83//!         }
84//!         _ => *res.status_mut() = StatusCode::MethodNotAllowed
85//!     }
86//! }).unwrap();
87//! ```
88//!
89//! ## An aside: Write Status
90//!
91//! The `Response` uses a phantom type parameter to determine its write status.
92//! What does that mean? In short, it ensures you never write a body before
93//! adding all headers, and never add a header after writing some of the body.
94//!
95//! This is often done in most implementations by include a boolean property
96//! on the response, such as `headers_written`, checking that each time the
97//! body has something to write, so as to make sure the headers are sent once,
98//! and only once. But this has 2 downsides:
99//!
100//! 1. You are typically never notified that your late header is doing nothing.
101//! 2. There's a runtime cost to checking on every write.
102//!
103//! Instead, hyper handles this statically, or at compile-time. A
104//! `Response<Fresh>` includes a `headers_mut()` method, allowing you add more
105//! headers. It also does not implement `Write`, so you can't accidentally
106//! write early. Once the "head" of the response is correct, you can "send" it
107//! out by calling `start` on the `Response<Fresh>`. This will return a new
108//! `Response<Streaming>` object, that no longer has `headers_mut()`, but does
109//! implement `Write`.
110use std::fmt;
111use std::io::{self, ErrorKind, BufWriter, Write};
112use std::net::{SocketAddr, ToSocketAddrs, Shutdown};
113use std::sync::Arc;
114use std::time::Duration;
115use cogo::coroutine::yield_now;
116pub use self::request::Request;
117pub use self::response::Response;
118pub use crate::net::{Fresh, Streaming};
119use crate::{Error, runtime};
120use crate::buffer::BufReader;
121use crate::header::{Headers, Expect, Connection};
122use crate::http;
123use crate::method::Method;
124use crate::net::{NetworkListener, NetworkStream, HttpListener, HttpsListener, SslServer};
125use crate::status::StatusCode;
126use crate::uri::RequestUri;
127use crate::version::HttpVersion::Http11;
128
129use self::listener::ListenerPool;
130
131pub mod request;
132pub mod response;
133
134mod listener;
135
136/// A server can listen on a TCP socket.
137///
138/// Once listening, it will create a `Request`/`Response` pair for each
139/// incoming connection, and hand them to the provided handler.
140#[derive(Debug)]
141pub struct Server<L = HttpListener> {
142    pub listener: L,
143    pub timeouts: Timeouts,
144}
145
146#[derive(Clone, Copy, Debug)]
147pub struct Timeouts {
148    read: Option<Duration>,
149    keep_alive: Option<Duration>,
150    keep_alive_type: KeepAliveType,
151}
152
153#[derive(Clone, Copy, Debug)]
154pub enum KeepAliveType {
155    WaitTime(Duration),
156    //wait time close
157    WaitError(i32),//wait error close
158}
159
160impl Default for Timeouts {
161    fn default() -> Timeouts {
162        Timeouts {
163            read: None,
164            keep_alive: Some(Duration::from_secs(5)),
165            keep_alive_type: KeepAliveType::WaitTime(Duration::from_secs(5)),
166        }
167    }
168}
169
170impl<L: NetworkListener> Server<L> {
171    /// Creates a new server with the provided handler.
172    #[inline]
173    pub fn new(listener: L) -> Server<L> {
174        Server {
175            listener: listener,
176            timeouts: Timeouts::default(),
177        }
178    }
179
180    /// Controls keep-alive for this server.
181    ///
182    /// The timeout duration passed will be used to determine how long
183    /// to keep the connection alive before dropping it.
184    ///
185    /// Passing `None` will disable keep-alive.
186    ///
187    /// Default is enabled with a 5 second timeout.
188    #[inline]
189    pub fn keep_alive(mut self, timeout: Option<Duration>) -> Self {
190        self.timeouts.keep_alive = timeout;
191        self.timeouts.keep_alive_type = KeepAliveType::WaitTime(timeout.unwrap_or(Duration::from_secs(5)));
192        self
193    }
194
195    /// Sets the read timeout for all Request reads.
196    pub fn set_read_timeout(mut self, dur: Option<Duration>) -> Self {
197        self.listener.set_read_timeout(dur);
198        self.timeouts.read = dur;
199        self
200    }
201
202    /// Sets the write timeout for all Response writes.
203    pub fn set_write_timeout(mut self, dur: Option<Duration>) -> Self {
204        self.listener.set_write_timeout(dur);
205        self
206    }
207
208    /// set set_keep_alive_type
209    pub fn set_keep_alive_type(mut self, t: KeepAliveType) -> Self {
210        self.timeouts.keep_alive_type = t;
211        self
212    }
213
214    /// Get the address that the server is listening on.
215    pub fn local_addr(&mut self) -> io::Result<SocketAddr> {
216        self.listener.local_addr()
217    }
218}
219
220impl Server<HttpListener> {
221    /// Creates a new server that will handle `HttpStream`s.
222    pub fn http<To: ToSocketAddrs>(addr: To) -> crate::Result<Server<HttpListener>> {
223        HttpListener::new(addr).map(Server::new)
224    }
225}
226
227impl<S: SslServer + Clone + Send> Server<HttpsListener<S>> {
228    /// Creates a new server that will handle `HttpStream`s over SSL.
229    ///
230    /// You can use any SSL implementation, as long as implements `cogo_http::net::Ssl`.
231    pub fn https<A: ToSocketAddrs>(addr: A, ssl: S) -> crate::Result<Server<HttpsListener<S>>> {
232        HttpsListener::new(addr, ssl).map(Server::new)
233    }
234}
235
236macro_rules! t_c {
237    ($e: expr) => {
238        match $e {
239            Ok(val) => val,
240            Err(err) => {
241                error!("call = {:?}\nerr = {:?}", stringify!($e), err);
242                continue;
243            }
244        }
245    };
246}
247
248impl<L: NetworkListener + Send + 'static> Server<L> {
249    /// Binds to a socket and starts handling connections.
250    pub fn handle<H: Handler + 'static>(self, handler: H) -> crate::Result<Listening> {
251        Self::handle_stack(self, handler, 0x2000)
252    }
253
254    /// Binds to a socket and starts handling connections.
255    pub fn handle_stack<H: Handler + 'static>(self, handler: H, stack_size: usize) -> crate::Result<Listening> {
256        let worker = Arc::new(Worker::new(handler, self.timeouts));
257        let mut listener = self.listener.clone();
258        let h = runtime::spawn_stack_size(move || {
259            for stream in listener.incoming() {
260                let mut stream = t_c!(stream);
261                let w = worker.clone();
262                runtime::spawn_stack_size(move || {
263                    {
264                        #[cfg(unix)]
265                        stream.set_nonblocking(true);
266                        {
267                            match w.timeouts.keep_alive_type {
268                                KeepAliveType::WaitTime(timeout) => {
269                                    let mut now = std::time::Instant::now();
270                                    loop {
271                                        stream.reset_io();
272                                        let keep_alive = w.handle_connection(&mut stream);
273                                        stream.wait_io();
274                                        if keep_alive == false {
275                                            if now.elapsed() >= timeout {
276                                                return;
277                                            } else {
278                                                yield_now();
279                                                continue;
280                                            }
281                                        } else {
282                                            if now.elapsed() <= timeout {
283                                                now = std::time::Instant::now();
284                                            }
285                                        }
286                                    }
287                                }
288                                KeepAliveType::WaitError(total) => {
289                                    let mut count = 0;
290                                    loop {
291                                        stream.reset_io();
292                                        let keep_alive = w.handle_connection(&mut stream);
293                                        stream.wait_io();
294                                        if keep_alive == false {
295                                            count += 1;
296                                            if count >= total {
297                                                return;
298                                            }
299                                            yield_now();
300                                        }
301                                    }
302                                }
303                            }
304                        }
305                    }
306                }, stack_size);
307            }
308        }, stack_size);
309        let socket = r#try!(self.listener.clone().local_addr());
310        return Ok(Listening {
311            _guard: Some(h),
312            socket: socket,
313        });
314    }
315
316    /// Binds to a socket and starts handling connections with the provided
317    /// number of tasks on pool
318    pub fn handle_tasks<H: Handler + 'static>(self, handler: H, tasks: usize) -> crate::Result<Listening> {
319        handle_task(self, handler, tasks)
320    }
321}
322
323fn handle_task<H, L>(mut server: Server<L>, handler: H, tasks: usize) -> crate::Result<Listening>
324    where H: Handler + 'static, L: NetworkListener + Send + 'static {
325    let socket = r#try!(server.listener.local_addr());
326
327    debug!("tasks = {:?}", tasks);
328    let pool = ListenerPool::new(server.listener);
329    let worker = Worker::new(handler, server.timeouts);
330    let work = move |mut stream| {
331        worker.handle_connection(&mut stream);
332    };
333
334    let guard = runtime::spawn(move || {
335        pool.accept(work, tasks);
336    });
337
338    Ok(Listening {
339        _guard: Some(guard),
340        socket: socket,
341    })
342}
343
344pub struct Worker<H: Handler + 'static> {
345    handler: H,
346    timeouts: Timeouts,
347}
348
349impl<H: Handler + 'static> Worker<H> {
350    pub fn new(handler: H, timeouts: Timeouts) -> Worker<H> {
351        Worker {
352            handler: handler,
353            timeouts: timeouts,
354        }
355    }
356
357    pub fn handle_connection<S>(&self, stream: &mut S) -> bool where S: NetworkStream {
358        debug!("Incoming stream");
359        self.handler.on_connection_start();
360
361        let addr = match stream.peer_addr() {
362            Ok(addr) => addr,
363            Err(e) => {
364                info!("Peer Name error: {:?}", e);
365                return false;
366            }
367        };
368        //safety will forget copy s
369        let mut s: S = unsafe { std::mem::transmute_copy(stream) };
370        let stream2: &mut dyn NetworkStream = &mut s;
371        let mut rdr = BufReader::new(stream2);
372        let mut wrt = BufWriter::new(stream);
373
374        let mut keep_alive = false;
375        while self.keep_alive_loop(&mut rdr, &mut wrt, addr) {
376            if let Err(e) = self.set_read_timeout(*rdr.get_ref(), self.timeouts.keep_alive) {
377                info!("set_read_timeout keep_alive {:?}", e);
378                break;
379            }
380            keep_alive = true;
381        }
382        self.handler.on_connection_end();
383        debug!("keep_alive loop ending for {}", addr);
384
385        std::mem::forget(s);
386        keep_alive
387    }
388
389    fn set_read_timeout(&self, s: &dyn NetworkStream, timeout: Option<Duration>) -> io::Result<()> {
390        s.set_read_timeout(timeout)
391    }
392
393    fn keep_alive_loop<W: Write>(&self, rdr: &mut BufReader<&mut dyn NetworkStream>,
394                                 wrt: &mut W, addr: SocketAddr) -> bool {
395        let req = match Request::new(rdr, addr) {
396            Ok(req) => req,
397            Err(Error::Io(ref e)) if e.kind() == ErrorKind::ConnectionAborted => {
398                trace!("tcp closed, cancelling keep-alive loop");
399                return false;
400            }
401            Err(Error::Io(e)) => {
402                debug!("ioerror in keepalive loop = {:?}", e);
403                return false;
404            }
405            Err(e) => {
406                //TODO: send a 400 response
407                info!("request error = {:?}", e);
408                return false;
409            }
410        };
411
412        if !self.handle_expect(&req, wrt) {
413            return false;
414        }
415
416        if let Err(e) = req.set_read_timeout(self.timeouts.read) {
417            info!("set_read_timeout {:?}", e);
418            return false;
419        }
420
421        let mut keep_alive = self.timeouts.keep_alive.is_some() &&
422            http::should_keep_alive(req.version, &req.headers);
423        let version = req.version;
424        let mut res_headers = Headers::with_capacity(1);
425        if !keep_alive {
426            res_headers.set(Connection::close());
427        }
428        {
429            let mut res = Response::new(wrt, &mut res_headers);
430            res.version = version;
431            self.handler.handle(req, res);
432        }
433
434        // if the request was keep-alive, we need to check that the server agrees
435        // if it wasn't, then the server cannot force it to be true anyways
436        if keep_alive {
437            keep_alive = http::should_keep_alive(version, &res_headers);
438        }
439
440        debug!("keep_alive = {:?} for {}", keep_alive, addr);
441        keep_alive
442    }
443
444    fn handle_expect<W: Write>(&self, req: &Request, wrt: &mut W) -> bool {
445        if req.version == Http11 && req.headers.get() == Some(&Expect::Continue) {
446            let status = self.handler.check_continue((&req.method, &req.uri, &req.headers));
447            match write!(wrt, "{} {}\r\n\r\n", Http11, status).and_then(|_| wrt.flush()) {
448                Ok(..) => (),
449                Err(e) => {
450                    info!("error writing 100-continue: {:?}", e);
451                    return false;
452                }
453            }
454
455            if status != StatusCode::Continue {
456                debug!("non-100 status ({}) for Expect 100 request", status);
457                return false;
458            }
459        }
460
461        true
462    }
463}
464
465/// A listening server, which can later be closed.
466pub struct Listening {
467    _guard: Option<runtime::JoinHandle<()>>,
468    /// The socket addresses that the server is bound to.
469    pub socket: SocketAddr,
470}
471
472impl fmt::Debug for Listening {
473    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
474        write!(f, "Listening {{ socket: {:?} }}", self.socket)
475    }
476}
477
478impl Drop for Listening {
479    fn drop(&mut self) {
480        let _ = self._guard.take().map(|g| g.join());
481    }
482}
483
484impl Listening {
485    /// Warning: This function doesn't work. The server remains listening after you called
486    /// it. See https://github.com/hyperium/hyper/issues/338 for more details.
487    ///
488    /// Stop the server from listening to its socket address.
489    pub fn close(&mut self) -> crate::Result<()> {
490        let _ = self._guard.take();
491        debug!("closing server");
492        Ok(())
493    }
494}
495
496/// A handler that can handle incoming requests for a server.
497pub trait Handler: Sync + Send {
498    /// Receives a `Request`/`Response` pair, and should perform some action on them.
499    ///
500    /// This could reading from the request, and writing to the response.
501    fn handle(&self, req: Request, resp: Response<'_, Fresh>);
502
503    /// Called when a Request includes a `Expect: 100-continue` header.
504    ///
505    /// By default, this will always immediately response with a `StatusCode::Continue`,
506    /// but can be overridden with custom behavior.
507    fn check_continue(&self, _: (&Method, &RequestUri, &Headers)) -> StatusCode {
508        StatusCode::Continue
509    }
510
511    /// This is run after a connection is received, on a per-connection basis (not a
512    /// per-request basis, as a connection with keep-alive may handle multiple
513    /// requests)
514    fn on_connection_start(&self) {}
515
516    /// This is run before a connection is closed, on a per-connection basis (not a
517    /// per-request basis, as a connection with keep-alive may handle multiple
518    /// requests)
519    fn on_connection_end(&self) {}
520}
521
522impl<F> Handler for F where F: Fn(Request, Response<Fresh>), F: Sync + Send {
523    fn handle(&self, req: Request, res: Response) {
524        self(req, res)
525    }
526}
527
528#[cfg(test)]
529mod tests {
530    use crate::header::Headers;
531    use crate::method::Method;
532    use crate::mock::MockStream;
533    use crate::status::StatusCode;
534    use crate::uri::RequestUri;
535
536    use super::{Request, Response, Fresh, Handler, Worker};
537
538    #[test]
539    fn test_check_continue_default() {
540        let mut mock = MockStream::with_input(b"\
541            POST /upload HTTP/1.1\r\n\
542            Host: example.domain\r\n\
543            Expect: 100-continue\r\n\
544            Content-Length: 10\r\n\
545            \r\n\
546            1234567890\
547        ");
548
549        fn handle(_: Request, res: Response<Fresh>) {
550            res.start().unwrap().end().unwrap();
551        }
552
553        Worker::new(handle, Default::default()).handle_connection(&mut mock);
554        let cont = b"HTTP/1.1 100 Continue\r\n\r\n";
555        assert_eq!(&mock.write[..cont.len()], cont);
556        let res = b"HTTP/1.1 200 OK\r\n";
557        assert_eq!(&mock.write[cont.len()..cont.len() + res.len()], res);
558    }
559
560    #[test]
561    fn test_check_continue_reject() {
562        struct Reject;
563        impl Handler for Reject {
564            fn handle(&self, _: Request, res: Response<'_, Fresh>) {
565                res.start().unwrap().end().unwrap();
566            }
567
568            fn check_continue(&self, _: (&Method, &RequestUri, &Headers)) -> StatusCode {
569                StatusCode::ExpectationFailed
570            }
571        }
572
573        let mut mock = MockStream::with_input(b"\
574            POST /upload HTTP/1.1\r\n\
575            Host: example.domain\r\n\
576            Expect: 100-continue\r\n\
577            Content-Length: 10\r\n\
578            \r\n\
579            1234567890\
580        ");
581
582        Worker::new(Reject, Default::default()).handle_connection(&mut mock);
583        assert_eq!(mock.write, &b"HTTP/1.1 417 Expectation Failed\r\n\r\n"[..]);
584    }
585}