athena_rs 3.3.0

Database gateway API
Documentation
use once_cell::sync::Lazy;
use serde_json::Value;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{Mutex, Notify};

static IN_FLIGHT_FETCHES: Lazy<Mutex<HashMap<String, Arc<InFlightFetch>>>> =
    Lazy::new(|| Mutex::new(HashMap::new()));

static SINGLEFLIGHT_WAIT_TIMEOUT: Lazy<Duration> = Lazy::new(|| {
    let timeout_ms: u64 = std::env::var("ATHENA_FETCH_SINGLEFLIGHT_WAIT_TIMEOUT_MS")
        .ok()
        .and_then(|value| value.parse::<u64>().ok())
        .filter(|value| *value > 0)
        .unwrap_or(5000);
    Duration::from_millis(timeout_ms)
});

pub(crate) struct InFlightFetch {
    result: Mutex<Option<Result<Vec<Value>, String>>>,
    notify: Notify,
}

impl InFlightFetch {
    fn new() -> Self {
        Self {
            result: Mutex::new(None),
            notify: Notify::new(),
        }
    }
}

pub(crate) enum SingleflightRole {
    Leader(Arc<InFlightFetch>),
    Follower(Arc<InFlightFetch>),
}

pub(crate) async fn acquire_fetch_singleflight(cache_key: &str) -> SingleflightRole {
    let mut in_flight = IN_FLIGHT_FETCHES.lock().await;

    if let Some(existing) = in_flight.get(cache_key) {
        return SingleflightRole::Follower(existing.clone());
    }

    let flight = Arc::new(InFlightFetch::new());
    in_flight.insert(cache_key.to_string(), flight.clone());
    SingleflightRole::Leader(flight)
}

pub(crate) async fn publish_fetch_singleflight_result(
    cache_key: &str,
    flight: Arc<InFlightFetch>,
    result: Result<Vec<Value>, String>,
) {
    {
        let mut shared_result = flight.result.lock().await;
        *shared_result = Some(result);
    }

    {
        let mut in_flight = IN_FLIGHT_FETCHES.lock().await;
        if let Some(existing) = in_flight.get(cache_key)
            && Arc::ptr_eq(existing, &flight)
        {
            in_flight.remove(cache_key);
        }
    }

    flight.notify.notify_waiters();
}

pub(crate) async fn wait_for_fetch_singleflight_result(
    flight: Arc<InFlightFetch>,
) -> Option<Result<Vec<Value>, String>> {
    {
        let shared_result = flight.result.lock().await;
        if let Some(result) = &*shared_result {
            return Some(result.clone());
        }
    }

    let wait_result =
        tokio::time::timeout(*SINGLEFLIGHT_WAIT_TIMEOUT, flight.notify.notified()).await;
    if wait_result.is_err() {
        return None;
    }

    let shared_result = flight.result.lock().await;
    shared_result.clone()
}

#[cfg(test)]
mod tests {
    use super::{
        SingleflightRole, acquire_fetch_singleflight, publish_fetch_singleflight_result,
        wait_for_fetch_singleflight_result,
    };
    use serde_json::json;

    #[tokio::test]
    async fn follower_receives_leader_result() {
        let key = "singleflight-test-key";

        let leader_flight = match acquire_fetch_singleflight(key).await {
            SingleflightRole::Leader(flight) => flight,
            SingleflightRole::Follower(_) => panic!("expected leader role"),
        };

        let follower_flight = match acquire_fetch_singleflight(key).await {
            SingleflightRole::Follower(flight) => flight,
            SingleflightRole::Leader(_) => panic!("expected follower role"),
        };

        let publish_task = tokio::spawn(async move {
            publish_fetch_singleflight_result(
                key,
                leader_flight,
                Ok(vec![json!({ "from": "leader" })]),
            )
            .await;
        });

        let waited = wait_for_fetch_singleflight_result(follower_flight).await;
        let _ = publish_task.await;

        assert!(waited.is_some());
        assert_eq!(waited.unwrap(), Ok(vec![json!({ "from": "leader" })]));
    }

    #[tokio::test]
    async fn entry_is_removed_after_publish() {
        let key = "singleflight-test-key-cleanup";

        let leader_flight = match acquire_fetch_singleflight(key).await {
            SingleflightRole::Leader(flight) => flight,
            SingleflightRole::Follower(_) => panic!("expected leader role"),
        };

        publish_fetch_singleflight_result(key, leader_flight, Ok(vec![json!({ "ok": true })]))
            .await;

        match acquire_fetch_singleflight(key).await {
            SingleflightRole::Leader(_) => {}
            SingleflightRole::Follower(_) => panic!("expected a new leader after cleanup"),
        }
    }
}