1use crate::id::ID;
2use crate::{Sender, UdpTx, connection};
3use net_pool::backend::{Address, BackendState};
4use net_pool::error::Error;
5use net_pool::strategy::LbStrategy;
6use net_pool::{debug, instrument_current_span, tokio_spawn, trace};
7use std::collections::HashMap;
8use std::io;
9use std::net::SocketAddr;
10use std::sync::{Arc, Mutex};
11use tokio::net::UdpSocket;
12
13pub struct Pool {
20 state: net_pool::pool::BaseState,
21 free_conn_map: Mutex<HashMap<ID, UdpTx>>,
22}
23
24impl Pool {
25 pub fn new(strategy: Arc<dyn LbStrategy>) -> Self {
26 let p = Pool {
27 state: net_pool::pool::BaseState::new(strategy),
28 free_conn_map: Mutex::new(HashMap::new()),
29 };
30 <Pool as net_pool::pool::Pool>::set_keepalive(
31 &p,
32 Some(std::time::Duration::from_secs(60 * 5)),
33 );
34 p
35 }
36}
37
38impl Default for Pool {
39 fn default() -> Self {
40 Pool::new(Arc::new(net_pool::strategy::RRStrategy::default()))
42 }
43}
44
45impl net_pool::pool::Pool for Pool {
46 net_pool::macros::base_pool_impl! {state}
47
48 fn remove_backend(&self, addr: &Address) -> bool {
49 if self.state.lb_strategy.remove_backend(addr) {
50 self.clear_bs_tx(addr);
52 true
53 } else {
54 false
55 }
56 }
57}
58
59impl Pool {
60 fn get_tx(&self, a: &SocketAddr) -> Option<UdpTx> {
61 let id = a.into();
62 let mut guard = self.free_conn_map.lock().unwrap();
63
64 let tx = guard.get(&id)?;
66 if tx.is_closed() {
67 guard.remove(&id);
68 assert!(
69 self.state
70 .cur_conn
71 .fetch_sub(1, std::sync::atomic::Ordering::Relaxed)
72 > 0
73 );
74 trace!(
75 "[udp pool] [desc] current socket count: {}",
76 self.state
77 .cur_conn
78 .load(std::sync::atomic::Ordering::Relaxed)
79 );
80 None
81 } else {
82 Some(tx.clone())
83 }
84 }
85
86 fn clear_bs_tx(&self, a: &Address) {
88 let mut guard = self.free_conn_map.lock().unwrap();
89 guard.retain(|k, _| {
90 if k.get_bid() != a.hash_code() {
91 true
92 } else {
93 assert!(
94 self.state
95 .cur_conn
96 .fetch_sub(1, std::sync::atomic::Ordering::Relaxed)
97 > 0
98 );
99 false
100 }
101 });
102
103 trace!(
104 "[udp pool] [desc] current socket count: {}",
105 self.state
106 .cur_conn
107 .load(std::sync::atomic::Ordering::Relaxed)
108 );
109 }
110
111 fn add_tx(&self, id: ID, tx: UdpTx) {
112 let mut guard = self.free_conn_map.lock().unwrap();
113 guard.insert(id, tx);
114 }
115
116 fn remove_tx(&self, id: ID) {
117 let mut guard = self.free_conn_map.lock().unwrap();
118 if guard.remove(&id).is_some() {
119 assert!(
120 self.state
121 .cur_conn
122 .fetch_sub(1, std::sync::atomic::Ordering::Relaxed)
123 > 0
124 );
125 trace!(
126 "[udp pool] [desc] current socket count: {}",
127 self.state
128 .cur_conn
129 .load(std::sync::atomic::Ordering::Relaxed)
130 );
131 }
132 }
133
134 async fn create_tx(
135 self: Arc<Self>,
136 a: SocketAddr,
137 client: Option<Arc<UdpSocket>>,
138 ) -> Result<UdpTx, Error> {
139 let mut id = ID::new(&a);
140 let bs = self
141 .state
142 .lb_strategy
143 .get_backend(&id.to_string())
144 .ok_or(Error::NoBackend)?;
145 id.set_bid(bs.hash_code());
146
147 let proxy = create_udp_socket(&bs).await?;
148 let ka = <Pool as net_pool::pool::Pool>::get_keepalive(&self);
149
150 let (tx, conn) = io(a, client, proxy, ka);
151
152 let pool = self.clone();
154 tokio_spawn! {
155 instrument_current_span! {
156 async move {
157 let _res = conn.await;
158 debug!("[udp pool] udp socket close: {:?}", _res);
159 pool.remove_tx(id);
161 }
162 }
163 };
164
165 self.add_tx(id, tx.clone());
166 Ok(tx)
167 }
168}
169
170pub trait UdpPool {
171 fn get(
172 self: Arc<Self>,
173 a: SocketAddr,
174 send: Option<Arc<UdpSocket>>,
175 ) -> impl Future<Output = Result<Sender, Error>> + Send;
176}
177
178impl UdpPool for Pool {
179 async fn get(
180 self: Arc<Self>,
181 a: SocketAddr,
182 send: Option<Arc<UdpSocket>>,
183 ) -> Result<Sender, Error> {
184 let tx = match self.get_tx(&a) {
185 Some(s) => Ok(s),
186 None => {
187 net_pool::pool::increase_current(&self.state.max_conn, &self.state.cur_conn)?;
189 self.clone().create_tx(a, send).await.map(|s| {
190 trace!(
191 "[udp pool] [incr] current socket count: {}",
192 self.state
193 .cur_conn
194 .load(std::sync::atomic::Ordering::Relaxed)
195 );
196 s
197 })
198 }
199 };
200
201 if tx.is_err() {
202 assert!(
203 self.state
204 .cur_conn
205 .fetch_sub(1, std::sync::atomic::Ordering::Relaxed)
206 > 0
207 );
208 }
209
210 tx.map(|tx| Sender::new(tx))
211 }
212}
213
214pub async fn get<P: UdpPool>(
215 pool: Arc<P>,
216 a: SocketAddr,
217 send: Option<Arc<UdpSocket>>,
218) -> Result<Sender, Error> {
219 UdpPool::get(pool, a, send).await
220}
221
222fn io(
223 client_addr: SocketAddr,
224 client: Option<Arc<UdpSocket>>,
225 proxy: UdpSocket,
226 keepalive: Option<std::time::Duration>,
227) -> (UdpTx, connection::Connection) {
228 let (tx, rx) = tokio::sync::mpsc::channel(20);
229 (
230 tx,
231 connection::Connection::new(client_addr, client, proxy, rx, keepalive),
232 )
233}
234
235async fn create_udp_socket(bs: &BackendState) -> Result<UdpSocket, Error> {
236 let recv = UdpSocket::bind("0.0.0.0:0").await?;
237
238 let a = to_socket_addrs(bs.get_address()).await?;
239 recv.connect(&a[..]).await?;
240
241 Ok(recv)
242}
243
244async fn to_socket_addrs(addr: &Address) -> io::Result<Vec<SocketAddr>> {
245 async fn inner<T: tokio::net::ToSocketAddrs>(host: &T) -> io::Result<Vec<SocketAddr>> {
246 tokio::net::lookup_host(host)
247 .await
248 .map(|a| a.into_iter().collect())
249 }
250
251 match addr {
252 Address::Ori(ori) => inner(ori).await,
253 Address::Addr(addr) => inner(addr).await,
254 }
255}