hyper_http_connector/
lib.rs

1//! A duplicate of the default HttpConnector that comes with `hyper`.
2//!
3//! This is useful if you want to make modifications on it.
4#[macro_use]
5extern crate log;
6#[macro_use]
7extern crate futures;
8#[macro_use]
9extern crate lazy_static;
10
11extern crate antidote;
12extern crate c_ares;
13extern crate c_ares_resolver;
14extern crate futures_cpupool;
15extern crate http;
16extern crate hyper;
17extern crate net2;
18extern crate tokio;
19extern crate tokio_reactor;
20extern crate tokio_tcp;
21
22use c_ares::AResults;
23use c_ares_resolver::CAresFuture;
24use std::net::Ipv4Addr;
25
26use std::borrow::Cow;
27use std::collections::HashMap;
28use std::error::Error as StdError;
29use std::fmt;
30use std::io;
31use std::net::{IpAddr, SocketAddr};
32use std::sync::Arc;
33use std::time::Duration;
34
35use antidote::{Mutex, RwLock};
36use futures::future::{ExecuteError, Executor};
37use futures::sync::oneshot;
38use futures::{Async, Future, Poll};
39use futures_cpupool::Builder as CpuPoolBuilder;
40use http::uri::Scheme;
41use hyper::client::connect::{Connect, Connected, Destination};
42use net2::TcpBuilder;
43use tokio_reactor::Handle;
44use tokio_tcp::{ConnectFuture, TcpStream};
45
46mod dns;
47mod timed_cache;
48
49use self::dns::GLOBAL_RESOLVER;
50use self::http_connector::HttpConnectorBlockingTask;
51use self::timed_cache::TimedCache;
52
53fn connect(
54    addr: &SocketAddr,
55    local_addr: &Option<IpAddr>,
56    handle: &Option<Handle>,
57) -> io::Result<ConnectFuture> {
58    let builder = match addr {
59        &SocketAddr::V4(_) => TcpBuilder::new_v4()?,
60        &SocketAddr::V6(_) => TcpBuilder::new_v6()?,
61    };
62
63    if let Some(ref local_addr) = *local_addr {
64        // Caller has requested this socket be bound before calling connect
65        builder.bind(SocketAddr::new(local_addr.clone(), 0))?;
66    } else if cfg!(windows) {
67        // Windows requires a socket be bound before calling connect
68        let any: SocketAddr = match addr {
69            &SocketAddr::V4(_) => ([0, 0, 0, 0], 0).into(),
70            &SocketAddr::V6(_) => ([0, 0, 0, 0, 0, 0, 0, 0], 0).into(),
71        };
72        builder.bind(any)?;
73    }
74
75    let handle = match *handle {
76        Some(ref handle) => Cow::Borrowed(handle),
77        None => Cow::Owned(Handle::current()),
78    };
79
80    Ok(TcpStream::connect_std(
81        builder.to_tcp_stream()?,
82        addr,
83        &handle,
84    ))
85}
86
87type ResultCache = TimedCache<Arc<String>, Vec<Ipv4Addr>>;
88struct RoundRobinMap(HashMap<Arc<String>, usize>);
89
90impl RoundRobinMap {
91    fn new() -> RoundRobinMap {
92        RoundRobinMap(HashMap::new())
93    }
94
95    fn get_and_incr(&mut self, host: Arc<String>) -> usize {
96        *self
97            .0
98            .entry(Arc::clone(&host))
99            .and_modify(|e| *e = e.overflowing_add(1).0)
100            .or_insert(0)
101    }
102}
103
104/// A connector for the `http` scheme.
105///
106/// Performs DNS resolution in a thread pool, and then connects over TCP.
107#[derive(Clone)]
108pub struct HttpConnector {
109    executor: HttpConnectExecutor,
110    enforce_http: bool,
111    handle: Option<Handle>,
112    keep_alive_timeout: Option<Duration>,
113    nodelay: bool,
114    local_address: Option<IpAddr>,
115    round_robin_map: Arc<Mutex<RoundRobinMap>>,
116    result_cache: Arc<RwLock<ResultCache>>,
117}
118
119impl HttpConnector {
120    /// Construct a new HttpConnector.
121    ///
122    /// Takes number of DNS worker threads.
123    #[inline]
124    pub fn new(threads: usize) -> HttpConnector {
125        HttpConnector::new_with_handle_opt(threads, None)
126    }
127
128    /// Construct a new HttpConnector with a specific Tokio handle.
129    pub fn new_with_handle(threads: usize, handle: Handle) -> HttpConnector {
130        HttpConnector::new_with_handle_opt(threads, Some(handle))
131    }
132
133    fn new_with_handle_opt(threads: usize, handle: Option<Handle>) -> HttpConnector {
134        let pool = CpuPoolBuilder::new()
135            .name_prefix("hyper-dns")
136            .pool_size(threads)
137            .create();
138        HttpConnector::new_with_executor(pool, handle)
139    }
140
141    /// Construct a new HttpConnector.
142    ///
143    /// Takes an executor to run blocking tasks on.
144    pub fn new_with_executor<E: 'static>(executor: E, handle: Option<Handle>) -> HttpConnector
145    where
146        E: Executor<HttpConnectorBlockingTask> + Send + Sync,
147    {
148        HttpConnector {
149            executor: HttpConnectExecutor(Arc::new(executor)),
150            enforce_http: true,
151            handle,
152            keep_alive_timeout: None,
153            nodelay: false,
154            local_address: None,
155            round_robin_map: Arc::new(Mutex::new(RoundRobinMap::new())),
156            result_cache: Arc::new(RwLock::new(TimedCache::new())),
157        }
158    }
159
160    /// Option to enforce all `Uri`s have the `http` scheme.
161    ///
162    /// Enabled by default.
163    #[inline]
164    pub fn enforce_http(&mut self, is_enforced: bool) {
165        self.enforce_http = is_enforced;
166    }
167
168    /// Set that all sockets have `SO_KEEPALIVE` set with the supplied duration.
169    ///
170    /// If `None`, the option will not be set.
171    ///
172    /// Default is `None`.
173    #[inline]
174    pub fn set_keepalive(&mut self, dur: Option<Duration>) {
175        self.keep_alive_timeout = dur;
176    }
177
178    /// Set that all sockets have `SO_NODELAY` set to the supplied value `nodelay`.
179    ///
180    /// Default is `false`.
181    #[inline]
182    pub fn set_nodelay(&mut self, nodelay: bool) {
183        self.nodelay = nodelay;
184    }
185
186    /// Set that all sockets are bound to the configured address before connection.
187    ///
188    /// If `None`, the sockets will not be bound.
189    ///
190    /// Default is `None`.
191    #[inline]
192    pub fn set_local_address(&mut self, addr: Option<IpAddr>) {
193        self.local_address = addr;
194    }
195}
196
197impl fmt::Debug for HttpConnector {
198    #[inline]
199    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
200        f.debug_struct("HttpConnector").finish()
201    }
202}
203
204impl Connect for HttpConnector {
205    type Transport = TcpStream;
206    type Error = io::Error;
207    type Future = HttpConnecting;
208
209    fn connect(&self, dst: Destination) -> Self::Future {
210        let scheme = dst.scheme();
211        let host = dst.host();
212        let port = dst.port();
213        trace!(
214            "Http::connect; scheme={}, host={}, port={:?}",
215            scheme,
216            host,
217            port,
218        );
219
220        if self.enforce_http {
221            if scheme != &Scheme::HTTP {
222                return invalid_url(InvalidUrl::NotHttp, &self.handle);
223            }
224        } else if scheme.is_empty() {
225            return invalid_url(InvalidUrl::MissingScheme, &self.handle);
226        }
227
228        if host.is_empty() {
229            return invalid_url(InvalidUrl::MissingAuthority, &self.handle);
230        }
231
232        let port = match port {
233            Some(port) => port,
234            None => if scheme == &Scheme::HTTPS {
235                443
236            } else {
237                80
238            },
239        };
240
241        let host = Arc::new(host.into());
242
243        HttpConnecting {
244            state: State::Lazy(
245                self.executor.clone(),
246                host,
247                port,
248                self.local_address,
249                Arc::clone(&self.round_robin_map),
250                Arc::clone(&self.result_cache),
251            ),
252            handle: self.handle.clone(),
253            keep_alive_timeout: self.keep_alive_timeout,
254            nodelay: self.nodelay,
255        }
256    }
257}
258
259#[inline]
260fn invalid_url(err: InvalidUrl, handle: &Option<Handle>) -> HttpConnecting {
261    HttpConnecting {
262        state: State::Error(Some(io::Error::new(io::ErrorKind::InvalidInput, err))),
263        handle: handle.clone(),
264        keep_alive_timeout: None,
265        nodelay: false,
266    }
267}
268
269#[derive(Debug, Clone, Copy)]
270enum InvalidUrl {
271    MissingScheme,
272    NotHttp,
273    MissingAuthority,
274}
275
276impl fmt::Display for InvalidUrl {
277    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
278        f.write_str(self.description())
279    }
280}
281
282impl StdError for InvalidUrl {
283    fn description(&self) -> &str {
284        match *self {
285            InvalidUrl::MissingScheme => "invalid URL, missing scheme",
286            InvalidUrl::NotHttp => "invalid URL, scheme must be http",
287            InvalidUrl::MissingAuthority => "invalid URL, missing domain",
288        }
289    }
290}
291/// A Future representing work to connect to a URL.
292#[must_use = "futures do nothing unless polled"]
293pub struct HttpConnecting {
294    state: State,
295    handle: Option<Handle>,
296    keep_alive_timeout: Option<Duration>,
297    nodelay: bool,
298}
299
300enum State {
301    Lazy(
302        HttpConnectExecutor,
303        Arc<String>,
304        u16,
305        Option<IpAddr>,
306        Arc<Mutex<RoundRobinMap>>,
307        Arc<RwLock<ResultCache>>,
308    ),
309    Resolving(
310        oneshot::SpawnHandle<AResults, c_ares::Error>,
311        Arc<String>,
312        u16,
313        Option<IpAddr>,
314        Arc<Mutex<RoundRobinMap>>,
315        Arc<RwLock<ResultCache>>,
316    ),
317    Connecting(ConnectingTcp),
318    Error(Option<io::Error>),
319}
320
321impl Future for HttpConnecting {
322    type Item = (TcpStream, Connected);
323    type Error = io::Error;
324
325    fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
326        loop {
327            let state = match self.state {
328                State::Lazy(
329                    ref executor,
330                    ref host,
331                    port,
332                    local_addr,
333                    ref round_robin_map,
334                    ref result_cache,
335                ) => {
336                    // If the host is already an IP addr (v4 or v6),
337                    // skip resolving the dns and start connecting right away.
338                    if let Some(addrs) = dns::IpAddrs::try_parse(host, port) {
339                        State::Connecting(ConnectingTcp {
340                            addrs,
341                            local_addr,
342                            current: None,
343                        })
344                    } else {
345                        if let Some(ip_addrs) = result_cache.read().get(host) {
346                            trace!("ResultCache - got cached ips!");
347                            let shift_index =
348                                round_robin_map.lock().get_and_incr(Arc::clone(&host));
349                            State::Connecting(ConnectingTcp {
350                                addrs: dns::IpAddrs::new(port, ip_addrs.clone(), shift_index),
351                                local_addr,
352                                current: None,
353                            })
354                        } else {
355                            trace!("ResultCache - no cached ips!");
356                            let work = GLOBAL_RESOLVER.query_a(host);
357                            State::Resolving(
358                                oneshot::spawn(work, executor),
359                                Arc::clone(host),
360                                port,
361                                local_addr,
362                                Arc::clone(round_robin_map),
363                                Arc::clone(result_cache),
364                            )
365                        }
366                    }
367                }
368                State::Resolving(
369                    ref mut future,
370                    ref host,
371                    port,
372                    local_addr,
373                    ref round_robin_map,
374                    ref result_cache,
375                ) => match future
376                    .poll()
377                    .map_err(|err| io::Error::new(io::ErrorKind::Other, err))?
378                {
379                    Async::NotReady => return Ok(Async::NotReady),
380                    Async::Ready(a_results) => {
381                        let min_ttl = a_results.iter().map(|res| res.ttl()).min();
382                        let ips = a_results.into_iter().map(|res| res.ipv4()).collect::<Vec<_>>();
383                        let shift_index = round_robin_map.lock().get_and_incr(Arc::clone(&host));
384
385                        trace!("ResultCache - putting in the cache for {}!", host);
386                        // min_ttl will be None if no ip records were found
387                        if let Some(min_ttl) = min_ttl {
388                            result_cache.write().set(Arc::clone(&host), ips.clone(), Duration::from_secs(min_ttl as u64));
389                        }
390                        State::Connecting(ConnectingTcp {
391                            addrs: dns::IpAddrs::new(port, ips, shift_index),
392                            local_addr,
393                            current: None,
394                        })
395                    }
396                },
397                State::Connecting(ref mut c) => {
398                    let sock = try_ready!(c.poll(&self.handle));
399
400                    if let Some(dur) = self.keep_alive_timeout {
401                        sock.set_keepalive(Some(dur))?;
402                    }
403
404                    sock.set_nodelay(self.nodelay)?;
405
406                    return Ok(Async::Ready((sock, Connected::new())));
407                }
408                State::Error(ref mut e) => return Err(e.take().expect("polled more than once")),
409            };
410            self.state = state;
411        }
412    }
413}
414
415impl fmt::Debug for HttpConnecting {
416    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
417        f.pad("HttpConnecting")
418    }
419}
420
421struct ConnectingTcp {
422    addrs: dns::IpAddrs,
423    local_addr: Option<IpAddr>,
424    current: Option<ConnectFuture>,
425}
426
427impl ConnectingTcp {
428    // not a Future, since passing a &Handle to poll
429    fn poll(&mut self, handle: &Option<Handle>) -> Poll<TcpStream, io::Error> {
430        let mut err = None;
431        loop {
432            if let Some(ref mut current) = self.current {
433                match current.poll() {
434                    Ok(ok) => return Ok(ok),
435                    Err(e) => {
436                        trace!("connect error {:?}", e);
437                        err = Some(e);
438                        if let Some(addr) = self.addrs.next() {
439                            debug!("connecting to {}", addr);
440                            *current = connect(&addr, &self.local_addr, handle)?;
441                            continue;
442                        }
443                    }
444                }
445            } else if let Some(addr) = self.addrs.next() {
446                debug!("connecting to {}", addr);
447                self.current = Some(connect(&addr, &self.local_addr, handle)?);
448                continue;
449            }
450
451            return Err(err.take().expect("missing connect error"));
452        }
453    }
454}
455
456// Make this Future unnameable outside of this crate.
457mod http_connector {
458    use super::*;
459    // Blocking task to be executed on a thread pool.
460    pub struct HttpConnectorBlockingTask {
461        pub(super) work: oneshot::Execute<CAresFuture<AResults>>,
462    }
463
464    impl fmt::Debug for HttpConnectorBlockingTask {
465        fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
466            f.pad("HttpConnectorBlockingTask")
467        }
468    }
469
470    impl Future for HttpConnectorBlockingTask {
471        type Item = ();
472        type Error = ();
473
474        fn poll(&mut self) -> Poll<(), ()> {
475            self.work.poll()
476        }
477    }
478}
479
480#[derive(Clone)]
481struct HttpConnectExecutor(Arc<Executor<HttpConnectorBlockingTask> + Send + Sync>);
482
483impl Executor<oneshot::Execute<CAresFuture<AResults>>> for HttpConnectExecutor {
484    fn execute(
485        &self,
486        future: oneshot::Execute<CAresFuture<AResults>>,
487    ) -> Result<(), ExecuteError<oneshot::Execute<CAresFuture<AResults>>>> {
488        self.0
489            .execute(HttpConnectorBlockingTask { work: future })
490            .map_err(|err| ExecuteError::new(err.kind(), err.into_future().work))
491    }
492}
493
494// Can't use these tests because they use non-public
495// variables to construct Destination
496/*
497#[cfg(test)]
498mod tests {
499    use std::io;
500    use futures::Future;
501    use super::{Connect, Destination, HttpConnector};
502
503    #[test]
504    fn test_errors_missing_authority() {
505        let uri = "/foo/bar?baz".parse().unwrap();
506        let dst = Destination {
507            uri,
508        };
509        let connector = HttpConnector::new(1);
510
511        assert_eq!(connector.connect(dst).wait().unwrap_err().kind(), io::ErrorKind::InvalidInput);
512    }
513
514    #[test]
515    fn test_errors_enforce_http() {
516        let uri = "https://example.domain/foo/bar?baz".parse().unwrap();
517        let dst = Destination {
518            uri,
519        };
520        let connector = HttpConnector::new(1);
521
522        assert_eq!(connector.connect(dst).wait().unwrap_err().kind(), io::ErrorKind::InvalidInput);
523    }
524
525
526    #[test]
527    fn test_errors_missing_scheme() {
528        let uri = "example.domain".parse().unwrap();
529        let dst = Destination {
530            uri,
531        };
532        let connector = HttpConnector::new(1);
533
534        assert_eq!(connector.connect(dst).wait().unwrap_err().kind(), io::ErrorKind::InvalidInput);
535    }
536}
537*/