use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use tokio::sync::broadcast;
use tokio_postgres::Row;
pub(crate) struct Singleflight {
in_flight: Mutex<HashMap<u64, broadcast::Sender<Arc<[Row]>>>>,
}
pub(crate) enum FlightStatus {
Leader,
Follower(broadcast::Receiver<Arc<[Row]>>),
}
impl Singleflight {
pub(crate) fn new() -> Self {
Self {
in_flight: Mutex::new(HashMap::new()),
}
}
pub(crate) fn try_join(&self, key: u64) -> FlightStatus {
let mut map = self.in_flight.lock().unwrap_or_else(|e| e.into_inner());
if let Some(tx) = map.get(&key) {
FlightStatus::Follower(tx.subscribe())
} else {
let (tx, _rx) = broadcast::channel(1);
map.insert(key, tx);
FlightStatus::Leader
}
}
pub(crate) fn complete(&self, key: u64, rows: Arc<[Row]>) {
let tx = {
let mut map = self.in_flight.lock().unwrap_or_else(|e| e.into_inner());
map.remove(&key)
};
if let Some(tx) = tx {
let _ = tx.send(rows);
}
}
pub(crate) fn abandon(&self, key: u64) {
let mut map = self.in_flight.lock().unwrap_or_else(|e| e.into_inner());
map.remove(&key);
}
}
pub(crate) fn sql_key(sql: &str) -> u64 {
crate::rapid_hash_str(sql)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn new_creates_empty_map() {
let sf = Singleflight::new();
let map = sf.in_flight.lock().unwrap();
assert!(map.is_empty());
}
#[test]
fn first_caller_is_leader() {
let sf = Singleflight::new();
assert!(matches!(sf.try_join(42), FlightStatus::Leader));
}
#[test]
fn second_caller_is_follower() {
let sf = Singleflight::new();
let _ = sf.try_join(42); assert!(matches!(sf.try_join(42), FlightStatus::Follower(_)));
}
#[test]
fn different_keys_are_independent() {
let sf = Singleflight::new();
let _ = sf.try_join(42);
assert!(matches!(sf.try_join(99), FlightStatus::Leader));
}
#[test]
fn complete_removes_entry() {
let sf = Singleflight::new();
let _ = sf.try_join(42);
sf.complete(42, Arc::from(Vec::<Row>::new()));
assert!(matches!(sf.try_join(42), FlightStatus::Leader));
}
#[test]
fn abandon_removes_entry() {
let sf = Singleflight::new();
let _ = sf.try_join(42);
sf.abandon(42);
assert!(matches!(sf.try_join(42), FlightStatus::Leader));
}
#[test]
fn sql_key_deterministic() {
let a = sql_key("SELECT id FROM users");
let b = sql_key("SELECT id FROM users");
assert_eq!(a, b);
}
#[test]
fn sql_key_different_sql_different_key() {
let a = sql_key("SELECT id FROM users");
let b = sql_key("SELECT name FROM users");
assert_ne!(a, b);
}
#[test]
fn complete_broadcasts_to_follower() {
let sf = Singleflight::new();
let _ = sf.try_join(42); let mut rx = match sf.try_join(42) {
FlightStatus::Follower(rx) => rx,
FlightStatus::Leader => panic!("expected follower"),
};
let rows: Arc<[Row]> = Arc::from(Vec::<Row>::new());
sf.complete(42, Arc::clone(&rows));
let received = rx.try_recv();
assert!(received.is_ok(), "follower should receive broadcast");
}
#[test]
fn abandon_closes_follower_channel() {
let sf = Singleflight::new();
let _ = sf.try_join(42); let mut rx = match sf.try_join(42) {
FlightStatus::Follower(rx) => rx,
FlightStatus::Leader => panic!("expected follower"),
};
sf.abandon(42);
let result = rx.try_recv();
assert!(
result.is_err(),
"follower channel should be closed after abandon"
);
}
#[test]
fn complete_nonexistent_key_is_noop() {
let sf = Singleflight::new();
sf.complete(999, Arc::from(Vec::<Row>::new()));
}
#[test]
fn abandon_nonexistent_key_is_noop() {
let sf = Singleflight::new();
sf.abandon(999);
}
#[test]
fn multiple_followers_all_receive() {
let sf = Singleflight::new();
let _ = sf.try_join(42);
let mut rx1 = match sf.try_join(42) {
FlightStatus::Follower(rx) => rx,
_ => panic!("expected follower"),
};
let mut rx2 = match sf.try_join(42) {
FlightStatus::Follower(rx) => rx,
_ => panic!("expected follower"),
};
let rows: Arc<[Row]> = Arc::from(Vec::<Row>::new());
sf.complete(42, rows);
assert!(rx1.try_recv().is_ok(), "follower 1 should receive");
assert!(rx2.try_recv().is_ok(), "follower 2 should receive");
}
#[test]
fn reuse_key_after_complete() {
let sf = Singleflight::new();
let _ = sf.try_join(42);
sf.complete(42, Arc::from(Vec::<Row>::new()));
assert!(matches!(sf.try_join(42), FlightStatus::Leader));
}
#[test]
fn reuse_key_after_abandon() {
let sf = Singleflight::new();
let _ = sf.try_join(42);
sf.abandon(42);
assert!(matches!(sf.try_join(42), FlightStatus::Leader));
}
#[test]
fn sql_key_case_sensitive() {
let a = sql_key("SELECT id FROM users");
let b = sql_key("select id from users");
assert_ne!(a, b, "sql_key should be case-sensitive");
}
#[test]
fn sql_key_whitespace_sensitive() {
let a = sql_key("SELECT id FROM users");
let b = sql_key("SELECT id FROM users");
assert_ne!(a, b, "sql_key should be whitespace-sensitive");
}
}