async_tftp/server/
read_req.rs

1use async_io::Async;
2use bytes::{BufMut, Bytes, BytesMut};
3use futures_lite::{AsyncRead, AsyncReadExt};
4use log::trace;
5use std::cmp;
6use std::io;
7use std::net::{IpAddr, SocketAddr, UdpSocket};
8use std::slice;
9use std::time::Duration;
10
11use crate::error::{Error, Result};
12use crate::packet::{Opts, Packet, RwReq, PACKET_DATA_HEADER_LEN};
13use crate::server::{ServerConfig, DEFAULT_BLOCK_SIZE};
14use crate::utils::io_timeout;
15
16pub(crate) struct ReadRequest<'r, R>
17where
18    R: AsyncRead + Send,
19{
20    peer: SocketAddr,
21    socket: Async<UdpSocket>,
22    reader: &'r mut R,
23    buffer: BytesMut,
24    block_size: usize,
25    timeout: Duration,
26    max_send_retries: u32,
27    oack_opts: Option<Opts>,
28}
29
30impl<'r, R> ReadRequest<'r, R>
31where
32    R: AsyncRead + Send + Unpin,
33{
34    pub(crate) async fn init(
35        reader: &'r mut R,
36        file_size: Option<u64>,
37        peer: SocketAddr,
38        req: &RwReq,
39        config: ServerConfig,
40        local_ip: IpAddr,
41    ) -> Result<ReadRequest<'r, R>> {
42        let oack_opts = build_oack_opts(&config, req, file_size);
43
44        let block_size = oack_opts
45            .as_ref()
46            .and_then(|o| o.block_size)
47            .map(usize::from)
48            .unwrap_or(DEFAULT_BLOCK_SIZE);
49
50        let timeout = oack_opts
51            .as_ref()
52            .and_then(|o| o.timeout)
53            .map(|t| Duration::from_secs(u64::from(t)))
54            .unwrap_or(config.timeout);
55
56        let addr = SocketAddr::new(local_ip, 0);
57        let socket = Async::<UdpSocket>::bind(addr).map_err(Error::Bind)?;
58
59        Ok(ReadRequest {
60            peer,
61            socket,
62            reader,
63            buffer: BytesMut::with_capacity(
64                PACKET_DATA_HEADER_LEN + block_size,
65            ),
66            block_size,
67            timeout,
68            max_send_retries: config.max_send_retries,
69            oack_opts,
70        })
71    }
72
73    pub(crate) async fn handle(&mut self) {
74        if let Err(e) = self.try_handle().await {
75            trace!("RRQ request failed (peer: {}, error: {})", &self.peer, &e);
76
77            Packet::Error(e.into()).encode(&mut self.buffer);
78            let buf = self.buffer.split().freeze();
79            // Errors are never retransmitted.
80            // We do not care if `send_to` resulted to an IO error.
81            let _ = self.socket.send_to(&buf[..], self.peer).await;
82        }
83    }
84
85    async fn try_handle(&mut self) -> Result<()> {
86        let mut block_id: u16 = 0;
87
88        // Send file to client
89        loop {
90            let is_last_block;
91
92            // Reclaim buffer
93            self.buffer.reserve(PACKET_DATA_HEADER_LEN + self.block_size);
94
95            // Encode head of Data packet
96            block_id = block_id.wrapping_add(1);
97            Packet::encode_data_head(block_id, &mut self.buffer);
98
99            // Read block in self.buffer
100            let buf = unsafe {
101                let uninit_buf = self.buffer.chunk_mut();
102
103                let data_buf = slice::from_raw_parts_mut(
104                    uninit_buf.as_mut_ptr(),
105                    uninit_buf.len(),
106                );
107
108                let len = self.read_block(data_buf).await?;
109                is_last_block = len < self.block_size;
110
111                self.buffer.advance_mut(len);
112                self.buffer.split().freeze()
113            };
114
115            // Send OACK after we manage to read the first block from reader.
116            //
117            // We do this because we want to give the developers the option to
118            // produce an error after they construct a reader.
119            if let Some(opts) = self.oack_opts.take() {
120                trace!("RRQ OACK (peer: {}, opts: {:?}", &self.peer, &opts);
121
122                let mut buf = BytesMut::new();
123                Packet::OAck(opts.to_owned()).encode(&mut buf);
124
125                self.send(buf.split().freeze(), 0).await?;
126            }
127
128            // Send Data packet
129            self.send(buf, block_id).await?;
130
131            if is_last_block {
132                break;
133            }
134        }
135
136        trace!("RRQ request served (peer: {})", &self.peer);
137        Ok(())
138    }
139
140    async fn send(&mut self, packet: Bytes, block_id: u16) -> Result<()> {
141        // Send packet until we receive an ack
142        for _ in 0..=self.max_send_retries {
143            self.socket.send_to(&packet[..], self.peer).await?;
144
145            match self.recv_ack(block_id).await {
146                Ok(_) => {
147                    trace!(
148                        "RRQ (peer: {}, block_id: {}) - Received ACK",
149                        &self.peer,
150                        block_id
151                    );
152                    return Ok(());
153                }
154                Err(ref e) if e.kind() == io::ErrorKind::TimedOut => {
155                    trace!(
156                        "RRQ (peer: {}, block_id: {}) - Timeout",
157                        &self.peer,
158                        block_id
159                    );
160                    continue;
161                }
162                Err(e) => return Err(e.into()),
163            }
164        }
165
166        Err(Error::MaxSendRetriesReached(self.peer, block_id))
167    }
168
169    async fn recv_ack(&mut self, block_id: u16) -> io::Result<()> {
170        // We can not use `self` within `async_std::io::timeout` because not all
171        // struct members implement `Sync`. So we borrow only what we need.
172        let socket = &mut self.socket;
173        let peer = self.peer;
174
175        io_timeout(self.timeout, async {
176            let mut buf = [0u8; 1024];
177
178            loop {
179                let (len, recved_peer) = socket.recv_from(&mut buf[..]).await?;
180
181                // if the packet do not come from the client we are serving, then ignore it
182                if recved_peer != peer {
183                    continue;
184                }
185
186                // parse only valid Ack packets, the rest are ignored
187                if let Ok(Packet::Ack(recved_block_id)) =
188                    Packet::decode(&buf[..len])
189                {
190                    if recved_block_id == block_id {
191                        return Ok(());
192                    }
193                }
194            }
195        })
196        .await?;
197
198        Ok(())
199    }
200
201    async fn read_block(&mut self, buf: &mut [u8]) -> Result<usize> {
202        let mut len = 0;
203
204        while len < buf.len() {
205            match self.reader.read(&mut buf[len..]).await? {
206                0 => break,
207                x => len += x,
208            }
209        }
210
211        Ok(len)
212    }
213}
214
215fn build_oack_opts(
216    config: &ServerConfig,
217    req: &RwReq,
218    file_size: Option<u64>,
219) -> Option<Opts> {
220    let mut opts = Opts::default();
221
222    if !config.ignore_client_block_size {
223        opts.block_size = match (req.opts.block_size, config.block_size_limit) {
224            (Some(bsize), Some(limit)) => Some(cmp::min(bsize, limit)),
225            (Some(bsize), None) => Some(bsize),
226            _ => None,
227        };
228    }
229
230    if !config.ignore_client_timeout {
231        opts.timeout = req.opts.timeout;
232    }
233
234    if let (Some(0), Some(file_size)) = (req.opts.transfer_size, file_size) {
235        opts.transfer_size = Some(file_size);
236    }
237
238    if opts == Opts::default() {
239        None
240    } else {
241        Some(opts)
242    }
243}