1use std::collections::HashMap;
24use std::sync::{Arc, Condvar, Mutex};
25use std::time::Duration;
26
27use crate::error::BsqlError;
28
29type SharedResult = Arc<Result<Arc<OwnedResultSnapshot>, BsqlError>>;
31
32pub struct FlightState {
34 result: Mutex<Option<SharedResult>>,
35 condvar: Condvar,
36 cancelled: std::sync::atomic::AtomicBool,
39}
40
41type InFlightMap = Arc<Mutex<HashMap<u64, Arc<FlightState>>>>;
43
44pub struct OwnedResultSnapshot {
49 pub result: bsql_driver_postgres::QueryResult,
51 pub arena: bsql_driver_postgres::Arena,
53}
54
55pub struct Singleflight {
59 in_flight: InFlightMap,
64}
65
66pub enum FlightResult {
68 Leader(FlightLeader),
70 Follower(Arc<FlightState>),
72}
73
74pub struct FlightLeader {
81 key: u64,
82 state: Arc<FlightState>,
83 in_flight: Option<InFlightMap>,
86}
87
88impl FlightLeader {
89 pub fn complete(mut self, sf: &Singleflight, result: SharedResult) {
91 sf.in_flight
93 .lock()
94 .unwrap_or_else(|e| e.into_inner())
95 .remove(&self.key);
96 self.in_flight = None;
98 *self.state.result.lock().unwrap_or_else(|e| e.into_inner()) = Some(result);
100 self.state.condvar.notify_all();
101 }
102}
103
104impl Drop for FlightLeader {
105 fn drop(&mut self) {
106 if let Some(ref map) = self.in_flight {
109 map.lock()
110 .unwrap_or_else(|e| e.into_inner())
111 .remove(&self.key);
112 self.state
114 .cancelled
115 .store(true, std::sync::atomic::Ordering::Release);
116 self.state.condvar.notify_all();
117 }
118 }
119}
120
121impl Singleflight {
122 pub fn new() -> Self {
124 Self {
125 in_flight: Arc::new(Mutex::new(HashMap::new())),
126 }
127 }
128
129 pub fn try_join(&self, key: u64) -> FlightResult {
133 let mut map = self.in_flight.lock().unwrap_or_else(|e| e.into_inner());
134
135 if let Some(state) = map.get(&key) {
136 FlightResult::Follower(Arc::clone(state))
138 } else {
139 let state = Arc::new(FlightState {
141 result: Mutex::new(None),
142 condvar: Condvar::new(),
143 cancelled: std::sync::atomic::AtomicBool::new(false),
144 });
145 map.insert(key, Arc::clone(&state));
146 FlightResult::Leader(FlightLeader {
147 key,
148 state,
149 in_flight: Some(Arc::clone(&self.in_flight)),
150 })
151 }
152 }
153
154 pub fn wait_for_result(state: &FlightState) -> Option<SharedResult> {
160 const SINGLEFLIGHT_TIMEOUT: Duration = Duration::from_secs(30);
161 let mut guard = state.result.lock().unwrap_or_else(|e| e.into_inner());
162 while guard.is_none() {
163 if state.cancelled.load(std::sync::atomic::Ordering::Acquire) {
165 return None;
166 }
167 let (new_guard, wait_result) = state
168 .condvar
169 .wait_timeout(guard, SINGLEFLIGHT_TIMEOUT)
170 .unwrap_or_else(|e| e.into_inner());
171 guard = new_guard;
172 if wait_result.timed_out() && guard.is_none() {
174 return None;
175 }
176 }
177 guard.clone()
178 }
179
180 pub fn compute_key(
186 sql_hash: u64,
187 params: &[&(dyn bsql_driver_postgres::Encode + Sync)],
188 ) -> u64 {
189 use std::hash::{Hash, Hasher};
190 let mut hasher = rapidhash::quality::RapidHasher::default();
191 sql_hash.hash(&mut hasher);
192 thread_local! {
195 static SCRATCH: std::cell::RefCell<Vec<u8>> = const { std::cell::RefCell::new(Vec::new()) };
196 }
197 SCRATCH.with(|cell| {
198 let mut scratch = cell.borrow_mut();
199 for param in params {
200 if param.is_null() {
201 hasher.write_u8(0xFF); } else {
203 scratch.clear();
204 param.encode_binary(&mut scratch);
205 hasher.write(&scratch);
206 }
207 }
208 });
209 hasher.finish()
210 }
211}
212
213impl Default for Singleflight {
214 fn default() -> Self {
215 Self::new()
216 }
217}
218
219#[cfg(test)]
220mod tests {
221 use super::*;
222
223 #[test]
224 fn singleflight_leader_when_empty() {
225 let sf = Singleflight::new();
226 let result = sf.try_join(42);
227 assert!(matches!(result, FlightResult::Leader(_)));
228 }
229
230 #[test]
231 fn singleflight_follower_when_in_flight() {
232 let sf = Singleflight::new();
233 let _leader = sf.try_join(42);
234 let result = sf.try_join(42);
235 assert!(matches!(result, FlightResult::Follower(_)));
236 }
237
238 #[test]
239 fn singleflight_different_keys_both_leaders() {
240 let sf = Singleflight::new();
241 let r1 = sf.try_join(42);
242 let r2 = sf.try_join(43);
243 assert!(matches!(r1, FlightResult::Leader(_)));
244 assert!(matches!(r2, FlightResult::Leader(_)));
245 }
246
247 #[test]
248 fn singleflight_complete_removes_from_map() {
249 let sf = Singleflight::new();
250 let leader = match sf.try_join(42) {
251 FlightResult::Leader(l) => l,
252 _ => panic!("expected leader"),
253 };
254 let err = BsqlError::from(bsql_driver_postgres::DriverError::Pool("test".into()));
255 leader.complete(&sf, Arc::new(Err(err)));
256
257 let result = sf.try_join(42);
259 assert!(matches!(result, FlightResult::Leader(_)));
260 }
261
262 #[test]
263 fn compute_key_same_inputs_same_key() {
264 let k1 = Singleflight::compute_key(123, &[]);
265 let k2 = Singleflight::compute_key(123, &[]);
266 assert_eq!(k1, k2);
267 }
268
269 #[test]
270 fn compute_key_different_sql_hash_different_key() {
271 let k1 = Singleflight::compute_key(123, &[]);
272 let k2 = Singleflight::compute_key(456, &[]);
273 assert_ne!(k1, k2);
274 }
275
276 #[test]
279 fn compute_key_same_params_same_key() {
280 let a = 42i32;
281 let b = 42i32;
282 let k1 = Singleflight::compute_key(100, &[&a]);
283 let k2 = Singleflight::compute_key(100, &[&b]);
284 assert_eq!(k1, k2);
285 }
286
287 #[test]
288 fn compute_key_different_params_different_key() {
289 let a = 42i32;
290 let b = 99i32;
291 let k1 = Singleflight::compute_key(100, &[&a]);
292 let k2 = Singleflight::compute_key(100, &[&b]);
293 assert_ne!(k1, k2);
294 }
295
296 #[test]
297 fn compute_key_different_sql_same_params_different_key() {
298 let a = 42i32;
299 let k1 = Singleflight::compute_key(100, &[&a]);
300 let k2 = Singleflight::compute_key(200, &[&a]);
301 assert_ne!(k1, k2);
302 }
303
304 #[test]
305 fn compute_key_null_param_handling() {
306 let null_val: Option<i32> = None;
308 let some_val: Option<i32> = Some(42);
309 let k1 = Singleflight::compute_key(100, &[&null_val]);
310 let k2 = Singleflight::compute_key(100, &[&some_val]);
311 assert_ne!(k1, k2, "NULL and Some(42) should produce different keys");
312 }
313
314 #[test]
315 fn compute_key_two_nulls_same_key() {
316 let a: Option<i32> = None;
317 let b: Option<i32> = None;
318 let k1 = Singleflight::compute_key(100, &[&a]);
319 let k2 = Singleflight::compute_key(100, &[&b]);
320 assert_eq!(k1, k2);
321 }
322
323 #[test]
324 fn compute_key_multiple_params() {
325 let a = 1i32;
326 let b = "hello";
327 let k1 = Singleflight::compute_key(100, &[&a, &b]);
328 let k2 = Singleflight::compute_key(100, &[&a, &b]);
329 assert_eq!(k1, k2);
330 }
331
332 #[test]
333 fn compute_key_param_order_matters() {
334 let a = 1i32;
335 let b = 2i32;
336 let k1 = Singleflight::compute_key(100, &[&a, &b]);
337 let k2 = Singleflight::compute_key(100, &[&b, &a]);
338 assert_ne!(k1, k2);
339 }
340
341 #[test]
344 fn leader_complete_notifies_follower() {
345 let sf = Arc::new(Singleflight::new());
346
347 let leader = match sf.try_join(42) {
348 FlightResult::Leader(l) => l,
349 _ => panic!("expected leader"),
350 };
351
352 let follower_state = match sf.try_join(42) {
353 FlightResult::Follower(state) => state,
354 _ => panic!("expected follower"),
355 };
356
357 let handle = std::thread::spawn(move || Singleflight::wait_for_result(&follower_state));
358
359 let err = BsqlError::from(bsql_driver_postgres::DriverError::Pool("test".into()));
360 leader.complete(&sf, Arc::new(Err(err)));
361
362 let received = handle.join().unwrap();
363 assert!(received.is_some());
364 assert!(received.unwrap().is_err());
365 }
366
367 #[test]
370 fn multiple_followers_receive_result() {
371 let sf = Arc::new(Singleflight::new());
372
373 let leader = match sf.try_join(42) {
374 FlightResult::Leader(l) => l,
375 _ => panic!("expected leader"),
376 };
377
378 let state1 = match sf.try_join(42) {
379 FlightResult::Follower(s) => s,
380 _ => panic!("expected follower 1"),
381 };
382 let state2 = match sf.try_join(42) {
383 FlightResult::Follower(s) => s,
384 _ => panic!("expected follower 2"),
385 };
386
387 let h1 = std::thread::spawn(move || Singleflight::wait_for_result(&state1));
388 let h2 = std::thread::spawn(move || Singleflight::wait_for_result(&state2));
389
390 let err = BsqlError::from(bsql_driver_postgres::DriverError::Pool("done".into()));
391 leader.complete(&sf, Arc::new(Err(err)));
392
393 let r1 = h1.join().unwrap();
394 let r2 = h2.join().unwrap();
395 assert!(r1.is_some());
396 assert!(r1.unwrap().is_err());
397 assert!(r2.is_some());
398 assert!(r2.unwrap().is_err());
399 }
400
401 #[test]
404 fn drop_leader_without_complete_cleans_up_map() {
405 let sf = Singleflight::new();
406
407 let leader = match sf.try_join(42) {
408 FlightResult::Leader(l) => l,
409 _ => panic!("expected leader"),
410 };
411
412 drop(leader);
416
417 let result = sf.try_join(42);
420 assert!(
421 matches!(result, FlightResult::Leader(_)),
422 "key should be removed from map after leader drop without complete"
423 );
424 }
425
426 #[test]
429 fn concurrent_stress_test() {
430 use std::sync::atomic::{AtomicUsize, Ordering};
431
432 let sf = Arc::new(Singleflight::new());
433 let leader_count = Arc::new(AtomicUsize::new(0));
434 let follower_count = Arc::new(AtomicUsize::new(0));
435
436 let mut handles = Vec::new();
437
438 for i in 0..10 {
440 let sf = Arc::clone(&sf);
441 let leaders = Arc::clone(&leader_count);
442 let followers = Arc::clone(&follower_count);
443 let key = (i % 5) as u64;
444
445 handles.push(std::thread::spawn(move || {
446 match sf.try_join(key) {
447 FlightResult::Leader(leader) => {
448 leaders.fetch_add(1, Ordering::Relaxed);
449 let err = BsqlError::from(bsql_driver_postgres::DriverError::Pool(
451 "stress".into(),
452 ));
453 leader.complete(&sf, Arc::new(Err(err)));
454 }
455 FlightResult::Follower(_state) => {
456 followers.fetch_add(1, Ordering::Relaxed);
457 }
458 }
459 }));
460 }
461
462 for h in handles {
463 h.join().unwrap();
464 }
465
466 let total = leader_count.load(Ordering::Relaxed) + follower_count.load(Ordering::Relaxed);
467 assert_eq!(total, 10, "all 10 threads should participate");
468 assert!(
470 leader_count.load(Ordering::Relaxed) >= 5,
471 "should have at least 5 leaders (one per key)"
472 );
473 }
474
475 #[test]
478 fn singleflight_default() {
479 let sf = Singleflight::default();
480 let result = sf.try_join(1);
482 assert!(matches!(result, FlightResult::Leader(_)));
483 }
484
485 fn _assert_send<T: Send>() {}
488 fn _assert_sync<T: Sync>() {}
489
490 #[test]
491 fn singleflight_is_send_and_sync() {
492 _assert_send::<Singleflight>();
493 _assert_sync::<Singleflight>();
494 }
495
496 #[test]
499 fn compute_key_string_params() {
500 let a = "hello";
501 let b = "world";
502 let k1 = Singleflight::compute_key(100, &[&a, &b]);
503 let k2 = Singleflight::compute_key(100, &[&a, &b]);
504 assert_eq!(k1, k2);
505 }
506
507 #[test]
508 fn compute_key_empty_params_consistent() {
509 let k1 = Singleflight::compute_key(0, &[]);
510 let k2 = Singleflight::compute_key(0, &[]);
511 assert_eq!(k1, k2);
512 }
513
514 #[test]
517 fn compute_key_bool_params() {
518 let t = true;
519 let f = false;
520 let k1 = Singleflight::compute_key(100, &[&t]);
521 let k2 = Singleflight::compute_key(100, &[&f]);
522 assert_ne!(k1, k2, "true and false should produce different keys");
523 }
524
525 #[test]
526 fn compute_key_bool_same_value_same_key() {
527 let a = true;
528 let b = true;
529 let k1 = Singleflight::compute_key(100, &[&a]);
530 let k2 = Singleflight::compute_key(100, &[&b]);
531 assert_eq!(k1, k2);
532 }
533
534 #[test]
537 fn compute_key_mixed_types() {
538 let i = 42i32;
539 let s = "hello";
540 let b = true;
541 let k1 = Singleflight::compute_key(100, &[&i, &s, &b]);
542 let k2 = Singleflight::compute_key(100, &[&i, &s, &b]);
543 assert_eq!(k1, k2, "same mixed params should produce same key");
544 }
545
546 #[test]
547 fn compute_key_mixed_types_different_values() {
548 let i1 = 42i32;
549 let i2 = 43i32;
550 let s = "hello";
551 let k1 = Singleflight::compute_key(100, &[&i1, &s]);
552 let k2 = Singleflight::compute_key(100, &[&i2, &s]);
553 assert_ne!(k1, k2, "different int values should produce different keys");
554 }
555
556 #[test]
559 fn compute_key_f64_params() {
560 let a = 1.23f64;
561 let b = 1.23f64;
562 let k1 = Singleflight::compute_key(100, &[&a]);
563 let k2 = Singleflight::compute_key(100, &[&b]);
564 assert_eq!(k1, k2);
565 }
566
567 #[test]
568 fn compute_key_f64_different_values() {
569 let a = 1.23f64;
570 let b = 4.56f64;
571 let k1 = Singleflight::compute_key(100, &[&a]);
572 let k2 = Singleflight::compute_key(100, &[&b]);
573 assert_ne!(k1, k2);
574 }
575
576 #[test]
579 fn leader_complete_with_no_followers() {
580 let sf = Singleflight::new();
581 let leader = match sf.try_join(42) {
582 FlightResult::Leader(l) => l,
583 _ => panic!("expected leader"),
584 };
585 let err = BsqlError::from(bsql_driver_postgres::DriverError::Pool("solo".into()));
587 leader.complete(&sf, Arc::new(Err(err)));
588
589 let result = sf.try_join(42);
591 assert!(matches!(result, FlightResult::Leader(_)));
592 }
593
594 #[test]
597 fn follower_gets_none_when_leader_dropped_without_complete() {
598 let sf = Arc::new(Singleflight::new());
599
600 let leader = match sf.try_join(42) {
601 FlightResult::Leader(l) => l,
602 _ => panic!("expected leader"),
603 };
604
605 let follower_state = match sf.try_join(42) {
606 FlightResult::Follower(s) => s,
607 _ => panic!("expected follower"),
608 };
609
610 let handle = std::thread::spawn(move || {
611 let _ = follower_state;
619 });
620
621 drop(leader);
623
624 handle.join().unwrap();
625
626 let result = sf.try_join(42);
628 assert!(
629 matches!(result, FlightResult::Leader(_)),
630 "key should be removed from map after leader drop"
631 );
632 }
633
634 #[test]
637 fn follower_wait_for_result_returns_none_when_leader_dropped() {
638 let sf = Arc::new(Singleflight::new());
639
640 let leader = match sf.try_join(42) {
641 FlightResult::Leader(l) => l,
642 _ => panic!("expected leader"),
643 };
644
645 let follower_state = match sf.try_join(42) {
646 FlightResult::Follower(s) => s,
647 _ => panic!("expected follower"),
648 };
649
650 let handle = std::thread::spawn(move || Singleflight::wait_for_result(&follower_state));
652
653 std::thread::sleep(std::time::Duration::from_millis(10));
655
656 drop(leader);
658
659 let received = handle.join().unwrap();
660 assert!(
661 received.is_none(),
662 "follower should get None when leader dropped without completing"
663 );
664 }
665
666 #[test]
669 fn new_leader_succeeds_after_previous_leader_dropped() {
670 let sf = Arc::new(Singleflight::new());
671
672 let leader1 = match sf.try_join(42) {
674 FlightResult::Leader(l) => l,
675 _ => panic!("expected leader"),
676 };
677 drop(leader1);
678
679 let leader2 = match sf.try_join(42) {
681 FlightResult::Leader(l) => l,
682 _ => panic!("expected new leader after previous leader drop"),
683 };
684
685 let follower_state = match sf.try_join(42) {
686 FlightResult::Follower(s) => s,
687 _ => panic!("expected follower for second leader"),
688 };
689
690 let handle = std::thread::spawn(move || Singleflight::wait_for_result(&follower_state));
691
692 let err = BsqlError::from(bsql_driver_postgres::DriverError::Pool("retry".into()));
693 leader2.complete(&sf, Arc::new(Err(err)));
694
695 let received = handle.join().unwrap();
696 assert!(received.is_some());
697 assert!(received.unwrap().is_err());
698 }
699}