use std::collections::VecDeque;
use std::time::{Duration, Instant};
use crate::codec::{self, Packet};
const SYNC_MOD: u16 = 256;
const DEFAULT_RESPONSE_TIMEOUT: Duration = Duration::from_millis(1000);
const DEFAULT_TOTAL_TIMEOUT: Duration = Duration::from_secs(20);
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Event {
Synced {
max_block_size: u16,
protocol_version: String,
},
Ack(u8),
ResendRequested(u8),
AsciiLine(String),
FatalError,
OutOfSync {
expected: u8,
got: u8,
},
Timeout {
sync: u8,
},
}
#[derive(Debug)]
struct InFlight {
sync: u8,
bytes: Vec<u8>,
first_sent: Instant,
last_sent: Instant,
is_sync_handshake: bool,
}
#[derive(Debug)]
struct Queued {
bytes_without_sync: BytesBuilder,
is_sync_handshake: bool,
}
#[derive(Debug, Clone)]
struct BytesBuilder {
protocol: u8,
packet_type: u8,
payload: Vec<u8>,
}
impl BytesBuilder {
fn build(&self, sync: u8) -> Vec<u8> {
let mut out = Vec::with_capacity(codec::HEADER_LEN + self.payload.len() + 2);
let pkt = Packet::new(sync, self.protocol, self.packet_type, &self.payload)
.expect("session validates protocol/type/length at queue time");
codec::encode(&pkt, &mut out).expect("validation already passed");
out
}
}
#[derive(Debug)]
pub struct Session {
sync: u8,
is_synced: bool,
max_block_size: Option<u16>,
protocol_version: Option<String>,
in_flight: Option<InFlight>,
queued: VecDeque<Queued>,
outbound: VecDeque<Vec<u8>>,
events: VecDeque<Event>,
inbound_buf: Vec<u8>,
response_timeout: Duration,
total_timeout: Duration,
}
impl Default for Session {
fn default() -> Self {
Self::new()
}
}
impl Session {
pub fn new() -> Self {
Self {
sync: 0,
is_synced: false,
max_block_size: None,
protocol_version: None,
in_flight: None,
queued: VecDeque::new(),
outbound: VecDeque::new(),
events: VecDeque::new(),
inbound_buf: Vec::with_capacity(256),
response_timeout: DEFAULT_RESPONSE_TIMEOUT,
total_timeout: DEFAULT_TOTAL_TIMEOUT,
}
}
pub fn with_response_timeout(mut self, timeout: Duration) -> Self {
self.response_timeout = timeout;
self
}
pub fn with_total_timeout(mut self, timeout: Duration) -> Self {
self.total_timeout = timeout;
self
}
pub fn response_timeout(&self) -> Duration {
self.response_timeout
}
pub fn total_timeout(&self) -> Duration {
self.total_timeout
}
pub fn is_synced(&self) -> bool {
self.is_synced
}
pub fn max_block_size(&self) -> Option<u16> {
self.max_block_size
}
pub fn protocol_version(&self) -> Option<&str> {
self.protocol_version.as_deref()
}
pub fn current_sync(&self) -> u8 {
self.sync
}
pub fn has_pending(&self) -> bool {
self.in_flight.is_some()
}
pub fn reset(&mut self) {
self.sync = 0;
self.is_synced = false;
self.max_block_size = None;
self.protocol_version = None;
self.in_flight = None;
self.queued.clear();
self.outbound.clear();
self.events.clear();
self.inbound_buf.clear();
}
pub fn connect(&mut self, now: Instant) {
self.queue(0, 1, &[], true);
self.dispatch_if_idle(now);
}
pub fn send(&mut self, protocol: u8, packet_type: u8, payload: &[u8], now: Instant) {
assert!(
self.is_synced,
"Session::send called before SYNC handshake completed; call connect() and drive feed() until Event::Synced first"
);
assert!(protocol <= 0xF, "protocol id out of range");
assert!(packet_type <= 0xF, "packet type out of range");
assert!(
payload.len() <= codec::MAX_PAYLOAD,
"payload exceeds MAX_PAYLOAD"
);
self.queue(protocol, packet_type, payload, false);
self.dispatch_if_idle(now);
}
fn queue(&mut self, protocol: u8, packet_type: u8, payload: &[u8], is_sync_handshake: bool) {
self.queued.push_back(Queued {
bytes_without_sync: BytesBuilder {
protocol,
packet_type,
payload: payload.to_vec(),
},
is_sync_handshake,
});
}
fn dispatch_if_idle(&mut self, now: Instant) {
if self.in_flight.is_some() {
return;
}
let Some(next) = self.queued.pop_front() else {
return;
};
let bytes = next.bytes_without_sync.build(self.sync);
self.outbound.push_back(bytes.clone());
self.in_flight = Some(InFlight {
sync: self.sync,
bytes,
first_sent: now,
last_sent: now,
is_sync_handshake: next.is_sync_handshake,
});
}
pub fn poll_outbound(&mut self) -> Option<Vec<u8>> {
self.outbound.pop_front()
}
pub fn feed(&mut self, bytes: &[u8], now: Instant) {
self.inbound_buf.extend_from_slice(bytes);
while let Some(pos) = self.inbound_buf.iter().position(|&b| b == b'\n') {
let line: Vec<u8> = self.inbound_buf.drain(..=pos).collect();
let trimmed = strip_line_endings(&line);
if trimmed.is_empty() {
continue;
}
self.process_line(trimmed, now);
}
}
fn process_line(&mut self, line: &[u8], now: Instant) {
if let Some(rest) = strip_prefix(line, b"ok") {
if let Some(n) = parse_decimal_u8(rest) {
self.handle_ok(n, now);
return;
}
}
if let Some(rest) = strip_prefix(line, b"rs") {
if let Some(n) = parse_decimal_u8(rest) {
self.events.push_back(Event::ResendRequested(n));
return;
}
}
if let Some(rest) = strip_prefix(line, b"ss") {
self.handle_ss(rest, now);
return;
}
if line == b"fe" {
self.events.push_back(Event::FatalError);
return;
}
match std::str::from_utf8(line) {
Ok(s) => self.events.push_back(Event::AsciiLine(s.to_string())),
Err(_) => {
}
}
}
fn handle_ok(&mut self, n: u8, now: Instant) {
let Some(flight) = self.in_flight.as_ref() else {
self.events.push_back(Event::AsciiLine(format!("ok{n}")));
return;
};
if flight.is_sync_handshake {
self.events.push_back(Event::OutOfSync {
expected: flight.sync,
got: n,
});
return;
}
if n != flight.sync {
self.events.push_back(Event::OutOfSync {
expected: flight.sync,
got: n,
});
return;
}
self.in_flight = None;
self.sync = ((self.sync as u16 + 1) % SYNC_MOD) as u8;
self.events.push_back(Event::Ack(n));
self.dispatch_if_idle(now);
}
fn handle_ss(&mut self, rest: &[u8], now: Instant) {
let s = match std::str::from_utf8(rest) {
Ok(s) => s,
Err(_) => return,
};
let mut parts = s.splitn(3, ',');
let (Some(sync_str), Some(bsize_str), Some(version_str)) =
(parts.next(), parts.next(), parts.next())
else {
return;
};
let Ok(new_sync) = sync_str.trim().parse::<u16>() else {
return;
};
let Ok(max_block_size) = bsize_str.trim().parse::<u16>() else {
return;
};
let new_sync = (new_sync % SYNC_MOD) as u8;
self.sync = new_sync;
self.max_block_size = Some(max_block_size);
let protocol_version = version_str.trim().to_string();
self.protocol_version = Some(protocol_version.clone());
self.is_synced = true;
if let Some(flight) = self.in_flight.as_ref() {
if flight.is_sync_handshake {
self.in_flight = None;
}
}
self.events.push_back(Event::Synced {
max_block_size,
protocol_version,
});
self.dispatch_if_idle(now);
}
pub fn poll_event(&mut self) -> Option<Event> {
self.events.pop_front()
}
pub fn tick(&mut self, now: Instant) {
let Some(flight) = self.in_flight.as_mut() else {
return;
};
if now.saturating_duration_since(flight.first_sent) >= self.total_timeout {
let sync = flight.sync;
self.in_flight = None;
self.events.push_back(Event::Timeout { sync });
self.dispatch_if_idle(now);
return;
}
if now.saturating_duration_since(flight.last_sent) >= self.response_timeout {
self.outbound.push_back(flight.bytes.clone());
flight.last_sent = now;
}
}
}
fn strip_line_endings(line: &[u8]) -> &[u8] {
let mut end = line.len();
while end > 0 && (line[end - 1] == b'\n' || line[end - 1] == b'\r') {
end -= 1;
}
&line[..end]
}
fn strip_prefix<'a>(line: &'a [u8], prefix: &[u8]) -> Option<&'a [u8]> {
if line.starts_with(prefix) {
Some(&line[prefix.len()..])
} else {
None
}
}
fn parse_decimal_u8(b: &[u8]) -> Option<u8> {
if b.is_empty() {
return None;
}
let s = std::str::from_utf8(b).ok()?;
let n: u32 = s.trim().parse().ok()?;
if n > 255 {
return None;
}
Some(n as u8)
}