#[cfg(test)]
mod buffer_test;
use std::sync::Arc;
use tokio::sync::{Mutex, Notify};
use tokio::time::{timeout, Duration};
use crate::error::{Error, Result};
const MIN_SIZE: usize = 2048;
const CUTOFF_SIZE: usize = 128 * 1024;
const MAX_SIZE: usize = 4 * 1024 * 1024;
#[derive(Debug)]
struct BufferInternal {
data: Vec<u8>,
head: usize,
tail: usize,
closed: bool,
subs: bool,
count: usize,
limit_count: usize,
limit_size: usize,
}
impl BufferInternal {
fn available(&self, size: usize) -> bool {
let mut available = self.head as isize - self.tail as isize;
if available <= 0 {
available += self.data.len() as isize;
}
size as isize + 2 < available
}
fn grow(&mut self) -> Result<()> {
let mut newsize = if self.data.len() < CUTOFF_SIZE {
2 * self.data.len()
} else {
5 * self.data.len() / 4
};
if newsize < MIN_SIZE {
newsize = MIN_SIZE
}
if (self.limit_size == 0) && newsize > MAX_SIZE {
newsize = MAX_SIZE
}
if self.limit_size > 0 && newsize > self.limit_size + 1 {
newsize = self.limit_size + 1
}
if newsize <= self.data.len() {
return Err(Error::ErrBufferFull);
}
let mut newdata: Vec<u8> = vec![0; newsize];
let mut n;
if self.head <= self.tail {
n = self.tail - self.head;
newdata[..n].copy_from_slice(&self.data[self.head..self.tail]);
} else {
n = self.data.len() - self.head;
newdata[..n].copy_from_slice(&self.data[self.head..]);
newdata[n..n + self.tail].copy_from_slice(&self.data[..self.tail]);
n += self.tail;
}
self.head = 0;
self.tail = n;
self.data = newdata;
Ok(())
}
fn size(&self) -> usize {
let mut size = self.tail as isize - self.head as isize;
if size < 0 {
size += self.data.len() as isize;
}
size as usize
}
}
#[derive(Debug, Clone)]
pub struct Buffer {
buffer: Arc<Mutex<BufferInternal>>,
notify: Arc<Notify>,
}
impl Buffer {
pub fn new(limit_count: usize, limit_size: usize) -> Self {
Buffer {
buffer: Arc::new(Mutex::new(BufferInternal {
data: vec![],
head: 0,
tail: 0,
closed: false,
subs: false,
count: 0,
limit_count,
limit_size,
})),
notify: Arc::new(Notify::new()),
}
}
pub async fn write(&self, packet: &[u8]) -> Result<usize> {
if packet.len() >= 0x10000 {
return Err(Error::ErrPacketTooBig);
}
let mut b = self.buffer.lock().await;
if b.closed {
return Err(Error::ErrBufferClosed);
}
if (b.limit_count > 0 && b.count >= b.limit_count)
|| (b.limit_size > 0 && b.size() + 2 + packet.len() > b.limit_size)
{
return Err(Error::ErrBufferFull);
}
while !b.available(packet.len()) {
b.grow()?;
}
let tail = b.tail;
b.data[tail] = (packet.len() >> 8) as u8;
b.tail += 1;
if b.tail >= b.data.len() {
b.tail = 0;
}
let tail = b.tail;
b.data[tail] = packet.len() as u8;
b.tail += 1;
if b.tail >= b.data.len() {
b.tail = 0;
}
let end = std::cmp::min(b.data.len(), b.tail + packet.len());
let n = end - b.tail;
let tail = b.tail;
b.data[tail..end].copy_from_slice(&packet[..n]);
b.tail += n;
if b.tail >= b.data.len() {
let m = packet.len() - n;
b.data[..m].copy_from_slice(&packet[n..]);
b.tail = m;
}
b.count += 1;
if b.subs {
self.notify.notify_one();
b.subs = false;
}
Ok(packet.len())
}
pub async fn read(&self, packet: &mut [u8], duration: Option<Duration>) -> Result<usize> {
loop {
{
let mut b = self.buffer.lock().await;
if b.head != b.tail {
let n1 = b.data[b.head];
b.head += 1;
if b.head >= b.data.len() {
b.head = 0;
}
let n2 = b.data[b.head];
b.head += 1;
if b.head >= b.data.len() {
b.head = 0;
}
let count = ((n1 as usize) << 8) | n2 as usize;
let mut copied = count;
if copied > packet.len() {
copied = packet.len();
}
if b.head + copied < b.data.len() {
packet[..copied].copy_from_slice(&b.data[b.head..b.head + copied]);
} else {
let k = b.data.len() - b.head;
packet[..k].copy_from_slice(&b.data[b.head..]);
packet[k..copied].copy_from_slice(&b.data[..copied - k]);
}
b.head += count;
if b.head >= b.data.len() {
b.head -= b.data.len();
}
if b.head == b.tail {
b.head = 0;
b.tail = 0;
}
b.count -= 1;
if copied < count {
return Err(Error::ErrBufferShort);
}
return Ok(copied);
} else {
b.subs = true;
}
if b.closed {
return Err(Error::ErrBufferClosed);
}
}
if let Some(d) = duration {
timeout(d, self.notify.notified())
.await
.map_err(|_| Error::ErrTimeout)?;
} else {
self.notify.notified().await;
}
}
}
pub async fn close(&self) {
let mut b = self.buffer.lock().await;
if b.closed {
return;
}
b.closed = true;
self.notify.notify_waiters();
}
pub async fn is_closed(&self) -> bool {
let b = self.buffer.lock().await;
b.closed
}
pub async fn count(&self) -> usize {
let b = self.buffer.lock().await;
b.count
}
pub async fn set_limit_count(&self, limit: usize) {
let mut b = self.buffer.lock().await;
b.limit_count = limit
}
pub async fn size(&self) -> usize {
let b = self.buffer.lock().await;
b.size()
}
pub async fn set_limit_size(&self, limit: usize) {
let mut b = self.buffer.lock().await;
b.limit_size = limit
}
}