openssh_sftp_client/
sftp.rs

1use crate::{
2    auxiliary,
3    file::{File, OpenOptions},
4    fs::Fs,
5    lowlevel, tasks,
6    utils::{ErrorExt, ResultExt},
7    Error, MpscQueue, SftpOptions, SharedData, WriteEnd, WriteEndWithCachedId,
8};
9
10use auxiliary::Auxiliary;
11use lowlevel::{connect, Extensions};
12use tasks::{create_flush_task, create_read_task};
13
14use std::{
15    any::Any, convert::TryInto, fmt, future::Future, ops::Deref, path::Path, pin::Pin, sync::Arc,
16};
17
18use derive_destructure2::destructure;
19use tokio::{
20    io::{AsyncRead, AsyncWrite},
21    runtime::Handle,
22    sync::oneshot::Receiver,
23    task::JoinHandle,
24};
25use tokio_io_utility::assert_send;
26
27#[cfg(feature = "openssh")]
28mod openssh_session;
29
30#[cfg(feature = "openssh")]
31pub use openssh_session::{CheckOpensshConnection, OpensshSession};
32
33#[derive(Debug, destructure)]
34pub(super) struct SftpHandle(SharedData);
35
36impl Deref for SftpHandle {
37    type Target = SharedData;
38
39    fn deref(&self) -> &Self::Target {
40        &self.0
41    }
42}
43
44impl SftpHandle {
45    fn new(shared_data: &SharedData) -> Self {
46        // Inc active_user_count for the same reason as Self::clone
47        shared_data.get_auxiliary().inc_active_user_count();
48
49        Self(shared_data.clone())
50    }
51
52    /// Takes `self` by value to ensure active_user_count get inc/dec properly.
53    pub(super) fn write_end(self) -> WriteEndWithCachedId {
54        // WriteEndWithCachedId also inc/dec active_user_count, so it's ok
55        // to destructure self here.
56        WriteEndWithCachedId::new(WriteEnd::new(self.destructure().0))
57    }
58}
59
60impl Clone for SftpHandle {
61    fn clone(&self) -> Self {
62        self.0.get_auxiliary().inc_active_user_count();
63        Self(self.0.clone())
64    }
65}
66
67impl Drop for SftpHandle {
68    fn drop(&mut self) {
69        self.0.get_auxiliary().dec_active_user_count();
70    }
71}
72
73/// A file-oriented channel to a remote host.
74#[derive(Debug)]
75pub struct Sftp {
76    handle: SftpHandle,
77    flush_task: JoinHandle<Result<(), Error>>,
78    read_task: JoinHandle<Result<(), Error>>,
79}
80
81/// Auxiliary data for [`Sftp`].
82#[non_exhaustive]
83pub enum SftpAuxiliaryData {
84    /// No auxiliary data.
85    None,
86    /// Store any `Box`ed value.
87    Boxed(Box<dyn Any + Send + Sync + 'static>),
88    /// Store any `Pin`ed `Future`.
89    PinnedFuture(Pin<Box<dyn Future<Output = ()> + Send + Sync + 'static>>),
90    /// Store any `Arc`ed value.
91    Arced(Arc<dyn Any + Send + Sync + 'static>),
92    /// Store [`OpensshSession`] with in an `Arc`.
93    #[cfg(feature = "openssh")]
94    ArcedOpensshSession(Arc<OpensshSession>),
95}
96
97impl fmt::Debug for SftpAuxiliaryData {
98    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
99        use SftpAuxiliaryData::*;
100
101        match self {
102            None => f.write_str("None"),
103            Boxed(_) => f.write_str("Boxed(boxed_any)"),
104            PinnedFuture(_) => f.write_str("PinnedFuture"),
105            Arced(_) => f.write_str("Arced(arced_any)"),
106            #[cfg(feature = "openssh")]
107            ArcedOpensshSession(session) => write!(f, "ArcedOpensshSession({session:?})"),
108        }
109    }
110}
111
112impl Sftp {
113    /// Create [`Sftp`].
114    pub async fn new<W: AsyncWrite + Send + 'static, R: AsyncRead + Send + 'static>(
115        stdin: W,
116        stdout: R,
117        options: SftpOptions,
118    ) -> Result<Self, Error> {
119        Self::new_with_auxiliary(stdin, stdout, options, SftpAuxiliaryData::None).await
120    }
121
122    /// Create [`Sftp`] with some auxiliary data.
123    ///
124    /// The auxiliary data will be dropped after all sftp requests has been
125    /// sent(flush_task), all responses processed (read_task) and [`Sftp`] has
126    /// been dropped.
127    ///
128    /// If you want to get back the data, you can simply use
129    /// [`SftpAuxiliaryData::Arced`] and then stores an [`Arc`] elsewhere.
130    ///
131    /// Once the sftp tasks is completed and [`Sftp`] is dropped, you can call
132    /// [`Arc::try_unwrap`] to get back the exclusive ownership of it.
133    pub async fn new_with_auxiliary<
134        W: AsyncWrite + Send + 'static,
135        R: AsyncRead + Send + 'static,
136    >(
137        stdin: W,
138        stdout: R,
139        options: SftpOptions,
140        auxiliary: SftpAuxiliaryData,
141    ) -> Result<Self, Error> {
142        assert_send(async move {
143            let write_end_buffer_size = options.get_write_end_buffer_size();
144
145            let write_end = assert_send(Self::connect(
146                write_end_buffer_size.get(),
147                options.get_max_pending_requests(),
148                auxiliary,
149                options.get_tokio_compat_file_write_limit(),
150            ))?;
151
152            let flush_task = create_flush_task(
153                stdin,
154                SharedData::clone(&write_end),
155                write_end_buffer_size,
156                options.get_flush_interval(),
157            );
158
159            let (rx, read_task) = create_read_task(
160                stdout,
161                options.get_read_end_buffer_size(),
162                SharedData::clone(&write_end),
163            );
164
165            Self::init(flush_task, read_task, write_end, rx, &options).await
166        })
167        .await
168    }
169
170    fn connect(
171        write_end_buffer_size: usize,
172        max_pending_requests: u16,
173        auxiliary: SftpAuxiliaryData,
174        tokio_compat_file_write_limit: usize,
175    ) -> Result<WriteEnd, Error> {
176        connect(
177            MpscQueue::with_capacity(write_end_buffer_size),
178            Auxiliary::new(
179                max_pending_requests,
180                auxiliary,
181                tokio_compat_file_write_limit,
182                Handle::current(),
183            ),
184        )
185    }
186
187    async fn init(
188        flush_task: JoinHandle<Result<(), Error>>,
189        read_task: JoinHandle<Result<(), Error>>,
190        write_end: WriteEnd,
191        rx: Receiver<Extensions>,
192        options: &SftpOptions,
193    ) -> Result<Self, Error> {
194        // Create sftp here.
195        //
196        // It would also gracefully shutdown `flush_task` and `read_task` if
197        // the future is cancelled or error is encounted.
198        let sftp = Self {
199            handle: SftpHandle::new(&write_end),
200            flush_task,
201            read_task,
202        };
203
204        let write_end = WriteEndWithCachedId::new(write_end);
205
206        let extensions = if let Ok(extensions) = rx.await {
207            extensions
208        } else {
209            drop(write_end);
210
211            // Wait on flush_task and read_task to get a more detailed error message.
212            sftp.close().await?;
213            std::unreachable!("Error must have occurred in either read_task or flush_task")
214        };
215
216        match Self::set_limits(write_end, options, extensions).await {
217            Err(Error::BackgroundTaskFailure(_)) => {
218                // Wait on flush_task and read_task to get a more detailed error message.
219                sftp.close().await?;
220                std::unreachable!("Error must have occurred in either read_task or flush_task")
221            }
222            res => res?,
223        }
224
225        Ok(sftp)
226    }
227
228    async fn set_limits(
229        mut write_end: WriteEndWithCachedId,
230        options: &SftpOptions,
231        extensions: Extensions,
232    ) -> Result<(), Error> {
233        let default_download_buflen = lowlevel::OPENSSH_PORTABLE_DEFAULT_DOWNLOAD_BUFLEN as u64;
234        let default_upload_buflen = lowlevel::OPENSSH_PORTABLE_DEFAULT_UPLOAD_BUFLEN as u64;
235
236        // sftp can accept packet as large as u32::MAX, the header itself
237        // is at least 9 bytes long.
238        let default_max_packet_len = u32::MAX - 9;
239
240        let (read_len, write_len, packet_len) = if extensions.contains(Extensions::LIMITS) {
241            let mut limits = write_end
242                .send_request(|write_end, id| Ok(write_end.send_limits_request(id)?.wait()))
243                .await?;
244
245            if limits.read_len == 0 {
246                limits.read_len = default_download_buflen;
247            }
248
249            if limits.write_len == 0 {
250                limits.write_len = default_upload_buflen;
251            }
252
253            (
254                limits.read_len,
255                limits.write_len,
256                limits
257                    .packet_len
258                    .try_into()
259                    .unwrap_or(default_max_packet_len),
260            )
261        } else {
262            (
263                default_download_buflen,
264                default_upload_buflen,
265                default_max_packet_len,
266            )
267        };
268
269        // Each read/write request also has a header and contains a handle,
270        // which is 4-byte long for openssh but can be at most 256 bytes long
271        // for other implementations.
272
273        let read_len = read_len.try_into().unwrap_or(packet_len - 300);
274        let read_len = options
275            .get_max_read_len()
276            .map(|v| v.min(read_len))
277            .unwrap_or(read_len);
278
279        let write_len = write_len.try_into().unwrap_or(packet_len - 300);
280        let write_len = options
281            .get_max_write_len()
282            .map(|v| v.min(write_len))
283            .unwrap_or(write_len);
284
285        let limits = auxiliary::Limits {
286            read_len,
287            write_len,
288        };
289
290        write_end
291            .get_auxiliary()
292            .conn_info
293            .set(auxiliary::ConnInfo { limits, extensions })
294            .expect("auxiliary.conn_info shall be uninitialized");
295
296        Ok(())
297    }
298
299    /// Close sftp connection
300    ///
301    /// If sftp is created using `Sftp::from_session`, then calling this
302    /// function would also await on `openssh::RemoteChild::wait` and
303    /// `openssh::Session::close` and propagate their error in
304    /// [`Sftp::close`].
305    pub async fn close(self) -> Result<(), Error> {
306        let Self {
307            handle,
308            flush_task,
309            read_task,
310        } = self;
311
312        let session = match &handle.get_auxiliary().auxiliary_data {
313            #[cfg(feature = "openssh")]
314            SftpAuxiliaryData::ArcedOpensshSession(session) => Some(Arc::clone(session)),
315            _ => None,
316        };
317
318        #[cfg(not(feature = "openssh"))]
319        {
320            // Help session infer generic T in Option<T>
321            let _: Option<()> = session;
322        }
323
324        // Drop handle.
325        drop(handle);
326
327        // Wait for responses for all requests buffered and sent.
328        let read_task_error = read_task.await.flatten().err();
329
330        // read_task would order the shutdown of read_task,
331        // so we just need to wait for it here.
332        let flush_task_error = flush_task.await.flatten().err();
333
334        let session_error: Option<Error> = match session {
335            #[cfg(feature = "openssh")]
336            Some(session) => Arc::try_unwrap(session)
337                .unwrap()
338                .recover_session_err()
339                .await
340                .err(),
341
342            #[cfg(not(feature = "openssh"))]
343            Some(_) => unreachable!(),
344
345            None => None,
346        };
347
348        match (read_task_error, flush_task_error, session_error) {
349            (Some(err1), Some(err2), Some(err3)) => Err(err1.error_on_cleanup3(err2, err3)),
350            (Some(err1), Some(err2), None)
351            | (Some(err1), None, Some(err2))
352            | (None, Some(err1), Some(err2)) => Err(err1.error_on_cleanup(err2)),
353            (Some(err), None, None) | (None, Some(err), None) | (None, None, Some(err)) => Err(err),
354            (None, None, None) => Ok(()),
355        }
356    }
357
358    /// Return a new [`OpenOptions`] object.
359    pub fn options(&self) -> OpenOptions {
360        OpenOptions::new(self.handle.clone())
361    }
362
363    /// Opens a file in write-only mode.
364    ///
365    /// This function will create a file if it does not exist, and will truncate
366    /// it if it does.
367    pub async fn create(&self, path: impl AsRef<Path>) -> Result<File, Error> {
368        async fn inner(this: &Sftp, path: &Path) -> Result<File, Error> {
369            this.options()
370                .write(true)
371                .create(true)
372                .truncate(true)
373                .open(path)
374                .await
375        }
376
377        inner(self, path.as_ref()).await
378    }
379
380    /// Attempts to open a file in read-only mode.
381    pub async fn open(&self, path: impl AsRef<Path>) -> Result<File, Error> {
382        async fn inner(this: &Sftp, path: &Path) -> Result<File, Error> {
383            this.options().read(true).open(path).await
384        }
385
386        inner(self, path.as_ref()).await
387    }
388
389    /// [`Fs`] defaults to the current working dir set by remote `sftp-server`,
390    /// which usually is the home directory.
391    pub fn fs(&self) -> Fs {
392        Fs::new(self.handle.clone().write_end(), "".into())
393    }
394
395    /// Check if the remote server supports the expand path extension.
396    ///
397    /// If it returns true, then [`Fs::canonicalize`] with expand path is supported.
398    pub fn support_expand_path(&self) -> bool {
399        self.handle
400            .get_auxiliary()
401            .extensions()
402            .contains(Extensions::EXPAND_PATH)
403    }
404
405    /// Check if the remote server supports the fsync extension.
406    ///
407    /// If it returns true, then [`File::sync_all`] is supported.
408    pub fn support_fsync(&self) -> bool {
409        self.handle
410            .get_auxiliary()
411            .extensions()
412            .contains(Extensions::FSYNC)
413    }
414
415    /// Check if the remote server supports the hardlink extension.
416    ///
417    /// If it returns true, then [`Fs::hard_link`] is supported.
418    pub fn support_hardlink(&self) -> bool {
419        self.handle
420            .get_auxiliary()
421            .extensions()
422            .contains(Extensions::HARDLINK)
423    }
424
425    /// Check if the remote server supports the posix rename extension.
426    ///
427    /// If it returns true, then [`Fs::rename`] will use posix rename.
428    pub fn support_posix_rename(&self) -> bool {
429        self.handle
430            .get_auxiliary()
431            .extensions()
432            .contains(Extensions::POSIX_RENAME)
433    }
434
435    /// Check if the remote server supports the copy data extension.
436    ///
437    /// If it returns true, then [`File::copy_to`] and [`File::copy_all_to`] are supported.
438    pub fn support_copy(&self) -> bool {
439        self.handle
440            .get_auxiliary()
441            .extensions()
442            .contains(Extensions::COPY_DATA)
443    }
444}
445
446#[cfg(feature = "__ci-tests")]
447impl Sftp {
448    /// The maximum amount of bytes that can be written in one request.
449    /// Writing more than that, then your write will be split into multiple requests
450    ///
451    /// If [`Sftp::max_buffered_write`] is less than [`max_atomic_write_len`],
452    /// then the direct write is enabled and [`Sftp::max_write_len`] must be
453    /// less than [`max_atomic_write_len`].
454    pub fn max_write_len(&self) -> u32 {
455        self.handle.get_auxiliary().limits().write_len
456    }
457
458    /// The maximum amount of bytes that can be read in one request.
459    /// Reading more than that, then your read will be split into multiple requests
460    pub fn max_read_len(&self) -> u32 {
461        self.handle.get_auxiliary().limits().read_len
462    }
463
464    /// Trigger flush task manually.
465    pub fn manual_flush(&self) {
466        self.handle.get_auxiliary().trigger_flushing()
467    }
468}