use crate::LynnError;
use std::net::SocketAddr;
use std::collections::HashMap;
use std::sync::Arc;
use std::net::IpAddr;
use dashmap::DashMap;
use tracing::{warn, error};
pub const MAX_MESSAGE_SIZE: usize = 10 * 1024 * 1024;
pub const MIN_MESSAGE_SIZE: usize = 3;
pub const DEFAULT_MAX_CONNECTIONS_PER_IP: usize = 10;
pub const DEFAULT_MESSAGE_RATE_LIMIT: usize = 1000;
pub const DEFAULT_MAX_BUFFER_SIZE: usize = 16 * 1024 * 1024;
pub fn validate_message_length(len: u64) -> Result<usize, LynnError> {
let len = if len > usize::MAX as u64 {
return Err(LynnError::protocol(format!(
"Message length {} exceeds usize::MAX",
len
)));
} else {
len as usize
};
if len < MIN_MESSAGE_SIZE {
return Err(LynnError::protocol(format!(
"Message too short: {} bytes (minimum {})",
len, MIN_MESSAGE_SIZE
)));
}
if len > MAX_MESSAGE_SIZE {
return Err(LynnError::protocol(format!(
"Message too large: {} bytes (maximum {})",
len, MAX_MESSAGE_SIZE
)));
}
Ok(len)
}
pub fn validate_message_format(
data: &[u8],
message_header_mark: u16,
message_tail_mark: u16,
) -> Result<usize, LynnError> {
if data.len() < 10 {
return Err(LynnError::protocol(format!(
"Message too short: {} bytes (minimum 10 for header+length)",
data.len()
)));
}
let header = u16::from_le_bytes([data[0], data[1]]);
if header != message_header_mark {
return Err(LynnError::protocol(format!(
"Invalid header mark: 0x{:04X} (expected 0x{:04X})",
header, message_header_mark
)));
}
let msg_len = u64::from_le_bytes([
data[2], data[3], data[4], data[5],
data[6], data[7], data[8], data[9],
]);
let body_len = validate_message_length(msg_len)?;
let total_expected_size = 10 + body_len + 2; if data.len() < total_expected_size {
return Err(LynnError::protocol(format!(
"Incomplete message: {} bytes (expected {} for complete message)",
data.len(), total_expected_size
)));
}
let tail_pos = 10 + body_len;
let tail = u16::from_le_bytes([data[tail_pos], data[tail_pos + 1]]);
if tail != message_tail_mark {
return Err(LynnError::protocol(format!(
"Invalid tail mark: 0x{:04X} (expected 0x{:04X})",
tail, message_tail_mark
)));
}
Ok(body_len)
}
#[derive(Clone)]
pub struct ConnectionLimiter {
max_connections: usize,
max_connections_per_ip: usize,
per_ip_counts: Arc<DashMap<IpAddr, usize>>,
}
impl ConnectionLimiter {
pub fn new(max_connections: usize, max_connections_per_ip: usize) -> Self {
Self {
max_connections,
max_connections_per_ip,
per_ip_counts: Arc::new(DashMap::new()),
}
}
pub fn can_accept_connection(&self, addr: SocketAddr) -> Result<(), LynnError> {
let total_count: usize = self.per_ip_counts.iter().map(|entry| *entry.value()).sum();
if total_count >= self.max_connections {
warn!(
"Rejecting connection from {}: maximum connections reached ({})",
addr, total_count
);
return Err(LynnError::server(format!(
"Maximum connections reached: {}",
self.max_connections
)));
}
let ip = addr.ip();
let ip_count = *self.per_ip_counts.entry(ip).or_insert(0);
if ip_count >= self.max_connections_per_ip {
warn!(
"Rejecting connection from {}: too many connections from this IP ({})",
addr, ip_count
);
return Err(LynnError::server(format!(
"Too many connections from IP: {} (limit: {})",
ip, self.max_connections_per_ip
)));
}
Ok(())
}
pub fn add_connection(&self, addr: SocketAddr) {
let ip = addr.ip();
*self.per_ip_counts.entry(ip).or_insert(0) += 1;
}
pub fn remove_connection(&self, addr: SocketAddr) {
let ip = addr.ip();
if let Some(mut count) = self.per_ip_counts.get_mut(&ip) {
if *count > 0 {
*count -= 1;
}
if *count == 0 {
self.per_ip_counts.remove(&ip);
}
}
}
pub fn total_connections(&self) -> usize {
self.per_ip_counts.iter().map(|entry| *entry.value()).sum()
}
pub fn connections_for_ip(&self, ip: IpAddr) -> usize {
self.per_ip_counts.get(&ip).map(|v| *v).unwrap_or(0)
}
}
#[derive(Clone)]
pub struct RateLimiter {
messages_per_second: usize,
}
impl RateLimiter {
pub fn new(messages_per_second: usize) -> Self {
Self {
messages_per_second,
}
}
pub fn check_rate(&self) -> Result<(), LynnError> {
Ok(())
}
}
pub struct SafeBuffer {
data: Vec<u8>,
max_size: usize,
}
impl SafeBuffer {
pub fn new(max_size: usize) -> Self {
Self {
data: Vec::with_capacity(4096), max_size,
}
}
pub fn extend(&mut self, data: &[u8]) -> Result<(), LynnError> {
if data.len() > self.max_size {
return Err(LynnError::buffer(format!(
"Single data chunk too large: {} bytes (maximum {})",
data.len(), self.max_size
)));
}
if self.data.len() + data.len() > self.max_size {
return Err(LynnError::buffer(format!(
"Buffer overflow: current={} bytes, adding={} bytes, maximum={} bytes",
self.data.len(),
data.len(),
self.max_size
)));
}
self.data.extend_from_slice(data);
Ok(())
}
pub fn clear(&mut self) {
self.data.clear();
}
pub fn len(&self) -> usize {
self.data.len()
}
pub fn is_empty(&self) -> bool {
self.data.is_empty()
}
pub fn as_slice(&self) -> &[u8] {
&self.data
}
pub fn max_size(&self) -> usize {
self.max_size
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_validate_message_length_valid() {
assert_eq!(validate_message_length(100).unwrap(), 100);
assert_eq!(validate_message_length(1024).unwrap(), 1024);
}
#[test]
fn test_validate_message_length_too_small() {
assert!(validate_message_length(2).is_err());
assert!(validate_message_length(0).is_err());
}
#[test]
fn test_validate_message_length_too_large() {
assert!(validate_message_length(MAX_MESSAGE_SIZE as u64 + 1).is_err());
}
#[test]
fn test_validate_message_format() {
let header: u16 = 0x23E9;
let tail: u16 = 0x1E27;
let mut data = vec
![0u8; 20];
data[0..2].copy_from_slice(&header.to_le_bytes());
data[2..10].copy_from_slice(&6u64.to_le_bytes());
data[16..18].copy_from_slice(&tail.to_le_bytes());
assert!(validate_message_format(&data, header, tail).is_ok());
}
#[test]
fn test_connection_limiter() {
let limiter = ConnectionLimiter::new(100, 5);
let addr = "127.0.0.1:8080".parse().unwrap();
for _ in 0..5 {
assert!(limiter.can_accept_connection(addr).is_ok());
limiter.add_connection(addr);
}
assert!(limiter.can_accept_connection(addr).is_err());
limiter.remove_connection(addr);
assert!(limiter.can_accept_connection(addr).is_ok());
}
#[test]
fn test_safe_buffer() {
let mut buffer = SafeBuffer::new(100);
assert!(buffer.extend(&[1, 2, 3]).is_ok());
assert_eq!(buffer.len(), 3);
assert!(buffer.extend(&vec![0u8; 200]).is_err());
}
}