async_tftp/server/
server.rs1use async_executor::Executor;
2use async_io::Async;
3use async_lock::Mutex;
4use log::trace;
5use std::collections::HashSet;
6use std::future::Future;
7use std::net::{IpAddr, SocketAddr, UdpSocket};
8use std::sync::Arc;
9use std::time::Duration;
10
11use super::read_req::*;
12use super::write_req::*;
13use super::Handler;
14use crate::error::*;
15use crate::packet::{Packet, RwReq};
16
17pub struct TftpServer<H>
19where
20 H: Handler,
21{
22 pub(crate) socket: Async<UdpSocket>,
23 pub(crate) handler: Arc<Mutex<H>>,
24 pub(crate) reqs_in_progress: Arc<Mutex<HashSet<SocketAddr>>>,
25 pub(crate) ex: Executor<'static>,
26 pub(crate) config: ServerConfig,
27 pub(crate) local_ip: IpAddr,
28}
29
30#[derive(Clone)]
31pub(crate) struct ServerConfig {
32 pub(crate) timeout: Duration,
33 pub(crate) block_size_limit: Option<u16>,
34 pub(crate) max_send_retries: u32,
35 pub(crate) ignore_client_timeout: bool,
36 pub(crate) ignore_client_block_size: bool,
37}
38
39pub(crate) const DEFAULT_BLOCK_SIZE: usize = 512;
40
41impl<H: 'static> TftpServer<H>
42where
43 H: Handler,
44{
45 pub fn listen_addr(&self) -> Result<SocketAddr> {
47 Ok(self.socket.get_ref().local_addr()?)
48 }
49
50 pub async fn serve(self) -> Result<()> {
52 self.ex
53 .run(async {
54 let mut buf = [0u8; 4096];
55
56 loop {
57 let (len, peer) = self.socket.recv_from(&mut buf).await?;
58 self.handle_req_packet(peer, &buf[..len]).await;
59 }
60 })
61 .await
62 }
63
64 async fn handle_req_packet(&self, peer: SocketAddr, data: &[u8]) {
65 let packet = match Packet::decode(data) {
66 Ok(p @ Packet::Rrq(_)) => p,
67 Ok(p @ Packet::Wrq(_)) => p,
68 Ok(_) => return,
70 Err(_) => return,
72 };
73
74 if !self.reqs_in_progress.lock().await.insert(peer) {
75 return;
77 }
78
79 match packet {
80 Packet::Rrq(req) => self.handle_rrq(peer, req),
81 Packet::Wrq(req) => self.handle_wrq(peer, req),
82 _ => unreachable!(),
83 }
84 }
85
86 fn handle_rrq(&self, peer: SocketAddr, req: RwReq) {
87 trace!("RRQ recieved (peer: {}, req: {:?})", &peer, &req);
88
89 let handler = Arc::clone(&self.handler);
90 let config = self.config.clone();
91 let local_ip = self.local_ip;
92
93 let req_fut = async move {
95 let (mut reader, size) = handler
96 .lock()
97 .await
98 .read_req_open(&peer, req.filename.as_ref())
99 .await
100 .map_err(Error::Packet)?;
101
102 let mut read_req = ReadRequest::init(
103 &mut reader,
104 size,
105 peer,
106 &req,
107 config,
108 local_ip,
109 )
110 .await?;
111
112 read_req.handle().await;
113
114 Ok(())
115 };
116
117 let reqs_in_progress = Arc::clone(&self.reqs_in_progress);
118
119 self.ex
121 .spawn(run_req(req_fut, peer, reqs_in_progress, local_ip))
122 .detach();
123 }
124
125 fn handle_wrq(&self, peer: SocketAddr, req: RwReq) {
126 trace!("WRQ recieved (peer: {}, req: {:?})", &peer, &req);
127
128 let handler = Arc::clone(&self.handler);
129 let config = self.config.clone();
130 let local_ip = self.local_ip;
131
132 let req_fut = async move {
134 let mut writer = handler
135 .lock()
136 .await
137 .write_req_open(
138 &peer,
139 req.filename.as_ref(),
140 req.opts.transfer_size,
141 )
142 .await
143 .map_err(Error::Packet)?;
144
145 let mut write_req =
146 WriteRequest::init(&mut writer, peer, &req, config, local_ip)
147 .await?;
148
149 write_req.handle().await;
150
151 Ok(())
152 };
153
154 let reqs_in_progress = Arc::clone(&self.reqs_in_progress);
155
156 self.ex
158 .spawn(run_req(req_fut, peer, reqs_in_progress, local_ip))
159 .detach();
160 }
161}
162
163async fn send_error(
164 error: Error,
165 peer: SocketAddr,
166 local_ip: IpAddr,
167) -> Result<()> {
168 let addr: SocketAddr = SocketAddr::new(local_ip, 0);
169 let socket = Async::<UdpSocket>::bind(addr).map_err(Error::Bind)?;
170
171 let data = Packet::Error(error.into()).to_bytes();
172 socket.send_to(&data[..], peer).await?;
173
174 Ok(())
175}
176
177async fn run_req(
178 req_fut: impl Future<Output = Result<()>>,
179 peer: SocketAddr,
180 reqs_in_progress: Arc<Mutex<HashSet<SocketAddr>>>,
181 local_ip: IpAddr,
182) {
183 if let Err(e) = req_fut.await {
184 trace!("Request failed (peer: {}, error: {}", &peer, &e);
185
186 if let Err(e) = send_error(e, peer, local_ip).await {
187 trace!("Failed to send error to peer {}: {}", &peer, &e);
188 }
189 }
190
191 reqs_in_progress.lock().await.remove(&peer);
192}