use std::time::Duration;
use std::io::{Read, Write};
use thiserror::Error;
const CTRL_A: u8 = 0x01;
const CTRL_B: u8 = 0x02;
const CTRL_C: u8 = 0x03;
const CTRL_D: u8 = 0x04;
const RAW_MODE_PROMPT: &[u8] = b"raw REPL; CTRL-B to exit\r\n>";
const MAX_FRAME_BYTES: usize = 1 << 17;
const NORMALIZE_DRAIN_MS: u64 = 100;
const MAX_PROMPT_BYTES_RELAXED: usize = 4096;
#[derive(Debug, Error)]
pub enum ReplError {
#[error("I/O error: {0}")]
Io(#[from] std::io::Error),
#[error("expected raw-mode prompt 'raw REPL; CTRL-B to exit\\r\\n>', got: {got:?}")]
HandshakeMismatch { got: Vec<u8> },
#[error("response frame did not start with 'OK', got: {got:?}")]
FrameMissingOk { got: Vec<u8> },
#[error("response frame missing 0x04 separator after stdout, got: {got:?}")]
FrameMissingFirstSeparator { got: Vec<u8> },
#[error("response frame missing 0x04 end-of-frame marker, got: {got:?}")]
FrameMissingEndMarker { got: Vec<u8> },
#[error("MicroPython error: {stderr}")]
PythonError { stdout: String, stderr: String },
#[error("ping failed: expected stdout {expected:?}, got {got:?}")]
PingFailed { expected: String, got: String },
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ReplResponse {
pub stdout: String,
pub stderr: String,
}
impl ReplResponse {
pub fn is_error(&self) -> bool {
!self.stderr.is_empty()
}
}
pub fn parse_response_frame(bytes: &[u8]) -> Result<ReplResponse, ReplError> {
if bytes.len() < 2 || &bytes[..2] != b"OK" {
return Err(ReplError::FrameMissingOk {
got: bytes.to_vec(),
});
}
let rest = &bytes[2..];
let sep1 = rest.iter().position(|&b| b == CTRL_D).ok_or_else(|| {
ReplError::FrameMissingFirstSeparator {
got: bytes.to_vec(),
}
})?;
let stdout_bytes = &rest[..sep1];
let after_sep1 = &rest[sep1 + 1..];
let sep2 = after_sep1
.iter()
.position(|&b| b == CTRL_D)
.ok_or_else(|| ReplError::FrameMissingEndMarker {
got: bytes.to_vec(),
})?;
let stderr_bytes = &after_sep1[..sep2];
Ok(ReplResponse {
stdout: String::from_utf8_lossy(stdout_bytes).into_owned(),
stderr: String::from_utf8_lossy(stderr_bytes).into_owned(),
})
}
fn read_until_suffix<P: Read + ?Sized>(
port: &mut P,
needle: &[u8],
max_bytes: usize,
) -> Result<Vec<u8>, ReplError> {
let mut buf = Vec::with_capacity(needle.len() * 2);
let mut byte = [0u8; 1];
loop {
port.read_exact(&mut byte).map_err(ReplError::Io)?;
buf.push(byte[0]);
if buf.len() >= needle.len() && buf.ends_with(needle) {
return Ok(buf);
}
if buf.len() >= max_bytes {
return Err(ReplError::HandshakeMismatch { got: buf });
}
}
}
fn read_until_frame_end<P: Read + ?Sized>(
port: &mut P,
max_bytes: usize,
) -> Result<Vec<u8>, ReplError> {
read_until_suffix(port, b"\x04>", max_bytes).map_err(|e| match e {
ReplError::HandshakeMismatch { got } => ReplError::FrameMissingEndMarker { got },
other => other,
})
}
pub fn enter_raw_mode<P: Read + Write + ?Sized>(port: &mut P) -> Result<(), ReplError> {
port.write_all(&[CTRL_C, CTRL_C, CTRL_B])
.map_err(ReplError::Io)?;
port.flush().map_err(ReplError::Io)?;
std::thread::sleep(std::time::Duration::from_millis(NORMALIZE_DRAIN_MS));
port.write_all(&[CTRL_A]).map_err(ReplError::Io)?;
port.flush().map_err(ReplError::Io)?;
let _received = read_until_suffix(port, RAW_MODE_PROMPT, MAX_PROMPT_BYTES_RELAXED)?;
Ok(())
}
pub fn exit_raw_mode<P: Write + ?Sized>(port: &mut P) -> Result<(), ReplError> {
port.write_all(&[CTRL_B]).map_err(ReplError::Io)?;
port.flush().map_err(ReplError::Io)?;
Ok(())
}
pub fn exec_code<P: Read + Write + ?Sized>(
port: &mut P,
code: &str,
) -> Result<ReplResponse, ReplError> {
port.write_all(code.as_bytes()).map_err(ReplError::Io)?;
port.write_all(&[CTRL_D]).map_err(ReplError::Io)?;
port.flush().map_err(ReplError::Io)?;
let frame = read_until_frame_end(port, MAX_FRAME_BYTES)?;
parse_response_frame(&frame)
}
pub fn ping<P: Read + Write + ?Sized>(port: &mut P) -> Result<(), ReplError> {
let resp = exec_code(port, "print(1+1)")?;
if resp.stdout.trim() == "2" {
Ok(())
} else if resp.is_error() {
Err(ReplError::PythonError {
stdout: resp.stdout,
stderr: resp.stderr,
})
} else {
Err(ReplError::PingFailed {
expected: "2".into(),
got: resp.stdout.trim().to_string(),
})
}
}
pub fn drain_read_buffer(port: &mut dyn serialport::SerialPort) -> std::io::Result<usize> {
let original_timeout = port.timeout();
port.set_timeout(Duration::from_millis(50))?;
let mut total = 0usize;
let mut buf = [0u8; 256];
loop {
match port.read(&mut buf) {
Ok(0) => break,
Ok(n) => total += n,
Err(e)
if e.kind() == std::io::ErrorKind::TimedOut
|| e.kind() == std::io::ErrorKind::WouldBlock =>
{
break
}
Err(e) => {
let _ = port.set_timeout(original_timeout);
return Err(e);
}
}
}
port.set_timeout(original_timeout)?;
Ok(total)
}
pub fn send_ctrl_c(port: &mut dyn serialport::SerialPort) -> std::io::Result<()> {
port.write_all(&[0x03])?;
port.flush()?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::VecDeque;
use std::io::{self, Read, Write};
use std::time::Duration;
struct MockPort {
read_data: VecDeque<u8>,
pub write_data: Vec<u8>,
}
impl MockPort {
fn new(read_bytes: impl Into<Vec<u8>>) -> Self {
MockPort {
read_data: VecDeque::from(read_bytes.into()),
write_data: Vec::new(),
}
}
}
impl Read for MockPort {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let n = buf.len().min(self.read_data.len());
if n == 0 {
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"MockPort: no more scripted bytes",
));
}
for (dst, src) in buf[..n].iter_mut().zip(self.read_data.drain(..n)) {
*dst = src;
}
Ok(n)
}
}
impl Write for MockPort {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.write_data.extend_from_slice(buf);
Ok(buf.len())
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}
#[test]
fn parse_simple_ok_response() {
let bytes = b"OK2\x04\x04>";
let r = parse_response_frame(bytes).unwrap();
assert_eq!(r.stdout, "2");
assert_eq!(r.stderr, "");
assert!(!r.is_error());
}
#[test]
fn parse_response_with_stderr() {
let bytes = b"OK\x04Traceback...\x04>";
let r = parse_response_frame(bytes).unwrap();
assert_eq!(r.stdout, "");
assert_eq!(r.stderr, "Traceback...");
assert!(r.is_error());
}
#[test]
fn parse_response_missing_ok_returns_error() {
let bytes = b"ERR\x04\x04>";
let err = parse_response_frame(bytes).unwrap_err();
assert!(
matches!(err, ReplError::FrameMissingOk { .. }),
"expected FrameMissingOk, got: {err:?}"
);
}
#[test]
fn parse_response_empty_bytes_returns_missing_ok() {
let bytes = b"";
let err = parse_response_frame(bytes).unwrap_err();
assert!(matches!(err, ReplError::FrameMissingOk { .. }));
}
#[test]
fn parse_response_missing_first_separator_returns_error() {
let bytes = b"OKsome output without any ctrl-d";
let err = parse_response_frame(bytes).unwrap_err();
assert!(
matches!(err, ReplError::FrameMissingFirstSeparator { .. }),
"expected FrameMissingFirstSeparator, got: {err:?}"
);
}
#[test]
fn parse_response_missing_end_marker_returns_error() {
let bytes = b"OKsome output\x04stderr but no end marker";
let err = parse_response_frame(bytes).unwrap_err();
assert!(
matches!(err, ReplError::FrameMissingEndMarker { .. }),
"expected FrameMissingEndMarker, got: {err:?}"
);
}
#[test]
fn parse_response_with_print_output() {
let bytes = b"OKhello\r\n\x04\x04>";
let r = parse_response_frame(bytes).unwrap();
assert_eq!(r.stdout, "hello\r\n");
assert_eq!(r.stderr, "");
}
#[test]
fn parse_response_empty_stdout_and_stderr() {
let bytes = b"OK\x04\x04>";
let r = parse_response_frame(bytes).unwrap();
assert_eq!(r.stdout, "");
assert_eq!(r.stderr, "");
assert!(!r.is_error());
}
#[test]
fn parse_response_with_multiple_lines_and_error() {
let bytes = b"OKline1\r\nline2\r\n\x04Traceback (most recent call last):\r\n File \"<stdin>\", line 2\r\nNameError: name 'x' is not defined\r\n\x04>";
let r = parse_response_frame(bytes).unwrap();
assert!(r.stdout.contains("line1"), "stdout: {:?}", r.stdout);
assert!(r.stdout.contains("line2"), "stdout: {:?}", r.stdout);
assert!(r.stderr.contains("NameError"), "stderr: {:?}", r.stderr);
assert!(r.is_error());
}
#[test]
fn parse_response_with_only_stderr() {
let bytes = b"OK\x04ZeroDivisionError: division by zero\r\n\x04>";
let r = parse_response_frame(bytes).unwrap();
assert_eq!(r.stdout, "");
assert!(r.stderr.contains("ZeroDivisionError"));
}
#[test]
fn enter_raw_mode_succeeds_on_correct_prompt() {
let mut port = MockPort::new(b"raw REPL; CTRL-B to exit\r\n>".to_vec());
enter_raw_mode(&mut port).unwrap();
assert_eq!(port.write_data, &[CTRL_C, CTRL_C, CTRL_B, CTRL_A]);
}
#[test]
fn enter_raw_mode_succeeds_when_prompt_has_prefix() {
let mut scripted = b"\r\n>>>".to_vec();
scripted.extend_from_slice(b"raw REPL; CTRL-B to exit\r\n>");
let mut port = MockPort::new(scripted);
enter_raw_mode(&mut port).unwrap();
}
#[test]
fn enter_raw_mode_fails_on_wrong_prompt() {
let mut port = MockPort::new(vec![b'X'; MAX_PROMPT_BYTES_RELAXED + 1]);
let err = enter_raw_mode(&mut port).unwrap_err();
assert!(
matches!(err, ReplError::HandshakeMismatch { .. }),
"expected HandshakeMismatch, got: {err:?}"
);
}
#[test]
fn enter_raw_mode_handles_device_already_in_raw_mode() {
let mut port = MockPort::new(b"raw REPL; CTRL-B to exit\r\n>".to_vec());
enter_raw_mode(&mut port).unwrap();
}
#[test]
fn enter_raw_mode_handles_device_running_code() {
let mut scripted = Vec::new();
scripted.extend_from_slice(b"\r\nTraceback (most recent call last):\r\n");
scripted.extend_from_slice(b" File \"<stdin>\", line 1, in <module>\r\n");
scripted.extend_from_slice(b"KeyboardInterrupt: \r\n");
scripted.extend_from_slice(b">>> ");
scripted.extend_from_slice(b"raw REPL; CTRL-B to exit\r\n>");
let mut port = MockPort::new(scripted);
enter_raw_mode(&mut port).unwrap();
}
#[test]
fn enter_raw_mode_handles_idle_friendly_repl() {
let mut scripted = Vec::new();
scripted.extend_from_slice(b"\r\n>>> \r\n>>> ");
scripted.extend_from_slice(b"raw REPL; CTRL-B to exit\r\n>");
let mut port = MockPort::new(scripted);
enter_raw_mode(&mut port).unwrap();
}
#[test]
fn exec_code_sends_code_and_ctrl_d() {
let mut port = MockPort::new(b"OK2\x04\x04>".to_vec());
let resp = exec_code(&mut port, "1+1").unwrap();
assert_eq!(resp.stdout, "2");
assert_eq!(port.write_data, b"1+1\x04");
assert!(
port.read_data.is_empty(),
"MockPort had unconsumed bytes after exec_code: {:?}",
port.read_data
);
}
#[test]
fn exec_code_returns_python_stderr_without_erroring() {
let mut port = MockPort::new(b"OK\x04NameError: x\r\n\x04>".to_vec());
let resp = exec_code(&mut port, "x").unwrap();
assert_eq!(resp.stdout, "");
assert!(resp.stderr.contains("NameError"));
assert!(resp.is_error());
}
#[test]
fn ping_succeeds_on_correct_response() {
let mut port = MockPort::new(b"OK2\r\n\x04\x04>".to_vec());
ping(&mut port).unwrap();
}
#[test]
fn ping_fails_on_wrong_output() {
let mut port = MockPort::new(b"OK3\r\n\x04\x04>".to_vec());
let err = ping(&mut port).unwrap_err();
assert!(
matches!(err, ReplError::PingFailed { .. }),
"expected PingFailed, got: {err:?}"
);
}
#[test]
fn ping_returns_python_error_on_exception() {
let mut port =
MockPort::new(b"OK\x04ZeroDivisionError: division by zero\r\n\x04>".to_vec());
let err = ping(&mut port).unwrap_err();
match err {
ReplError::PythonError { stdout, stderr } => {
assert_eq!(stdout, "");
assert!(stderr.contains("ZeroDivisionError"));
}
other => panic!("expected PythonError, got: {other:?}"),
}
}
#[test]
fn exit_raw_mode_sends_ctrl_b() {
let mut port = MockPort::new(vec![]); exit_raw_mode(&mut port).unwrap();
assert_eq!(port.write_data, &[CTRL_B]);
}
#[test]
fn connect_script_round_trips_through_mock_port() {
use crate::ceremony::CEREMONY_CONNECT_SCRIPT;
let mock_response = b"OK\x04\x04>".to_vec();
let mut port = MockPort::new(mock_response);
let resp = exec_code(&mut port, CEREMONY_CONNECT_SCRIPT).unwrap();
assert!(
!resp.is_error(),
"connect ceremony script raised a Python exception in mock: stderr={:?}",
resp.stderr
);
let mut expected_write = CEREMONY_CONNECT_SCRIPT.as_bytes().to_vec();
expected_write.push(0x04); assert_eq!(
port.write_data, expected_write,
"exec_code did not send connect ceremony script + Ctrl-D"
);
assert!(
port.read_data.is_empty(),
"MockPort had unconsumed bytes: {:?}",
port.read_data
);
}
#[test]
fn disconnect_script_round_trips_through_mock_port() {
use crate::ceremony::CEREMONY_DISCONNECT_SCRIPT;
let mock_response = b"OK\x04\x04>".to_vec();
let mut port = MockPort::new(mock_response);
let resp = exec_code(&mut port, CEREMONY_DISCONNECT_SCRIPT).unwrap();
assert!(
!resp.is_error(),
"disconnect ceremony script raised a Python exception in mock: stderr={:?}",
resp.stderr
);
let mut expected_write = CEREMONY_DISCONNECT_SCRIPT.as_bytes().to_vec();
expected_write.push(0x04); assert_eq!(
port.write_data, expected_write,
"exec_code did not send disconnect ceremony script + Ctrl-D"
);
assert!(
port.read_data.is_empty(),
"MockPort had unconsumed bytes: {:?}",
port.read_data
);
}
struct MockSerialPort {
read_data: VecDeque<u8>,
pub write_data: Vec<u8>,
pub timeout_value: Duration,
}
impl MockSerialPort {
fn new(read_bytes: impl Into<Vec<u8>>) -> Self {
MockSerialPort {
read_data: VecDeque::from(read_bytes.into()),
write_data: Vec::new(),
timeout_value: Duration::from_millis(500),
}
}
}
impl Read for MockSerialPort {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let n = buf.len().min(self.read_data.len());
if n == 0 {
return Err(io::Error::new(
io::ErrorKind::TimedOut,
"mock: buffer empty",
));
}
for (dst, src) in buf[..n].iter_mut().zip(self.read_data.drain(..n)) {
*dst = src;
}
Ok(n)
}
}
impl Write for MockSerialPort {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.write_data.extend_from_slice(buf);
Ok(buf.len())
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}
fn mock_sp_err() -> serialport::Error {
serialport::Error::new(serialport::ErrorKind::Unknown, "mock: not implemented")
}
impl serialport::SerialPort for MockSerialPort {
fn name(&self) -> Option<String> {
Some("mock".to_string())
}
fn baud_rate(&self) -> serialport::Result<u32> {
Ok(115_200)
}
fn data_bits(&self) -> serialport::Result<serialport::DataBits> {
Ok(serialport::DataBits::Eight)
}
fn flow_control(&self) -> serialport::Result<serialport::FlowControl> {
Ok(serialport::FlowControl::None)
}
fn parity(&self) -> serialport::Result<serialport::Parity> {
Ok(serialport::Parity::None)
}
fn stop_bits(&self) -> serialport::Result<serialport::StopBits> {
Ok(serialport::StopBits::One)
}
fn timeout(&self) -> Duration {
self.timeout_value
}
fn set_baud_rate(&mut self, _: u32) -> serialport::Result<()> {
Err(mock_sp_err())
}
fn set_data_bits(&mut self, _: serialport::DataBits) -> serialport::Result<()> {
Err(mock_sp_err())
}
fn set_flow_control(&mut self, _: serialport::FlowControl) -> serialport::Result<()> {
Err(mock_sp_err())
}
fn set_parity(&mut self, _: serialport::Parity) -> serialport::Result<()> {
Err(mock_sp_err())
}
fn set_stop_bits(&mut self, _: serialport::StopBits) -> serialport::Result<()> {
Err(mock_sp_err())
}
fn set_timeout(&mut self, timeout: Duration) -> serialport::Result<()> {
self.timeout_value = timeout;
Ok(())
}
fn write_request_to_send(&mut self, _: bool) -> serialport::Result<()> {
Err(mock_sp_err())
}
fn write_data_terminal_ready(&mut self, _: bool) -> serialport::Result<()> {
Err(mock_sp_err())
}
fn read_clear_to_send(&mut self) -> serialport::Result<bool> {
Err(mock_sp_err())
}
fn read_data_set_ready(&mut self) -> serialport::Result<bool> {
Err(mock_sp_err())
}
fn read_ring_indicator(&mut self) -> serialport::Result<bool> {
Err(mock_sp_err())
}
fn read_carrier_detect(&mut self) -> serialport::Result<bool> {
Err(mock_sp_err())
}
fn bytes_to_read(&self) -> serialport::Result<u32> {
Ok(self.read_data.len() as u32)
}
fn bytes_to_write(&self) -> serialport::Result<u32> {
Ok(0)
}
fn clear(&self, _: serialport::ClearBuffer) -> serialport::Result<()> {
Err(mock_sp_err())
}
fn try_clone(&self) -> serialport::Result<Box<dyn serialport::SerialPort>> {
Err(mock_sp_err())
}
fn set_break(&self) -> serialport::Result<()> {
Err(mock_sp_err())
}
fn clear_break(&self) -> serialport::Result<()> {
Err(mock_sp_err())
}
}
#[test]
fn drain_read_buffer_consumes_all_bytes() {
let mut port = MockSerialPort::new(b"stale garbage bytes".to_vec());
let drained = drain_read_buffer(&mut port).unwrap();
assert_eq!(drained, 19, "should drain all 19 bytes");
assert!(
port.read_data.is_empty(),
"read_data should be empty after drain"
);
}
#[test]
fn drain_read_buffer_empty_returns_zero() {
let mut port = MockSerialPort::new(b"".to_vec());
let drained = drain_read_buffer(&mut port).unwrap();
assert_eq!(drained, 0, "empty buffer should drain 0 bytes");
}
#[test]
fn drain_read_buffer_restores_timeout() {
let mut port = MockSerialPort::new(b"abc".to_vec());
port.timeout_value = Duration::from_millis(500);
drain_read_buffer(&mut port).unwrap();
assert_eq!(
port.timeout_value,
Duration::from_millis(500),
"timeout must be restored to original 500ms after drain"
);
}
#[test]
fn send_ctrl_c_writes_ctrl_c_byte() {
let mut port = MockSerialPort::new(b"".to_vec());
send_ctrl_c(&mut port).unwrap();
assert_eq!(
port.write_data,
&[0x03],
"send_ctrl_c must write exactly 0x03 (Ctrl-C)"
);
}
}