Skip to main content

marlin_binary_transfer/
session.rs

1//! Layer 2: sans-I/O session state machine.
2//!
3//! Owns the sync counter, the outbound queue, and the inbound ASCII line
4//! parser. Callers drive it with `feed` / `poll_outbound` / `poll_event` /
5//! `tick`, plumbing the bytes through any I/O of their choice.
6//!
7//! # Lifecycle
8//!
9//! 1. Caller writes the ASCII trigger `b"M28B1\n"` to the device.
10//! 2. Caller calls [`Session::connect`], drains `poll_outbound`, writes
11//!    those bytes to the device.
12//! 3. Caller reads bytes from the device and pushes them via
13//!    [`Session::feed`].
14//! 4. Caller drains [`Session::poll_event`] until [`Event::Synced`] is
15//!    observed — the device has acknowledged the SYNC and reported its
16//!    block size and protocol version.
17//! 5. Caller calls [`Session::send`] for each subsequent binary packet,
18//!    pumping `poll_outbound` / `feed` / `poll_event` as before, calling
19//!    [`Session::tick`] periodically so retransmits fire on timeout.
20//!
21//! # Concurrency model
22//!
23//! Mirrors the Python reference: only one packet is in flight at a time.
24//! Calls to [`Session::send`] while a packet is in flight are queued FIFO
25//! and dispatched as each ack arrives.
26
27use std::collections::VecDeque;
28use std::time::{Duration, Instant};
29
30use crate::codec::{self, Packet};
31
32/// Maximum sync-counter value before wrapping to 0.
33const SYNC_MOD: u16 = 256;
34
35/// Default per-attempt response timeout.
36const DEFAULT_RESPONSE_TIMEOUT: Duration = Duration::from_millis(1000);
37/// Default total budget for a single packet (= 20 attempts at the per-attempt timeout).
38const DEFAULT_TOTAL_TIMEOUT: Duration = Duration::from_secs(20);
39
40/// Things the session emits to the caller as parsed bytes arrive.
41#[derive(Debug, Clone, PartialEq, Eq)]
42pub enum Event {
43    /// Sync handshake completed. Carries the device's reported block size
44    /// and protocol version.
45    Synced {
46        /// Device-advertised maximum payload bytes per packet.
47        max_block_size: u16,
48        /// Device-advertised protocol version string.
49        protocol_version: String,
50    },
51    /// An `ok<n>` line was received, acknowledging the packet with sync `n`.
52    Ack(u8),
53    /// A `rs<n>` line was received: the device is requesting we retransmit
54    /// the packet with sync `n`.
55    ResendRequested(u8),
56    /// A line was received that did not match any known control token.
57    /// The file-transfer layer consumes these to parse `PFT:*` replies.
58    AsciiLine(String),
59    /// A `fe` line was received: device reports a fatal protocol error.
60    FatalError,
61    /// The session received an `ok<m>` whose number did not match the
62    /// in-flight packet's sync. Recovery requires calling
63    /// [`Session::reset`] then [`Session::connect`] — the protocol
64    /// has no way to resynchronise mid-stream.
65    OutOfSync {
66        /// Sync number we expected an ack for.
67        expected: u8,
68        /// Sync number the device acked.
69        got: u8,
70    },
71    /// A queued packet exceeded its total retransmit budget without an ack.
72    /// The packet has been dropped from the queue.
73    Timeout {
74        /// Sync number of the packet that timed out.
75        sync: u8,
76    },
77}
78
79#[derive(Debug)]
80struct InFlight {
81    sync: u8,
82    bytes: Vec<u8>,
83    first_sent: Instant,
84    last_sent: Instant,
85    /// True until the SYNC handshake completes; an `ss` reply consumes
86    /// this packet rather than `ok<n>`.
87    is_sync_handshake: bool,
88}
89
90#[derive(Debug)]
91struct Queued {
92    /// Will be assigned a sync number at dispatch time so retransmits in
93    /// the meantime don't bump the counter underneath us.
94    bytes_without_sync: BytesBuilder,
95    is_sync_handshake: bool,
96}
97
98/// Pre-built packet bytes minus the parts that depend on the sync number.
99/// We rebuild the header on dispatch so the sync number reflects whatever
100/// was last `Synced` from the device (and so that retransmits use the same
101/// counter).
102#[derive(Debug, Clone)]
103struct BytesBuilder {
104    protocol: u8,
105    packet_type: u8,
106    payload: Vec<u8>,
107}
108
109impl BytesBuilder {
110    fn build(&self, sync: u8) -> Vec<u8> {
111        let mut out = Vec::with_capacity(codec::HEADER_LEN + self.payload.len() + 2);
112        let pkt = Packet::new(sync, self.protocol, self.packet_type, &self.payload)
113            .expect("session validates protocol/type/length at queue time");
114        codec::encode(&pkt, &mut out).expect("validation already passed");
115        out
116    }
117}
118
119/// Sans-I/O session driver.
120#[derive(Debug)]
121pub struct Session {
122    sync: u8,
123    is_synced: bool,
124    max_block_size: Option<u16>,
125    protocol_version: Option<String>,
126
127    in_flight: Option<InFlight>,
128    queued: VecDeque<Queued>,
129    outbound: VecDeque<Vec<u8>>,
130    events: VecDeque<Event>,
131    inbound_buf: Vec<u8>,
132
133    response_timeout: Duration,
134    total_timeout: Duration,
135}
136
137impl Default for Session {
138    fn default() -> Self {
139        Self::new()
140    }
141}
142
143impl Session {
144    /// Construct a fresh, unconnected session with default timeouts.
145    pub fn new() -> Self {
146        Self {
147            sync: 0,
148            is_synced: false,
149            max_block_size: None,
150            protocol_version: None,
151            in_flight: None,
152            queued: VecDeque::new(),
153            outbound: VecDeque::new(),
154            events: VecDeque::new(),
155            inbound_buf: Vec::with_capacity(256),
156            response_timeout: DEFAULT_RESPONSE_TIMEOUT,
157            total_timeout: DEFAULT_TOTAL_TIMEOUT,
158        }
159    }
160
161    /// Set the per-attempt response timeout. Retransmits fire after this
162    /// long without an ack.
163    pub fn with_response_timeout(mut self, timeout: Duration) -> Self {
164        self.response_timeout = timeout;
165        self
166    }
167
168    /// Set the total budget for a single packet across all retransmits.
169    /// When this elapses, the packet is dropped and [`Event::Timeout`] is
170    /// emitted.
171    pub fn with_total_timeout(mut self, timeout: Duration) -> Self {
172        self.total_timeout = timeout;
173        self
174    }
175
176    /// Current per-attempt response timeout. Adapters wrap their inbound
177    /// reads in this so [`Self::tick`] can fire even when the transport
178    /// stays idle.
179    pub fn response_timeout(&self) -> Duration {
180        self.response_timeout
181    }
182
183    /// Current total per-packet budget across all retransmits.
184    pub fn total_timeout(&self) -> Duration {
185        self.total_timeout
186    }
187
188    /// True once an `ss` handshake reply has been received.
189    pub fn is_synced(&self) -> bool {
190        self.is_synced
191    }
192
193    /// Device-advertised maximum payload bytes per packet, set during the
194    /// SYNC handshake.
195    pub fn max_block_size(&self) -> Option<u16> {
196        self.max_block_size
197    }
198
199    /// Device-advertised protocol version, set during the SYNC handshake.
200    pub fn protocol_version(&self) -> Option<&str> {
201        self.protocol_version.as_deref()
202    }
203
204    /// Current sync counter value. Diagnostic accessor used in tests; not
205    /// load-bearing for protocol clients.
206    pub fn current_sync(&self) -> u8 {
207        self.sync
208    }
209
210    /// True if a packet is currently in flight (sent, awaiting ack).
211    pub fn has_pending(&self) -> bool {
212        self.in_flight.is_some()
213    }
214
215    /// Reset the session to its construction baseline.
216    ///
217    /// Drops any in-flight packet, queued packets, pending outbound
218    /// bytes, pending events, and inbound buffer; clears the sync
219    /// counter, sync state, and device-advertised values. Timeouts
220    /// (`response_timeout` / `total_timeout`) are preserved.
221    ///
222    /// Use this after observing [`Event::OutOfSync`] before calling
223    /// [`connect`](Self::connect) again: the BFT protocol has no
224    /// way to resynchronise mid-stream, so the only recovery path is
225    /// to clear local state and redo the handshake from scratch.
226    pub fn reset(&mut self) {
227        self.sync = 0;
228        self.is_synced = false;
229        self.max_block_size = None;
230        self.protocol_version = None;
231        self.in_flight = None;
232        self.queued.clear();
233        self.outbound.clear();
234        self.events.clear();
235        self.inbound_buf.clear();
236    }
237
238    /// Queue the SYNC control packet (protocol=0, packet_type=1).
239    /// The caller should already have written the ASCII trigger
240    /// `b"M28B1\n"` before calling this.
241    pub fn connect(&mut self, now: Instant) {
242        self.queue(0, 1, &[], /* is_sync_handshake = */ true);
243        self.dispatch_if_idle(now);
244    }
245
246    /// Queue a binary packet for transmission.
247    ///
248    /// If the session is idle (no packet in flight), the bytes are pushed
249    /// to the outbound queue immediately. Otherwise the packet waits until
250    /// the in-flight packet is acked.
251    ///
252    /// # Panics
253    ///
254    /// Panics if:
255    /// - the session has not yet observed an `ss` handshake reply
256    ///   (`is_synced() == false`) — without a device-confirmed sync
257    ///   counter, the packet would go out with `sync=0` and almost
258    ///   certainly desynchronise the protocol;
259    /// - `protocol > 15`, `packet_type > 15`, or the payload is longer
260    ///   than [`codec::MAX_PAYLOAD`].
261    ///
262    /// All of the above are programmer errors; production callers should
263    /// drive [`connect`](Self::connect) to completion, observe
264    /// [`Event::Synced`], and clamp payload size to
265    /// [`max_block_size`](Self::max_block_size).
266    pub fn send(&mut self, protocol: u8, packet_type: u8, payload: &[u8], now: Instant) {
267        assert!(
268            self.is_synced,
269            "Session::send called before SYNC handshake completed; call connect() and drive feed() until Event::Synced first"
270        );
271        assert!(protocol <= 0xF, "protocol id out of range");
272        assert!(packet_type <= 0xF, "packet type out of range");
273        assert!(
274            payload.len() <= codec::MAX_PAYLOAD,
275            "payload exceeds MAX_PAYLOAD"
276        );
277        self.queue(protocol, packet_type, payload, false);
278        self.dispatch_if_idle(now);
279    }
280
281    fn queue(&mut self, protocol: u8, packet_type: u8, payload: &[u8], is_sync_handshake: bool) {
282        self.queued.push_back(Queued {
283            bytes_without_sync: BytesBuilder {
284                protocol,
285                packet_type,
286                payload: payload.to_vec(),
287            },
288            is_sync_handshake,
289        });
290    }
291
292    fn dispatch_if_idle(&mut self, now: Instant) {
293        if self.in_flight.is_some() {
294            return;
295        }
296        let Some(next) = self.queued.pop_front() else {
297            return;
298        };
299        let bytes = next.bytes_without_sync.build(self.sync);
300        self.outbound.push_back(bytes.clone());
301        self.in_flight = Some(InFlight {
302            sync: self.sync,
303            bytes,
304            first_sent: now,
305            last_sent: now,
306            is_sync_handshake: next.is_sync_handshake,
307        });
308    }
309
310    /// Drain a single chunk of bytes the caller should write to the wire.
311    /// Returns `None` when no more bytes are pending.
312    pub fn poll_outbound(&mut self) -> Option<Vec<u8>> {
313        self.outbound.pop_front()
314    }
315
316    /// Push received bytes from the wire. Bytes are accumulated until a
317    /// newline-terminated ASCII line is recognised, at which point an
318    /// [`Event`] is queued for [`poll_event`](Self::poll_event).
319    ///
320    /// `now` is used to timestamp any queued packet that gets dispatched
321    /// as a side effect of an inbound ack — fully sans-I/O, no internal
322    /// wall-clock reads.
323    pub fn feed(&mut self, bytes: &[u8], now: Instant) {
324        self.inbound_buf.extend_from_slice(bytes);
325        while let Some(pos) = self.inbound_buf.iter().position(|&b| b == b'\n') {
326            let line: Vec<u8> = self.inbound_buf.drain(..=pos).collect();
327            // Strip trailing \r and \n.
328            let trimmed = strip_line_endings(&line);
329            if trimmed.is_empty() {
330                continue;
331            }
332            self.process_line(trimmed, now);
333        }
334    }
335
336    fn process_line(&mut self, line: &[u8], now: Instant) {
337        if let Some(rest) = strip_prefix(line, b"ok") {
338            if let Some(n) = parse_decimal_u8(rest) {
339                self.handle_ok(n, now);
340                return;
341            }
342        }
343        if let Some(rest) = strip_prefix(line, b"rs") {
344            if let Some(n) = parse_decimal_u8(rest) {
345                self.events.push_back(Event::ResendRequested(n));
346                return;
347            }
348        }
349        if let Some(rest) = strip_prefix(line, b"ss") {
350            self.handle_ss(rest, now);
351            return;
352        }
353        if line == b"fe" {
354            self.events.push_back(Event::FatalError);
355            return;
356        }
357        // Anything else: pass through as a UTF-8 string so file_transfer
358        // (or arbitrary callers) can match on PFT:* tokens.
359        match std::str::from_utf8(line) {
360            Ok(s) => self.events.push_back(Event::AsciiLine(s.to_string())),
361            Err(_) => {
362                // Non-UTF-8 garbage: drop, the caller can't usefully parse it.
363            }
364        }
365    }
366
367    fn handle_ok(&mut self, n: u8, now: Instant) {
368        let Some(flight) = self.in_flight.as_ref() else {
369            // No outstanding packet — stray ack. Surface as a passthrough so
370            // callers can debug, but don't crash.
371            self.events.push_back(Event::AsciiLine(format!("ok{n}")));
372            return;
373        };
374        if flight.is_sync_handshake {
375            // SYNC handshake is acked with `ss`, not `ok`. Treat this as
376            // out-of-sync.
377            self.events.push_back(Event::OutOfSync {
378                expected: flight.sync,
379                got: n,
380            });
381            return;
382        }
383        if n != flight.sync {
384            self.events.push_back(Event::OutOfSync {
385                expected: flight.sync,
386                got: n,
387            });
388            return;
389        }
390        self.in_flight = None;
391        self.sync = ((self.sync as u16 + 1) % SYNC_MOD) as u8;
392        self.events.push_back(Event::Ack(n));
393        self.dispatch_if_idle(now);
394    }
395
396    fn handle_ss(&mut self, rest: &[u8], now: Instant) {
397        let s = match std::str::from_utf8(rest) {
398            Ok(s) => s,
399            Err(_) => return,
400        };
401        let mut parts = s.splitn(3, ',');
402        let (Some(sync_str), Some(bsize_str), Some(version_str)) =
403            (parts.next(), parts.next(), parts.next())
404        else {
405            return;
406        };
407        let Ok(new_sync) = sync_str.trim().parse::<u16>() else {
408            return;
409        };
410        let Ok(max_block_size) = bsize_str.trim().parse::<u16>() else {
411            return;
412        };
413        let new_sync = (new_sync % SYNC_MOD) as u8;
414        self.sync = new_sync;
415        self.max_block_size = Some(max_block_size);
416        let protocol_version = version_str.trim().to_string();
417        self.protocol_version = Some(protocol_version.clone());
418        self.is_synced = true;
419        // SS consumes the in-flight SYNC packet (if any).
420        if let Some(flight) = self.in_flight.as_ref() {
421            if flight.is_sync_handshake {
422                self.in_flight = None;
423            }
424        }
425        self.events.push_back(Event::Synced {
426            max_block_size,
427            protocol_version,
428        });
429        self.dispatch_if_idle(now);
430    }
431
432    /// Drain the next queued event. Returns `None` when the queue is empty.
433    pub fn poll_event(&mut self) -> Option<Event> {
434        self.events.pop_front()
435    }
436
437    /// Drive retransmit and total-timeout logic. Callers should call this
438    /// at least as often as the per-attempt response timeout.
439    pub fn tick(&mut self, now: Instant) {
440        let Some(flight) = self.in_flight.as_mut() else {
441            return;
442        };
443        if now.saturating_duration_since(flight.first_sent) >= self.total_timeout {
444            let sync = flight.sync;
445            self.in_flight = None;
446            self.events.push_back(Event::Timeout { sync });
447            self.dispatch_if_idle(now);
448            return;
449        }
450        if now.saturating_duration_since(flight.last_sent) >= self.response_timeout {
451            // Retransmit.
452            self.outbound.push_back(flight.bytes.clone());
453            flight.last_sent = now;
454        }
455    }
456}
457
458fn strip_line_endings(line: &[u8]) -> &[u8] {
459    let mut end = line.len();
460    while end > 0 && (line[end - 1] == b'\n' || line[end - 1] == b'\r') {
461        end -= 1;
462    }
463    &line[..end]
464}
465
466fn strip_prefix<'a>(line: &'a [u8], prefix: &[u8]) -> Option<&'a [u8]> {
467    if line.starts_with(prefix) {
468        Some(&line[prefix.len()..])
469    } else {
470        None
471    }
472}
473
474fn parse_decimal_u8(b: &[u8]) -> Option<u8> {
475    if b.is_empty() {
476        return None;
477    }
478    let s = std::str::from_utf8(b).ok()?;
479    let n: u32 = s.trim().parse().ok()?;
480    if n > 255 {
481        return None;
482    }
483    Some(n as u8)
484}