udp_server/
udp_serv.rs

1use async_lock::Mutex;
2use std::collections::HashMap;
3use std::convert::TryFrom;
4use std::error::Error;
5use std::future::Future;
6use std::io;
7use std::marker::PhantomData;
8use std::net::{SocketAddr, ToSocketAddrs};
9use std::sync::Arc;
10use std::time::Duration;
11
12use crate::peer::{UDPPeer, UdpPeer, UdpReader};
13use net2::{UdpBuilder, UdpSocketExt};
14use tokio::net::UdpSocket;
15use tokio::sync::mpsc::unbounded_channel;
16
17///The maximum size of a single UDP packet is 4096 by default. The MTU is generally not more than 1500 on the Internet
18///If the LAN is likely to be larger, 4096 is generally enough
19pub const BUFF_MAX_SIZE: usize = 4096;
20
21/// UDP Context
22/// each bind will create a
23pub struct UdpContext {
24    pub id: usize,
25    recv: Arc<UdpSocket>,
26    pub peers: Mutex<HashMap<SocketAddr, UDPPeer>>,
27}
28
29unsafe impl Send for UdpContext {}
30unsafe impl Sync for UdpContext {}
31
32/// UDP Server listen
33pub struct UdpServer<I, T> {
34    udp_contexts: Vec<Arc<UdpContext>>,
35    input: Arc<I>,
36    _ph: PhantomData<T>,
37    clean_sec: Option<u64>,
38}
39
40impl<I, R, T> UdpServer<I, T>
41where
42    I: Fn(UDPPeer, UdpReader, T) -> R + Send + Sync + 'static,
43    R: Future<Output = Result<(), Box<dyn Error>>> + Send + 'static,
44    T: Sync + Send + Clone + 'static,
45{
46    /// new udp server
47    pub fn new<A: ToSocketAddrs>(addr: A, input: I) -> io::Result<Self> {
48        let udp_list = create_udp_socket_list(&addr, get_cpu_count())?;
49        let udp_contexts = udp_list
50            .into_iter()
51            .enumerate()
52            .map(|(id, socket)| {
53                Arc::new(UdpContext {
54                    id,
55                    recv: Arc::new(socket),
56                    peers: Default::default(),
57                })
58            })
59            .collect();
60        Ok(UdpServer {
61            udp_contexts,
62            input: Arc::new(input),
63            _ph: Default::default(),
64            clean_sec: None,
65        })
66    }
67
68    /// set how long the packet is not obtained and close the udp peer
69    #[inline]
70    pub fn set_peer_timeout_sec(mut self, sec: u64) -> UdpServer<I, T> {
71        assert!(sec > 0);
72        self.clean_sec = Some(sec);
73        self
74    }
75
76    /// start server
77    #[inline]
78    pub async fn start(&self, inner: T) -> io::Result<()> {
79        let need_check_timeout = {
80            if let Some(clean_sec) = self.clean_sec {
81                let clean_sec = clean_sec as i64;
82                let contexts = self.udp_contexts.clone();
83                tokio::spawn(async move {
84                    loop {
85                        let current = chrono::Utc::now().timestamp();
86                        for context in contexts.iter() {
87                            context.peers.lock().await.values().for_each(|peer| {
88                                if current - peer.get_last_recv_sec() > clean_sec {
89                                    peer.close();
90                                }
91                            });
92                        }
93                        tokio::time::sleep(Duration::from_secs(1)).await
94                    }
95                });
96                true
97            } else {
98                false
99            }
100        };
101
102        let (tx, mut rx) = unbounded_channel();
103        for (index, udp_listen) in self.udp_contexts.iter().enumerate() {
104            let create_peer_tx = tx.clone();
105            let udp_context = udp_listen.clone();
106            tokio::spawn(async move {
107                log::debug!("start udp listen:{index}");
108                let mut buff = [0; BUFF_MAX_SIZE];
109                loop {
110                    match udp_context.recv.recv_from(&mut buff).await {
111                        Ok((size, addr)) => {
112                            let peer = {
113                                udp_context
114                                    .peers
115                                    .lock()
116                                    .await
117                                    .entry(addr)
118                                    .or_insert_with(|| {
119                                        let (peer, reader) =
120                                            UdpPeer::new(index, udp_context.recv.clone(), addr);
121                                        log::trace!("create udp listen:{index} udp peer:{addr}");
122                                        if let Err(err) =
123                                            create_peer_tx.send((peer.clone(), reader, index, addr))
124                                        {
125                                            panic!("create_peer_tx err:{}", err);
126                                        }
127                                        peer
128                                    })
129                                    .clone()
130                            };
131
132                            if need_check_timeout {
133                                if let Err(err) = peer
134                                    .push_data_and_update_instant(buff[..size].to_vec())
135                                    .await
136                                {
137                                    log::error!("peer push data and update instant is error:{err}");
138                                }
139                            } else if let Err(err) = peer.push_data(buff[..size].to_vec()) {
140                                log::error!("peer push data is error:{err}");
141                            }
142                        }
143                        Err(err) => {
144                            log::trace!("udp:{index} recv_from error:{err}");
145                        }
146                    }
147                }
148            });
149        }
150        drop(tx);
151
152        while let Some((peer, reader, index, addr)) = rx.recv().await {
153            let inner = inner.clone();
154            let input_fn = self.input.clone();
155            let context = self
156                .udp_contexts
157                .get(index)
158                .expect("not found context")
159                .clone();
160            tokio::spawn(async move {
161                if let Err(err) = (input_fn)(peer, reader, inner).await {
162                    log::error!("udp input error:{err}")
163                }
164                context.peers.lock().await.remove(&addr);
165            });
166        }
167        Ok(())
168    }
169}
170
171///Create udp socket for windows
172#[cfg(target_os = "windows")]
173fn make_udp_client(addr: SocketAddr) -> io::Result<std::net::UdpSocket> {
174    if addr.is_ipv4() {
175        Ok(UdpBuilder::new_v4()?.reuse_address(true)?.bind(addr)?)
176    } else if addr.is_ipv6() {
177        Ok(UdpBuilder::new_v6()?.reuse_address(true)?.bind(addr)?)
178    } else {
179        Err(io::Error::new(io::ErrorKind::Other, "not address AF_INET"))
180    }
181}
182
183///It is used to create udp sockets for non-windows. The difference from windows is that reuse_port
184#[cfg(not(target_os = "windows"))]
185fn make_udp_client(addr: SocketAddr) -> io::Result<std::net::UdpSocket> {
186    use net2::unix::UnixUdpBuilderExt;
187    if addr.is_ipv4() {
188        Ok(UdpBuilder::new_v4()?
189            .reuse_address(true)?
190            .reuse_port(true)?
191            .bind(addr)?)
192    } else if addr.is_ipv6() {
193        Ok(UdpBuilder::new_v6()?
194            .reuse_address(true)?
195            .reuse_port(true)?
196            .bind(addr)?)
197    } else {
198        Err(io::Error::new(io::ErrorKind::Other, "not address AF_INET"))
199    }
200}
201
202///Create a udp socket and set the buffer size
203fn create_udp_socket<A: ToSocketAddrs>(addr: &A) -> io::Result<std::net::UdpSocket> {
204    let addr = {
205        let mut addrs = addr.to_socket_addrs()?;
206        let addr = match addrs.next() {
207            Some(addr) => addr,
208            None => {
209                return Err(io::Error::new(
210                    io::ErrorKind::Other,
211                    "no socket addresses could be resolved",
212                ))
213            }
214        };
215        if addrs.next().is_none() {
216            Ok(addr)
217        } else {
218            Err(io::Error::new(
219                io::ErrorKind::Other,
220                "more than one address resolved",
221            ))
222        }
223    };
224    let res = make_udp_client(addr?)?;
225    res.set_send_buffer_size(1784 * 10000)?;
226    res.set_recv_buffer_size(1784 * 10000)?;
227    Ok(res)
228}
229
230/// From std socket create tokio udp socket
231fn create_async_udp_socket<A: ToSocketAddrs>(addr: &A) -> io::Result<UdpSocket> {
232    let std_sock = create_udp_socket(&addr)?;
233    std_sock.set_nonblocking(true)?;
234    let sock = UdpSocket::try_from(std_sock)?;
235    Ok(sock)
236}
237
238/// create tokio UDP socket list
239/// listen_count indicates how many UDP SOCKETS to listen
240fn create_udp_socket_list<A: ToSocketAddrs>(
241    addr: &A,
242    listen_count: usize,
243) -> io::Result<Vec<UdpSocket>> {
244    log::debug!("cpus:{listen_count}");
245    let mut listens = Vec::with_capacity(listen_count);
246    for _ in 0..listen_count {
247        let sock = create_async_udp_socket(addr)?;
248        listens.push(sock);
249    }
250    Ok(listens)
251}
252
253#[cfg(not(target_os = "windows"))]
254fn get_cpu_count() -> usize {
255    num_cpus::get()
256}
257
258#[cfg(target_os = "windows")]
259fn get_cpu_count() -> usize {
260    1
261}