use dashmap::DashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Duration, Instant};
use turbomcp_protocol::MessageId;
use crate::error::{ProxyError, ProxyResult};
const MAX_MAPPINGS: usize = 10_000;
const MAPPING_TIMEOUT: Duration = Duration::from_secs(300);
#[derive(Debug, Clone)]
struct MappingEntry {
backend_id: MessageId,
created_at: Instant,
}
#[derive(Debug, Clone)]
pub struct IdTranslator {
frontend_to_backend: Arc<DashMap<MessageId, MappingEntry>>,
backend_to_frontend: Arc<DashMap<MessageId, MessageId>>,
next_backend_id: Arc<AtomicU64>,
max_mappings: usize,
mapping_timeout: Duration,
}
impl IdTranslator {
#[must_use]
pub fn new() -> Self {
Self::with_limits(MAX_MAPPINGS, MAPPING_TIMEOUT)
}
#[must_use]
pub fn with_limits(max_mappings: usize, mapping_timeout: Duration) -> Self {
Self {
frontend_to_backend: Arc::new(DashMap::new()),
backend_to_frontend: Arc::new(DashMap::new()),
next_backend_id: Arc::new(AtomicU64::new(1)),
max_mappings,
mapping_timeout,
}
}
pub fn allocate(&self, frontend_id: MessageId) -> ProxyResult<MessageId> {
self.evict_expired();
if self.frontend_to_backend.len() >= self.max_mappings {
return Err(ProxyError::rate_limit_exceeded(format!(
"Too many pending requests ({}/{}), server overloaded",
self.frontend_to_backend.len(),
self.max_mappings
)));
}
let backend_id_num = self.next_backend_id.fetch_add(1, Ordering::SeqCst);
#[allow(clippy::cast_possible_wrap)]
let backend_id = MessageId::Number(backend_id_num as i64);
let entry = MappingEntry {
backend_id: backend_id.clone(),
created_at: Instant::now(),
};
self.frontend_to_backend.insert(frontend_id.clone(), entry);
self.backend_to_frontend
.insert(backend_id.clone(), frontend_id);
Ok(backend_id)
}
#[must_use]
pub fn get_frontend_id(&self, backend_id: &MessageId) -> Option<MessageId> {
self.backend_to_frontend
.get(backend_id)
.map(|entry| entry.value().clone())
}
pub fn release(&self, frontend_id: &MessageId) {
if let Some((_, entry)) = self.frontend_to_backend.remove(frontend_id) {
self.backend_to_frontend
.remove_if(&entry.backend_id, |_k, v| v == frontend_id);
}
}
fn evict_expired(&self) {
let now = Instant::now();
self.frontend_to_backend
.retain(|_k, v| now.duration_since(v.created_at) < self.mapping_timeout);
self.backend_to_frontend
.retain(|_backend_id, frontend_id| self.frontend_to_backend.contains_key(frontend_id));
}
#[must_use]
pub fn spawn_eviction_task(self: Arc<Self>) -> tokio::task::JoinHandle<()> {
tokio::spawn(async move {
let mut interval = tokio::time::interval(Duration::from_secs(60));
loop {
interval.tick().await;
self.evict_expired();
}
})
}
#[must_use]
pub fn mapping_count(&self) -> usize {
self.frontend_to_backend.len()
}
pub fn clear(&self) {
self.frontend_to_backend.clear();
self.backend_to_frontend.clear();
}
}
impl Default for IdTranslator {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
#[test]
fn test_allocate_and_lookup() {
let translator = IdTranslator::new();
let frontend_id = MessageId::String("frontend-123".to_string());
let backend_id = translator.allocate(frontend_id.clone()).unwrap();
assert!(matches!(backend_id, MessageId::Number(1)));
let found_frontend_id = translator.get_frontend_id(&backend_id);
assert_eq!(found_frontend_id, Some(frontend_id.clone()));
translator.release(&frontend_id);
let not_found = translator.get_frontend_id(&backend_id);
assert_eq!(not_found, None);
}
#[test]
fn test_multiple_allocations() {
let translator = IdTranslator::new();
let id1 = MessageId::String("req-1".to_string());
let id2 = MessageId::String("req-2".to_string());
let id3 = MessageId::Number(999);
let backend1 = translator.allocate(id1.clone()).unwrap();
let backend2 = translator.allocate(id2.clone()).unwrap();
let backend3 = translator.allocate(id3.clone()).unwrap();
assert_eq!(backend1, MessageId::Number(1));
assert_eq!(backend2, MessageId::Number(2));
assert_eq!(backend3, MessageId::Number(3));
assert_eq!(translator.get_frontend_id(&backend1), Some(id1.clone()));
assert_eq!(translator.get_frontend_id(&backend2), Some(id2.clone()));
assert_eq!(translator.get_frontend_id(&backend3), Some(id3.clone()));
assert_eq!(translator.mapping_count(), 3);
translator.clear();
assert_eq!(translator.mapping_count(), 0);
}
#[test]
fn test_sequential_backend_ids() {
let translator = IdTranslator::new();
for i in 1..=10 {
let frontend_id = MessageId::String(format!("req-{i}"));
let backend_id = translator.allocate(frontend_id).unwrap();
assert_eq!(backend_id, MessageId::Number(i64::from(i)));
}
}
#[test]
fn test_max_mappings_limit() {
let translator = IdTranslator::with_limits(5, Duration::from_secs(300));
for i in 1..=5 {
let frontend_id = MessageId::String(format!("req-{i}"));
let result = translator.allocate(frontend_id);
assert!(result.is_ok(), "Should allocate within limit");
}
let frontend_id = MessageId::String("req-overflow".to_string());
let result = translator.allocate(frontend_id);
assert!(result.is_err(), "Should fail when exceeding limit");
match result {
Err(ProxyError::RateLimitExceeded { .. }) => {}
_ => panic!("Expected RateLimitExceeded error"),
}
}
#[test]
fn test_timeout_eviction() {
let translator = IdTranslator::with_limits(10, Duration::from_millis(100));
let id1 = MessageId::String("req-1".to_string());
let id2 = MessageId::String("req-2".to_string());
translator.allocate(id1.clone()).unwrap();
translator.allocate(id2.clone()).unwrap();
assert_eq!(translator.mapping_count(), 2);
thread::sleep(Duration::from_millis(150));
translator.evict_expired();
assert_eq!(translator.mapping_count(), 0);
}
#[test]
fn test_release_race_condition() {
let translator = Arc::new(IdTranslator::new());
let frontend_id = MessageId::String("concurrent-test".to_string());
let backend_id = translator.allocate(frontend_id.clone()).unwrap();
assert!(translator.get_frontend_id(&backend_id).is_some());
let handles: Vec<_> = (0..10)
.map(|_| {
let t = Arc::clone(&translator);
let fid = frontend_id.clone();
thread::spawn(move || {
t.release(&fid);
})
})
.collect();
for handle in handles {
handle.join().unwrap();
}
assert!(translator.get_frontend_id(&backend_id).is_none());
assert_eq!(translator.mapping_count(), 0);
}
#[test]
fn test_eviction_after_limit_makes_room() {
let translator = IdTranslator::with_limits(3, Duration::from_millis(100));
for i in 1..=3 {
let frontend_id = MessageId::String(format!("req-{i}"));
translator.allocate(frontend_id).unwrap();
}
let result = translator.allocate(MessageId::String("overflow".to_string()));
assert!(result.is_err());
thread::sleep(Duration::from_millis(150));
let result = translator.allocate(MessageId::String("after-timeout".to_string()));
assert!(
result.is_ok(),
"Should succeed after expired entries are evicted"
);
}
}