use std::collections::HashMap;
use std::sync::{Arc, Condvar, Mutex};
use crate::error::BsqlError;
type SharedResult = Arc<Result<Arc<OwnedResultSnapshot>, BsqlError>>;
pub struct FlightState {
result: Mutex<Option<SharedResult>>,
condvar: Condvar,
}
type InFlightMap = Arc<Mutex<HashMap<u64, Arc<FlightState>>>>;
pub struct OwnedResultSnapshot {
pub result: bsql_driver_postgres::QueryResult,
pub arena: bsql_driver_postgres::Arena,
}
pub struct Singleflight {
in_flight: InFlightMap,
}
pub enum FlightResult {
Leader(FlightLeader),
Follower(Arc<FlightState>),
}
pub struct FlightLeader {
key: u64,
state: Arc<FlightState>,
in_flight: Option<InFlightMap>,
}
impl FlightLeader {
pub fn complete(mut self, sf: &Singleflight, result: SharedResult) {
sf.in_flight
.lock()
.unwrap_or_else(|e| e.into_inner())
.remove(&self.key);
self.in_flight = None;
*self.state.result.lock().unwrap_or_else(|e| e.into_inner()) = Some(result);
self.state.condvar.notify_all();
}
}
impl Drop for FlightLeader {
fn drop(&mut self) {
if let Some(ref map) = self.in_flight {
map.lock()
.unwrap_or_else(|e| e.into_inner())
.remove(&self.key);
self.state.condvar.notify_all();
}
}
}
impl Singleflight {
pub fn new() -> Self {
Self {
in_flight: Arc::new(Mutex::new(HashMap::new())),
}
}
pub fn try_join(&self, key: u64) -> FlightResult {
let mut map = self.in_flight.lock().unwrap_or_else(|e| e.into_inner());
if let Some(state) = map.get(&key) {
FlightResult::Follower(Arc::clone(state))
} else {
let state = Arc::new(FlightState {
result: Mutex::new(None),
condvar: Condvar::new(),
});
map.insert(key, Arc::clone(&state));
FlightResult::Leader(FlightLeader {
key,
state,
in_flight: Some(Arc::clone(&self.in_flight)),
})
}
}
pub fn wait_for_result(state: &FlightState) -> Option<SharedResult> {
let mut guard = state.result.lock().unwrap_or_else(|e| e.into_inner());
while guard.is_none() {
guard = state.condvar.wait(guard).unwrap_or_else(|e| e.into_inner());
}
guard.clone()
}
pub fn compute_key(
sql_hash: u64,
params: &[&(dyn bsql_driver_postgres::Encode + Sync)],
) -> u64 {
use std::hash::{Hash, Hasher};
let mut hasher = rapidhash::quality::RapidHasher::default();
sql_hash.hash(&mut hasher);
let mut scratch = Vec::with_capacity(64);
for param in params {
if param.is_null() {
hasher.write_u8(0xFF); } else {
scratch.clear();
param.encode_binary(&mut scratch);
hasher.write(&scratch);
}
}
hasher.finish()
}
}
impl Default for Singleflight {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn singleflight_leader_when_empty() {
let sf = Singleflight::new();
let result = sf.try_join(42);
assert!(matches!(result, FlightResult::Leader(_)));
}
#[test]
fn singleflight_follower_when_in_flight() {
let sf = Singleflight::new();
let _leader = sf.try_join(42);
let result = sf.try_join(42);
assert!(matches!(result, FlightResult::Follower(_)));
}
#[test]
fn singleflight_different_keys_both_leaders() {
let sf = Singleflight::new();
let r1 = sf.try_join(42);
let r2 = sf.try_join(43);
assert!(matches!(r1, FlightResult::Leader(_)));
assert!(matches!(r2, FlightResult::Leader(_)));
}
#[test]
fn singleflight_complete_removes_from_map() {
let sf = Singleflight::new();
let leader = match sf.try_join(42) {
FlightResult::Leader(l) => l,
_ => panic!("expected leader"),
};
let err = BsqlError::from(bsql_driver_postgres::DriverError::Pool("test".into()));
leader.complete(&sf, Arc::new(Err(err)));
let result = sf.try_join(42);
assert!(matches!(result, FlightResult::Leader(_)));
}
#[test]
fn compute_key_same_inputs_same_key() {
let k1 = Singleflight::compute_key(123, &[]);
let k2 = Singleflight::compute_key(123, &[]);
assert_eq!(k1, k2);
}
#[test]
fn compute_key_different_sql_hash_different_key() {
let k1 = Singleflight::compute_key(123, &[]);
let k2 = Singleflight::compute_key(456, &[]);
assert_ne!(k1, k2);
}
#[test]
fn compute_key_same_params_same_key() {
let a = 42i32;
let b = 42i32;
let k1 = Singleflight::compute_key(100, &[&a]);
let k2 = Singleflight::compute_key(100, &[&b]);
assert_eq!(k1, k2);
}
#[test]
fn compute_key_different_params_different_key() {
let a = 42i32;
let b = 99i32;
let k1 = Singleflight::compute_key(100, &[&a]);
let k2 = Singleflight::compute_key(100, &[&b]);
assert_ne!(k1, k2);
}
#[test]
fn compute_key_different_sql_same_params_different_key() {
let a = 42i32;
let k1 = Singleflight::compute_key(100, &[&a]);
let k2 = Singleflight::compute_key(200, &[&a]);
assert_ne!(k1, k2);
}
#[test]
fn compute_key_null_param_handling() {
let null_val: Option<i32> = None;
let some_val: Option<i32> = Some(42);
let k1 = Singleflight::compute_key(100, &[&null_val]);
let k2 = Singleflight::compute_key(100, &[&some_val]);
assert_ne!(k1, k2, "NULL and Some(42) should produce different keys");
}
#[test]
fn compute_key_two_nulls_same_key() {
let a: Option<i32> = None;
let b: Option<i32> = None;
let k1 = Singleflight::compute_key(100, &[&a]);
let k2 = Singleflight::compute_key(100, &[&b]);
assert_eq!(k1, k2);
}
#[test]
fn compute_key_multiple_params() {
let a = 1i32;
let b = "hello";
let k1 = Singleflight::compute_key(100, &[&a, &b]);
let k2 = Singleflight::compute_key(100, &[&a, &b]);
assert_eq!(k1, k2);
}
#[test]
fn compute_key_param_order_matters() {
let a = 1i32;
let b = 2i32;
let k1 = Singleflight::compute_key(100, &[&a, &b]);
let k2 = Singleflight::compute_key(100, &[&b, &a]);
assert_ne!(k1, k2);
}
#[test]
fn leader_complete_notifies_follower() {
let sf = Arc::new(Singleflight::new());
let leader = match sf.try_join(42) {
FlightResult::Leader(l) => l,
_ => panic!("expected leader"),
};
let follower_state = match sf.try_join(42) {
FlightResult::Follower(state) => state,
_ => panic!("expected follower"),
};
let handle = std::thread::spawn(move || Singleflight::wait_for_result(&follower_state));
let err = BsqlError::from(bsql_driver_postgres::DriverError::Pool("test".into()));
leader.complete(&sf, Arc::new(Err(err)));
let received = handle.join().unwrap();
assert!(received.is_some());
assert!(received.unwrap().is_err());
}
#[test]
fn multiple_followers_receive_result() {
let sf = Arc::new(Singleflight::new());
let leader = match sf.try_join(42) {
FlightResult::Leader(l) => l,
_ => panic!("expected leader"),
};
let state1 = match sf.try_join(42) {
FlightResult::Follower(s) => s,
_ => panic!("expected follower 1"),
};
let state2 = match sf.try_join(42) {
FlightResult::Follower(s) => s,
_ => panic!("expected follower 2"),
};
let h1 = std::thread::spawn(move || Singleflight::wait_for_result(&state1));
let h2 = std::thread::spawn(move || Singleflight::wait_for_result(&state2));
let err = BsqlError::from(bsql_driver_postgres::DriverError::Pool("done".into()));
leader.complete(&sf, Arc::new(Err(err)));
let r1 = h1.join().unwrap();
let r2 = h2.join().unwrap();
assert!(r1.is_some());
assert!(r1.unwrap().is_err());
assert!(r2.is_some());
assert!(r2.unwrap().is_err());
}
#[test]
fn drop_leader_without_complete_cleans_up_map() {
let sf = Singleflight::new();
let leader = match sf.try_join(42) {
FlightResult::Leader(l) => l,
_ => panic!("expected leader"),
};
drop(leader);
let result = sf.try_join(42);
assert!(
matches!(result, FlightResult::Leader(_)),
"key should be removed from map after leader drop without complete"
);
}
#[test]
fn concurrent_stress_test() {
use std::sync::atomic::{AtomicUsize, Ordering};
let sf = Arc::new(Singleflight::new());
let leader_count = Arc::new(AtomicUsize::new(0));
let follower_count = Arc::new(AtomicUsize::new(0));
let mut handles = Vec::new();
for i in 0..10 {
let sf = Arc::clone(&sf);
let leaders = Arc::clone(&leader_count);
let followers = Arc::clone(&follower_count);
let key = (i % 5) as u64;
handles.push(std::thread::spawn(move || {
match sf.try_join(key) {
FlightResult::Leader(leader) => {
leaders.fetch_add(1, Ordering::Relaxed);
let err = BsqlError::from(bsql_driver_postgres::DriverError::Pool(
"stress".into(),
));
leader.complete(&sf, Arc::new(Err(err)));
}
FlightResult::Follower(_state) => {
followers.fetch_add(1, Ordering::Relaxed);
}
}
}));
}
for h in handles {
h.join().unwrap();
}
let total = leader_count.load(Ordering::Relaxed) + follower_count.load(Ordering::Relaxed);
assert_eq!(total, 10, "all 10 threads should participate");
assert!(
leader_count.load(Ordering::Relaxed) >= 5,
"should have at least 5 leaders (one per key)"
);
}
#[test]
fn singleflight_default() {
let sf = Singleflight::default();
let result = sf.try_join(1);
assert!(matches!(result, FlightResult::Leader(_)));
}
fn _assert_send<T: Send>() {}
fn _assert_sync<T: Sync>() {}
#[test]
fn singleflight_is_send_and_sync() {
_assert_send::<Singleflight>();
_assert_sync::<Singleflight>();
}
#[test]
fn compute_key_string_params() {
let a = "hello";
let b = "world";
let k1 = Singleflight::compute_key(100, &[&a, &b]);
let k2 = Singleflight::compute_key(100, &[&a, &b]);
assert_eq!(k1, k2);
}
#[test]
fn compute_key_empty_params_consistent() {
let k1 = Singleflight::compute_key(0, &[]);
let k2 = Singleflight::compute_key(0, &[]);
assert_eq!(k1, k2);
}
#[test]
fn leader_complete_with_no_followers() {
let sf = Singleflight::new();
let leader = match sf.try_join(42) {
FlightResult::Leader(l) => l,
_ => panic!("expected leader"),
};
let err = BsqlError::from(bsql_driver_postgres::DriverError::Pool("solo".into()));
leader.complete(&sf, Arc::new(Err(err)));
let result = sf.try_join(42);
assert!(matches!(result, FlightResult::Leader(_)));
}
#[test]
fn follower_gets_none_when_leader_dropped_without_complete() {
let sf = Arc::new(Singleflight::new());
let leader = match sf.try_join(42) {
FlightResult::Leader(l) => l,
_ => panic!("expected leader"),
};
let follower_state = match sf.try_join(42) {
FlightResult::Follower(s) => s,
_ => panic!("expected follower"),
};
let handle = std::thread::spawn(move || {
let _ = follower_state;
});
drop(leader);
handle.join().unwrap();
let result = sf.try_join(42);
assert!(
matches!(result, FlightResult::Leader(_)),
"key should be removed from map after leader drop"
);
}
#[test]
fn new_leader_succeeds_after_previous_leader_dropped() {
let sf = Arc::new(Singleflight::new());
let leader1 = match sf.try_join(42) {
FlightResult::Leader(l) => l,
_ => panic!("expected leader"),
};
drop(leader1);
let leader2 = match sf.try_join(42) {
FlightResult::Leader(l) => l,
_ => panic!("expected new leader after previous leader drop"),
};
let follower_state = match sf.try_join(42) {
FlightResult::Follower(s) => s,
_ => panic!("expected follower for second leader"),
};
let handle = std::thread::spawn(move || Singleflight::wait_for_result(&follower_state));
let err = BsqlError::from(bsql_driver_postgres::DriverError::Pool("retry".into()));
leader2.complete(&sf, Arc::new(Err(err)));
let received = handle.join().unwrap();
assert!(received.is_some());
assert!(received.unwrap().is_err());
}
}