async_tftp/server/
server.rs

1use async_executor::Executor;
2use async_io::Async;
3use async_lock::Mutex;
4use log::trace;
5use std::collections::HashSet;
6use std::future::Future;
7use std::net::{IpAddr, SocketAddr, UdpSocket};
8use std::sync::Arc;
9use std::time::Duration;
10
11use super::read_req::*;
12use super::write_req::*;
13use super::Handler;
14use crate::error::*;
15use crate::packet::{Packet, RwReq};
16
17/// TFTP server.
18pub struct TftpServer<H>
19where
20    H: Handler,
21{
22    pub(crate) socket: Async<UdpSocket>,
23    pub(crate) handler: Arc<Mutex<H>>,
24    pub(crate) reqs_in_progress: Arc<Mutex<HashSet<SocketAddr>>>,
25    pub(crate) ex: Executor<'static>,
26    pub(crate) config: ServerConfig,
27    pub(crate) local_ip: IpAddr,
28}
29
30#[derive(Clone)]
31pub(crate) struct ServerConfig {
32    pub(crate) timeout: Duration,
33    pub(crate) block_size_limit: Option<u16>,
34    pub(crate) window_size_limit: Option<u16>,
35    pub(crate) max_send_retries: u32,
36    pub(crate) ignore_client_timeout: bool,
37    pub(crate) ignore_client_block_size: bool,
38    pub(crate) ignore_client_window_size: bool,
39}
40
41pub(crate) const DEFAULT_BLOCK_SIZE: usize = 512;
42
43impl<H: 'static> TftpServer<H>
44where
45    H: Handler,
46{
47    /// Returns the listenning socket address.
48    pub fn listen_addr(&self) -> Result<SocketAddr> {
49        Ok(self.socket.get_ref().local_addr()?)
50    }
51
52    /// Consume and start the server.
53    pub async fn serve(self) -> Result<()> {
54        self.ex
55            .run(async {
56                let mut buf = [0u8; 4096];
57
58                loop {
59                    let (len, peer) = self.socket.recv_from(&mut buf).await?;
60                    self.handle_req_packet(peer, &buf[..len]).await;
61                }
62            })
63            .await
64    }
65
66    async fn handle_req_packet(&self, peer: SocketAddr, data: &[u8]) {
67        let packet = match Packet::decode(data) {
68            Ok(p @ Packet::Rrq(_)) => p,
69            Ok(p @ Packet::Wrq(_)) => p,
70            // Ignore packets that are not requests
71            Ok(_) => return,
72            // Ignore invalid packets
73            Err(_) => return,
74        };
75
76        if !self.reqs_in_progress.lock().await.insert(peer) {
77            // Ignore pending requests
78            return;
79        }
80
81        match packet {
82            Packet::Rrq(req) => self.handle_rrq(peer, req),
83            Packet::Wrq(req) => self.handle_wrq(peer, req),
84            _ => unreachable!(),
85        }
86    }
87
88    fn handle_rrq(&self, peer: SocketAddr, req: RwReq) {
89        trace!("RRQ recieved (peer: {}, req: {:?})", &peer, &req);
90
91        let handler = Arc::clone(&self.handler);
92        let config = self.config.clone();
93        let local_ip = self.local_ip;
94
95        // Prepare request future
96        let req_fut = async move {
97            let (mut reader, size) = handler
98                .lock()
99                .await
100                .read_req_open(&peer, req.filename.as_ref())
101                .await
102                .map_err(Error::Packet)?;
103
104            let mut read_req = ReadRequest::init(
105                &mut reader,
106                size,
107                peer,
108                &req,
109                config,
110                local_ip,
111            )
112            .await?;
113
114            read_req.handle().await;
115
116            Ok(())
117        };
118
119        let reqs_in_progress = Arc::clone(&self.reqs_in_progress);
120
121        // Run request future in a new task
122        self.ex
123            .spawn(run_req(req_fut, peer, reqs_in_progress, local_ip))
124            .detach();
125    }
126
127    fn handle_wrq(&self, peer: SocketAddr, req: RwReq) {
128        trace!("WRQ recieved (peer: {}, req: {:?})", &peer, &req);
129
130        let handler = Arc::clone(&self.handler);
131        let config = self.config.clone();
132        let local_ip = self.local_ip;
133
134        // Prepare request future
135        let req_fut = async move {
136            let mut writer = handler
137                .lock()
138                .await
139                .write_req_open(
140                    &peer,
141                    req.filename.as_ref(),
142                    req.opts.transfer_size,
143                )
144                .await
145                .map_err(Error::Packet)?;
146
147            let mut write_req =
148                WriteRequest::init(&mut writer, peer, &req, config, local_ip)
149                    .await?;
150
151            write_req.handle().await;
152
153            Ok(())
154        };
155
156        let reqs_in_progress = Arc::clone(&self.reqs_in_progress);
157
158        // Run request future in a new task
159        self.ex
160            .spawn(run_req(req_fut, peer, reqs_in_progress, local_ip))
161            .detach();
162    }
163}
164
165async fn send_error(
166    error: Error,
167    peer: SocketAddr,
168    local_ip: IpAddr,
169) -> Result<()> {
170    let addr: SocketAddr = SocketAddr::new(local_ip, 0);
171    let socket = Async::<UdpSocket>::bind(addr).map_err(Error::Bind)?;
172
173    let data = Packet::Error(error.into()).to_bytes();
174    socket.send_to(&data[..], peer).await?;
175
176    Ok(())
177}
178
179async fn run_req(
180    req_fut: impl Future<Output = Result<()>>,
181    peer: SocketAddr,
182    reqs_in_progress: Arc<Mutex<HashSet<SocketAddr>>>,
183    local_ip: IpAddr,
184) {
185    if let Err(e) = req_fut.await {
186        trace!("Request failed (peer: {}, error: {}", &peer, &e);
187
188        if let Err(e) = send_error(e, peer, local_ip).await {
189            trace!("Failed to send error to peer {}: {}", &peer, &e);
190        }
191    }
192
193    reqs_in_progress.lock().await.remove(&peer);
194}