use std::sync::Arc;
use async_trait::async_trait;
use russh::{client::Msg, Channel, ChannelMsg};
use tokio::sync::mpsc;
use tokio::task::JoinHandle;
use crate::message::{Init, Message, StatusCode, Version};
mod commands;
mod dir;
mod error;
mod file;
mod receiver;
mod request;
mod stop;
pub use dir::{Dir, DIR_CLOSED};
pub use error::Error;
pub use file::{File, FILE_CLOSED};
pub use request::SftpRequest;
use stop::SftpClientStopping;
#[derive(Default, Clone)]
pub struct SftpClient {
commands: Option<mpsc::UnboundedSender<receiver::Request>>,
request_processor: Option<Arc<JoinHandle<()>>>,
}
pub static SFTP_CLIENT_STOPPED: SftpClient = SftpClient::new_stopped();
impl SftpClient {
pub const fn new_stopped() -> Self {
Self {
commands: None,
request_processor: None,
}
}
pub async fn new<T: ToSftpChannel>(ssh: T) -> Result<Self, Error> {
Self::with_channel(ssh.to_sftp_channel().await?).await
}
pub async fn with_channel(mut channel: Channel<Msg>) -> Result<Self, Error> {
channel.request_subsystem(false, "sftp").await?;
let init_message = Message::Init(Init {
version: 3,
extensions: Default::default(),
});
let init_frame = init_message.encode(0)?;
channel.data(init_frame.as_ref()).await?;
loop {
match channel.wait().await {
Some(ChannelMsg::Data { data }) => {
match Message::decode(data.as_ref()) {
Ok((
_,
Message::Version(Version {
version: 3,
extensions: _,
}),
)) => break,
Ok((_, Message::Version(_))) => {
return Err(StatusCode::BadMessage
.to_status("Invalid sftp version".into())
.into());
}
Ok(_) => {
return Err(StatusCode::BadMessage
.to_status("Bad SFTP init".into())
.into());
}
Err(err) => {
return Err(err.into());
}
}
}
Some(_) => (),
None => {
return Err(StatusCode::BadMessage
.to_status("Failed to start SFTP subsystem".into())
.into());
}
}
}
let (receiver, tx) = receiver::Receiver::new(channel);
let request_processor = tokio::spawn(receiver.run());
Ok(Self {
commands: Some(tx),
request_processor: Some(Arc::new(request_processor)),
})
}
}
impl std::fmt::Debug for SftpClient {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "SftpClient")
}
}
#[async_trait]
pub trait ToSftpChannel {
async fn to_sftp_channel(self) -> Result<Channel<Msg>, Error>;
}
#[async_trait]
impl ToSftpChannel for Channel<Msg> {
async fn to_sftp_channel(self) -> Result<Channel<Msg>, Error> {
Ok(self)
}
}
#[async_trait]
impl<H: russh::client::Handler> ToSftpChannel for &russh::client::Handle<H> {
async fn to_sftp_channel(self) -> Result<Channel<Msg>, Error> {
self.channel_open_session().await.map_err(Into::into)
}
}
#[async_trait]
impl<H: russh::client::Handler> ToSftpChannel for russh::client::Handle<H> {
async fn to_sftp_channel(self) -> Result<Channel<Msg>, Error> {
(&self).to_sftp_channel().await
}
}