tcp_pool/
pool.rs

1use crate::TcpStream;
2use net_pool::backend::Address;
3use net_pool::{Error, Strategy, debug};
4use std::io;
5use std::net::SocketAddr;
6use std::sync::Arc;
7
8/// tcp连接池
9/// 连接池中的连接受max_conn数据限制
10/// keepalive数据没有任何效果
11/// 连接池中的连接不复用, get出来后tcp stream不被引用则会导致连接断开
12pub struct Pool {
13    state: net_pool::pool::BaseState,
14}
15
16impl Pool {
17    pub fn new(strategy: Arc<dyn Strategy>) -> Self {
18        Pool {
19            state: net_pool::pool::BaseState::new(strategy),
20        }
21    }
22}
23
24impl Default for Pool {
25    fn default() -> Self {
26        Pool::new(Arc::new(net_pool::strategy::HashStrategy::default()))
27    }
28}
29
30impl<L: Strategy + 'static> From<L> for Pool {
31    fn from(value: L) -> Self {
32        Self::new(Arc::new(value))
33    }
34}
35
36impl net_pool::pool::Pool for Pool {
37    net_pool::macros::base_pool_impl! {state}
38}
39
40pub trait TcpPool {
41    fn get(self: Arc<Self>, key: &str) -> impl Future<Output = Result<TcpStream, Error>> + Send;
42}
43
44impl TcpPool for Pool {
45    async fn get(self: Arc<Self>, key: &str) -> Result<TcpStream, Error> {
46        // 预分配数量
47        net_pool::pool::increase_current(&self.state.max_conn, &self.state.cur_conn)?;
48
49        let tcp = {
50            match self
51                .state
52                .lb_strategy
53                .get_backend(key)
54                .ok_or(Error::NoBackend)
55            {
56                Err(e) => Err(e),
57                Ok(bs) => create_tcp_stream(bs.get_address()).await,
58            }
59        };
60
61        if tcp.is_err() {
62            assert!(
63                self.state
64                    .cur_conn
65                    .fetch_sub(1, std::sync::atomic::Ordering::Relaxed)
66                    > 0
67            );
68        } else {
69            debug!(
70                "[tcp pool] [incr] current connection count: {}",
71                self.state
72                    .cur_conn
73                    .load(std::sync::atomic::Ordering::Relaxed)
74            );
75        }
76
77        let pool = self.clone();
78        tcp.map(|t| {
79            TcpStream::new(
80                move || {
81                    assert!(
82                        pool.state
83                            .cur_conn
84                            .fetch_sub(1, std::sync::atomic::Ordering::Relaxed)
85                            > 0
86                    );
87                    debug!(
88                        "[tcp pool] [desc] current connection count: {}",
89                        pool.state
90                            .cur_conn
91                            .load(std::sync::atomic::Ordering::Relaxed)
92                    );
93                },
94                t,
95            )
96        })
97    }
98}
99
100async fn create_tcp_stream(addrs: &Address) -> Result<tokio::net::TcpStream, Error> {
101    let a = to_socket_addrs(addrs).await?;
102    tokio::net::TcpStream::connect(&a[..])
103        .await
104        .map_err(|e| Error::from_other(e))
105}
106
107async fn to_socket_addrs(addr: &Address) -> io::Result<Vec<SocketAddr>> {
108    match addr {
109        Address::Ori(ori) => tokio::net::lookup_host(ori)
110            .await
111            .map(|a| a.into_iter().collect()),
112        Address::Addr(addr) => Ok(vec![addr.clone()]),
113    }
114}
115
116pub async fn get<P: TcpPool + Send>(pool: Arc<P>, key: &str) -> Result<TcpStream, Error> {
117    TcpPool::get(pool.clone(), key).await
118}