1use async_lock::Mutex;
2use std::collections::HashMap;
3use std::convert::TryFrom;
4use std::error::Error;
5use std::future::Future;
6use std::io;
7use std::marker::PhantomData;
8use std::net::{SocketAddr, ToSocketAddrs};
9use std::sync::Arc;
10use std::time::Duration;
11
12use crate::peer::{UDPPeer, UdpPeer, UdpReader};
13use net2::{UdpBuilder, UdpSocketExt};
14use tokio::net::UdpSocket;
15use tokio::sync::mpsc::unbounded_channel;
16
17pub const BUFF_MAX_SIZE: usize = 4096;
20
21pub struct UdpContext {
24 pub id: usize,
25 recv: Arc<UdpSocket>,
26 pub peers: Mutex<HashMap<SocketAddr, UDPPeer>>,
27}
28
29unsafe impl Send for UdpContext {}
30unsafe impl Sync for UdpContext {}
31
32pub struct UdpServer<I, T> {
34 udp_contexts: Vec<Arc<UdpContext>>,
35 input: Arc<I>,
36 _ph: PhantomData<T>,
37 clean_sec: Option<u64>,
38}
39
40impl<I, R, T> UdpServer<I, T>
41where
42 I: Fn(UDPPeer, UdpReader, T) -> R + Send + Sync + 'static,
43 R: Future<Output = Result<(), Box<dyn Error>>> + Send + 'static,
44 T: Sync + Send + Clone + 'static,
45{
46 pub fn new<A: ToSocketAddrs>(addr: A, input: I) -> io::Result<Self> {
48 let udp_list = create_udp_socket_list(&addr, get_cpu_count())?;
49 let udp_contexts = udp_list
50 .into_iter()
51 .enumerate()
52 .map(|(id, socket)| {
53 Arc::new(UdpContext {
54 id,
55 recv: Arc::new(socket),
56 peers: Default::default(),
57 })
58 })
59 .collect();
60 Ok(UdpServer {
61 udp_contexts,
62 input: Arc::new(input),
63 _ph: Default::default(),
64 clean_sec: None,
65 })
66 }
67
68 #[inline]
70 pub fn set_peer_timeout_sec(mut self, sec: u64) -> UdpServer<I, T> {
71 assert!(sec > 0);
72 self.clean_sec = Some(sec);
73 self
74 }
75
76 #[inline]
78 pub async fn start(&self, inner: T) -> io::Result<()> {
79 let need_check_timeout = {
80 if let Some(clean_sec) = self.clean_sec {
81 let clean_sec = clean_sec as i64;
82 let contexts = self.udp_contexts.clone();
83 tokio::spawn(async move {
84 loop {
85 let current = chrono::Utc::now().timestamp();
86 for context in contexts.iter() {
87 context.peers.lock().await.values().for_each(|peer| {
88 if current - peer.get_last_recv_sec() > clean_sec {
89 peer.close();
90 }
91 });
92 }
93 tokio::time::sleep(Duration::from_secs(1)).await
94 }
95 });
96 true
97 } else {
98 false
99 }
100 };
101
102 let (tx, mut rx) = unbounded_channel();
103 for (index, udp_listen) in self.udp_contexts.iter().enumerate() {
104 let create_peer_tx = tx.clone();
105 let udp_context = udp_listen.clone();
106 tokio::spawn(async move {
107 log::debug!("start udp listen:{index}");
108 let mut buff = [0; BUFF_MAX_SIZE];
109 loop {
110 match udp_context.recv.recv_from(&mut buff).await {
111 Ok((size, addr)) => {
112 let peer = {
113 udp_context
114 .peers
115 .lock()
116 .await
117 .entry(addr)
118 .or_insert_with(|| {
119 let (peer, reader) =
120 UdpPeer::new(index, udp_context.recv.clone(), addr);
121 log::trace!("create udp listen:{index} udp peer:{addr}");
122 if let Err(err) =
123 create_peer_tx.send((peer.clone(), reader, index, addr))
124 {
125 panic!("create_peer_tx err:{}", err);
126 }
127 peer
128 })
129 .clone()
130 };
131
132 if need_check_timeout {
133 if let Err(err) = peer
134 .push_data_and_update_instant(buff[..size].to_vec())
135 .await
136 {
137 log::error!("peer push data and update instant is error:{err}");
138 }
139 } else if let Err(err) = peer.push_data(buff[..size].to_vec()) {
140 log::error!("peer push data is error:{err}");
141 }
142 }
143 Err(err) => {
144 log::trace!("udp:{index} recv_from error:{err}");
145 }
146 }
147 }
148 });
149 }
150 drop(tx);
151
152 while let Some((peer, reader, index, addr)) = rx.recv().await {
153 let inner = inner.clone();
154 let input_fn = self.input.clone();
155 let context = self
156 .udp_contexts
157 .get(index)
158 .expect("not found context")
159 .clone();
160 tokio::spawn(async move {
161 if let Err(err) = (input_fn)(peer, reader, inner).await {
162 log::error!("udp input error:{err}")
163 }
164 context.peers.lock().await.remove(&addr);
165 });
166 }
167 Ok(())
168 }
169}
170
171#[cfg(target_os = "windows")]
173fn make_udp_client(addr: SocketAddr) -> io::Result<std::net::UdpSocket> {
174 if addr.is_ipv4() {
175 Ok(UdpBuilder::new_v4()?.reuse_address(true)?.bind(addr)?)
176 } else if addr.is_ipv6() {
177 Ok(UdpBuilder::new_v6()?.reuse_address(true)?.bind(addr)?)
178 } else {
179 Err(io::Error::new(io::ErrorKind::Other, "not address AF_INET"))
180 }
181}
182
183#[cfg(not(target_os = "windows"))]
185fn make_udp_client(addr: SocketAddr) -> io::Result<std::net::UdpSocket> {
186 use net2::unix::UnixUdpBuilderExt;
187 if addr.is_ipv4() {
188 Ok(UdpBuilder::new_v4()?
189 .reuse_address(true)?
190 .reuse_port(true)?
191 .bind(addr)?)
192 } else if addr.is_ipv6() {
193 Ok(UdpBuilder::new_v6()?
194 .reuse_address(true)?
195 .reuse_port(true)?
196 .bind(addr)?)
197 } else {
198 Err(io::Error::new(io::ErrorKind::Other, "not address AF_INET"))
199 }
200}
201
202fn create_udp_socket<A: ToSocketAddrs>(addr: &A) -> io::Result<std::net::UdpSocket> {
204 let addr = {
205 let mut addrs = addr.to_socket_addrs()?;
206 let addr = match addrs.next() {
207 Some(addr) => addr,
208 None => {
209 return Err(io::Error::new(
210 io::ErrorKind::Other,
211 "no socket addresses could be resolved",
212 ))
213 }
214 };
215 if addrs.next().is_none() {
216 Ok(addr)
217 } else {
218 Err(io::Error::new(
219 io::ErrorKind::Other,
220 "more than one address resolved",
221 ))
222 }
223 };
224 let res = make_udp_client(addr?)?;
225 res.set_send_buffer_size(1784 * 10000)?;
226 res.set_recv_buffer_size(1784 * 10000)?;
227 Ok(res)
228}
229
230fn create_async_udp_socket<A: ToSocketAddrs>(addr: &A) -> io::Result<UdpSocket> {
232 let std_sock = create_udp_socket(&addr)?;
233 std_sock.set_nonblocking(true)?;
234 let sock = UdpSocket::try_from(std_sock)?;
235 Ok(sock)
236}
237
238fn create_udp_socket_list<A: ToSocketAddrs>(
241 addr: &A,
242 listen_count: usize,
243) -> io::Result<Vec<UdpSocket>> {
244 log::debug!("cpus:{listen_count}");
245 let mut listens = Vec::with_capacity(listen_count);
246 for _ in 0..listen_count {
247 let sock = create_async_udp_socket(addr)?;
248 listens.push(sock);
249 }
250 Ok(listens)
251}
252
253#[cfg(not(target_os = "windows"))]
254fn get_cpu_count() -> usize {
255 num_cpus::get()
256}
257
258#[cfg(target_os = "windows")]
259fn get_cpu_count() -> usize {
260 1
261}