async_tftp/server/
write_req.rs

1use async_io::Async;
2use bytes::{Buf, Bytes, BytesMut};
3use futures_lite::{AsyncWrite, AsyncWriteExt};
4use log::trace;
5use std::cmp;
6use std::io;
7use std::net::{IpAddr, SocketAddr, UdpSocket};
8use std::time::Duration;
9
10use crate::error::{Error, Result};
11use crate::packet::{Opts, Packet, RwReq, PACKET_DATA_HEADER_LEN};
12use crate::server::{ServerConfig, DEFAULT_BLOCK_SIZE};
13use crate::utils::io_timeout;
14
15pub(crate) struct WriteRequest<'w, W>
16where
17    W: AsyncWrite + Send,
18{
19    peer: SocketAddr,
20    socket: Async<UdpSocket>,
21    writer: &'w mut W,
22    // BytesMut reclaims memory only if it is continuous.
23    // Because we always need to keep the previous ACK, we can not use
24    // `buffer` as its storage since it breaks the continuity.
25    // So we keep previous ACK in `ack` buffer.
26    buffer: BytesMut,
27    ack: BytesMut,
28    block_size: usize,
29    timeout: Duration,
30    max_retries: u32,
31    oack_opts: Option<Opts>,
32}
33
34impl<'w, W> WriteRequest<'w, W>
35where
36    W: AsyncWrite + Send + Unpin,
37{
38    pub(crate) async fn init(
39        writer: &'w mut W,
40        peer: SocketAddr,
41        req: &RwReq,
42        config: ServerConfig,
43        local_ip: IpAddr,
44    ) -> Result<WriteRequest<'w, W>> {
45        let oack_opts = build_oack_opts(&config, req);
46
47        let block_size = oack_opts
48            .as_ref()
49            .and_then(|o| o.block_size)
50            .map(usize::from)
51            .unwrap_or(DEFAULT_BLOCK_SIZE);
52
53        let timeout = oack_opts
54            .as_ref()
55            .and_then(|o| o.timeout)
56            .map(|t| Duration::from_secs(u64::from(t)))
57            .unwrap_or(config.timeout);
58
59        let addr = SocketAddr::new(local_ip, 0);
60        let socket = Async::<UdpSocket>::bind(addr).map_err(Error::Bind)?;
61
62        Ok(WriteRequest {
63            peer,
64            socket,
65            writer,
66            buffer: BytesMut::new(),
67            ack: BytesMut::new(),
68            block_size,
69            timeout,
70            max_retries: config.max_send_retries,
71            oack_opts,
72        })
73    }
74
75    pub(crate) async fn handle(&mut self) {
76        if let Err(e) = self.try_handle().await {
77            trace!("WRQ request failed (peer: {}, error: {}", self.peer, &e);
78
79            Packet::Error(e.into()).encode(&mut self.buffer);
80            let buf = self.buffer.split().freeze();
81            // Errors are never retransmitted.
82            // We do not care if `send_to` resulted to an IO error.
83            let _ = self.socket.send_to(&buf[..], self.peer).await;
84        }
85    }
86
87    async fn try_handle(&mut self) -> Result<()> {
88        let mut block_id: u16 = 0;
89
90        // Send first Ack/OAck
91        match self.oack_opts.take() {
92            Some(opts) => Packet::OAck(opts).encode(&mut self.ack),
93            None => Packet::Ack(0).encode(&mut self.ack),
94        }
95
96        self.socket.send_to(&self.ack, self.peer).await?;
97
98        loop {
99            // Recv data
100            block_id = block_id.wrapping_add(1);
101            let data = self.recv_data(block_id).await?;
102
103            // Write data to file
104            self.writer.write_all(&data[..]).await?;
105
106            if data.len() < self.block_size {
107                break;
108            }
109        }
110
111        Ok(())
112    }
113
114    async fn recv_data(&mut self, block_id: u16) -> Result<Bytes> {
115        for _ in 0..=self.max_retries {
116            match self.recv_data_block(block_id).await {
117                Ok(data) => {
118                    // Data received, send ACK
119                    self.ack.clear();
120                    Packet::Ack(block_id).encode(&mut self.ack);
121
122                    self.socket.send_to(&self.ack, self.peer).await?;
123                    return Ok(data);
124                }
125                Err(ref e) if e.kind() == io::ErrorKind::TimedOut => {
126                    // On timeout reply with the previous ACK packet
127                    self.socket.send_to(&self.ack, self.peer).await?;
128                    continue;
129                }
130                Err(e) => return Err(e.into()),
131            }
132        }
133
134        Err(Error::MaxSendRetriesReached(self.peer, block_id))
135    }
136
137    async fn recv_data_block(&mut self, block_id: u16) -> io::Result<Bytes> {
138        let socket = &mut self.socket;
139        let peer = self.peer;
140
141        self.buffer.resize(PACKET_DATA_HEADER_LEN + self.block_size, 0);
142        let mut buf = self.buffer.split();
143
144        io_timeout(self.timeout, async move {
145            loop {
146                let (len, recved_peer) = socket.recv_from(&mut buf[..]).await?;
147
148                if recved_peer != peer {
149                    continue;
150                }
151
152                if let Ok(Packet::Data(recved_block_id, _)) =
153                    Packet::decode(&buf[..len])
154                {
155                    if recved_block_id == block_id {
156                        buf.truncate(len);
157                        buf.advance(PACKET_DATA_HEADER_LEN);
158                        break;
159                    }
160                }
161            }
162
163            Ok(buf.freeze())
164        })
165        .await
166    }
167}
168
169fn build_oack_opts(config: &ServerConfig, req: &RwReq) -> Option<Opts> {
170    let mut opts = Opts::default();
171
172    if !config.ignore_client_block_size {
173        opts.block_size = match (req.opts.block_size, config.block_size_limit) {
174            (Some(bsize), Some(limit)) => Some(cmp::min(bsize, limit)),
175            (Some(bsize), None) => Some(bsize),
176            _ => None,
177        };
178    }
179
180    if !config.ignore_client_timeout {
181        opts.timeout = req.opts.timeout;
182    }
183
184    opts.transfer_size = req.opts.transfer_size;
185
186    if opts == Opts::default() {
187        None
188    } else {
189        Some(opts)
190    }
191}