async_tftp/server/
server.rs1use async_executor::Executor;
2use async_io::Async;
3use async_lock::Mutex;
4use log::trace;
5use std::collections::HashSet;
6use std::future::Future;
7use std::net::{IpAddr, SocketAddr, UdpSocket};
8use std::sync::Arc;
9use std::time::Duration;
10
11use super::read_req::*;
12use super::write_req::*;
13use super::Handler;
14use crate::error::*;
15use crate::packet::{Packet, RwReq};
16
17pub struct TftpServer<H>
19where
20 H: Handler,
21{
22 pub(crate) socket: Async<UdpSocket>,
23 pub(crate) handler: Arc<Mutex<H>>,
24 pub(crate) reqs_in_progress: Arc<Mutex<HashSet<SocketAddr>>>,
25 pub(crate) ex: Executor<'static>,
26 pub(crate) config: ServerConfig,
27 pub(crate) local_ip: IpAddr,
28}
29
30#[derive(Clone)]
31pub(crate) struct ServerConfig {
32 pub(crate) timeout: Duration,
33 pub(crate) block_size_limit: Option<u16>,
34 pub(crate) window_size_limit: Option<u16>,
35 pub(crate) max_send_retries: u32,
36 pub(crate) ignore_client_timeout: bool,
37 pub(crate) ignore_client_block_size: bool,
38 pub(crate) ignore_client_window_size: bool,
39}
40
41pub(crate) const DEFAULT_BLOCK_SIZE: usize = 512;
42
43impl<H: 'static> TftpServer<H>
44where
45 H: Handler,
46{
47 pub fn listen_addr(&self) -> Result<SocketAddr> {
49 Ok(self.socket.get_ref().local_addr()?)
50 }
51
52 pub async fn serve(self) -> Result<()> {
54 self.ex
55 .run(async {
56 let mut buf = [0u8; 4096];
57
58 loop {
59 let (len, peer) = self.socket.recv_from(&mut buf).await?;
60 self.handle_req_packet(peer, &buf[..len]).await;
61 }
62 })
63 .await
64 }
65
66 async fn handle_req_packet(&self, peer: SocketAddr, data: &[u8]) {
67 let packet = match Packet::decode(data) {
68 Ok(p @ Packet::Rrq(_)) => p,
69 Ok(p @ Packet::Wrq(_)) => p,
70 Ok(_) => return,
72 Err(_) => return,
74 };
75
76 if !self.reqs_in_progress.lock().await.insert(peer) {
77 return;
79 }
80
81 match packet {
82 Packet::Rrq(req) => self.handle_rrq(peer, req),
83 Packet::Wrq(req) => self.handle_wrq(peer, req),
84 _ => unreachable!(),
85 }
86 }
87
88 fn handle_rrq(&self, peer: SocketAddr, req: RwReq) {
89 trace!("RRQ recieved (peer: {}, req: {:?})", &peer, &req);
90
91 let handler = Arc::clone(&self.handler);
92 let config = self.config.clone();
93 let local_ip = self.local_ip;
94
95 let req_fut = async move {
97 let (mut reader, size) = handler
98 .lock()
99 .await
100 .read_req_open(&peer, req.filename.as_ref())
101 .await
102 .map_err(Error::Packet)?;
103
104 let mut read_req = ReadRequest::init(
105 &mut reader,
106 size,
107 peer,
108 &req,
109 config,
110 local_ip,
111 )
112 .await?;
113
114 read_req.handle().await;
115
116 Ok(())
117 };
118
119 let reqs_in_progress = Arc::clone(&self.reqs_in_progress);
120
121 self.ex
123 .spawn(run_req(req_fut, peer, reqs_in_progress, local_ip))
124 .detach();
125 }
126
127 fn handle_wrq(&self, peer: SocketAddr, req: RwReq) {
128 trace!("WRQ recieved (peer: {}, req: {:?})", &peer, &req);
129
130 let handler = Arc::clone(&self.handler);
131 let config = self.config.clone();
132 let local_ip = self.local_ip;
133
134 let req_fut = async move {
136 let mut writer = handler
137 .lock()
138 .await
139 .write_req_open(
140 &peer,
141 req.filename.as_ref(),
142 req.opts.transfer_size,
143 )
144 .await
145 .map_err(Error::Packet)?;
146
147 let mut write_req =
148 WriteRequest::init(&mut writer, peer, &req, config, local_ip)
149 .await?;
150
151 write_req.handle().await;
152
153 Ok(())
154 };
155
156 let reqs_in_progress = Arc::clone(&self.reqs_in_progress);
157
158 self.ex
160 .spawn(run_req(req_fut, peer, reqs_in_progress, local_ip))
161 .detach();
162 }
163}
164
165async fn send_error(
166 error: Error,
167 peer: SocketAddr,
168 local_ip: IpAddr,
169) -> Result<()> {
170 let addr: SocketAddr = SocketAddr::new(local_ip, 0);
171 let socket = Async::<UdpSocket>::bind(addr).map_err(Error::Bind)?;
172
173 let data = Packet::Error(error.into()).to_bytes();
174 socket.send_to(&data[..], peer).await?;
175
176 Ok(())
177}
178
179async fn run_req(
180 req_fut: impl Future<Output = Result<()>>,
181 peer: SocketAddr,
182 reqs_in_progress: Arc<Mutex<HashSet<SocketAddr>>>,
183 local_ip: IpAddr,
184) {
185 if let Err(e) = req_fut.await {
186 trace!("Request failed (peer: {}, error: {}", &peer, &e);
187
188 if let Err(e) = send_error(e, peer, local_ip).await {
189 trace!("Failed to send error to peer {}: {}", &peer, &e);
190 }
191 }
192
193 reqs_in_progress.lock().await.remove(&peer);
194}