use std::collections::HashMap;
use std::hash::{Hash, Hasher};
use std::sync::{Arc, Mutex};
use tokio::sync::broadcast;
use tokio_postgres::Row;
pub(crate) struct Singleflight {
in_flight: Mutex<HashMap<u64, broadcast::Sender<Arc<Vec<Row>>>>>,
}
pub(crate) enum FlightStatus {
Leader,
Follower(broadcast::Receiver<Arc<Vec<Row>>>),
}
impl Singleflight {
pub(crate) fn new() -> Self {
Self {
in_flight: Mutex::new(HashMap::new()),
}
}
pub(crate) fn try_join(&self, key: u64) -> FlightStatus {
let mut map = self.in_flight.lock().unwrap_or_else(|e| e.into_inner());
if let Some(tx) = map.get(&key) {
FlightStatus::Follower(tx.subscribe())
} else {
let (tx, _rx) = broadcast::channel(2);
map.insert(key, tx);
FlightStatus::Leader
}
}
pub(crate) fn complete(&self, key: u64, rows: Arc<Vec<Row>>) {
let tx = {
let mut map = self.in_flight.lock().unwrap_or_else(|e| e.into_inner());
map.remove(&key)
};
if let Some(tx) = tx {
let _ = tx.send(rows);
}
}
pub(crate) fn abandon(&self, key: u64) {
let mut map = self.in_flight.lock().unwrap_or_else(|e| e.into_inner());
map.remove(&key);
}
}
pub(crate) fn sql_key(sql: &str) -> u64 {
let mut hasher = rapidhash::quality::RapidHasher::default();
sql.hash(&mut hasher);
hasher.finish()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn new_creates_empty_map() {
let sf = Singleflight::new();
let map = sf.in_flight.lock().unwrap();
assert!(map.is_empty());
}
#[test]
fn first_caller_is_leader() {
let sf = Singleflight::new();
assert!(matches!(sf.try_join(42), FlightStatus::Leader));
}
#[test]
fn second_caller_is_follower() {
let sf = Singleflight::new();
let _ = sf.try_join(42); assert!(matches!(sf.try_join(42), FlightStatus::Follower(_)));
}
#[test]
fn different_keys_are_independent() {
let sf = Singleflight::new();
let _ = sf.try_join(42);
assert!(matches!(sf.try_join(99), FlightStatus::Leader));
}
#[test]
fn complete_removes_entry() {
let sf = Singleflight::new();
let _ = sf.try_join(42);
sf.complete(42, Arc::new(Vec::new()));
assert!(matches!(sf.try_join(42), FlightStatus::Leader));
}
#[test]
fn abandon_removes_entry() {
let sf = Singleflight::new();
let _ = sf.try_join(42);
sf.abandon(42);
assert!(matches!(sf.try_join(42), FlightStatus::Leader));
}
#[test]
fn sql_key_deterministic() {
let a = sql_key("SELECT id FROM users");
let b = sql_key("SELECT id FROM users");
assert_eq!(a, b);
}
#[test]
fn sql_key_different_sql_different_key() {
let a = sql_key("SELECT id FROM users");
let b = sql_key("SELECT name FROM users");
assert_ne!(a, b);
}
}