aft/executor/
single_flight.rs1use std::{collections::HashMap, hash::Hash, sync::Arc};
2
3use parking_lot::{Condvar, Mutex};
4
5pub struct SingleFlight<K, T> {
11 inner: Mutex<HashMap<K, FlightEntry<T>>>,
12 changed: Condvar,
13}
14
15enum FlightEntry<T> {
16 Building { generation: u64 },
17 Ready { generation: u64, value: Arc<T> },
18}
19
20struct BuildingCleanup<'a, K, T>
21where
22 K: Clone + Eq + Hash,
23{
24 flight: &'a SingleFlight<K, T>,
25 id: K,
26 generation: u64,
27 installed: bool,
28}
29
30impl<'a, K, T> BuildingCleanup<'a, K, T>
31where
32 K: Clone + Eq + Hash,
33{
34 fn new(flight: &'a SingleFlight<K, T>, id: K, generation: u64) -> Self {
35 Self {
36 flight,
37 id,
38 generation,
39 installed: false,
40 }
41 }
42
43 fn disarm(&mut self) {
44 self.installed = true;
45 }
46}
47
48impl<K, T> Drop for BuildingCleanup<'_, K, T>
49where
50 K: Clone + Eq + Hash,
51{
52 fn drop(&mut self) {
53 if !self.installed {
54 self.flight.clear_building(&self.id, self.generation);
55 }
56 }
57}
58
59impl<K, T> Default for SingleFlight<K, T>
60where
61 K: Clone + Eq + Hash,
62{
63 fn default() -> Self {
64 Self::new()
65 }
66}
67
68impl<K, T> SingleFlight<K, T>
69where
70 K: Clone + Eq + Hash,
71{
72 pub fn new() -> Self {
73 Self {
74 inner: Mutex::new(HashMap::new()),
75 changed: Condvar::new(),
76 }
77 }
78
79 pub fn get_or_build<E>(
91 &self,
92 id: K,
93 generation: u64,
94 build_fn: impl FnOnce() -> Result<T, E>,
95 ) -> Result<Arc<T>, E> {
96 let mut build_fn = Some(build_fn);
97
98 loop {
99 let mut guard = self.inner.lock();
100 match guard.get(&id) {
101 Some(FlightEntry::Ready {
102 generation: ready_generation,
103 value,
104 }) if *ready_generation >= generation => return Ok(Arc::clone(value)),
105 Some(FlightEntry::Building {
106 generation: building_generation,
107 }) if *building_generation >= generation => {
108 self.changed.wait(&mut guard);
109 }
110 _ => {
111 guard.insert(id.clone(), FlightEntry::Building { generation });
112 drop(guard);
113
114 let mut cleanup = BuildingCleanup::new(self, id.clone(), generation);
115 let build = build_fn
116 .take()
117 .expect("single-flight build function used more than once");
118 let built = Arc::new(build()?);
119
120 let mut superseded = false;
121 loop {
122 let mut guard = self.inner.lock();
123 match guard.get(&id) {
124 Some(FlightEntry::Building {
125 generation: current_generation,
126 }) if *current_generation > generation => {
127 superseded = true;
128 self.changed.wait(&mut guard);
129 }
130 Some(FlightEntry::Ready {
131 generation: current_generation,
132 value,
133 }) if *current_generation >= generation => {
134 let value = Arc::clone(value);
135 cleanup.disarm();
136 self.changed.notify_all();
137 return Ok(value);
138 }
139 _ if superseded => {
140 cleanup.disarm();
141 self.changed.notify_all();
142 return Ok(built);
143 }
144 _ => {
145 guard.insert(
146 id.clone(),
147 FlightEntry::Ready {
148 generation,
149 value: Arc::clone(&built),
150 },
151 );
152 cleanup.disarm();
153 self.changed.notify_all();
154 return Ok(built);
155 }
156 }
157 }
158 }
159 }
160 }
161 }
162
163 fn clear_building(&self, id: &K, generation: u64) {
164 let mut guard = self.inner.lock();
165 if matches!(
166 guard.get(id),
167 Some(FlightEntry::Building {
168 generation: current_generation,
169 }) if *current_generation == generation
170 ) {
171 guard.remove(id);
172 }
173 self.changed.notify_all();
174 }
175}