udp_pool/
pool.rs

1use crate::id::ID;
2use crate::{Sender, UdpTx, connection};
3use net_pool::backend::{Address, BackendState};
4use net_pool::error::Error;
5use net_pool::strategy::LbStrategy;
6use net_pool::{debug, instrument_current_span, tokio_spawn, trace};
7use std::collections::HashMap;
8use std::io;
9use std::net::SocketAddr;
10use std::sync::{Arc, Mutex};
11use tokio::net::UdpSocket;
12
13/// udp连接池
14/// 连接池中的连接受max_conn数据限制
15/// 默认的keepalive是5分钟
16/// 有两种行为会导致连接池中的连接被释放
17///     1. remove_backend被调用清除后端地址, 这会导致该地址的所有连接失效
18///     2. 连接keepalive时间后没有数据发送
19pub struct Pool {
20    state: net_pool::pool::BaseState,
21    free_conn_map: Mutex<HashMap<ID, UdpTx>>,
22}
23
24impl Pool {
25    pub fn new(strategy: Arc<dyn LbStrategy>) -> Self {
26        let p = Pool {
27            state: net_pool::pool::BaseState::new(strategy),
28            free_conn_map: Mutex::new(HashMap::new()),
29        };
30        <Pool as net_pool::pool::Pool>::set_keepalive(
31            &p,
32            Some(std::time::Duration::from_secs(60 * 5)),
33        );
34        p
35    }
36}
37
38impl Default for Pool {
39    fn default() -> Self {
40        // robin round
41        Pool::new(Arc::new(net_pool::strategy::RRStrategy::default()))
42    }
43}
44
45impl net_pool::pool::Pool for Pool {
46    net_pool::macros::base_pool_impl! {state}
47
48    fn remove_backend(&self, addr: &Address) -> bool {
49        if self.state.lb_strategy.remove_backend(addr) {
50            // 清除缓存
51            self.clear_bs_tx(addr);
52            true
53        } else {
54            false
55        }
56    }
57}
58
59impl Pool {
60    fn get_tx(&self, a: &SocketAddr) -> Option<UdpTx> {
61        let id = a.into();
62        let mut guard = self.free_conn_map.lock().unwrap();
63
64        // 这里虽然删除了一些closed状态的, 但算法却有延迟性,它无法及时删除所有的closed
65        let tx = guard.get(&id)?;
66        if tx.is_closed() {
67            guard.remove(&id);
68            assert!(
69                self.state
70                    .cur_conn
71                    .fetch_sub(1, std::sync::atomic::Ordering::Relaxed)
72                    > 0
73            );
74            trace!(
75                "[udp pool] [desc] current socket count: {}",
76                self.state
77                    .cur_conn
78                    .load(std::sync::atomic::Ordering::Relaxed)
79            );
80            None
81        } else {
82            Some(tx.clone())
83        }
84    }
85
86    /// 清除所有的后端, 因为清除了连接缓存, 但可能因为外界仍持有sender所以会导致conn还无法及时退出.但cur_conn数值已发生了变化
87    fn clear_bs_tx(&self, a: &Address) {
88        let mut guard = self.free_conn_map.lock().unwrap();
89        guard.retain(|k, _| {
90            if k.get_bid() != a.hash_code() {
91                true
92            } else {
93                assert!(
94                    self.state
95                        .cur_conn
96                        .fetch_sub(1, std::sync::atomic::Ordering::Relaxed)
97                        > 0
98                );
99                false
100            }
101        });
102
103        trace!(
104            "[udp pool] [desc] current socket count: {}",
105            self.state
106                .cur_conn
107                .load(std::sync::atomic::Ordering::Relaxed)
108        );
109    }
110
111    fn add_tx(&self, id: ID, tx: UdpTx) {
112        let mut guard = self.free_conn_map.lock().unwrap();
113        guard.insert(id, tx);
114    }
115
116    fn remove_tx(&self, id: ID) {
117        let mut guard = self.free_conn_map.lock().unwrap();
118        if guard.remove(&id).is_some() {
119            assert!(
120                self.state
121                    .cur_conn
122                    .fetch_sub(1, std::sync::atomic::Ordering::Relaxed)
123                    > 0
124            );
125            trace!(
126                "[udp pool] [desc] current socket count: {}",
127                self.state
128                    .cur_conn
129                    .load(std::sync::atomic::Ordering::Relaxed)
130            );
131        }
132    }
133
134    async fn create_tx(
135        self: Arc<Self>,
136        a: SocketAddr,
137        client: Option<Arc<UdpSocket>>,
138    ) -> Result<UdpTx, Error> {
139        let mut id = ID::new(&a);
140        let bs = self
141            .state
142            .lb_strategy
143            .get_backend(&id.to_string())
144            .ok_or(Error::NoBackend)?;
145        id.set_bid(bs.hash_code());
146
147        let proxy = create_udp_socket(&bs).await?;
148        let ka = <Pool as net_pool::pool::Pool>::get_keepalive(&self);
149
150        let (tx, conn) = io(a, client, proxy, ka);
151
152        // 驱动
153        let pool = self.clone();
154        tokio_spawn! {
155            instrument_current_span! {
156                async move {
157                    let _res = conn.await;
158                    debug!("[udp pool] udp socket close: {:?}", _res);
159                    // conn可能会因为超时而退出, 但外界仍持有sender
160                    pool.remove_tx(id);
161                }
162            }
163        };
164
165        self.add_tx(id, tx.clone());
166        Ok(tx)
167    }
168}
169
170pub trait UdpPool {
171    fn get(
172        self: Arc<Self>,
173        a: SocketAddr,
174        send: Option<Arc<UdpSocket>>,
175    ) -> impl Future<Output = Result<Sender, Error>> + Send;
176}
177
178impl UdpPool for Pool {
179    async fn get(
180        self: Arc<Self>,
181        a: SocketAddr,
182        send: Option<Arc<UdpSocket>>,
183    ) -> Result<Sender, Error> {
184        let tx = match self.get_tx(&a) {
185            Some(s) => Ok(s),
186            None => {
187                // 预分配数量
188                net_pool::pool::increase_current(&self.state.max_conn, &self.state.cur_conn)?;
189                self.clone().create_tx(a, send).await.map(|s| {
190                    trace!(
191                        "[udp pool] [incr] current socket count: {}",
192                        self.state
193                            .cur_conn
194                            .load(std::sync::atomic::Ordering::Relaxed)
195                    );
196                    s
197                })
198            }
199        };
200
201        if tx.is_err() {
202            assert!(
203                self.state
204                    .cur_conn
205                    .fetch_sub(1, std::sync::atomic::Ordering::Relaxed)
206                    > 0
207            );
208        }
209
210        tx.map(|tx| Sender::new(tx))
211    }
212}
213
214pub async fn get<P: UdpPool>(
215    pool: Arc<P>,
216    a: SocketAddr,
217    send: Option<Arc<UdpSocket>>,
218) -> Result<Sender, Error> {
219    UdpPool::get(pool, a, send).await
220}
221
222fn io(
223    client_addr: SocketAddr,
224    client: Option<Arc<UdpSocket>>,
225    proxy: UdpSocket,
226    keepalive: Option<std::time::Duration>,
227) -> (UdpTx, connection::Connection) {
228    let (tx, rx) = tokio::sync::mpsc::channel(20);
229    (
230        tx,
231        connection::Connection::new(client_addr, client, proxy, rx, keepalive),
232    )
233}
234
235async fn create_udp_socket(bs: &BackendState) -> Result<UdpSocket, Error> {
236    let recv = UdpSocket::bind("0.0.0.0:0").await?;
237
238    let a = to_socket_addrs(bs.get_address()).await?;
239    recv.connect(&a[..]).await?;
240
241    Ok(recv)
242}
243
244async fn to_socket_addrs(addr: &Address) -> io::Result<Vec<SocketAddr>> {
245    async fn inner<T: tokio::net::ToSocketAddrs>(host: &T) -> io::Result<Vec<SocketAddr>> {
246        tokio::net::lookup_host(host)
247            .await
248            .map(|a| a.into_iter().collect())
249    }
250
251    match addr {
252        Address::Ori(ori) => inner(ori).await,
253        Address::Addr(addr) => inner(addr).await,
254    }
255}