marlin_binary_transfer/session.rs
1//! Layer 2: sans-I/O session state machine.
2//!
3//! Owns the sync counter, the outbound queue, and the inbound ASCII line
4//! parser. Callers drive it with `feed` / `poll_outbound` / `poll_event` /
5//! `tick`, plumbing the bytes through any I/O of their choice.
6//!
7//! # Lifecycle
8//!
9//! 1. Caller writes the ASCII trigger `b"M28B1\n"` to the device.
10//! 2. Caller calls [`Session::connect`], drains `poll_outbound`, writes
11//! those bytes to the device.
12//! 3. Caller reads bytes from the device and pushes them via
13//! [`Session::feed`].
14//! 4. Caller drains [`Session::poll_event`] until [`Event::Synced`] is
15//! observed — the device has acknowledged the SYNC and reported its
16//! block size and protocol version.
17//! 5. Caller calls [`Session::send`] for each subsequent binary packet,
18//! pumping `poll_outbound` / `feed` / `poll_event` as before, calling
19//! [`Session::tick`] periodically so retransmits fire on timeout.
20//!
21//! # Concurrency model
22//!
23//! Mirrors the Python reference: only one packet is in flight at a time.
24//! Calls to [`Session::send`] while a packet is in flight are queued FIFO
25//! and dispatched as each ack arrives.
26
27use std::collections::VecDeque;
28use std::time::{Duration, Instant};
29
30use crate::codec::{self, Packet};
31
32/// Maximum sync-counter value before wrapping to 0.
33const SYNC_MOD: u16 = 256;
34
35/// Default per-attempt response timeout.
36const DEFAULT_RESPONSE_TIMEOUT: Duration = Duration::from_millis(1000);
37/// Default total budget for a single packet (= 20 attempts at the per-attempt timeout).
38const DEFAULT_TOTAL_TIMEOUT: Duration = Duration::from_secs(20);
39
40/// Things the session emits to the caller as parsed bytes arrive.
41#[derive(Debug, Clone, PartialEq, Eq)]
42pub enum Event {
43 /// Sync handshake completed. Carries the device's reported block size
44 /// and protocol version.
45 Synced {
46 /// Device-advertised maximum payload bytes per packet.
47 max_block_size: u16,
48 /// Device-advertised protocol version string.
49 protocol_version: String,
50 },
51 /// An `ok<n>` line was received, acknowledging the packet with sync `n`.
52 Ack(u8),
53 /// A `rs<n>` line was received: the device is requesting we retransmit
54 /// the packet with sync `n`.
55 ResendRequested(u8),
56 /// A line was received that did not match any known control token.
57 /// The file-transfer layer consumes these to parse `PFT:*` replies.
58 AsciiLine(String),
59 /// A `fe` line was received: device reports a fatal protocol error.
60 FatalError,
61 /// The session received an `ok<m>` whose number did not match the
62 /// in-flight packet's sync. Recovery requires calling
63 /// [`Session::reset`] then [`Session::connect`] — the protocol
64 /// has no way to resynchronise mid-stream.
65 OutOfSync {
66 /// Sync number we expected an ack for.
67 expected: u8,
68 /// Sync number the device acked.
69 got: u8,
70 },
71 /// A queued packet exceeded its total retransmit budget without an ack.
72 /// The packet has been dropped from the queue.
73 Timeout {
74 /// Sync number of the packet that timed out.
75 sync: u8,
76 },
77}
78
79#[derive(Debug)]
80struct InFlight {
81 sync: u8,
82 bytes: Vec<u8>,
83 first_sent: Instant,
84 last_sent: Instant,
85 /// True until the SYNC handshake completes; an `ss` reply consumes
86 /// this packet rather than `ok<n>`.
87 is_sync_handshake: bool,
88}
89
90#[derive(Debug)]
91struct Queued {
92 /// Will be assigned a sync number at dispatch time so retransmits in
93 /// the meantime don't bump the counter underneath us.
94 bytes_without_sync: BytesBuilder,
95 is_sync_handshake: bool,
96}
97
98/// Pre-built packet bytes minus the parts that depend on the sync number.
99/// We rebuild the header on dispatch so the sync number reflects whatever
100/// was last `Synced` from the device (and so that retransmits use the same
101/// counter).
102#[derive(Debug, Clone)]
103struct BytesBuilder {
104 protocol: u8,
105 packet_type: u8,
106 payload: Vec<u8>,
107}
108
109impl BytesBuilder {
110 fn build(&self, sync: u8) -> Vec<u8> {
111 let mut out = Vec::with_capacity(codec::HEADER_LEN + self.payload.len() + 2);
112 let pkt = Packet::new(sync, self.protocol, self.packet_type, &self.payload)
113 .expect("session validates protocol/type/length at queue time");
114 codec::encode(&pkt, &mut out).expect("validation already passed");
115 out
116 }
117}
118
119/// Sans-I/O session driver.
120#[derive(Debug)]
121pub struct Session {
122 sync: u8,
123 is_synced: bool,
124 max_block_size: Option<u16>,
125 protocol_version: Option<String>,
126
127 in_flight: Option<InFlight>,
128 queued: VecDeque<Queued>,
129 outbound: VecDeque<Vec<u8>>,
130 events: VecDeque<Event>,
131 inbound_buf: Vec<u8>,
132
133 response_timeout: Duration,
134 total_timeout: Duration,
135}
136
137impl Default for Session {
138 fn default() -> Self {
139 Self::new()
140 }
141}
142
143impl Session {
144 /// Construct a fresh, unconnected session with default timeouts.
145 pub fn new() -> Self {
146 Self {
147 sync: 0,
148 is_synced: false,
149 max_block_size: None,
150 protocol_version: None,
151 in_flight: None,
152 queued: VecDeque::new(),
153 outbound: VecDeque::new(),
154 events: VecDeque::new(),
155 inbound_buf: Vec::with_capacity(256),
156 response_timeout: DEFAULT_RESPONSE_TIMEOUT,
157 total_timeout: DEFAULT_TOTAL_TIMEOUT,
158 }
159 }
160
161 /// Set the per-attempt response timeout. Retransmits fire after this
162 /// long without an ack.
163 pub fn with_response_timeout(mut self, timeout: Duration) -> Self {
164 self.response_timeout = timeout;
165 self
166 }
167
168 /// Set the total budget for a single packet across all retransmits.
169 /// When this elapses, the packet is dropped and [`Event::Timeout`] is
170 /// emitted.
171 pub fn with_total_timeout(mut self, timeout: Duration) -> Self {
172 self.total_timeout = timeout;
173 self
174 }
175
176 /// Current per-attempt response timeout. Adapters wrap their inbound
177 /// reads in this so [`Self::tick`] can fire even when the transport
178 /// stays idle.
179 pub fn response_timeout(&self) -> Duration {
180 self.response_timeout
181 }
182
183 /// Current total per-packet budget across all retransmits.
184 pub fn total_timeout(&self) -> Duration {
185 self.total_timeout
186 }
187
188 /// True once an `ss` handshake reply has been received.
189 pub fn is_synced(&self) -> bool {
190 self.is_synced
191 }
192
193 /// Device-advertised maximum payload bytes per packet, set during the
194 /// SYNC handshake.
195 pub fn max_block_size(&self) -> Option<u16> {
196 self.max_block_size
197 }
198
199 /// Device-advertised protocol version, set during the SYNC handshake.
200 pub fn protocol_version(&self) -> Option<&str> {
201 self.protocol_version.as_deref()
202 }
203
204 /// Current sync counter value. Diagnostic accessor used in tests; not
205 /// load-bearing for protocol clients.
206 pub fn current_sync(&self) -> u8 {
207 self.sync
208 }
209
210 /// True if a packet is currently in flight (sent, awaiting ack).
211 pub fn has_pending(&self) -> bool {
212 self.in_flight.is_some()
213 }
214
215 /// Reset the session to its construction baseline.
216 ///
217 /// Drops any in-flight packet, queued packets, pending outbound
218 /// bytes, pending events, and inbound buffer; clears the sync
219 /// counter, sync state, and device-advertised values. Timeouts
220 /// (`response_timeout` / `total_timeout`) are preserved.
221 ///
222 /// Use this after observing [`Event::OutOfSync`] before calling
223 /// [`connect`](Self::connect) again: the BFT protocol has no
224 /// way to resynchronise mid-stream, so the only recovery path is
225 /// to clear local state and redo the handshake from scratch.
226 pub fn reset(&mut self) {
227 self.sync = 0;
228 self.is_synced = false;
229 self.max_block_size = None;
230 self.protocol_version = None;
231 self.in_flight = None;
232 self.queued.clear();
233 self.outbound.clear();
234 self.events.clear();
235 self.inbound_buf.clear();
236 }
237
238 /// Queue the SYNC control packet (protocol=0, packet_type=1).
239 /// The caller should already have written the ASCII trigger
240 /// `b"M28B1\n"` before calling this.
241 pub fn connect(&mut self, now: Instant) {
242 self.queue(0, 1, &[], /* is_sync_handshake = */ true);
243 self.dispatch_if_idle(now);
244 }
245
246 /// Queue a binary packet for transmission.
247 ///
248 /// If the session is idle (no packet in flight), the bytes are pushed
249 /// to the outbound queue immediately. Otherwise the packet waits until
250 /// the in-flight packet is acked.
251 ///
252 /// # Panics
253 ///
254 /// Panics if:
255 /// - the session has not yet observed an `ss` handshake reply
256 /// (`is_synced() == false`) — without a device-confirmed sync
257 /// counter, the packet would go out with `sync=0` and almost
258 /// certainly desynchronise the protocol;
259 /// - `protocol > 15`, `packet_type > 15`, or the payload is longer
260 /// than [`codec::MAX_PAYLOAD`].
261 ///
262 /// All of the above are programmer errors; production callers should
263 /// drive [`connect`](Self::connect) to completion, observe
264 /// [`Event::Synced`], and clamp payload size to
265 /// [`max_block_size`](Self::max_block_size).
266 pub fn send(&mut self, protocol: u8, packet_type: u8, payload: &[u8], now: Instant) {
267 assert!(
268 self.is_synced,
269 "Session::send called before SYNC handshake completed; call connect() and drive feed() until Event::Synced first"
270 );
271 assert!(protocol <= 0xF, "protocol id out of range");
272 assert!(packet_type <= 0xF, "packet type out of range");
273 assert!(
274 payload.len() <= codec::MAX_PAYLOAD,
275 "payload exceeds MAX_PAYLOAD"
276 );
277 self.queue(protocol, packet_type, payload, false);
278 self.dispatch_if_idle(now);
279 }
280
281 fn queue(&mut self, protocol: u8, packet_type: u8, payload: &[u8], is_sync_handshake: bool) {
282 self.queued.push_back(Queued {
283 bytes_without_sync: BytesBuilder {
284 protocol,
285 packet_type,
286 payload: payload.to_vec(),
287 },
288 is_sync_handshake,
289 });
290 }
291
292 fn dispatch_if_idle(&mut self, now: Instant) {
293 if self.in_flight.is_some() {
294 return;
295 }
296 let Some(next) = self.queued.pop_front() else {
297 return;
298 };
299 let bytes = next.bytes_without_sync.build(self.sync);
300 self.outbound.push_back(bytes.clone());
301 self.in_flight = Some(InFlight {
302 sync: self.sync,
303 bytes,
304 first_sent: now,
305 last_sent: now,
306 is_sync_handshake: next.is_sync_handshake,
307 });
308 }
309
310 /// Drain a single chunk of bytes the caller should write to the wire.
311 /// Returns `None` when no more bytes are pending.
312 pub fn poll_outbound(&mut self) -> Option<Vec<u8>> {
313 self.outbound.pop_front()
314 }
315
316 /// Push received bytes from the wire. Bytes are accumulated until a
317 /// newline-terminated ASCII line is recognised, at which point an
318 /// [`Event`] is queued for [`poll_event`](Self::poll_event).
319 ///
320 /// `now` is used to timestamp any queued packet that gets dispatched
321 /// as a side effect of an inbound ack — fully sans-I/O, no internal
322 /// wall-clock reads.
323 pub fn feed(&mut self, bytes: &[u8], now: Instant) {
324 self.inbound_buf.extend_from_slice(bytes);
325 while let Some(pos) = self.inbound_buf.iter().position(|&b| b == b'\n') {
326 let line: Vec<u8> = self.inbound_buf.drain(..=pos).collect();
327 // Strip trailing \r and \n.
328 let trimmed = strip_line_endings(&line);
329 if trimmed.is_empty() {
330 continue;
331 }
332 self.process_line(trimmed, now);
333 }
334 }
335
336 fn process_line(&mut self, line: &[u8], now: Instant) {
337 if let Some(rest) = strip_prefix(line, b"ok") {
338 if let Some(n) = parse_decimal_u8(rest) {
339 self.handle_ok(n, now);
340 return;
341 }
342 }
343 if let Some(rest) = strip_prefix(line, b"rs") {
344 if let Some(n) = parse_decimal_u8(rest) {
345 self.events.push_back(Event::ResendRequested(n));
346 return;
347 }
348 }
349 if let Some(rest) = strip_prefix(line, b"ss") {
350 self.handle_ss(rest, now);
351 return;
352 }
353 if line == b"fe" {
354 self.events.push_back(Event::FatalError);
355 return;
356 }
357 // Anything else: pass through as a UTF-8 string so file_transfer
358 // (or arbitrary callers) can match on PFT:* tokens.
359 match std::str::from_utf8(line) {
360 Ok(s) => self.events.push_back(Event::AsciiLine(s.to_string())),
361 Err(_) => {
362 // Non-UTF-8 garbage: drop, the caller can't usefully parse it.
363 }
364 }
365 }
366
367 fn handle_ok(&mut self, n: u8, now: Instant) {
368 let Some(flight) = self.in_flight.as_ref() else {
369 // No outstanding packet — stray ack. Surface as a passthrough so
370 // callers can debug, but don't crash.
371 self.events.push_back(Event::AsciiLine(format!("ok{n}")));
372 return;
373 };
374 if flight.is_sync_handshake {
375 // SYNC handshake is acked with `ss`, not `ok`. Treat this as
376 // out-of-sync.
377 self.events.push_back(Event::OutOfSync {
378 expected: flight.sync,
379 got: n,
380 });
381 return;
382 }
383 if n != flight.sync {
384 self.events.push_back(Event::OutOfSync {
385 expected: flight.sync,
386 got: n,
387 });
388 return;
389 }
390 self.in_flight = None;
391 self.sync = ((self.sync as u16 + 1) % SYNC_MOD) as u8;
392 self.events.push_back(Event::Ack(n));
393 self.dispatch_if_idle(now);
394 }
395
396 fn handle_ss(&mut self, rest: &[u8], now: Instant) {
397 let s = match std::str::from_utf8(rest) {
398 Ok(s) => s,
399 Err(_) => return,
400 };
401 let mut parts = s.splitn(3, ',');
402 let (Some(sync_str), Some(bsize_str), Some(version_str)) =
403 (parts.next(), parts.next(), parts.next())
404 else {
405 return;
406 };
407 let Ok(new_sync) = sync_str.trim().parse::<u16>() else {
408 return;
409 };
410 let Ok(max_block_size) = bsize_str.trim().parse::<u16>() else {
411 return;
412 };
413 let new_sync = (new_sync % SYNC_MOD) as u8;
414 self.sync = new_sync;
415 self.max_block_size = Some(max_block_size);
416 let protocol_version = version_str.trim().to_string();
417 self.protocol_version = Some(protocol_version.clone());
418 self.is_synced = true;
419 // SS consumes the in-flight SYNC packet (if any).
420 if let Some(flight) = self.in_flight.as_ref() {
421 if flight.is_sync_handshake {
422 self.in_flight = None;
423 }
424 }
425 self.events.push_back(Event::Synced {
426 max_block_size,
427 protocol_version,
428 });
429 self.dispatch_if_idle(now);
430 }
431
432 /// Drain the next queued event. Returns `None` when the queue is empty.
433 pub fn poll_event(&mut self) -> Option<Event> {
434 self.events.pop_front()
435 }
436
437 /// Drive retransmit and total-timeout logic. Callers should call this
438 /// at least as often as the per-attempt response timeout.
439 pub fn tick(&mut self, now: Instant) {
440 let Some(flight) = self.in_flight.as_mut() else {
441 return;
442 };
443 if now.saturating_duration_since(flight.first_sent) >= self.total_timeout {
444 let sync = flight.sync;
445 self.in_flight = None;
446 self.events.push_back(Event::Timeout { sync });
447 self.dispatch_if_idle(now);
448 return;
449 }
450 if now.saturating_duration_since(flight.last_sent) >= self.response_timeout {
451 // Retransmit.
452 self.outbound.push_back(flight.bytes.clone());
453 flight.last_sent = now;
454 }
455 }
456}
457
458fn strip_line_endings(line: &[u8]) -> &[u8] {
459 let mut end = line.len();
460 while end > 0 && (line[end - 1] == b'\n' || line[end - 1] == b'\r') {
461 end -= 1;
462 }
463 &line[..end]
464}
465
466fn strip_prefix<'a>(line: &'a [u8], prefix: &[u8]) -> Option<&'a [u8]> {
467 if line.starts_with(prefix) {
468 Some(&line[prefix.len()..])
469 } else {
470 None
471 }
472}
473
474fn parse_decimal_u8(b: &[u8]) -> Option<u8> {
475 if b.is_empty() {
476 return None;
477 }
478 let s = std::str::from_utf8(b).ok()?;
479 let n: u32 = s.trim().parse().ok()?;
480 if n > 255 {
481 return None;
482 }
483 Some(n as u8)
484}