1use std::collections::HashMap;
24use std::sync::{Arc, Condvar, Mutex};
25
26use crate::error::BsqlError;
27
28type SharedResult = Arc<Result<Arc<OwnedResultSnapshot>, BsqlError>>;
30
31pub struct FlightState {
33 result: Mutex<Option<SharedResult>>,
34 condvar: Condvar,
35}
36
37type InFlightMap = Arc<Mutex<HashMap<u64, Arc<FlightState>>>>;
39
40pub struct OwnedResultSnapshot {
45 pub result: bsql_driver_postgres::QueryResult,
47 pub arena: bsql_driver_postgres::Arena,
49}
50
51pub struct Singleflight {
55 in_flight: InFlightMap,
60}
61
62pub enum FlightResult {
64 Leader(FlightLeader),
66 Follower(Arc<FlightState>),
68}
69
70pub struct FlightLeader {
77 key: u64,
78 state: Arc<FlightState>,
79 in_flight: Option<InFlightMap>,
82}
83
84impl FlightLeader {
85 pub fn complete(mut self, sf: &Singleflight, result: SharedResult) {
87 sf.in_flight
89 .lock()
90 .unwrap_or_else(|e| e.into_inner())
91 .remove(&self.key);
92 self.in_flight = None;
94 *self.state.result.lock().unwrap_or_else(|e| e.into_inner()) = Some(result);
96 self.state.condvar.notify_all();
97 }
98}
99
100impl Drop for FlightLeader {
101 fn drop(&mut self) {
102 if let Some(ref map) = self.in_flight {
106 map.lock()
107 .unwrap_or_else(|e| e.into_inner())
108 .remove(&self.key);
109 self.state.condvar.notify_all();
111 }
112 }
113}
114
115impl Singleflight {
116 pub fn new() -> Self {
118 Self {
119 in_flight: Arc::new(Mutex::new(HashMap::new())),
120 }
121 }
122
123 pub fn try_join(&self, key: u64) -> FlightResult {
127 let mut map = self.in_flight.lock().unwrap_or_else(|e| e.into_inner());
128
129 if let Some(state) = map.get(&key) {
130 FlightResult::Follower(Arc::clone(state))
132 } else {
133 let state = Arc::new(FlightState {
135 result: Mutex::new(None),
136 condvar: Condvar::new(),
137 });
138 map.insert(key, Arc::clone(&state));
139 FlightResult::Leader(FlightLeader {
140 key,
141 state,
142 in_flight: Some(Arc::clone(&self.in_flight)),
143 })
144 }
145 }
146
147 pub fn wait_for_result(state: &FlightState) -> Option<SharedResult> {
152 let mut guard = state.result.lock().unwrap_or_else(|e| e.into_inner());
153 while guard.is_none() {
154 guard = state.condvar.wait(guard).unwrap_or_else(|e| e.into_inner());
155 }
164 guard.clone()
165 }
166
167 pub fn compute_key(
173 sql_hash: u64,
174 params: &[&(dyn bsql_driver_postgres::Encode + Sync)],
175 ) -> u64 {
176 use std::hash::{Hash, Hasher};
177 let mut hasher = rapidhash::quality::RapidHasher::default();
178 sql_hash.hash(&mut hasher);
179 let mut scratch = Vec::with_capacity(64);
180 for param in params {
182 if param.is_null() {
183 hasher.write_u8(0xFF); } else {
185 scratch.clear();
186 param.encode_binary(&mut scratch);
187 hasher.write(&scratch);
188 }
189 }
190 hasher.finish()
191 }
192}
193
194impl Default for Singleflight {
195 fn default() -> Self {
196 Self::new()
197 }
198}
199
200#[cfg(test)]
201mod tests {
202 use super::*;
203
204 #[test]
205 fn singleflight_leader_when_empty() {
206 let sf = Singleflight::new();
207 let result = sf.try_join(42);
208 assert!(matches!(result, FlightResult::Leader(_)));
209 }
210
211 #[test]
212 fn singleflight_follower_when_in_flight() {
213 let sf = Singleflight::new();
214 let _leader = sf.try_join(42);
215 let result = sf.try_join(42);
216 assert!(matches!(result, FlightResult::Follower(_)));
217 }
218
219 #[test]
220 fn singleflight_different_keys_both_leaders() {
221 let sf = Singleflight::new();
222 let r1 = sf.try_join(42);
223 let r2 = sf.try_join(43);
224 assert!(matches!(r1, FlightResult::Leader(_)));
225 assert!(matches!(r2, FlightResult::Leader(_)));
226 }
227
228 #[test]
229 fn singleflight_complete_removes_from_map() {
230 let sf = Singleflight::new();
231 let leader = match sf.try_join(42) {
232 FlightResult::Leader(l) => l,
233 _ => panic!("expected leader"),
234 };
235 let err = BsqlError::from(bsql_driver_postgres::DriverError::Pool("test".into()));
236 leader.complete(&sf, Arc::new(Err(err)));
237
238 let result = sf.try_join(42);
240 assert!(matches!(result, FlightResult::Leader(_)));
241 }
242
243 #[test]
244 fn compute_key_same_inputs_same_key() {
245 let k1 = Singleflight::compute_key(123, &[]);
246 let k2 = Singleflight::compute_key(123, &[]);
247 assert_eq!(k1, k2);
248 }
249
250 #[test]
251 fn compute_key_different_sql_hash_different_key() {
252 let k1 = Singleflight::compute_key(123, &[]);
253 let k2 = Singleflight::compute_key(456, &[]);
254 assert_ne!(k1, k2);
255 }
256
257 #[test]
260 fn compute_key_same_params_same_key() {
261 let a = 42i32;
262 let b = 42i32;
263 let k1 = Singleflight::compute_key(100, &[&a]);
264 let k2 = Singleflight::compute_key(100, &[&b]);
265 assert_eq!(k1, k2);
266 }
267
268 #[test]
269 fn compute_key_different_params_different_key() {
270 let a = 42i32;
271 let b = 99i32;
272 let k1 = Singleflight::compute_key(100, &[&a]);
273 let k2 = Singleflight::compute_key(100, &[&b]);
274 assert_ne!(k1, k2);
275 }
276
277 #[test]
278 fn compute_key_different_sql_same_params_different_key() {
279 let a = 42i32;
280 let k1 = Singleflight::compute_key(100, &[&a]);
281 let k2 = Singleflight::compute_key(200, &[&a]);
282 assert_ne!(k1, k2);
283 }
284
285 #[test]
286 fn compute_key_null_param_handling() {
287 let null_val: Option<i32> = None;
289 let some_val: Option<i32> = Some(42);
290 let k1 = Singleflight::compute_key(100, &[&null_val]);
291 let k2 = Singleflight::compute_key(100, &[&some_val]);
292 assert_ne!(k1, k2, "NULL and Some(42) should produce different keys");
293 }
294
295 #[test]
296 fn compute_key_two_nulls_same_key() {
297 let a: Option<i32> = None;
298 let b: Option<i32> = None;
299 let k1 = Singleflight::compute_key(100, &[&a]);
300 let k2 = Singleflight::compute_key(100, &[&b]);
301 assert_eq!(k1, k2);
302 }
303
304 #[test]
305 fn compute_key_multiple_params() {
306 let a = 1i32;
307 let b = "hello";
308 let k1 = Singleflight::compute_key(100, &[&a, &b]);
309 let k2 = Singleflight::compute_key(100, &[&a, &b]);
310 assert_eq!(k1, k2);
311 }
312
313 #[test]
314 fn compute_key_param_order_matters() {
315 let a = 1i32;
316 let b = 2i32;
317 let k1 = Singleflight::compute_key(100, &[&a, &b]);
318 let k2 = Singleflight::compute_key(100, &[&b, &a]);
319 assert_ne!(k1, k2);
320 }
321
322 #[test]
325 fn leader_complete_notifies_follower() {
326 let sf = Arc::new(Singleflight::new());
327
328 let leader = match sf.try_join(42) {
329 FlightResult::Leader(l) => l,
330 _ => panic!("expected leader"),
331 };
332
333 let follower_state = match sf.try_join(42) {
334 FlightResult::Follower(state) => state,
335 _ => panic!("expected follower"),
336 };
337
338 let handle = std::thread::spawn(move || Singleflight::wait_for_result(&follower_state));
339
340 let err = BsqlError::from(bsql_driver_postgres::DriverError::Pool("test".into()));
341 leader.complete(&sf, Arc::new(Err(err)));
342
343 let received = handle.join().unwrap();
344 assert!(received.is_some());
345 assert!(received.unwrap().is_err());
346 }
347
348 #[test]
351 fn multiple_followers_receive_result() {
352 let sf = Arc::new(Singleflight::new());
353
354 let leader = match sf.try_join(42) {
355 FlightResult::Leader(l) => l,
356 _ => panic!("expected leader"),
357 };
358
359 let state1 = match sf.try_join(42) {
360 FlightResult::Follower(s) => s,
361 _ => panic!("expected follower 1"),
362 };
363 let state2 = match sf.try_join(42) {
364 FlightResult::Follower(s) => s,
365 _ => panic!("expected follower 2"),
366 };
367
368 let h1 = std::thread::spawn(move || Singleflight::wait_for_result(&state1));
369 let h2 = std::thread::spawn(move || Singleflight::wait_for_result(&state2));
370
371 let err = BsqlError::from(bsql_driver_postgres::DriverError::Pool("done".into()));
372 leader.complete(&sf, Arc::new(Err(err)));
373
374 let r1 = h1.join().unwrap();
375 let r2 = h2.join().unwrap();
376 assert!(r1.is_some());
377 assert!(r1.unwrap().is_err());
378 assert!(r2.is_some());
379 assert!(r2.unwrap().is_err());
380 }
381
382 #[test]
385 fn drop_leader_without_complete_cleans_up_map() {
386 let sf = Singleflight::new();
387
388 let leader = match sf.try_join(42) {
389 FlightResult::Leader(l) => l,
390 _ => panic!("expected leader"),
391 };
392
393 drop(leader);
397
398 let result = sf.try_join(42);
401 assert!(
402 matches!(result, FlightResult::Leader(_)),
403 "key should be removed from map after leader drop without complete"
404 );
405 }
406
407 #[test]
410 fn concurrent_stress_test() {
411 use std::sync::atomic::{AtomicUsize, Ordering};
412
413 let sf = Arc::new(Singleflight::new());
414 let leader_count = Arc::new(AtomicUsize::new(0));
415 let follower_count = Arc::new(AtomicUsize::new(0));
416
417 let mut handles = Vec::new();
418
419 for i in 0..10 {
421 let sf = Arc::clone(&sf);
422 let leaders = Arc::clone(&leader_count);
423 let followers = Arc::clone(&follower_count);
424 let key = (i % 5) as u64;
425
426 handles.push(std::thread::spawn(move || {
427 match sf.try_join(key) {
428 FlightResult::Leader(leader) => {
429 leaders.fetch_add(1, Ordering::Relaxed);
430 let err = BsqlError::from(bsql_driver_postgres::DriverError::Pool(
432 "stress".into(),
433 ));
434 leader.complete(&sf, Arc::new(Err(err)));
435 }
436 FlightResult::Follower(_state) => {
437 followers.fetch_add(1, Ordering::Relaxed);
438 }
439 }
440 }));
441 }
442
443 for h in handles {
444 h.join().unwrap();
445 }
446
447 let total = leader_count.load(Ordering::Relaxed) + follower_count.load(Ordering::Relaxed);
448 assert_eq!(total, 10, "all 10 threads should participate");
449 assert!(
451 leader_count.load(Ordering::Relaxed) >= 5,
452 "should have at least 5 leaders (one per key)"
453 );
454 }
455
456 #[test]
459 fn singleflight_default() {
460 let sf = Singleflight::default();
461 let result = sf.try_join(1);
463 assert!(matches!(result, FlightResult::Leader(_)));
464 }
465
466 fn _assert_send<T: Send>() {}
469 fn _assert_sync<T: Sync>() {}
470
471 #[test]
472 fn singleflight_is_send_and_sync() {
473 _assert_send::<Singleflight>();
474 _assert_sync::<Singleflight>();
475 }
476
477 #[test]
480 fn compute_key_string_params() {
481 let a = "hello";
482 let b = "world";
483 let k1 = Singleflight::compute_key(100, &[&a, &b]);
484 let k2 = Singleflight::compute_key(100, &[&a, &b]);
485 assert_eq!(k1, k2);
486 }
487
488 #[test]
489 fn compute_key_empty_params_consistent() {
490 let k1 = Singleflight::compute_key(0, &[]);
491 let k2 = Singleflight::compute_key(0, &[]);
492 assert_eq!(k1, k2);
493 }
494
495 #[test]
498 fn leader_complete_with_no_followers() {
499 let sf = Singleflight::new();
500 let leader = match sf.try_join(42) {
501 FlightResult::Leader(l) => l,
502 _ => panic!("expected leader"),
503 };
504 let err = BsqlError::from(bsql_driver_postgres::DriverError::Pool("solo".into()));
506 leader.complete(&sf, Arc::new(Err(err)));
507
508 let result = sf.try_join(42);
510 assert!(matches!(result, FlightResult::Leader(_)));
511 }
512
513 #[test]
516 fn follower_gets_none_when_leader_dropped_without_complete() {
517 let sf = Arc::new(Singleflight::new());
518
519 let leader = match sf.try_join(42) {
520 FlightResult::Leader(l) => l,
521 _ => panic!("expected leader"),
522 };
523
524 let follower_state = match sf.try_join(42) {
525 FlightResult::Follower(s) => s,
526 _ => panic!("expected follower"),
527 };
528
529 let handle = std::thread::spawn(move || {
530 let _ = follower_state;
538 });
539
540 drop(leader);
542
543 handle.join().unwrap();
544
545 let result = sf.try_join(42);
547 assert!(
548 matches!(result, FlightResult::Leader(_)),
549 "key should be removed from map after leader drop"
550 );
551 }
552
553 #[test]
556 fn new_leader_succeeds_after_previous_leader_dropped() {
557 let sf = Arc::new(Singleflight::new());
558
559 let leader1 = match sf.try_join(42) {
561 FlightResult::Leader(l) => l,
562 _ => panic!("expected leader"),
563 };
564 drop(leader1);
565
566 let leader2 = match sf.try_join(42) {
568 FlightResult::Leader(l) => l,
569 _ => panic!("expected new leader after previous leader drop"),
570 };
571
572 let follower_state = match sf.try_join(42) {
573 FlightResult::Follower(s) => s,
574 _ => panic!("expected follower for second leader"),
575 };
576
577 let handle = std::thread::spawn(move || Singleflight::wait_for_result(&follower_state));
578
579 let err = BsqlError::from(bsql_driver_postgres::DriverError::Pool("retry".into()));
580 leader2.complete(&sf, Arc::new(Err(err)));
581
582 let received = handle.join().unwrap();
583 assert!(received.is_some());
584 assert!(received.unwrap().is_err());
585 }
586}