Skip to main content

aft/executor/
single_flight.rs

1use std::{collections::HashMap, hash::Hash, sync::Arc};
2
3use parking_lot::{Condvar, Mutex};
4
5/// Generation-aware single-flight cache.
6///
7/// Calls for the same key and generation share one in-flight build. A newer
8/// generation supersedes an older in-flight build; the older result is not
9/// installed if the entry has moved on by the time it finishes.
10pub struct SingleFlight<K, T> {
11    inner: Mutex<HashMap<K, FlightEntry<T>>>,
12    changed: Condvar,
13}
14
15enum FlightEntry<T> {
16    Building { generation: u64 },
17    Ready { generation: u64, value: Arc<T> },
18}
19
20struct BuildingCleanup<'a, K, T>
21where
22    K: Clone + Eq + Hash,
23{
24    flight: &'a SingleFlight<K, T>,
25    id: K,
26    generation: u64,
27    installed: bool,
28}
29
30impl<'a, K, T> BuildingCleanup<'a, K, T>
31where
32    K: Clone + Eq + Hash,
33{
34    fn new(flight: &'a SingleFlight<K, T>, id: K, generation: u64) -> Self {
35        Self {
36            flight,
37            id,
38            generation,
39            installed: false,
40        }
41    }
42
43    fn disarm(&mut self) {
44        self.installed = true;
45    }
46}
47
48impl<K, T> Drop for BuildingCleanup<'_, K, T>
49where
50    K: Clone + Eq + Hash,
51{
52    fn drop(&mut self) {
53        if !self.installed {
54            self.flight.clear_building(&self.id, self.generation);
55        }
56    }
57}
58
59impl<K, T> Default for SingleFlight<K, T>
60where
61    K: Clone + Eq + Hash,
62{
63    fn default() -> Self {
64        Self::new()
65    }
66}
67
68impl<K, T> SingleFlight<K, T>
69where
70    K: Clone + Eq + Hash,
71{
72    pub fn new() -> Self {
73        Self {
74            inner: Mutex::new(HashMap::new()),
75            changed: Condvar::new(),
76        }
77    }
78
79    /// Return the cached value for `id` at `generation`, or build it once.
80    ///
81    /// The build function runs outside the internal lock. Concurrent callers for
82    /// the same `(id, generation)` wait for the in-flight build and receive the
83    /// installed value. If a newer generation supersedes this call while its
84    /// build is running, this call returns the newer ready value instead of
85    /// overwriting it with stale work.
86    ///
87    /// If the builder returns an error or panics, the in-flight marker is cleared
88    /// and waiters are notified so the key can be retried instead of remaining
89    /// permanently stuck in `Building`.
90    pub fn get_or_build<E>(
91        &self,
92        id: K,
93        generation: u64,
94        build_fn: impl FnOnce() -> Result<T, E>,
95    ) -> Result<Arc<T>, E> {
96        let mut build_fn = Some(build_fn);
97
98        loop {
99            let mut guard = self.inner.lock();
100            match guard.get(&id) {
101                Some(FlightEntry::Ready {
102                    generation: ready_generation,
103                    value,
104                }) if *ready_generation >= generation => return Ok(Arc::clone(value)),
105                Some(FlightEntry::Building {
106                    generation: building_generation,
107                }) if *building_generation >= generation => {
108                    self.changed.wait(&mut guard);
109                }
110                _ => {
111                    guard.insert(id.clone(), FlightEntry::Building { generation });
112                    drop(guard);
113
114                    let mut cleanup = BuildingCleanup::new(self, id.clone(), generation);
115                    let build = build_fn
116                        .take()
117                        .expect("single-flight build function used more than once");
118                    let built = Arc::new(build()?);
119
120                    let mut superseded = false;
121                    loop {
122                        let mut guard = self.inner.lock();
123                        match guard.get(&id) {
124                            Some(FlightEntry::Building {
125                                generation: current_generation,
126                            }) if *current_generation > generation => {
127                                superseded = true;
128                                self.changed.wait(&mut guard);
129                            }
130                            Some(FlightEntry::Ready {
131                                generation: current_generation,
132                                value,
133                            }) if *current_generation >= generation => {
134                                let value = Arc::clone(value);
135                                cleanup.disarm();
136                                self.changed.notify_all();
137                                return Ok(value);
138                            }
139                            _ if superseded => {
140                                cleanup.disarm();
141                                self.changed.notify_all();
142                                return Ok(built);
143                            }
144                            _ => {
145                                guard.insert(
146                                    id.clone(),
147                                    FlightEntry::Ready {
148                                        generation,
149                                        value: Arc::clone(&built),
150                                    },
151                                );
152                                cleanup.disarm();
153                                self.changed.notify_all();
154                                return Ok(built);
155                            }
156                        }
157                    }
158                }
159            }
160        }
161    }
162
163    fn clear_building(&self, id: &K, generation: u64) {
164        let mut guard = self.inner.lock();
165        if matches!(
166            guard.get(id),
167            Some(FlightEntry::Building {
168                generation: current_generation,
169            }) if *current_generation == generation
170        ) {
171            guard.remove(id);
172        }
173        self.changed.notify_all();
174    }
175}