use std::collections::HashMap;
use std::pin::Pin;
use std::task::Poll;
use bytes::{Buf, Bytes, BytesMut};
use futures::{Stream, StreamExt};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::sync::{mpsc, oneshot};
use crate::client::Error;
use crate::message::{Message, StatusCode};
pub(super) type Response = Result<Message, Error>;
pub struct Request(pub(super) Message, pub(super) oneshot::Sender<Response>);
pub(super) struct Receiver<S> {
onflight: HashMap<u32, oneshot::Sender<Response>>,
next_id: u32,
commands: mpsc::UnboundedReceiver<Request>,
stream: S,
response_size: Option<u32>,
response_buffer: BytesMut,
}
impl<S> Receiver<S> {
pub(super) fn new(stream: S) -> (Self, mpsc::UnboundedSender<Request>) {
let (tx, rx) = mpsc::unbounded_channel();
(
Self {
onflight: HashMap::new(),
next_id: 0,
commands: rx,
stream,
response_size: None,
response_buffer: Default::default(),
},
tx,
)
}
}
pub enum StreamItem {
Request(Request),
Response(Bytes),
Error(std::io::Error),
}
impl<S: AsyncRead + AsyncWrite + Unpin> Stream for Receiver<S> {
type Item = StreamItem;
fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
match self.commands.poll_recv(cx) {
Poll::Ready(Some(request)) => {
return Poll::Ready(Some(StreamItem::Request(request)));
}
Poll::Ready(None) => {
if self.onflight.is_empty() {
return Poll::Ready(None);
}
}
Poll::Pending => (),
};
loop {
let new_len;
match self.response_size {
Some(response_size) => {
if self.response_buffer.len() >= response_size as usize {
self.response_size = None;
let response = self.response_buffer.split_to(response_size as usize);
return Poll::Ready(Some(StreamItem::Response(response.freeze())));
}
new_len = response_size as usize;
}
None => {
if self.response_buffer.len() >= std::mem::size_of::<u32>() {
let len = self.response_buffer.get_u32();
self.response_size = Some(len);
continue;
}
new_len = std::mem::size_of::<u32>();
}
}
let old_len = self.response_buffer.len();
let mut buffer = std::mem::take(&mut self.response_buffer);
buffer.resize(new_len.max(1024), 0);
let mut read_buf = tokio::io::ReadBuf::new(&mut buffer[old_len..]);
let read = Pin::new(&mut self.stream).poll_read(cx, &mut read_buf);
let len = read_buf.filled().len();
buffer.resize(old_len + len, 0);
self.response_buffer = buffer;
match read {
Poll::Ready(Ok(())) => (),
Poll::Ready(Err(err)) => {
return Poll::Ready(Some(StreamItem::Error(err)));
}
Poll::Pending => {
return Poll::Pending;
}
}
if len == old_len {
return Poll::Ready(None);
}
}
}
}
impl<S: AsyncRead + AsyncWrite + Unpin> Receiver<S> {
pub(super) async fn run(mut self) {
log::debug!("Start SFTP client");
while let Some(event) = self.next().await {
match event {
StreamItem::Request(Request(message, tx)) => {
self.next_id += 1;
let id = self.next_id;
log::trace!("Request #{id}: {message:?}");
match write_msg(&mut self.stream, message, id).await {
Ok(()) => {
self.onflight.insert(id, tx);
}
Err(err) => {
log::debug!("Could not send request #{id}: {err:?}");
send_response(tx, Err(err));
}
}
}
StreamItem::Response(response) => match Message::decode_raw(response.as_ref()) {
Ok((id, message)) => {
log::trace!("Response #{id}: {message:?}");
if let Some(tx) = self.onflight.remove(&id) {
send_response(tx, Ok(message));
} else {
log::error!("SFTP Error: Received a reply with an invalid id");
}
}
Err(err) => {
log::trace!("Failed to parse message: {response:?}: {err:?}");
if let Some(id) = err.id {
if let Some(tx) = self.onflight.remove(&id) {
send_response(tx, Err(err.into()));
} else {
log::error!("SFTP Error: Received a reply with an invalid id");
}
} else {
log::error!("SFTP Error: Received a bad reply");
}
}
},
StreamItem::Error(err) => {
log::error!("Error while waiting for SFTP response: {err:?}");
match err.kind() {
std::io::ErrorKind::WouldBlock => (),
std::io::ErrorKind::TimedOut => (),
std::io::ErrorKind::WriteZero => (),
std::io::ErrorKind::Interrupted => (),
std::io::ErrorKind::OutOfMemory => (),
_ => break,
}
}
}
}
for (_, tx) in self.onflight {
send_response(
tx,
Err(Error::Sftp(StatusCode::ConnectionLost.to_status(
"Could not receive response: SFTP stream stopped",
))),
);
}
self.commands.close();
if let Err(err) = self.stream.shutdown().await {
log::warn!("Error while closing SSH channel: {err:?}");
}
log::debug!("SFTP client stopped");
}
}
fn send_response(tx: oneshot::Sender<Response>, msg: Response) {
match tx.send(msg) {
Ok(()) => (),
Err(err) => {
log::error!("Could not send back message to client: {err:?}");
}
}
}
pub(super) async fn write_msg(
stream: &mut (impl AsyncWrite + Unpin),
msg: Message,
id: u32,
) -> Result<(), Error> {
let frame = msg.encode(id)?;
Ok(stream.write_all(frame.as_ref()).await?)
}
pub(super) async fn read_msg(
stream: &mut (impl AsyncRead + Unpin),
) -> Result<(u32, Message), Error> {
let length = stream.read_u32().await?;
let mut bytes = vec![0u8; length as usize];
stream.read_exact(bytes.as_mut_slice()).await?;
Ok(Message::decode_raw(bytes.as_slice())?)
}