use std::time::{Duration, Instant};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use crate::adapters::common::resolve_chunk_size;
use crate::file_transfer::{Compression, FileEvent, FileTransfer};
use crate::session::Session;
pub use crate::adapters::common::{
Progress, ProgressCallback, UploadError, UploadOptions, UploadStats,
};
pub async fn upload<T, S>(
transport: &mut T,
src: &mut S,
mut options: UploadOptions,
) -> Result<UploadStats, UploadError>
where
T: AsyncRead + AsyncWrite + Unpin,
S: AsyncRead + Unpin,
{
transport.write_all(b"M28B1\n").await?;
let mut session = Session::new();
session.connect(Instant::now());
drive_until_synced(transport, &mut session).await?;
let device_max = session.max_block_size().unwrap_or(0);
let mut ft = FileTransfer::new(&mut session);
ft.query(options.compression.clone(), Instant::now());
let negotiated = drive_until_negotiated(transport, &mut ft).await?;
ft.open(&options.dest_filename, options.dummy, Instant::now());
drive_until(transport, &mut ft, |e| matches!(e, FileEvent::Opened)).await?;
let mut stats = UploadStats {
compression: negotiated.clone(),
..UploadStats::default()
};
let mut source_bytes = Vec::new();
src.read_to_end(&mut source_bytes).await?;
stats.source_bytes = source_bytes.len() as u64;
let payload: Vec<u8> = match &negotiated {
Compression::None => source_bytes,
Compression::Heatshrink { window, lookahead } => {
#[cfg(feature = "heatshrink")]
{
crate::compression::compress(&source_bytes, *window, *lookahead)?
}
#[cfg(not(feature = "heatshrink"))]
{
let _ = (window, lookahead);
return Err(UploadError::CompressionFeatureDisabled);
}
}
Compression::Auto => unreachable!("FileTransfer resolves Auto during query"),
};
let chunk_size = resolve_chunk_size(options.chunk_size, device_max);
for chunk in payload.chunks(chunk_size) {
ft.write(chunk, Instant::now());
drive_until(transport, &mut ft, |e| matches!(e, FileEvent::WriteAcked)).await?;
stats.bytes_sent += chunk.len() as u64;
stats.chunks_sent += 1;
if let Some(cb) = options.progress.as_mut() {
cb(Progress {
bytes_sent: stats.bytes_sent,
chunks_sent: stats.chunks_sent,
source_bytes: stats.source_bytes,
});
}
}
ft.close(Instant::now());
drive_until(transport, &mut ft, |e| matches!(e, FileEvent::Closed)).await?;
drop(ft);
session.send(0, 2, &[], Instant::now());
drive_session_until_idle(transport, &mut session).await?;
Ok(stats)
}
async fn drive_session_until_idle<T>(
transport: &mut T,
session: &mut Session,
) -> Result<(), UploadError>
where
T: AsyncRead + AsyncWrite + Unpin,
{
use crate::file_transfer::FileError;
use crate::session::Event;
let mut buf = [0u8; 1024];
for _ in 0..200 {
while let Some(out) = session.poll_outbound() {
transport.write_all(&out).await?;
}
let n = read_with_timeout(transport, &mut buf, session.response_timeout()).await?;
if n > 0 {
session.feed(&buf[..n], Instant::now());
}
while let Some(evt) = session.poll_event() {
match evt {
Event::Ack(_) => return Ok(()),
Event::FatalError => {
return Err(UploadError::Transfer(FileError::SessionFatalError));
}
Event::Timeout { .. } => {
return Err(UploadError::Transfer(FileError::SessionTimeout));
}
Event::OutOfSync { expected, got } => {
return Err(UploadError::Transfer(FileError::SessionOutOfSync {
expected,
got,
}));
}
_ => {}
}
}
session.tick(Instant::now());
}
Err(UploadError::Stalled("control close not acked"))
}
async fn drive_until_synced<T>(transport: &mut T, session: &mut Session) -> Result<(), UploadError>
where
T: AsyncRead + AsyncWrite + Unpin,
{
use crate::file_transfer::FileError;
use crate::session::Event;
let mut buf = [0u8; 1024];
for _ in 0..200 {
while let Some(out) = session.poll_outbound() {
transport.write_all(&out).await?;
}
let n = read_with_timeout(transport, &mut buf, session.response_timeout()).await?;
if n > 0 {
session.feed(&buf[..n], Instant::now());
}
while let Some(evt) = session.poll_event() {
match evt {
Event::Synced { .. } => return Ok(()),
Event::FatalError => {
return Err(UploadError::Transfer(FileError::SessionFatalError));
}
Event::Timeout { .. } => {
return Err(UploadError::Transfer(FileError::SessionTimeout));
}
Event::OutOfSync { expected, got } => {
return Err(UploadError::Transfer(FileError::SessionOutOfSync {
expected,
got,
}));
}
_ => {}
}
}
session.tick(Instant::now());
}
Err(UploadError::HandshakeFailed)
}
async fn drive_until_negotiated<T>(
transport: &mut T,
ft: &mut FileTransfer<'_>,
) -> Result<Compression, UploadError>
where
T: AsyncRead + AsyncWrite + Unpin,
{
let mut buf = [0u8; 1024];
for _ in 0..200 {
while let Some(out) = ft.poll_outbound() {
transport.write_all(&out).await?;
}
let n = read_with_timeout(transport, &mut buf, ft.response_timeout()).await?;
if n > 0 {
ft.feed(&buf[..n], Instant::now());
}
while let Some(evt) = ft.poll() {
match evt {
FileEvent::Negotiated { compression, .. } => return Ok(compression),
FileEvent::Failed(err) => return Err(UploadError::Transfer(err)),
_ => {}
}
}
ft.tick(Instant::now());
}
Err(UploadError::Stalled("negotiation did not complete"))
}
async fn drive_until<T, F>(
transport: &mut T,
ft: &mut FileTransfer<'_>,
pred: F,
) -> Result<(), UploadError>
where
T: AsyncRead + AsyncWrite + Unpin,
F: Fn(&FileEvent) -> bool,
{
let mut buf = [0u8; 1024];
for _ in 0..200 {
while let Some(out) = ft.poll_outbound() {
transport.write_all(&out).await?;
}
let n = read_with_timeout(transport, &mut buf, ft.response_timeout()).await?;
if n > 0 {
ft.feed(&buf[..n], Instant::now());
}
while let Some(evt) = ft.poll() {
if let FileEvent::Failed(err) = &evt {
return Err(UploadError::Transfer(err.clone()));
}
if pred(&evt) {
return Ok(());
}
}
ft.tick(Instant::now());
}
Err(UploadError::Stalled("event did not arrive in time"))
}
async fn read_with_timeout<T>(
transport: &mut T,
buf: &mut [u8],
timeout: Duration,
) -> Result<usize, UploadError>
where
T: AsyncRead + Unpin,
{
match tokio::time::timeout(timeout, transport.read(buf)).await {
Ok(r) => r.map_err(UploadError::Io),
Err(_) => Ok(0),
}
}