use std::sync::Arc;
use async_trait::async_trait;
use russh::ChannelStream;
use russh::{client::Msg, Channel};
use tokio::io::AsyncWrite;
use tokio::task::JoinHandle;
use tokio::{io::AsyncRead, sync::mpsc};
use crate::message::{Init, Message, StatusCode, Version};
mod commands;
mod dir;
mod error;
mod file;
mod receiver;
mod request;
mod stop;
pub use dir::{Dir, DIR_CLOSED};
pub use error::Error;
pub use file::{File, FILE_CLOSED};
pub use request::{SftpFuture, SftpReply, SftpRequest};
use stop::SftpClientStopping;
#[derive(Default, Clone)]
pub struct SftpClient {
commands: Option<mpsc::UnboundedSender<receiver::Request>>,
request_processor: Option<Arc<JoinHandle<()>>>,
}
pub static SFTP_CLIENT_STOPPED: SftpClient = SftpClient::new_stopped();
impl SftpClient {
pub const fn new_stopped() -> Self {
Self {
commands: None,
request_processor: None,
}
}
pub async fn new<T: IntoSftpStream>(ssh: T) -> Result<Self, Error> {
Self::with_stream(ssh.into_sftp_stream().await?).await
}
pub async fn with_stream(
mut stream: impl AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static,
) -> Result<Self, Error> {
receiver::write_msg(
&mut stream,
Message::Init(Init {
version: 3,
extensions: Default::default(),
}),
3,
)
.await?;
match receiver::read_msg(&mut stream).await? {
(
_,
Message::Version(Version {
version: 3,
extensions: _,
}),
) => (),
(_, Message::Version(_)) => {
return Err(StatusCode::BadMessage
.to_status("Invalid sftp version")
.into());
}
_ => {
return Err(StatusCode::BadMessage.to_status("Bad SFTP init").into());
}
}
let (receiver, tx) = receiver::Receiver::new(stream);
let request_processor = tokio::spawn(receiver.run());
Ok(Self {
commands: Some(tx),
request_processor: Some(Arc::new(request_processor)),
})
}
}
impl std::fmt::Debug for SftpClient {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "SftpClient")
}
}
#[async_trait]
pub trait IntoSftpStream {
type Stream: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static;
async fn into_sftp_stream(self) -> Result<Self::Stream, Error>;
}
#[async_trait]
impl IntoSftpStream for ChannelStream<Msg> {
type Stream = ChannelStream<Msg>;
async fn into_sftp_stream(self) -> Result<Self::Stream, Error> {
Ok(self)
}
}
#[async_trait]
impl IntoSftpStream for Channel<Msg> {
type Stream = ChannelStream<Msg>;
async fn into_sftp_stream(self) -> Result<Self::Stream, Error> {
self.request_subsystem(false, "sftp").await?;
Ok(self.into_stream())
}
}
#[async_trait]
impl<H: russh::client::Handler> IntoSftpStream for &russh::client::Handle<H> {
type Stream = ChannelStream<Msg>;
async fn into_sftp_stream(self) -> Result<Self::Stream, Error> {
self.channel_open_session().await?.into_sftp_stream().await
}
}
#[async_trait]
impl<H: russh::client::Handler> IntoSftpStream for russh::client::Handle<H> {
type Stream = ChannelStream<Msg>;
async fn into_sftp_stream(self) -> Result<Self::Stream, Error> {
(&self).into_sftp_stream().await
}
}