fast_scp/
scp.rs

1use futures::future::join_all;
2use indicatif::{ProgressBar, ProgressStyle};
3use ssh2::{Session, Sftp};
4use std::{
5    ffi::OsStr,
6    fs::{self, File},
7    io::{Read, Write},
8    net::TcpStream,
9    path::PathBuf,
10    time::Duration,
11};
12
13use crate::{error::Result, utils::with_retry};
14
15pub struct Connect {
16    session: Session,
17    ssh_opts: SshOpts,
18    mode: Mode,
19    sftp: Sftp,
20}
21
22impl Connect {
23    pub fn new(ssh_opts: SshOpts, mode: Mode) -> Result<Self> {
24        let session = create_session(&ssh_opts)?;
25        let sftp = session.sftp()?;
26
27        Ok(Self {
28            session,
29            ssh_opts,
30            mode,
31            sftp,
32        })
33    }
34
35    pub async fn receive(&self, from: &PathBuf, to: &PathBuf) -> Result<()> {
36        let is_dir = self.stat(from)?;
37
38        if is_dir {
39            self.handle_dir(from, to).await
40        } else {
41            self.handle_file(from, to).await
42        }
43    }
44
45    async fn handle_file(&self, from: &PathBuf, to: &PathBuf) -> Result<()> {
46        let full_path = to.join(from.file_name().unwrap_or(OsStr::new("unknown")));
47        let result =
48            copy_file_from_remote(&self.ssh_opts, from.clone(), full_path, &self.mode).await;
49
50        println!("✅ File received successfully");
51        result
52    }
53
54    async fn handle_dir(&self, from: &PathBuf, to: &PathBuf) -> Result<()> {
55        let mut files = self.list_files(from)?;
56
57        #[cfg(any(target_os = "linux", target_os = "macos"))]
58        if self.mode != Mode::Replace {
59            let output = std::process::Command::new("find")
60                .arg(to)
61                .arg("-type")
62                .arg("f")
63                .output()?;
64
65            let existing_files = String::from_utf8_lossy(&output.stdout)
66                .lines()
67                .map(|line| PathBuf::from(line))
68                .collect::<Vec<_>>();
69
70            files = files
71                .into_iter()
72                .filter(|file| !existing_files.contains(&to.join(file.strip_prefix(from).unwrap())))
73                .collect::<Vec<_>>();
74        }
75
76        let pb = ProgressBar::new(files.len() as u64);
77        pb.set_style(
78            ProgressStyle::with_template(
79                "{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {pos}/{len} ({eta})\n\n{msg}",
80            )
81            .unwrap()
82            .progress_chars("#>-"),
83        );
84        pb.enable_steady_tick(Duration::from_millis(100));
85
86        let mut handles = Vec::new();
87        for item in &files {
88            let to_path = to.join(item.strip_prefix(from).unwrap());
89            let item_clone = item.clone();
90            let ssh_opts = self.ssh_opts.clone();
91            let pb = pb.clone();
92            let mode = self.mode.clone();
93            let handle = tokio::spawn(async move {
94                let result =
95                    copy_file_from_remote(&ssh_opts, item_clone.clone(), to_path, &mode).await;
96                pb.inc(1);
97                result
98            });
99
100            handles.push(handle);
101        }
102
103        let items = join_all(handles).await;
104
105        if items.iter().all(|x| x.is_ok()) {
106            pb.finish_with_message(format!(
107                "✅ All files received successfully ({} files)",
108                files.len()
109            ));
110            Ok(())
111        } else {
112            Err(std::io::Error::new(
113                std::io::ErrorKind::Other,
114                "One or more files failed to copy",
115            )
116            .into())
117        }
118    }
119
120    fn stat(&self, path: &PathBuf) -> Result<bool> {
121        let file = self.sftp.stat(&path)?;
122        Ok(file.is_dir())
123    }
124
125    fn list_files(&self, dir: &PathBuf) -> Result<Vec<PathBuf>> {
126        let mut channel = self.session.channel_session()?;
127
128        channel.exec(&format!("find {} -type f", dir.display()))?;
129
130        let mut buf = String::new();
131        channel.read_to_string(&mut buf)?;
132
133        let files_only = find_files(&buf);
134
135        Ok(files_only)
136    }
137}
138
139pub fn find_files(buf: &str) -> Vec<PathBuf> {
140    buf.lines().map(|line| PathBuf::from(line.trim())).collect()
141}
142
143#[derive(Clone)]
144pub struct SshOpts {
145    pub host: String,
146    pub username: String,
147    pub private_key: PathBuf,
148}
149
150/// Mode to use when copying files
151/// Replace will overwrite the file if it exists
152/// Ignore will skip the file if it exists
153#[derive(Clone, PartialEq)]
154pub enum Mode {
155    Replace,
156    Ignore,
157}
158
159async fn copy_file_from_remote(
160    ssh_opts: &SshOpts,
161    remote_file_path: PathBuf,
162    local_file_path: PathBuf,
163    mode: &Mode,
164) -> Result<()> {
165    let create_session = || create_session(ssh_opts);
166    let session = with_retry(create_session, 10)?;
167
168    // Create a SCP channel for receiving the file
169    let (mut remote_file, stat) = session.scp_recv(&remote_file_path)?;
170    let mut contents = Vec::with_capacity(stat.size() as usize);
171    remote_file.read_to_end(&mut contents)?;
172
173    // make the dir if not exists
174    fs::create_dir_all(local_file_path.parent().unwrap())?;
175
176    match mode {
177        Mode::Replace => {
178            let mut local_file = File::create(&local_file_path)?;
179            local_file.write_all(&contents)?;
180        }
181        Mode::Ignore => {
182            if local_file_path.exists() {
183                println!(
184                    "Skipping already existing file: {}",
185                    local_file_path.display()
186                );
187                return Ok(());
188            }
189
190            let mut local_file = File::create(local_file_path)?;
191            local_file.write_all(&contents)?;
192        }
193    }
194
195    session.disconnect(None, "Bye", None)?;
196
197    Ok(())
198}
199
200pub fn create_session(ssh_opts: &SshOpts) -> Result<Session> {
201    // Connect to the host
202    let tcp = TcpStream::connect(&ssh_opts.host)?;
203    let mut session = Session::new()?;
204    session.set_tcp_stream(tcp);
205    session.handshake()?;
206
207    // Authenticate using a private key
208    session.userauth_pubkey_file(&ssh_opts.username, None, &ssh_opts.private_key, None)?;
209
210    Ok(session)
211}