udp_pool/
pool.rs

1use crate::id::ID;
2use crate::{Sender, UdpTx, connection};
3use net_pool::backend::{Address, BackendState};
4use net_pool::error::Error;
5use net_pool::{debug, instrument_current_span, Strategy, tokio_spawn};
6use std::collections::HashMap;
7use std::io;
8use std::net::SocketAddr;
9use std::sync::{Arc, Mutex};
10use tokio::net::UdpSocket;
11
12/// udp连接池
13/// 连接池中的连接受max_conn数据限制
14/// 默认的keepalive是5分钟
15/// 有两种行为会导致连接池中的连接被释放
16///     1. remove_backend被调用清除后端地址, 这会导致该地址的所有连接失效
17///     2. 连接keepalive时间后没有数据发送
18pub struct Pool {
19    state: net_pool::pool::BaseState,
20    free_conn_map: Mutex<HashMap<ID, UdpTx>>,
21}
22
23impl Pool {
24    pub fn new(strategy: Arc<dyn Strategy>) -> Self {
25        let p = Pool {
26            state: net_pool::pool::BaseState::new(strategy),
27            free_conn_map: Mutex::new(HashMap::new()),
28        };
29        <Pool as net_pool::pool::Pool>::set_keepalive(
30            &p,
31            Some(std::time::Duration::from_secs(60 * 5)),
32        );
33        p
34    }
35}
36
37impl Default for Pool {
38    fn default() -> Self {
39        // robin round
40        Pool::new(Arc::new(net_pool::strategy::RRStrategy::default()))
41    }
42}
43
44impl<L: Strategy + 'static> From<L> for Pool {
45    fn from(value: L) -> Self {
46        Self::new(Arc::new(value))
47    }
48}
49
50impl net_pool::pool::Pool for Pool {
51    net_pool::macros::base_pool_impl! {state}
52
53    fn remove_backend(&self, addr: &Address) -> bool {
54        if self.state.lb_strategy.remove_backend(addr) {
55            // 清除缓存
56            self.clear_bs_tx(addr);
57            true
58        } else {
59            false
60        }
61    }
62}
63
64impl Pool {
65    fn get_tx(&self, a: &SocketAddr) -> Option<UdpTx> {
66        let id = a.into();
67        let mut guard = self.free_conn_map.lock().unwrap();
68
69        // 这里虽然删除了一些closed状态的, 但算法却有延迟性,它无法及时删除所有的closed
70        let tx = guard.get(&id)?;
71        if tx.is_closed() {
72            guard.remove(&id);
73            assert!(
74                self.state
75                    .cur_conn
76                    .fetch_sub(1, std::sync::atomic::Ordering::Relaxed)
77                    > 0
78            );
79            debug!(
80                "[udp pool] [desc] current socket count: {}",
81                self.state
82                    .cur_conn
83                    .load(std::sync::atomic::Ordering::Relaxed)
84            );
85            None
86        } else {
87            Some(tx.clone())
88        }
89    }
90
91    /// 清除所有的后端, 因为清除了连接缓存, 但可能因为外界仍持有sender所以会导致conn还无法及时退出.但cur_conn数值已发生了变化
92    fn clear_bs_tx(&self, a: &Address) {
93        let mut guard = self.free_conn_map.lock().unwrap();
94        guard.retain(|k, _| {
95            if k.get_bid() != a.hash_code() {
96                true
97            } else {
98                assert!(
99                    self.state
100                        .cur_conn
101                        .fetch_sub(1, std::sync::atomic::Ordering::Relaxed)
102                        > 0
103                );
104                false
105            }
106        });
107
108        debug!(
109            "[udp pool] [desc] current socket count: {}",
110            self.state
111                .cur_conn
112                .load(std::sync::atomic::Ordering::Relaxed)
113        );
114    }
115
116    fn add_tx(&self, id: ID, tx: UdpTx) {
117        let mut guard = self.free_conn_map.lock().unwrap();
118        guard.insert(id, tx);
119    }
120
121    fn remove_tx(&self, id: ID) {
122        let mut guard = self.free_conn_map.lock().unwrap();
123        if guard.remove(&id).is_some() {
124            assert!(
125                self.state
126                    .cur_conn
127                    .fetch_sub(1, std::sync::atomic::Ordering::Relaxed)
128                    > 0
129            );
130            debug!(
131                "[udp pool] [desc] current socket count: {}",
132                self.state
133                    .cur_conn
134                    .load(std::sync::atomic::Ordering::Relaxed)
135            );
136        }
137    }
138
139    async fn create_tx(
140        self: Arc<Self>,
141        a: SocketAddr,
142        client: Option<Arc<UdpSocket>>,
143    ) -> Result<UdpTx, Error> {
144        let mut id = ID::new(&a);
145        let bs = self
146            .state
147            .lb_strategy
148            .get_backend(&id.to_string())
149            .ok_or(Error::NoBackend)?;
150        id.set_bid(bs.hash_code());
151
152        let proxy = create_udp_socket(&bs).await?;
153        let ka = <Pool as net_pool::pool::Pool>::get_keepalive(&self);
154
155        let (tx, conn) = io(a, client, proxy, ka);
156
157        // 驱动
158        let pool = self.clone();
159        tokio_spawn! {
160            instrument_current_span! {
161                async move {
162                    let _res = conn.await;
163                    debug!("[udp pool] udp socket close: {:?}", _res);
164                    // conn可能会因为超时而退出, 但外界仍持有sender
165                    pool.remove_tx(id);
166                }
167            }
168        };
169
170        // 添加tx进去缓存时, 可能此时bs已经被删除了
171        // 由于缓存中持有tx, 所以有可能会导致在connection没有设置超时的情况下一直在运行,从而一直占用最大连接数
172        // 由于缓存和bs是通过不同的锁持有的, 所以暂时无法对它们进行一致性限制
173        self.add_tx(id, tx.clone());
174        Ok(tx)
175    }
176}
177
178pub trait UdpPool {
179    fn get(
180        self: Arc<Self>,
181        a: SocketAddr,
182        send: Option<Arc<UdpSocket>>,
183    ) -> impl Future<Output = Result<Sender, Error>> + Send;
184}
185
186impl UdpPool for Pool {
187    async fn get(
188        self: Arc<Self>,
189        a: SocketAddr,
190        send: Option<Arc<UdpSocket>>,
191    ) -> Result<Sender, Error> {
192        let tx = match self.get_tx(&a) {
193            Some(s) => Ok(s),
194            None => {
195                // 预分配数量
196                net_pool::pool::increase_current(&self.state.max_conn, &self.state.cur_conn)?;
197                self.clone().create_tx(a, send).await.map(|s| {
198                    debug!(
199                        "[udp pool] [incr] current socket count: {}",
200                        self.state
201                            .cur_conn
202                            .load(std::sync::atomic::Ordering::Relaxed)
203                    );
204                    s
205                })
206            }
207        };
208
209        if tx.is_err() {
210            assert!(
211                self.state
212                    .cur_conn
213                    .fetch_sub(1, std::sync::atomic::Ordering::Relaxed)
214                    > 0
215            );
216        }
217
218        tx.map(|tx| Sender::new(tx))
219    }
220}
221
222pub async fn get<P: UdpPool>(
223    pool: Arc<P>,
224    a: SocketAddr,
225    send: Option<Arc<UdpSocket>>,
226) -> Result<Sender, Error> {
227    UdpPool::get(pool, a, send).await
228}
229
230fn io(
231    client_addr: SocketAddr,
232    client: Option<Arc<UdpSocket>>,
233    proxy: UdpSocket,
234    keepalive: Option<std::time::Duration>,
235) -> (UdpTx, connection::Connection) {
236    let (tx, rx) = tokio::sync::mpsc::channel(20);
237    (
238        tx,
239        connection::Connection::new(client_addr, client, proxy, rx, keepalive),
240    )
241}
242
243async fn create_udp_socket(bs: &BackendState) -> Result<UdpSocket, Error> {
244    let recv = UdpSocket::bind("0.0.0.0:0").await?;
245
246    let a = to_socket_addrs(bs.get_address()).await?;
247    recv.connect(&a[..]).await?;
248
249    Ok(recv)
250}
251
252async fn to_socket_addrs(addr: &Address) -> io::Result<Vec<SocketAddr>> {
253    match addr {
254        Address::Ori(ori) => tokio::net::lookup_host(ori)
255            .await
256            .map(|a| a.into_iter().collect()),
257        Address::Addr(addr) => Ok(vec![addr.clone()]),
258    }
259}