Skip to main content

remotefs_ssh/ssh/backend/
russh.rs

1//! [russh](https://docs.rs/russh/latest/russh/) backend for `remotefs-ssh`.
2
3mod auth;
4mod scp;
5
6use std::borrow::Cow;
7use std::io::{Read, Seek, Write};
8use std::path::{Path, PathBuf};
9use std::sync::Arc;
10
11use remotefs::fs::{Metadata, ReadStream, WriteStream};
12use remotefs::{File, RemoteError, RemoteErrorType, RemoteResult};
13use russh::client::{Handle, Handler};
14use russh::keys::{Algorithm, PublicKey};
15use russh::{Disconnect, client};
16use russh_sftp::client::SftpSession;
17use tokio::runtime::Runtime;
18
19use super::{SshSession, WriteMode};
20use crate::SshOpts;
21use crate::ssh::backend::Sftp;
22use crate::ssh::config::Config;
23use crate::ssh::key_method::MethodType;
24
25/// The default SSH client handler for russh.
26///
27/// Accepts all server host keys. Host key verification should be implemented
28/// by the caller if stricter security is required.
29///
30/// You can implement your own [`Handler`] and use it with [`RusshSession`] if you want a different behaviour.
31#[derive(Default)]
32pub struct NoCheckServerKey;
33
34impl Handler for NoCheckServerKey {
35    type Error = russh::Error;
36
37    async fn check_server_key(
38        &mut self,
39        _server_public_key: &PublicKey,
40    ) -> Result<bool, Self::Error> {
41        Ok(true)
42    }
43}
44
45/// [`russh`](https://docs.rs/russh/latest/russh) session.
46pub struct RusshSession<T>
47where
48    T: Handler + Default + Send + 'static,
49{
50    runtime: Arc<Runtime>,
51    session: Handle<T>,
52}
53
54/// SFTP handle for russh.
55pub struct RusshSftp {
56    runtime: Arc<Runtime>,
57    session: Arc<SftpSession>,
58}
59
60impl<T> SshSession for RusshSession<T>
61where
62    T: Handler + Default + Send + 'static,
63{
64    type Sftp = RusshSftp;
65
66    fn connect(opts: &SshOpts) -> RemoteResult<Self> {
67        let runtime = opts.runtime.as_ref().cloned().ok_or_else(|| {
68            RemoteError::new_ex(
69                RemoteErrorType::UnsupportedFeature,
70                "RusshSession requires a Tokio runtime",
71            )
72        })?;
73
74        let ssh_config = Config::try_from(opts)?;
75        debug!("Connecting to '{}'", ssh_config.address);
76
77        let mut config = client::Config {
78            inactivity_timeout: Some(ssh_config.connection_timeout),
79            ..Default::default()
80        };
81
82        // Apply algorithm preferences from ssh config
83        apply_config_algo_prefs(&mut config, &ssh_config);
84
85        // Apply algorithm preferences from opts
86        apply_opts_algo_prefs(&mut config, opts);
87
88        let config = Arc::new(config);
89
90        let mut session = runtime
91            .block_on(async {
92                client::connect(config, ssh_config.address.as_str(), T::default()).await
93            })
94            .map_err(|err| {
95                let msg = format!("SSH connection failed: {err:?}");
96                error!("{msg}");
97                RemoteError::new_ex(RemoteErrorType::ConnectionError, msg)
98            })?;
99
100        // Authenticate
101        auth::authenticate(&mut session, &runtime, opts, &ssh_config)?;
102
103        Ok(Self { runtime, session })
104    }
105
106    fn disconnect(&self) -> RemoteResult<()> {
107        self.runtime
108            .block_on(async {
109                self.session
110                    .disconnect(Disconnect::ByApplication, "Closed by user", "en_US")
111                    .await
112            })
113            .map_err(|err| {
114                log::error!("failed to disconnect {err}");
115                RemoteError::new_ex(RemoteErrorType::ConnectionError, err.to_string())
116            })
117    }
118
119    fn banner(&self) -> RemoteResult<Option<String>> {
120        // russh delivers the auth banner via the Handler::auth_banner callback
121        // during authentication, but does not expose it from the Handle after the fact.
122        // <https://docs.rs/russh/latest/russh/client/struct.Handle.html>
123        // <https://docs.rs/russh/latest/russh/client/trait.Handler.html#method.auth_banner>
124        Ok(None)
125    }
126
127    fn authenticated(&self) -> RemoteResult<bool> {
128        Ok(!self.session.is_closed())
129    }
130
131    fn cmd<S>(&mut self, cmd: S) -> RemoteResult<(u32, String)>
132    where
133        S: AsRef<str>,
134    {
135        let cmd = cmd.as_ref();
136        trace!("Running command: {cmd}");
137
138        // Escape single quotes and wrap in sh -c for consistent shell behavior.
139        // Without this, commands like "cd /some/dir; somecommand" fail if the
140        // remote user's login shell is fish or another non-POSIX shell.
141        let escaped = cmd.replace('\'', r#"'\''"#);
142        let wrapped = format!("sh -c '{escaped}'");
143
144        self.runtime
145            .block_on(async { perform_shell_cmd(&self.session, &wrapped).await })
146    }
147
148    fn scp_recv(&self, path: &Path) -> RemoteResult<Box<dyn Read + Send>> {
149        self.runtime
150            .block_on(async { scp::recv(&self.session, path).await })
151    }
152
153    fn scp_send(
154        &self,
155        remote_path: &Path,
156        mode: i32,
157        size: u64,
158        _times: Option<(u64, u64)>,
159    ) -> RemoteResult<Box<dyn Write + Send>> {
160        let runtime = self.runtime.clone();
161        self.runtime
162            .block_on(async { scp::send(&self.session, remote_path, mode, size, runtime).await })
163    }
164
165    fn sftp(&self) -> RemoteResult<Self::Sftp> {
166        let channel = self
167            .runtime
168            .block_on(async {
169                let channel = self.session.channel_open_session().await?;
170                channel.request_subsystem(true, "sftp").await?;
171                Ok(channel)
172            })
173            .map_err(|err: russh::Error| {
174                error!("Failed to init SFTP session: {err}");
175                RemoteError::new_ex(RemoteErrorType::ProtocolError, err.to_string())
176            })?;
177
178        self.runtime
179            .block_on(async { SftpSession::new(channel.into_stream()).await })
180            .map(|session| RusshSftp {
181                runtime: self.runtime.clone(),
182                session: Arc::new(session),
183            })
184            .map_err(|err| {
185                error!("Failed to init SFTP session: {err}");
186                RemoteError::new_ex(RemoteErrorType::ProtocolError, err.to_string())
187            })
188    }
189}
190
191impl Sftp for RusshSftp {
192    fn mkdir(&self, path: &Path, mode: i32) -> RemoteResult<()> {
193        let path_str = path.to_string_lossy().to_string();
194        self.runtime.block_on(async {
195            self.session.create_dir(&path_str).await.map_err(|err| {
196                RemoteError::new_ex(
197                    RemoteErrorType::FileCreateDenied,
198                    format!("Could not create directory '{}': {err}", path.display()),
199                )
200            })?;
201            // create_dir does not set permissions; apply them separately
202            let mut attrs = russh_sftp::protocol::FileAttributes::empty();
203            attrs.permissions = Some(mode as u32 & 0o7777);
204            self.session
205                .set_metadata(&path_str, attrs)
206                .await
207                .map_err(|err| {
208                    RemoteError::new_ex(
209                        RemoteErrorType::ProtocolError,
210                        format!("Could not set permissions on '{}': {err}", path.display()),
211                    )
212                })
213        })
214    }
215
216    fn open_read(&self, path: &Path) -> RemoteResult<ReadStream> {
217        let path_str = path.to_string_lossy().to_string();
218        let reader = PipelinedSftpReader::new(self.runtime.clone(), self.session.clone(), path_str)
219            .map_err(|err| {
220                RemoteError::new_ex(
221                    RemoteErrorType::ProtocolError,
222                    format!("Could not read file at '{}': {err}", path.display()),
223                )
224            })?;
225        Ok(ReadStream::from(Box::new(reader) as Box<dyn Read + Send>))
226    }
227
228    fn open_write(&self, path: &Path, flags: WriteMode, mode: i32) -> RemoteResult<WriteStream> {
229        let path_str = path.to_string_lossy().to_string();
230        self.runtime.block_on(async {
231            let open_flags = match flags {
232                WriteMode::Append => {
233                    russh_sftp::protocol::OpenFlags::WRITE
234                        | russh_sftp::protocol::OpenFlags::APPEND
235                        | russh_sftp::protocol::OpenFlags::CREATE
236                }
237                WriteMode::Truncate => {
238                    russh_sftp::protocol::OpenFlags::WRITE
239                        | russh_sftp::protocol::OpenFlags::CREATE
240                        | russh_sftp::protocol::OpenFlags::TRUNCATE
241                }
242            };
243
244            let mut attrs = russh_sftp::protocol::FileAttributes::empty();
245            attrs.permissions = Some(mode as u32 & 0o7777);
246
247            let file = self
248                .session
249                .open_with_flags_and_attributes(&path_str, open_flags, attrs)
250                .await
251                .map_err(|err| {
252                    RemoteError::new_ex(
253                        RemoteErrorType::ProtocolError,
254                        format!("Could not open file at '{}': {err}", path.display()),
255                    )
256                })?;
257
258            let writer = SftpFileWriter {
259                file,
260                runtime: self.runtime.clone(),
261            };
262            Ok(WriteStream::from(
263                Box::new(writer) as Box<dyn remotefs::fs::stream::WriteAndSeek>
264            ))
265        })
266    }
267
268    fn readdir<T>(&self, dirname: T) -> RemoteResult<Vec<File>>
269    where
270        T: AsRef<Path>,
271    {
272        let dirname = dirname.as_ref();
273        let dir_str = dirname.to_string_lossy().to_string();
274        self.runtime.block_on(async {
275            let entries = self.session.read_dir(&dir_str).await.map_err(|err| {
276                RemoteError::new_ex(
277                    RemoteErrorType::ProtocolError,
278                    format!("Could not read directory: {err}"),
279                )
280            })?;
281
282            let mut files = Vec::new();
283            for entry in entries {
284                let entry_path = dirname.join(entry.file_name());
285                let symlink = if entry.file_type().is_symlink() {
286                    match self
287                        .session
288                        .read_link(entry_path.to_string_lossy().as_ref())
289                        .await
290                    {
291                        Ok(target) => Some(PathBuf::from(target)),
292                        Err(err) => {
293                            error!(
294                                "Failed to read link of {} (even though it's a symlink): {err}",
295                                entry_path.display()
296                            );
297                            None
298                        }
299                    }
300                } else {
301                    None
302                };
303                files.push(make_fsentry(&entry_path, &entry.metadata(), symlink));
304            }
305
306            Ok(files)
307        })
308    }
309
310    fn realpath(&self, path: &Path) -> RemoteResult<PathBuf> {
311        let path_str = path.to_string_lossy().to_string();
312        self.runtime.block_on(async {
313            self.session
314                .canonicalize(&path_str)
315                .await
316                .map(PathBuf::from)
317                .map_err(|err| {
318                    RemoteError::new_ex(
319                        RemoteErrorType::ProtocolError,
320                        format!(
321                            "Could not resolve real path for '{}': {err}",
322                            path.display()
323                        ),
324                    )
325                })
326        })
327    }
328
329    fn rename(&self, src: &Path, dest: &Path) -> RemoteResult<()> {
330        let src_str = src.to_string_lossy().to_string();
331        let dest_str = dest.to_string_lossy().to_string();
332        self.runtime.block_on(async {
333            self.session
334                .rename(&src_str, &dest_str)
335                .await
336                .map_err(|err| {
337                    RemoteError::new_ex(
338                        RemoteErrorType::ProtocolError,
339                        format!("Could not rename file '{}': {err}", src.display()),
340                    )
341                })
342        })
343    }
344
345    fn rmdir(&self, path: &Path) -> RemoteResult<()> {
346        let path_str = path.to_string_lossy().to_string();
347        self.runtime.block_on(async {
348            self.session.remove_dir(&path_str).await.map_err(|err| {
349                RemoteError::new_ex(
350                    RemoteErrorType::CouldNotRemoveFile,
351                    format!("Could not remove directory '{}': {err}", path.display()),
352                )
353            })
354        })
355    }
356
357    fn setstat(&self, path: &Path, metadata: Metadata) -> RemoteResult<()> {
358        let path_str = path.to_string_lossy().to_string();
359        let attrs = metadata_to_file_attributes(metadata);
360        self.runtime.block_on(async {
361            self.session
362                .set_metadata(&path_str, attrs)
363                .await
364                .map_err(|err| {
365                    RemoteError::new_ex(
366                        RemoteErrorType::ProtocolError,
367                        format!(
368                            "Could not set file attributes for '{}': {err}",
369                            path.display()
370                        ),
371                    )
372                })
373        })
374    }
375
376    fn stat(&self, filename: &Path) -> RemoteResult<File> {
377        let path_str = filename.to_string_lossy().to_string();
378        self.runtime.block_on(async {
379            let attrs = self.session.metadata(&path_str).await.map_err(|err| {
380                RemoteError::new_ex(
381                    RemoteErrorType::ProtocolError,
382                    format!(
383                        "Could not get file attributes for '{}': {err}",
384                        filename.display()
385                    ),
386                )
387            })?;
388
389            let symlink = if attrs.is_symlink() {
390                match self.session.read_link(&path_str).await {
391                    Ok(target) => Some(PathBuf::from(target)),
392                    Err(err) => {
393                        error!(
394                            "Failed to read link of {} (even though it's a symlink): {err}",
395                            filename.display()
396                        );
397                        None
398                    }
399                }
400            } else {
401                None
402            };
403
404            Ok(make_fsentry(filename, &attrs, symlink))
405        })
406    }
407
408    fn symlink(&self, path: &Path, target: &Path) -> RemoteResult<()> {
409        let path_str = path.to_string_lossy().to_string();
410        let target_str = target.to_string_lossy().to_string();
411        self.runtime.block_on(async {
412            self.session
413                .symlink(&path_str, &target_str)
414                .await
415                .map_err(|err| {
416                    RemoteError::new_ex(
417                        RemoteErrorType::FileCreateDenied,
418                        format!("Could not create symlink '{}': {err}", path.display()),
419                    )
420                })
421        })
422    }
423
424    fn unlink(&self, path: &Path) -> RemoteResult<()> {
425        let path_str = path.to_string_lossy().to_string();
426        self.runtime.block_on(async {
427            self.session.remove_file(&path_str).await.map_err(|err| {
428                RemoteError::new_ex(
429                    RemoteErrorType::CouldNotRemoveFile,
430                    format!("Could not remove file '{}': {err}", path.display()),
431                )
432            })
433        })
434    }
435}
436
437/// Convert `remotefs::fs::Metadata` to `russh_sftp::protocol::FileAttributes`.
438fn metadata_to_file_attributes(metadata: Metadata) -> russh_sftp::protocol::FileAttributes {
439    let atime = metadata
440        .accessed
441        .and_then(|x| x.duration_since(std::time::UNIX_EPOCH).ok())
442        .map(|x| x.as_secs() as u32);
443    let mtime = metadata
444        .modified
445        .and_then(|x| x.duration_since(std::time::UNIX_EPOCH).ok())
446        .map(|x| x.as_secs() as u32);
447    russh_sftp::protocol::FileAttributes {
448        size: Some(metadata.size),
449        uid: metadata.uid,
450        user: None,
451        gid: metadata.gid,
452        group: None,
453        permissions: metadata.mode.map(u32::from),
454        atime,
455        mtime,
456    }
457}
458
459/// Build a `remotefs::File` from a path and russh-sftp `FileAttributes`.
460fn make_fsentry(
461    path: &Path,
462    attrs: &russh_sftp::protocol::FileAttributes,
463    symlink: Option<PathBuf>,
464) -> File {
465    let name = match path.file_name() {
466        None => "/".to_string(),
467        Some(name) => name.to_string_lossy().to_string(),
468    };
469    debug!("Found file {name}");
470
471    let uid = attrs.uid;
472    let gid = attrs.gid;
473    let mode = attrs.permissions.map(remotefs::fs::UnixPex::from);
474    let size = attrs.size.unwrap_or(0);
475    let accessed = attrs.atime.map(|x| {
476        std::time::UNIX_EPOCH
477            .checked_add(std::time::Duration::from_secs(u64::from(x)))
478            .unwrap_or(std::time::UNIX_EPOCH)
479    });
480    let modified = attrs.mtime.map(|x| {
481        std::time::UNIX_EPOCH
482            .checked_add(std::time::Duration::from_secs(u64::from(x)))
483            .unwrap_or(std::time::UNIX_EPOCH)
484    });
485
486    let file_type = if symlink.is_some() {
487        remotefs::fs::FileType::Symlink
488    } else if attrs.is_dir() {
489        remotefs::fs::FileType::Directory
490    } else {
491        remotefs::fs::FileType::File
492    };
493
494    let entry_metadata = Metadata {
495        accessed,
496        created: None,
497        file_type,
498        gid,
499        mode,
500        modified,
501        size,
502        symlink,
503        uid,
504    };
505    trace!("Metadata for {}: {:?}", path.display(), entry_metadata);
506    File {
507        path: path.to_path_buf(),
508        metadata: entry_metadata,
509    }
510}
511
512/// Synchronous writer wrapping a russh-sftp [`russh_sftp::client::fs::File`].
513///
514/// Stores the full `Arc<Runtime>` rather than just a `Handle` so that
515/// `Runtime::block_on` drives IO and background tasks on a current-thread
516/// runtime.
517struct SftpFileWriter {
518    file: russh_sftp::client::fs::File,
519    runtime: Arc<Runtime>,
520}
521
522impl Write for SftpFileWriter {
523    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
524        use tokio::io::AsyncWriteExt as _;
525        self.runtime.block_on(self.file.write(buf))
526    }
527
528    fn flush(&mut self) -> std::io::Result<()> {
529        use tokio::io::AsyncWriteExt as _;
530        self.runtime.block_on(self.file.flush())
531    }
532}
533
534impl Seek for SftpFileWriter {
535    fn seek(&mut self, pos: std::io::SeekFrom) -> std::io::Result<u64> {
536        use tokio::io::AsyncSeekExt as _;
537        self.runtime.block_on(self.file.seek(pos))
538    }
539}
540
541impl remotefs::fs::stream::WriteAndSeek for SftpFileWriter {}
542
543impl Drop for SftpFileWriter {
544    fn drop(&mut self) {
545        use tokio::io::AsyncWriteExt as _;
546        // Close the handle with an awaited close. russh-sftp's `File::drop`
547        // uses `close_nowait`, which never decrements the client's open-handle
548        // counter and would leak a handle per upload until the negotiated limit
549        // is reached ("Handle limit reached").
550        let _ = self.runtime.block_on(self.file.shutdown());
551    }
552}
553
554/// Number of concurrent SFTP file handles used per batch in pipelined reads.
555const SFTP_PIPELINE_DEPTH: usize = 4;
556
557/// Size of each chunk read by a single pipeline task (4 MiB).
558const SFTP_CHUNK_SIZE: usize = 4 * 1024 * 1024;
559
560/// Maximum number of completed batches to buffer ahead of the current read
561/// position. Caps memory usage to roughly `(MAX_PREFETCH + 1) * BATCH_SIZE`.
562const MAX_PREFETCH: usize = 2;
563
564/// Batch size: [`SFTP_PIPELINE_DEPTH`] * [`SFTP_CHUNK_SIZE`] = 16 MiB.
565const BATCH_SIZE: usize = SFTP_PIPELINE_DEPTH * SFTP_CHUNK_SIZE;
566
567/// A streaming SFTP reader that pipelines reads in batches.
568///
569/// Each batch spawns [`SFTP_PIPELINE_DEPTH`] concurrent SFTP read tasks of
570/// [`SFTP_CHUNK_SIZE`] bytes. Up to [`MAX_PREFETCH`] batches are fetched ahead
571/// of the current read position so the caller receives data immediately while
572/// keeping memory bounded.
573struct PipelinedSftpReader {
574    runtime: Arc<Runtime>,
575    session: Arc<SftpSession>,
576    path: String,
577    file_size: usize,
578    /// Next byte offset to start fetching from the remote file.
579    fetch_offset: usize,
580    /// Completed batches ready for consumption, front = current.
581    batches: std::collections::VecDeque<Vec<u8>>,
582    /// Read cursor within `batches[0]`.
583    buf_cursor: usize,
584    /// Background pre-fetch task, if any.
585    pending: Option<PrefetchTask>,
586}
587
588/// In-flight background batch fetch.
589struct PrefetchTask {
590    /// The byte offset this batch starts at — used to roll back
591    /// `fetch_offset` on failure.
592    batch_offset: usize,
593    handle: tokio::task::JoinHandle<Result<Vec<u8>, std::io::Error>>,
594}
595
596impl PipelinedSftpReader {
597    /// Creates a new streaming reader.
598    ///
599    /// Eagerly fetches the first batch and starts a background pre-fetch for
600    /// the second batch so the caller can start reading immediately.
601    fn new(
602        runtime: Arc<Runtime>,
603        session: Arc<SftpSession>,
604        path: String,
605    ) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
606        let metadata = runtime.block_on(session.metadata(&path))?;
607        let file_size = metadata.size.unwrap_or(0) as usize;
608
609        let mut reader = Self {
610            runtime,
611            session,
612            path,
613            file_size,
614            fetch_offset: 0,
615            batches: std::collections::VecDeque::new(),
616            buf_cursor: 0,
617            pending: None,
618        };
619
620        if file_size == 0 {
621            return Ok(reader);
622        }
623
624        // Eagerly fetch the first batch so data is available immediately.
625        let first_batch = reader.fetch_batch_blocking()?;
626        reader.batches.push_back(first_batch);
627
628        // Start background pre-fetch for the next batch.
629        reader.maybe_start_prefetch();
630
631        Ok(reader)
632    }
633
634    /// Fetches the next batch synchronously by blocking on the runtime.
635    fn fetch_batch_blocking(
636        &mut self,
637    ) -> Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>> {
638        let remaining = self.file_size.saturating_sub(self.fetch_offset);
639        if remaining == 0 {
640            return Ok(Vec::new());
641        }
642
643        let batch_len = remaining.min(BATCH_SIZE);
644        let offset = self.fetch_offset;
645        let batch = self
646            .runtime
647            .block_on(Self::fetch_batch(
648                self.session.clone(),
649                self.path.clone(),
650                offset,
651                batch_len,
652            ))
653            .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
654
655        self.fetch_offset += batch_len;
656        Ok(batch)
657    }
658
659    /// Spawns a background batch fetch if there is more data and the prefetch
660    /// queue is not full.
661    fn maybe_start_prefetch(&mut self) {
662        if self.pending.is_some() {
663            return;
664        }
665        if self.batches.len() > MAX_PREFETCH {
666            return;
667        }
668        let remaining = self.file_size.saturating_sub(self.fetch_offset);
669        if remaining == 0 {
670            return;
671        }
672
673        let batch_len = remaining.min(BATCH_SIZE);
674        let session = self.session.clone();
675        let path = self.path.clone();
676        let offset = self.fetch_offset;
677        // Speculatively advance; rolled back in collect_pending on failure.
678        self.fetch_offset += batch_len;
679
680        let handle = self
681            .runtime
682            .spawn(async move { Self::fetch_batch(session, path, offset, batch_len).await });
683
684        self.pending = Some(PrefetchTask {
685            batch_offset: offset,
686            handle,
687        });
688    }
689
690    /// Collects the result of a pending pre-fetch task.
691    ///
692    /// On failure, rolls back `fetch_offset` so the batch can be retried.
693    fn collect_pending(&mut self) -> std::io::Result<Option<Vec<u8>>> {
694        let task = match self.pending.take() {
695            Some(t) => t,
696            None => return Ok(None),
697        };
698
699        match self
700            .runtime
701            .block_on(task.handle)
702            .map_err(std::io::Error::other)?
703        {
704            Ok(batch) if batch.is_empty() => Ok(None),
705            Ok(batch) => Ok(Some(batch)),
706            Err(err) => {
707                // Roll back so the caller (or a retry) can re-fetch this range.
708                self.fetch_offset = task.batch_offset;
709                Err(std::io::Error::other(err))
710            }
711        }
712    }
713
714    /// Fetches a single batch: spawns [`SFTP_PIPELINE_DEPTH`] concurrent reads
715    /// and assembles the result into a contiguous buffer.
716    async fn fetch_batch(
717        session: Arc<SftpSession>,
718        path: String,
719        batch_offset: usize,
720        batch_len: usize,
721    ) -> Result<Vec<u8>, std::io::Error> {
722        use tokio::io::{AsyncReadExt as _, AsyncSeekExt as _, AsyncWriteExt as _};
723
724        let chunk_count = batch_len.div_ceil(SFTP_CHUNK_SIZE);
725        let mut tasks = Vec::with_capacity(chunk_count);
726
727        for i in 0..chunk_count {
728            let chunk_offset = i * SFTP_CHUNK_SIZE;
729            let len = SFTP_CHUNK_SIZE.min(batch_len - chunk_offset);
730            let abs_offset = batch_offset + chunk_offset;
731            let session = Arc::clone(&session);
732            let path = path.clone();
733
734            tasks.push(tokio::spawn(async move {
735                let mut file = session.open(&path).await.map_err(std::io::Error::other)?;
736                file.seek(std::io::SeekFrom::Start(abs_offset as u64))
737                    .await?;
738                let mut buf = vec![0_u8; len];
739                let read_res = file.read_exact(&mut buf).await;
740                // Explicitly close the handle with an awaited close. russh-sftp's
741                // `File::drop` uses `close_nowait`, which frees the handle
742                // server-side but never decrements the client's open-handle
743                // counter. Relying on it leaks handles until the negotiated
744                // limit is hit ("Handle limit reached") after many opens.
745                let _ = file.shutdown().await;
746                read_res?;
747                Ok::<(usize, Vec<u8>), std::io::Error>((chunk_offset, buf))
748            }));
749        }
750
751        let mut result = vec![0_u8; batch_len];
752        for task in tasks {
753            let (chunk_offset, chunk) = task
754                .await
755                .map_err(std::io::Error::other)?
756                .map_err(std::io::Error::other)?;
757            result[chunk_offset..chunk_offset + chunk.len()].copy_from_slice(&chunk);
758        }
759
760        Ok(result)
761    }
762}
763
764impl Read for PipelinedSftpReader {
765    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
766        loop {
767            // Try to serve from the current front batch.
768            if let Some(front) = self.batches.front() {
769                let available = &front[self.buf_cursor..];
770                if !available.is_empty() {
771                    let to_copy = available.len().min(buf.len());
772                    buf[..to_copy].copy_from_slice(&available[..to_copy]);
773                    self.buf_cursor += to_copy;
774                    return Ok(to_copy);
775                }
776
777                // Current batch fully consumed — pop it.
778                self.batches.pop_front();
779                self.buf_cursor = 0;
780
781                // Collect the pending pre-fetch if any.
782                if let Some(batch) = self.collect_pending()? {
783                    self.batches.push_back(batch);
784                }
785
786                // Kick off next pre-fetch.
787                self.maybe_start_prefetch();
788
789                continue;
790            }
791
792            // No batches buffered — try to collect pending.
793            if let Some(batch) = self.collect_pending()? {
794                self.batches.push_back(batch);
795                self.maybe_start_prefetch();
796                continue;
797            }
798
799            // Nothing left — EOF.
800            return Ok(0);
801        }
802    }
803}
804
805/// Apply algorithm preferences from SSH config to the russh [`client::Config`].
806fn apply_config_algo_prefs(config: &mut client::Config, ssh_config: &Config) {
807    let params = &ssh_config.params;
808
809    // KEX algorithms
810    let kex: Vec<russh::kex::Name> = params
811        .kex_algorithms
812        .algorithms()
813        .iter()
814        .filter_map(|name| {
815            russh::kex::Name::try_from(name.as_str())
816                .map_err(|()| warn!("Unsupported KEX algorithm: {name}"))
817                .ok()
818        })
819        .collect();
820    if !kex.is_empty() {
821        config.preferred.kex = Cow::Owned(kex);
822    }
823
824    // Host key algorithms
825    let host_keys: Vec<Algorithm> = params
826        .host_key_algorithms
827        .algorithms()
828        .iter()
829        .filter_map(|name| {
830            name.parse::<Algorithm>()
831                .map_err(|err| warn!("Unsupported host key algorithm '{name}': {err}"))
832                .ok()
833        })
834        .collect();
835    if !host_keys.is_empty() {
836        config.preferred.key = Cow::Owned(host_keys);
837    }
838
839    // Cipher algorithms
840    let ciphers: Vec<russh::cipher::Name> = params
841        .ciphers
842        .algorithms()
843        .iter()
844        .filter_map(|name| {
845            russh::cipher::Name::try_from(name.as_str())
846                .map_err(|()| warn!("Unsupported cipher algorithm: {name}"))
847                .ok()
848        })
849        .collect();
850    if !ciphers.is_empty() {
851        config.preferred.cipher = Cow::Owned(ciphers);
852    }
853
854    // MAC algorithms
855    let macs: Vec<russh::mac::Name> = params
856        .mac
857        .algorithms()
858        .iter()
859        .filter_map(|name| {
860            russh::mac::Name::try_from(name.as_str())
861                .map_err(|()| warn!("Unsupported MAC algorithm: {name}"))
862                .ok()
863        })
864        .collect();
865    if !macs.is_empty() {
866        config.preferred.mac = Cow::Owned(macs);
867    }
868}
869
870/// Apply algorithm preferences from [`SshOpts`] methods to the russh [`client::Config`].
871///
872/// Options from `SshOpts::methods` override those from the SSH config file.
873fn apply_opts_algo_prefs(config: &mut client::Config, opts: &SshOpts) {
874    for method in opts.methods.iter() {
875        let algos = method.prefs();
876        let names: Vec<&str> = algos.split(',').collect();
877
878        match method.method_type {
879            MethodType::Kex => {
880                let kex: Vec<russh::kex::Name> = names
881                    .iter()
882                    .filter_map(|name| {
883                        russh::kex::Name::try_from(*name)
884                            .map_err(|()| warn!("Unsupported KEX algorithm: {name}"))
885                            .ok()
886                    })
887                    .collect();
888                if !kex.is_empty() {
889                    config.preferred.kex = Cow::Owned(kex);
890                }
891            }
892            MethodType::HostKey => {
893                let keys: Vec<Algorithm> = names
894                    .iter()
895                    .filter_map(|name| {
896                        name.parse::<Algorithm>()
897                            .map_err(|err| warn!("Unsupported host key algorithm '{name}': {err}"))
898                            .ok()
899                    })
900                    .collect();
901                if !keys.is_empty() {
902                    config.preferred.key = Cow::Owned(keys);
903                }
904            }
905            MethodType::CryptClientServer | MethodType::CryptServerClient => {
906                let ciphers: Vec<russh::cipher::Name> = names
907                    .iter()
908                    .filter_map(|name| {
909                        russh::cipher::Name::try_from(*name)
910                            .map_err(|()| warn!("Unsupported cipher algorithm: {name}"))
911                            .ok()
912                    })
913                    .collect();
914                if !ciphers.is_empty() {
915                    config.preferred.cipher = Cow::Owned(ciphers);
916                }
917            }
918            MethodType::MacClientServer | MethodType::MacServerClient => {
919                let macs: Vec<russh::mac::Name> = names
920                    .iter()
921                    .filter_map(|name| {
922                        russh::mac::Name::try_from(*name)
923                            .map_err(|()| warn!("Unsupported MAC algorithm: {name}"))
924                            .ok()
925                    })
926                    .collect();
927                if !macs.is_empty() {
928                    config.preferred.mac = Cow::Owned(macs);
929                }
930            }
931            _ => {
932                trace!(
933                    "Ignoring unsupported method type {:?} for russh backend",
934                    method.method_type
935                );
936            }
937        }
938    }
939}
940
941/// Execute a shell command on the remote server via a russh channel.
942///
943/// Opens a session channel, executes the command, collects stdout,
944/// and returns the exit code with the output.
945async fn perform_shell_cmd<T>(session: &Handle<T>, cmd: &str) -> RemoteResult<(u32, String)>
946where
947    T: Handler,
948{
949    let mut channel = open_channel(session).await?;
950
951    channel.exec(true, cmd).await.map_err(|err| {
952        RemoteError::new_ex(
953            RemoteErrorType::ProtocolError,
954            format!("Could not execute command \"{cmd}\": {err}"),
955        )
956    })?;
957
958    let mut output = String::new();
959    let mut exit_code: Option<u32> = None;
960
961    while let Some(msg) = channel.wait().await {
962        match msg {
963            russh::ChannelMsg::Data { data } => {
964                output.push_str(&String::from_utf8_lossy(&data));
965            }
966            russh::ChannelMsg::ExitStatus { exit_status } => {
967                exit_code = Some(exit_status);
968            }
969            russh::ChannelMsg::Close => break,
970            russh::ChannelMsg::Eof => {}
971            _ => {}
972        }
973    }
974
975    let rc = exit_code.unwrap_or_else(|| {
976        warn!("No exit status received for command \"{cmd}\", defaulting to 1");
977        1
978    });
979
980    trace!("Command output: {output}");
981    debug!(r#"Command output: "{output}"; exit code: {rc}"#);
982
983    Ok((rc, output))
984}
985
986/// Open a session channel on the given handle.
987async fn open_channel<T>(session: &Handle<T>) -> RemoteResult<russh::Channel<russh::client::Msg>>
988where
989    T: Handler,
990{
991    session.channel_open_session().await.map_err(|err| {
992        RemoteError::new_ex(
993            RemoteErrorType::ProtocolError,
994            format!("Could not open channel: {err}"),
995        )
996    })
997}
998
999#[cfg(test)]
1000mod test {
1001
1002    use std::sync::Arc;
1003
1004    use ssh2_config::ParseRule;
1005
1006    use super::*;
1007    use crate::mock::ssh as ssh_mock;
1008
1009    fn test_runtime() -> Arc<Runtime> {
1010        Arc::new(
1011            tokio::runtime::Builder::new_current_thread()
1012                .enable_all()
1013                .build()
1014                .unwrap(),
1015        )
1016    }
1017
1018    #[test]
1019    fn should_connect_to_ssh_server_auth_user_password() {
1020        use crate::ssh::container::OpensshServer;
1021
1022        let container = OpensshServer::start();
1023        let port = container.port();
1024
1025        crate::mock::logger();
1026        let runtime = test_runtime();
1027        let config_file = ssh_mock::create_ssh_config(port);
1028        let opts = SshOpts::new("sftp")
1029            .config_file(config_file.path(), ParseRule::ALLOW_UNKNOWN_FIELDS)
1030            .password("password")
1031            .runtime(runtime);
1032
1033        if let Err(err) = RusshSession::<NoCheckServerKey>::connect(&opts) {
1034            panic!("Could not connect to server: {err}");
1035        }
1036        let session = RusshSession::<NoCheckServerKey>::connect(&opts).unwrap();
1037        assert!(session.authenticated().unwrap());
1038
1039        drop(container);
1040    }
1041
1042    #[test]
1043    fn should_connect_to_ssh_server_auth_key() {
1044        use crate::ssh::container::OpensshServer;
1045
1046        let container = OpensshServer::start();
1047        let port = container.port();
1048
1049        crate::mock::logger();
1050        let runtime = test_runtime();
1051        let config_file = ssh_mock::create_ssh_config(port);
1052        let opts = SshOpts::new("sftp")
1053            .config_file(config_file.path(), ParseRule::ALLOW_UNKNOWN_FIELDS)
1054            .key_storage(Box::new(ssh_mock::MockSshKeyStorage::default()))
1055            .runtime(runtime);
1056        let session = RusshSession::<NoCheckServerKey>::connect(&opts).unwrap();
1057        assert!(session.authenticated().unwrap());
1058    }
1059
1060    #[test]
1061    fn should_connect_to_ssh_server_auth_key_from_ssh_config() {
1062        use crate::ssh::container::OpensshServer;
1063
1064        let container = OpensshServer::start();
1065        let port = container.port();
1066
1067        crate::mock::logger();
1068        let runtime = test_runtime();
1069        // Authenticate purely via the `IdentityFile` directive of the ssh config,
1070        // with no key storage configured.
1071        let key_file = ssh_mock::create_key_file();
1072        let config_file = ssh_mock::create_ssh_config_with_identity(port, key_file.path());
1073        let opts = SshOpts::new("sftp")
1074            .config_file(config_file.path(), ParseRule::ALLOW_UNKNOWN_FIELDS)
1075            .runtime(runtime);
1076        let session = RusshSession::<NoCheckServerKey>::connect(&opts).unwrap();
1077        assert!(session.authenticated().unwrap());
1078    }
1079
1080    #[test]
1081    #[cfg(unix)]
1082    fn should_connect_to_ssh_server_auth_ssh_agent() {
1083        use std::process::Command;
1084
1085        use crate::SshAgentIdentity;
1086        use crate::ssh::container::OpensshServer;
1087
1088        crate::mock::logger();
1089
1090        // Spawn a dedicated ssh-agent and load the mock key into it.
1091        let agent_out = Command::new("ssh-agent")
1092            .arg("-s")
1093            .output()
1094            .expect("failed to spawn ssh-agent (is openssh installed?)");
1095        let agent_stdout = String::from_utf8_lossy(&agent_out.stdout);
1096        let auth_sock = parse_agent_var(&agent_stdout, "SSH_AUTH_SOCK")
1097            .expect("ssh-agent did not report SSH_AUTH_SOCK");
1098        let agent_pid = parse_agent_var(&agent_stdout, "SSH_AGENT_PID")
1099            .expect("ssh-agent did not report SSH_AGENT_PID");
1100
1101        let key_file = ssh_mock::create_key_file();
1102        // ssh-add refuses keys with loose permissions.
1103        Command::new("chmod")
1104            .args(["600", &key_file.path().display().to_string()])
1105            .status()
1106            .expect("chmod failed");
1107        let added = Command::new("ssh-add")
1108            .arg(key_file.path())
1109            .env("SSH_AUTH_SOCK", &auth_sock)
1110            .status()
1111            .expect("ssh-add failed to run");
1112        assert!(added.success(), "ssh-add could not load the mock key");
1113
1114        // Point the russh agent client at our agent. No key storage, no password:
1115        // authentication must succeed through the agent alone.
1116        // SAFETY: tests in this module run single-threaded (`--test-threads=1`).
1117        unsafe {
1118            std::env::set_var("SSH_AUTH_SOCK", &auth_sock);
1119        }
1120
1121        let container = OpensshServer::start();
1122        let port = container.port();
1123        let runtime = test_runtime();
1124        let config_file = ssh_mock::create_ssh_config(port);
1125        let opts = SshOpts::new("sftp")
1126            .config_file(config_file.path(), ParseRule::ALLOW_UNKNOWN_FIELDS)
1127            .ssh_agent_identity(Some(SshAgentIdentity::All))
1128            .runtime(runtime);
1129
1130        let result = RusshSession::<NoCheckServerKey>::connect(&opts);
1131
1132        // Tear the agent down regardless of the outcome.
1133        // SAFETY: see above.
1134        unsafe {
1135            std::env::remove_var("SSH_AUTH_SOCK");
1136        }
1137        let _ = Command::new("kill").arg(&agent_pid).status();
1138
1139        let session = result.expect("could not authenticate via ssh agent");
1140        assert!(session.authenticated().unwrap());
1141    }
1142
1143    /// Parse a `NAME=value;` assignment from `ssh-agent -s` output.
1144    #[cfg(unix)]
1145    fn parse_agent_var(output: &str, name: &str) -> Option<String> {
1146        let needle = format!("{name}=");
1147        let start = output.find(&needle)? + needle.len();
1148        let rest = &output[start..];
1149        let end = rest.find(';')?;
1150        Some(rest[..end].to_string())
1151    }
1152
1153    #[test]
1154    fn should_perform_shell_command_on_server() {
1155        crate::mock::logger();
1156        let container = crate::ssh::container::OpensshServer::start();
1157        let port = container.port();
1158
1159        let runtime = test_runtime();
1160        let opts = SshOpts::new("127.0.0.1")
1161            .port(port)
1162            .username("sftp")
1163            .password("password")
1164            .runtime(runtime);
1165        let mut session = RusshSession::<NoCheckServerKey>::connect(&opts).unwrap();
1166        assert!(session.authenticated().unwrap());
1167        assert!(session.cmd("pwd").is_ok());
1168    }
1169
1170    #[test]
1171    fn should_perform_shell_command_on_server_and_return_exit_code() {
1172        crate::mock::logger();
1173        let container = crate::ssh::container::OpensshServer::start();
1174        let port = container.port();
1175
1176        let runtime = test_runtime();
1177        let opts = SshOpts::new("127.0.0.1")
1178            .port(port)
1179            .username("sftp")
1180            .password("password")
1181            .runtime(runtime);
1182        let mut session = RusshSession::<NoCheckServerKey>::connect(&opts).unwrap();
1183        assert!(session.authenticated().unwrap());
1184        assert_eq!(
1185            session.cmd_at("pwd", Path::new("/tmp")).ok().unwrap(),
1186            (0, String::from("/tmp\n"))
1187        );
1188        assert_eq!(
1189            session
1190                .cmd_at("pippopluto", Path::new("/tmp"))
1191                .ok()
1192                .unwrap()
1193                .0,
1194            127
1195        );
1196    }
1197
1198    #[test]
1199    fn should_fail_authentication() {
1200        crate::mock::logger();
1201        let container = crate::ssh::container::OpensshServer::start();
1202        let port = container.port();
1203
1204        let runtime = test_runtime();
1205        let opts = SshOpts::new("127.0.0.1")
1206            .port(port)
1207            .username("sftp")
1208            .password("ippopotamo")
1209            .runtime(runtime);
1210        assert!(RusshSession::<NoCheckServerKey>::connect(&opts).is_err());
1211    }
1212
1213    #[test]
1214    fn test_filetransfer_sftp_bad_server() {
1215        crate::mock::logger();
1216        let runtime = test_runtime();
1217        let opts = SshOpts::new("myverybad.verybad.server")
1218            .port(10022)
1219            .username("sftp")
1220            .password("ippopotamo")
1221            .runtime(runtime);
1222        assert!(RusshSession::<NoCheckServerKey>::connect(&opts).is_err());
1223    }
1224}