1use std::collections::HashMap;
24use std::sync::{Arc, Condvar, Mutex};
25use std::time::Duration;
26
27use crate::error::BsqlError;
28
29type SharedResult = Arc<Result<Arc<OwnedResultSnapshot>, BsqlError>>;
31
32pub struct FlightState {
34 result: Mutex<Option<SharedResult>>,
35 condvar: Condvar,
36 cancelled: std::sync::atomic::AtomicBool,
39}
40
41type InFlightMap = Arc<Mutex<HashMap<u64, Arc<FlightState>>>>;
43
44pub struct OwnedResultSnapshot {
49 pub result: bsql_driver_postgres::QueryResult,
51 pub arena: bsql_driver_postgres::Arena,
53}
54
55pub struct Singleflight {
59 in_flight: InFlightMap,
64}
65
66pub enum FlightResult {
68 Leader(FlightLeader),
70 Follower(Arc<FlightState>),
72}
73
74pub struct FlightLeader {
81 key: u64,
82 state: Arc<FlightState>,
83 in_flight: Option<InFlightMap>,
86}
87
88impl FlightLeader {
89 pub fn complete(mut self, sf: &Singleflight, result: SharedResult) {
91 sf.in_flight
93 .lock()
94 .unwrap_or_else(|e| e.into_inner())
95 .remove(&self.key);
96 self.in_flight = None;
98 *self.state.result.lock().unwrap_or_else(|e| e.into_inner()) = Some(result);
100 self.state.condvar.notify_all();
101 }
102}
103
104impl Drop for FlightLeader {
105 fn drop(&mut self) {
106 if let Some(ref map) = self.in_flight {
109 map.lock()
110 .unwrap_or_else(|e| e.into_inner())
111 .remove(&self.key);
112 self.state
114 .cancelled
115 .store(true, std::sync::atomic::Ordering::Release);
116 self.state.condvar.notify_all();
117 }
118 }
119}
120
121impl Singleflight {
122 pub fn new() -> Self {
124 Self {
125 in_flight: Arc::new(Mutex::new(HashMap::new())),
126 }
127 }
128
129 pub fn try_join(&self, key: u64) -> FlightResult {
133 let mut map = self.in_flight.lock().unwrap_or_else(|e| e.into_inner());
134
135 if let Some(state) = map.get(&key) {
136 FlightResult::Follower(Arc::clone(state))
138 } else {
139 let state = Arc::new(FlightState {
141 result: Mutex::new(None),
142 condvar: Condvar::new(),
143 cancelled: std::sync::atomic::AtomicBool::new(false),
144 });
145 map.insert(key, Arc::clone(&state));
146 FlightResult::Leader(FlightLeader {
147 key,
148 state,
149 in_flight: Some(Arc::clone(&self.in_flight)),
150 })
151 }
152 }
153
154 pub fn wait_for_result(state: &FlightState) -> Option<SharedResult> {
160 const SINGLEFLIGHT_TIMEOUT: Duration = Duration::from_secs(30);
161 let mut guard = state.result.lock().unwrap_or_else(|e| e.into_inner());
162 while guard.is_none() {
163 if state.cancelled.load(std::sync::atomic::Ordering::Acquire) {
165 return None;
166 }
167 let (new_guard, wait_result) = state
168 .condvar
169 .wait_timeout(guard, SINGLEFLIGHT_TIMEOUT)
170 .unwrap_or_else(|e| e.into_inner());
171 guard = new_guard;
172 if wait_result.timed_out() && guard.is_none() {
174 return None;
175 }
176 }
177 guard.clone()
178 }
179
180 pub fn compute_key(
186 sql_hash: u64,
187 params: &[&(dyn bsql_driver_postgres::Encode + Sync)],
188 ) -> u64 {
189 use std::hash::{Hash, Hasher};
190 let mut hasher = rapidhash::quality::RapidHasher::default();
191 sql_hash.hash(&mut hasher);
192 let mut scratch = Vec::with_capacity(64);
193 for param in params {
195 if param.is_null() {
196 hasher.write_u8(0xFF); } else {
198 scratch.clear();
199 param.encode_binary(&mut scratch);
200 hasher.write(&scratch);
201 }
202 }
203 hasher.finish()
204 }
205}
206
207impl Default for Singleflight {
208 fn default() -> Self {
209 Self::new()
210 }
211}
212
213#[cfg(test)]
214mod tests {
215 use super::*;
216
217 #[test]
218 fn singleflight_leader_when_empty() {
219 let sf = Singleflight::new();
220 let result = sf.try_join(42);
221 assert!(matches!(result, FlightResult::Leader(_)));
222 }
223
224 #[test]
225 fn singleflight_follower_when_in_flight() {
226 let sf = Singleflight::new();
227 let _leader = sf.try_join(42);
228 let result = sf.try_join(42);
229 assert!(matches!(result, FlightResult::Follower(_)));
230 }
231
232 #[test]
233 fn singleflight_different_keys_both_leaders() {
234 let sf = Singleflight::new();
235 let r1 = sf.try_join(42);
236 let r2 = sf.try_join(43);
237 assert!(matches!(r1, FlightResult::Leader(_)));
238 assert!(matches!(r2, FlightResult::Leader(_)));
239 }
240
241 #[test]
242 fn singleflight_complete_removes_from_map() {
243 let sf = Singleflight::new();
244 let leader = match sf.try_join(42) {
245 FlightResult::Leader(l) => l,
246 _ => panic!("expected leader"),
247 };
248 let err = BsqlError::from(bsql_driver_postgres::DriverError::Pool("test".into()));
249 leader.complete(&sf, Arc::new(Err(err)));
250
251 let result = sf.try_join(42);
253 assert!(matches!(result, FlightResult::Leader(_)));
254 }
255
256 #[test]
257 fn compute_key_same_inputs_same_key() {
258 let k1 = Singleflight::compute_key(123, &[]);
259 let k2 = Singleflight::compute_key(123, &[]);
260 assert_eq!(k1, k2);
261 }
262
263 #[test]
264 fn compute_key_different_sql_hash_different_key() {
265 let k1 = Singleflight::compute_key(123, &[]);
266 let k2 = Singleflight::compute_key(456, &[]);
267 assert_ne!(k1, k2);
268 }
269
270 #[test]
273 fn compute_key_same_params_same_key() {
274 let a = 42i32;
275 let b = 42i32;
276 let k1 = Singleflight::compute_key(100, &[&a]);
277 let k2 = Singleflight::compute_key(100, &[&b]);
278 assert_eq!(k1, k2);
279 }
280
281 #[test]
282 fn compute_key_different_params_different_key() {
283 let a = 42i32;
284 let b = 99i32;
285 let k1 = Singleflight::compute_key(100, &[&a]);
286 let k2 = Singleflight::compute_key(100, &[&b]);
287 assert_ne!(k1, k2);
288 }
289
290 #[test]
291 fn compute_key_different_sql_same_params_different_key() {
292 let a = 42i32;
293 let k1 = Singleflight::compute_key(100, &[&a]);
294 let k2 = Singleflight::compute_key(200, &[&a]);
295 assert_ne!(k1, k2);
296 }
297
298 #[test]
299 fn compute_key_null_param_handling() {
300 let null_val: Option<i32> = None;
302 let some_val: Option<i32> = Some(42);
303 let k1 = Singleflight::compute_key(100, &[&null_val]);
304 let k2 = Singleflight::compute_key(100, &[&some_val]);
305 assert_ne!(k1, k2, "NULL and Some(42) should produce different keys");
306 }
307
308 #[test]
309 fn compute_key_two_nulls_same_key() {
310 let a: Option<i32> = None;
311 let b: Option<i32> = None;
312 let k1 = Singleflight::compute_key(100, &[&a]);
313 let k2 = Singleflight::compute_key(100, &[&b]);
314 assert_eq!(k1, k2);
315 }
316
317 #[test]
318 fn compute_key_multiple_params() {
319 let a = 1i32;
320 let b = "hello";
321 let k1 = Singleflight::compute_key(100, &[&a, &b]);
322 let k2 = Singleflight::compute_key(100, &[&a, &b]);
323 assert_eq!(k1, k2);
324 }
325
326 #[test]
327 fn compute_key_param_order_matters() {
328 let a = 1i32;
329 let b = 2i32;
330 let k1 = Singleflight::compute_key(100, &[&a, &b]);
331 let k2 = Singleflight::compute_key(100, &[&b, &a]);
332 assert_ne!(k1, k2);
333 }
334
335 #[test]
338 fn leader_complete_notifies_follower() {
339 let sf = Arc::new(Singleflight::new());
340
341 let leader = match sf.try_join(42) {
342 FlightResult::Leader(l) => l,
343 _ => panic!("expected leader"),
344 };
345
346 let follower_state = match sf.try_join(42) {
347 FlightResult::Follower(state) => state,
348 _ => panic!("expected follower"),
349 };
350
351 let handle = std::thread::spawn(move || Singleflight::wait_for_result(&follower_state));
352
353 let err = BsqlError::from(bsql_driver_postgres::DriverError::Pool("test".into()));
354 leader.complete(&sf, Arc::new(Err(err)));
355
356 let received = handle.join().unwrap();
357 assert!(received.is_some());
358 assert!(received.unwrap().is_err());
359 }
360
361 #[test]
364 fn multiple_followers_receive_result() {
365 let sf = Arc::new(Singleflight::new());
366
367 let leader = match sf.try_join(42) {
368 FlightResult::Leader(l) => l,
369 _ => panic!("expected leader"),
370 };
371
372 let state1 = match sf.try_join(42) {
373 FlightResult::Follower(s) => s,
374 _ => panic!("expected follower 1"),
375 };
376 let state2 = match sf.try_join(42) {
377 FlightResult::Follower(s) => s,
378 _ => panic!("expected follower 2"),
379 };
380
381 let h1 = std::thread::spawn(move || Singleflight::wait_for_result(&state1));
382 let h2 = std::thread::spawn(move || Singleflight::wait_for_result(&state2));
383
384 let err = BsqlError::from(bsql_driver_postgres::DriverError::Pool("done".into()));
385 leader.complete(&sf, Arc::new(Err(err)));
386
387 let r1 = h1.join().unwrap();
388 let r2 = h2.join().unwrap();
389 assert!(r1.is_some());
390 assert!(r1.unwrap().is_err());
391 assert!(r2.is_some());
392 assert!(r2.unwrap().is_err());
393 }
394
395 #[test]
398 fn drop_leader_without_complete_cleans_up_map() {
399 let sf = Singleflight::new();
400
401 let leader = match sf.try_join(42) {
402 FlightResult::Leader(l) => l,
403 _ => panic!("expected leader"),
404 };
405
406 drop(leader);
410
411 let result = sf.try_join(42);
414 assert!(
415 matches!(result, FlightResult::Leader(_)),
416 "key should be removed from map after leader drop without complete"
417 );
418 }
419
420 #[test]
423 fn concurrent_stress_test() {
424 use std::sync::atomic::{AtomicUsize, Ordering};
425
426 let sf = Arc::new(Singleflight::new());
427 let leader_count = Arc::new(AtomicUsize::new(0));
428 let follower_count = Arc::new(AtomicUsize::new(0));
429
430 let mut handles = Vec::new();
431
432 for i in 0..10 {
434 let sf = Arc::clone(&sf);
435 let leaders = Arc::clone(&leader_count);
436 let followers = Arc::clone(&follower_count);
437 let key = (i % 5) as u64;
438
439 handles.push(std::thread::spawn(move || {
440 match sf.try_join(key) {
441 FlightResult::Leader(leader) => {
442 leaders.fetch_add(1, Ordering::Relaxed);
443 let err = BsqlError::from(bsql_driver_postgres::DriverError::Pool(
445 "stress".into(),
446 ));
447 leader.complete(&sf, Arc::new(Err(err)));
448 }
449 FlightResult::Follower(_state) => {
450 followers.fetch_add(1, Ordering::Relaxed);
451 }
452 }
453 }));
454 }
455
456 for h in handles {
457 h.join().unwrap();
458 }
459
460 let total = leader_count.load(Ordering::Relaxed) + follower_count.load(Ordering::Relaxed);
461 assert_eq!(total, 10, "all 10 threads should participate");
462 assert!(
464 leader_count.load(Ordering::Relaxed) >= 5,
465 "should have at least 5 leaders (one per key)"
466 );
467 }
468
469 #[test]
472 fn singleflight_default() {
473 let sf = Singleflight::default();
474 let result = sf.try_join(1);
476 assert!(matches!(result, FlightResult::Leader(_)));
477 }
478
479 fn _assert_send<T: Send>() {}
482 fn _assert_sync<T: Sync>() {}
483
484 #[test]
485 fn singleflight_is_send_and_sync() {
486 _assert_send::<Singleflight>();
487 _assert_sync::<Singleflight>();
488 }
489
490 #[test]
493 fn compute_key_string_params() {
494 let a = "hello";
495 let b = "world";
496 let k1 = Singleflight::compute_key(100, &[&a, &b]);
497 let k2 = Singleflight::compute_key(100, &[&a, &b]);
498 assert_eq!(k1, k2);
499 }
500
501 #[test]
502 fn compute_key_empty_params_consistent() {
503 let k1 = Singleflight::compute_key(0, &[]);
504 let k2 = Singleflight::compute_key(0, &[]);
505 assert_eq!(k1, k2);
506 }
507
508 #[test]
511 fn leader_complete_with_no_followers() {
512 let sf = Singleflight::new();
513 let leader = match sf.try_join(42) {
514 FlightResult::Leader(l) => l,
515 _ => panic!("expected leader"),
516 };
517 let err = BsqlError::from(bsql_driver_postgres::DriverError::Pool("solo".into()));
519 leader.complete(&sf, Arc::new(Err(err)));
520
521 let result = sf.try_join(42);
523 assert!(matches!(result, FlightResult::Leader(_)));
524 }
525
526 #[test]
529 fn follower_gets_none_when_leader_dropped_without_complete() {
530 let sf = Arc::new(Singleflight::new());
531
532 let leader = match sf.try_join(42) {
533 FlightResult::Leader(l) => l,
534 _ => panic!("expected leader"),
535 };
536
537 let follower_state = match sf.try_join(42) {
538 FlightResult::Follower(s) => s,
539 _ => panic!("expected follower"),
540 };
541
542 let handle = std::thread::spawn(move || {
543 let _ = follower_state;
551 });
552
553 drop(leader);
555
556 handle.join().unwrap();
557
558 let result = sf.try_join(42);
560 assert!(
561 matches!(result, FlightResult::Leader(_)),
562 "key should be removed from map after leader drop"
563 );
564 }
565
566 #[test]
569 fn new_leader_succeeds_after_previous_leader_dropped() {
570 let sf = Arc::new(Singleflight::new());
571
572 let leader1 = match sf.try_join(42) {
574 FlightResult::Leader(l) => l,
575 _ => panic!("expected leader"),
576 };
577 drop(leader1);
578
579 let leader2 = match sf.try_join(42) {
581 FlightResult::Leader(l) => l,
582 _ => panic!("expected new leader after previous leader drop"),
583 };
584
585 let follower_state = match sf.try_join(42) {
586 FlightResult::Follower(s) => s,
587 _ => panic!("expected follower for second leader"),
588 };
589
590 let handle = std::thread::spawn(move || Singleflight::wait_for_result(&follower_state));
591
592 let err = BsqlError::from(bsql_driver_postgres::DriverError::Pool("retry".into()));
593 leader2.complete(&sf, Arc::new(Err(err)));
594
595 let received = handle.join().unwrap();
596 assert!(received.is_some());
597 assert!(received.unwrap().is_err());
598 }
599}