async_tftp/server/
write_req.rs1use async_io::Async;
2use bytes::{Buf, Bytes, BytesMut};
3use futures_lite::{AsyncWrite, AsyncWriteExt};
4use log::trace;
5use std::cmp;
6use std::io;
7use std::net::{IpAddr, SocketAddr, UdpSocket};
8use std::time::Duration;
9
10use crate::error::{Error, Result};
11use crate::packet::{Opts, Packet, RwReq, PACKET_DATA_HEADER_LEN};
12use crate::server::{ServerConfig, DEFAULT_BLOCK_SIZE};
13use crate::utils::io_timeout;
14
15pub(crate) struct WriteRequest<'w, W>
16where
17 W: AsyncWrite + Send,
18{
19 peer: SocketAddr,
20 socket: Async<UdpSocket>,
21 writer: &'w mut W,
22 buffer: BytesMut,
27 ack: BytesMut,
28 block_size: usize,
29 timeout: Duration,
30 max_retries: u32,
31 oack_opts: Option<Opts>,
32}
33
34impl<'w, W> WriteRequest<'w, W>
35where
36 W: AsyncWrite + Send + Unpin,
37{
38 pub(crate) async fn init(
39 writer: &'w mut W,
40 peer: SocketAddr,
41 req: &RwReq,
42 config: ServerConfig,
43 local_ip: IpAddr,
44 ) -> Result<WriteRequest<'w, W>> {
45 let oack_opts = build_oack_opts(&config, req);
46
47 let block_size = oack_opts
48 .as_ref()
49 .and_then(|o| o.block_size)
50 .map(usize::from)
51 .unwrap_or(DEFAULT_BLOCK_SIZE);
52
53 let timeout = oack_opts
54 .as_ref()
55 .and_then(|o| o.timeout)
56 .map(|t| Duration::from_secs(u64::from(t)))
57 .unwrap_or(config.timeout);
58
59 let addr = SocketAddr::new(local_ip, 0);
60 let socket = Async::<UdpSocket>::bind(addr).map_err(Error::Bind)?;
61
62 Ok(WriteRequest {
63 peer,
64 socket,
65 writer,
66 buffer: BytesMut::new(),
67 ack: BytesMut::new(),
68 block_size,
69 timeout,
70 max_retries: config.max_send_retries,
71 oack_opts,
72 })
73 }
74
75 pub(crate) async fn handle(&mut self) {
76 if let Err(e) = self.try_handle().await {
77 trace!("WRQ request failed (peer: {}, error: {}", self.peer, &e);
78
79 Packet::Error(e.into()).encode(&mut self.buffer);
80 let buf = self.buffer.split().freeze();
81 let _ = self.socket.send_to(&buf[..], self.peer).await;
84 }
85 }
86
87 async fn try_handle(&mut self) -> Result<()> {
88 let mut block_id: u16 = 0;
89
90 match self.oack_opts.take() {
92 Some(opts) => Packet::OAck(opts).encode(&mut self.ack),
93 None => Packet::Ack(0).encode(&mut self.ack),
94 }
95
96 self.socket.send_to(&self.ack, self.peer).await?;
97
98 loop {
99 block_id = block_id.wrapping_add(1);
101 let data = self.recv_data(block_id).await?;
102
103 self.writer.write_all(&data[..]).await?;
105
106 if data.len() < self.block_size {
107 break;
108 }
109 }
110
111 Ok(())
112 }
113
114 async fn recv_data(&mut self, block_id: u16) -> Result<Bytes> {
115 for _ in 0..=self.max_retries {
116 match self.recv_data_block(block_id).await {
117 Ok(data) => {
118 self.ack.clear();
120 Packet::Ack(block_id).encode(&mut self.ack);
121
122 self.socket.send_to(&self.ack, self.peer).await?;
123 return Ok(data);
124 }
125 Err(ref e) if e.kind() == io::ErrorKind::TimedOut => {
126 self.socket.send_to(&self.ack, self.peer).await?;
128 continue;
129 }
130 Err(e) => return Err(e.into()),
131 }
132 }
133
134 Err(Error::MaxSendRetriesReached(self.peer, block_id))
135 }
136
137 async fn recv_data_block(&mut self, block_id: u16) -> io::Result<Bytes> {
138 let socket = &mut self.socket;
139 let peer = self.peer;
140
141 self.buffer.resize(PACKET_DATA_HEADER_LEN + self.block_size, 0);
142 let mut buf = self.buffer.split();
143
144 io_timeout(self.timeout, async move {
145 loop {
146 let (len, recved_peer) = socket.recv_from(&mut buf[..]).await?;
147
148 if recved_peer != peer {
149 continue;
150 }
151
152 if let Ok(Packet::Data(recved_block_id, _)) =
153 Packet::decode(&buf[..len])
154 {
155 if recved_block_id == block_id {
156 buf.truncate(len);
157 buf.advance(PACKET_DATA_HEADER_LEN);
158 break;
159 }
160 }
161 }
162
163 Ok(buf.freeze())
164 })
165 .await
166 }
167}
168
169fn build_oack_opts(config: &ServerConfig, req: &RwReq) -> Option<Opts> {
170 let mut opts = Opts::default();
171
172 if !config.ignore_client_block_size {
173 opts.block_size = match (req.opts.block_size, config.block_size_limit) {
174 (Some(bsize), Some(limit)) => Some(cmp::min(bsize, limit)),
175 (Some(bsize), None) => Some(bsize),
176 _ => None,
177 };
178 }
179
180 if !config.ignore_client_timeout {
181 opts.timeout = req.opts.timeout;
182 }
183
184 opts.transfer_size = req.opts.transfer_size;
185
186 if opts == Opts::default() {
187 None
188 } else {
189 Some(opts)
190 }
191}