use core::{cmp, mem};
use alloc::{collections::VecDeque, vec::Vec};
#[must_use]
pub struct ReadWrite<TNow> {
pub now: TNow,
pub incoming_buffer: Vec<u8>,
pub expected_incoming_bytes: Option<usize>,
pub read_bytes: usize,
pub write_buffers: Vec<Vec<u8>>,
pub write_bytes_queued: usize,
pub write_bytes_queueable: Option<usize>,
pub wake_up_after: Option<TNow>,
}
impl<TNow> ReadWrite<TNow> {
pub fn is_dead(&self) -> bool {
self.expected_incoming_bytes.is_none() && self.write_bytes_queueable.is_none()
}
pub fn close_write(&mut self) {
self.write_bytes_queueable = None;
}
pub fn incoming_buffer_available(&self) -> usize {
self.incoming_buffer.len()
}
pub fn discard_all_incoming(&mut self) {
self.read_bytes += self.incoming_buffer.len();
if let Some(expected_incoming_bytes) = &mut self.expected_incoming_bytes {
*expected_incoming_bytes =
expected_incoming_bytes.saturating_sub(self.incoming_buffer.len());
}
self.incoming_buffer.clear();
}
pub fn incoming_bytes_take(
&mut self,
num: usize,
) -> Result<Option<Vec<u8>>, IncomingBytesTakeError> {
if self.incoming_buffer.len() < num {
if let Some(expected_incoming_bytes) = self.expected_incoming_bytes.as_mut() {
*expected_incoming_bytes = num;
return Ok(None);
} else {
return Err(IncomingBytesTakeError::ReadClosed);
}
}
self.read_bytes += num;
if let Some(expected_incoming_bytes) = self.expected_incoming_bytes.as_mut() {
*expected_incoming_bytes = expected_incoming_bytes.saturating_sub(num);
}
if self.incoming_buffer.len() == num {
Ok(Some(mem::take(&mut self.incoming_buffer)))
} else if self.incoming_buffer.len() - num < num.saturating_mul(2) {
let remains = self.incoming_buffer.split_at(num).1.to_vec();
self.incoming_buffer.truncate(num);
Ok(Some(mem::replace(&mut self.incoming_buffer, remains)))
} else {
let to_ret = self.incoming_buffer.split_at(num).0.to_vec();
self.incoming_buffer.copy_within(num.., 0);
self.incoming_buffer
.truncate(self.incoming_buffer.len() - num);
Ok(Some(to_ret))
}
}
pub fn incoming_bytes_take_array<const N: usize>(
&mut self,
) -> Result<Option<[u8; N]>, IncomingBytesTakeError> {
let Some(vec) = self.incoming_bytes_take(N)? else {
return Ok(None);
};
let bytes = <&[u8; N]>::try_from(&vec[..]).unwrap();
Ok(Some(*bytes))
}
pub fn incoming_bytes_take_leb128(
&mut self,
max_decoded_number: usize,
) -> Result<Option<usize>, IncomingBytesTakeLeb128Error> {
match crate::util::leb128::nom_leb128_usize::<nom::error::Error<&[u8]>>(
&self.incoming_buffer,
) {
Ok((rest, num)) => {
if num > max_decoded_number {
return Err(IncomingBytesTakeLeb128Error::TooLarge);
}
let consumed_bytes = self.incoming_buffer.len() - rest.len();
if !rest.is_empty() {
self.incoming_buffer.copy_within(consumed_bytes.., 0);
self.incoming_buffer
.truncate(self.incoming_buffer.len() - consumed_bytes);
} else {
self.incoming_buffer.clear();
}
self.read_bytes += consumed_bytes;
if let Some(expected_incoming_bytes) = self.expected_incoming_bytes.as_mut() {
*expected_incoming_bytes =
expected_incoming_bytes.saturating_sub(consumed_bytes);
}
Ok(Some(num))
}
Err(nom::Err::Incomplete(nom::Needed::Size(num))) => {
if let Some(expected_incoming_bytes) = self.expected_incoming_bytes.as_mut() {
*expected_incoming_bytes = self.incoming_buffer.len() + num.get();
} else {
return Err(IncomingBytesTakeLeb128Error::ReadClosed);
}
Ok(None)
}
Err(nom::Err::Incomplete(nom::Needed::Unknown)) => {
if let Some(expected_incoming_bytes) = self.expected_incoming_bytes.as_mut() {
*expected_incoming_bytes = self.incoming_buffer.len() + 1;
} else {
return Err(IncomingBytesTakeLeb128Error::ReadClosed);
}
Ok(None)
}
Err(_) => Err(IncomingBytesTakeLeb128Error::InvalidLeb128),
}
}
pub fn write_from_vec(&mut self, data: &mut Vec<u8>) {
let Some(queueable) = self.write_bytes_queueable.as_mut() else {
return;
};
let to_copy = cmp::min(data.len(), *queueable);
if to_copy == 0 {
return;
}
if to_copy == data.len() {
self.write_buffers.push(mem::take(data));
} else {
self.write_buffers.push(data[..to_copy].to_vec());
data.copy_within(to_copy.., 0);
data.truncate(data.len() - to_copy);
}
self.write_bytes_queued += to_copy;
*queueable -= to_copy;
}
pub fn write_from_vec_deque(&mut self, data: &mut VecDeque<u8>) {
let (slice1, slice2) = data.as_slices();
let to_copy1 = cmp::min(slice1.len(), self.write_bytes_queueable.unwrap_or(0));
let to_copy2 = if to_copy1 == slice1.len() {
cmp::min(
slice2.len(),
self.write_bytes_queueable.unwrap_or(0) - to_copy1,
)
} else {
0
};
let total_tocopy = to_copy1 + to_copy2;
if total_tocopy == 0 {
return;
}
self.write_buffers.push(slice1[..to_copy1].to_vec());
self.write_buffers.push(slice2[..to_copy2].to_vec());
self.write_bytes_queued += total_tocopy;
*self.write_bytes_queueable.as_mut().unwrap() -= total_tocopy;
for _ in 0..total_tocopy {
data.pop_front();
}
}
pub fn write_out(&mut self, data: Vec<u8>) {
if data.is_empty() {
return;
}
assert!(data.len() <= self.write_bytes_queueable.unwrap_or(0));
self.write_bytes_queued += data.len();
*self.write_bytes_queueable.as_mut().unwrap() -= data.len();
self.write_buffers.push(data);
}
pub fn wake_up_after(&mut self, after: &TNow)
where
TNow: Clone + Ord,
{
match self.wake_up_after {
Some(ref mut t) if *t < *after => {}
Some(ref mut t) => *t = after.clone(),
ref mut t @ None => *t = Some(after.clone()),
}
}
pub fn wake_up_asap(&mut self)
where
TNow: Clone,
{
self.wake_up_after = Some(self.now.clone());
}
}
#[derive(Debug, Clone, derive_more::Display, derive_more::Error)]
pub enum IncomingBytesTakeError {
ReadClosed,
}
#[derive(Debug, Clone, derive_more::Display, derive_more::Error)]
pub enum IncomingBytesTakeLeb128Error {
InvalidLeb128,
ReadClosed,
TooLarge,
}
#[cfg(test)]
mod tests {
use super::{IncomingBytesTakeError, ReadWrite};
#[test]
fn take_bytes() {
let mut rw = ReadWrite {
now: 0,
incoming_buffer: vec![0x80; 64],
expected_incoming_bytes: Some(12),
read_bytes: 2,
write_buffers: Vec::new(),
write_bytes_queued: 0,
write_bytes_queueable: None,
wake_up_after: None,
};
let buffer = rw.incoming_bytes_take(5).unwrap().unwrap();
assert_eq!(buffer, &[0x80, 0x80, 0x80, 0x80, 0x80]);
assert_eq!(rw.incoming_buffer.len(), 59);
assert_eq!(rw.read_bytes, 7);
assert_eq!(rw.expected_incoming_bytes, Some(7));
assert!(matches!(rw.incoming_bytes_take(1000), Ok(None)));
assert_eq!(rw.read_bytes, 7);
assert_eq!(rw.expected_incoming_bytes, Some(1000));
let buffer = rw.incoming_bytes_take(57).unwrap().unwrap();
assert_eq!(buffer.len(), 57);
assert_eq!(rw.incoming_buffer.len(), 2);
assert_eq!(rw.read_bytes, 64);
assert_eq!(rw.expected_incoming_bytes, Some(1000 - 57));
}
#[test]
fn take_bytes_closed() {
let mut rw = ReadWrite {
now: 0,
incoming_buffer: vec![0x80; 64],
expected_incoming_bytes: None,
read_bytes: 2,
write_buffers: Vec::new(),
write_bytes_queued: 0,
write_bytes_queueable: None,
wake_up_after: None,
};
assert!(matches!(
rw.incoming_bytes_take(1000),
Err(IncomingBytesTakeError::ReadClosed)
));
assert_eq!(rw.expected_incoming_bytes, None);
let buffer = rw.incoming_bytes_take(5).unwrap().unwrap();
assert_eq!(buffer, &[0x80, 0x80, 0x80, 0x80, 0x80]);
assert_eq!(rw.incoming_buffer.len(), 59);
assert_eq!(rw.read_bytes, 7);
assert_eq!(rw.expected_incoming_bytes, None);
assert!(matches!(
rw.incoming_bytes_take(1000),
Err(IncomingBytesTakeError::ReadClosed)
));
assert_eq!(rw.expected_incoming_bytes, None);
}
#[test]
fn write_out() {
let mut rw = ReadWrite {
now: 0,
incoming_buffer: Vec::new(),
expected_incoming_bytes: None,
read_bytes: 0,
write_buffers: Vec::new(),
write_bytes_queued: 11,
write_bytes_queueable: Some(10),
wake_up_after: None,
};
rw.write_out(b"hello".to_vec());
assert_eq!(rw.write_buffers.len(), 1);
assert_eq!(rw.write_bytes_queued, 16);
assert_eq!(rw.write_bytes_queueable, Some(5));
}
#[test]
fn write_from_vec_deque_smaller() {
let mut input = [1, 2, 3, 4].iter().cloned().collect();
let mut rw = ReadWrite {
now: 0,
incoming_buffer: Vec::new(),
expected_incoming_bytes: None,
read_bytes: 0,
write_buffers: Vec::new(),
write_bytes_queueable: Some(5),
write_bytes_queued: 5,
wake_up_after: None,
};
rw.write_from_vec_deque(&mut input);
assert!(input.is_empty());
assert_eq!(rw.write_bytes_queued, 9);
assert_eq!(rw.write_bytes_queueable, Some(1));
}
#[test]
fn write_from_vec_deque_larger() {
let mut input = [1, 2, 3, 4, 5, 6].iter().cloned().collect();
let mut rw = ReadWrite {
now: 0,
incoming_buffer: Vec::new(),
expected_incoming_bytes: None,
read_bytes: 0,
write_buffers: Vec::new(),
write_bytes_queueable: Some(5),
write_bytes_queued: 5,
wake_up_after: None,
};
rw.write_from_vec_deque(&mut input);
assert_eq!(input.into_iter().collect::<Vec<_>>(), &[6]);
assert_eq!(rw.write_bytes_queued, 10);
assert_eq!(rw.write_bytes_queueable, Some(0));
}
}