use std::{
sync::{
Arc,
atomic::{AtomicUsize, Ordering},
},
time::{Duration, Instant},
};
use tokio::sync::oneshot;
use super::{config::WsConfig, types::RequestId};
use crate::error::{TransportError, TransportResult};
type ResponseSender = oneshot::Sender<TransportResult<String>>;
pub struct PendingRequest {
pub response_tx: oneshot::Sender<TransportResult<String>>,
pub created_at: Instant,
pub timeout: Duration,
}
pub struct PendingRequestStore {
requests: scc::HashMap<RequestId, PendingRequest>,
config: Arc<WsConfig>,
count: AtomicUsize,
}
impl PendingRequestStore {
pub fn new(config: Arc<WsConfig>) -> Self {
Self {
requests: scc::HashMap::new(),
config,
count: AtomicUsize::new(0),
}
}
pub fn add(
&self,
id: RequestId,
timeout: Option<Duration>,
) -> Option<oneshot::Receiver<TransportResult<String>>> {
if !self.reserve_slot() {
return None;
}
let (tx, rx) = oneshot::channel();
let timeout = timeout.unwrap_or(self.config.request_timeout);
let pending = PendingRequest {
response_tx: tx,
created_at: Instant::now(),
timeout,
};
if self.requests.insert_sync(id, pending).is_err() {
self.count.fetch_sub(1, Ordering::AcqRel);
return None;
}
Some(rx)
}
pub fn resolve(&self, id: &RequestId, response: TransportResult<String>) -> bool {
if let Some((_, pending)) = self.requests.remove_sync(id) {
self.count.fetch_sub(1, Ordering::AcqRel);
let _ = pending.response_tx.send(response);
return true;
}
false
}
pub fn remove(&self, id: &RequestId) -> bool {
if self.requests.remove_sync(id).is_some() {
self.count.fetch_sub(1, Ordering::AcqRel);
return true;
}
false
}
pub fn cleanup_stale(&self) {
let now = Instant::now();
let removed = AtomicUsize::new(0);
self.requests.retain_sync(|_, pending| {
let keep = now.duration_since(pending.created_at) < pending.timeout;
if !keep {
removed.fetch_add(1, Ordering::Relaxed);
}
keep
});
let removed = removed.load(Ordering::Relaxed);
if removed > 0 {
self.count.fetch_sub(removed, Ordering::AcqRel);
}
}
pub fn cleanup_stale_with_notify(&self) {
let now = Instant::now();
let mut expired_senders: Vec<(ResponseSender, Duration, RequestId)> = Vec::new();
self.requests.retain_sync(|id, pending| {
if now.duration_since(pending.created_at) >= pending.timeout {
let tx = std::mem::replace(&mut pending.response_tx, oneshot::channel().0);
expired_senders.push((tx, pending.timeout, id.clone()));
false } else {
true }
});
let removed = expired_senders.len();
if removed > 0 {
self.count.fetch_sub(removed, Ordering::AcqRel);
}
for (tx, timeout_dur, id) in expired_senders {
let _ = tx.send(Err(TransportError::request_timeout(
timeout_dur,
id.to_string(),
)));
}
}
pub fn has_capacity(&self) -> bool {
self.count.load(Ordering::Acquire) < self.config.max_pending_requests
}
pub fn len(&self) -> usize {
self.count.load(Ordering::Acquire)
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn clear_with_error(&self, error_message: &str) {
let mut senders: Vec<ResponseSender> = Vec::new();
self.requests.retain_sync(|_, pending| {
let tx = std::mem::replace(&mut pending.response_tx, oneshot::channel().0);
senders.push(tx);
false });
let removed = senders.len();
if removed > 0 {
self.count.fetch_sub(removed, Ordering::AcqRel);
}
let error_message = error_message.to_string();
for tx in senders {
let _ = tx.send(Err(TransportError::connection_closed(Some(
error_message.clone(),
))));
}
}
pub fn clear(&self) {
let mut removed = 0usize;
self.requests.retain_sync(|_, _| {
removed += 1;
false
});
if removed > 0 {
self.count.fetch_sub(removed, Ordering::AcqRel);
}
}
fn reserve_slot(&self) -> bool {
let max = self.config.max_pending_requests;
self.count
.fetch_update(Ordering::AcqRel, Ordering::Acquire, |current| {
if current >= max {
None
} else {
Some(current + 1)
}
})
.is_ok()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_config() -> Arc<WsConfig> {
Arc::new(WsConfig::new("wss://test.com").max_pending_requests(10))
}
#[test]
fn test_add_and_resolve() {
let store = PendingRequestStore::new(test_config());
let id = RequestId::new();
let rx = store.add(id.clone(), None);
assert!(rx.is_some());
assert_eq!(store.len(), 1);
let resolved = store.resolve(&id, Ok("response".to_string()));
assert!(resolved);
assert_eq!(store.len(), 0);
}
#[test]
fn test_capacity_limit() {
let store = PendingRequestStore::new(test_config());
for _ in 0..10 {
let rx = store.add(RequestId::new(), None);
assert!(rx.is_some());
}
let rx = store.add(RequestId::new(), None);
assert!(rx.is_none());
assert!(!store.has_capacity());
}
#[test]
fn test_resolve_nonexistent() {
let store = PendingRequestStore::new(test_config());
let id = RequestId::new();
let resolved = store.resolve(&id, Ok("response".to_string()));
assert!(!resolved);
}
#[test]
fn test_cleanup_stale() {
let config =
Arc::new(WsConfig::new("wss://test.com").request_timeout(Duration::from_millis(1)));
let store = PendingRequestStore::new(config);
let _rx = store.add(RequestId::new(), None);
assert_eq!(store.len(), 1);
std::thread::sleep(Duration::from_millis(10));
store.cleanup_stale();
assert_eq!(store.len(), 0);
}
}