use crate::{
cancel_error,
file::{utility::take_io_slices, File},
lowlevel::{AwaitableDataFuture, AwaitableStatusFuture, Handle},
Buffer, Data, Error, Id, WriteEnd,
};
use std::{
borrow::Cow,
cmp::{max, min},
collections::VecDeque,
convert::TryInto,
future::Future,
io::{self, IoSlice},
mem,
num::{NonZeroU32, NonZeroUsize},
ops::{Deref, DerefMut},
pin::Pin,
task::{Context, Poll},
};
use bytes::{Buf, Bytes, BytesMut};
use derive_destructure2::destructure;
use pin_project::{pin_project, pinned_drop};
use tokio::io::{AsyncBufRead, AsyncRead, AsyncSeek, AsyncWrite, ReadBuf};
use tokio_io_utility::ready;
use tokio_util::sync::WaitForCancellationFutureOwned;
pub const DEFAULT_BUFLEN: NonZeroUsize = unsafe { NonZeroUsize::new_unchecked(4096) };
fn sftp_to_io_error(sftp_err: Error) -> io::Error {
match sftp_err {
Error::IOError(io_error) => io_error,
sftp_err => io::Error::new(io::ErrorKind::Other, sftp_err),
}
}
fn send_request<Func, R>(file: &mut File, f: Func) -> Result<R, Error>
where
Func: FnOnce(&mut WriteEnd, Id, Cow<'_, Handle>, u64) -> Result<R, Error>,
{
let id = file.inner.get_id_mut();
let offset = file.offset;
let (write_end, handle) = file.get_inner();
let awaitable = f(write_end, id, handle, offset)?;
write_end.get_auxiliary().wakeup_flush_task();
Ok(awaitable)
}
#[derive(Debug, destructure)]
#[pin_project(PinnedDrop)]
pub struct TokioCompatFile {
inner: File,
buffer_len: NonZeroUsize,
buffer: BytesMut,
write_len: usize,
read_future: Option<AwaitableDataFuture<Buffer>>,
write_futures: VecDeque<WriteFutureElement>,
#[pin]
cancellation_future: WaitForCancellationFutureOwned,
}
#[derive(Debug)]
struct WriteFutureElement {
future: AwaitableStatusFuture<Buffer>,
write_len: usize,
}
impl TokioCompatFile {
pub fn new(inner: File) -> Self {
Self::with_capacity(inner, DEFAULT_BUFLEN)
}
pub fn with_capacity(inner: File, buffer_len: NonZeroUsize) -> Self {
Self {
cancellation_future: inner.get_auxiliary().cancel_token.clone().cancelled_owned(),
inner,
buffer: BytesMut::new(),
buffer_len,
write_len: 0,
read_future: None,
write_futures: VecDeque::new(),
}
}
pub fn into_inner(self) -> File {
self.destructure().0
}
pub fn capacity(&self) -> usize {
self.buffer.capacity()
}
pub fn reserve(&mut self, new_cap: usize) {
let curr_cap = self.capacity();
if curr_cap < new_cap {
self.buffer.reserve(new_cap - curr_cap);
}
}
pub fn shrink_to(&mut self, new_cap: usize) {
let curr_cap = self.capacity();
if curr_cap > new_cap {
self.buffer = BytesMut::with_capacity(new_cap);
}
}
pub async fn fill_buf(mut self: Pin<&mut Self>) -> Result<(), Error> {
let this = self.as_mut().project();
if this.buffer.is_empty() {
let buffer_len = this.buffer_len.get().try_into().unwrap_or(u32::MAX);
let buffer_len = NonZeroU32::new(buffer_len).unwrap();
self.read_into_buffer(buffer_len).await?;
}
Ok(())
}
pub fn consume_and_return_buffer(&mut self, amt: usize) -> Bytes {
let buffer = &mut self.buffer;
let amt = min(amt, buffer.len());
let bytes = self.buffer.split_to(amt).freeze();
self.offset += amt as u64;
bytes
}
pub fn poll_read_into_buffer(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
amt: NonZeroU32,
) -> Poll<Result<(), Error>> {
let this = self.project();
this.inner.check_for_readable()?;
let max_read_len = this.inner.max_read_len_impl();
let amt = min(amt.get(), max_read_len);
let future = if let Some(future) = this.read_future {
future
} else {
this.buffer.reserve(amt as usize);
let cap = this.buffer.capacity();
let buffer = this.buffer.split_off(cap - (amt as usize));
let future = send_request(this.inner, |write_end, id, handle, offset| {
write_end.send_read_request(id, handle, offset, amt, Some(buffer))
})?
.wait();
*this.read_future = Some(future);
this.read_future
.as_mut()
.expect("FileFuture::Data is just assigned to self.future!")
};
if this.cancellation_future.poll(cx).is_ready() {
return Poll::Ready(Err(cancel_error()));
}
let res = ready!(Pin::new(future).poll(cx));
*this.read_future = None;
let (id, data) = res?;
this.inner.inner.cache_id_mut(id);
match data {
Data::Buffer(buffer) => {
debug_assert!(!buffer.is_empty());
debug_assert!(buffer.len() <= max_read_len as usize);
this.buffer.unsplit(buffer);
}
Data::Eof => return Poll::Ready(Ok(())),
_ => std::unreachable!("Expect Data::Buffer"),
};
Poll::Ready(Ok(()))
}
pub async fn read_into_buffer(self: Pin<&mut Self>, amt: NonZeroU32) -> Result<(), Error> {
#[must_use]
struct ReadIntoBuffer<'a>(Pin<&'a mut TokioCompatFile>, NonZeroU32);
impl Future for ReadIntoBuffer<'_> {
type Output = Result<(), Error>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let amt = self.1;
self.0.as_mut().poll_read_into_buffer(cx, amt)
}
}
ReadIntoBuffer(self, amt).await
}
pub fn as_mut_file(self: Pin<&mut Self>) -> &mut File {
self.project().inner
}
fn flush_pending_requests(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Result<(), std::io::Error> {
let this = self.project();
if this.inner.need_flush {
if this.inner.auxiliary().get_pending_requests() != 0 {
this.inner.auxiliary().trigger_flushing();
}
this.inner.need_flush = false;
}
if this.cancellation_future.poll(cx).is_ready() {
return Err(sftp_to_io_error(cancel_error()));
}
Ok(())
}
fn flush_one(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
self.as_mut().flush_pending_requests(cx)?;
let this = self.project();
let res = if let Some(element) = this.write_futures.front_mut() {
let res = ready!(Pin::new(&mut element.future).poll(cx));
*this.write_len -= element.write_len;
res
} else {
debug_assert_eq!(*this.write_len, 0);
return Poll::Ready(Ok(()));
};
this.write_futures
.pop_front()
.expect("futures should have at least one elements in it");
this.inner
.inner
.cache_id_mut(res.map_err(sftp_to_io_error)?.0);
Poll::Ready(Ok(()))
}
}
impl From<File> for TokioCompatFile {
fn from(inner: File) -> Self {
Self::new(inner)
}
}
impl From<TokioCompatFile> for File {
fn from(file: TokioCompatFile) -> Self {
file.into_inner()
}
}
impl Clone for TokioCompatFile {
fn clone(&self) -> Self {
Self::with_capacity(self.inner.clone(), self.buffer_len)
}
}
impl Deref for TokioCompatFile {
type Target = File;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
impl DerefMut for TokioCompatFile {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.inner
}
}
impl AsyncSeek for TokioCompatFile {
fn start_seek(mut self: Pin<&mut Self>, position: io::SeekFrom) -> io::Result<()> {
let this = self.as_mut().project();
let prev_offset = this.inner.offset();
Pin::new(&mut *this.inner).start_seek(position)?;
let new_offset = this.inner.offset();
if new_offset != prev_offset {
*this.read_future = None;
if new_offset < prev_offset {
this.buffer.clear();
} else if let Ok(offset) = (new_offset - prev_offset).try_into() {
if offset > this.buffer.len() {
this.buffer.clear();
} else {
this.buffer.advance(offset);
}
} else {
this.buffer.clear();
}
}
Ok(())
}
fn poll_complete(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
Pin::new(self.project().inner).poll_complete(cx)
}
}
impl AsyncBufRead for TokioCompatFile {
fn poll_fill_buf(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
let this = self.as_mut().project();
if this.buffer.is_empty() {
let buffer_len = this.buffer_len.get().try_into().unwrap_or(u32::MAX);
let buffer_len = NonZeroU32::new(buffer_len).unwrap();
ready!(self.as_mut().poll_read_into_buffer(cx, buffer_len))
.map_err(sftp_to_io_error)?;
}
Poll::Ready(Ok(self.project().buffer))
}
fn consume(self: Pin<&mut Self>, amt: usize) {
let this = self.project();
let buffer = this.buffer;
buffer.advance(amt);
this.inner.offset += amt as u64;
}
}
impl AsyncRead for TokioCompatFile {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
read_buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
self.check_for_readable_io_err()?;
let remaining = read_buf.remaining();
if remaining == 0 {
return Poll::Ready(Ok(()));
}
if self.buffer.is_empty() {
let n = max(remaining, DEFAULT_BUFLEN.get());
let n = n.try_into().unwrap_or(u32::MAX);
let n = NonZeroU32::new(n).unwrap();
ready!(self.as_mut().poll_read_into_buffer(cx, n)).map_err(sftp_to_io_error)?;
}
let n = min(remaining, self.buffer.len());
read_buf.put_slice(&self.buffer[..n]);
self.consume(n);
Poll::Ready(Ok(()))
}
}
impl AsyncWrite for TokioCompatFile {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
self.check_for_writable_io_err()?;
if buf.is_empty() {
return Poll::Ready(Ok(0));
}
let max_write_len = self.max_write_len_impl();
let mut n: u32 = buf
.len()
.try_into()
.map(|n| min(n, max_write_len))
.unwrap_or(max_write_len);
let write_limit = self.get_auxiliary().tokio_compat_file_write_limit();
let mut write_len = self.write_len;
if write_len == write_limit {
ready!(self.as_mut().flush_one(cx))?;
write_len = self.write_len;
}
let new_write_len = match write_len.checked_add(n as usize) {
Some(new_write_len) if new_write_len > write_limit => {
n = (write_limit - write_len).try_into().unwrap();
write_limit
}
None => {
n = (write_limit - write_len).try_into().unwrap();
write_limit
}
Some(new_write_len) => new_write_len,
};
let buf = &buf[..(n as usize)];
let this = self.as_mut().project();
let file = this.inner;
let future = send_request(file, |write_end, id, handle, offset| {
write_end.send_write_request_buffered(id, handle, offset, Cow::Borrowed(buf))
})
.map_err(sftp_to_io_error)?
.wait();
file.need_flush = true;
this.write_futures.push_back(WriteFutureElement {
future,
write_len: n as usize,
});
*self.as_mut().project().write_len = new_write_len;
Poll::Ready(
self.start_seek(io::SeekFrom::Current(n as i64))
.map(|_| n as usize),
)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.check_for_writable_io_err()?;
if self.as_mut().project().write_futures.is_empty() {
return Poll::Ready(Ok(()));
}
self.as_mut().flush_pending_requests(cx)?;
let this = self.project();
loop {
let res = if let Some(element) = this.write_futures.front_mut() {
let res = ready!(Pin::new(&mut element.future).poll(cx));
*this.write_len -= element.write_len;
res
} else {
debug_assert_eq!(*this.write_len, 0);
break Poll::Ready(Ok(()));
};
this.write_futures
.pop_front()
.expect("futures should have at least one elements in it");
this.inner
.inner
.cache_id_mut(res.map_err(sftp_to_io_error)?.0);
}
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.poll_flush(cx)
}
fn poll_write_vectored(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[IoSlice<'_>],
) -> Poll<io::Result<usize>> {
self.check_for_writable_io_err()?;
if bufs.is_empty() {
return Poll::Ready(Ok(0));
}
let max_write_len = self.max_write_len_impl();
let n = if let Some(res) = take_io_slices(bufs, max_write_len as usize) {
res.0
} else {
return Poll::Ready(Ok(0));
};
let mut n: u32 = n.try_into().unwrap();
let write_limit = self.get_auxiliary().tokio_compat_file_write_limit();
let mut write_len = self.write_len;
if write_len == write_limit {
ready!(self.as_mut().flush_one(cx))?;
write_len = self.write_len;
}
let new_write_len = match write_len.checked_add(n as usize) {
Some(new_write_len) if new_write_len > write_limit => {
n = (write_limit - write_len).try_into().unwrap();
write_limit
}
None => {
n = (write_limit - write_len).try_into().unwrap();
write_limit
}
Some(new_write_len) => new_write_len,
};
let (_, bufs, buf) = take_io_slices(bufs, n as usize).unwrap();
let buffers = [bufs, &buf];
let this = self.as_mut().project();
let file = this.inner;
let future = send_request(file, |write_end, id, handle, offset| {
write_end.send_write_request_buffered_vectored2(id, handle, offset, &buffers)
})
.map_err(sftp_to_io_error)?
.wait();
file.need_flush = true;
this.write_futures.push_back(WriteFutureElement {
future,
write_len: n as usize,
});
*self.as_mut().project().write_len = new_write_len;
Poll::Ready(
self.start_seek(io::SeekFrom::Current(n as i64))
.map(|_| n as usize),
)
}
fn is_write_vectored(&self) -> bool {
true
}
}
impl TokioCompatFile {
async fn do_drop(
mut file: File,
read_future: Option<AwaitableDataFuture<Buffer>>,
write_futures: VecDeque<WriteFutureElement>,
) {
if let Some(read_future) = read_future {
if let Ok((id, _)) = read_future.await {
file.inner.cache_id_mut(id);
}
}
for write_element in write_futures {
match write_element.future.await {
Ok((id, _)) => file.inner.cache_id_mut(id),
Err(_err) => {
#[cfg(feature = "tracing")]
tracing::error!(?_err, "failed to write to File")
}
}
}
if let Err(_err) = file.close().await {
#[cfg(feature = "tracing")]
tracing::error!(?_err, "failed to close handle");
}
}
}
#[pinned_drop]
impl PinnedDrop for TokioCompatFile {
fn drop(mut self: Pin<&mut Self>) {
let this = self.as_mut().project();
let file = this.inner.clone();
let read_future = this.read_future.take();
let write_futures = mem::take(this.write_futures);
let cancellation_fut = self.auxiliary().cancel_token.clone().cancelled_owned();
let do_drop_fut = Self::do_drop(file, read_future, write_futures);
self.auxiliary().tokio_handle().spawn(async move {
tokio::select! {
biased;
_ = cancellation_fut => (),
_ = do_drop_fut => (),
}
});
}
}