use crate::{
auxiliary,
file::{File, OpenOptions},
fs::Fs,
lowlevel, tasks, Error, MpscQueue, SftpOptions, SharedData, WriteEnd, WriteEndWithCachedId,
};
use auxiliary::Auxiliary;
use lowlevel::{connect, Extensions};
use tasks::{create_flush_task, create_read_task};
use std::{cmp::min, convert::TryInto, path::Path, sync::atomic::Ordering};
use derive_destructure2::destructure;
use tokio::{
io::{AsyncRead, AsyncWrite},
sync::oneshot::Receiver,
task::JoinHandle,
};
use tokio_io_utility::assert_send;
#[derive(Debug, destructure)]
pub struct Sftp {
shared_data: SharedData,
flush_task: JoinHandle<Result<(), Error>>,
read_task: JoinHandle<Result<(), Error>>,
}
impl Sftp {
pub async fn new<W: AsyncWrite + Send + 'static, R: AsyncRead + Send + 'static>(
stdin: W,
stdout: R,
options: SftpOptions,
) -> Result<Self, Error> {
assert_send(async move {
let write_end_buffer_size = options.get_write_end_buffer_size();
let write_end = assert_send(Self::connect(
write_end_buffer_size.get(),
options.get_max_pending_requests(),
))
.await?;
let flush_task = create_flush_task(
stdin,
SharedData::clone(&write_end),
write_end_buffer_size,
options.get_flush_interval(),
);
let (rx, read_task) = create_read_task(
stdout,
options.get_read_end_buffer_size(),
SharedData::clone(&write_end),
);
Self::init(flush_task, read_task, write_end, rx, &options).await
})
.await
}
async fn connect(
write_end_buffer_size: usize,
max_pending_requests: u16,
) -> Result<WriteEnd, Error> {
connect(
MpscQueue::with_capacity(write_end_buffer_size),
Auxiliary::new(max_pending_requests),
)
.await
}
async fn init(
flush_task: JoinHandle<Result<(), Error>>,
read_task: JoinHandle<Result<(), Error>>,
write_end: WriteEnd,
rx: Receiver<Extensions>,
options: &SftpOptions,
) -> Result<Self, Error> {
let sftp = Self {
shared_data: SharedData::clone(&write_end),
flush_task,
read_task,
};
let extensions = if let Ok(extensions) = rx.await {
extensions
} else {
drop(write_end);
sftp.close().await?;
std::unreachable!("Error must have occurred in either read_task or flush_task")
};
match sftp.set_limits(write_end, options, extensions).await {
Err(Error::BackgroundTaskFailure(_)) => {
sftp.close().await?;
std::unreachable!("Error must have occurred in either read_task or flush_task")
}
res => res?,
}
Ok(sftp)
}
async fn set_limits(
&self,
write_end: WriteEnd,
options: &SftpOptions,
extensions: Extensions,
) -> Result<(), Error> {
let mut write_end = WriteEndWithCachedId::new(self, write_end);
let default_download_buflen = lowlevel::OPENSSH_PORTABLE_DEFAULT_DOWNLOAD_BUFLEN as u64;
let default_upload_buflen = lowlevel::OPENSSH_PORTABLE_DEFAULT_UPLOAD_BUFLEN as u64;
let default_max_packet_len = u32::MAX - 9;
let (read_len, write_len, packet_len) = if extensions.contains(Extensions::LIMITS) {
let mut limits = write_end
.send_request(|write_end, id| Ok(write_end.send_limits_request(id)?.wait()))
.await?;
if limits.read_len == 0 {
limits.read_len = default_download_buflen;
}
if limits.write_len == 0 {
limits.write_len = default_upload_buflen;
}
(
limits.read_len,
limits.write_len,
limits
.packet_len
.try_into()
.unwrap_or(default_max_packet_len),
)
} else {
(
default_download_buflen,
default_upload_buflen,
default_max_packet_len,
)
};
let read_len = read_len.try_into().unwrap_or(packet_len - 300);
let read_len = options
.get_max_read_len()
.map(|v| min(v, read_len))
.unwrap_or(read_len);
let write_len = write_len.try_into().unwrap_or(packet_len - 300);
let write_len = options
.get_max_write_len()
.map(|v| min(v, write_len))
.unwrap_or(write_len);
let limits = auxiliary::Limits {
read_len,
write_len,
};
write_end
.get_auxiliary()
.conn_info
.set(auxiliary::ConnInfo { limits, extensions })
.expect("auxiliary.conn_info shall be uninitialized");
Ok(())
}
pub async fn close(self) -> Result<(), Error> {
let (shared_data, flush_task, read_task) = self.destructure();
shared_data.get_auxiliary().order_shutdown();
read_task.await??;
flush_task.await??;
Ok(())
}
pub fn options(&self) -> OpenOptions<'_> {
OpenOptions::new(self)
}
pub async fn create(&self, path: impl AsRef<Path>) -> Result<File<'_>, Error> {
async fn inner<'s>(this: &'s Sftp, path: &Path) -> Result<File<'s>, Error> {
this.options()
.write(true)
.create(true)
.truncate(true)
.open(path)
.await
}
inner(self, path.as_ref()).await
}
pub async fn open(&self, path: impl AsRef<Path>) -> Result<File<'_>, Error> {
async fn inner<'s>(this: &'s Sftp, path: &Path) -> Result<File<'s>, Error> {
this.options().read(true).open(path).await
}
inner(self, path.as_ref()).await
}
pub fn fs(&self) -> Fs<'_> {
Fs::new(self.write_end(), "".into())
}
}
impl Sftp {
pub(super) fn write_end(&self) -> WriteEndWithCachedId<'_> {
WriteEndWithCachedId::new(self, WriteEnd::new(self.shared_data.clone()))
}
pub(super) fn auxiliary(&self) -> &Auxiliary {
self.shared_data.get_auxiliary()
}
pub(super) fn trigger_flushing(&self) {
self.auxiliary().flush_immediately.notify_one();
}
pub(super) fn get_pending_requests(&self) -> usize {
self.auxiliary().pending_requests.load(Ordering::Relaxed)
}
}
#[cfg(feature = "ci-tests")]
impl Sftp {
pub fn max_write_len(&self) -> u32 {
self.shared_data.get_auxiliary().limits().write_len
}
pub fn max_read_len(&self) -> u32 {
self.shared_data.get_auxiliary().limits().read_len
}
}
impl Drop for Sftp {
fn drop(&mut self) {
self.shared_data.get_auxiliary().order_shutdown();
}
}