use embedded_io::{Error as _, ErrorType};
use embedded_io_async::Write;
use super::chunked::write_chunked_header;
const EMPTY_CHUNK: &[u8; 5] = b"0\r\n\r\n";
const NEWLINE: &[u8; 2] = b"\r\n";
pub struct BufferingChunkedBodyWriter<'a, C: Write> {
conn: C,
buf: &'a mut [u8],
header_pos: usize,
allocated_header: usize,
pos: usize,
terminated: bool,
}
impl<'a, C> BufferingChunkedBodyWriter<'a, C>
where
C: Write,
{
pub fn new_with_data(conn: C, buf: &'a mut [u8], written: usize) -> Self {
assert!(written <= buf.len());
let allocated_header = get_max_chunk_header_size(buf.len() - written);
assert!(buf.len() > allocated_header + NEWLINE.len()); Self {
conn,
buf,
header_pos: written,
pos: written + allocated_header,
allocated_header,
terminated: false,
}
}
pub async fn terminate(&mut self) -> Result<(), C::Error> {
assert!(!self.terminated);
if self.pos > self.header_pos + self.allocated_header {
self.finish_current_chunk();
}
if self.header_pos + EMPTY_CHUNK.len() > self.buf.len() {
self.emit_buffered().await?;
}
self.buf[self.header_pos..self.header_pos + EMPTY_CHUNK.len()].copy_from_slice(EMPTY_CHUNK);
self.header_pos += EMPTY_CHUNK.len();
self.allocated_header = 0;
self.pos = self.header_pos + self.allocated_header;
self.emit_buffered().await?;
self.terminated = true;
Ok(())
}
fn append_current_chunk(&mut self, buf: &[u8]) -> usize {
let buffered = usize::min(buf.len(), self.buf.len().saturating_sub(NEWLINE.len() + self.pos));
if buffered > 0 {
self.buf[self.pos..self.pos + buffered].copy_from_slice(&buf[..buffered]);
self.pos += buffered;
}
buffered
}
fn finish_current_chunk(&mut self) {
let chunk_len = self.pos - self.header_pos - self.allocated_header;
let header_buf = &mut self.buf[self.header_pos..self.header_pos + self.allocated_header];
let header_len = write_chunked_header(header_buf, chunk_len);
let spacing = self.allocated_header - header_len;
if spacing > 0 {
self.buf.copy_within(
self.header_pos + self.allocated_header..self.pos,
self.header_pos + header_len,
);
self.pos -= spacing
}
self.buf[self.pos..self.pos + NEWLINE.len()].copy_from_slice(NEWLINE);
self.pos += 2;
self.header_pos = self.pos;
self.allocated_header = get_max_chunk_header_size(self.buf.len() - self.header_pos);
self.pos = self.header_pos + self.allocated_header;
}
fn current_chunk_is_full(&self) -> bool {
self.pos + NEWLINE.len() == self.buf.len()
}
async fn emit_buffered(&mut self) -> Result<(), C::Error> {
self.conn.write_all(&self.buf[..self.header_pos]).await?;
self.header_pos = 0;
self.allocated_header = get_max_chunk_header_size(self.buf.len());
self.pos = self.allocated_header;
Ok(())
}
}
impl<C> ErrorType for BufferingChunkedBodyWriter<'_, C>
where
C: Write,
{
type Error = embedded_io::ErrorKind;
}
impl<C> Write for BufferingChunkedBodyWriter<'_, C>
where
C: Write,
{
async fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
if buf.is_empty() {
return Ok(0);
}
let mut written = self.append_current_chunk(buf);
if written == 0 {
self.emit_buffered().await.map_err(|e| e.kind())?;
written = self.append_current_chunk(buf);
}
if self.current_chunk_is_full() {
self.finish_current_chunk();
self.emit_buffered().await.map_err(|e| e.kind())?;
}
Ok(written)
}
async fn flush(&mut self) -> Result<(), Self::Error> {
if self.pos > self.header_pos + self.allocated_header {
self.finish_current_chunk();
self.emit_buffered().await.map_err(|e| e.kind())?;
} else if self.header_pos > 0 {
self.emit_buffered().await.map_err(|e| e.kind())?;
}
self.conn.flush().await.map_err(|e| e.kind())
}
}
const fn get_num_hex_chars(number: usize) -> usize {
if number == 0 {
1
} else {
(usize::BITS - number.leading_zeros()).div_ceil(4) as usize
}
}
const fn get_max_chunk_header_size(buffer_size: usize) -> usize {
if let Some(hex_chars_and_payload_size) = buffer_size.checked_sub(2 * NEWLINE.len()) {
get_num_hex_chars(hex_chars_and_payload_size) + NEWLINE.len()
} else {
0
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn can_get_hex_chars() {
assert_eq!(1, get_num_hex_chars(0));
assert_eq!(1, get_num_hex_chars(1));
assert_eq!(1, get_num_hex_chars(0xF));
assert_eq!(2, get_num_hex_chars(0x10));
assert_eq!(2, get_num_hex_chars(0xFF));
assert_eq!(3, get_num_hex_chars(0x100));
}
#[test]
fn can_get_max_chunk_header_size() {
assert_eq!(0, get_max_chunk_header_size(0));
assert_eq!(0, get_max_chunk_header_size(1));
assert_eq!(0, get_max_chunk_header_size(2));
assert_eq!(0, get_max_chunk_header_size(3));
assert_eq!(3, get_max_chunk_header_size(0x00 + 2 + 2));
assert_eq!(3, get_max_chunk_header_size(0x01 + 2 + 2));
assert_eq!(3, get_max_chunk_header_size(0x0F + 2 + 2));
assert_eq!(4, get_max_chunk_header_size(0x10 + 2 + 2));
assert_eq!(4, get_max_chunk_header_size(0x11 + 2 + 2));
assert_eq!(4, get_max_chunk_header_size(0x12 + 2 + 2));
}
#[tokio::test]
async fn preserves_already_written_bytes_in_the_buffer_without_any_chunks() {
let mut conn = Vec::new();
let mut buf = [0; 1024];
buf[..5].copy_from_slice(b"HELLO");
let mut writer = BufferingChunkedBodyWriter::new_with_data(&mut conn, &mut buf, 5);
writer.terminate().await.unwrap();
assert_eq!(b"HELLO0\r\n\r\n", conn.as_slice());
}
#[tokio::test]
async fn preserves_already_written_bytes_in_the_buffer_with_chunks() {
let mut conn = Vec::new();
let mut buf = [0; 1024];
buf[..5].copy_from_slice(b"HELLO");
let mut writer = BufferingChunkedBodyWriter::new_with_data(&mut conn, &mut buf, 5);
writer.write_all(b"BODY").await.unwrap();
writer.terminate().await.unwrap();
assert_eq!(b"HELLO4\r\nBODY\r\n0\r\n\r\n", conn.as_slice());
}
#[tokio::test]
async fn write_when_entire_buffer_is_prewritten() {
let mut conn = Vec::new();
let mut buf = [0; 10];
buf.copy_from_slice(b"HELLOHELLO");
let mut writer = BufferingChunkedBodyWriter::new_with_data(&mut conn, &mut buf, 10);
writer.write_all(b"BODY").await.unwrap(); writer.terminate().await.unwrap();
print!("{:?}", conn.as_slice());
assert_eq!(b"HELLOHELLO4\r\nBODY\r\n0\r\n\r\n", conn.as_slice());
}
#[tokio::test]
async fn flush_empty_body_when_entire_buffer_is_prewritten() {
let mut conn = Vec::new();
let mut buf = [0; 10];
buf.copy_from_slice(b"HELLOHELLO");
let mut writer = BufferingChunkedBodyWriter::new_with_data(&mut conn, &mut buf, 10);
writer.flush().await.unwrap();
print!("{:?}", conn.as_slice());
assert_eq!(b"HELLOHELLO", conn.as_slice());
}
#[tokio::test]
async fn terminate_empty_body_when_entire_buffer_is_prewritten() {
let mut conn = Vec::new();
let mut buf = [0; 10];
buf.copy_from_slice(b"HELLOHELLO");
let mut writer = BufferingChunkedBodyWriter::new_with_data(&mut conn, &mut buf, 10);
writer.terminate().await.unwrap();
print!("{:?}", conn.as_slice());
assert_eq!(b"HELLOHELLO0\r\n\r\n", conn.as_slice());
}
#[tokio::test]
async fn flush_when_entire_buffer_is_nearly_prewritten() {
let mut conn = Vec::new();
let mut buf = [0; 11];
buf[..10].copy_from_slice(b"HELLOHELLO");
let mut writer = BufferingChunkedBodyWriter::new_with_data(&mut conn, &mut buf, 10);
writer.flush().await.unwrap();
print!("{:?}", conn.as_slice());
assert_eq!(b"HELLOHELLO", conn.as_slice());
}
#[tokio::test]
async fn flushes_already_written_bytes_if_first_cannot_fit() {
let mut conn = Vec::new();
let mut buf = [0; 10];
buf[..5].copy_from_slice(b"HELLO");
let mut writer = BufferingChunkedBodyWriter::new_with_data(&mut conn, &mut buf, 5);
writer.write_all(b"BODY").await.unwrap(); writer.terminate().await.unwrap();
assert_eq!(b"HELLO4\r\nBODY\r\n0\r\n\r\n", conn.as_slice());
}
#[tokio::test]
async fn written_bytes_fit_exactly() {
let mut conn = Vec::new();
let mut buf = [0; 14];
buf[..5].copy_from_slice(b"HELLO");
let mut writer = BufferingChunkedBodyWriter::new_with_data(&mut conn, &mut buf, 5);
writer.write_all(b"BODY").await.unwrap(); writer.write_all(b"BODY").await.unwrap(); writer.terminate().await.unwrap();
assert_eq!(b"HELLO4\r\nBODY\r\n4\r\nBODY\r\n0\r\n\r\n", conn.as_slice());
}
#[tokio::test]
async fn current_chunk_is_emitted_on_termination_before_empty_chunk_is_emitted() {
let mut conn = Vec::new();
let mut buf = [0; 14];
buf[..5].copy_from_slice(b"HELLO");
let mut writer = BufferingChunkedBodyWriter::new_with_data(&mut conn, &mut buf, 5);
writer.write_all(b"BOD").await.unwrap(); writer.terminate().await.unwrap();
assert_eq!(b"HELLO3\r\nBOD\r\n0\r\n\r\n", conn.as_slice());
}
#[tokio::test]
async fn write_emits_chunks() {
let mut conn = Vec::new();
let mut buf = [0; 12];
buf[..5].copy_from_slice(b"HELLO");
let mut writer = BufferingChunkedBodyWriter::new_with_data(&mut conn, &mut buf, 5);
writer.write_all(b"BODY").await.unwrap(); writer.terminate().await.unwrap();
assert_eq!(b"HELLO2\r\nBO\r\n2\r\nDY\r\n0\r\n\r\n", conn.as_slice());
}
}