use std::collections::HashMap;
use std::fmt;
use std::net::SocketAddr;
use std::time::Instant;
use crate::VarInt;
#[derive(Debug)]
pub struct ContextManager {
local_contexts: HashMap<VarInt, ContextInfo>,
remote_contexts: HashMap<VarInt, ContextInfo>,
uncompressed_context: Option<VarInt>,
next_local_id: u64,
is_client: bool,
}
#[derive(Debug, Clone)]
pub struct ContextInfo {
pub target: Option<SocketAddr>,
pub state: ContextState,
pub created_at: Instant,
pub last_activity: Instant,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ContextState {
Pending,
Active,
Closing,
Closed,
}
impl fmt::Display for ContextState {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ContextState::Pending => write!(f, "pending"),
ContextState::Active => write!(f, "active"),
ContextState::Closing => write!(f, "closing"),
ContextState::Closed => write!(f, "closed"),
}
}
}
impl ContextManager {
pub fn new(is_client: bool) -> Self {
Self {
local_contexts: HashMap::new(),
remote_contexts: HashMap::new(),
uncompressed_context: None,
next_local_id: if is_client { 2 } else { 1 },
is_client,
}
}
pub fn is_client(&self) -> bool {
self.is_client
}
pub fn allocate_local(&mut self) -> Result<VarInt, ContextError> {
let id = self.next_local_id;
if id > VarInt::MAX.into_inner() {
return Err(ContextError::IdSpaceExhausted);
}
self.next_local_id = self
.next_local_id
.checked_add(2)
.ok_or(ContextError::IdSpaceExhausted)?;
VarInt::from_u64(id).map_err(|_| ContextError::IdSpaceExhausted)
}
pub fn register_uncompressed(&mut self, context_id: VarInt) -> Result<(), ContextError> {
if self.uncompressed_context.is_some() {
return Err(ContextError::DuplicateUncompressed);
}
if context_id.into_inner() == 0 {
return Err(ContextError::ReservedId);
}
let info = ContextInfo {
target: None,
state: ContextState::Pending,
created_at: Instant::now(),
last_activity: Instant::now(),
};
self.local_contexts.insert(context_id, info);
self.uncompressed_context = Some(context_id);
Ok(())
}
pub fn register_compressed(
&mut self,
context_id: VarInt,
target: SocketAddr,
) -> Result<(), ContextError> {
for info in self
.local_contexts
.values()
.chain(self.remote_contexts.values())
{
if info.target == Some(target) && info.state != ContextState::Closed {
return Err(ContextError::DuplicateTarget(target));
}
}
let info = ContextInfo {
target: Some(target),
state: ContextState::Pending,
created_at: Instant::now(),
last_activity: Instant::now(),
};
self.local_contexts.insert(context_id, info);
Ok(())
}
pub fn register_remote(
&mut self,
context_id: VarInt,
target: Option<SocketAddr>,
) -> Result<(), ContextError> {
if target.is_none() && self.uncompressed_context.is_some() {
return Err(ContextError::DuplicateUncompressed);
}
if let Some(t) = target {
for info in self
.local_contexts
.values()
.chain(self.remote_contexts.values())
{
if info.target == Some(t) && info.state != ContextState::Closed {
return Err(ContextError::DuplicateTarget(t));
}
}
}
let info = ContextInfo {
target,
state: ContextState::Active, created_at: Instant::now(),
last_activity: Instant::now(),
};
self.remote_contexts.insert(context_id, info);
if target.is_none() {
self.uncompressed_context = Some(context_id);
}
Ok(())
}
pub fn handle_ack(&mut self, context_id: VarInt) -> Result<(), ContextError> {
let info = self
.local_contexts
.get_mut(&context_id)
.ok_or(ContextError::UnknownContext)?;
if info.state != ContextState::Pending {
return Err(ContextError::InvalidState);
}
info.state = ContextState::Active;
info.last_activity = Instant::now();
Ok(())
}
pub fn close(&mut self, context_id: VarInt) -> Result<(), ContextError> {
if let Some(info) = self.local_contexts.get_mut(&context_id) {
info.state = ContextState::Closed;
info.last_activity = Instant::now();
} else if let Some(info) = self.remote_contexts.get_mut(&context_id) {
info.state = ContextState::Closed;
info.last_activity = Instant::now();
} else {
return Err(ContextError::UnknownContext);
}
if self.uncompressed_context == Some(context_id) {
self.uncompressed_context = None;
}
Ok(())
}
pub fn get_by_target(&self, target: SocketAddr) -> Option<VarInt> {
for (id, info) in self
.local_contexts
.iter()
.chain(self.remote_contexts.iter())
{
if info.target == Some(target) && info.state == ContextState::Active {
return Some(*id);
}
}
None
}
pub fn uncompressed(&self) -> Option<VarInt> {
self.uncompressed_context.filter(|id| {
self.local_contexts
.get(id)
.or_else(|| self.remote_contexts.get(id))
.map(|i| i.state == ContextState::Active)
.unwrap_or(false)
})
}
pub fn get_context(&self, context_id: VarInt) -> Option<&ContextInfo> {
self.local_contexts
.get(&context_id)
.or_else(|| self.remote_contexts.get(&context_id))
}
pub fn get_target(&self, context_id: VarInt) -> Option<SocketAddr> {
self.get_context(context_id).and_then(|info| info.target)
}
pub fn touch(&mut self, context_id: VarInt) -> Result<(), ContextError> {
if let Some(info) = self.local_contexts.get_mut(&context_id) {
info.last_activity = Instant::now();
Ok(())
} else if let Some(info) = self.remote_contexts.get_mut(&context_id) {
info.last_activity = Instant::now();
Ok(())
} else {
Err(ContextError::UnknownContext)
}
}
pub fn active_count(&self) -> usize {
self.local_contexts
.values()
.chain(self.remote_contexts.values())
.filter(|info| info.state == ContextState::Active)
.count()
}
pub fn cleanup_closed(&mut self, max_age: std::time::Duration) {
let now = Instant::now();
self.local_contexts.retain(|_, info| {
info.state != ContextState::Closed || now.duration_since(info.last_activity) < max_age
});
self.remote_contexts.retain(|_, info| {
info.state != ContextState::Closed || now.duration_since(info.last_activity) < max_age
});
}
pub fn local_context_ids(&self) -> impl Iterator<Item = VarInt> + '_ {
self.local_contexts.keys().copied()
}
pub fn remote_context_ids(&self) -> impl Iterator<Item = VarInt> + '_ {
self.remote_contexts.keys().copied()
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ContextError {
IdSpaceExhausted,
DuplicateUncompressed,
ReservedId,
DuplicateTarget(SocketAddr),
UnknownContext,
InvalidState,
}
impl fmt::Display for ContextError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ContextError::IdSpaceExhausted => write!(f, "context ID space exhausted"),
ContextError::DuplicateUncompressed => {
write!(f, "only one uncompressed context allowed")
}
ContextError::ReservedId => write!(f, "context ID 0 is reserved"),
ContextError::DuplicateTarget(addr) => {
write!(f, "duplicate target address: {}", addr)
}
ContextError::UnknownContext => write!(f, "unknown context ID"),
ContextError::InvalidState => write!(f, "invalid context state for operation"),
}
}
}
impl std::error::Error for ContextError {}
#[cfg(test)]
mod tests {
use super::*;
use std::net::{IpAddr, Ipv4Addr};
#[test]
fn test_context_allocation_client() {
let mut mgr = ContextManager::new(true);
assert!(mgr.is_client());
let id1 = mgr.allocate_local().unwrap();
assert_eq!(id1.into_inner(), 2);
let id2 = mgr.allocate_local().unwrap();
assert_eq!(id2.into_inner(), 4);
let id3 = mgr.allocate_local().unwrap();
assert_eq!(id3.into_inner(), 6);
}
#[test]
fn test_context_allocation_server() {
let mut mgr = ContextManager::new(false);
assert!(!mgr.is_client());
let id1 = mgr.allocate_local().unwrap();
assert_eq!(id1.into_inner(), 1);
let id2 = mgr.allocate_local().unwrap();
assert_eq!(id2.into_inner(), 3);
}
#[test]
fn test_uncompressed_context_limit() {
let mut mgr = ContextManager::new(true);
let id = mgr.allocate_local().unwrap();
mgr.register_uncompressed(id).unwrap();
let id2 = mgr.allocate_local().unwrap();
let result = mgr.register_uncompressed(id2);
assert_eq!(result, Err(ContextError::DuplicateUncompressed));
}
#[test]
fn test_reserved_id_zero() {
let mut mgr = ContextManager::new(true);
let result = mgr.register_uncompressed(VarInt::from_u32(0));
assert_eq!(result, Err(ContextError::ReservedId));
}
#[test]
fn test_compressed_context_lifecycle() {
let mut mgr = ContextManager::new(true);
let id = mgr.allocate_local().unwrap();
let target = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)), 8080);
mgr.register_compressed(id, target).unwrap();
assert_eq!(mgr.get_context(id).unwrap().state, ContextState::Pending);
mgr.handle_ack(id).unwrap();
assert_eq!(mgr.get_context(id).unwrap().state, ContextState::Active);
assert_eq!(mgr.get_by_target(target), Some(id));
assert_eq!(mgr.get_target(id), Some(target));
mgr.close(id).unwrap();
assert_eq!(mgr.get_context(id).unwrap().state, ContextState::Closed);
assert_eq!(mgr.get_by_target(target), None);
}
#[test]
fn test_duplicate_target() {
let mut mgr = ContextManager::new(true);
let target = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)), 9000);
let id1 = mgr.allocate_local().unwrap();
mgr.register_compressed(id1, target).unwrap();
mgr.handle_ack(id1).unwrap();
let id2 = mgr.allocate_local().unwrap();
let result = mgr.register_compressed(id2, target);
assert_eq!(result, Err(ContextError::DuplicateTarget(target)));
}
#[test]
fn test_remote_context_registration() {
let mut mgr = ContextManager::new(true);
let target = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)), 8080);
mgr.register_remote(VarInt::from_u32(1), Some(target))
.unwrap();
assert_eq!(
mgr.get_context(VarInt::from_u32(1)).unwrap().state,
ContextState::Active
);
assert_eq!(mgr.get_by_target(target), Some(VarInt::from_u32(1)));
}
#[test]
fn test_active_count() {
let mut mgr = ContextManager::new(true);
assert_eq!(mgr.active_count(), 0);
let id1 = mgr.allocate_local().unwrap();
let target1 = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)), 1000);
mgr.register_compressed(id1, target1).unwrap();
mgr.handle_ack(id1).unwrap();
assert_eq!(mgr.active_count(), 1);
let id2 = mgr.allocate_local().unwrap();
let target2 = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 2)), 2000);
mgr.register_compressed(id2, target2).unwrap();
mgr.handle_ack(id2).unwrap();
assert_eq!(mgr.active_count(), 2);
mgr.close(id1).unwrap();
assert_eq!(mgr.active_count(), 1);
}
#[test]
fn test_unknown_context_errors() {
let mut mgr = ContextManager::new(true);
let unknown_id = VarInt::from_u32(999);
assert_eq!(
mgr.handle_ack(unknown_id),
Err(ContextError::UnknownContext)
);
assert_eq!(mgr.close(unknown_id), Err(ContextError::UnknownContext));
assert_eq!(mgr.touch(unknown_id), Err(ContextError::UnknownContext));
}
#[test]
fn test_invalid_state_ack() {
let mut mgr = ContextManager::new(true);
let id = mgr.allocate_local().unwrap();
let target = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)), 8080);
mgr.register_compressed(id, target).unwrap();
mgr.handle_ack(id).unwrap();
assert_eq!(mgr.handle_ack(id), Err(ContextError::InvalidState));
}
#[test]
fn test_context_iterators() {
let mut mgr = ContextManager::new(true);
let id1 = mgr.allocate_local().unwrap();
let id2 = mgr.allocate_local().unwrap();
let target1 = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)), 1000);
let target2 = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 2)), 2000);
mgr.register_compressed(id1, target1).unwrap();
mgr.register_compressed(id2, target2).unwrap();
let local_ids: Vec<_> = mgr.local_context_ids().collect();
assert_eq!(local_ids.len(), 2);
assert!(local_ids.contains(&id1));
assert!(local_ids.contains(&id2));
let remote_id = VarInt::from_u32(1);
let remote_target = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)), 8080);
mgr.register_remote(remote_id, Some(remote_target)).unwrap();
let remote_ids: Vec<_> = mgr.remote_context_ids().collect();
assert_eq!(remote_ids.len(), 1);
assert!(remote_ids.contains(&remote_id));
}
}