use russh_sftp::{client::SftpSession, protocol::OpenFlags};
use std::path::Path;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use super::connection::Client;
use crate::utils::buffer_pool::global;
impl Client {
pub async fn upload_file<T: AsRef<Path>, U: Into<String>>(
&self,
src_file_path: T,
dest_file_path: U,
) -> Result<(), super::Error> {
let channel = self.get_channel().await?;
channel.request_subsystem(true, "sftp").await?;
let sftp = SftpSession::new(channel.into_stream()).await?;
let file_contents = tokio::fs::read(src_file_path)
.await
.map_err(super::Error::IoError)?;
let mut file = sftp
.open_with_flags(
dest_file_path,
OpenFlags::CREATE | OpenFlags::TRUNCATE | OpenFlags::WRITE | OpenFlags::READ,
)
.await?;
file.write_all(&file_contents)
.await
.map_err(super::Error::IoError)?;
file.flush().await.map_err(super::Error::IoError)?;
file.shutdown().await.map_err(super::Error::IoError)?;
Ok(())
}
pub async fn download_file<T: AsRef<Path>, U: Into<String>>(
&self,
remote_file_path: U,
local_file_path: T,
) -> Result<(), super::Error> {
let channel = self.get_channel().await?;
channel.request_subsystem(true, "sftp").await?;
let sftp = SftpSession::new(channel.into_stream()).await?;
let mut remote_file = sftp
.open_with_flags(remote_file_path, OpenFlags::READ)
.await?;
let mut pooled_buffer = global::get_large_buffer();
remote_file.read_to_end(pooled_buffer.as_mut_vec()).await?;
let contents = pooled_buffer.as_vec().clone();
let mut local_file = tokio::fs::File::create(local_file_path.as_ref())
.await
.map_err(super::Error::IoError)?;
local_file
.write_all(&contents)
.await
.map_err(super::Error::IoError)?;
local_file.flush().await.map_err(super::Error::IoError)?;
Ok(())
}
pub async fn upload_dir<T: AsRef<Path>, U: Into<String>>(
&self,
local_dir_path: T,
remote_dir_path: U,
) -> Result<(), super::Error> {
let local_dir = local_dir_path.as_ref();
let remote_dir = remote_dir_path.into();
if !local_dir.is_dir() {
return Err(super::Error::IoError(std::io::Error::new(
std::io::ErrorKind::NotFound,
format!("Local directory does not exist: {local_dir:?}"),
)));
}
let channel = self.get_channel().await?;
channel.request_subsystem(true, "sftp").await?;
let sftp = SftpSession::new(channel.into_stream()).await?;
let _ = sftp.create_dir(&remote_dir).await;
self.upload_dir_recursive(&sftp, local_dir, &remote_dir)
.await?;
Ok(())
}
#[allow(clippy::only_used_in_recursion)]
fn upload_dir_recursive<'a>(
&'a self,
sftp: &'a SftpSession,
local_dir: &'a Path,
remote_dir: &'a str,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<(), super::Error>> + Send + 'a>>
{
Box::pin(async move {
let entries = tokio::fs::read_dir(local_dir)
.await
.map_err(super::Error::IoError)?;
let mut entries = entries;
while let Some(entry) = entries.next_entry().await.map_err(super::Error::IoError)? {
let path = entry.path();
let file_name = entry.file_name();
let file_name_str = file_name.to_string_lossy();
let remote_path = format!("{remote_dir}/{file_name_str}");
let metadata = entry.metadata().await.map_err(super::Error::IoError)?;
if metadata.is_dir() {
let _ = sftp.create_dir(&remote_path).await; self.upload_dir_recursive(sftp, &path, &remote_path).await?;
} else if metadata.is_file() {
let file_contents = tokio::fs::read(&path)
.await
.map_err(super::Error::IoError)?;
let mut remote_file = sftp
.open_with_flags(
&remote_path,
OpenFlags::CREATE | OpenFlags::TRUNCATE | OpenFlags::WRITE,
)
.await?;
remote_file
.write_all(&file_contents)
.await
.map_err(super::Error::IoError)?;
remote_file.flush().await.map_err(super::Error::IoError)?;
remote_file
.shutdown()
.await
.map_err(super::Error::IoError)?;
}
}
Ok(())
})
}
pub async fn download_dir<T: AsRef<Path>, U: Into<String>>(
&self,
remote_dir_path: U,
local_dir_path: T,
) -> Result<(), super::Error> {
let local_dir = local_dir_path.as_ref();
let remote_dir = remote_dir_path.into();
let channel = self.get_channel().await?;
channel.request_subsystem(true, "sftp").await?;
let sftp = SftpSession::new(channel.into_stream()).await?;
tokio::fs::create_dir_all(local_dir)
.await
.map_err(super::Error::IoError)?;
self.download_dir_recursive(&sftp, &remote_dir, local_dir)
.await?;
Ok(())
}
#[allow(clippy::only_used_in_recursion)]
fn download_dir_recursive<'a>(
&'a self,
sftp: &'a SftpSession,
remote_dir: &'a str,
local_dir: &'a Path,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<(), super::Error>> + Send + 'a>>
{
Box::pin(async move {
let entries = sftp.read_dir(remote_dir).await?;
for entry in entries {
let name = entry.file_name();
let metadata = entry.metadata();
if name == "." || name == ".." {
continue;
}
let remote_path = format!("{remote_dir}/{name}");
let local_path = local_dir.join(&name);
if metadata.file_type().is_dir() {
tokio::fs::create_dir_all(&local_path)
.await
.map_err(super::Error::IoError)?;
self.download_dir_recursive(sftp, &remote_path, &local_path)
.await?;
} else if metadata.file_type().is_file() {
let mut remote_file =
sftp.open_with_flags(&remote_path, OpenFlags::READ).await?;
let mut pooled_buffer = global::get_large_buffer();
remote_file.read_to_end(pooled_buffer.as_mut_vec()).await?;
let contents = pooled_buffer.as_vec().clone();
tokio::fs::write(&local_path, contents)
.await
.map_err(super::Error::IoError)?;
}
}
Ok(())
})
}
}