1use std::borrow::ToOwned;
3use std::collections::HashMap;
4use std::fmt;
5use std::io::{self, Read, Write};
6use std::net::{SocketAddr, Shutdown};
7use std::sync::{Arc};
8use std::sync::atomic::{AtomicBool, Ordering};
9
10use std::time::{Duration, Instant};
11
12use crate::net::{NetworkConnector, NetworkStream, DefaultConnector};
13use crate::client::scheme::Scheme;
14use crate::runtime;
15
16use self::stale::{StaleCheck, Stale};
17
18pub struct Pool<C: NetworkConnector> {
20 connector: C,
21 inner: Arc<runtime::Mutex<PoolImpl<<C as NetworkConnector>::Stream>>>,
22 stale_check: Option<StaleCallback<C::Stream>>,
23}
24
25#[derive(Debug)]
27pub struct Config {
28 pub max_idle: usize,
30}
31
32impl Default for Config {
33 #[inline]
34 fn default() -> Config {
35 Config {
36 max_idle: 5,
37 }
38 }
39}
40
41#[derive(Debug)]
44struct Config2 {
45 idle_timeout: Option<Duration>,
46 max_idle: usize,
47}
48
49
50#[derive(Debug)]
51struct PoolImpl<S> {
52 conns: HashMap<Key, Vec<PooledStreamInner<S>>>,
53 config: Config2,
54}
55
56type Key = (String, u16, Scheme);
57
58fn key<T: Into<Scheme>>(host: &str, port: u16, scheme: T) -> Key {
59 (host.to_owned(), port, scheme.into())
60}
61
62impl Pool<DefaultConnector> {
63 #[inline]
65 pub fn new(config: Config) -> Pool<DefaultConnector> {
66 Pool::with_connector(config, DefaultConnector::default())
67 }
68}
69
70impl<C: NetworkConnector> Pool<C> {
71 #[inline]
73 pub fn with_connector(config: Config, connector: C) -> Pool<C> {
74 Pool {
75 connector: connector,
76 inner: Arc::new(runtime::Mutex::new(PoolImpl {
77 conns: HashMap::new(),
78 config: Config2 {
79 idle_timeout: None,
80 max_idle: config.max_idle,
81 },
82 })),
83 stale_check: None,
84 }
85 }
86
87 pub fn set_idle_timeout(&mut self, timeout: Option<Duration>) {
89 self.inner.lock().unwrap().config.idle_timeout = timeout;
90 }
91
92 pub fn set_stale_check<F>(&mut self, callback: F)
93 where F: Fn(StaleCheck<C::Stream>) -> Stale + Send + Sync + 'static {
94 self.stale_check = Some(Box::new(callback));
95 }
96
97 #[inline]
99 pub fn clear_idle(&mut self) {
100 self.inner.lock().unwrap().conns.clear();
101 }
102
103 fn checkout(&self, key: &Key) -> Option<PooledStreamInner<C::Stream>> {
106 while let Some(mut inner) = self.lookup(key) {
107 if let Some(ref stale_check) = self.stale_check {
108 let dur = inner.idle.expect("idle is never missing inside pool").elapsed();
109 let arg = stale::check(&mut inner.stream, dur);
110 if stale_check(arg).is_stale() {
111 trace!("ejecting stale connection");
112 continue;
113 }
114 }
115 return Some(inner);
116 }
117 None
118 }
119
120
121 fn lookup(&self, key: &Key) -> Option<PooledStreamInner<C::Stream>> {
122 let mut locked = self.inner.lock().unwrap();
123 let mut should_remove = false;
124 let deadline = locked.config.idle_timeout.map(|dur| Instant::now() - dur);
125 let inner = locked.conns.get_mut(key).and_then(|vec| {
126 while let Some(inner) = vec.pop() {
127 should_remove = vec.is_empty();
128 if let Some(deadline) = deadline {
129 if inner.idle.expect("idle is never missing inside pool") < deadline {
130 trace!("ejecting expired connection");
131 continue;
132 }
133 }
134 return Some(inner);
135 }
136 None
137 });
138 if should_remove {
139 locked.conns.remove(key);
140 }
141 inner
142 }
143}
144
145impl<S> PoolImpl<S> {
146 fn reuse(&mut self, key: Key, conn: PooledStreamInner<S>) {
147 trace!("reuse {:?}", key);
148 let conns = self.conns.entry(key).or_insert(vec![]);
149 if conns.len() < self.config.max_idle {
150 conns.push(conn);
151 }
152 }
153}
154
155impl<C: NetworkConnector<Stream=S>, S: NetworkStream + Send> NetworkConnector for Pool<C> {
156 type Stream = PooledStream<S>;
157 fn connect(&self, host: &str, port: u16, scheme: &str) -> crate::Result<PooledStream<S>> {
158 let key = key(host, port, scheme);
159 let inner = match self.checkout(&key) {
160 Some(inner) => {
161 trace!("Pool had connection, using");
162 inner
163 }
164 None => PooledStreamInner {
165 key: key.clone(),
166 idle: None,
167 stream: r#try!(self.connector.connect(host, port, scheme)),
168 previous_response_expected_no_content: false,
169 }
170 };
171 Ok(PooledStream {
172 has_read: false,
173 inner: Some(inner),
174 is_closed: AtomicBool::new(false),
175 pool: self.inner.clone(),
176 })
177 }
178}
179
180type StaleCallback<S> = Box<dyn Fn(StaleCheck<S>) -> Stale + Send + Sync + 'static>;
181
182mod stale {
197 use std::time::Duration;
198
199 pub struct StaleCheck<'a, S: 'a> {
200 stream: &'a mut S,
201 duration: Duration,
202 }
203
204 #[inline]
205 pub fn check<'a, S: 'a>(stream: &'a mut S, dur: Duration) -> StaleCheck<'a, S> {
206 StaleCheck {
207 stream: stream,
208 duration: dur,
209 }
210 }
211
212 impl<'a, S: 'a> StaleCheck<'a, S> {
213 pub fn stream(&mut self) -> &mut S {
214 self.stream
215 }
216
217 pub fn idle_duration(&self) -> Duration {
218 self.duration
219 }
220
221 pub fn stale(self) -> Stale {
222 Stale(true)
223 }
224
225 pub fn fresh(self) -> Stale {
226 Stale(false)
227 }
228 }
229
230 pub struct Stale(bool);
231
232
233 impl Stale {
234 #[inline]
235 pub fn is_stale(self) -> bool {
236 self.0
237 }
238 }
239}
240
241
242pub struct PooledStream<S> {
244 has_read: bool,
245 inner: Option<PooledStreamInner<S>>,
246 is_closed: AtomicBool,
248 pool: Arc<runtime::Mutex<PoolImpl<S>>>,
249}
250
251impl<S> fmt::Debug for PooledStream<S> where S: fmt::Debug + 'static {
253 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
254 fmt.debug_struct("PooledStream")
255 .field("inner", &self.inner)
256 .field("has_read", &self.has_read)
257 .field("is_closed", &self.is_closed.load(Ordering::Relaxed))
258 .field("pool", &self.pool)
259 .finish()
260 }
261}
262
263impl<S: NetworkStream> PooledStream<S> {
264 pub fn into_inner(mut self) -> S {
266 self.inner.take().expect("PooledStream lost its inner stream").stream
267 }
268
269 pub fn get_ref(&self) -> &S {
271 &self.inner.as_ref().expect("PooledStream lost its inner stream").stream
272 }
273
274 #[cfg(test)]
275 fn get_mut(&mut self) -> &mut S {
276 &mut self.inner.as_mut().expect("PooledStream lost its inner stream").stream
277 }
278}
279
280#[derive(Debug)]
281struct PooledStreamInner<S> {
282 key: Key,
283 idle: Option<Instant>,
284 stream: S,
285 previous_response_expected_no_content: bool,
286}
287
288impl<S: NetworkStream> Read for PooledStream<S> {
289 #[inline]
290 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
291 let inner = self.inner.as_mut().unwrap();
292 let n = r#try!(inner.stream.read(buf));
293 if n == 0 {
294 self.is_closed.store(true, Ordering::Relaxed);
298
299 if !self.has_read && inner.idle.is_some() {
303 Err(io::Error::new(
305 io::ErrorKind::ConnectionAborted,
306 "Pooled stream disconnected",
307 ))
308 } else {
309 Ok(0)
310 }
311 } else {
312 self.has_read = true;
313 Ok(n)
314 }
315 }
316}
317
318impl<S: NetworkStream> Write for PooledStream<S> {
319 #[inline]
320 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
321 self.inner.as_mut().unwrap().stream.write(buf)
322 }
323
324 #[inline]
325 fn flush(&mut self) -> io::Result<()> {
326 self.inner.as_mut().unwrap().stream.flush()
327 }
328}
329
330impl<S: NetworkStream> NetworkStream for PooledStream<S> {
331 #[inline]
332 fn peer_addr(&mut self) -> io::Result<SocketAddr> {
333 self.inner.as_mut().unwrap().stream.peer_addr()
334 .map_err(|e| {
335 self.is_closed.store(true, Ordering::Relaxed);
336 e
337 })
338 }
339
340 #[inline]
341 fn set_read_timeout(&self, dur: Option<Duration>) -> io::Result<()> {
342 self.inner.as_ref().unwrap().stream.set_read_timeout(dur)
343 .map_err(|e| {
344 self.is_closed.store(true, Ordering::Relaxed);
345 e
346 })
347 }
348
349 #[inline]
350 fn set_write_timeout(&self, dur: Option<Duration>) -> io::Result<()> {
351 self.inner.as_ref().unwrap().stream.set_write_timeout(dur)
352 .map_err(|e| {
353 self.is_closed.store(true, Ordering::Relaxed);
354 e
355 })
356 }
357
358 #[inline]
359 fn close(&mut self, how: Shutdown) -> io::Result<()> {
360 self.is_closed.store(true, Ordering::Relaxed);
361 self.inner.as_mut().unwrap().stream.close(how)
362 }
363
364 #[inline]
365 fn set_previous_response_expected_no_content(&mut self, expected: bool) {
366 trace!("set_previous_response_expected_no_content {}", expected);
367 self.inner.as_mut().unwrap().previous_response_expected_no_content = expected;
368 }
369
370 #[inline]
371 fn previous_response_expected_no_content(&self) -> bool {
372 let answer = self.inner.as_ref().unwrap().previous_response_expected_no_content;
373 trace!("previous_response_expected_no_content {}", answer);
374 answer
375 }
376
377 fn set_nonblocking(&self, b: bool) {
378 self.inner.as_ref().unwrap().stream.set_nonblocking(b);
379 }
380
381 fn reset_io(&self) {
382 self.inner.as_ref().unwrap().stream.reset_io();
383 }
384
385 fn wait_io(&self) {
386 self.inner.as_ref().unwrap().stream.wait_io();
387 }
388}
389
390impl<S> Drop for PooledStream<S> {
391 fn drop(&mut self) {
392 let is_closed = self.is_closed.load(Ordering::Relaxed);
393 trace!("PooledStream.drop, is_closed={}", is_closed);
394 if !is_closed {
395 self.inner.take().map(|mut inner| {
396 let now = Instant::now();
397 inner.idle = Some(now);
398 if let Ok(mut pool) = self.pool.lock() {
399 pool.reuse(inner.key.clone(), inner);
400 }
401 });
403 }
404 }
405}
406
407#[cfg(test)]
408mod tests {
409 use std::net::Shutdown;
410 use std::io::Read;
411 use std::time::Duration;
412 use crate::mock::{MockConnector};
413 use crate::net::{NetworkConnector, NetworkStream};
414
415 use super::{Pool, key};
416
417 macro_rules! mocked {
418 () => ({
419 Pool::with_connector(Default::default(), MockConnector)
420 })
421 }
422
423 #[test]
424 fn test_connect_and_drop() {
425 let mut pool = mocked!();
426 pool.set_idle_timeout(Some(Duration::from_millis(100)));
427 let key = key("127.0.0.1", 3000, "http");
428 let mut stream = pool.connect("127.0.0.1", 3000, "http").unwrap();
429 assert_eq!(stream.get_ref().id, 0);
430 stream.get_mut().id = 9;
431 drop(stream);
432 {
433 let locked = pool.inner.lock().unwrap();
434 assert_eq!(locked.conns.len(), 1);
435 assert_eq!(locked.conns.get(&key).unwrap().len(), 1);
436 }
437 let stream = pool.connect("127.0.0.1", 3000, "http").unwrap(); assert_eq!(stream.get_ref().id, 9);
439 drop(stream);
440 {
441 let locked = pool.inner.lock().unwrap();
442 assert_eq!(locked.conns.len(), 1);
443 assert_eq!(locked.conns.get(&key).unwrap().len(), 1);
444 }
445 }
446
447 #[test]
448 fn test_double_connect_reuse() {
449 let mut pool = mocked!();
450 pool.set_idle_timeout(Some(Duration::from_millis(100)));
451 let key = key("127.0.0.1", 3000, "http");
452 let stream1 = pool.connect("127.0.0.1", 3000, "http").unwrap();
453 let stream2 = pool.connect("127.0.0.1", 3000, "http").unwrap();
454 drop(stream1);
455 drop(stream2);
456 let stream1 = pool.connect("127.0.0.1", 3000, "http").unwrap();
457 {
458 let locked = pool.inner.lock().unwrap();
459 assert_eq!(locked.conns.len(), 1);
460 assert_eq!(locked.conns.get(&key).unwrap().len(), 1);
461 }
462 let _ = stream1;
463 }
464
465 #[test]
466 fn test_closed() {
467 let pool = mocked!();
468 let mut stream = pool.connect("127.0.0.1", 3000, "http").unwrap();
469 stream.close(Shutdown::Both).unwrap();
470 drop(stream);
471 let locked = pool.inner.lock().unwrap();
472 assert_eq!(locked.conns.len(), 0);
473 }
474
475 #[test]
476 fn test_eof_closes() {
477 let pool = mocked!();
478
479 let mut stream = pool.connect("127.0.0.1", 3000, "http").unwrap();
480 assert_eq!(stream.read(&mut [0]).unwrap(), 0);
481 drop(stream);
482 let locked = pool.inner.lock().unwrap();
483 assert_eq!(locked.conns.len(), 0);
484 }
485
486 #[test]
487 fn test_read_conn_aborted() {
488 let pool = mocked!();
489
490 pool.connect("127.0.0.1", 3000, "http").unwrap();
491 let mut stream = pool.connect("127.0.0.1", 3000, "http").unwrap();
492 let err = stream.read(&mut [0]).unwrap_err();
493 assert_eq!(err.kind(), ::std::io::ErrorKind::ConnectionAborted);
494 drop(stream);
495 let locked = pool.inner.lock().unwrap();
496 assert_eq!(locked.conns.len(), 0);
497 }
498
499 #[test]
500 fn test_idle_timeout() {
501 let mut pool = mocked!();
502 pool.set_idle_timeout(Some(Duration::from_millis(10)));
503 let mut stream = pool.connect("127.0.0.1", 3000, "http").unwrap();
504 assert_eq!(stream.get_ref().id, 0);
505 stream.get_mut().id = 1337;
506 drop(stream);
507 ::std::thread::sleep(Duration::from_millis(100));
508 let stream = pool.connect("127.0.0.1", 3000, "http").unwrap();
509 assert_eq!(stream.get_ref().id, 0);
510 }
511}