1use crate::{
2 auxiliary,
3 file::{File, OpenOptions},
4 fs::Fs,
5 lowlevel, tasks,
6 utils::{ErrorExt, ResultExt},
7 Error, MpscQueue, SftpOptions, SharedData, WriteEnd, WriteEndWithCachedId,
8};
9
10use auxiliary::Auxiliary;
11use lowlevel::{connect, Extensions};
12use tasks::{create_flush_task, create_read_task};
13
14use std::{
15 any::Any, convert::TryInto, fmt, future::Future, ops::Deref, path::Path, pin::Pin, sync::Arc,
16};
17
18use derive_destructure2::destructure;
19use tokio::{
20 io::{AsyncRead, AsyncWrite},
21 runtime::Handle,
22 sync::oneshot::Receiver,
23 task::JoinHandle,
24};
25use tokio_io_utility::assert_send;
26
27#[cfg(feature = "openssh")]
28mod openssh_session;
29
30#[cfg(feature = "openssh")]
31pub use openssh_session::{CheckOpensshConnection, OpensshSession};
32
33#[derive(Debug, destructure)]
34pub(super) struct SftpHandle(SharedData);
35
36impl Deref for SftpHandle {
37 type Target = SharedData;
38
39 fn deref(&self) -> &Self::Target {
40 &self.0
41 }
42}
43
44impl SftpHandle {
45 fn new(shared_data: &SharedData) -> Self {
46 shared_data.get_auxiliary().inc_active_user_count();
48
49 Self(shared_data.clone())
50 }
51
52 pub(super) fn write_end(self) -> WriteEndWithCachedId {
54 WriteEndWithCachedId::new(WriteEnd::new(self.destructure().0))
57 }
58}
59
60impl Clone for SftpHandle {
61 fn clone(&self) -> Self {
62 self.0.get_auxiliary().inc_active_user_count();
63 Self(self.0.clone())
64 }
65}
66
67impl Drop for SftpHandle {
68 fn drop(&mut self) {
69 self.0.get_auxiliary().dec_active_user_count();
70 }
71}
72
73#[derive(Debug)]
75pub struct Sftp {
76 handle: SftpHandle,
77 flush_task: JoinHandle<Result<(), Error>>,
78 read_task: JoinHandle<Result<(), Error>>,
79}
80
81#[non_exhaustive]
83pub enum SftpAuxiliaryData {
84 None,
86 Boxed(Box<dyn Any + Send + Sync + 'static>),
88 PinnedFuture(Pin<Box<dyn Future<Output = ()> + Send + Sync + 'static>>),
90 Arced(Arc<dyn Any + Send + Sync + 'static>),
92 #[cfg(feature = "openssh")]
94 ArcedOpensshSession(Arc<OpensshSession>),
95}
96
97impl fmt::Debug for SftpAuxiliaryData {
98 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
99 use SftpAuxiliaryData::*;
100
101 match self {
102 None => f.write_str("None"),
103 Boxed(_) => f.write_str("Boxed(boxed_any)"),
104 PinnedFuture(_) => f.write_str("PinnedFuture"),
105 Arced(_) => f.write_str("Arced(arced_any)"),
106 #[cfg(feature = "openssh")]
107 ArcedOpensshSession(session) => write!(f, "ArcedOpensshSession({session:?})"),
108 }
109 }
110}
111
112impl Sftp {
113 pub async fn new<W: AsyncWrite + Send + 'static, R: AsyncRead + Send + 'static>(
115 stdin: W,
116 stdout: R,
117 options: SftpOptions,
118 ) -> Result<Self, Error> {
119 Self::new_with_auxiliary(stdin, stdout, options, SftpAuxiliaryData::None).await
120 }
121
122 pub async fn new_with_auxiliary<
134 W: AsyncWrite + Send + 'static,
135 R: AsyncRead + Send + 'static,
136 >(
137 stdin: W,
138 stdout: R,
139 options: SftpOptions,
140 auxiliary: SftpAuxiliaryData,
141 ) -> Result<Self, Error> {
142 assert_send(async move {
143 let write_end_buffer_size = options.get_write_end_buffer_size();
144
145 let write_end = assert_send(Self::connect(
146 write_end_buffer_size.get(),
147 options.get_max_pending_requests(),
148 auxiliary,
149 options.get_tokio_compat_file_write_limit(),
150 ))?;
151
152 let flush_task = create_flush_task(
153 stdin,
154 SharedData::clone(&write_end),
155 write_end_buffer_size,
156 options.get_flush_interval(),
157 );
158
159 let (rx, read_task) = create_read_task(
160 stdout,
161 options.get_read_end_buffer_size(),
162 SharedData::clone(&write_end),
163 );
164
165 Self::init(flush_task, read_task, write_end, rx, &options).await
166 })
167 .await
168 }
169
170 fn connect(
171 write_end_buffer_size: usize,
172 max_pending_requests: u16,
173 auxiliary: SftpAuxiliaryData,
174 tokio_compat_file_write_limit: usize,
175 ) -> Result<WriteEnd, Error> {
176 connect(
177 MpscQueue::with_capacity(write_end_buffer_size),
178 Auxiliary::new(
179 max_pending_requests,
180 auxiliary,
181 tokio_compat_file_write_limit,
182 Handle::current(),
183 ),
184 )
185 }
186
187 async fn init(
188 flush_task: JoinHandle<Result<(), Error>>,
189 read_task: JoinHandle<Result<(), Error>>,
190 write_end: WriteEnd,
191 rx: Receiver<Extensions>,
192 options: &SftpOptions,
193 ) -> Result<Self, Error> {
194 let sftp = Self {
199 handle: SftpHandle::new(&write_end),
200 flush_task,
201 read_task,
202 };
203
204 let write_end = WriteEndWithCachedId::new(write_end);
205
206 let extensions = if let Ok(extensions) = rx.await {
207 extensions
208 } else {
209 drop(write_end);
210
211 sftp.close().await?;
213 std::unreachable!("Error must have occurred in either read_task or flush_task")
214 };
215
216 match Self::set_limits(write_end, options, extensions).await {
217 Err(Error::BackgroundTaskFailure(_)) => {
218 sftp.close().await?;
220 std::unreachable!("Error must have occurred in either read_task or flush_task")
221 }
222 res => res?,
223 }
224
225 Ok(sftp)
226 }
227
228 async fn set_limits(
229 mut write_end: WriteEndWithCachedId,
230 options: &SftpOptions,
231 extensions: Extensions,
232 ) -> Result<(), Error> {
233 let default_download_buflen = lowlevel::OPENSSH_PORTABLE_DEFAULT_DOWNLOAD_BUFLEN as u64;
234 let default_upload_buflen = lowlevel::OPENSSH_PORTABLE_DEFAULT_UPLOAD_BUFLEN as u64;
235
236 let default_max_packet_len = u32::MAX - 9;
239
240 let (read_len, write_len, packet_len) = if extensions.contains(Extensions::LIMITS) {
241 let mut limits = write_end
242 .send_request(|write_end, id| Ok(write_end.send_limits_request(id)?.wait()))
243 .await?;
244
245 if limits.read_len == 0 {
246 limits.read_len = default_download_buflen;
247 }
248
249 if limits.write_len == 0 {
250 limits.write_len = default_upload_buflen;
251 }
252
253 (
254 limits.read_len,
255 limits.write_len,
256 limits
257 .packet_len
258 .try_into()
259 .unwrap_or(default_max_packet_len),
260 )
261 } else {
262 (
263 default_download_buflen,
264 default_upload_buflen,
265 default_max_packet_len,
266 )
267 };
268
269 let read_len = read_len.try_into().unwrap_or(packet_len - 300);
274 let read_len = options
275 .get_max_read_len()
276 .map(|v| v.min(read_len))
277 .unwrap_or(read_len);
278
279 let write_len = write_len.try_into().unwrap_or(packet_len - 300);
280 let write_len = options
281 .get_max_write_len()
282 .map(|v| v.min(write_len))
283 .unwrap_or(write_len);
284
285 let limits = auxiliary::Limits {
286 read_len,
287 write_len,
288 };
289
290 write_end
291 .get_auxiliary()
292 .conn_info
293 .set(auxiliary::ConnInfo { limits, extensions })
294 .expect("auxiliary.conn_info shall be uninitialized");
295
296 Ok(())
297 }
298
299 pub async fn close(self) -> Result<(), Error> {
306 let Self {
307 handle,
308 flush_task,
309 read_task,
310 } = self;
311
312 let session = match &handle.get_auxiliary().auxiliary_data {
313 #[cfg(feature = "openssh")]
314 SftpAuxiliaryData::ArcedOpensshSession(session) => Some(Arc::clone(session)),
315 _ => None,
316 };
317
318 #[cfg(not(feature = "openssh"))]
319 {
320 let _: Option<()> = session;
322 }
323
324 drop(handle);
326
327 let read_task_error = read_task.await.flatten().err();
329
330 let flush_task_error = flush_task.await.flatten().err();
333
334 let session_error: Option<Error> = match session {
335 #[cfg(feature = "openssh")]
336 Some(session) => Arc::try_unwrap(session)
337 .unwrap()
338 .recover_session_err()
339 .await
340 .err(),
341
342 #[cfg(not(feature = "openssh"))]
343 Some(_) => unreachable!(),
344
345 None => None,
346 };
347
348 match (read_task_error, flush_task_error, session_error) {
349 (Some(err1), Some(err2), Some(err3)) => Err(err1.error_on_cleanup3(err2, err3)),
350 (Some(err1), Some(err2), None)
351 | (Some(err1), None, Some(err2))
352 | (None, Some(err1), Some(err2)) => Err(err1.error_on_cleanup(err2)),
353 (Some(err), None, None) | (None, Some(err), None) | (None, None, Some(err)) => Err(err),
354 (None, None, None) => Ok(()),
355 }
356 }
357
358 pub fn options(&self) -> OpenOptions {
360 OpenOptions::new(self.handle.clone())
361 }
362
363 pub async fn create(&self, path: impl AsRef<Path>) -> Result<File, Error> {
368 async fn inner(this: &Sftp, path: &Path) -> Result<File, Error> {
369 this.options()
370 .write(true)
371 .create(true)
372 .truncate(true)
373 .open(path)
374 .await
375 }
376
377 inner(self, path.as_ref()).await
378 }
379
380 pub async fn open(&self, path: impl AsRef<Path>) -> Result<File, Error> {
382 async fn inner(this: &Sftp, path: &Path) -> Result<File, Error> {
383 this.options().read(true).open(path).await
384 }
385
386 inner(self, path.as_ref()).await
387 }
388
389 pub fn fs(&self) -> Fs {
392 Fs::new(self.handle.clone().write_end(), "".into())
393 }
394
395 pub fn support_expand_path(&self) -> bool {
399 self.handle
400 .get_auxiliary()
401 .extensions()
402 .contains(Extensions::EXPAND_PATH)
403 }
404
405 pub fn support_fsync(&self) -> bool {
409 self.handle
410 .get_auxiliary()
411 .extensions()
412 .contains(Extensions::FSYNC)
413 }
414
415 pub fn support_hardlink(&self) -> bool {
419 self.handle
420 .get_auxiliary()
421 .extensions()
422 .contains(Extensions::HARDLINK)
423 }
424
425 pub fn support_posix_rename(&self) -> bool {
429 self.handle
430 .get_auxiliary()
431 .extensions()
432 .contains(Extensions::POSIX_RENAME)
433 }
434
435 pub fn support_copy(&self) -> bool {
439 self.handle
440 .get_auxiliary()
441 .extensions()
442 .contains(Extensions::COPY_DATA)
443 }
444}
445
446#[cfg(feature = "__ci-tests")]
447impl Sftp {
448 pub fn max_write_len(&self) -> u32 {
455 self.handle.get_auxiliary().limits().write_len
456 }
457
458 pub fn max_read_len(&self) -> u32 {
461 self.handle.get_auxiliary().limits().read_len
462 }
463
464 pub fn manual_flush(&self) {
466 self.handle.get_auxiliary().trigger_flushing()
467 }
468}