1use futures::future::join_all;
2use indicatif::{ProgressBar, ProgressStyle};
3use ssh2::{Session, Sftp};
4use std::{
5 ffi::OsStr,
6 fs::{self, File},
7 io::{Read, Write},
8 net::TcpStream,
9 path::PathBuf,
10 time::Duration,
11};
12
13use crate::{error::Result, utils::with_retry};
14
15pub struct Connect {
16 session: Session,
17 ssh_opts: SshOpts,
18 mode: Mode,
19 sftp: Sftp,
20}
21
22impl Connect {
23 pub fn new(ssh_opts: SshOpts, mode: Mode) -> Result<Self> {
24 let session = create_session(&ssh_opts)?;
25 let sftp = session.sftp()?;
26
27 Ok(Self {
28 session,
29 ssh_opts,
30 mode,
31 sftp,
32 })
33 }
34
35 pub async fn receive(&self, from: &PathBuf, to: &PathBuf) -> Result<()> {
36 let is_dir = self.stat(from)?;
37
38 if is_dir {
39 self.handle_dir(from, to).await
40 } else {
41 self.handle_file(from, to).await
42 }
43 }
44
45 async fn handle_file(&self, from: &PathBuf, to: &PathBuf) -> Result<()> {
46 let full_path = to.join(from.file_name().unwrap_or(OsStr::new("unknown")));
47 let result =
48 copy_file_from_remote(&self.ssh_opts, from.clone(), full_path, &self.mode).await;
49
50 println!("✅ File received successfully");
51 result
52 }
53
54 async fn handle_dir(&self, from: &PathBuf, to: &PathBuf) -> Result<()> {
55 let mut files = self.list_files(from)?;
56
57 #[cfg(any(target_os = "linux", target_os = "macos"))]
58 if self.mode != Mode::Replace {
59 let output = std::process::Command::new("find")
60 .arg(to)
61 .arg("-type")
62 .arg("f")
63 .output()?;
64
65 let existing_files = String::from_utf8_lossy(&output.stdout)
66 .lines()
67 .map(|line| PathBuf::from(line))
68 .collect::<Vec<_>>();
69
70 files = files
71 .into_iter()
72 .filter(|file| !existing_files.contains(&to.join(file.strip_prefix(from).unwrap())))
73 .collect::<Vec<_>>();
74 }
75
76 let pb = ProgressBar::new(files.len() as u64);
77 pb.set_style(
78 ProgressStyle::with_template(
79 "{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {pos}/{len} ({eta})\n\n{msg}",
80 )
81 .unwrap()
82 .progress_chars("#>-"),
83 );
84 pb.enable_steady_tick(Duration::from_millis(100));
85
86 let mut handles = Vec::new();
87 for item in &files {
88 let to_path = to.join(item.strip_prefix(from).unwrap());
89 let item_clone = item.clone();
90 let ssh_opts = self.ssh_opts.clone();
91 let pb = pb.clone();
92 let mode = self.mode.clone();
93 let handle = tokio::spawn(async move {
94 let result =
95 copy_file_from_remote(&ssh_opts, item_clone.clone(), to_path, &mode).await;
96 pb.inc(1);
97 result
98 });
99
100 handles.push(handle);
101 }
102
103 let items = join_all(handles).await;
104
105 if items.iter().all(|x| x.is_ok()) {
106 pb.finish_with_message(format!(
107 "✅ All files received successfully ({} files)",
108 files.len()
109 ));
110 Ok(())
111 } else {
112 Err(std::io::Error::new(
113 std::io::ErrorKind::Other,
114 "One or more files failed to copy",
115 )
116 .into())
117 }
118 }
119
120 fn stat(&self, path: &PathBuf) -> Result<bool> {
121 let file = self.sftp.stat(&path)?;
122 Ok(file.is_dir())
123 }
124
125 fn list_files(&self, dir: &PathBuf) -> Result<Vec<PathBuf>> {
126 let mut channel = self.session.channel_session()?;
127
128 channel.exec(&format!("find {} -type f", dir.display()))?;
129
130 let mut buf = String::new();
131 channel.read_to_string(&mut buf)?;
132
133 let files_only = find_files(&buf);
134
135 Ok(files_only)
136 }
137}
138
139pub fn find_files(buf: &str) -> Vec<PathBuf> {
140 buf.lines().map(|line| PathBuf::from(line.trim())).collect()
141}
142
143#[derive(Clone)]
144pub struct SshOpts {
145 pub host: String,
146 pub username: String,
147 pub private_key: PathBuf,
148}
149
150#[derive(Clone, PartialEq)]
154pub enum Mode {
155 Replace,
156 Ignore,
157}
158
159async fn copy_file_from_remote(
160 ssh_opts: &SshOpts,
161 remote_file_path: PathBuf,
162 local_file_path: PathBuf,
163 mode: &Mode,
164) -> Result<()> {
165 let create_session = || create_session(ssh_opts);
166 let session = with_retry(create_session, 10)?;
167
168 let (mut remote_file, stat) = session.scp_recv(&remote_file_path)?;
170 let mut contents = Vec::with_capacity(stat.size() as usize);
171 remote_file.read_to_end(&mut contents)?;
172
173 fs::create_dir_all(local_file_path.parent().unwrap())?;
175
176 match mode {
177 Mode::Replace => {
178 let mut local_file = File::create(&local_file_path)?;
179 local_file.write_all(&contents)?;
180 }
181 Mode::Ignore => {
182 if local_file_path.exists() {
183 println!(
184 "Skipping already existing file: {}",
185 local_file_path.display()
186 );
187 return Ok(());
188 }
189
190 let mut local_file = File::create(local_file_path)?;
191 local_file.write_all(&contents)?;
192 }
193 }
194
195 session.disconnect(None, "Bye", None)?;
196
197 Ok(())
198}
199
200pub fn create_session(ssh_opts: &SshOpts) -> Result<Session> {
201 let tcp = TcpStream::connect(&ssh_opts.host)?;
203 let mut session = Session::new()?;
204 session.set_tcp_stream(tcp);
205 session.handshake()?;
206
207 session.userauth_pubkey_file(&ssh_opts.username, None, &ssh_opts.private_key, None)?;
209
210 Ok(session)
211}