use crate::conn::Connection;
use crate::internal::BeBytes;
use crate::protocol::SpecificIdKind;
use crate::protocol::SpecificThreadId;
#[cfg(feature = "trace-pkt")]
use alloc::string::String;
#[cfg(feature = "trace-pkt")]
use alloc::vec::Vec;
use num_traits::identities::one;
use num_traits::CheckedRem;
use num_traits::PrimInt;
#[derive(Debug, Clone)]
pub struct Error<C>(pub C);
pub struct ResponseWriter<'a, C: Connection> {
inner: &'a mut C,
started: bool,
checksum: u8,
rle_enabled: bool,
rle_char: u8,
rle_repeat: u8,
#[cfg(feature = "trace-pkt")]
msg: Vec<u8>,
}
impl<'a, C: Connection + 'a> ResponseWriter<'a, C> {
pub fn new(inner: &'a mut C, rle_enabled: bool) -> Self {
Self {
inner,
started: false,
checksum: 0,
rle_enabled,
rle_char: 0,
rle_repeat: 0,
#[cfg(feature = "trace-pkt")]
msg: Vec::new(),
}
}
pub fn flush(mut self) -> Result<(), Error<C::Error>> {
let checksum = if self.rle_enabled {
self.write(b'#')?;
self.checksum
} else {
let checksum = self.checksum;
self.write(b'#')?;
checksum
};
self.write_hex(checksum)?;
if self.rle_enabled {
self.write(0)?;
}
#[cfg(feature = "trace-pkt")]
trace!("--> ${}", String::from_utf8_lossy(&self.msg));
self.inner.flush().map_err(Error)?;
Ok(())
}
pub fn as_conn(&mut self) -> &mut C {
self.inner
}
fn inner_write(&mut self, byte: u8) -> Result<(), Error<C::Error>> {
#[cfg(feature = "trace-pkt")]
if log_enabled!(log::Level::Trace) {
if self.rle_enabled {
match self.msg.as_slice() {
[.., c, b'*'] => {
let c = *c;
self.msg.pop();
for _ in 0..(byte - 29) {
self.msg.push(c);
}
}
_ => self.msg.push(byte),
}
} else {
self.msg.push(byte)
}
}
if !self.started {
self.started = true;
self.inner.write(b'$').map_err(Error)?;
}
self.checksum = self.checksum.wrapping_add(byte);
self.inner.write(byte).map_err(Error)
}
fn write(&mut self, byte: u8) -> Result<(), Error<C::Error>> {
if !self.rle_enabled {
return self.inner_write(byte);
}
const ASCII_FIRST_PRINT: u8 = b' ';
const ASCII_LAST_PRINT: u8 = b'~';
let rle_printable = (ASCII_FIRST_PRINT - 4 + (self.rle_repeat + 1)) <= ASCII_LAST_PRINT;
if byte == self.rle_char && rle_printable {
self.rle_repeat += 1;
Ok(())
} else {
loop {
match self.rle_repeat {
0 => {} 1 | 2 | 3 => {
for _ in 0..self.rle_repeat {
self.inner_write(self.rle_char)?
}
}
7 | 8 => {
self.inner_write(self.rle_char)?;
self.rle_repeat -= 1;
continue;
}
_ => {
self.inner_write(self.rle_char)?;
self.inner_write(b'*')?;
self.inner_write(ASCII_FIRST_PRINT - 4 + self.rle_repeat)?;
}
}
self.rle_char = byte;
self.rle_repeat = 1;
break Ok(());
}
}
}
pub fn write_str(&mut self, s: &str) -> Result<(), Error<C::Error>> {
for b in s.as_bytes().iter() {
self.write(*b)?;
}
Ok(())
}
fn write_hex(&mut self, byte: u8) -> Result<(), Error<C::Error>> {
for &digit in [(byte & 0xf0) >> 4, byte & 0x0f].iter() {
let c = match digit {
0..=9 => b'0' + digit,
10..=15 => b'a' + digit - 10,
_ => digit,
};
self.write(c)?;
}
Ok(())
}
pub fn write_hex_buf(&mut self, data: &[u8]) -> Result<(), Error<C::Error>> {
for b in data.iter() {
self.write_hex(*b)?;
}
Ok(())
}
pub fn write_binary(&mut self, data: &[u8]) -> Result<(), Error<C::Error>> {
for &b in data.iter() {
match b {
b'#' | b'$' | b'}' | b'*' => {
self.write(b'}')?;
self.write(b ^ 0x20)?
}
_ => self.write(b)?,
}
}
Ok(())
}
pub fn write_num<D: BeBytes + PrimInt>(&mut self, digit: D) -> Result<(), Error<C::Error>> {
if digit.is_zero() {
return self.write_hex(0);
}
let mut buf = [0; 16];
let len = digit.to_be_bytes(&mut buf).unwrap();
let buf = &buf[..len];
for b in buf.iter().copied().skip_while(|&b| b == 0) {
self.write_hex(b)?
}
Ok(())
}
pub fn write_dec<D: PrimInt + CheckedRem>(
&mut self,
mut digit: D,
) -> Result<(), Error<C::Error>> {
if digit.is_zero() {
return self.write(b'0');
}
let one: D = one();
let ten = (one << 3) + (one << 1);
let mut d = digit;
let mut pow_10 = one;
while d >= ten {
d = d / ten;
pow_10 = pow_10 * ten;
}
while !pow_10.is_zero() {
let mut byte = 0;
for i in 0..4 {
if !((digit / pow_10) & (one << i)).is_zero() {
byte += 1 << i;
}
}
self.write(b'0' + byte)?;
digit = digit % pow_10;
pow_10 = pow_10 / ten;
}
Ok(())
}
#[inline]
fn write_specific_id_kind(&mut self, tid: SpecificIdKind) -> Result<(), Error<C::Error>> {
match tid {
SpecificIdKind::All => self.write_str("-1")?,
SpecificIdKind::WithId(id) => self.write_num(id.get())?,
};
Ok(())
}
pub fn write_specific_thread_id(
&mut self,
tid: SpecificThreadId,
) -> Result<(), Error<C::Error>> {
if let Some(pid) = tid.pid {
self.write_str("p")?;
self.write_specific_id_kind(pid)?;
self.write_str(".")?;
}
self.write_specific_id_kind(tid.tid)?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use alloc::vec::Vec;
struct MockConnection {
data: Vec<u8>,
}
impl MockConnection {
fn new() -> Self {
Self { data: Vec::new() }
}
}
impl Connection for MockConnection {
type Error = ();
fn write(&mut self, byte: u8) -> Result<(), Self::Error> {
self.data.push(byte);
Ok(())
}
fn write_all(&mut self, buf: &[u8]) -> Result<(), Self::Error> {
self.data.extend_from_slice(buf);
Ok(())
}
fn flush(&mut self) -> Result<(), Self::Error> {
Ok(())
}
}
fn assert_no_special_chars_in_body(data: &[u8]) {
let hash_pos = data.iter().rposition(|&b| b == b'#').unwrap();
for &byte in &data[1..hash_pos] {
assert!(
byte != b'$' && byte != b'#',
"found {:?} in packet body",
byte as char
);
}
}
#[test]
fn rle_avoids_hash() {
let mut conn = MockConnection::new();
let mut writer = ResponseWriter::new(&mut conn, true);
writer.write_str("0000000").unwrap();
writer.flush().unwrap();
assert_no_special_chars_in_body(&conn.data);
}
#[test]
fn rle_avoids_dollar() {
let mut conn = MockConnection::new();
let mut writer = ResponseWriter::new(&mut conn, true);
writer.write_str("00000000").unwrap();
writer.flush().unwrap();
assert_no_special_chars_in_body(&conn.data);
}
}