use crate::buffer::Buffer;
use crate::circular::CircularBuffer;
use crate::pool::BufferPool;
use std::collections::VecDeque;
use std::sync::Arc;
#[derive(Debug, Clone)]
pub struct ConnectionBufferConfig {
pub max_packet_queue_size: usize,
pub max_packet_queue_bytes: usize,
}
impl Default for ConnectionBufferConfig {
fn default() -> Self {
Self {
max_packet_queue_size: 100,
max_packet_queue_bytes: 10_485_760, }
}
}
pub struct ConnectionBuffers {
pub read_buf: Option<Buffer>,
pub write_buf: Option<Buffer>,
pub stream_bufs: Vec<CircularBuffer>,
pub packet_queue: VecDeque<Buffer>,
config: ConnectionBufferConfig,
packet_queue_bytes: usize,
burned: bool,
}
impl ConnectionBuffers {
pub fn new() -> Self {
Self::with_config(ConnectionBufferConfig::default())
}
pub fn with_config(config: ConnectionBufferConfig) -> Self {
Self {
read_buf: None,
write_buf: None,
stream_bufs: Vec::new(),
packet_queue: VecDeque::new(),
config,
packet_queue_bytes: 0,
burned: false,
}
}
pub fn init_read_buf(&mut self, size: usize) {
self.read_buf = Some(Buffer::new(size));
}
pub fn init_write_buf(&mut self, size: usize) {
self.write_buf = Some(Buffer::new(size));
}
pub fn add_stream_buf(&mut self, size: usize) {
self.stream_bufs.push(CircularBuffer::new(size));
}
pub fn queue_packet(&mut self, buf: Buffer) -> Result<(), QueueFullError> {
if self.packet_queue.len() >= self.config.max_packet_queue_size {
return Err(QueueFullError::TooManyPackets);
}
let buf_len = buf.len();
if self.packet_queue_bytes + buf_len > self.config.max_packet_queue_bytes {
return Err(QueueFullError::TooManyBytes);
}
self.packet_queue_bytes += buf_len;
self.packet_queue.push_back(buf);
Ok(())
}
pub fn dequeue_packet(&mut self) -> Option<Buffer> {
self.packet_queue.pop_front().map(|buf| {
self.packet_queue_bytes =
self.packet_queue_bytes.saturating_sub(buf.len());
buf
})
}
pub fn packet_queue_len(&self) -> usize {
self.packet_queue.len()
}
pub fn packet_queue_bytes(&self) -> usize {
self.packet_queue_bytes
}
pub fn is_queue_near_full(&self) -> bool {
self.packet_queue.len() > self.config.max_packet_queue_size * 80 / 100
|| self.packet_queue_bytes > self.config.max_packet_queue_bytes * 80 / 100
}
pub fn reset(&mut self) {
if let Some(ref mut buf) = self.read_buf {
buf.reset();
}
if let Some(ref mut buf) = self.write_buf {
buf.reset();
}
for buf in &mut self.stream_bufs {
buf.clear();
}
self.packet_queue.clear();
self.packet_queue_bytes = 0;
self.burned = false;
}
pub fn burn(&mut self) {
if self.burned {
return;
}
if let Some(ref mut buf) = self.read_buf {
buf.burn();
}
if let Some(ref mut buf) = self.write_buf {
buf.burn();
}
for buf in &mut self.stream_bufs {
buf.free();
}
while let Some(mut pkt) = self.packet_queue.pop_front() {
pkt.burn();
}
self.packet_queue_bytes = 0;
self.burned = true;
}
pub fn memory_usage(&self) -> ConnectionMemoryStats {
let read_buf_bytes = self.read_buf.as_ref().map(|b| b.capacity()).unwrap_or(0);
let write_buf_bytes = self.write_buf.as_ref().map(|b| b.capacity()).unwrap_or(0);
let stream_buf_bytes: usize = self.stream_bufs.iter().map(|b| b.size()).sum();
ConnectionMemoryStats {
read_buf_bytes,
write_buf_bytes,
stream_buf_bytes,
packet_queue_bytes: self.packet_queue_bytes,
total_bytes: read_buf_bytes
+ write_buf_bytes
+ stream_buf_bytes
+ self.packet_queue_bytes,
}
}
}
impl Default for ConnectionBuffers {
fn default() -> Self {
Self::new()
}
}
impl Drop for ConnectionBuffers {
fn drop(&mut self) {
self.burn();
}
}
#[derive(Debug, Clone)]
pub struct ConnectionMemoryStats {
pub read_buf_bytes: usize,
pub write_buf_bytes: usize,
pub stream_buf_bytes: usize,
pub packet_queue_bytes: usize,
pub total_bytes: usize,
}
#[derive(Debug, Clone)]
pub enum QueueFullError {
TooManyPackets,
TooManyBytes,
}
impl std::fmt::Display for QueueFullError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::TooManyPackets => write!(f, "Packet queue full (too many packets)"),
Self::TooManyBytes => write!(f, "Packet queue full (too many bytes)"),
}
}
}
impl std::error::Error for QueueFullError {}
pub struct PooledConnectionBuffers {
buffers: ConnectionBuffers,
pool: Option<Arc<BufferPool>>,
}
impl PooledConnectionBuffers {
pub fn new(pool: Arc<BufferPool>) -> Self {
Self {
buffers: ConnectionBuffers::new(),
pool: Some(pool),
}
}
pub fn with_config(pool: Arc<BufferPool>, config: ConnectionBufferConfig) -> Self {
Self {
buffers: ConnectionBuffers::with_config(config),
pool: Some(pool),
}
}
pub fn buffers(&mut self) -> &mut ConnectionBuffers {
&mut self.buffers
}
pub fn pool(&self) -> Option<&Arc<BufferPool>> {
self.pool.as_ref()
}
pub fn burn_and_release(&mut self) {
self.buffers.burn(); self.pool = None; }
}
impl Drop for PooledConnectionBuffers {
fn drop(&mut self) {
self.buffers.burn(); }
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_connection_buffers_cleanup() {
let mut conn = ConnectionBuffers::new();
conn.init_read_buf(1024);
conn.init_write_buf(1024);
conn.add_stream_buf(512);
if let Some(ref mut buf) = conn.read_buf {
buf.put_byte(0xAB).unwrap();
}
drop(conn); }
#[test]
fn test_burn_is_idempotent() {
let mut conn = ConnectionBuffers::new();
conn.init_read_buf(256);
conn.burn();
conn.burn(); drop(conn); }
#[test]
fn test_packet_queue_limits() {
let mut conn = ConnectionBuffers::with_config(ConnectionBufferConfig {
max_packet_queue_size: 2,
max_packet_queue_bytes: 1024,
});
conn.queue_packet(Buffer::new(256)).unwrap();
conn.queue_packet(Buffer::new(256)).unwrap();
assert!(conn.queue_packet(Buffer::new(256)).is_err());
}
#[test]
fn test_memory_stats() {
let mut conn = ConnectionBuffers::new();
conn.init_read_buf(1024);
conn.init_write_buf(2048);
conn.add_stream_buf(512);
let stats = conn.memory_usage();
assert_eq!(stats.read_buf_bytes, 1024);
assert_eq!(stats.write_buf_bytes, 2048);
assert_eq!(stats.stream_buf_bytes, 512);
}
#[test]
fn test_pooled_connection_burn_and_release() {
use crate::pool::{BufferPool, PoolConfig};
let pool = Arc::new(BufferPool::new(PoolConfig::default()));
let mut pc = PooledConnectionBuffers::new(Arc::clone(&pool));
pc.buffers().init_read_buf(512);
pc.burn_and_release();
assert!(pc.pool().is_none());
drop(pc); }
#[test]
fn test_dequeue_packet() {
let mut conn = ConnectionBuffers::new();
let mut buf = Buffer::new(64);
buf.put_u32(0xDEAD_BEEF).unwrap();
conn.queue_packet(buf).unwrap();
let pkt = conn.dequeue_packet().unwrap();
assert_eq!(pkt.len(), 4);
assert_eq!(conn.packet_queue_len(), 0);
assert_eq!(conn.packet_queue_bytes(), 0);
}
}