use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use std::time::{Duration, SystemTime};
pub struct TransactionTracker {
retention_period: Duration,
housekeeping_interval: Duration,
transactions: Mutex<TransactionStore>,
}
struct TransactionStore {
transactions: HashMap<(u32, String), TransactionState>,
last_housekeeping: SystemTime,
}
#[derive(Debug)]
pub enum TransactionStatus {
New,
InProgress,
Completed(Arc<Vec<u8>>),
}
impl TransactionTracker {
pub fn new(retention_period: Duration) -> Self {
let housekeeping_interval = if retention_period < Duration::from_secs(1) {
retention_period
} else {
std::cmp::max(Duration::from_secs(1), retention_period / 4)
};
Self {
retention_period,
housekeeping_interval,
transactions: Mutex::new(TransactionStore {
transactions: HashMap::new(),
last_housekeeping: SystemTime::now(),
}),
}
}
pub fn check(&self, xid: u32, client_addr: &str) -> TransactionStatus {
let key = (xid, client_addr.to_string());
let now = SystemTime::now();
let mut store = self.transactions.lock().expect("unable to unlock transactions mutex");
if !store.transactions.is_empty()
&& should_housekeep(store.last_housekeeping, now, self.housekeeping_interval)
{
housekeeping(&mut store.transactions, self.retention_period, now);
store.last_housekeeping = now;
}
match store.transactions.entry(key) {
std::collections::hash_map::Entry::Vacant(e) => {
e.insert(TransactionState::InProgress);
TransactionStatus::New
}
std::collections::hash_map::Entry::Occupied(mut entry) => match entry.get_mut() {
TransactionState::InProgress => TransactionStatus::InProgress,
TransactionState::Completed { completion_time, response } => {
*completion_time = now;
TransactionStatus::Completed(Arc::clone(response))
}
},
}
}
pub fn record_response(&self, xid: u32, client_addr: &str, response: Arc<Vec<u8>>) {
let key = (xid, client_addr.to_string());
let completion_time = SystemTime::now();
let mut store = self.transactions.lock().expect("unable to unlock transactions mutex");
store.transactions.insert(key, TransactionState::Completed { completion_time, response });
}
pub fn clear(&self, xid: u32, client_addr: &str) {
let key = (xid, client_addr.to_string());
let mut store = self.transactions.lock().expect("unable to unlock transactions mutex");
store.transactions.remove(&key);
}
}
fn should_housekeep(last: SystemTime, now: SystemTime, interval: Duration) -> bool {
now.duration_since(last).map(|elapsed| elapsed >= interval).unwrap_or(true)
}
fn housekeeping(
transactions: &mut HashMap<(u32, String), TransactionState>,
max_age: Duration,
now: SystemTime,
) {
let mut cutoff = now - max_age;
transactions.retain(|_, v| match v {
TransactionState::InProgress => true,
TransactionState::Completed { completion_time, .. } => completion_time >= &mut cutoff,
});
}
enum TransactionState {
InProgress,
Completed { completion_time: SystemTime, response: Arc<Vec<u8>> },
}
#[cfg(test)]
mod tests {
use super::{TransactionStatus, TransactionTracker};
use std::sync::Arc;
use std::time::Duration;
#[test]
fn retransmit_in_flight_reports_in_progress() {
let tracker = TransactionTracker::new(Duration::from_secs(60));
let xid = 7;
let client_addr = "127.0.0.1:1234";
assert!(matches!(tracker.check(xid, client_addr), TransactionStatus::New));
assert!(matches!(tracker.check(xid, client_addr), TransactionStatus::InProgress));
let response = Arc::new(vec![1, 2, 3]);
tracker.record_response(xid, client_addr, Arc::clone(&response));
match tracker.check(xid, client_addr) {
TransactionStatus::Completed(replay) => {
assert_eq!(&*replay, &*response);
}
other => panic!("expected Completed, got {other:?}"),
}
}
}