use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use tokio::sync::broadcast;
use crate::error::BsqlError;
type SharedResult = Arc<Result<Arc<OwnedResultSnapshot>, BsqlError>>;
type InFlightMap = Arc<Mutex<HashMap<u64, broadcast::Sender<SharedResult>>>>;
pub struct OwnedResultSnapshot {
pub result: bsql_driver_postgres::QueryResult,
pub arena: bsql_driver_postgres::Arena,
}
pub struct Singleflight {
in_flight: InFlightMap,
}
pub enum FlightResult {
Leader(FlightLeader),
Follower(broadcast::Receiver<SharedResult>),
}
pub struct FlightLeader {
key: u64,
tx: broadcast::Sender<SharedResult>,
in_flight: Option<InFlightMap>,
}
impl FlightLeader {
pub fn complete(mut self, sf: &Singleflight, result: SharedResult) {
sf.in_flight
.lock()
.unwrap_or_else(|e| e.into_inner())
.remove(&self.key);
self.in_flight = None;
let _ = self.tx.send(result);
}
}
impl Drop for FlightLeader {
fn drop(&mut self) {
if let Some(ref map) = self.in_flight {
map.lock()
.unwrap_or_else(|e| e.into_inner())
.remove(&self.key);
}
}
}
impl Singleflight {
pub fn new() -> Self {
Self {
in_flight: Arc::new(Mutex::new(HashMap::new())),
}
}
pub fn try_join(&self, key: u64) -> FlightResult {
let mut map = self.in_flight.lock().unwrap_or_else(|e| e.into_inner());
if let Some(tx) = map.get(&key) {
FlightResult::Follower(tx.subscribe())
} else {
let (tx, _) = broadcast::channel(1);
map.insert(key, tx.clone());
FlightResult::Leader(FlightLeader {
key,
tx,
in_flight: Some(Arc::clone(&self.in_flight)),
})
}
}
pub fn compute_key(
sql_hash: u64,
params: &[&(dyn bsql_driver_postgres::Encode + Sync)],
) -> u64 {
use std::hash::{Hash, Hasher};
let mut hasher = rapidhash::quality::RapidHasher::default();
sql_hash.hash(&mut hasher);
let mut scratch = Vec::with_capacity(64);
for param in params {
if param.is_null() {
hasher.write_u8(0xFF); } else {
scratch.clear();
param.encode_binary(&mut scratch);
hasher.write(&scratch);
}
}
hasher.finish()
}
}
impl Default for Singleflight {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn singleflight_leader_when_empty() {
let sf = Singleflight::new();
let result = sf.try_join(42);
assert!(matches!(result, FlightResult::Leader(_)));
}
#[test]
fn singleflight_follower_when_in_flight() {
let sf = Singleflight::new();
let _leader = sf.try_join(42);
let result = sf.try_join(42);
assert!(matches!(result, FlightResult::Follower(_)));
}
#[test]
fn singleflight_different_keys_both_leaders() {
let sf = Singleflight::new();
let r1 = sf.try_join(42);
let r2 = sf.try_join(43);
assert!(matches!(r1, FlightResult::Leader(_)));
assert!(matches!(r2, FlightResult::Leader(_)));
}
#[test]
fn singleflight_complete_removes_from_map() {
let sf = Singleflight::new();
let leader = match sf.try_join(42) {
FlightResult::Leader(l) => l,
_ => panic!("expected leader"),
};
let err = BsqlError::from(bsql_driver_postgres::DriverError::Pool("test".into()));
leader.complete(&sf, Arc::new(Err(err)));
let result = sf.try_join(42);
assert!(matches!(result, FlightResult::Leader(_)));
}
#[test]
fn compute_key_same_inputs_same_key() {
let k1 = Singleflight::compute_key(123, &[]);
let k2 = Singleflight::compute_key(123, &[]);
assert_eq!(k1, k2);
}
#[test]
fn compute_key_different_sql_hash_different_key() {
let k1 = Singleflight::compute_key(123, &[]);
let k2 = Singleflight::compute_key(456, &[]);
assert_ne!(k1, k2);
}
#[test]
fn compute_key_same_params_same_key() {
let a = 42i32;
let b = 42i32;
let k1 = Singleflight::compute_key(100, &[&a]);
let k2 = Singleflight::compute_key(100, &[&b]);
assert_eq!(k1, k2);
}
#[test]
fn compute_key_different_params_different_key() {
let a = 42i32;
let b = 99i32;
let k1 = Singleflight::compute_key(100, &[&a]);
let k2 = Singleflight::compute_key(100, &[&b]);
assert_ne!(k1, k2);
}
#[test]
fn compute_key_different_sql_same_params_different_key() {
let a = 42i32;
let k1 = Singleflight::compute_key(100, &[&a]);
let k2 = Singleflight::compute_key(200, &[&a]);
assert_ne!(k1, k2);
}
#[test]
fn compute_key_null_param_handling() {
let null_val: Option<i32> = None;
let some_val: Option<i32> = Some(42);
let k1 = Singleflight::compute_key(100, &[&null_val]);
let k2 = Singleflight::compute_key(100, &[&some_val]);
assert_ne!(k1, k2, "NULL and Some(42) should produce different keys");
}
#[test]
fn compute_key_two_nulls_same_key() {
let a: Option<i32> = None;
let b: Option<i32> = None;
let k1 = Singleflight::compute_key(100, &[&a]);
let k2 = Singleflight::compute_key(100, &[&b]);
assert_eq!(k1, k2);
}
#[test]
fn compute_key_multiple_params() {
let a = 1i32;
let b = "hello";
let k1 = Singleflight::compute_key(100, &[&a, &b]);
let k2 = Singleflight::compute_key(100, &[&a, &b]);
assert_eq!(k1, k2);
}
#[test]
fn compute_key_param_order_matters() {
let a = 1i32;
let b = 2i32;
let k1 = Singleflight::compute_key(100, &[&a, &b]);
let k2 = Singleflight::compute_key(100, &[&b, &a]);
assert_ne!(k1, k2);
}
#[tokio::test]
async fn leader_complete_broadcasts_to_follower() {
let sf = Singleflight::new();
let leader = match sf.try_join(42) {
FlightResult::Leader(l) => l,
_ => panic!("expected leader"),
};
let mut rx = match sf.try_join(42) {
FlightResult::Follower(rx) => rx,
_ => panic!("expected follower"),
};
let err = BsqlError::from(bsql_driver_postgres::DriverError::Pool("test".into()));
leader.complete(&sf, Arc::new(Err(err)));
let received = rx.recv().await.unwrap();
assert!(received.is_err());
}
#[tokio::test]
async fn multiple_followers_receive_result() {
let sf = Singleflight::new();
let leader = match sf.try_join(42) {
FlightResult::Leader(l) => l,
_ => panic!("expected leader"),
};
let mut rx1 = match sf.try_join(42) {
FlightResult::Follower(rx) => rx,
_ => panic!("expected follower 1"),
};
let mut rx2 = match sf.try_join(42) {
FlightResult::Follower(rx) => rx,
_ => panic!("expected follower 2"),
};
let err = BsqlError::from(bsql_driver_postgres::DriverError::Pool("done".into()));
leader.complete(&sf, Arc::new(Err(err)));
let r1 = rx1.recv().await.unwrap();
let r2 = rx2.recv().await.unwrap();
assert!(r1.is_err());
assert!(r2.is_err());
}
#[test]
fn drop_leader_without_complete_cleans_up_map() {
let sf = Singleflight::new();
let leader = match sf.try_join(42) {
FlightResult::Leader(l) => l,
_ => panic!("expected leader"),
};
drop(leader);
let result = sf.try_join(42);
assert!(
matches!(result, FlightResult::Leader(_)),
"key should be removed from map after leader drop without complete"
);
}
#[tokio::test]
async fn concurrent_stress_test() {
use std::sync::atomic::{AtomicUsize, Ordering};
use tokio::task;
let sf = Arc::new(Singleflight::new());
let leader_count = Arc::new(AtomicUsize::new(0));
let follower_count = Arc::new(AtomicUsize::new(0));
let mut handles = Vec::new();
for i in 0..10 {
let sf = Arc::clone(&sf);
let leaders = Arc::clone(&leader_count);
let followers = Arc::clone(&follower_count);
let key = (i % 5) as u64;
handles.push(task::spawn(async move {
match sf.try_join(key) {
FlightResult::Leader(leader) => {
leaders.fetch_add(1, Ordering::Relaxed);
let err = BsqlError::from(bsql_driver_postgres::DriverError::Pool(
"stress".into(),
));
leader.complete(&sf, Arc::new(Err(err)));
}
FlightResult::Follower(_rx) => {
followers.fetch_add(1, Ordering::Relaxed);
}
}
}));
}
for h in handles {
h.await.unwrap();
}
let total = leader_count.load(Ordering::Relaxed) + follower_count.load(Ordering::Relaxed);
assert_eq!(total, 10, "all 10 tasks should participate");
assert!(
leader_count.load(Ordering::Relaxed) >= 5,
"should have at least 5 leaders (one per key)"
);
}
#[test]
fn singleflight_default() {
let sf = Singleflight::default();
let result = sf.try_join(1);
assert!(matches!(result, FlightResult::Leader(_)));
}
fn _assert_send<T: Send>() {}
fn _assert_sync<T: Sync>() {}
#[test]
fn singleflight_is_send_and_sync() {
_assert_send::<Singleflight>();
_assert_sync::<Singleflight>();
}
#[test]
fn compute_key_string_params() {
let a = "hello";
let b = "world";
let k1 = Singleflight::compute_key(100, &[&a, &b]);
let k2 = Singleflight::compute_key(100, &[&a, &b]);
assert_eq!(k1, k2);
}
#[test]
fn compute_key_empty_params_consistent() {
let k1 = Singleflight::compute_key(0, &[]);
let k2 = Singleflight::compute_key(0, &[]);
assert_eq!(k1, k2);
}
#[test]
fn leader_complete_with_no_followers() {
let sf = Singleflight::new();
let leader = match sf.try_join(42) {
FlightResult::Leader(l) => l,
_ => panic!("expected leader"),
};
let err = BsqlError::from(bsql_driver_postgres::DriverError::Pool("solo".into()));
leader.complete(&sf, Arc::new(Err(err)));
let result = sf.try_join(42);
assert!(matches!(result, FlightResult::Leader(_)));
}
#[tokio::test]
async fn follower_gets_error_when_leader_dropped_without_complete() {
let sf = Singleflight::new();
let leader = match sf.try_join(42) {
FlightResult::Leader(l) => l,
_ => panic!("expected leader"),
};
let mut rx = match sf.try_join(42) {
FlightResult::Follower(rx) => rx,
_ => panic!("expected follower"),
};
drop(leader);
let result = rx.recv().await;
assert!(
result.is_err(),
"follower should get RecvError when leader is dropped without complete"
);
}
#[tokio::test]
async fn new_leader_succeeds_after_previous_leader_dropped() {
let sf = Arc::new(Singleflight::new());
let leader1 = match sf.try_join(42) {
FlightResult::Leader(l) => l,
_ => panic!("expected leader"),
};
drop(leader1);
let leader2 = match sf.try_join(42) {
FlightResult::Leader(l) => l,
_ => panic!("expected new leader after previous leader drop"),
};
let mut rx = match sf.try_join(42) {
FlightResult::Follower(rx) => rx,
_ => panic!("expected follower for second leader"),
};
let err = BsqlError::from(bsql_driver_postgres::DriverError::Pool("retry".into()));
leader2.complete(&sf, Arc::new(Err(err)));
let received = rx.recv().await.unwrap();
assert!(received.is_err());
}
}