openssh_sftp_client_lowlevel/
read_end.rs1#![forbid(unsafe_code)]
2
3use super::{
4 awaitable_responses::ArenaArc, awaitable_responses::Response, connection::SharedData,
5 reader_buffered::ReaderBuffered, Error, Extensions, ToBuffer,
6};
7
8use std::{io, num::NonZeroUsize, pin::Pin};
9
10use openssh_sftp_error::RecursiveError;
11use openssh_sftp_protocol::{
12 constants::SSH2_FILEXFER_VERSION,
13 response::{self, ServerVersion},
14 serde::de::DeserializeOwned,
15 ssh_format::{self, from_bytes},
16};
17use pin_project::pin_project;
18use tokio::io::{copy_buf, sink, AsyncBufReadExt, AsyncRead, AsyncReadExt};
19use tokio_io_utility::{read_exact_to_bytes, read_exact_to_vec};
20
21#[derive(Debug)]
23#[pin_project]
24pub struct ReadEnd<R, Buffer, Q, Auxiliary = ()> {
25 #[pin]
26 reader: ReaderBuffered<R>,
27 shared_data: SharedData<Buffer, Q, Auxiliary>,
28}
29
30impl<R, Buffer, Q, Auxiliary> ReadEnd<R, Buffer, Q, Auxiliary>
31where
32 R: AsyncRead,
33 Buffer: ToBuffer + 'static + Send + Sync,
34{
35 pub fn new(
39 reader: R,
40 reader_buffer_len: NonZeroUsize,
41 shared_data: SharedData<Buffer, Q, Auxiliary>,
42 ) -> Self {
43 Self {
44 reader: ReaderBuffered::new(reader, reader_buffer_len),
45 shared_data,
46 }
47 }
48
49 pub async fn receive_server_hello_pinned(
52 mut self: Pin<&mut Self>,
53 ) -> Result<Extensions, Error> {
54 let len: u32 = self.as_mut().read_and_deserialize(4).await?;
56 if (len as usize) > 4096 {
57 return Err(Error::SftpServerHelloMsgTooLong { len });
58 }
59
60 let drain = self
61 .project()
62 .reader
63 .read_exact_into_buffer(len as usize)
64 .await?;
65 let server_version =
66 ServerVersion::deserialize(&mut ssh_format::Deserializer::from_bytes(&drain))?;
67
68 if server_version.version != SSH2_FILEXFER_VERSION {
69 Err(Error::UnsupportedSftpProtocol {
70 version: server_version.version,
71 })
72 } else {
73 Ok(server_version.extensions)
74 }
75 }
76
77 async fn read_and_deserialize<T: DeserializeOwned>(
78 self: Pin<&mut Self>,
79 size: usize,
80 ) -> Result<T, Error> {
81 let drain = self.project().reader.read_exact_into_buffer(size).await?;
82 Ok(from_bytes(&drain)?.0)
83 }
84
85 async fn consume_packet(self: Pin<&mut Self>, len: u32, err: Error) -> Result<(), Error> {
86 let reader = self.project().reader;
87 if let Err(consumption_err) = copy_buf(&mut reader.take(len as u64), &mut sink()).await {
88 Err(Error::RecursiveErrors(Box::new(RecursiveError {
89 original_error: err,
90 occuring_error: consumption_err.into(),
91 })))
92 } else {
93 Err(err)
94 }
95 }
96
97 async fn read_into_box(self: Pin<&mut Self>, len: usize) -> Result<Box<[u8]>, Error> {
98 let mut vec = Vec::new();
99 read_exact_to_vec(&mut self.project().reader, &mut vec, len).await?;
100
101 Ok(vec.into_boxed_slice())
102 }
103
104 async fn read_in_data_packet_fallback(
105 self: Pin<&mut Self>,
106 len: usize,
107 ) -> Result<Response<Buffer>, Error> {
108 self.read_into_box(len).await.map(Response::AllocatedBox)
109 }
110
111 async fn read_in_data_packet(
113 mut self: Pin<&mut Self>,
114 len: u32,
115 buffer: Option<Buffer>,
116 ) -> Result<Response<Buffer>, Error> {
117 self.as_mut()
119 .project()
120 .reader
121 .read_exact_into_buffer(4)
122 .await?;
123
124 let len = (len - 4) as usize;
125
126 if let Some(mut buffer) = buffer {
127 match buffer.get_buffer() {
128 super::Buffer::Vector(vec) => {
129 read_exact_to_vec(&mut self.project().reader, vec, len).await?;
130 Ok(Response::Buffer(buffer))
131 }
132 super::Buffer::Slice(slice) => {
133 if slice.len() >= len {
134 self.project().reader.read_exact(slice).await?;
135 Ok(Response::Buffer(buffer))
136 } else {
137 self.read_in_data_packet_fallback(len).await
138 }
139 }
140 super::Buffer::Bytes(bytes) => {
141 read_exact_to_bytes(&mut self.project().reader, bytes, len).await?;
142 Ok(Response::Buffer(buffer))
143 }
144 }
145 } else {
146 self.read_in_data_packet_fallback(len).await
147 }
148 }
149
150 async fn read_in_packet(self: Pin<&mut Self>, len: u32) -> Result<Response<Buffer>, Error> {
152 let response: response::Response = self.read_and_deserialize(len as usize).await?;
153
154 Ok(Response::Header(response.response_inner))
155 }
156
157 async fn read_in_extended_reply(
159 self: Pin<&mut Self>,
160 len: u32,
161 ) -> Result<Response<Buffer>, Error> {
162 self.read_into_box(len as usize)
163 .await
164 .map(Response::ExtendedReply)
165 }
166
167 pub async fn read_in_one_packet_pinned(mut self: Pin<&mut Self>) -> Result<(), Error> {
199 let this = self.as_mut().project();
200 let drain = this.reader.read_exact_into_buffer(9).await?;
201 let (len, packet_type, response_id): (u32, u8, u32) = from_bytes(&drain)?.0;
202
203 let len = len - 5;
204
205 let res = this.shared_data.responses().get(response_id);
206
207 let callback = match res {
208 Ok(callback) => callback,
209
210 Err(err) => {
212 drop(drain);
213
214 return self.consume_packet(len, err).await;
217 }
218 };
219
220 let response = if response::Response::is_data(packet_type) {
221 drop(drain);
222
223 let buffer = match callback.take_input() {
224 Ok(buffer) => buffer,
225 Err(err) => {
226 return self.consume_packet(len, err.into()).await;
229 }
230 };
231 self.read_in_data_packet(len, buffer).await?
232 } else if response::Response::is_extended_reply(packet_type) {
233 drop(drain);
234
235 self.read_in_extended_reply(len).await?
236 } else {
237 drain.subdrain(4);
240
241 self.read_in_packet(len + 5).await?
242 };
243
244 let res = callback.done(response);
245
246 if ArenaArc::strong_count(&callback) == 2 {
262 ArenaArc::remove(&callback);
263 }
264
265 Ok(res?)
266 }
267
268 pub async fn ready_for_read_pinned(self: Pin<&mut Self>) -> Result<(), io::Error> {
277 if self.project().reader.fill_buf().await?.is_empty() {
278 Err(io::Error::new(io::ErrorKind::UnexpectedEof, ""))
280 } else {
281 Ok(())
282 }
283 }
284}
285
286impl<R, Buffer, Q, Auxiliary> ReadEnd<R, Buffer, Q, Auxiliary>
287where
288 Self: Unpin,
289 R: AsyncRead,
290 Buffer: ToBuffer + 'static + Send + Sync,
291{
292 pub async fn receive_server_hello(&mut self) -> Result<Extensions, Error> {
295 Pin::new(self).receive_server_hello_pinned().await
296 }
297
298 pub async fn read_in_one_packet(&mut self) -> Result<(), Error> {
313 Pin::new(self).read_in_one_packet_pinned().await
314 }
315
316 pub async fn ready_for_read(&mut self) -> Result<(), io::Error> {
325 Pin::new(self).ready_for_read_pinned().await
326 }
327}
328
329impl<R, Buffer, Q, Auxiliary> ReadEnd<R, Buffer, Q, Auxiliary> {
330 pub fn get_shared_data(&self) -> &SharedData<Buffer, Q, Auxiliary> {
332 &self.shared_data
333 }
334}