use async_graphql::Value as GqlValue;
use parking_lot::RwLock; use sha2::{Digest, Sha256};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::{watch, Mutex as TokioMutex};
#[derive(Debug, Clone)]
pub struct RequestCollapsingConfig {
pub coalesce_window: Duration,
pub max_waiters: usize,
pub enabled: bool,
pub max_cache_size: usize,
}
impl Default for RequestCollapsingConfig {
fn default() -> Self {
Self {
coalesce_window: Duration::from_millis(50),
max_waiters: 100,
enabled: true,
max_cache_size: 10000,
}
}
}
impl RequestCollapsingConfig {
pub fn new() -> Self {
Self::default()
}
pub fn coalesce_window(mut self, duration: Duration) -> Self {
self.coalesce_window = duration;
self
}
pub fn max_waiters(mut self, max: usize) -> Self {
self.max_waiters = max;
self
}
pub fn enabled(mut self, enabled: bool) -> Self {
self.enabled = enabled;
self
}
pub fn max_cache_size(mut self, size: usize) -> Self {
self.max_cache_size = size;
self
}
pub fn high_throughput() -> Self {
Self {
coalesce_window: Duration::from_millis(100),
max_waiters: 500,
enabled: true,
max_cache_size: 50000,
}
}
pub fn low_latency() -> Self {
Self {
coalesce_window: Duration::from_millis(10),
max_waiters: 50,
enabled: true,
max_cache_size: 5000,
}
}
pub fn disabled() -> Self {
Self {
enabled: false,
..Default::default()
}
}
}
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
pub struct RequestKey {
hash: String,
}
impl RequestKey {
pub fn new(service_name: &str, grpc_path: &str, request_bytes: &[u8]) -> Self {
let mut hasher = Sha256::new();
hasher.update(service_name.as_bytes());
hasher.update(b":");
hasher.update(grpc_path.as_bytes());
hasher.update(b":");
hasher.update(request_bytes);
let hash = hex::encode(hasher.finalize());
Self { hash }
}
pub fn hash(&self) -> &str {
&self.hash
}
}
#[derive(Debug)]
pub enum CollapseResult {
Leader(RequestBroadcaster),
Follower(RequestReceiver),
Passthrough,
}
#[derive(Debug)]
pub struct RequestBroadcaster {
key: RequestKey,
sender: watch::Sender<Option<Arc<Result<GqlValue, String>>>>,
registry: Arc<RequestCollapsingRegistry>,
}
impl RequestBroadcaster {
pub fn broadcast(self, result: Result<GqlValue, String>) {
let result = Arc::new(result);
let _ = self.sender.send(Some(result));
self.registry.remove(&self.key);
}
pub fn key(&self) -> &RequestKey {
&self.key
}
}
#[derive(Debug)]
pub struct RequestReceiver {
receiver: watch::Receiver<Option<Arc<Result<GqlValue, String>>>>,
}
impl RequestReceiver {
pub async fn recv(mut self) -> Result<GqlValue, String> {
loop {
if let Some(result) = &*self.receiver.borrow() {
return (**result).clone();
}
if self.receiver.changed().await.is_err() {
return Err("Request leader dropped without sending result".to_string());
}
}
}
}
struct InFlightRequest {
sender: watch::Sender<Option<Arc<Result<GqlValue, String>>>>,
started_at: Instant,
waiter_count: usize,
}
pub struct RequestCollapsingRegistry {
config: RequestCollapsingConfig,
in_flight: RwLock<HashMap<RequestKey, Arc<TokioMutex<InFlightRequest>>>>,
}
impl std::fmt::Debug for RequestCollapsingRegistry {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let in_flight_count = self.in_flight.read().len();
f.debug_struct("RequestCollapsingRegistry")
.field("config", &self.config)
.field("in_flight_count", &in_flight_count)
.finish()
}
}
impl RequestCollapsingRegistry {
pub fn new(config: RequestCollapsingConfig) -> Self {
Self {
config,
in_flight: RwLock::new(HashMap::new()),
}
}
pub async fn try_collapse(&self, key: RequestKey) -> CollapseResult {
if !self.config.enabled {
return CollapseResult::Passthrough;
}
let entry_opt = {
let in_flight = self.in_flight.read();
in_flight.get(&key).cloned()
};
if let Some(entry) = entry_opt {
let mut guard = entry.lock().await;
if guard.started_at.elapsed() < self.config.coalesce_window
&& guard.waiter_count < self.config.max_waiters
{
guard.waiter_count += 1;
let receiver = guard.sender.subscribe();
return CollapseResult::Follower(RequestReceiver { receiver });
}
}
let (sender, _) = watch::channel(None);
let entry = Arc::new(TokioMutex::new(InFlightRequest {
sender: sender.clone(),
started_at: Instant::now(),
waiter_count: 0,
}));
{
let mut in_flight = self.in_flight.write();
if in_flight.len() >= self.config.max_cache_size {
self.evict_stale_entries(&mut in_flight);
}
in_flight.insert(key.clone(), entry);
}
CollapseResult::Leader(RequestBroadcaster {
key,
sender,
registry: Arc::new(Self::new(self.config.clone())),
})
}
fn remove(&self, key: &RequestKey) {
let mut in_flight = self.in_flight.write();
in_flight.remove(key);
}
fn evict_stale_entries(
&self,
in_flight: &mut HashMap<RequestKey, Arc<TokioMutex<InFlightRequest>>>,
) {
let stale_threshold = self.config.coalesce_window * 10;
let now = Instant::now();
let keys_to_remove: Vec<RequestKey> = in_flight
.iter()
.filter_map(|(key, entry)| {
if let Ok(guard) = entry.try_lock() {
if now.duration_since(guard.started_at) > stale_threshold {
return Some(key.clone());
}
}
None
})
.collect();
for key in keys_to_remove {
in_flight.remove(&key);
}
}
pub fn stats(&self) -> CollapsingStats {
let in_flight = self.in_flight.read();
CollapsingStats {
in_flight_count: in_flight.len(),
max_cache_size: self.config.max_cache_size,
enabled: self.config.enabled,
}
}
}
pub type SharedRequestCollapsingRegistry = Arc<RequestCollapsingRegistry>;
pub fn create_request_collapsing_registry(
config: RequestCollapsingConfig,
) -> SharedRequestCollapsingRegistry {
Arc::new(RequestCollapsingRegistry::new(config))
}
#[derive(Debug, Clone)]
pub struct CollapsingStats {
pub in_flight_count: usize,
pub max_cache_size: usize,
pub enabled: bool,
}
#[derive(Debug, Default)]
pub struct CollapsingMetrics {
pub total_requests: std::sync::atomic::AtomicU64,
pub leader_requests: std::sync::atomic::AtomicU64,
pub collapsed_requests: std::sync::atomic::AtomicU64,
pub passthrough_requests: std::sync::atomic::AtomicU64,
}
impl CollapsingMetrics {
pub fn new() -> Self {
Self::default()
}
pub fn record_leader(&self) {
use std::sync::atomic::Ordering;
self.total_requests.fetch_add(1, Ordering::Relaxed);
self.leader_requests.fetch_add(1, Ordering::Relaxed);
}
pub fn record_collapsed(&self) {
use std::sync::atomic::Ordering;
self.total_requests.fetch_add(1, Ordering::Relaxed);
self.collapsed_requests.fetch_add(1, Ordering::Relaxed);
}
pub fn record_passthrough(&self) {
use std::sync::atomic::Ordering;
self.total_requests.fetch_add(1, Ordering::Relaxed);
self.passthrough_requests.fetch_add(1, Ordering::Relaxed);
}
pub fn collapse_ratio(&self) -> f64 {
use std::sync::atomic::Ordering;
let total = self.total_requests.load(Ordering::Relaxed);
let collapsed = self.collapsed_requests.load(Ordering::Relaxed);
if total == 0 {
0.0
} else {
collapsed as f64 / total as f64
}
}
pub fn snapshot(&self) -> CollapsingMetricsSnapshot {
use std::sync::atomic::Ordering;
CollapsingMetricsSnapshot {
total_requests: self.total_requests.load(Ordering::Relaxed),
leader_requests: self.leader_requests.load(Ordering::Relaxed),
collapsed_requests: self.collapsed_requests.load(Ordering::Relaxed),
passthrough_requests: self.passthrough_requests.load(Ordering::Relaxed),
collapse_ratio: self.collapse_ratio(),
}
}
}
#[derive(Debug, Clone, serde::Serialize)]
pub struct CollapsingMetricsSnapshot {
pub total_requests: u64,
pub leader_requests: u64,
pub collapsed_requests: u64,
pub passthrough_requests: u64,
pub collapse_ratio: f64,
}
pub type SharedCollapsingMetrics = Arc<CollapsingMetrics>;
pub fn create_collapsing_metrics() -> SharedCollapsingMetrics {
Arc::new(CollapsingMetrics::new())
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::time::sleep;
#[test]
fn test_request_key_creation() {
let key1 = RequestKey::new("service", "/path", b"request1");
let key2 = RequestKey::new("service", "/path", b"request1");
let key3 = RequestKey::new("service", "/path", b"request2");
assert_eq!(key1, key2);
assert_ne!(key1, key3);
assert_ne!(key1.hash(), key3.hash());
}
#[test]
fn test_config_defaults() {
let config = RequestCollapsingConfig::default();
assert!(config.enabled);
assert_eq!(config.max_waiters, 100);
assert_eq!(config.max_cache_size, 10000);
}
#[test]
fn test_config_presets() {
let high_throughput = RequestCollapsingConfig::high_throughput();
assert!(high_throughput.coalesce_window > Duration::from_millis(50));
assert_eq!(high_throughput.max_waiters, 500);
let low_latency = RequestCollapsingConfig::low_latency();
assert!(low_latency.coalesce_window < Duration::from_millis(50));
assert_eq!(low_latency.max_waiters, 50);
let disabled = RequestCollapsingConfig::disabled();
assert!(!disabled.enabled);
}
#[tokio::test]
async fn test_leader_selection() {
let config = RequestCollapsingConfig::default();
let registry = RequestCollapsingRegistry::new(config);
let key = RequestKey::new("service", "/path", b"request");
let result = registry.try_collapse(key).await;
match result {
CollapseResult::Leader(_) => {}
_ => panic!("Expected leader result"),
}
}
#[tokio::test]
async fn test_follower_selection() {
let config = RequestCollapsingConfig::default();
let registry = Arc::new(RequestCollapsingRegistry::new(config));
let key = RequestKey::new("service", "/path", b"request");
let result1 = registry.try_collapse(key.clone()).await;
let broadcaster = match result1 {
CollapseResult::Leader(b) => b,
_ => panic!("Expected leader result"),
};
let result2 = registry.try_collapse(key).await;
match result2 {
CollapseResult::Follower(_) => {}
_ => panic!("Expected follower result"),
}
drop(broadcaster);
}
#[tokio::test]
async fn test_disabled_collapsing() {
let config = RequestCollapsingConfig::disabled();
let registry = RequestCollapsingRegistry::new(config);
let key = RequestKey::new("service", "/path", b"request");
let result = registry.try_collapse(key).await;
match result {
CollapseResult::Passthrough => {}
_ => panic!("Expected passthrough result"),
}
}
#[tokio::test]
async fn test_result_broadcasting() {
let config = RequestCollapsingConfig::default();
let registry = Arc::new(RequestCollapsingRegistry::new(config));
let key = RequestKey::new("service", "/path", b"request");
let result1 = registry.try_collapse(key.clone()).await;
let broadcaster = match result1 {
CollapseResult::Leader(b) => b,
_ => panic!("Expected leader result"),
};
let result2 = registry.try_collapse(key).await;
let receiver = match result2 {
CollapseResult::Follower(r) => r,
_ => panic!("Expected follower result"),
};
let expected = GqlValue::String("test".to_string());
broadcaster.broadcast(Ok(expected.clone()));
let received = receiver.recv().await.unwrap();
assert_eq!(received, expected);
}
#[tokio::test]
async fn test_error_broadcasting() {
let config = RequestCollapsingConfig::default();
let registry = Arc::new(RequestCollapsingRegistry::new(config));
let key = RequestKey::new("service", "/path", b"request_err");
let broadcaster = match registry.try_collapse(key.clone()).await {
CollapseResult::Leader(b) => b,
_ => panic!("Expected leader"),
};
let receiver = match registry.try_collapse(key).await {
CollapseResult::Follower(r) => r,
_ => panic!("Expected follower"),
};
let error_msg = "test error".to_string();
broadcaster.broadcast(Err(error_msg.clone()));
let received = receiver.recv().await;
assert_eq!(received, Err(error_msg));
}
#[tokio::test]
async fn test_expired_requests_create_new_leader() {
let config = RequestCollapsingConfig::new().coalesce_window(Duration::from_millis(10));
let registry = Arc::new(RequestCollapsingRegistry::new(config));
let key = RequestKey::new("service", "/path", b"request");
let _ = registry.try_collapse(key.clone()).await;
sleep(Duration::from_millis(20)).await;
let result2 = registry.try_collapse(key).await;
match result2 {
CollapseResult::Leader(_) => {}
_ => panic!("Expected new leader after expiry"),
}
}
#[tokio::test]
async fn test_max_waiters_reached() {
let config = RequestCollapsingConfig::default().max_waiters(1);
let registry = Arc::new(RequestCollapsingRegistry::new(config));
let key = RequestKey::new("param", "waiters", b"");
let _ = registry.try_collapse(key.clone()).await;
match registry.try_collapse(key.clone()).await {
CollapseResult::Follower(_) => {}
_ => panic!("Second should be follower"),
}
match registry.try_collapse(key).await {
CollapseResult::Leader(_) => {}
res => panic!("Third should be new leader. Got: {:?}", res),
}
}
#[tokio::test]
async fn test_eviction() {
let config = RequestCollapsingConfig::default()
.max_cache_size(1)
.coalesce_window(Duration::from_millis(1));
let registry = RequestCollapsingRegistry::new(config);
let key1 = RequestKey::new("s", "p", b"1");
let _ = registry.try_collapse(key1.clone()).await;
assert_eq!(registry.stats().in_flight_count, 1);
sleep(Duration::from_millis(15)).await;
let key2 = RequestKey::new("s", "p", b"2");
let _ = registry.try_collapse(key2).await;
assert_eq!(registry.stats().in_flight_count, 1);
match registry.try_collapse(key1).await {
CollapseResult::Leader(_) => {}
_ => panic!("Key1 should have been evicted"),
}
}
#[test]
fn test_metrics() {
let metrics = CollapsingMetrics::new();
metrics.record_leader();
metrics.record_collapsed();
metrics.record_collapsed();
metrics.record_passthrough();
let snapshot = metrics.snapshot();
assert_eq!(snapshot.total_requests, 4);
assert_eq!(snapshot.leader_requests, 1);
assert_eq!(snapshot.collapsed_requests, 2);
assert_eq!(snapshot.passthrough_requests, 1);
assert!((snapshot.collapse_ratio - 0.5).abs() < 0.01);
}
#[test]
fn test_stats() {
let config = RequestCollapsingConfig::default();
let registry = RequestCollapsingRegistry::new(config);
let stats = registry.stats();
assert_eq!(stats.in_flight_count, 0);
assert!(stats.enabled);
}
}
#[cfg(test)]
mod proptest_checks {
use super::*;
use proptest::prelude::*;
use tokio::runtime::Runtime;
proptest! {
#[test]
fn fuzz_collapsing_mixed_traffic(
requests in proptest::collection::vec(("[a-c]", 0..5u64), 10..50)
) {
let rt = Runtime::new().unwrap();
rt.block_on(async {
let config = RequestCollapsingConfig::new()
.coalesce_window(Duration::from_millis(20));
let registry = Arc::new(RequestCollapsingRegistry::new(config));
let mut handles = vec![];
for (suffix, delay_ms) in requests {
let registry = registry.clone();
let key = RequestKey::new("svc", &format!("path_{}", suffix), b"data");
handles.push(tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(delay_ms)).await;
match registry.try_collapse(key).await {
CollapseResult::Leader(broadcaster) => {
tokio::time::sleep(Duration::from_millis(10)).await;
broadcaster.broadcast(Ok(async_graphql::Value::String("result".to_string())));
}
CollapseResult::Follower(receiver) => {
let _ = tokio::time::timeout(Duration::from_millis(100), receiver.recv()).await;
}
CollapseResult::Passthrough => {}
}
}));
}
futures::future::join_all(handles).await;
});
}
}
}