pub use uefi_raw::protocol::console::serial::{
ControlBits, Parity, SerialIoMode as IoMode, StopBits,
};
use crate::proto::unsafe_protocol;
use crate::{Error, Result, ResultExt, Status, StatusExt, boot};
use core::time::Duration;
use core::{cmp, fmt};
use log::error;
use uefi_raw::protocol::console::serial::{
SerialIoProtocol, SerialIoProtocol_1_1, SerialIoProtocolRevision,
};
use uguid::Guid;
fn duration_per_byte_estimate(mode: &IoMode) -> Duration {
if mode.baud_rate == 0 {
return Duration::from_millis(100);
}
let data_bits = if mode.data_bits == 0 {
8
} else {
mode.data_bits
};
let parity_bits: u32 = if mode.parity == Parity::NONE || mode.parity == Parity::DEFAULT {
0
} else {
1
};
let stop_bits: u32 = match mode.stop_bits {
StopBits::ONE => 1,
StopBits::DEFAULT | StopBits::ONE_FIVE | StopBits::TWO => 2,
_ => 2,
};
let bits_per_char = 1 + data_bits + parity_bits + stop_bits;
let us_per_bit = 1_000_000_u64.div_ceil(mode.baud_rate);
let us_per_byte = us_per_bit * (bits_per_char as u64);
Duration::from_micros(us_per_byte)
}
fn duration_fifo_estimate(mode: &IoMode, remaining: usize) -> Duration {
let remaining = u32::try_from(remaining).unwrap_or(u32::MAX);
let depth = mode.receive_fifo_depth.max(1);
let remaining = cmp::min(depth, remaining);
duration_per_byte_estimate(mode) * remaining
}
#[derive(Debug)]
#[repr(transparent)]
#[unsafe_protocol(SerialIoProtocol::GUID)]
pub struct Serial(SerialIoProtocol);
impl Serial {
#[must_use]
pub const fn revision(&self) -> SerialIoProtocolRevision {
self.0.revision
}
pub fn reset(&mut self) -> Result {
unsafe { (self.0.reset)(&mut self.0) }.to_result()
}
#[must_use]
pub const fn io_mode(&self) -> &IoMode {
unsafe { &*self.0.mode }
}
pub fn set_attributes(&mut self, mode: &IoMode) -> Result {
unsafe {
(self.0.set_attributes)(
&mut self.0,
mode.baud_rate,
mode.receive_fifo_depth,
mode.timeout,
mode.parity,
mode.data_bits as u8,
mode.stop_bits,
)
}
.to_result()
}
pub fn get_control_bits(&self) -> Result<ControlBits> {
let mut bits = ControlBits::empty();
unsafe { (self.0.get_control_bits)(&self.0, &mut bits) }.to_result_with_val(|| bits)
}
pub fn set_control_bits(&mut self, bits: ControlBits) -> Result {
unsafe { (self.0.set_control_bits)(&mut self.0, bits) }.to_result()
}
pub fn read(&mut self, buffer: &mut [u8]) -> Result<(), usize > {
let mut buffer_size = buffer.len();
unsafe { (self.0.read)(&mut self.0, &mut buffer_size, buffer.as_mut_ptr()) }.to_result_with(
|| {
assert_eq!(buffer_size, buffer.len())
},
|_| buffer_size,
)
}
pub fn read_exact(&mut self, buffer: &mut [u8]) -> Result<()> {
const MAX_ZERO_PROGRESS: usize = 16;
let mut remaining_buffer = buffer;
let mut zero_progress_count = 0;
while !remaining_buffer.is_empty() {
match self.read(remaining_buffer) {
Ok(_) => return Ok(()),
Err(err) if err.status() == Status::TIMEOUT => {
let n = *err.data();
if n == 0 {
zero_progress_count += 1;
if zero_progress_count >= MAX_ZERO_PROGRESS {
return Err(Error::from(Status::TIMEOUT));
}
} else {
zero_progress_count = 0;
}
remaining_buffer = &mut remaining_buffer[n..];
let fifo_stall_duration =
duration_fifo_estimate(self.io_mode(), remaining_buffer.len());
boot::stall(fifo_stall_duration);
}
err => {
return Err(Error::from(err.status()));
}
}
}
Ok(())
}
pub fn write(&mut self, data: &[u8]) -> Result<(), usize > {
let mut buffer_size = data.len();
unsafe { (self.0.write)(&mut self.0, &mut buffer_size, data.as_ptr()) }.to_result_with(
|| {
assert_eq!(buffer_size, data.len())
},
|_| buffer_size,
)
}
pub fn write_exact(&mut self, data: &[u8]) -> Result<()> {
const MAX_ZERO_PROGRESS: usize = 16;
let mut remaining_bytes = data;
let mut zero_progress_count = 0;
while !remaining_bytes.is_empty() {
match self.write(remaining_bytes) {
Ok(_) => return Ok(()),
Err(err) if err.status() == Status::TIMEOUT => {
let n = *err.data();
if n == 0 {
zero_progress_count += 1;
if zero_progress_count >= MAX_ZERO_PROGRESS {
return Err(Error::from(Status::TIMEOUT));
}
} else {
zero_progress_count = 0;
}
remaining_bytes = &remaining_bytes[n..];
let fifo_stall_duration =
duration_fifo_estimate(self.io_mode(), remaining_bytes.len());
boot::stall(fifo_stall_duration);
}
Err(err) => return Err(Error::from(err.status())),
}
}
Ok(())
}
pub fn device_type_guid(&self) -> Result<Option<&'_ Guid>> {
let proto = self.as_revision_1_1()?;
let device_type_guid = unsafe { proto.device_type_guid.as_ref() };
Ok(device_type_guid)
}
fn as_revision_1_1(&self) -> Result<&'_ SerialIoProtocol_1_1> {
if self.revision() < SerialIoProtocolRevision::REVISION_1_1 {
return Err(Error::from(Status::UNSUPPORTED));
}
let ptr = &raw const self.0;
let protocol = unsafe {
ptr.cast::<SerialIoProtocol_1_1>()
.as_ref()
.unwrap_unchecked()
};
Ok(protocol)
}
}
impl fmt::Write for Serial {
fn write_str(&mut self, s: &str) -> fmt::Result {
self.write_exact(s.as_bytes()).map_err(|e| {
let msg = "failed to write to serial device";
if !s.contains(msg) {
error!("{msg}: {e}");
}
fmt::Error
})
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_mode(baud_rate: u64, data_bits: u32, parity: Parity, stop_bits: StopBits) -> IoMode {
IoMode {
control_mask: ControlBits::empty(),
timeout: 0,
baud_rate,
receive_fifo_depth: 0,
data_bits,
parity,
stop_bits,
}
}
#[test]
fn unknown_baud_rate_returns_large_fallback() {
let mode = make_mode(0, 8, Parity::NONE, StopBits::ONE);
let duration = duration_per_byte_estimate(&mode);
assert!(
duration >= Duration::from_millis(50),
"fallback should be at least 100ms, got {duration:?}"
);
}
#[test]
fn higher_baud_rate_gives_shorter_duration() {
let slow = make_mode(9_600, 8, Parity::NONE, StopBits::ONE);
let fast = make_mode(115_200, 8, Parity::NONE, StopBits::ONE);
assert!(
duration_per_byte_estimate(&slow) > duration_per_byte_estimate(&fast),
"9600 baud should take longer per byte than 115200 baud"
);
}
#[test]
fn parity_bit_increases_duration() {
let no_parity = make_mode(9_600, 8, Parity::NONE, StopBits::ONE);
let with_parity = make_mode(9_600, 8, Parity::EVEN, StopBits::ONE);
assert!(
duration_per_byte_estimate(&with_parity) > duration_per_byte_estimate(&no_parity),
"a parity bit should increase the estimated duration"
);
}
#[test]
fn two_stop_bits_increases_duration() {
let one_stop = make_mode(9_600, 8, Parity::NONE, StopBits::ONE);
let two_stop = make_mode(9_600, 8, Parity::NONE, StopBits::TWO);
assert!(
duration_per_byte_estimate(&two_stop) > duration_per_byte_estimate(&one_stop),
"two stop bits should increase the estimated duration"
);
}
#[test]
fn one_five_stop_bits_same_as_two() {
let one_five = make_mode(9_600, 8, Parity::NONE, StopBits::ONE_FIVE);
let two = make_mode(9_600, 8, Parity::NONE, StopBits::TWO);
assert_eq!(
duration_per_byte_estimate(&one_five),
duration_per_byte_estimate(&two),
"ONE_FIVE should be treated as 2 stop bits"
);
}
#[test]
fn more_data_bits_increases_duration() {
let seven = make_mode(9_600, 7, Parity::NONE, StopBits::ONE);
let eight = make_mode(9_600, 8, Parity::NONE, StopBits::ONE);
assert!(
duration_per_byte_estimate(&eight) > duration_per_byte_estimate(&seven),
"more data bits should increase the estimated duration"
);
}
#[test]
fn default_stop_bits_conservative() {
let default_stop = make_mode(9_600, 8, Parity::NONE, StopBits::DEFAULT);
let two_stop = make_mode(9_600, 8, Parity::NONE, StopBits::TWO);
assert!(
duration_per_byte_estimate(&default_stop) >= duration_per_byte_estimate(&two_stop),
"DEFAULT stop bits should be at least as conservative as TWO"
);
}
#[test]
fn default_parity_conservative() {
let default_parity = make_mode(9_600, 8, Parity::DEFAULT, StopBits::ONE);
let no_parity = make_mode(9_600, 8, Parity::NONE, StopBits::ONE);
assert!(
duration_per_byte_estimate(&default_parity) >= duration_per_byte_estimate(&no_parity),
"DEFAULT parity should be at least as conservative as a known parity bit"
);
}
}