async_tftp/server/
server.rs

1use async_executor::Executor;
2use async_io::Async;
3use async_lock::Mutex;
4use log::trace;
5use std::collections::HashSet;
6use std::future::Future;
7use std::net::{IpAddr, SocketAddr, UdpSocket};
8use std::sync::Arc;
9use std::time::Duration;
10
11use super::read_req::*;
12use super::write_req::*;
13use super::Handler;
14use crate::error::*;
15use crate::packet::{Packet, RwReq};
16
17/// TFTP server.
18pub struct TftpServer<H>
19where
20    H: Handler,
21{
22    pub(crate) socket: Async<UdpSocket>,
23    pub(crate) handler: Arc<Mutex<H>>,
24    pub(crate) reqs_in_progress: Arc<Mutex<HashSet<SocketAddr>>>,
25    pub(crate) ex: Executor<'static>,
26    pub(crate) config: ServerConfig,
27    pub(crate) local_ip: IpAddr,
28}
29
30#[derive(Clone)]
31pub(crate) struct ServerConfig {
32    pub(crate) timeout: Duration,
33    pub(crate) block_size_limit: Option<u16>,
34    pub(crate) max_send_retries: u32,
35    pub(crate) ignore_client_timeout: bool,
36    pub(crate) ignore_client_block_size: bool,
37}
38
39pub(crate) const DEFAULT_BLOCK_SIZE: usize = 512;
40
41impl<H: 'static> TftpServer<H>
42where
43    H: Handler,
44{
45    /// Returns the listenning socket address.
46    pub fn listen_addr(&self) -> Result<SocketAddr> {
47        Ok(self.socket.get_ref().local_addr()?)
48    }
49
50    /// Consume and start the server.
51    pub async fn serve(self) -> Result<()> {
52        self.ex
53            .run(async {
54                let mut buf = [0u8; 4096];
55
56                loop {
57                    let (len, peer) = self.socket.recv_from(&mut buf).await?;
58                    self.handle_req_packet(peer, &buf[..len]).await;
59                }
60            })
61            .await
62    }
63
64    async fn handle_req_packet(&self, peer: SocketAddr, data: &[u8]) {
65        let packet = match Packet::decode(data) {
66            Ok(p @ Packet::Rrq(_)) => p,
67            Ok(p @ Packet::Wrq(_)) => p,
68            // Ignore packets that are not requests
69            Ok(_) => return,
70            // Ignore invalid packets
71            Err(_) => return,
72        };
73
74        if !self.reqs_in_progress.lock().await.insert(peer) {
75            // Ignore pending requests
76            return;
77        }
78
79        match packet {
80            Packet::Rrq(req) => self.handle_rrq(peer, req),
81            Packet::Wrq(req) => self.handle_wrq(peer, req),
82            _ => unreachable!(),
83        }
84    }
85
86    fn handle_rrq(&self, peer: SocketAddr, req: RwReq) {
87        trace!("RRQ recieved (peer: {}, req: {:?})", &peer, &req);
88
89        let handler = Arc::clone(&self.handler);
90        let config = self.config.clone();
91        let local_ip = self.local_ip;
92
93        // Prepare request future
94        let req_fut = async move {
95            let (mut reader, size) = handler
96                .lock()
97                .await
98                .read_req_open(&peer, req.filename.as_ref())
99                .await
100                .map_err(Error::Packet)?;
101
102            let mut read_req = ReadRequest::init(
103                &mut reader,
104                size,
105                peer,
106                &req,
107                config,
108                local_ip,
109            )
110            .await?;
111
112            read_req.handle().await;
113
114            Ok(())
115        };
116
117        let reqs_in_progress = Arc::clone(&self.reqs_in_progress);
118
119        // Run request future in a new task
120        self.ex
121            .spawn(run_req(req_fut, peer, reqs_in_progress, local_ip))
122            .detach();
123    }
124
125    fn handle_wrq(&self, peer: SocketAddr, req: RwReq) {
126        trace!("WRQ recieved (peer: {}, req: {:?})", &peer, &req);
127
128        let handler = Arc::clone(&self.handler);
129        let config = self.config.clone();
130        let local_ip = self.local_ip;
131
132        // Prepare request future
133        let req_fut = async move {
134            let mut writer = handler
135                .lock()
136                .await
137                .write_req_open(
138                    &peer,
139                    req.filename.as_ref(),
140                    req.opts.transfer_size,
141                )
142                .await
143                .map_err(Error::Packet)?;
144
145            let mut write_req =
146                WriteRequest::init(&mut writer, peer, &req, config, local_ip)
147                    .await?;
148
149            write_req.handle().await;
150
151            Ok(())
152        };
153
154        let reqs_in_progress = Arc::clone(&self.reqs_in_progress);
155
156        // Run request future in a new task
157        self.ex
158            .spawn(run_req(req_fut, peer, reqs_in_progress, local_ip))
159            .detach();
160    }
161}
162
163async fn send_error(
164    error: Error,
165    peer: SocketAddr,
166    local_ip: IpAddr,
167) -> Result<()> {
168    let addr: SocketAddr = SocketAddr::new(local_ip, 0);
169    let socket = Async::<UdpSocket>::bind(addr).map_err(Error::Bind)?;
170
171    let data = Packet::Error(error.into()).to_bytes();
172    socket.send_to(&data[..], peer).await?;
173
174    Ok(())
175}
176
177async fn run_req(
178    req_fut: impl Future<Output = Result<()>>,
179    peer: SocketAddr,
180    reqs_in_progress: Arc<Mutex<HashSet<SocketAddr>>>,
181    local_ip: IpAddr,
182) {
183    if let Err(e) = req_fut.await {
184        trace!("Request failed (peer: {}, error: {}", &peer, &e);
185
186        if let Err(e) = send_error(e, peer, local_ip).await {
187            trace!("Failed to send error to peer {}: {}", &peer, &e);
188        }
189    }
190
191    reqs_in_progress.lock().await.remove(&peer);
192}