openssh_sftp_client_lowlevel/
read_end.rs

1#![forbid(unsafe_code)]
2
3use super::{
4    awaitable_responses::ArenaArc, awaitable_responses::Response, connection::SharedData,
5    reader_buffered::ReaderBuffered, Error, Extensions, ToBuffer,
6};
7
8use std::{io, num::NonZeroUsize, pin::Pin};
9
10use openssh_sftp_error::RecursiveError;
11use openssh_sftp_protocol::{
12    constants::SSH2_FILEXFER_VERSION,
13    response::{self, ServerVersion},
14    serde::de::DeserializeOwned,
15    ssh_format::{self, from_bytes},
16};
17use pin_project::pin_project;
18use tokio::io::{copy_buf, sink, AsyncBufReadExt, AsyncRead, AsyncReadExt};
19use tokio_io_utility::{read_exact_to_bytes, read_exact_to_vec};
20
21/// The ReadEnd for the lowlevel API.
22#[derive(Debug)]
23#[pin_project]
24pub struct ReadEnd<R, Buffer, Q, Auxiliary = ()> {
25    #[pin]
26    reader: ReaderBuffered<R>,
27    shared_data: SharedData<Buffer, Q, Auxiliary>,
28}
29
30impl<R, Buffer, Q, Auxiliary> ReadEnd<R, Buffer, Q, Auxiliary>
31where
32    R: AsyncRead,
33    Buffer: ToBuffer + 'static + Send + Sync,
34{
35    /// Must call [`ReadEnd::receive_server_hello_pinned`]
36    /// or [`ReadEnd::receive_server_hello`] after this
37    /// function call.
38    pub fn new(
39        reader: R,
40        reader_buffer_len: NonZeroUsize,
41        shared_data: SharedData<Buffer, Q, Auxiliary>,
42    ) -> Self {
43        Self {
44            reader: ReaderBuffered::new(reader, reader_buffer_len),
45            shared_data,
46        }
47    }
48
49    /// Must be called once right after [`ReadEnd::new`]
50    /// to receive the hello message from the server.
51    pub async fn receive_server_hello_pinned(
52        mut self: Pin<&mut Self>,
53    ) -> Result<Extensions, Error> {
54        // Receive server version
55        let len: u32 = self.as_mut().read_and_deserialize(4).await?;
56        if (len as usize) > 4096 {
57            return Err(Error::SftpServerHelloMsgTooLong { len });
58        }
59
60        let drain = self
61            .project()
62            .reader
63            .read_exact_into_buffer(len as usize)
64            .await?;
65        let server_version =
66            ServerVersion::deserialize(&mut ssh_format::Deserializer::from_bytes(&drain))?;
67
68        if server_version.version != SSH2_FILEXFER_VERSION {
69            Err(Error::UnsupportedSftpProtocol {
70                version: server_version.version,
71            })
72        } else {
73            Ok(server_version.extensions)
74        }
75    }
76
77    async fn read_and_deserialize<T: DeserializeOwned>(
78        self: Pin<&mut Self>,
79        size: usize,
80    ) -> Result<T, Error> {
81        let drain = self.project().reader.read_exact_into_buffer(size).await?;
82        Ok(from_bytes(&drain)?.0)
83    }
84
85    async fn consume_packet(self: Pin<&mut Self>, len: u32, err: Error) -> Result<(), Error> {
86        let reader = self.project().reader;
87        if let Err(consumption_err) = copy_buf(&mut reader.take(len as u64), &mut sink()).await {
88            Err(Error::RecursiveErrors(Box::new(RecursiveError {
89                original_error: err,
90                occuring_error: consumption_err.into(),
91            })))
92        } else {
93            Err(err)
94        }
95    }
96
97    async fn read_into_box(self: Pin<&mut Self>, len: usize) -> Result<Box<[u8]>, Error> {
98        let mut vec = Vec::new();
99        read_exact_to_vec(&mut self.project().reader, &mut vec, len).await?;
100
101        Ok(vec.into_boxed_slice())
102    }
103
104    async fn read_in_data_packet_fallback(
105        self: Pin<&mut Self>,
106        len: usize,
107    ) -> Result<Response<Buffer>, Error> {
108        self.read_into_box(len).await.map(Response::AllocatedBox)
109    }
110
111    /// * `len` - excludes packet_type and request_id.
112    async fn read_in_data_packet(
113        mut self: Pin<&mut Self>,
114        len: u32,
115        buffer: Option<Buffer>,
116    ) -> Result<Response<Buffer>, Error> {
117        // Since the data is sent as a string, we need to consume the 4-byte length first.
118        self.as_mut()
119            .project()
120            .reader
121            .read_exact_into_buffer(4)
122            .await?;
123
124        let len = (len - 4) as usize;
125
126        if let Some(mut buffer) = buffer {
127            match buffer.get_buffer() {
128                super::Buffer::Vector(vec) => {
129                    read_exact_to_vec(&mut self.project().reader, vec, len).await?;
130                    Ok(Response::Buffer(buffer))
131                }
132                super::Buffer::Slice(slice) => {
133                    if slice.len() >= len {
134                        self.project().reader.read_exact(slice).await?;
135                        Ok(Response::Buffer(buffer))
136                    } else {
137                        self.read_in_data_packet_fallback(len).await
138                    }
139                }
140                super::Buffer::Bytes(bytes) => {
141                    read_exact_to_bytes(&mut self.project().reader, bytes, len).await?;
142                    Ok(Response::Buffer(buffer))
143                }
144            }
145        } else {
146            self.read_in_data_packet_fallback(len).await
147        }
148    }
149
150    /// * `len` - includes packet_type and request_id.
151    async fn read_in_packet(self: Pin<&mut Self>, len: u32) -> Result<Response<Buffer>, Error> {
152        let response: response::Response = self.read_and_deserialize(len as usize).await?;
153
154        Ok(Response::Header(response.response_inner))
155    }
156
157    /// * `len` - excludes packet_type and request_id.
158    async fn read_in_extended_reply(
159        self: Pin<&mut Self>,
160        len: u32,
161    ) -> Result<Response<Buffer>, Error> {
162        self.read_into_box(len as usize)
163            .await
164            .map(Response::ExtendedReply)
165    }
166
167    /// # Restart on Error
168    ///
169    /// Only when the returned error is [`Error::InvalidResponseId`] or
170    /// [`Error::AwaitableError`], can the function be restarted.
171    ///
172    /// Upon other errors [`Error::IOError`], [`Error::FormatError`] and
173    /// [`Error::RecursiveErrors`], the sftp session has to be discarded.
174    ///
175    /// # Example
176    ///
177    /// ```rust,ignore
178    /// let readend = ...;
179    /// loop {
180    ///     let new_requests_submit = readend.wait_for_new_request().await;
181    ///     if new_requests_submit == 0 {
182    ///         break;
183    ///     }
184    ///
185    ///     // If attempt to read in more than new_requests_submit, then
186    ///     // `read_in_one_packet` might block forever.
187    ///     for _ in 0..new_requests_submit {
188    ///         readend.read_in_one_packet().await.unwrap();
189    ///     }
190    /// }
191    /// ```
192    /// # Cancel Safety
193    ///
194    /// This function is not cancel safe.
195    ///
196    /// Dropping the future might cause the response packet to be partially read,
197    /// and the next read would treat the partial response as a new response.
198    pub async fn read_in_one_packet_pinned(mut self: Pin<&mut Self>) -> Result<(), Error> {
199        let this = self.as_mut().project();
200        let drain = this.reader.read_exact_into_buffer(9).await?;
201        let (len, packet_type, response_id): (u32, u8, u32) = from_bytes(&drain)?.0;
202
203        let len = len - 5;
204
205        let res = this.shared_data.responses().get(response_id);
206
207        let callback = match res {
208            Ok(callback) => callback,
209
210            // Invalid response_id
211            Err(err) => {
212                drop(drain);
213
214                // Consume the invalid data to return self to a valid state
215                // where read_in_one_packet can be called again.
216                return self.consume_packet(len, err).await;
217            }
218        };
219
220        let response = if response::Response::is_data(packet_type) {
221            drop(drain);
222
223            let buffer = match callback.take_input() {
224                Ok(buffer) => buffer,
225                Err(err) => {
226                    // Consume the invalid data to return self to a valid state
227                    // where read_in_one_packet can be called again.
228                    return self.consume_packet(len, err.into()).await;
229                }
230            };
231            self.read_in_data_packet(len, buffer).await?
232        } else if response::Response::is_extended_reply(packet_type) {
233            drop(drain);
234
235            self.read_in_extended_reply(len).await?
236        } else {
237            // Consumes 4 bytes and put back the rest, since
238            // read_in_packet needs the packet_type and response_id.
239            drain.subdrain(4);
240
241            self.read_in_packet(len + 5).await?
242        };
243
244        let res = callback.done(response);
245
246        // If counter == 2, then it must be one of the following situation:
247        //  - `ReadEnd` is the only holder other than the `Arena` itself;
248        //  - `ReadEnd` and the `AwaitableInner` is the holder and `AwaitableInner::drop`
249        //    has already `ArenaArc::remove`d it.
250        //
251        // In case 1, since there is no `AwaitableInner` holding reference to it,
252        // it can be removed safely.
253        //
254        // In case 2, since it is already removed, remove it again is a no-op.
255        //
256        // NOTE that if the arc is dropped after this call while having the
257        // `Awaitable*::drop` executed before `callback.done`, then the callback
258        // would not be removed.
259        //
260        // Though this kind of situation is rare.
261        if ArenaArc::strong_count(&callback) == 2 {
262            ArenaArc::remove(&callback);
263        }
264
265        Ok(res?)
266    }
267
268    /// Wait for next packet to be readable.
269    ///
270    /// Return `Ok(())` if next packet is ready and readable, `Error::IOError(io_error)`
271    /// where `io_error.kind() == ErrorKind::UnexpectedEof` if `EOF` is met.
272    ///
273    /// # Cancel Safety
274    ///
275    /// This function is cancel safe.
276    pub async fn ready_for_read_pinned(self: Pin<&mut Self>) -> Result<(), io::Error> {
277        if self.project().reader.fill_buf().await?.is_empty() {
278            // Empty buffer means EOF
279            Err(io::Error::new(io::ErrorKind::UnexpectedEof, ""))
280        } else {
281            Ok(())
282        }
283    }
284}
285
286impl<R, Buffer, Q, Auxiliary> ReadEnd<R, Buffer, Q, Auxiliary>
287where
288    Self: Unpin,
289    R: AsyncRead,
290    Buffer: ToBuffer + 'static + Send + Sync,
291{
292    /// Must be called once right after [`super::connect`]
293    /// to receive the hello message from the server.
294    pub async fn receive_server_hello(&mut self) -> Result<Extensions, Error> {
295        Pin::new(self).receive_server_hello_pinned().await
296    }
297
298    /// # Restart on Error
299    ///
300    /// Only when the returned error is [`Error::InvalidResponseId`] or
301    /// [`Error::AwaitableError`], can the function be restarted.
302    ///
303    /// Upon other errors [`Error::IOError`], [`Error::FormatError`] and
304    /// [`Error::RecursiveErrors`], the sftp session has to be discarded.
305    ///
306    /// # Cancel Safety
307    ///
308    /// This function is not cancel safe.
309    ///
310    /// Dropping the future might cause the response packet to be partially read,
311    /// and the next read would treat the partial response as a new response.
312    pub async fn read_in_one_packet(&mut self) -> Result<(), Error> {
313        Pin::new(self).read_in_one_packet_pinned().await
314    }
315
316    /// Wait for next packet to be readable.
317    ///
318    /// Return `Ok(())` if next packet is ready and readable, `Error::IOError(io_error)`
319    /// where `io_error.kind() == ErrorKind::UnexpectedEof` if `EOF` is met.
320    ///
321    /// # Cancel Safety
322    ///
323    /// This function is cancel safe.
324    pub async fn ready_for_read(&mut self) -> Result<(), io::Error> {
325        Pin::new(self).ready_for_read_pinned().await
326    }
327}
328
329impl<R, Buffer, Q, Auxiliary> ReadEnd<R, Buffer, Q, Auxiliary> {
330    /// Return the [`SharedData`] held by [`ReadEnd`].
331    pub fn get_shared_data(&self) -> &SharedData<Buffer, Q, Auxiliary> {
332        &self.shared_data
333    }
334}