async_tftp/server/
read_req.rs1use async_io::Async;
2use bytes::{BufMut, Bytes, BytesMut};
3use futures_lite::{AsyncRead, AsyncReadExt};
4use log::trace;
5use std::cmp;
6use std::io;
7use std::net::{IpAddr, SocketAddr, UdpSocket};
8use std::slice;
9use std::time::Duration;
10
11use crate::error::{Error, Result};
12use crate::packet::{Opts, Packet, RwReq, PACKET_DATA_HEADER_LEN};
13use crate::server::{ServerConfig, DEFAULT_BLOCK_SIZE};
14use crate::utils::io_timeout;
15
16pub(crate) struct ReadRequest<'r, R>
17where
18 R: AsyncRead + Send,
19{
20 peer: SocketAddr,
21 socket: Async<UdpSocket>,
22 reader: &'r mut R,
23 buffer: BytesMut,
24 block_size: usize,
25 timeout: Duration,
26 max_send_retries: u32,
27 oack_opts: Option<Opts>,
28}
29
30impl<'r, R> ReadRequest<'r, R>
31where
32 R: AsyncRead + Send + Unpin,
33{
34 pub(crate) async fn init(
35 reader: &'r mut R,
36 file_size: Option<u64>,
37 peer: SocketAddr,
38 req: &RwReq,
39 config: ServerConfig,
40 local_ip: IpAddr,
41 ) -> Result<ReadRequest<'r, R>> {
42 let oack_opts = build_oack_opts(&config, req, file_size);
43
44 let block_size = oack_opts
45 .as_ref()
46 .and_then(|o| o.block_size)
47 .map(usize::from)
48 .unwrap_or(DEFAULT_BLOCK_SIZE);
49
50 let timeout = oack_opts
51 .as_ref()
52 .and_then(|o| o.timeout)
53 .map(|t| Duration::from_secs(u64::from(t)))
54 .unwrap_or(config.timeout);
55
56 let addr = SocketAddr::new(local_ip, 0);
57 let socket = Async::<UdpSocket>::bind(addr).map_err(Error::Bind)?;
58
59 Ok(ReadRequest {
60 peer,
61 socket,
62 reader,
63 buffer: BytesMut::with_capacity(
64 PACKET_DATA_HEADER_LEN + block_size,
65 ),
66 block_size,
67 timeout,
68 max_send_retries: config.max_send_retries,
69 oack_opts,
70 })
71 }
72
73 pub(crate) async fn handle(&mut self) {
74 if let Err(e) = self.try_handle().await {
75 trace!("RRQ request failed (peer: {}, error: {})", &self.peer, &e);
76
77 Packet::Error(e.into()).encode(&mut self.buffer);
78 let buf = self.buffer.split().freeze();
79 let _ = self.socket.send_to(&buf[..], self.peer).await;
82 }
83 }
84
85 async fn try_handle(&mut self) -> Result<()> {
86 let mut block_id: u16 = 0;
87
88 loop {
90 let is_last_block;
91
92 self.buffer.reserve(PACKET_DATA_HEADER_LEN + self.block_size);
94
95 block_id = block_id.wrapping_add(1);
97 Packet::encode_data_head(block_id, &mut self.buffer);
98
99 let buf = unsafe {
101 let uninit_buf = self.buffer.chunk_mut();
102
103 let data_buf = slice::from_raw_parts_mut(
104 uninit_buf.as_mut_ptr(),
105 uninit_buf.len(),
106 );
107
108 let len = self.read_block(data_buf).await?;
109 is_last_block = len < self.block_size;
110
111 self.buffer.advance_mut(len);
112 self.buffer.split().freeze()
113 };
114
115 if let Some(opts) = self.oack_opts.take() {
120 trace!("RRQ OACK (peer: {}, opts: {:?}", &self.peer, &opts);
121
122 let mut buf = BytesMut::new();
123 Packet::OAck(opts.to_owned()).encode(&mut buf);
124
125 self.send(buf.split().freeze(), 0).await?;
126 }
127
128 self.send(buf, block_id).await?;
130
131 if is_last_block {
132 break;
133 }
134 }
135
136 trace!("RRQ request served (peer: {})", &self.peer);
137 Ok(())
138 }
139
140 async fn send(&mut self, packet: Bytes, block_id: u16) -> Result<()> {
141 for _ in 0..=self.max_send_retries {
143 self.socket.send_to(&packet[..], self.peer).await?;
144
145 match self.recv_ack(block_id).await {
146 Ok(_) => {
147 trace!(
148 "RRQ (peer: {}, block_id: {}) - Received ACK",
149 &self.peer,
150 block_id
151 );
152 return Ok(());
153 }
154 Err(ref e) if e.kind() == io::ErrorKind::TimedOut => {
155 trace!(
156 "RRQ (peer: {}, block_id: {}) - Timeout",
157 &self.peer,
158 block_id
159 );
160 continue;
161 }
162 Err(e) => return Err(e.into()),
163 }
164 }
165
166 Err(Error::MaxSendRetriesReached(self.peer, block_id))
167 }
168
169 async fn recv_ack(&mut self, block_id: u16) -> io::Result<()> {
170 let socket = &mut self.socket;
173 let peer = self.peer;
174
175 io_timeout(self.timeout, async {
176 let mut buf = [0u8; 1024];
177
178 loop {
179 let (len, recved_peer) = socket.recv_from(&mut buf[..]).await?;
180
181 if recved_peer != peer {
183 continue;
184 }
185
186 if let Ok(Packet::Ack(recved_block_id)) =
188 Packet::decode(&buf[..len])
189 {
190 if recved_block_id == block_id {
191 return Ok(());
192 }
193 }
194 }
195 })
196 .await?;
197
198 Ok(())
199 }
200
201 async fn read_block(&mut self, buf: &mut [u8]) -> Result<usize> {
202 let mut len = 0;
203
204 while len < buf.len() {
205 match self.reader.read(&mut buf[len..]).await? {
206 0 => break,
207 x => len += x,
208 }
209 }
210
211 Ok(len)
212 }
213}
214
215fn build_oack_opts(
216 config: &ServerConfig,
217 req: &RwReq,
218 file_size: Option<u64>,
219) -> Option<Opts> {
220 let mut opts = Opts::default();
221
222 if !config.ignore_client_block_size {
223 opts.block_size = match (req.opts.block_size, config.block_size_limit) {
224 (Some(bsize), Some(limit)) => Some(cmp::min(bsize, limit)),
225 (Some(bsize), None) => Some(bsize),
226 _ => None,
227 };
228 }
229
230 if !config.ignore_client_timeout {
231 opts.timeout = req.opts.timeout;
232 }
233
234 if let (Some(0), Some(file_size)) = (req.opts.transfer_size, file_size) {
235 opts.transfer_size = Some(file_size);
236 }
237
238 if opts == Opts::default() {
239 None
240 } else {
241 Some(opts)
242 }
243}