use std::time::Instant;
use thiserror::Error;
use crate::session::{Event, Session};
#[derive(Debug, Clone, PartialEq, Eq, Default)]
pub enum Compression {
#[default]
None,
Heatshrink {
window: u8,
lookahead: u8,
},
Auto,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum FileEvent {
Negotiated {
version: String,
compression: Compression,
},
Opened,
WriteAcked,
Closed,
AbortAcked,
Failed(FileError),
}
#[derive(Debug, Clone, PartialEq, Eq, Error)]
pub enum FileError {
#[error("device busy (PFT:busy)")]
OpenBusy,
#[error("device refused open (PFT:fail)")]
OpenFail,
#[error("device storage I/O error (PFT:ioerror)")]
IoError,
#[error("no open file on device (PFT:invalid)")]
NoOpenFile,
#[error("session timed out waiting for reply")]
SessionTimeout,
#[error("session reported fatal error (fe)")]
SessionFatalError,
#[error("session out of sync: expected {expected}, got {got}")]
SessionOutOfSync {
expected: u8,
got: u8,
},
#[error("device does not support heatshrink compression")]
CompressionUnsupported,
#[error("device sent bare ok in {state} without expected {expected} preamble")]
ProtocolViolation {
state: &'static str,
expected: &'static str,
},
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum State {
Idle,
AwaitingQueryReply,
Negotiated,
AwaitingOpenReply,
Opened,
AwaitingWriteAck,
AwaitingCloseReply,
Closed,
AwaitingAbortReply,
Aborted,
Failed,
}
#[derive(Debug)]
pub struct FileTransfer<'a> {
session: &'a mut Session,
state: State,
requested_compression: Compression,
negotiated: Option<Compression>,
advertised_version: Option<String>,
pending_ascii: Option<PendingAscii>,
out_events: std::collections::VecDeque<FileEvent>,
protocol_id: u8,
}
#[derive(Debug)]
enum PendingAscii {
QueryVersion {
version: String,
compression: Compression,
},
QueryFailed(FileError),
OpenSuccess,
OpenBusy,
OpenFail,
CloseSuccess,
CloseIoError,
CloseInvalid,
AbortSuccess,
}
const PROTOCOL_FILE_TRANSFER: u8 = 1;
const PT_QUERY: u8 = 0;
const PT_OPEN: u8 = 1;
const PT_CLOSE: u8 = 2;
const PT_WRITE: u8 = 3;
const PT_ABORT: u8 = 4;
impl<'a> FileTransfer<'a> {
pub fn new(session: &'a mut Session) -> Self {
Self {
session,
state: State::Idle,
requested_compression: Compression::None,
negotiated: None,
advertised_version: None,
pending_ascii: None,
out_events: std::collections::VecDeque::new(),
protocol_id: PROTOCOL_FILE_TRANSFER,
}
}
pub fn query(&mut self, compression: Compression, now: Instant) {
assert!(
matches!(self.state, State::Idle),
"query() requires Idle state, found {:?}",
self.state
);
self.requested_compression = compression;
self.state = State::AwaitingQueryReply;
self.session.send(self.protocol_id, PT_QUERY, &[], now);
}
pub fn open(&mut self, name: &str, dummy: bool, now: Instant) {
assert!(
matches!(self.state, State::Negotiated),
"open() requires Negotiated state (call query() first), found {:?}",
self.state
);
let comp_byte = match self.negotiated {
Some(Compression::Heatshrink { .. }) => 1u8,
_ => 0u8,
};
let mut payload = Vec::with_capacity(2 + name.len() + 1);
payload.push(if dummy { 1 } else { 0 });
payload.push(comp_byte);
payload.extend_from_slice(name.as_bytes());
payload.push(0);
self.state = State::AwaitingOpenReply;
self.session.send(self.protocol_id, PT_OPEN, &payload, now);
}
pub fn write(&mut self, chunk: &[u8], now: Instant) {
assert!(
matches!(self.state, State::Opened),
"write() requires Opened state (pump until WriteAcked between writes), found {:?}",
self.state
);
self.state = State::AwaitingWriteAck;
self.session.send(self.protocol_id, PT_WRITE, chunk, now);
}
pub fn close(&mut self, now: Instant) {
assert!(
matches!(self.state, State::Opened),
"close() requires Opened state, found {:?}",
self.state
);
self.state = State::AwaitingCloseReply;
self.session.send(self.protocol_id, PT_CLOSE, &[], now);
}
pub fn abort(&mut self, now: Instant) {
assert!(
matches!(self.state, State::Opened),
"abort() requires Opened state, found {:?}",
self.state
);
self.state = State::AwaitingAbortReply;
self.session.send(self.protocol_id, PT_ABORT, &[], now);
}
pub fn negotiated_compression(&self) -> Option<&Compression> {
self.negotiated.as_ref()
}
pub fn poll_outbound(&mut self) -> Option<Vec<u8>> {
self.session.poll_outbound()
}
pub fn feed(&mut self, bytes: &[u8], now: Instant) {
self.session.feed(bytes, now);
}
pub fn tick(&mut self, now: Instant) {
self.session.tick(now);
}
pub fn response_timeout(&self) -> std::time::Duration {
self.session.response_timeout()
}
pub fn poll(&mut self) -> Option<FileEvent> {
while let Some(event) = self.session.poll_event() {
self.handle_session_event(event);
}
self.out_events.pop_front()
}
fn handle_session_event(&mut self, event: Event) {
match event {
Event::AsciiLine(line) => self.handle_ascii_line(line),
Event::Ack(_) => self.handle_ack(),
Event::Synced { .. } => {
}
Event::ResendRequested(_) => {
}
Event::FatalError => self.fail(FileError::SessionFatalError),
Event::OutOfSync { expected, got } => {
self.fail(FileError::SessionOutOfSync { expected, got })
}
Event::Timeout { .. } => self.fail(FileError::SessionTimeout),
}
}
fn handle_ascii_line(&mut self, line: String) {
if let Some(rest) = line.strip_prefix("PFT:version:") {
if !matches!(self.state, State::AwaitingQueryReply) {
return;
}
let (version, comp) = match rest.split_once(':') {
Some(parts) => parts,
None => return,
};
self.advertised_version = Some(version.to_string());
let device_compression = parse_compression_spec(comp);
self.pending_ascii = Some(match self.choose_compression(&device_compression) {
Ok(chosen) => PendingAscii::QueryVersion {
version: version.to_string(),
compression: chosen,
},
Err(err) => PendingAscii::QueryFailed(err),
});
return;
}
match line.as_str() {
"PFT:success" => {
self.pending_ascii = Some(match self.state {
State::AwaitingOpenReply => PendingAscii::OpenSuccess,
State::AwaitingCloseReply => PendingAscii::CloseSuccess,
State::AwaitingAbortReply => PendingAscii::AbortSuccess,
_ => return, });
}
"PFT:busy" => {
if matches!(self.state, State::AwaitingOpenReply) {
self.pending_ascii = Some(PendingAscii::OpenBusy);
}
}
"PFT:fail" => {
if matches!(self.state, State::AwaitingOpenReply) {
self.pending_ascii = Some(PendingAscii::OpenFail);
}
}
"PFT:ioerror" => {
if matches!(self.state, State::AwaitingCloseReply) {
self.pending_ascii = Some(PendingAscii::CloseIoError);
}
}
"PFT:invalid" => {
if matches!(self.state, State::AwaitingCloseReply) {
self.pending_ascii = Some(PendingAscii::CloseInvalid);
}
}
_ => {
}
}
}
fn handle_ack(&mut self) {
let pending = self.pending_ascii.take();
match (self.state, pending) {
(
State::AwaitingQueryReply,
Some(PendingAscii::QueryVersion {
version,
compression,
}),
) => {
self.negotiated = Some(compression.clone());
self.state = State::Negotiated;
self.out_events.push_back(FileEvent::Negotiated {
version,
compression,
});
}
(State::AwaitingQueryReply, Some(PendingAscii::QueryFailed(err))) => {
self.fail(err);
}
(State::AwaitingQueryReply, None) => {
self.fail(FileError::ProtocolViolation {
state: "AwaitingQueryReply",
expected: "PFT:version:<v>:<compression-spec>",
});
}
(State::AwaitingOpenReply, Some(PendingAscii::OpenSuccess)) => {
self.state = State::Opened;
self.out_events.push_back(FileEvent::Opened);
}
(State::AwaitingOpenReply, Some(PendingAscii::OpenBusy)) => {
self.state = State::Failed;
self.out_events
.push_back(FileEvent::Failed(FileError::OpenBusy));
}
(State::AwaitingOpenReply, Some(PendingAscii::OpenFail)) => {
self.state = State::Failed;
self.out_events
.push_back(FileEvent::Failed(FileError::OpenFail));
}
(State::AwaitingOpenReply, None) => {
self.fail(FileError::ProtocolViolation {
state: "AwaitingOpenReply",
expected: "PFT:success | PFT:busy | PFT:fail",
});
}
(State::AwaitingWriteAck, None) => {
self.state = State::Opened;
self.out_events.push_back(FileEvent::WriteAcked);
}
(State::AwaitingWriteAck, Some(_)) => {
self.fail(FileError::ProtocolViolation {
state: "AwaitingWriteAck",
expected: "bare ok<n> (no PFT preamble)",
});
}
(State::AwaitingCloseReply, Some(PendingAscii::CloseSuccess)) => {
self.state = State::Closed;
self.out_events.push_back(FileEvent::Closed);
}
(State::AwaitingCloseReply, Some(PendingAscii::CloseIoError)) => {
self.state = State::Failed;
self.out_events
.push_back(FileEvent::Failed(FileError::IoError));
}
(State::AwaitingCloseReply, Some(PendingAscii::CloseInvalid)) => {
self.state = State::Failed;
self.out_events
.push_back(FileEvent::Failed(FileError::NoOpenFile));
}
(State::AwaitingCloseReply, None) => {
self.fail(FileError::ProtocolViolation {
state: "AwaitingCloseReply",
expected: "PFT:success | PFT:ioerror | PFT:invalid",
});
}
(State::AwaitingAbortReply, Some(PendingAscii::AbortSuccess)) => {
self.state = State::Aborted;
self.out_events.push_back(FileEvent::AbortAcked);
}
(State::AwaitingAbortReply, None) => {
self.fail(FileError::ProtocolViolation {
state: "AwaitingAbortReply",
expected: "PFT:success",
});
}
_ => {
}
}
}
fn choose_compression(&self, device: &Compression) -> Result<Compression, FileError> {
match (&self.requested_compression, device) {
(Compression::None, _) => Ok(Compression::None),
(Compression::Heatshrink { window, lookahead }, Compression::Heatshrink { .. }) => {
Ok(Compression::Heatshrink {
window: *window,
lookahead: *lookahead,
})
}
(Compression::Heatshrink { .. }, _) => Err(FileError::CompressionUnsupported),
(Compression::Auto, Compression::Heatshrink { window, lookahead }) => {
Ok(Compression::Heatshrink {
window: *window,
lookahead: *lookahead,
})
}
(Compression::Auto, _) => Ok(Compression::None),
}
}
fn fail(&mut self, err: FileError) {
self.state = State::Failed;
self.out_events.push_back(FileEvent::Failed(err));
}
}
fn parse_compression_spec(spec: &str) -> Compression {
let spec = spec.trim();
if spec == "none" {
return Compression::None;
}
if let Some(rest) = spec.strip_prefix("heatshrink,") {
let mut parts = rest.split(',');
let window: Option<u8> = parts.next().and_then(|s| s.trim().parse().ok());
let lookahead: Option<u8> = parts.next().and_then(|s| s.trim().parse().ok());
if let (Some(window), Some(lookahead)) = (window, lookahead) {
return Compression::Heatshrink { window, lookahead };
}
}
Compression::None
}