openssh_sftp_client/file/
tokio_compat_file.rs1use crate::{
2 cancel_error,
3 file::{utility::take_io_slices, File},
4 lowlevel::{AwaitableDataFuture, AwaitableStatusFuture, Handle},
5 Buffer, Data, Error, Id, WriteEnd,
6};
7
8use std::{
9 borrow::Cow,
10 cmp::{max, min},
11 collections::VecDeque,
12 convert::TryInto,
13 future::Future,
14 io::{self, IoSlice},
15 mem,
16 num::{NonZeroU32, NonZeroUsize},
17 ops::{Deref, DerefMut},
18 pin::Pin,
19 task::{Context, Poll},
20};
21
22use bytes::{Buf, Bytes, BytesMut};
23use derive_destructure2::destructure;
24use pin_project::{pin_project, pinned_drop};
25use tokio::io::{AsyncBufRead, AsyncRead, AsyncSeek, AsyncWrite, ReadBuf};
26use tokio_io_utility::ready;
27use tokio_util::sync::WaitForCancellationFutureOwned;
28
29pub const DEFAULT_BUFLEN: NonZeroUsize = unsafe { NonZeroUsize::new_unchecked(4096) };
31
32fn sftp_to_io_error(sftp_err: Error) -> io::Error {
33 match sftp_err {
34 Error::IOError(io_error) => io_error,
35 sftp_err => io::Error::new(io::ErrorKind::Other, sftp_err),
36 }
37}
38
39fn send_request<Func, R>(file: &mut File, f: Func) -> Result<R, Error>
40where
41 Func: FnOnce(&mut WriteEnd, Id, Cow<'_, Handle>, u64) -> Result<R, Error>,
42{
43 let id = file.inner.get_id_mut();
45 let offset = file.offset;
46
47 let (write_end, handle) = file.get_inner();
48
49 let awaitable = f(write_end, id, handle, offset)?;
51
52 write_end.get_auxiliary().wakeup_flush_task();
55
56 Ok(awaitable)
57}
58
59#[derive(Debug, destructure)]
63#[pin_project(PinnedDrop)]
64pub struct TokioCompatFile {
65 inner: File,
66
67 buffer_len: NonZeroUsize,
68 buffer: BytesMut,
69
70 write_len: usize,
71
72 read_future: Option<AwaitableDataFuture<Buffer>>,
73 write_futures: VecDeque<WriteFutureElement>,
74
75 #[pin]
80 cancellation_future: WaitForCancellationFutureOwned,
81}
82
83#[derive(Debug)]
84struct WriteFutureElement {
85 future: AwaitableStatusFuture<Buffer>,
86 write_len: usize,
87}
88
89impl TokioCompatFile {
90 pub fn new(inner: File) -> Self {
92 Self::with_capacity(inner, DEFAULT_BUFLEN)
93 }
94
95 pub fn with_capacity(inner: File, buffer_len: NonZeroUsize) -> Self {
100 Self {
101 cancellation_future: inner.get_auxiliary().cancel_token.clone().cancelled_owned(),
102
103 inner,
104
105 buffer: BytesMut::new(),
106 buffer_len,
107
108 write_len: 0,
109
110 read_future: None,
111 write_futures: VecDeque::new(),
112 }
113 }
114
115 pub fn into_inner(self) -> File {
117 self.destructure().0
118 }
119
120 pub fn capacity(&self) -> usize {
125 self.buffer.capacity()
126 }
127
128 pub fn reserve(&mut self, new_cap: usize) {
131 let curr_cap = self.capacity();
132
133 if curr_cap < new_cap {
134 self.buffer.reserve(new_cap - curr_cap);
135 }
136 }
137
138 pub fn shrink_to(&mut self, new_cap: usize) {
141 let curr_cap = self.capacity();
142
143 if curr_cap > new_cap {
144 self.buffer = BytesMut::with_capacity(new_cap);
145 }
146 }
147
148 pub async fn fill_buf(mut self: Pin<&mut Self>) -> Result<(), Error> {
163 let this = self.as_mut().project();
164
165 if this.buffer.is_empty() {
166 let buffer_len = this.buffer_len.get().try_into().unwrap_or(u32::MAX);
167 let buffer_len = NonZeroU32::new(buffer_len).unwrap();
168
169 self.read_into_buffer(buffer_len).await?;
170 }
171
172 Ok(())
173 }
174
175 pub fn consume_and_return_buffer(&mut self, amt: usize) -> Bytes {
183 let buffer = &mut self.buffer;
184 let amt = min(amt, buffer.len());
185 let bytes = self.buffer.split_to(amt).freeze();
186
187 self.offset += amt as u64;
188
189 bytes
190 }
191
192 pub fn poll_read_into_buffer(
209 self: Pin<&mut Self>,
210 cx: &mut Context<'_>,
211 amt: NonZeroU32,
212 ) -> Poll<Result<(), Error>> {
213 let this = self.project();
216
217 this.inner.check_for_readable()?;
218
219 let max_read_len = this.inner.max_read_len_impl();
220 let amt = min(amt.get(), max_read_len);
221
222 let future = if let Some(future) = this.read_future {
223 future
230 } else {
231 this.buffer.reserve(amt as usize);
232 let cap = this.buffer.capacity();
233 let buffer = this.buffer.split_off(cap - (amt as usize));
234
235 let future = send_request(this.inner, |write_end, id, handle, offset| {
236 write_end.send_read_request(id, handle, offset, amt, Some(buffer))
237 })?
238 .wait();
239
240 *this.read_future = Some(future);
242 this.read_future
243 .as_mut()
244 .expect("FileFuture::Data is just assigned to self.future!")
245 };
246
247 if this.cancellation_future.poll(cx).is_ready() {
248 return Poll::Ready(Err(cancel_error()));
249 }
250
251 let res = ready!(Pin::new(future).poll(cx));
253 *this.read_future = None;
254 let (id, data) = res?;
255
256 this.inner.inner.cache_id_mut(id);
257 match data {
258 Data::Buffer(buffer) => {
259 debug_assert!(!buffer.is_empty());
262
263 debug_assert!(buffer.len() <= max_read_len as usize);
265
266 this.buffer.unsplit(buffer);
267 }
268 Data::Eof => return Poll::Ready(Ok(())),
269 _ => std::unreachable!("Expect Data::Buffer"),
270 };
271
272 Poll::Ready(Ok(()))
273 }
274
275 pub async fn read_into_buffer(self: Pin<&mut Self>, amt: NonZeroU32) -> Result<(), Error> {
292 #[must_use]
293 struct ReadIntoBuffer<'a>(Pin<&'a mut TokioCompatFile>, NonZeroU32);
294
295 impl Future for ReadIntoBuffer<'_> {
296 type Output = Result<(), Error>;
297
298 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
299 let amt = self.1;
300 self.0.as_mut().poll_read_into_buffer(cx, amt)
301 }
302 }
303
304 ReadIntoBuffer(self, amt).await
305 }
306
307 pub fn as_mut_file(self: Pin<&mut Self>) -> &mut File {
309 self.project().inner
310 }
311
312 fn flush_pending_requests(
313 self: Pin<&mut Self>,
314 cx: &mut Context<'_>,
315 ) -> Result<(), std::io::Error> {
316 let this = self.project();
317
318 if this.inner.need_flush {
320 if this.inner.auxiliary().get_pending_requests() != 0 {
322 this.inner.auxiliary().trigger_flushing();
323 }
324 this.inner.need_flush = false;
325 }
326
327 if this.cancellation_future.poll(cx).is_ready() {
328 return Err(sftp_to_io_error(cancel_error()));
329 }
330
331 Ok(())
332 }
333
334 fn flush_one(
335 mut self: Pin<&mut Self>,
336 cx: &mut Context<'_>,
337 ) -> Poll<Result<(), std::io::Error>> {
338 self.as_mut().flush_pending_requests(cx)?;
339
340 let this = self.project();
341
342 let res = if let Some(element) = this.write_futures.front_mut() {
343 let res = ready!(Pin::new(&mut element.future).poll(cx));
344 *this.write_len -= element.write_len;
345 res
346 } else {
347 debug_assert_eq!(*this.write_len, 0);
349 return Poll::Ready(Ok(()));
350 };
351
352 this.write_futures
353 .pop_front()
354 .expect("futures should have at least one elements in it");
355
356 this.inner
358 .inner
359 .cache_id_mut(res.map_err(sftp_to_io_error)?.0);
360
361 Poll::Ready(Ok(()))
362 }
363}
364
365impl From<File> for TokioCompatFile {
366 fn from(inner: File) -> Self {
367 Self::new(inner)
368 }
369}
370
371impl From<TokioCompatFile> for File {
372 fn from(file: TokioCompatFile) -> Self {
373 file.into_inner()
374 }
375}
376
377impl Clone for TokioCompatFile {
382 fn clone(&self) -> Self {
383 Self::with_capacity(self.inner.clone(), self.buffer_len)
384 }
385}
386
387impl Deref for TokioCompatFile {
388 type Target = File;
389
390 fn deref(&self) -> &Self::Target {
391 &self.inner
392 }
393}
394
395impl DerefMut for TokioCompatFile {
396 fn deref_mut(&mut self) -> &mut Self::Target {
397 &mut self.inner
398 }
399}
400
401impl AsyncSeek for TokioCompatFile {
402 fn start_seek(mut self: Pin<&mut Self>, position: io::SeekFrom) -> io::Result<()> {
403 let this = self.as_mut().project();
404
405 let prev_offset = this.inner.offset();
406 Pin::new(&mut *this.inner).start_seek(position)?;
407 let new_offset = this.inner.offset();
408
409 if new_offset != prev_offset {
410 *this.read_future = None;
412
413 if new_offset < prev_offset {
415 this.buffer.clear();
416 } else if let Ok(offset) = (new_offset - prev_offset).try_into() {
417 if offset > this.buffer.len() {
418 this.buffer.clear();
419 } else {
420 this.buffer.advance(offset);
421 }
422 } else {
423 this.buffer.clear();
424 }
425 }
426
427 Ok(())
428 }
429
430 fn poll_complete(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
431 Pin::new(self.project().inner).poll_complete(cx)
432 }
433}
434
435impl AsyncBufRead for TokioCompatFile {
436 fn poll_fill_buf(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
437 let this = self.as_mut().project();
438
439 if this.buffer.is_empty() {
440 let buffer_len = this.buffer_len.get().try_into().unwrap_or(u32::MAX);
441 let buffer_len = NonZeroU32::new(buffer_len).unwrap();
442
443 ready!(self.as_mut().poll_read_into_buffer(cx, buffer_len))
444 .map_err(sftp_to_io_error)?;
445 }
446
447 Poll::Ready(Ok(self.project().buffer))
448 }
449
450 fn consume(self: Pin<&mut Self>, amt: usize) {
451 let this = self.project();
452
453 let buffer = this.buffer;
454
455 buffer.advance(amt);
456 this.inner.offset += amt as u64;
457 }
458}
459
460impl AsyncRead for TokioCompatFile {
461 fn poll_read(
462 mut self: Pin<&mut Self>,
463 cx: &mut Context<'_>,
464 read_buf: &mut ReadBuf<'_>,
465 ) -> Poll<io::Result<()>> {
466 self.check_for_readable_io_err()?;
467
468 let remaining = read_buf.remaining();
469 if remaining == 0 {
470 return Poll::Ready(Ok(()));
471 }
472
473 if self.buffer.is_empty() {
474 let n = max(remaining, DEFAULT_BUFLEN.get());
475 let n = n.try_into().unwrap_or(u32::MAX);
476 let n = NonZeroU32::new(n).unwrap();
477
478 ready!(self.as_mut().poll_read_into_buffer(cx, n)).map_err(sftp_to_io_error)?;
479 }
480
481 let n = min(remaining, self.buffer.len());
482 read_buf.put_slice(&self.buffer[..n]);
483 self.consume(n);
484
485 Poll::Ready(Ok(()))
486 }
487}
488
489impl AsyncWrite for TokioCompatFile {
514 fn poll_write(
515 mut self: Pin<&mut Self>,
516 cx: &mut Context<'_>,
517 buf: &[u8],
518 ) -> Poll<io::Result<usize>> {
519 self.check_for_writable_io_err()?;
520
521 if buf.is_empty() {
522 return Poll::Ready(Ok(0));
523 }
524
525 let max_write_len = self.max_write_len_impl();
527
528 let mut n: u32 = buf
529 .len()
530 .try_into()
531 .map(|n| min(n, max_write_len))
532 .unwrap_or(max_write_len);
533
534 let write_limit = self.get_auxiliary().tokio_compat_file_write_limit();
535 let mut write_len = self.write_len;
536
537 if write_len == write_limit {
538 ready!(self.as_mut().flush_one(cx))?;
539 write_len = self.write_len;
540 }
541
542 let new_write_len = match write_len.checked_add(n as usize) {
543 Some(new_write_len) if new_write_len > write_limit => {
544 n = (write_limit - write_len).try_into().unwrap();
545 write_limit
546 }
547 None => {
548 n = (write_limit - write_len).try_into().unwrap();
553 write_limit
554 }
555 Some(new_write_len) => new_write_len,
556 };
557
558 let buf = &buf[..(n as usize)];
560
561 let this = self.as_mut().project();
562
563 let file = this.inner;
564
565 let future = send_request(file, |write_end, id, handle, offset| {
566 write_end.send_write_request_buffered(id, handle, offset, Cow::Borrowed(buf))
567 })
568 .map_err(sftp_to_io_error)?
569 .wait();
570
571 file.need_flush = true;
573
574 this.write_futures.push_back(WriteFutureElement {
575 future,
576 write_len: n as usize,
577 });
578
579 *self.as_mut().project().write_len = new_write_len;
580
581 Poll::Ready(
583 self.start_seek(io::SeekFrom::Current(n as i64))
584 .map(|_| n as usize),
585 )
586 }
587
588 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
589 self.check_for_writable_io_err()?;
590
591 if self.as_mut().project().write_futures.is_empty() {
592 return Poll::Ready(Ok(()));
593 }
594
595 self.as_mut().flush_pending_requests(cx)?;
596
597 let this = self.project();
598
599 loop {
600 let res = if let Some(element) = this.write_futures.front_mut() {
601 let res = ready!(Pin::new(&mut element.future).poll(cx));
602 *this.write_len -= element.write_len;
603 res
604 } else {
605 debug_assert_eq!(*this.write_len, 0);
607 break Poll::Ready(Ok(()));
608 };
609
610 this.write_futures
611 .pop_front()
612 .expect("futures should have at least one elements in it");
613
614 this.inner
616 .inner
617 .cache_id_mut(res.map_err(sftp_to_io_error)?.0);
618 }
619 }
620
621 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
622 self.poll_flush(cx)
623 }
624
625 fn poll_write_vectored(
626 mut self: Pin<&mut Self>,
627 cx: &mut Context<'_>,
628 bufs: &[IoSlice<'_>],
629 ) -> Poll<io::Result<usize>> {
630 self.check_for_writable_io_err()?;
631
632 if bufs.is_empty() {
633 return Poll::Ready(Ok(0));
634 }
635
636 let max_write_len = self.max_write_len_impl();
637
638 let n = if let Some(res) = take_io_slices(bufs, max_write_len as usize) {
639 res.0
640 } else {
641 return Poll::Ready(Ok(0));
642 };
643
644 let mut n: u32 = n.try_into().unwrap();
645
646 let write_limit = self.get_auxiliary().tokio_compat_file_write_limit();
647 let mut write_len = self.write_len;
648
649 if write_len == write_limit {
650 ready!(self.as_mut().flush_one(cx))?;
651 write_len = self.write_len;
652 }
653
654 let new_write_len = match write_len.checked_add(n as usize) {
655 Some(new_write_len) if new_write_len > write_limit => {
656 n = (write_limit - write_len).try_into().unwrap();
657 write_limit
658 }
659 None => {
660 n = (write_limit - write_len).try_into().unwrap();
665 write_limit
666 }
667 Some(new_write_len) => new_write_len,
668 };
669
670 let (_, bufs, buf) = take_io_slices(bufs, n as usize).unwrap();
671
672 let buffers = [bufs, &buf];
673
674 let this = self.as_mut().project();
677
678 let file = this.inner;
679
680 let future = send_request(file, |write_end, id, handle, offset| {
681 write_end.send_write_request_buffered_vectored2(id, handle, offset, &buffers)
682 })
683 .map_err(sftp_to_io_error)?
684 .wait();
685
686 file.need_flush = true;
688
689 this.write_futures.push_back(WriteFutureElement {
690 future,
691 write_len: n as usize,
692 });
693
694 *self.as_mut().project().write_len = new_write_len;
695
696 Poll::Ready(
698 self.start_seek(io::SeekFrom::Current(n as i64))
699 .map(|_| n as usize),
700 )
701 }
702
703 fn is_write_vectored(&self) -> bool {
704 true
705 }
706}
707
708impl TokioCompatFile {
709 async fn do_drop(
710 mut file: File,
711 read_future: Option<AwaitableDataFuture<Buffer>>,
712 write_futures: VecDeque<WriteFutureElement>,
713 ) {
714 if let Some(read_future) = read_future {
715 if let Ok((id, _)) = read_future.await {
718 file.inner.cache_id_mut(id);
719 }
720 }
721 for write_element in write_futures {
722 match write_element.future.await {
728 Ok((id, _)) => file.inner.cache_id_mut(id),
729 Err(_err) => {
730 #[cfg(feature = "tracing")]
731 tracing::error!(?_err, "failed to write to File")
732 }
733 }
734 }
735 if let Err(_err) = file.close().await {
736 #[cfg(feature = "tracing")]
737 tracing::error!(?_err, "failed to close handle");
738 }
739 }
740}
741
742#[pinned_drop]
746impl PinnedDrop for TokioCompatFile {
747 fn drop(mut self: Pin<&mut Self>) {
748 let this = self.as_mut().project();
749
750 let file = this.inner.clone();
751 let read_future = this.read_future.take();
752 let write_futures = mem::take(this.write_futures);
753
754 let cancellation_fut = self.auxiliary().cancel_token.clone().cancelled_owned();
755
756 let do_drop_fut = Self::do_drop(file, read_future, write_futures);
757
758 self.auxiliary().tokio_handle().spawn(async move {
759 tokio::select! {
760 biased;
761
762 _ = cancellation_fut => (),
763 _ = do_drop_fut => (),
764 }
765 });
766 }
767}