use std::future::Future;
use std::pin::Pin;
use std::task::{ready, Poll};
use tokio::sync::oneshot;
use crate::client::{Error, SftpClient};
use crate::message::{self, Message, Status, StatusCode};
impl SftpClient {
pub fn request<R: SftpRequest>(&self, request: R) -> SftpFuture<R::Reply> {
self.request_with(
request.to_request_message(),
(),
stateless_from_reply_message::<R::Reply>,
)
}
pub fn request_with<S, T>(
&self,
request: Result<Message, Error>,
state: S,
f: fn(S, Message) -> Result<T, Error>,
) -> SftpFuture<T, S> {
if let Some(commands) = &self.commands {
match request {
Ok(Message::Status(Status {
code: StatusCode::Ok,
..
})) => SftpFuture::Error(
StatusCode::BadMessage
.to_status("Tried to send an OK status message to the server")
.into(),
),
Ok(Message::Status(status)) => SftpFuture::Error(status.into()),
Ok(msg) => {
let (tx, rx) = oneshot::channel();
log::trace!("Sending: {msg:?}");
match commands.send(super::receiver::Request(msg, tx)) {
Ok(()) => SftpFuture::Pending {
future: rx,
state,
f,
},
Err(err) => {
SftpFuture::Error(StatusCode::Failure.to_status(err.to_string()).into())
}
}
}
Err(err) => SftpFuture::Error(err),
}
} else {
SftpFuture::Error(
std::io::Error::new(
std::io::ErrorKind::BrokenPipe,
"SFTP client has been stopped",
)
.into(),
)
}
}
}
pub enum SftpFuture<Output = (), State = ()> {
Error(Error),
Pending {
future: tokio::sync::oneshot::Receiver<Result<Message, Error>>,
state: State,
f: fn(State, Message) -> Result<Output, Error>,
},
Polled,
}
impl<Output, State> Future for SftpFuture<Output, State>
where
State: Unpin,
{
type Output = Result<Output, Error>;
fn poll(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Self::Output> {
match &mut *self {
SftpFuture::Error(_) => {
let SftpFuture::Error(err) = std::mem::replace(&mut *self, SftpFuture::Polled)
else {
unreachable!()
};
Poll::Ready(Err(err))
}
SftpFuture::Pending { future, .. } => {
let result = match ready!(Pin::new(future).poll(cx)) {
Ok(Ok(msg)) => {
let SftpFuture::Pending { state, f, .. } =
std::mem::replace(&mut *self, SftpFuture::Polled)
else {
unreachable!()
};
f(state, msg)
}
Ok(Err(err)) => Err(err),
Err(_) => Err(Error::Io(std::io::Error::new(
std::io::ErrorKind::ConnectionReset,
"Could not get reply from SFTP client",
))),
};
*self = SftpFuture::Polled;
Poll::Ready(result)
}
SftpFuture::Polled => panic!("Duplicated poll"),
}
}
}
pub trait SftpRequest {
type Reply: SftpReply;
fn to_request_message(self) -> Result<Message, Error>;
}
pub trait SftpReply: Sized {
fn from_reply_message(msg: Message) -> Result<Self, Error>;
}
impl SftpRequest for Message {
type Reply = Message;
fn to_request_message(self) -> Result<Message, Error> {
Ok(self)
}
}
impl SftpReply for Message {
fn from_reply_message(msg: Message) -> Result<Self, Error> {
Ok(msg)
}
}
impl SftpReply for () {
fn from_reply_message(msg: Message) -> Result<Self, Error> {
match msg {
Message::Status(Status {
code: StatusCode::Ok,
..
}) => Ok(()),
Message::Status(status) => Err(status),
_ => Err(StatusCode::BadMessage.to_status("Expected a status")),
}
.map_err(Into::into)
}
}
macro_rules! request_impl {
($input:ident) => {
impl SftpRequest for message::$input {
type Reply = ();
fn to_request_message(self) -> Result<Message, Error> {
Ok(self.into())
}
}
};
($input:ident -> $output:ident) => {
impl SftpRequest for message::$input {
type Reply = message::$output;
fn to_request_message(self) -> Result<Message, Error> {
Ok(self.into())
}
}
};
}
macro_rules! reply_impl {
($output:ident) => {
impl SftpReply for message::$output {
fn from_reply_message(msg: Message) -> Result<Self, Error> {
match msg {
Message::$output(response) => Ok(response),
Message::Status(status) => Err(status),
_ => Err(StatusCode::BadMessage
.to_status(std::stringify!(Expected a $output or a Status))),
}.map_err(Into::into)
}
}
};
}
request_impl!(Open -> Handle);
request_impl!(Close);
request_impl!(Read -> Data);
request_impl!(Write);
request_impl!(LStat -> Attrs);
request_impl!(FStat -> Attrs);
request_impl!(SetStat);
request_impl!(FSetStat);
request_impl!(OpenDir -> Handle);
request_impl!(ReadDir -> Name);
request_impl!(Remove);
request_impl!(MkDir);
request_impl!(RmDir);
request_impl!(RealPath -> Name);
request_impl!(Stat -> Attrs);
request_impl!(Rename);
request_impl!(ReadLink -> Name);
request_impl!(Symlink);
request_impl!(Extended -> ExtendedReply);
reply_impl!(Attrs);
reply_impl!(Data);
reply_impl!(Handle);
reply_impl!(Name);
reply_impl!(ExtendedReply);
fn stateless_from_reply_message<R: SftpReply>(_: (), msg: Message) -> Result<R, Error> {
R::from_reply_message(msg)
}