1use std::collections::HashMap;
24use std::sync::{Arc, Mutex};
25
26use tokio::sync::broadcast;
27
28use crate::error::BsqlError;
29
30type SharedResult = Arc<Result<Arc<OwnedResultSnapshot>, BsqlError>>;
32
33type InFlightMap = Arc<Mutex<HashMap<u64, broadcast::Sender<SharedResult>>>>;
35
36pub struct OwnedResultSnapshot {
41 pub result: bsql_driver_postgres::QueryResult,
43 pub arena: bsql_driver_postgres::Arena,
45}
46
47pub struct Singleflight {
51 in_flight: InFlightMap,
56}
57
58pub enum FlightResult {
60 Leader(FlightLeader),
62 Follower(broadcast::Receiver<SharedResult>),
64}
65
66pub struct FlightLeader {
73 key: u64,
74 tx: broadcast::Sender<SharedResult>,
75 in_flight: Option<InFlightMap>,
78}
79
80impl FlightLeader {
81 pub fn complete(mut self, sf: &Singleflight, result: SharedResult) {
83 sf.in_flight
85 .lock()
86 .unwrap_or_else(|e| e.into_inner())
87 .remove(&self.key);
88 self.in_flight = None;
90 let _ = self.tx.send(result);
92 }
93}
94
95impl Drop for FlightLeader {
96 fn drop(&mut self) {
97 if let Some(ref map) = self.in_flight {
101 map.lock()
102 .unwrap_or_else(|e| e.into_inner())
103 .remove(&self.key);
104 }
105 }
106}
107
108impl Singleflight {
109 pub fn new() -> Self {
111 Self {
112 in_flight: Arc::new(Mutex::new(HashMap::new())),
113 }
114 }
115
116 pub fn try_join(&self, key: u64) -> FlightResult {
120 let mut map = self.in_flight.lock().unwrap_or_else(|e| e.into_inner());
121
122 if let Some(tx) = map.get(&key) {
123 FlightResult::Follower(tx.subscribe())
125 } else {
126 let (tx, _) = broadcast::channel(1);
129 map.insert(key, tx.clone());
130 FlightResult::Leader(FlightLeader {
131 key,
132 tx,
133 in_flight: Some(Arc::clone(&self.in_flight)),
134 })
135 }
136 }
137
138 pub fn compute_key(
144 sql_hash: u64,
145 params: &[&(dyn bsql_driver_postgres::Encode + Sync)],
146 ) -> u64 {
147 use std::hash::{Hash, Hasher};
148 let mut hasher = rapidhash::quality::RapidHasher::default();
149 sql_hash.hash(&mut hasher);
150 let mut scratch = Vec::with_capacity(64);
151 for param in params {
153 if param.is_null() {
154 hasher.write_u8(0xFF); } else {
156 scratch.clear();
157 param.encode_binary(&mut scratch);
158 hasher.write(&scratch);
159 }
160 }
161 hasher.finish()
162 }
163}
164
165impl Default for Singleflight {
166 fn default() -> Self {
167 Self::new()
168 }
169}
170
171#[cfg(test)]
172mod tests {
173 use super::*;
174
175 #[test]
176 fn singleflight_leader_when_empty() {
177 let sf = Singleflight::new();
178 let result = sf.try_join(42);
179 assert!(matches!(result, FlightResult::Leader(_)));
180 }
181
182 #[test]
183 fn singleflight_follower_when_in_flight() {
184 let sf = Singleflight::new();
185 let _leader = sf.try_join(42);
186 let result = sf.try_join(42);
187 assert!(matches!(result, FlightResult::Follower(_)));
188 }
189
190 #[test]
191 fn singleflight_different_keys_both_leaders() {
192 let sf = Singleflight::new();
193 let r1 = sf.try_join(42);
194 let r2 = sf.try_join(43);
195 assert!(matches!(r1, FlightResult::Leader(_)));
196 assert!(matches!(r2, FlightResult::Leader(_)));
197 }
198
199 #[test]
200 fn singleflight_complete_removes_from_map() {
201 let sf = Singleflight::new();
202 let leader = match sf.try_join(42) {
203 FlightResult::Leader(l) => l,
204 _ => panic!("expected leader"),
205 };
206 let err = BsqlError::from(bsql_driver_postgres::DriverError::Pool("test".into()));
207 leader.complete(&sf, Arc::new(Err(err)));
208
209 let result = sf.try_join(42);
211 assert!(matches!(result, FlightResult::Leader(_)));
212 }
213
214 #[test]
215 fn compute_key_same_inputs_same_key() {
216 let k1 = Singleflight::compute_key(123, &[]);
217 let k2 = Singleflight::compute_key(123, &[]);
218 assert_eq!(k1, k2);
219 }
220
221 #[test]
222 fn compute_key_different_sql_hash_different_key() {
223 let k1 = Singleflight::compute_key(123, &[]);
224 let k2 = Singleflight::compute_key(456, &[]);
225 assert_ne!(k1, k2);
226 }
227
228 #[test]
231 fn compute_key_same_params_same_key() {
232 let a = 42i32;
233 let b = 42i32;
234 let k1 = Singleflight::compute_key(100, &[&a]);
235 let k2 = Singleflight::compute_key(100, &[&b]);
236 assert_eq!(k1, k2);
237 }
238
239 #[test]
240 fn compute_key_different_params_different_key() {
241 let a = 42i32;
242 let b = 99i32;
243 let k1 = Singleflight::compute_key(100, &[&a]);
244 let k2 = Singleflight::compute_key(100, &[&b]);
245 assert_ne!(k1, k2);
246 }
247
248 #[test]
249 fn compute_key_different_sql_same_params_different_key() {
250 let a = 42i32;
251 let k1 = Singleflight::compute_key(100, &[&a]);
252 let k2 = Singleflight::compute_key(200, &[&a]);
253 assert_ne!(k1, k2);
254 }
255
256 #[test]
257 fn compute_key_null_param_handling() {
258 let null_val: Option<i32> = None;
260 let some_val: Option<i32> = Some(42);
261 let k1 = Singleflight::compute_key(100, &[&null_val]);
262 let k2 = Singleflight::compute_key(100, &[&some_val]);
263 assert_ne!(k1, k2, "NULL and Some(42) should produce different keys");
264 }
265
266 #[test]
267 fn compute_key_two_nulls_same_key() {
268 let a: Option<i32> = None;
269 let b: Option<i32> = None;
270 let k1 = Singleflight::compute_key(100, &[&a]);
271 let k2 = Singleflight::compute_key(100, &[&b]);
272 assert_eq!(k1, k2);
273 }
274
275 #[test]
276 fn compute_key_multiple_params() {
277 let a = 1i32;
278 let b = "hello";
279 let k1 = Singleflight::compute_key(100, &[&a, &b]);
280 let k2 = Singleflight::compute_key(100, &[&a, &b]);
281 assert_eq!(k1, k2);
282 }
283
284 #[test]
285 fn compute_key_param_order_matters() {
286 let a = 1i32;
287 let b = 2i32;
288 let k1 = Singleflight::compute_key(100, &[&a, &b]);
289 let k2 = Singleflight::compute_key(100, &[&b, &a]);
290 assert_ne!(k1, k2);
291 }
292
293 #[tokio::test]
296 async fn leader_complete_broadcasts_to_follower() {
297 let sf = Singleflight::new();
298
299 let leader = match sf.try_join(42) {
300 FlightResult::Leader(l) => l,
301 _ => panic!("expected leader"),
302 };
303
304 let mut rx = match sf.try_join(42) {
305 FlightResult::Follower(rx) => rx,
306 _ => panic!("expected follower"),
307 };
308
309 let err = BsqlError::from(bsql_driver_postgres::DriverError::Pool("test".into()));
310 leader.complete(&sf, Arc::new(Err(err)));
311
312 let received = rx.recv().await.unwrap();
313 assert!(received.is_err());
314 }
315
316 #[tokio::test]
319 async fn multiple_followers_receive_result() {
320 let sf = Singleflight::new();
321
322 let leader = match sf.try_join(42) {
323 FlightResult::Leader(l) => l,
324 _ => panic!("expected leader"),
325 };
326
327 let mut rx1 = match sf.try_join(42) {
328 FlightResult::Follower(rx) => rx,
329 _ => panic!("expected follower 1"),
330 };
331 let mut rx2 = match sf.try_join(42) {
332 FlightResult::Follower(rx) => rx,
333 _ => panic!("expected follower 2"),
334 };
335
336 let err = BsqlError::from(bsql_driver_postgres::DriverError::Pool("done".into()));
337 leader.complete(&sf, Arc::new(Err(err)));
338
339 let r1 = rx1.recv().await.unwrap();
340 let r2 = rx2.recv().await.unwrap();
341 assert!(r1.is_err());
342 assert!(r2.is_err());
343 }
344
345 #[test]
348 fn drop_leader_without_complete_cleans_up_map() {
349 let sf = Singleflight::new();
350
351 let leader = match sf.try_join(42) {
352 FlightResult::Leader(l) => l,
353 _ => panic!("expected leader"),
354 };
355
356 drop(leader);
360
361 let result = sf.try_join(42);
364 assert!(
365 matches!(result, FlightResult::Leader(_)),
366 "key should be removed from map after leader drop without complete"
367 );
368 }
369
370 #[tokio::test]
373 async fn concurrent_stress_test() {
374 use std::sync::atomic::{AtomicUsize, Ordering};
375 use tokio::task;
376
377 let sf = Arc::new(Singleflight::new());
378 let leader_count = Arc::new(AtomicUsize::new(0));
379 let follower_count = Arc::new(AtomicUsize::new(0));
380
381 let mut handles = Vec::new();
382
383 for i in 0..10 {
385 let sf = Arc::clone(&sf);
386 let leaders = Arc::clone(&leader_count);
387 let followers = Arc::clone(&follower_count);
388 let key = (i % 5) as u64;
389
390 handles.push(task::spawn(async move {
391 match sf.try_join(key) {
392 FlightResult::Leader(leader) => {
393 leaders.fetch_add(1, Ordering::Relaxed);
394 let err = BsqlError::from(bsql_driver_postgres::DriverError::Pool(
396 "stress".into(),
397 ));
398 leader.complete(&sf, Arc::new(Err(err)));
399 }
400 FlightResult::Follower(_rx) => {
401 followers.fetch_add(1, Ordering::Relaxed);
402 }
403 }
404 }));
405 }
406
407 for h in handles {
408 h.await.unwrap();
409 }
410
411 let total = leader_count.load(Ordering::Relaxed) + follower_count.load(Ordering::Relaxed);
412 assert_eq!(total, 10, "all 10 tasks should participate");
413 assert!(
415 leader_count.load(Ordering::Relaxed) >= 5,
416 "should have at least 5 leaders (one per key)"
417 );
418 }
419
420 #[test]
423 fn singleflight_default() {
424 let sf = Singleflight::default();
425 let result = sf.try_join(1);
427 assert!(matches!(result, FlightResult::Leader(_)));
428 }
429
430 fn _assert_send<T: Send>() {}
433 fn _assert_sync<T: Sync>() {}
434
435 #[test]
436 fn singleflight_is_send_and_sync() {
437 _assert_send::<Singleflight>();
438 _assert_sync::<Singleflight>();
439 }
440
441 #[test]
444 fn compute_key_string_params() {
445 let a = "hello";
446 let b = "world";
447 let k1 = Singleflight::compute_key(100, &[&a, &b]);
448 let k2 = Singleflight::compute_key(100, &[&a, &b]);
449 assert_eq!(k1, k2);
450 }
451
452 #[test]
453 fn compute_key_empty_params_consistent() {
454 let k1 = Singleflight::compute_key(0, &[]);
455 let k2 = Singleflight::compute_key(0, &[]);
456 assert_eq!(k1, k2);
457 }
458
459 #[test]
462 fn leader_complete_with_no_followers() {
463 let sf = Singleflight::new();
464 let leader = match sf.try_join(42) {
465 FlightResult::Leader(l) => l,
466 _ => panic!("expected leader"),
467 };
468 let err = BsqlError::from(bsql_driver_postgres::DriverError::Pool("solo".into()));
470 leader.complete(&sf, Arc::new(Err(err)));
471
472 let result = sf.try_join(42);
474 assert!(matches!(result, FlightResult::Leader(_)));
475 }
476
477 #[tokio::test]
480 async fn follower_gets_error_when_leader_dropped_without_complete() {
481 let sf = Singleflight::new();
482
483 let leader = match sf.try_join(42) {
484 FlightResult::Leader(l) => l,
485 _ => panic!("expected leader"),
486 };
487
488 let mut rx = match sf.try_join(42) {
489 FlightResult::Follower(rx) => rx,
490 _ => panic!("expected follower"),
491 };
492
493 drop(leader);
495
496 let result = rx.recv().await;
498 assert!(
499 result.is_err(),
500 "follower should get RecvError when leader is dropped without complete"
501 );
502 }
503
504 #[tokio::test]
507 async fn new_leader_succeeds_after_previous_leader_dropped() {
508 let sf = Arc::new(Singleflight::new());
509
510 let leader1 = match sf.try_join(42) {
512 FlightResult::Leader(l) => l,
513 _ => panic!("expected leader"),
514 };
515 drop(leader1);
516
517 let leader2 = match sf.try_join(42) {
519 FlightResult::Leader(l) => l,
520 _ => panic!("expected new leader after previous leader drop"),
521 };
522
523 let mut rx = match sf.try_join(42) {
524 FlightResult::Follower(rx) => rx,
525 _ => panic!("expected follower for second leader"),
526 };
527
528 let err = BsqlError::from(bsql_driver_postgres::DriverError::Pool("retry".into()));
529 leader2.complete(&sf, Arc::new(Err(err)));
530
531 let received = rx.recv().await.unwrap();
532 assert!(received.is_err());
533 }
534}