use net_relay::builder::Parts;
use net_relay::{Builder, Error};
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Duration;
use udp_pool::bytes::BytesMut;
use udp_pool::net_pool::{Pool, debug, info, instrument_debug_span, tokio_spawn, warn2};
pub struct Relay<F, S, P = udp_pool::Pool>
where
F: Fn(udp_pool::Sender, BytesMut) -> S,
S: Future<Output = ()>,
P: Pool + udp_pool::UdpPool,
{
parts: Parts<P, F>,
pending: Option<Pin<Box<dyn Future<Output = Result<(), Error>> + Send + 'static>>>,
}
impl<F, S, P> Relay<F, S, P>
where
F: Fn(udp_pool::Sender, BytesMut) -> S,
S: Future<Output = ()>,
P: Pool + udp_pool::UdpPool,
{
pub fn build<B: FnOnce(Builder<P, F>) -> Builder<P, F>>(b: B) -> Result<Self, Error> {
let builder = Builder::new();
let parts = b(builder).build()?;
Ok(Relay {
parts,
pending: None,
})
}
pub fn bind_addrs(&self) -> &Vec<SocketAddr> {
&self.parts.bind_addrs
}
pub fn relay_fn(&self) -> Arc<F> {
self.parts.relay_fn.as_ref().unwrap().clone()
}
pub fn pool(&self) -> Arc<P> {
self.parts.pools[0].clone()
}
pub fn set_max_conn(&self, max: Option<usize>) {
self.pool().set_max_conn(max)
}
pub fn set_keepalive(&self, duration: Option<Duration>) {
self.pool().set_keepalive(duration)
}
}
impl<F, S, P> net_relay::Relay for Relay<F, S, P>
where
F: Fn(udp_pool::Sender, BytesMut) -> S + Send + Sync + 'static,
S: Future<Output = ()> + Send + 'static,
P: Pool + udp_pool::UdpPool + Send + 'static,
{
fn poll_run(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
if self.pending.is_none() {
let addrs = self.bind_addrs().clone();
let pool = self.pool();
let relay_fn = self.relay_fn();
self.pending = Some(Box::pin(async move {
let udp = tokio::net::UdpSocket::bind(addrs.as_slice())
.await
.map(|u| Arc::new(u))?;
info!("[Udp Relay] listen on: {:?}", udp.local_addr().unwrap());
loop {
let mut buf = BytesMut::with_capacity(1500);
if let Ok((_, addr)) = udp.recv_buf_from(&mut buf).await {
let tuple = (pool.clone(), udp.clone(), relay_fn.clone());
tokio_spawn! {
instrument_debug_span! {
async move {
match tuple.0.get(addr, Some(tuple.1)).await {
Ok(s) => {
debug!("[Udp Relay] recv udp packet: {}", buf.len());
tuple.2(s, buf).await;
},
Err(_e) => {
warn2!("[Udp Relay] get udp socket from pool, error occurred: {:?}", _e);
}
}
},
"udp_socket",
address=addr.to_string()
}
};
}
}
}));
}
self.pending.as_mut().unwrap().as_mut().poll(cx)
}
}
pub async fn default_relay_fn(sender: udp_pool::Sender, data: BytesMut) {
match sender.send(data).await {
Err(_e) => {
debug!("[Udp Relay] transfer packet, error occurred: {:?}", _e);
}
_ => {}
}
}