1use std::borrow::ToOwned;
3use std::collections::HashMap;
4use std::fmt;
5use std::io::{self, Read, Write};
6use std::net::{SocketAddr, Shutdown};
7use std::sync::{Arc, Mutex};
8use std::sync::atomic::{AtomicBool, Ordering};
9
10use std::time::{Duration, Instant};
11
12use crate::net::{NetworkConnector, NetworkStream, DefaultConnector};
13use crate::client::scheme::Scheme;
14
15use self::stale::{StaleCheck, Stale};
16
17pub struct Pool<C: NetworkConnector> {
19 connector: C,
20 inner: Arc<Mutex<PoolImpl<<C as NetworkConnector>::Stream>>>,
21 stale_check: Option<StaleCallback<C::Stream>>,
22}
23
24#[derive(Debug)]
26pub struct Config {
27 pub max_idle: usize,
29}
30
31impl Default for Config {
32 #[inline]
33 fn default() -> Config {
34 Config {
35 max_idle: 5,
36 }
37 }
38}
39
40#[derive(Debug)]
43struct Config2 {
44 idle_timeout: Option<Duration>,
45 max_idle: usize,
46}
47
48
49#[derive(Debug)]
50struct PoolImpl<S> {
51 conns: HashMap<Key, Vec<PooledStreamInner<S>>>,
52 config: Config2,
53}
54
55type Key = (String, u16, Scheme);
56
57fn key<T: Into<Scheme>>(host: &str, port: u16, scheme: T) -> Key {
58 (host.to_owned(), port, scheme.into())
59}
60
61impl Pool<DefaultConnector> {
62 #[inline]
64 pub fn new(config: Config) -> Pool<DefaultConnector> {
65 Pool::with_connector(config, DefaultConnector::default())
66 }
67}
68
69impl<C: NetworkConnector> Pool<C> {
70 #[inline]
72 pub fn with_connector(config: Config, connector: C) -> Pool<C> {
73 Pool {
74 connector: connector,
75 inner: Arc::new(Mutex::new(PoolImpl {
76 conns: HashMap::new(),
77 config: Config2 {
78 idle_timeout: None,
79 max_idle: config.max_idle,
80 }
81 })),
82 stale_check: None,
83 }
84 }
85
86 pub fn set_idle_timeout(&mut self, timeout: Option<Duration>) {
88 self.inner.lock().unwrap().config.idle_timeout = timeout;
89 }
90
91 pub fn set_stale_check<F>(&mut self, callback: F)
92 where F: Fn(StaleCheck<C::Stream>) -> Stale + Send + Sync + 'static {
93 self.stale_check = Some(Box::new(callback));
94 }
95
96 #[inline]
98 pub fn clear_idle(&mut self) {
99 self.inner.lock().unwrap().conns.clear();
100 }
101
102 fn checkout(&self, key: &Key) -> Option<PooledStreamInner<C::Stream>> {
105 while let Some(mut inner) = self.lookup(key) {
106 if let Some(ref stale_check) = self.stale_check {
107 let dur = inner.idle.expect("idle is never missing inside pool").elapsed();
108 let arg = stale::check(&mut inner.stream, dur);
109 if stale_check(arg).is_stale() {
110 trace!("ejecting stale connection");
111 continue;
112 }
113 }
114 return Some(inner);
115 }
116 None
117 }
118
119
120 fn lookup(&self, key: &Key) -> Option<PooledStreamInner<C::Stream>> {
121 let mut locked = self.inner.lock().unwrap();
122 let mut should_remove = false;
123 let deadline = locked.config.idle_timeout.map(|dur| Instant::now() - dur);
124 let inner = locked.conns.get_mut(key).and_then(|vec| {
125 while let Some(inner) = vec.pop() {
126 should_remove = vec.is_empty();
127 if let Some(deadline) = deadline {
128 if inner.idle.expect("idle is never missing inside pool") < deadline {
129 trace!("ejecting expired connection");
130 continue;
131 }
132 }
133 return Some(inner);
134 }
135 None
136 });
137 if should_remove {
138 locked.conns.remove(key);
139 }
140 inner
141 }
142}
143
144impl<S> PoolImpl<S> {
145 fn reuse(&mut self, key: Key, conn: PooledStreamInner<S>) {
146 trace!("reuse {:?}", key);
147 let conns = self.conns.entry(key).or_insert(vec![]);
148 if conns.len() < self.config.max_idle {
149 conns.push(conn);
150 }
151 }
152}
153
154impl<C: NetworkConnector<Stream=S>, S: NetworkStream + Send> NetworkConnector for Pool<C> {
155 type Stream = PooledStream<S>;
156 fn connect(&self, host: &str, port: u16, scheme: &str) -> crate::Result<PooledStream<S>> {
157 let key = key(host, port, scheme);
158 let inner = match self.checkout(&key) {
159 Some(inner) => {
160 trace!("Pool had connection, using");
161 inner
162 },
163 None => PooledStreamInner {
164 key: key.clone(),
165 idle: None,
166 stream: self.connector.connect(host, port, scheme)?,
167 previous_response_expected_no_content: false,
168 }
169
170 };
171 Ok(PooledStream {
172 has_read: false,
173 inner: Some(inner),
174 is_closed: AtomicBool::new(false),
175 pool: self.inner.clone(),
176 })
177 }
178}
179
180type StaleCallback<S> = Box<dyn Fn(StaleCheck<S>) -> Stale + Send + Sync + 'static>;
181
182mod stale {
197 use std::time::Duration;
198
199 pub struct StaleCheck<'a, S: 'a> {
200 stream: &'a mut S,
201 duration: Duration,
202 }
203
204 #[inline]
205 pub fn check<'a, S: 'a>(stream: &'a mut S, dur: Duration) -> StaleCheck<'a, S> {
206 StaleCheck {
207 stream: stream,
208 duration: dur,
209 }
210 }
211
212 impl<'a, S: 'a> StaleCheck<'a, S> {
213 pub fn stream(&mut self) -> &mut S {
214 self.stream
215 }
216
217 pub fn idle_duration(&self) -> Duration {
218 self.duration
219 }
220
221 pub fn stale(self) -> Stale {
222 Stale(true)
223 }
224
225 pub fn fresh(self) -> Stale {
226 Stale(false)
227 }
228 }
229
230 pub struct Stale(bool);
231
232
233 impl Stale {
234 #[inline]
235 pub fn is_stale(self) -> bool {
236 self.0
237 }
238 }
239}
240
241
242pub struct PooledStream<S> {
244 has_read: bool,
245 inner: Option<PooledStreamInner<S>>,
246 is_closed: AtomicBool,
248 pool: Arc<Mutex<PoolImpl<S>>>,
249}
250
251impl<S> fmt::Debug for PooledStream<S> where S: fmt::Debug + 'static {
253 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
254 fmt.debug_struct("PooledStream")
255 .field("inner", &self.inner)
256 .field("has_read", &self.has_read)
257 .field("is_closed", &self.is_closed.load(Ordering::Relaxed))
258 .field("pool", &self.pool)
259 .finish()
260 }
261}
262
263impl<S: NetworkStream> PooledStream<S> {
264 pub fn into_inner(mut self) -> S {
266 self.inner.take().expect("PooledStream lost its inner stream").stream
267 }
268
269 pub fn get_ref(&self) -> &S {
271 &self.inner.as_ref().expect("PooledStream lost its inner stream").stream
272 }
273
274 #[cfg(test)]
275 fn get_mut(&mut self) -> &mut S {
276 &mut self.inner.as_mut().expect("PooledStream lost its inner stream").stream
277 }
278}
279
280#[derive(Debug)]
281struct PooledStreamInner<S> {
282 key: Key,
283 idle: Option<Instant>,
284 stream: S,
285 previous_response_expected_no_content: bool,
286}
287
288impl<S: NetworkStream> Read for PooledStream<S> {
289 #[inline]
290 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
291 let inner = self.inner.as_mut().unwrap();
292 let n = inner.stream.read(buf)?;
293 if n == 0 {
294 self.is_closed.store(true, Ordering::Relaxed);
298
299 if !self.has_read && inner.idle.is_some() {
303 Err(io::Error::new(
305 io::ErrorKind::ConnectionAborted,
306 "Pooled stream disconnected"
307 ))
308 } else {
309 Ok(0)
310 }
311 } else {
312 self.has_read = true;
313 Ok(n)
314 }
315 }
316}
317
318impl<S: NetworkStream> Write for PooledStream<S> {
319 #[inline]
320 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
321 self.inner.as_mut().unwrap().stream.write(buf)
322 }
323
324 #[inline]
325 fn flush(&mut self) -> io::Result<()> {
326 self.inner.as_mut().unwrap().stream.flush()
327 }
328}
329
330impl<S: NetworkStream> NetworkStream for PooledStream<S> {
331 #[inline]
332 fn peer_addr(&mut self) -> io::Result<SocketAddr> {
333 self.inner.as_mut().unwrap().stream.peer_addr()
334 .map_err(|e| {
335 self.is_closed.store(true, Ordering::Relaxed);
336 e
337 })
338 }
339
340 #[inline]
341 fn set_read_timeout(&self, dur: Option<Duration>) -> io::Result<()> {
342 self.inner.as_ref().unwrap().stream.set_read_timeout(dur)
343 .map_err(|e| {
344 self.is_closed.store(true, Ordering::Relaxed);
345 e
346 })
347 }
348
349 #[inline]
350 fn set_write_timeout(&self, dur: Option<Duration>) -> io::Result<()> {
351 self.inner.as_ref().unwrap().stream.set_write_timeout(dur)
352 .map_err(|e| {
353 self.is_closed.store(true, Ordering::Relaxed);
354 e
355 })
356 }
357
358 #[inline]
359 fn close(&mut self, how: Shutdown) -> io::Result<()> {
360 self.is_closed.store(true, Ordering::Relaxed);
361 self.inner.as_mut().unwrap().stream.close(how)
362 }
363
364 #[inline]
365 fn set_previous_response_expected_no_content(&mut self, expected: bool) {
366 trace!("set_previous_response_expected_no_content {}", expected);
367 self.inner.as_mut().unwrap().previous_response_expected_no_content = expected;
368 }
369
370 #[inline]
371 fn previous_response_expected_no_content(&self) -> bool {
372 let answer = self.inner.as_ref().unwrap().previous_response_expected_no_content;
373 trace!("previous_response_expected_no_content {}", answer);
374 answer
375 }
376}
377
378impl<S> Drop for PooledStream<S> {
379 fn drop(&mut self) {
380 let is_closed = self.is_closed.load(Ordering::Relaxed);
381 trace!("PooledStream.drop, is_closed={}", is_closed);
382 if !is_closed {
383 self.inner.take().map(|mut inner| {
384 let now = Instant::now();
385 inner.idle = Some(now);
386 if let Ok(mut pool) = self.pool.lock() {
387 pool.reuse(inner.key.clone(), inner);
388 }
389 });
391 }
392 }
393}
394
395#[cfg(test)]
396mod tests {
397 use std::net::Shutdown;
398 use std::io::Read;
399 use std::time::Duration;
400 use crate::mock::{MockConnector};
401 use crate::net::{NetworkConnector, NetworkStream};
402
403 use super::{Pool, key};
404
405 macro_rules! mocked {
406 () => ({
407 Pool::with_connector(Default::default(), MockConnector)
408 })
409 }
410
411 #[test]
412 fn test_connect_and_drop() {
413 let mut pool = mocked!();
414 pool.set_idle_timeout(Some(Duration::from_millis(100)));
415 let key = key("127.0.0.1", 3000, "http");
416 let mut stream = pool.connect("127.0.0.1", 3000, "http").unwrap();
417 assert_eq!(stream.get_ref().id, 0);
418 stream.get_mut().id = 9;
419 drop(stream);
420 {
421 let locked = pool.inner.lock().unwrap();
422 assert_eq!(locked.conns.len(), 1);
423 assert_eq!(locked.conns.get(&key).unwrap().len(), 1);
424 }
425 let stream = pool.connect("127.0.0.1", 3000, "http").unwrap(); assert_eq!(stream.get_ref().id, 9);
427 drop(stream);
428 {
429 let locked = pool.inner.lock().unwrap();
430 assert_eq!(locked.conns.len(), 1);
431 assert_eq!(locked.conns.get(&key).unwrap().len(), 1);
432 }
433 }
434
435 #[test]
436 fn test_double_connect_reuse() {
437 let mut pool = mocked!();
438 pool.set_idle_timeout(Some(Duration::from_millis(100)));
439 let key = key("127.0.0.1", 3000, "http");
440 let stream1 = pool.connect("127.0.0.1", 3000, "http").unwrap();
441 let stream2 = pool.connect("127.0.0.1", 3000, "http").unwrap();
442 drop(stream1);
443 drop(stream2);
444 let stream1 = pool.connect("127.0.0.1", 3000, "http").unwrap();
445 {
446 let locked = pool.inner.lock().unwrap();
447 assert_eq!(locked.conns.len(), 1);
448 assert_eq!(locked.conns.get(&key).unwrap().len(), 1);
449 }
450 let _ = stream1;
451 }
452
453 #[test]
454 fn test_closed() {
455 let pool = mocked!();
456 let mut stream = pool.connect("127.0.0.1", 3000, "http").unwrap();
457 stream.close(Shutdown::Both).unwrap();
458 drop(stream);
459 let locked = pool.inner.lock().unwrap();
460 assert_eq!(locked.conns.len(), 0);
461 }
462
463 #[test]
464 fn test_eof_closes() {
465 let pool = mocked!();
466
467 let mut stream = pool.connect("127.0.0.1", 3000, "http").unwrap();
468 assert_eq!(stream.read(&mut [0]).unwrap(), 0);
469 drop(stream);
470 let locked = pool.inner.lock().unwrap();
471 assert_eq!(locked.conns.len(), 0);
472 }
473
474 #[test]
475 fn test_read_conn_aborted() {
476 let pool = mocked!();
477
478 pool.connect("127.0.0.1", 3000, "http").unwrap();
479 let mut stream = pool.connect("127.0.0.1", 3000, "http").unwrap();
480 let err = stream.read(&mut [0]).unwrap_err();
481 assert_eq!(err.kind(), ::std::io::ErrorKind::ConnectionAborted);
482 drop(stream);
483 let locked = pool.inner.lock().unwrap();
484 assert_eq!(locked.conns.len(), 0);
485 }
486
487 #[test]
488 fn test_idle_timeout() {
489 let mut pool = mocked!();
490 pool.set_idle_timeout(Some(Duration::from_millis(10)));
491 let mut stream = pool.connect("127.0.0.1", 3000, "http").unwrap();
492 assert_eq!(stream.get_ref().id, 0);
493 stream.get_mut().id = 1337;
494 drop(stream);
495 ::std::thread::sleep(Duration::from_millis(100));
496 let stream = pool.connect("127.0.0.1", 3000, "http").unwrap();
497 assert_eq!(stream.get_ref().id, 0);
498 }
499}