udp_pool/
pool.rs

1use crate::id::ID;
2use crate::{Sender, UdpTx, connection};
3use net_pool::backend::{Address, BackendState};
4use net_pool::error::Error;
5use net_pool::strategy::LbStrategy;
6use net_pool::{debug, instrument_current_span, tokio_spawn};
7use std::collections::HashMap;
8use std::io;
9use std::net::SocketAddr;
10use std::sync::{Arc, Mutex};
11use tokio::net::UdpSocket;
12
13/// udp连接池
14/// 连接池中的连接受max_conn数据限制
15/// 默认的keepalive是5分钟
16/// 有两种行为会导致连接池中的连接被释放
17///     1. remove_backend被调用清除后端地址, 这会导致该地址的所有连接失效
18///     2. 连接keepalive时间后没有数据发送
19pub struct Pool {
20    state: net_pool::pool::BaseState,
21    free_conn_map: Mutex<HashMap<ID, UdpTx>>,
22}
23
24impl Pool {
25    pub fn new(strategy: Arc<dyn LbStrategy>) -> Self {
26        let p = Pool {
27            state: net_pool::pool::BaseState::new(strategy),
28            free_conn_map: Mutex::new(HashMap::new()),
29        };
30        <Pool as net_pool::pool::Pool>::set_keepalive(
31            &p,
32            Some(std::time::Duration::from_secs(60 * 5)),
33        );
34        p
35    }
36}
37
38impl Default for Pool {
39    fn default() -> Self {
40        // robin round
41        Pool::new(Arc::new(net_pool::strategy::RRStrategy::default()))
42    }
43}
44
45impl<L: LbStrategy + 'static> From<L> for Pool {
46    fn from(value: L) -> Self {
47        Self::new(Arc::new(value))
48    }
49}
50
51impl net_pool::pool::Pool for Pool {
52    net_pool::macros::base_pool_impl! {state}
53
54    fn remove_backend(&self, addr: &Address) -> bool {
55        if self.state.lb_strategy.remove_backend(addr) {
56            // 清除缓存
57            self.clear_bs_tx(addr);
58            true
59        } else {
60            false
61        }
62    }
63}
64
65impl Pool {
66    fn get_tx(&self, a: &SocketAddr) -> Option<UdpTx> {
67        let id = a.into();
68        let mut guard = self.free_conn_map.lock().unwrap();
69
70        // 这里虽然删除了一些closed状态的, 但算法却有延迟性,它无法及时删除所有的closed
71        let tx = guard.get(&id)?;
72        if tx.is_closed() {
73            guard.remove(&id);
74            assert!(
75                self.state
76                    .cur_conn
77                    .fetch_sub(1, std::sync::atomic::Ordering::Relaxed)
78                    > 0
79            );
80            debug!(
81                "[udp pool] [desc] current socket count: {}",
82                self.state
83                    .cur_conn
84                    .load(std::sync::atomic::Ordering::Relaxed)
85            );
86            None
87        } else {
88            Some(tx.clone())
89        }
90    }
91
92    /// 清除所有的后端, 因为清除了连接缓存, 但可能因为外界仍持有sender所以会导致conn还无法及时退出.但cur_conn数值已发生了变化
93    fn clear_bs_tx(&self, a: &Address) {
94        let mut guard = self.free_conn_map.lock().unwrap();
95        guard.retain(|k, _| {
96            if k.get_bid() != a.hash_code() {
97                true
98            } else {
99                assert!(
100                    self.state
101                        .cur_conn
102                        .fetch_sub(1, std::sync::atomic::Ordering::Relaxed)
103                        > 0
104                );
105                false
106            }
107        });
108
109        debug!(
110            "[udp pool] [desc] current socket count: {}",
111            self.state
112                .cur_conn
113                .load(std::sync::atomic::Ordering::Relaxed)
114        );
115    }
116
117    fn add_tx(&self, id: ID, tx: UdpTx) {
118        let mut guard = self.free_conn_map.lock().unwrap();
119        guard.insert(id, tx);
120    }
121
122    fn remove_tx(&self, id: ID) {
123        let mut guard = self.free_conn_map.lock().unwrap();
124        if guard.remove(&id).is_some() {
125            assert!(
126                self.state
127                    .cur_conn
128                    .fetch_sub(1, std::sync::atomic::Ordering::Relaxed)
129                    > 0
130            );
131            debug!(
132                "[udp pool] [desc] current socket count: {}",
133                self.state
134                    .cur_conn
135                    .load(std::sync::atomic::Ordering::Relaxed)
136            );
137        }
138    }
139
140    async fn create_tx(
141        self: Arc<Self>,
142        a: SocketAddr,
143        client: Option<Arc<UdpSocket>>,
144    ) -> Result<UdpTx, Error> {
145        let mut id = ID::new(&a);
146        let bs = self
147            .state
148            .lb_strategy
149            .get_backend(&id.to_string())
150            .ok_or(Error::NoBackend)?;
151        id.set_bid(bs.hash_code());
152
153        let proxy = create_udp_socket(&bs).await?;
154        let ka = <Pool as net_pool::pool::Pool>::get_keepalive(&self);
155
156        let (tx, conn) = io(a, client, proxy, ka);
157
158        // 驱动
159        let pool = self.clone();
160        tokio_spawn! {
161            instrument_current_span! {
162                async move {
163                    let _res = conn.await;
164                    debug!("[udp pool] udp socket close: {:?}", _res);
165                    // conn可能会因为超时而退出, 但外界仍持有sender
166                    pool.remove_tx(id);
167                }
168            }
169        };
170
171        // 添加tx进去缓存时, 可能此时bs已经被删除了
172        // 由于缓存中持有tx, 所以有可能会导致在connection没有设置超时的情况下一直在运行,从而一直占用最大连接数
173        // 由于缓存和bs是通过不同的锁持有的, 所以暂时无法对它们进行一致性限制
174        self.add_tx(id, tx.clone());
175        Ok(tx)
176    }
177}
178
179pub trait UdpPool {
180    fn get(
181        self: Arc<Self>,
182        a: SocketAddr,
183        send: Option<Arc<UdpSocket>>,
184    ) -> impl Future<Output = Result<Sender, Error>> + Send;
185}
186
187impl UdpPool for Pool {
188    async fn get(
189        self: Arc<Self>,
190        a: SocketAddr,
191        send: Option<Arc<UdpSocket>>,
192    ) -> Result<Sender, Error> {
193        let tx = match self.get_tx(&a) {
194            Some(s) => Ok(s),
195            None => {
196                // 预分配数量
197                net_pool::pool::increase_current(&self.state.max_conn, &self.state.cur_conn)?;
198                self.clone().create_tx(a, send).await.map(|s| {
199                    debug!(
200                        "[udp pool] [incr] current socket count: {}",
201                        self.state
202                            .cur_conn
203                            .load(std::sync::atomic::Ordering::Relaxed)
204                    );
205                    s
206                })
207            }
208        };
209
210        if tx.is_err() {
211            assert!(
212                self.state
213                    .cur_conn
214                    .fetch_sub(1, std::sync::atomic::Ordering::Relaxed)
215                    > 0
216            );
217        }
218
219        tx.map(|tx| Sender::new(tx))
220    }
221}
222
223pub async fn get<P: UdpPool>(
224    pool: Arc<P>,
225    a: SocketAddr,
226    send: Option<Arc<UdpSocket>>,
227) -> Result<Sender, Error> {
228    UdpPool::get(pool, a, send).await
229}
230
231fn io(
232    client_addr: SocketAddr,
233    client: Option<Arc<UdpSocket>>,
234    proxy: UdpSocket,
235    keepalive: Option<std::time::Duration>,
236) -> (UdpTx, connection::Connection) {
237    let (tx, rx) = tokio::sync::mpsc::channel(20);
238    (
239        tx,
240        connection::Connection::new(client_addr, client, proxy, rx, keepalive),
241    )
242}
243
244async fn create_udp_socket(bs: &BackendState) -> Result<UdpSocket, Error> {
245    let recv = UdpSocket::bind("0.0.0.0:0").await?;
246
247    let a = to_socket_addrs(bs.get_address()).await?;
248    recv.connect(&a[..]).await?;
249
250    Ok(recv)
251}
252
253async fn to_socket_addrs(addr: &Address) -> io::Result<Vec<SocketAddr>> {
254    async fn inner<T: tokio::net::ToSocketAddrs>(host: &T) -> io::Result<Vec<SocketAddr>> {
255        tokio::net::lookup_host(host)
256            .await
257            .map(|a| a.into_iter().collect())
258    }
259
260    match addr {
261        Address::Ori(ori) => inner(ori).await,
262        Address::Addr(addr) => inner(addr).await,
263    }
264}