use once_cell::sync::Lazy;
use serde_json::Value;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{Mutex, Notify};
static IN_FLIGHT_FETCHES: Lazy<Mutex<HashMap<String, Arc<InFlightFetch>>>> =
Lazy::new(|| Mutex::new(HashMap::new()));
static SINGLEFLIGHT_WAIT_TIMEOUT: Lazy<Duration> = Lazy::new(|| {
let timeout_ms: u64 = std::env::var("ATHENA_FETCH_SINGLEFLIGHT_WAIT_TIMEOUT_MS")
.ok()
.and_then(|value| value.parse::<u64>().ok())
.filter(|value| *value > 0)
.unwrap_or(5000);
Duration::from_millis(timeout_ms)
});
pub(crate) struct InFlightFetch {
result: Mutex<Option<Result<Vec<Value>, String>>>,
notify: Notify,
}
impl InFlightFetch {
fn new() -> Self {
Self {
result: Mutex::new(None),
notify: Notify::new(),
}
}
}
pub(crate) enum SingleflightRole {
Leader(Arc<InFlightFetch>),
Follower(Arc<InFlightFetch>),
}
pub(crate) async fn acquire_fetch_singleflight(cache_key: &str) -> SingleflightRole {
let mut in_flight = IN_FLIGHT_FETCHES.lock().await;
if let Some(existing) = in_flight.get(cache_key) {
return SingleflightRole::Follower(existing.clone());
}
let flight = Arc::new(InFlightFetch::new());
in_flight.insert(cache_key.to_string(), flight.clone());
SingleflightRole::Leader(flight)
}
pub(crate) async fn publish_fetch_singleflight_result(
cache_key: &str,
flight: Arc<InFlightFetch>,
result: Result<Vec<Value>, String>,
) {
{
let mut shared_result = flight.result.lock().await;
*shared_result = Some(result);
}
{
let mut in_flight = IN_FLIGHT_FETCHES.lock().await;
if let Some(existing) = in_flight.get(cache_key)
&& Arc::ptr_eq(existing, &flight)
{
in_flight.remove(cache_key);
}
}
flight.notify.notify_waiters();
}
pub(crate) async fn wait_for_fetch_singleflight_result(
flight: Arc<InFlightFetch>,
) -> Option<Result<Vec<Value>, String>> {
{
let shared_result = flight.result.lock().await;
if let Some(result) = &*shared_result {
return Some(result.clone());
}
}
let wait_result =
tokio::time::timeout(*SINGLEFLIGHT_WAIT_TIMEOUT, flight.notify.notified()).await;
if wait_result.is_err() {
return None;
}
let shared_result = flight.result.lock().await;
shared_result.clone()
}
#[cfg(test)]
mod tests {
use super::{
SingleflightRole, acquire_fetch_singleflight, publish_fetch_singleflight_result,
wait_for_fetch_singleflight_result,
};
use serde_json::json;
#[tokio::test]
async fn follower_receives_leader_result() {
let key = "singleflight-test-key";
let leader_flight = match acquire_fetch_singleflight(key).await {
SingleflightRole::Leader(flight) => flight,
SingleflightRole::Follower(_) => panic!("expected leader role"),
};
let follower_flight = match acquire_fetch_singleflight(key).await {
SingleflightRole::Follower(flight) => flight,
SingleflightRole::Leader(_) => panic!("expected follower role"),
};
let publish_task = tokio::spawn(async move {
publish_fetch_singleflight_result(
key,
leader_flight,
Ok(vec![json!({ "from": "leader" })]),
)
.await;
});
let waited = wait_for_fetch_singleflight_result(follower_flight).await;
let _ = publish_task.await;
assert!(waited.is_some());
assert_eq!(waited.unwrap(), Ok(vec![json!({ "from": "leader" })]));
}
#[tokio::test]
async fn entry_is_removed_after_publish() {
let key = "singleflight-test-key-cleanup";
let leader_flight = match acquire_fetch_singleflight(key).await {
SingleflightRole::Leader(flight) => flight,
SingleflightRole::Follower(_) => panic!("expected leader role"),
};
publish_fetch_singleflight_result(key, leader_flight, Ok(vec![json!({ "ok": true })]))
.await;
match acquire_fetch_singleflight(key).await {
SingleflightRole::Leader(_) => {}
SingleflightRole::Follower(_) => panic!("expected a new leader after cleanup"),
}
}
}