use super::error;
use super::send::{SendPool, SendUDP};
use net2::{UdpBuilder, UdpSocketExt};
use std::cell::{RefCell, UnsafeCell};
use std::convert::TryFrom;
use std::error::Error;
use std::future::Future;
use std::net::{SocketAddr, ToSocketAddrs};
use std::sync::Arc;
use tokio::net::UdpSocket;
use tokio::sync::mpsc::{channel, Sender, UnboundedSender};
use tokio::sync::Mutex;
#[cfg(not(target_os = "windows"))]
use net2::unix::UnixUdpBuilderExt;
use std::fmt::{Debug, Formatter};
use std::marker::PhantomData;
pub enum RecvType {
INPUT(UnboundedSender<(Vec<u8>, SocketAddr)>, SocketAddr, Vec<u8>),
REMOVE(u32),
}
impl Debug for RecvType {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
RecvType::INPUT(_, _, _) => f.debug_struct("INPUT").finish(),
RecvType::REMOVE(_) => f.debug_struct("REMOVE").finish(),
}
}
}
pub const BUFF_MAX_SIZE: usize = 4096;
pub struct UdpContext {
pub id: usize,
recv: RefCell<Option<Arc<UdpSocket>>>,
pub send: SendPool,
}
unsafe impl Send for UdpContext {}
unsafe impl Sync for UdpContext {}
pub type ErrorInput = Arc<Mutex<dyn Fn(Option<SocketAddr>, Box<dyn Error>) -> bool + Send>>;
pub struct UdpServer<I, R, S> {
inner: Arc<S>,
udp_contexts: Vec<UdpContext>,
input: Option<Arc<I>>,
error_input: Option<ErrorInput>,
remove_event: Option<Box<dyn Fn(u32)>>,
msg_tx: UnsafeCell<Option<Sender<RecvType>>>,
phantom:PhantomData<R>
}
unsafe impl<I, R, S> Send for UdpServer<I, R, S> {}
unsafe impl<I, R, S> Sync for UdpServer<I, R, S> {}
#[derive(Debug)]
pub struct TokenStore<T>(pub Option<T>);
impl<T: Send> TokenStore<T> {
pub fn have(&self) -> bool {
self.0.is_some()
}
pub fn get(&mut self) -> Option<&mut T> {
self.0.as_mut()
}
pub fn set(&mut self, v: Option<T>) {
self.0 = v;
}
}
#[derive(Debug)]
pub struct Peer<T> {
pub socket_id: usize,
pub addr: SocketAddr,
pub token: Arc<Mutex<TokenStore<T>>>,
pub udp_sock: SendUDP,
}
impl<T: Send> Peer<T> {
pub async fn send(&self, data: Vec<u8>) -> Result<(), Box<dyn Error>> {
self.udp_sock.send((data, self.addr))?;
Ok(())
}
}
impl<I, R> UdpServer<I, R, ()>
where
I: Fn(Arc<()>, SendUDP, SocketAddr, Vec<u8>) -> R + Send + Sync + 'static,
R: Future<Output = Result<(), Box<dyn Error>>> + Send,
{
pub async fn new<A: ToSocketAddrs>(addr: A) -> Result<Self, Box<dyn Error>> {
Self::new_inner(addr, Arc::new(())).await
}
}
impl<I, R, S> UdpServer<I, R, S>
where
I: Fn(Arc<S>, SendUDP, SocketAddr, Vec<u8>) -> R + Send + Sync + 'static,
R: Future<Output = Result<(), Box<dyn Error>>> + Send,
S: Sync + Send + 'static,
{
#[cfg(not(target_os = "windows"))]
fn make_udp_client<A: ToSocketAddrs>(addr: &A) -> Result<std::net::UdpSocket, Box<dyn Error>> {
let res = UdpBuilder::new_v4()?
.reuse_address(true)?
.reuse_port(true)?
.bind(addr)?;
Ok(res)
}
#[cfg(target_os = "windows")]
fn make_udp_client<A: ToSocketAddrs>(addr: &A) -> Result<std::net::UdpSocket, Box<dyn Error>> {
let res = UdpBuilder::new_v4()?.reuse_address(true)?.bind(addr)?;
Ok(res)
}
fn create_udp_socket<A: ToSocketAddrs>(
addr: &A,
) -> Result<std::net::UdpSocket, Box<dyn Error>> {
let res = Self::make_udp_client(addr)?;
res.set_send_buffer_size(1784 * 10000)?;
res.set_recv_buffer_size(1784 * 10000)?;
Ok(res)
}
fn create_async_udp_socket<A: ToSocketAddrs>(addr: &A) -> Result<UdpSocket, Box<dyn Error>> {
let std_sock = Self::create_udp_socket(&addr)?;
let sock = UdpSocket::try_from(std_sock)?;
Ok(sock)
}
fn create_udp_socket_list<A: ToSocketAddrs>(
addr: &A,
listen_count: usize,
) -> Result<Vec<UdpSocket>, Box<dyn Error>> {
println!("cpus:{}", listen_count);
let mut listens = vec![];
for _ in 0..listen_count {
let sock = Self::create_async_udp_socket(addr)?;
listens.push(sock);
}
Ok(listens)
}
#[cfg(not(target_os = "windows"))]
fn get_cpu_count() -> usize {
num_cpus::get()
}
#[cfg(target_os = "windows")]
fn get_cpu_count() -> usize {
1
}
pub async fn new_inner<A: ToSocketAddrs>(
addr: A,
inner: Arc<S>,
) -> Result<Self, Box<dyn Error>> {
let udp_list = Self::create_udp_socket_list(&addr, Self::get_cpu_count())?;
let mut udp_map = vec![];
let mut id = 1;
for udp in udp_list {
let udp_socket_ptr = Arc::new(udp);
udp_map.push(UdpContext {
id,
recv: RefCell::new(Some(udp_socket_ptr.clone())),
send: SendPool::new(udp_socket_ptr),
});
id += 1;
}
Ok(UdpServer {
inner,
udp_contexts: udp_map,
input: None,
error_input: None,
msg_tx: UnsafeCell::new(None),
remove_event: None,
phantom:PhantomData::default()
})
}
pub fn get_msg_tx(&self) -> Option<Sender<RecvType>> {
unsafe {
if let Some(ref tx) = *self.msg_tx.get() {
return Some(tx.clone());
}
None
}
}
pub fn set_input(&mut self, input: I) {
self.input = Some(Arc::new(input));
}
pub fn set_err_input<P: Fn(Option<SocketAddr>, Box<dyn Error>) -> bool + Send + 'static>(
&mut self,
err_input: P,
) {
self.error_input = Some(Arc::new(Mutex::new(err_input)));
}
pub fn set_remove_input<F: Fn(u32) + Send + 'static>(&mut self, remove_input: F) {
self.remove_event = Some(Box::new(remove_input));
}
pub async fn start(&self) -> Result<(), Box<dyn Error>> {
if let Some(ref input) = self.input {
let err_input = {
if let Some(err) = &self.error_input {
let x = err;
x.clone()
} else {
Arc::new(Mutex::new(
|addr: Option<SocketAddr>, err: Box<dyn Error>| {
match addr {
Some(addr) => {
println!("{}-{}", addr, err);
}
None => {
println!("{}", err);
}
}
true
},
))
}
};
let (tx, mut rx) = channel(self.udp_contexts.len() * 2048);
for udp_sock in self.udp_contexts.iter() {
let recv_sock = udp_sock.recv.borrow_mut().take();
let send_sock = udp_sock.send.get_tx();
if let Some(recv_sock) = recv_sock {
let move_data_tx = tx.clone();
let error_input = err_input.clone();
tokio::spawn(async move {
let mut buff = [0; BUFF_MAX_SIZE];
loop {
match recv_sock.recv_from(&mut buff).await{
Ok((size, addr)) => {
if let Err(er) = move_data_tx
.send(RecvType::INPUT(
send_sock.clone(),
addr,
buff[..size].to_vec(),
))
.await
{
let error = error_input.lock().await;
let _ = error(Some(addr), Box::new(er));
break;
}
},
Err(er) => {
let error = error_input.lock().await;
let stop = error(None, error::Error::IOError(er).into());
if stop {
return;
}
}
}
}
});
}
}
unsafe {
(*self.msg_tx.get()).replace(tx);
}
while let Some(recv_type) = rx.recv().await {
match recv_type {
RecvType::INPUT(send_sock, addr, data) => {
let err = {
let res = input(self.inner.clone(), send_sock, addr, data).await;
match res {
Err(er) => Some(format!("{}->{:?}", er, er)),
Ok(_) => None,
}
};
if let Some(er_msg) = err {
let error = err_input.lock().await;
let stop = error(Some(addr), er_msg.into());
if stop {
break;
}
}
}
RecvType::REMOVE(ids) => {
if let Some(ref remove_input) = self.remove_event {
remove_input(ids);
}
}
}
}
Ok(())
} else {
panic!("not found input")
}
}
}