use serde_json::Value;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{Mutex, Notify};
pub type GatewayFetchRowsResult = Result<Vec<Value>, String>;
pub struct GatewayInFlightFetch {
result: Mutex<Option<GatewayFetchRowsResult>>,
notify: Notify,
}
impl GatewayInFlightFetch {
fn new() -> Self {
Self {
result: Mutex::new(None),
notify: Notify::new(),
}
}
}
pub enum GatewayFetchSingleflightRole {
Leader(Arc<GatewayInFlightFetch>),
Follower(Arc<GatewayInFlightFetch>),
}
pub struct GatewayFetchSingleflight {
in_flight: Mutex<HashMap<String, Arc<GatewayInFlightFetch>>>,
wait_timeout: Duration,
}
impl GatewayFetchSingleflight {
pub fn new(wait_timeout: Duration) -> Self {
Self {
in_flight: Mutex::new(HashMap::new()),
wait_timeout,
}
}
pub fn wait_timeout(&self) -> Duration {
self.wait_timeout
}
pub async fn acquire(&self, cache_key: &str) -> GatewayFetchSingleflightRole {
let mut in_flight = self.in_flight.lock().await;
if let Some(existing) = in_flight.get(cache_key) {
return GatewayFetchSingleflightRole::Follower(existing.clone());
}
let flight = Arc::new(GatewayInFlightFetch::new());
in_flight.insert(cache_key.to_string(), flight.clone());
GatewayFetchSingleflightRole::Leader(flight)
}
pub async fn publish_result(
&self,
cache_key: &str,
flight: Arc<GatewayInFlightFetch>,
result: GatewayFetchRowsResult,
) {
{
let mut shared_result = flight.result.lock().await;
*shared_result = Some(result);
}
{
let mut in_flight = self.in_flight.lock().await;
if let Some(existing) = in_flight.get(cache_key)
&& Arc::ptr_eq(existing, &flight)
{
in_flight.remove(cache_key);
}
}
flight.notify.notify_waiters();
}
pub async fn wait_for_result(
&self,
flight: Arc<GatewayInFlightFetch>,
) -> Option<GatewayFetchRowsResult> {
{
let shared_result = flight.result.lock().await;
if let Some(result) = &*shared_result {
return Some(result.clone());
}
}
let wait_result = tokio::time::timeout(self.wait_timeout, flight.notify.notified()).await;
if wait_result.is_err() {
return None;
}
let shared_result = flight.result.lock().await;
shared_result.clone()
}
}
#[cfg(test)]
mod tests {
use super::{GatewayFetchSingleflight, GatewayFetchSingleflightRole, GatewayInFlightFetch};
use serde_json::json;
use std::sync::Arc;
use std::time::Duration;
fn coordinator(timeout_ms: u64) -> GatewayFetchSingleflight {
GatewayFetchSingleflight::new(Duration::from_millis(timeout_ms))
}
#[tokio::test(flavor = "current_thread")]
async fn follower_receives_leader_result() {
let key = "singleflight-test-key";
let coordinator = coordinator(100);
let leader_flight: Arc<GatewayInFlightFetch> = match coordinator.acquire(key).await {
GatewayFetchSingleflightRole::Leader(flight) => flight,
GatewayFetchSingleflightRole::Follower(_) => panic!("expected leader role"),
};
let follower_flight: Arc<GatewayInFlightFetch> = match coordinator.acquire(key).await {
GatewayFetchSingleflightRole::Follower(flight) => flight,
GatewayFetchSingleflightRole::Leader(_) => panic!("expected follower role"),
};
coordinator
.publish_result(key, leader_flight, Ok(vec![json!({ "from": "leader" })]))
.await;
let waited = coordinator.wait_for_result(follower_flight).await;
assert!(waited.is_some());
assert_eq!(waited.unwrap(), Ok(vec![json!({ "from": "leader" })]));
}
#[tokio::test(flavor = "current_thread")]
async fn entry_is_removed_after_publish() {
let key = "singleflight-test-key-cleanup";
let coordinator = coordinator(100);
let leader_flight: Arc<GatewayInFlightFetch> = match coordinator.acquire(key).await {
GatewayFetchSingleflightRole::Leader(flight) => flight,
GatewayFetchSingleflightRole::Follower(_) => panic!("expected leader role"),
};
coordinator
.publish_result(key, leader_flight, Ok(vec![json!({ "ok": true })]))
.await;
match coordinator.acquire(key).await {
GatewayFetchSingleflightRole::Leader(_) => {}
GatewayFetchSingleflightRole::Follower(_) => {
panic!("expected a new leader after cleanup")
}
}
}
#[tokio::test(flavor = "current_thread")]
async fn follower_times_out_without_result() {
let key = "singleflight-test-key-timeout";
let coordinator = coordinator(1);
let _leader_flight: Arc<GatewayInFlightFetch> = match coordinator.acquire(key).await {
GatewayFetchSingleflightRole::Leader(flight) => flight,
GatewayFetchSingleflightRole::Follower(_) => panic!("expected leader role"),
};
let follower_flight: Arc<GatewayInFlightFetch> = match coordinator.acquire(key).await {
GatewayFetchSingleflightRole::Follower(flight) => flight,
GatewayFetchSingleflightRole::Leader(_) => panic!("expected follower role"),
};
let waited = coordinator.wait_for_result(follower_flight).await;
assert!(waited.is_none());
}
}