use async_executor::Executor;
use async_io::Async;
use async_lock::Mutex;
use log::trace;
use std::collections::HashSet;
use std::future::Future;
use std::net::{IpAddr, SocketAddr, UdpSocket};
use std::sync::Arc;
use std::time::Duration;
use super::read_req::*;
use super::write_req::*;
use super::Handler;
use crate::error::*;
use crate::packet::{Packet, RwReq};
pub struct TftpServer<H>
where
H: Handler,
{
pub(crate) socket: Async<UdpSocket>,
pub(crate) handler: Arc<Mutex<H>>,
pub(crate) reqs_in_progress: Arc<Mutex<HashSet<SocketAddr>>>,
pub(crate) ex: Executor<'static>,
pub(crate) config: ServerConfig,
pub(crate) local_ip: IpAddr,
}
#[derive(Clone)]
pub(crate) struct ServerConfig {
pub(crate) timeout: Duration,
pub(crate) block_size_limit: Option<u16>,
pub(crate) max_send_retries: u32,
pub(crate) ignore_client_timeout: bool,
pub(crate) ignore_client_block_size: bool,
}
pub(crate) const DEFAULT_BLOCK_SIZE: usize = 512;
impl<H: 'static> TftpServer<H>
where
H: Handler,
{
pub fn listen_addr(&self) -> Result<SocketAddr> {
Ok(self.socket.get_ref().local_addr()?)
}
pub async fn serve(self) -> Result<()> {
self.ex
.run(async {
let mut buf = [0u8; 4096];
loop {
let (len, peer) = self.socket.recv_from(&mut buf).await?;
self.handle_req_packet(peer, &buf[..len]).await;
}
})
.await
}
async fn handle_req_packet(&self, peer: SocketAddr, data: &[u8]) {
let packet = match Packet::decode(data) {
Ok(p @ Packet::Rrq(_)) => p,
Ok(p @ Packet::Wrq(_)) => p,
Ok(_) => return,
Err(_) => return,
};
if !self.reqs_in_progress.lock().await.insert(peer) {
return;
}
match packet {
Packet::Rrq(req) => self.handle_rrq(peer, req),
Packet::Wrq(req) => self.handle_wrq(peer, req),
_ => unreachable!(),
}
}
fn handle_rrq(&self, peer: SocketAddr, req: RwReq) {
trace!("RRQ recieved (peer: {}, req: {:?})", &peer, &req);
let handler = Arc::clone(&self.handler);
let config = self.config.clone();
let local_ip = self.local_ip;
let req_fut = async move {
let (mut reader, size) = handler
.lock()
.await
.read_req_open(&peer, req.filename.as_ref())
.await
.map_err(Error::Packet)?;
let mut read_req = ReadRequest::init(
&mut reader,
size,
peer,
&req,
config,
local_ip,
)
.await?;
read_req.handle().await;
Ok(())
};
let reqs_in_progress = Arc::clone(&self.reqs_in_progress);
self.ex
.spawn(run_req(req_fut, peer, reqs_in_progress, local_ip))
.detach();
}
fn handle_wrq(&self, peer: SocketAddr, req: RwReq) {
trace!("WRQ recieved (peer: {}, req: {:?})", &peer, &req);
let handler = Arc::clone(&self.handler);
let config = self.config.clone();
let local_ip = self.local_ip;
let req_fut = async move {
let mut writer = handler
.lock()
.await
.write_req_open(
&peer,
req.filename.as_ref(),
req.opts.transfer_size,
)
.await
.map_err(Error::Packet)?;
let mut write_req =
WriteRequest::init(&mut writer, peer, &req, config, local_ip)
.await?;
write_req.handle().await;
Ok(())
};
let reqs_in_progress = Arc::clone(&self.reqs_in_progress);
self.ex
.spawn(run_req(req_fut, peer, reqs_in_progress, local_ip))
.detach();
}
}
async fn send_error(
error: Error,
peer: SocketAddr,
local_ip: IpAddr,
) -> Result<()> {
let addr: SocketAddr = SocketAddr::new(local_ip, 0);
let socket = Async::<UdpSocket>::bind(addr).map_err(Error::Bind)?;
let data = Packet::Error(error.into()).to_bytes();
socket.send_to(&data[..], peer).await?;
Ok(())
}
async fn run_req(
req_fut: impl Future<Output = Result<()>>,
peer: SocketAddr,
reqs_in_progress: Arc<Mutex<HashSet<SocketAddr>>>,
local_ip: IpAddr,
) {
if let Err(e) = req_fut.await {
trace!("Request failed (peer: {}, error: {}", &peer, &e);
if let Err(e) = send_error(e, peer, local_ip).await {
trace!("Failed to send error to peer {}: {}", &peer, &e);
}
}
reqs_in_progress.lock().await.remove(&peer);
}