mco_http/client/
pool.rs

1//! Client Connection Pooling
2use std::borrow::ToOwned;
3use std::collections::HashMap;
4use std::fmt;
5use std::io::{self, Read, Write};
6use std::net::{SocketAddr, Shutdown};
7use std::sync::{Arc, Mutex};
8use std::sync::atomic::{AtomicBool, Ordering};
9
10use std::time::{Duration, Instant};
11
12use crate::net::{NetworkConnector, NetworkStream, DefaultConnector};
13use crate::client::scheme::Scheme;
14
15use self::stale::{StaleCheck, Stale};
16
17/// The `NetworkConnector` that behaves as a connection pool used by mco_http's `Client`.
18pub struct Pool<C: NetworkConnector> {
19    connector: C,
20    inner: Arc<Mutex<PoolImpl<<C as NetworkConnector>::Stream>>>,
21    stale_check: Option<StaleCallback<C::Stream>>,
22}
23
24/// Config options for the `Pool`.
25#[derive(Debug)]
26pub struct Config {
27    /// The maximum idle connections *per host*.
28    pub max_idle: usize,
29}
30
31impl Default for Config {
32    #[inline]
33    fn default() -> Config {
34        Config {
35            max_idle: 5,
36        }
37    }
38}
39
40// Because `Config` has all its properties public, it would be a breaking
41// change to add new ones. Sigh.
42#[derive(Debug)]
43struct Config2 {
44    idle_timeout: Option<Duration>,
45    max_idle: usize,
46}
47
48
49#[derive(Debug)]
50struct PoolImpl<S> {
51    conns: HashMap<Key, Vec<PooledStreamInner<S>>>,
52    config: Config2,
53}
54
55type Key = (String, u16, Scheme);
56
57fn key<T: Into<Scheme>>(host: &str, port: u16, scheme: T) -> Key {
58    (host.to_owned(), port, scheme.into())
59}
60
61impl Pool<DefaultConnector> {
62    /// Creates a `Pool` with a `DefaultConnector`.
63    #[inline]
64    pub fn new(config: Config) -> Pool<DefaultConnector> {
65        Pool::with_connector(config, DefaultConnector::default())
66    }
67}
68
69impl<C: NetworkConnector> Pool<C> {
70    /// Creates a `Pool` with a specified `NetworkConnector`.
71    #[inline]
72    pub fn with_connector(config: Config, connector: C) -> Pool<C> {
73        Pool {
74            connector: connector,
75            inner: Arc::new(Mutex::new(PoolImpl {
76                conns: HashMap::new(),
77                config: Config2 {
78                    idle_timeout: None,
79                    max_idle: config.max_idle,
80                }
81            })),
82            stale_check: None,
83        }
84    }
85
86    /// Set a duration for how long an idle connection is still valid.
87    pub fn set_idle_timeout(&mut self, timeout: Option<Duration>) {
88        self.inner.lock().unwrap().config.idle_timeout = timeout;
89    }
90
91    pub fn set_stale_check<F>(&mut self, callback: F)
92    where F: Fn(StaleCheck<C::Stream>) -> Stale + Send + Sync + 'static {
93        self.stale_check = Some(Box::new(callback));
94    }
95
96    /// Clear all idle connections from the Pool, closing them.
97    #[inline]
98    pub fn clear_idle(&mut self) {
99        self.inner.lock().unwrap().conns.clear();
100    }
101
102    // private
103
104    fn checkout(&self, key: &Key) -> Option<PooledStreamInner<C::Stream>> {
105        while let Some(mut inner) = self.lookup(key) {
106            if let Some(ref stale_check) = self.stale_check {
107                let dur = inner.idle.expect("idle is never missing inside pool").elapsed();
108                let arg = stale::check(&mut inner.stream, dur);
109                if stale_check(arg).is_stale() {
110                    trace!("ejecting stale connection");
111                    continue;
112                }
113            }
114            return Some(inner);
115        }
116        None
117    }
118
119
120    fn lookup(&self, key: &Key) -> Option<PooledStreamInner<C::Stream>> {
121        let mut locked = self.inner.lock().unwrap();
122        let mut should_remove = false;
123        let deadline = locked.config.idle_timeout.map(|dur| Instant::now() - dur);
124        let inner = locked.conns.get_mut(key).and_then(|vec| {
125            while let Some(inner) = vec.pop() {
126                should_remove = vec.is_empty();
127                if let Some(deadline) = deadline {
128                    if inner.idle.expect("idle is never missing inside pool") < deadline {
129                        trace!("ejecting expired connection");
130                        continue;
131                    }
132                }
133                return Some(inner);
134            }
135            None
136        });
137        if should_remove {
138            locked.conns.remove(key);
139        }
140        inner
141    }
142}
143
144impl<S> PoolImpl<S> {
145    fn reuse(&mut self, key: Key, conn: PooledStreamInner<S>) {
146        trace!("reuse {:?}", key);
147        let conns = self.conns.entry(key).or_insert(vec![]);
148        if conns.len() < self.config.max_idle {
149            conns.push(conn);
150        }
151    }
152}
153
154impl<C: NetworkConnector<Stream=S>, S: NetworkStream + Send> NetworkConnector for Pool<C> {
155    type Stream = PooledStream<S>;
156    fn connect(&self, host: &str, port: u16, scheme: &str) -> crate::Result<PooledStream<S>> {
157        let key = key(host, port, scheme);
158        let inner = match self.checkout(&key) {
159            Some(inner) => {
160                trace!("Pool had connection, using");
161                inner
162            },
163            None => PooledStreamInner {
164                key: key.clone(),
165                idle: None,
166                stream: self.connector.connect(host, port, scheme)?,
167                previous_response_expected_no_content: false,
168            }
169
170        };
171        Ok(PooledStream {
172            has_read: false,
173            inner: Some(inner),
174            is_closed: AtomicBool::new(false),
175            pool: self.inner.clone(),
176        })
177    }
178}
179
180type StaleCallback<S> = Box<dyn Fn(StaleCheck<S>) -> Stale + Send + Sync + 'static>;
181
182// private on purpose
183//
184// Yes, I know! Shame on me! This hurts docs! And it means it only
185// works with closures! I know!
186//
187// The thing is, this is experiemental. I'm not certain about the naming.
188// Or other things. So I don't really want it in the docs, yet.
189//
190// As for only working with closures, that's fine. A closure is probably
191// enough, and if it isn't, well you can grab the stream and duration and
192// pass those to a function, and then figure out whether to call stale()
193// or fresh() based on the return value.
194//
195// Point is, it's not that bad. And it's not ready to publicize.
196mod stale {
197    use std::time::Duration;
198
199    pub struct StaleCheck<'a, S: 'a> {
200        stream: &'a mut S,
201        duration: Duration,
202    }
203
204    #[inline]
205    pub fn check<'a, S: 'a>(stream: &'a mut S, dur: Duration) -> StaleCheck<'a, S> {
206        StaleCheck {
207            stream: stream,
208            duration: dur,
209        }
210    }
211
212    impl<'a, S: 'a> StaleCheck<'a, S> {
213        pub fn stream(&mut self) -> &mut S {
214            self.stream
215        }
216
217        pub fn idle_duration(&self) -> Duration {
218            self.duration
219        }
220
221        pub fn stale(self) -> Stale {
222            Stale(true)
223        }
224
225        pub fn fresh(self) -> Stale {
226            Stale(false)
227        }
228    }
229
230    pub struct Stale(bool);
231
232
233    impl Stale {
234        #[inline]
235        pub fn is_stale(self) -> bool {
236            self.0
237        }
238    }
239}
240
241
242/// A Stream that will try to be returned to the Pool when dropped.
243pub struct PooledStream<S> {
244    has_read: bool,
245    inner: Option<PooledStreamInner<S>>,
246    // mutated in &self methods
247    is_closed: AtomicBool,
248    pool: Arc<Mutex<PoolImpl<S>>>,
249}
250
251// manual impl to add the 'static bound for 1.7 compat
252impl<S> fmt::Debug for PooledStream<S> where S: fmt::Debug + 'static {
253    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
254        fmt.debug_struct("PooledStream")
255           .field("inner", &self.inner)
256           .field("has_read", &self.has_read)
257           .field("is_closed", &self.is_closed.load(Ordering::Relaxed))
258           .field("pool", &self.pool)
259           .finish()
260    }
261}
262
263impl<S: NetworkStream> PooledStream<S> {
264    /// Take the wrapped stream out of the pool completely.
265    pub fn into_inner(mut self) -> S {
266        self.inner.take().expect("PooledStream lost its inner stream").stream
267    }
268
269    /// Gets a borrowed reference to the underlying stream.
270    pub fn get_ref(&self) -> &S {
271        &self.inner.as_ref().expect("PooledStream lost its inner stream").stream
272    }
273
274    #[cfg(test)]
275    fn get_mut(&mut self) -> &mut S {
276        &mut self.inner.as_mut().expect("PooledStream lost its inner stream").stream
277    }
278}
279
280#[derive(Debug)]
281struct PooledStreamInner<S> {
282    key: Key,
283    idle: Option<Instant>,
284    stream: S,
285    previous_response_expected_no_content: bool,
286}
287
288impl<S: NetworkStream> Read for PooledStream<S> {
289    #[inline]
290    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
291        let inner = self.inner.as_mut().unwrap();
292        let n = inner.stream.read(buf)?;
293        if n == 0 {
294            // if the wrapped stream returns EOF (Ok(0)), that means the
295            // server has closed the stream. we must be sure this stream
296            // is dropped and not put back into the pool.
297            self.is_closed.store(true, Ordering::Relaxed);
298
299            // if the stream has never read bytes before, then the pooled
300            // stream may have been disconnected by the server while
301            // we checked it back out
302            if !self.has_read && inner.idle.is_some() {
303                // idle being some means this is a reused stream
304                Err(io::Error::new(
305                    io::ErrorKind::ConnectionAborted,
306                    "Pooled stream disconnected"
307                ))
308            } else {
309                Ok(0)
310            }
311        } else {
312            self.has_read = true;
313            Ok(n)
314        }
315    }
316}
317
318impl<S: NetworkStream> Write for PooledStream<S> {
319    #[inline]
320    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
321        self.inner.as_mut().unwrap().stream.write(buf)
322    }
323
324    #[inline]
325    fn flush(&mut self) -> io::Result<()> {
326        self.inner.as_mut().unwrap().stream.flush()
327    }
328}
329
330impl<S: NetworkStream> NetworkStream for PooledStream<S> {
331    #[inline]
332    fn peer_addr(&mut self) -> io::Result<SocketAddr> {
333        self.inner.as_mut().unwrap().stream.peer_addr()
334            .map_err(|e| {
335                self.is_closed.store(true, Ordering::Relaxed);
336                e
337            })
338    }
339
340    #[inline]
341    fn set_read_timeout(&self, dur: Option<Duration>) -> io::Result<()> {
342        self.inner.as_ref().unwrap().stream.set_read_timeout(dur)
343            .map_err(|e| {
344                self.is_closed.store(true, Ordering::Relaxed);
345                e
346            })
347    }
348
349    #[inline]
350    fn set_write_timeout(&self, dur: Option<Duration>) -> io::Result<()> {
351        self.inner.as_ref().unwrap().stream.set_write_timeout(dur)
352            .map_err(|e| {
353                self.is_closed.store(true, Ordering::Relaxed);
354                e
355            })
356    }
357
358    #[inline]
359    fn close(&mut self, how: Shutdown) -> io::Result<()> {
360        self.is_closed.store(true, Ordering::Relaxed);
361        self.inner.as_mut().unwrap().stream.close(how)
362    }
363
364    #[inline]
365    fn set_previous_response_expected_no_content(&mut self, expected: bool) {
366        trace!("set_previous_response_expected_no_content {}", expected);
367        self.inner.as_mut().unwrap().previous_response_expected_no_content = expected;
368    }
369
370    #[inline]
371    fn previous_response_expected_no_content(&self) -> bool {
372        let answer = self.inner.as_ref().unwrap().previous_response_expected_no_content;
373        trace!("previous_response_expected_no_content {}", answer);
374        answer
375    }
376}
377
378impl<S> Drop for PooledStream<S> {
379    fn drop(&mut self) {
380        let is_closed = self.is_closed.load(Ordering::Relaxed);
381        trace!("PooledStream.drop, is_closed={}", is_closed);
382        if !is_closed {
383            self.inner.take().map(|mut inner| {
384                let now = Instant::now();
385                inner.idle = Some(now);
386                if let Ok(mut pool) = self.pool.lock() {
387                    pool.reuse(inner.key.clone(), inner);
388                }
389                // else poisoned, give up
390            });
391        }
392    }
393}
394
395#[cfg(test)]
396mod tests {
397    use std::net::Shutdown;
398    use std::io::Read;
399    use std::time::Duration;
400    use crate::mock::{MockConnector};
401    use crate::net::{NetworkConnector, NetworkStream};
402
403    use super::{Pool, key};
404
405    macro_rules! mocked {
406        () => ({
407            Pool::with_connector(Default::default(), MockConnector)
408        })
409    }
410
411    #[test]
412    fn test_connect_and_drop() {
413        let mut pool = mocked!();
414        pool.set_idle_timeout(Some(Duration::from_millis(100)));
415        let key = key("127.0.0.1", 3000, "http");
416        let mut stream = pool.connect("127.0.0.1", 3000, "http").unwrap();
417        assert_eq!(stream.get_ref().id, 0);
418        stream.get_mut().id = 9;
419        drop(stream);
420        {
421            let locked = pool.inner.lock().unwrap();
422            assert_eq!(locked.conns.len(), 1);
423            assert_eq!(locked.conns.get(&key).unwrap().len(), 1);
424        }
425        let stream = pool.connect("127.0.0.1", 3000, "http").unwrap(); //reused
426        assert_eq!(stream.get_ref().id, 9);
427        drop(stream);
428        {
429            let locked = pool.inner.lock().unwrap();
430            assert_eq!(locked.conns.len(), 1);
431            assert_eq!(locked.conns.get(&key).unwrap().len(), 1);
432        }
433    }
434
435    #[test]
436    fn test_double_connect_reuse() {
437        let mut pool = mocked!();
438        pool.set_idle_timeout(Some(Duration::from_millis(100)));
439        let key = key("127.0.0.1", 3000, "http");
440        let stream1 = pool.connect("127.0.0.1", 3000, "http").unwrap();
441        let stream2 = pool.connect("127.0.0.1", 3000, "http").unwrap();
442        drop(stream1);
443        drop(stream2);
444        let stream1 = pool.connect("127.0.0.1", 3000, "http").unwrap();
445        {
446            let locked = pool.inner.lock().unwrap();
447            assert_eq!(locked.conns.len(), 1);
448            assert_eq!(locked.conns.get(&key).unwrap().len(), 1);
449        }
450        let _ = stream1;
451    }
452
453    #[test]
454    fn test_closed() {
455        let pool = mocked!();
456        let mut stream = pool.connect("127.0.0.1", 3000, "http").unwrap();
457        stream.close(Shutdown::Both).unwrap();
458        drop(stream);
459        let locked = pool.inner.lock().unwrap();
460        assert_eq!(locked.conns.len(), 0);
461    }
462
463    #[test]
464    fn test_eof_closes() {
465        let pool = mocked!();
466
467        let mut stream = pool.connect("127.0.0.1", 3000, "http").unwrap();
468        assert_eq!(stream.read(&mut [0]).unwrap(), 0);
469        drop(stream);
470        let locked = pool.inner.lock().unwrap();
471        assert_eq!(locked.conns.len(), 0);
472    }
473
474    #[test]
475    fn test_read_conn_aborted() {
476        let pool = mocked!();
477
478        pool.connect("127.0.0.1", 3000, "http").unwrap();
479        let mut stream = pool.connect("127.0.0.1", 3000, "http").unwrap();
480        let err = stream.read(&mut [0]).unwrap_err();
481        assert_eq!(err.kind(), ::std::io::ErrorKind::ConnectionAborted);
482        drop(stream);
483        let locked = pool.inner.lock().unwrap();
484        assert_eq!(locked.conns.len(), 0);
485    }
486
487    #[test]
488    fn test_idle_timeout() {
489        let mut pool = mocked!();
490        pool.set_idle_timeout(Some(Duration::from_millis(10)));
491        let mut stream = pool.connect("127.0.0.1", 3000, "http").unwrap();
492        assert_eq!(stream.get_ref().id, 0);
493        stream.get_mut().id = 1337;
494        drop(stream);
495        ::std::thread::sleep(Duration::from_millis(100));
496        let stream = pool.connect("127.0.0.1", 3000, "http").unwrap();
497        assert_eq!(stream.get_ref().id, 0);
498    }
499}