cogo_http/client/
pool.rs

1//! Client Connection Pooling
2use std::borrow::ToOwned;
3use std::collections::HashMap;
4use std::fmt;
5use std::io::{self, Read, Write};
6use std::net::{SocketAddr, Shutdown};
7use std::sync::{Arc};
8use std::sync::atomic::{AtomicBool, Ordering};
9
10use std::time::{Duration, Instant};
11
12use crate::net::{NetworkConnector, NetworkStream, DefaultConnector};
13use crate::client::scheme::Scheme;
14use crate::runtime;
15
16use self::stale::{StaleCheck, Stale};
17
18/// The `NetworkConnector` that behaves as a connection pool used by hyper's `Client`.
19pub struct Pool<C: NetworkConnector> {
20    connector: C,
21    inner: Arc<runtime::Mutex<PoolImpl<<C as NetworkConnector>::Stream>>>,
22    stale_check: Option<StaleCallback<C::Stream>>,
23}
24
25/// Config options for the `Pool`.
26#[derive(Debug)]
27pub struct Config {
28    /// The maximum idle connections *per host*.
29    pub max_idle: usize,
30}
31
32impl Default for Config {
33    #[inline]
34    fn default() -> Config {
35        Config {
36            max_idle: 5,
37        }
38    }
39}
40
41// Because `Config` has all its properties public, it would be a breaking
42// change to add new ones. Sigh.
43#[derive(Debug)]
44struct Config2 {
45    idle_timeout: Option<Duration>,
46    max_idle: usize,
47}
48
49
50#[derive(Debug)]
51struct PoolImpl<S> {
52    conns: HashMap<Key, Vec<PooledStreamInner<S>>>,
53    config: Config2,
54}
55
56type Key = (String, u16, Scheme);
57
58fn key<T: Into<Scheme>>(host: &str, port: u16, scheme: T) -> Key {
59    (host.to_owned(), port, scheme.into())
60}
61
62impl Pool<DefaultConnector> {
63    /// Creates a `Pool` with a `DefaultConnector`.
64    #[inline]
65    pub fn new(config: Config) -> Pool<DefaultConnector> {
66        Pool::with_connector(config, DefaultConnector::default())
67    }
68}
69
70impl<C: NetworkConnector> Pool<C> {
71    /// Creates a `Pool` with a specified `NetworkConnector`.
72    #[inline]
73    pub fn with_connector(config: Config, connector: C) -> Pool<C> {
74        Pool {
75            connector: connector,
76            inner: Arc::new(runtime::Mutex::new(PoolImpl {
77                conns: HashMap::new(),
78                config: Config2 {
79                    idle_timeout: None,
80                    max_idle: config.max_idle,
81                },
82            })),
83            stale_check: None,
84        }
85    }
86
87    /// Set a duration for how long an idle connection is still valid.
88    pub fn set_idle_timeout(&mut self, timeout: Option<Duration>) {
89        self.inner.lock().unwrap().config.idle_timeout = timeout;
90    }
91
92    pub fn set_stale_check<F>(&mut self, callback: F)
93        where F: Fn(StaleCheck<C::Stream>) -> Stale + Send + Sync + 'static {
94        self.stale_check = Some(Box::new(callback));
95    }
96
97    /// Clear all idle connections from the Pool, closing them.
98    #[inline]
99    pub fn clear_idle(&mut self) {
100        self.inner.lock().unwrap().conns.clear();
101    }
102
103    // private
104
105    fn checkout(&self, key: &Key) -> Option<PooledStreamInner<C::Stream>> {
106        while let Some(mut inner) = self.lookup(key) {
107            if let Some(ref stale_check) = self.stale_check {
108                let dur = inner.idle.expect("idle is never missing inside pool").elapsed();
109                let arg = stale::check(&mut inner.stream, dur);
110                if stale_check(arg).is_stale() {
111                    trace!("ejecting stale connection");
112                    continue;
113                }
114            }
115            return Some(inner);
116        }
117        None
118    }
119
120
121    fn lookup(&self, key: &Key) -> Option<PooledStreamInner<C::Stream>> {
122        let mut locked = self.inner.lock().unwrap();
123        let mut should_remove = false;
124        let deadline = locked.config.idle_timeout.map(|dur| Instant::now() - dur);
125        let inner = locked.conns.get_mut(key).and_then(|vec| {
126            while let Some(inner) = vec.pop() {
127                should_remove = vec.is_empty();
128                if let Some(deadline) = deadline {
129                    if inner.idle.expect("idle is never missing inside pool") < deadline {
130                        trace!("ejecting expired connection");
131                        continue;
132                    }
133                }
134                return Some(inner);
135            }
136            None
137        });
138        if should_remove {
139            locked.conns.remove(key);
140        }
141        inner
142    }
143}
144
145impl<S> PoolImpl<S> {
146    fn reuse(&mut self, key: Key, conn: PooledStreamInner<S>) {
147        trace!("reuse {:?}", key);
148        let conns = self.conns.entry(key).or_insert(vec![]);
149        if conns.len() < self.config.max_idle {
150            conns.push(conn);
151        }
152    }
153}
154
155impl<C: NetworkConnector<Stream=S>, S: NetworkStream + Send> NetworkConnector for Pool<C> {
156    type Stream = PooledStream<S>;
157    fn connect(&self, host: &str, port: u16, scheme: &str) -> crate::Result<PooledStream<S>> {
158        let key = key(host, port, scheme);
159        let inner = match self.checkout(&key) {
160            Some(inner) => {
161                trace!("Pool had connection, using");
162                inner
163            }
164            None => PooledStreamInner {
165                key: key.clone(),
166                idle: None,
167                stream: r#try!(self.connector.connect(host, port, scheme)),
168                previous_response_expected_no_content: false,
169            }
170        };
171        Ok(PooledStream {
172            has_read: false,
173            inner: Some(inner),
174            is_closed: AtomicBool::new(false),
175            pool: self.inner.clone(),
176        })
177    }
178}
179
180type StaleCallback<S> = Box<dyn Fn(StaleCheck<S>) -> Stale + Send + Sync + 'static>;
181
182// private on purpose
183//
184// Yes, I know! Shame on me! This hurts docs! And it means it only
185// works with closures! I know!
186//
187// The thing is, this is experiemental. I'm not certain about the naming.
188// Or other things. So I don't really want it in the docs, yet.
189//
190// As for only working with closures, that's fine. A closure is probably
191// enough, and if it isn't, well you can grab the stream and duration and
192// pass those to a function, and then figure out whether to call stale()
193// or fresh() based on the return value.
194//
195// Point is, it's not that bad. And it's not ready to publicize.
196mod stale {
197    use std::time::Duration;
198
199    pub struct StaleCheck<'a, S: 'a> {
200        stream: &'a mut S,
201        duration: Duration,
202    }
203
204    #[inline]
205    pub fn check<'a, S: 'a>(stream: &'a mut S, dur: Duration) -> StaleCheck<'a, S> {
206        StaleCheck {
207            stream: stream,
208            duration: dur,
209        }
210    }
211
212    impl<'a, S: 'a> StaleCheck<'a, S> {
213        pub fn stream(&mut self) -> &mut S {
214            self.stream
215        }
216
217        pub fn idle_duration(&self) -> Duration {
218            self.duration
219        }
220
221        pub fn stale(self) -> Stale {
222            Stale(true)
223        }
224
225        pub fn fresh(self) -> Stale {
226            Stale(false)
227        }
228    }
229
230    pub struct Stale(bool);
231
232
233    impl Stale {
234        #[inline]
235        pub fn is_stale(self) -> bool {
236            self.0
237        }
238    }
239}
240
241
242/// A Stream that will try to be returned to the Pool when dropped.
243pub struct PooledStream<S> {
244    has_read: bool,
245    inner: Option<PooledStreamInner<S>>,
246    // mutated in &self methods
247    is_closed: AtomicBool,
248    pool: Arc<runtime::Mutex<PoolImpl<S>>>,
249}
250
251// manual impl to add the 'static bound for 1.7 compat
252impl<S> fmt::Debug for PooledStream<S> where S: fmt::Debug + 'static {
253    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
254        fmt.debug_struct("PooledStream")
255            .field("inner", &self.inner)
256            .field("has_read", &self.has_read)
257            .field("is_closed", &self.is_closed.load(Ordering::Relaxed))
258            .field("pool", &self.pool)
259            .finish()
260    }
261}
262
263impl<S: NetworkStream> PooledStream<S> {
264    /// Take the wrapped stream out of the pool completely.
265    pub fn into_inner(mut self) -> S {
266        self.inner.take().expect("PooledStream lost its inner stream").stream
267    }
268
269    /// Gets a borrowed reference to the underlying stream.
270    pub fn get_ref(&self) -> &S {
271        &self.inner.as_ref().expect("PooledStream lost its inner stream").stream
272    }
273
274    #[cfg(test)]
275    fn get_mut(&mut self) -> &mut S {
276        &mut self.inner.as_mut().expect("PooledStream lost its inner stream").stream
277    }
278}
279
280#[derive(Debug)]
281struct PooledStreamInner<S> {
282    key: Key,
283    idle: Option<Instant>,
284    stream: S,
285    previous_response_expected_no_content: bool,
286}
287
288impl<S: NetworkStream> Read for PooledStream<S> {
289    #[inline]
290    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
291        let inner = self.inner.as_mut().unwrap();
292        let n = r#try!(inner.stream.read(buf));
293        if n == 0 {
294            // if the wrapped stream returns EOF (Ok(0)), that means the
295            // server has closed the stream. we must be sure this stream
296            // is dropped and not put back into the pool.
297            self.is_closed.store(true, Ordering::Relaxed);
298
299            // if the stream has never read bytes before, then the pooled
300            // stream may have been disconnected by the server while
301            // we checked it back out
302            if !self.has_read && inner.idle.is_some() {
303                // idle being some means this is a reused stream
304                Err(io::Error::new(
305                    io::ErrorKind::ConnectionAborted,
306                    "Pooled stream disconnected",
307                ))
308            } else {
309                Ok(0)
310            }
311        } else {
312            self.has_read = true;
313            Ok(n)
314        }
315    }
316}
317
318impl<S: NetworkStream> Write for PooledStream<S> {
319    #[inline]
320    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
321        self.inner.as_mut().unwrap().stream.write(buf)
322    }
323
324    #[inline]
325    fn flush(&mut self) -> io::Result<()> {
326        self.inner.as_mut().unwrap().stream.flush()
327    }
328}
329
330impl<S: NetworkStream> NetworkStream for PooledStream<S> {
331    #[inline]
332    fn peer_addr(&mut self) -> io::Result<SocketAddr> {
333        self.inner.as_mut().unwrap().stream.peer_addr()
334            .map_err(|e| {
335                self.is_closed.store(true, Ordering::Relaxed);
336                e
337            })
338    }
339
340    #[inline]
341    fn set_read_timeout(&self, dur: Option<Duration>) -> io::Result<()> {
342        self.inner.as_ref().unwrap().stream.set_read_timeout(dur)
343            .map_err(|e| {
344                self.is_closed.store(true, Ordering::Relaxed);
345                e
346            })
347    }
348
349    #[inline]
350    fn set_write_timeout(&self, dur: Option<Duration>) -> io::Result<()> {
351        self.inner.as_ref().unwrap().stream.set_write_timeout(dur)
352            .map_err(|e| {
353                self.is_closed.store(true, Ordering::Relaxed);
354                e
355            })
356    }
357
358    #[inline]
359    fn close(&mut self, how: Shutdown) -> io::Result<()> {
360        self.is_closed.store(true, Ordering::Relaxed);
361        self.inner.as_mut().unwrap().stream.close(how)
362    }
363
364    #[inline]
365    fn set_previous_response_expected_no_content(&mut self, expected: bool) {
366        trace!("set_previous_response_expected_no_content {}", expected);
367        self.inner.as_mut().unwrap().previous_response_expected_no_content = expected;
368    }
369
370    #[inline]
371    fn previous_response_expected_no_content(&self) -> bool {
372        let answer = self.inner.as_ref().unwrap().previous_response_expected_no_content;
373        trace!("previous_response_expected_no_content {}", answer);
374        answer
375    }
376
377    fn set_nonblocking(&self, b: bool) {
378        self.inner.as_ref().unwrap().stream.set_nonblocking(b);
379    }
380
381    fn reset_io(&self) {
382        self.inner.as_ref().unwrap().stream.reset_io();
383    }
384
385    fn wait_io(&self) {
386        self.inner.as_ref().unwrap().stream.wait_io();
387    }
388}
389
390impl<S> Drop for PooledStream<S> {
391    fn drop(&mut self) {
392        let is_closed = self.is_closed.load(Ordering::Relaxed);
393        trace!("PooledStream.drop, is_closed={}", is_closed);
394        if !is_closed {
395            self.inner.take().map(|mut inner| {
396                let now = Instant::now();
397                inner.idle = Some(now);
398                if let Ok(mut pool) = self.pool.lock() {
399                    pool.reuse(inner.key.clone(), inner);
400                }
401                // else poisoned, give up
402            });
403        }
404    }
405}
406
407#[cfg(test)]
408mod tests {
409    use std::net::Shutdown;
410    use std::io::Read;
411    use std::time::Duration;
412    use crate::mock::{MockConnector};
413    use crate::net::{NetworkConnector, NetworkStream};
414
415    use super::{Pool, key};
416
417    macro_rules! mocked {
418        () => ({
419            Pool::with_connector(Default::default(), MockConnector)
420        })
421    }
422
423    #[test]
424    fn test_connect_and_drop() {
425        let mut pool = mocked!();
426        pool.set_idle_timeout(Some(Duration::from_millis(100)));
427        let key = key("127.0.0.1", 3000, "http");
428        let mut stream = pool.connect("127.0.0.1", 3000, "http").unwrap();
429        assert_eq!(stream.get_ref().id, 0);
430        stream.get_mut().id = 9;
431        drop(stream);
432        {
433            let locked = pool.inner.lock().unwrap();
434            assert_eq!(locked.conns.len(), 1);
435            assert_eq!(locked.conns.get(&key).unwrap().len(), 1);
436        }
437        let stream = pool.connect("127.0.0.1", 3000, "http").unwrap(); //reused
438        assert_eq!(stream.get_ref().id, 9);
439        drop(stream);
440        {
441            let locked = pool.inner.lock().unwrap();
442            assert_eq!(locked.conns.len(), 1);
443            assert_eq!(locked.conns.get(&key).unwrap().len(), 1);
444        }
445    }
446
447    #[test]
448    fn test_double_connect_reuse() {
449        let mut pool = mocked!();
450        pool.set_idle_timeout(Some(Duration::from_millis(100)));
451        let key = key("127.0.0.1", 3000, "http");
452        let stream1 = pool.connect("127.0.0.1", 3000, "http").unwrap();
453        let stream2 = pool.connect("127.0.0.1", 3000, "http").unwrap();
454        drop(stream1);
455        drop(stream2);
456        let stream1 = pool.connect("127.0.0.1", 3000, "http").unwrap();
457        {
458            let locked = pool.inner.lock().unwrap();
459            assert_eq!(locked.conns.len(), 1);
460            assert_eq!(locked.conns.get(&key).unwrap().len(), 1);
461        }
462        let _ = stream1;
463    }
464
465    #[test]
466    fn test_closed() {
467        let pool = mocked!();
468        let mut stream = pool.connect("127.0.0.1", 3000, "http").unwrap();
469        stream.close(Shutdown::Both).unwrap();
470        drop(stream);
471        let locked = pool.inner.lock().unwrap();
472        assert_eq!(locked.conns.len(), 0);
473    }
474
475    #[test]
476    fn test_eof_closes() {
477        let pool = mocked!();
478
479        let mut stream = pool.connect("127.0.0.1", 3000, "http").unwrap();
480        assert_eq!(stream.read(&mut [0]).unwrap(), 0);
481        drop(stream);
482        let locked = pool.inner.lock().unwrap();
483        assert_eq!(locked.conns.len(), 0);
484    }
485
486    #[test]
487    fn test_read_conn_aborted() {
488        let pool = mocked!();
489
490        pool.connect("127.0.0.1", 3000, "http").unwrap();
491        let mut stream = pool.connect("127.0.0.1", 3000, "http").unwrap();
492        let err = stream.read(&mut [0]).unwrap_err();
493        assert_eq!(err.kind(), ::std::io::ErrorKind::ConnectionAborted);
494        drop(stream);
495        let locked = pool.inner.lock().unwrap();
496        assert_eq!(locked.conns.len(), 0);
497    }
498
499    #[test]
500    fn test_idle_timeout() {
501        let mut pool = mocked!();
502        pool.set_idle_timeout(Some(Duration::from_millis(10)));
503        let mut stream = pool.connect("127.0.0.1", 3000, "http").unwrap();
504        assert_eq!(stream.get_ref().id, 0);
505        stream.get_mut().id = 1337;
506        drop(stream);
507        ::std::thread::sleep(Duration::from_millis(100));
508        let stream = pool.connect("127.0.0.1", 3000, "http").unwrap();
509        assert_eq!(stream.get_ref().id, 0);
510    }
511}