use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Duration, Instant};
use parking_lot::RwLock;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct TransactionKey {
pub source: SocketAddr,
pub invoke_id: u8,
}
impl TransactionKey {
pub fn new(source: SocketAddr, invoke_id: u8) -> Self {
Self { source, invoke_id }
}
}
#[derive(Debug, Clone)]
pub struct TransactionEntry {
pub created_at: Instant,
pub service_choice: u8,
pub cached_response: Option<Vec<u8>>,
pub duplicate_count: u32,
pub state: TransactionState,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TransactionState {
AwaitingResponse,
ResponseReady,
Completed,
}
#[derive(Debug, Clone)]
pub struct TsmConfig {
pub duplicate_window: Duration,
pub max_transactions: usize,
pub chaos_delay: Duration,
pub chaos_drop_probability: f64,
}
impl Default for TsmConfig {
fn default() -> Self {
Self {
duplicate_window: Duration::from_secs(10),
max_transactions: 256,
chaos_delay: Duration::ZERO,
chaos_drop_probability: 0.0,
}
}
}
pub struct ServerTsm {
config: TsmConfig,
transactions: RwLock<HashMap<TransactionKey, TransactionEntry>>,
total_transactions: AtomicU64,
duplicate_count: AtomicU64,
dropped_count: AtomicU64,
timeout_count: AtomicU64,
}
impl ServerTsm {
pub fn new() -> Self {
Self::with_config(TsmConfig::default())
}
pub fn with_config(config: TsmConfig) -> Self {
Self {
transactions: RwLock::new(HashMap::with_capacity(config.max_transactions)),
config,
total_transactions: AtomicU64::new(0),
duplicate_count: AtomicU64::new(0),
dropped_count: AtomicU64::new(0),
timeout_count: AtomicU64::new(0),
}
}
pub fn begin_transaction(
&self,
key: TransactionKey,
service_choice: u8,
) -> Result<Option<Vec<u8>>, TsmError> {
let now = Instant::now();
{
let mut txns = self.transactions.write();
if let Some(entry) = txns.get_mut(&key) {
if now.duration_since(entry.created_at) > self.config.duplicate_window {
txns.remove(&key);
} else {
entry.duplicate_count += 1;
self.duplicate_count.fetch_add(1, Ordering::Relaxed);
return match entry.state {
TransactionState::AwaitingResponse => {
Err(TsmError::DuplicateInProgress)
}
TransactionState::ResponseReady | TransactionState::Completed => {
Ok(entry.cached_response.clone())
}
};
}
}
if txns.len() >= self.config.max_transactions {
let expired_keys: Vec<TransactionKey> = txns
.iter()
.filter(|(_, e)| {
now.duration_since(e.created_at) > self.config.duplicate_window
})
.map(|(k, _)| *k)
.collect();
for k in &expired_keys {
txns.remove(k);
}
self.timeout_count
.fetch_add(expired_keys.len() as u64, Ordering::Relaxed);
if txns.len() >= self.config.max_transactions {
return Err(TsmError::AtCapacity);
}
}
txns.insert(
key,
TransactionEntry {
created_at: now,
service_choice,
cached_response: None,
duplicate_count: 0,
state: TransactionState::AwaitingResponse,
},
);
}
self.total_transactions.fetch_add(1, Ordering::Relaxed);
Ok(None)
}
pub fn complete_transaction(&self, key: &TransactionKey, response: Vec<u8>) -> bool {
let mut txns = self.transactions.write();
if let Some(entry) = txns.get_mut(key) {
entry.cached_response = Some(response);
entry.state = TransactionState::Completed;
}
if self.config.chaos_drop_probability > 0.0 {
let counter = self.total_transactions.load(Ordering::Relaxed);
let threshold = (self.config.chaos_drop_probability * 1000.0) as u64;
let pseudo_random = (counter * 7 + 13) % 1000;
if pseudo_random < threshold {
self.dropped_count.fetch_add(1, Ordering::Relaxed);
return false;
}
}
true
}
pub fn chaos_delay(&self) -> Duration {
self.config.chaos_delay
}
pub fn cleanup_expired(&self) -> usize {
let now = Instant::now();
let mut txns = self.transactions.write();
let before = txns.len();
txns.retain(|_, entry| {
now.duration_since(entry.created_at) <= self.config.duplicate_window
});
let removed = before - txns.len();
if removed > 0 {
self.timeout_count
.fetch_add(removed as u64, Ordering::Relaxed);
}
removed
}
pub fn active_count(&self) -> usize {
self.transactions.read().len()
}
pub fn statistics(&self) -> TsmStatistics {
TsmStatistics {
total_transactions: self.total_transactions.load(Ordering::Relaxed),
active_transactions: self.active_count(),
duplicate_count: self.duplicate_count.load(Ordering::Relaxed),
dropped_count: self.dropped_count.load(Ordering::Relaxed),
timeout_count: self.timeout_count.load(Ordering::Relaxed),
}
}
pub fn set_chaos_delay(&mut self, delay: Duration) {
self.config.chaos_delay = delay;
}
pub fn set_chaos_drop_probability(&mut self, probability: f64) {
self.config.chaos_drop_probability = probability.clamp(0.0, 1.0);
}
}
impl Default for ServerTsm {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct TsmStatistics {
pub total_transactions: u64,
pub active_transactions: usize,
pub duplicate_count: u64,
pub dropped_count: u64,
pub timeout_count: u64,
}
#[derive(Debug, thiserror::Error)]
pub enum TsmError {
#[error("Duplicate request still in progress")]
DuplicateInProgress,
#[error("TSM at maximum capacity")]
AtCapacity,
}
#[cfg(test)]
mod tests {
use super::*;
fn localhost(port: u16) -> SocketAddr {
SocketAddr::from(([127, 0, 0, 1], port))
}
#[test]
fn test_new_transaction() {
let tsm = ServerTsm::new();
let key = TransactionKey::new(localhost(47808), 1);
let result = tsm.begin_transaction(key, 12);
assert!(result.is_ok());
assert!(result.unwrap().is_none()); assert_eq!(tsm.active_count(), 1);
}
#[test]
fn test_duplicate_detection_in_progress() {
let tsm = ServerTsm::new();
let key = TransactionKey::new(localhost(47808), 1);
tsm.begin_transaction(key, 12).unwrap();
let result = tsm.begin_transaction(key, 12);
assert!(matches!(result, Err(TsmError::DuplicateInProgress)));
}
#[test]
fn test_duplicate_returns_cached_response() {
let tsm = ServerTsm::new();
let key = TransactionKey::new(localhost(47808), 1);
tsm.begin_transaction(key, 12).unwrap();
let response = vec![0x30, 1, 12, 0xAA, 0xBB];
tsm.complete_transaction(&key, response.clone());
let result = tsm.begin_transaction(key, 12).unwrap();
assert_eq!(result, Some(response));
}
#[test]
fn test_different_invoke_ids_are_independent() {
let tsm = ServerTsm::new();
let addr = localhost(47808);
let key1 = TransactionKey::new(addr, 1);
let key2 = TransactionKey::new(addr, 2);
tsm.begin_transaction(key1, 12).unwrap();
let result = tsm.begin_transaction(key2, 15);
assert!(result.is_ok());
assert!(result.unwrap().is_none());
assert_eq!(tsm.active_count(), 2);
}
#[test]
fn test_different_sources_are_independent() {
let tsm = ServerTsm::new();
let key1 = TransactionKey::new(localhost(47808), 1);
let key2 = TransactionKey::new(localhost(47809), 1);
tsm.begin_transaction(key1, 12).unwrap();
let result = tsm.begin_transaction(key2, 12);
assert!(result.is_ok());
assert!(result.unwrap().is_none());
assert_eq!(tsm.active_count(), 2);
}
#[test]
fn test_capacity_limit() {
let config = TsmConfig {
max_transactions: 2,
..Default::default()
};
let tsm = ServerTsm::with_config(config);
let key1 = TransactionKey::new(localhost(47808), 1);
let key2 = TransactionKey::new(localhost(47808), 2);
let key3 = TransactionKey::new(localhost(47808), 3);
tsm.begin_transaction(key1, 12).unwrap();
tsm.begin_transaction(key2, 12).unwrap();
let result = tsm.begin_transaction(key3, 12);
assert!(matches!(result, Err(TsmError::AtCapacity)));
}
#[test]
fn test_cleanup_expired() {
let config = TsmConfig {
duplicate_window: Duration::from_millis(10),
..Default::default()
};
let tsm = ServerTsm::with_config(config);
let key = TransactionKey::new(localhost(47808), 1);
tsm.begin_transaction(key, 12).unwrap();
tsm.complete_transaction(&key, vec![0x20, 1, 12]);
assert_eq!(tsm.active_count(), 1);
std::thread::sleep(Duration::from_millis(15));
let removed = tsm.cleanup_expired();
assert_eq!(removed, 1);
assert_eq!(tsm.active_count(), 0);
}
#[test]
fn test_expired_entry_treated_as_new() {
let config = TsmConfig {
duplicate_window: Duration::from_millis(10),
..Default::default()
};
let tsm = ServerTsm::with_config(config);
let key = TransactionKey::new(localhost(47808), 1);
tsm.begin_transaction(key, 12).unwrap();
tsm.complete_transaction(&key, vec![0x20, 1, 12]);
std::thread::sleep(Duration::from_millis(15));
let result = tsm.begin_transaction(key, 12);
assert!(result.is_ok());
assert!(result.unwrap().is_none());
}
#[test]
fn test_statistics() {
let tsm = ServerTsm::new();
let key = TransactionKey::new(localhost(47808), 1);
tsm.begin_transaction(key, 12).unwrap();
tsm.complete_transaction(&key, vec![0x20, 1, 12]);
let _ = tsm.begin_transaction(key, 12);
let stats = tsm.statistics();
assert_eq!(stats.total_transactions, 1);
assert_eq!(stats.active_transactions, 1);
assert_eq!(stats.duplicate_count, 1);
}
#[test]
fn test_chaos_drop() {
let config = TsmConfig {
chaos_drop_probability: 1.0, ..Default::default()
};
let tsm = ServerTsm::with_config(config);
let key = TransactionKey::new(localhost(47808), 1);
tsm.begin_transaction(key, 12).unwrap();
let should_send = tsm.complete_transaction(&key, vec![0x20, 1, 12]);
assert!(!should_send);
let stats = tsm.statistics();
assert_eq!(stats.dropped_count, 1);
}
#[test]
fn test_no_chaos_drop() {
let config = TsmConfig {
chaos_drop_probability: 0.0, ..Default::default()
};
let tsm = ServerTsm::with_config(config);
let key = TransactionKey::new(localhost(47808), 1);
tsm.begin_transaction(key, 12).unwrap();
let should_send = tsm.complete_transaction(&key, vec![0x20, 1, 12]);
assert!(should_send);
}
#[test]
fn test_chaos_delay() {
let config = TsmConfig {
chaos_delay: Duration::from_millis(500),
..Default::default()
};
let tsm = ServerTsm::with_config(config);
assert_eq!(tsm.chaos_delay(), Duration::from_millis(500));
}
#[test]
fn test_capacity_evicts_expired() {
let config = TsmConfig {
max_transactions: 2,
duplicate_window: Duration::from_millis(10),
..Default::default()
};
let tsm = ServerTsm::with_config(config);
let key1 = TransactionKey::new(localhost(47808), 1);
let key2 = TransactionKey::new(localhost(47808), 2);
tsm.begin_transaction(key1, 12).unwrap();
tsm.begin_transaction(key2, 12).unwrap();
std::thread::sleep(Duration::from_millis(15));
let key3 = TransactionKey::new(localhost(47808), 3);
let result = tsm.begin_transaction(key3, 12);
assert!(result.is_ok());
}
#[test]
fn test_transaction_states() {
let tsm = ServerTsm::new();
let key = TransactionKey::new(localhost(47808), 1);
tsm.begin_transaction(key, 12).unwrap();
{
let txns = tsm.transactions.read();
let entry = txns.get(&key).unwrap();
assert_eq!(entry.state, TransactionState::AwaitingResponse);
}
tsm.complete_transaction(&key, vec![0x20, 1, 12]);
{
let txns = tsm.transactions.read();
let entry = txns.get(&key).unwrap();
assert_eq!(entry.state, TransactionState::Completed);
assert!(entry.cached_response.is_some());
}
}
}