Skip to main content

bytesandbrains_core/
pending_requests.rs

1use std::collections::{HashMap, VecDeque};
2use std::hash::Hash;
3use std::time::{Duration, Instant};
4
5use crate::PeerId;
6
7/// Unique identifier for an individual request within a query.
8/// Unlike QueryId which identifies a query operation, RequestId identifies
9/// a specific request-response exchange with a peer.
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
11pub struct RequestId(pub u64);
12
13impl RequestId {
14    pub fn new(id: u64) -> Self {
15        Self(id)
16    }
17}
18
19/// Composite key for tracking pending requests.
20/// Correlates responses by (peer_id, request_id) rather than by content.
21#[derive(Debug, Clone, PartialEq, Eq, Hash)]
22pub struct RequestKey {
23    pub peer_id: PeerId,
24    pub request_id: RequestId,
25}
26
27impl RequestKey {
28    pub fn new(peer_id: PeerId, request_id: RequestId) -> Self {
29        Self { peer_id, request_id }
30    }
31}
32
33#[derive(Debug, Clone)]
34pub struct PendingRequestManager<K, V> {
35    requests: HashMap<K, V>,
36    timeout_queue: VecDeque<TimeoutEntry<K>>,
37    default_timeout: Duration,
38}
39
40#[derive(Debug, Clone)]
41struct TimeoutEntry<K> {
42    key: K,
43    timeout: Instant,
44}
45
46pub enum InsertResult<V> {
47    Inserted,
48    Replaced(V),
49}
50
51impl<K, V> PendingRequestManager<K, V>
52where
53    K: Clone + Eq + Hash,
54{
55    pub fn new(default_timeout: Duration) -> Self {
56        Self {
57            requests: HashMap::new(),
58            timeout_queue: VecDeque::new(),
59            default_timeout,
60        }
61    }
62
63    pub fn insert(&mut self, key: K, data: V) -> InsertResult<V> {
64        if self.requests.contains_key(&key) {
65            let old_val = self.requests.insert(key.clone(), data).unwrap();
66            self.timeout_queue.push_back(TimeoutEntry {
67                key,
68                timeout: Instant::now() + self.default_timeout,
69            });
70            return InsertResult::Replaced(old_val);
71        }
72
73        let timeout = Instant::now() + self.default_timeout;
74
75        self.timeout_queue.push_back(TimeoutEntry {
76            key: key.clone(),
77            timeout,
78        });
79
80        self.requests.insert(key, data);
81        InsertResult::Inserted
82    }
83
84    pub fn remove(&mut self, key: &K) -> Option<V> {
85        self.requests.remove(key)
86    }
87
88    pub fn process_timeouts(&mut self) -> Vec<(K, V)> {
89        let now = Instant::now();
90        let mut timed_out = Vec::new();
91
92        while let Some(front) = self.timeout_queue.front() {
93            if front.timeout > now {
94                break;
95            }
96
97            let entry = self.timeout_queue.pop_front().unwrap();
98
99            if let Some(data) = self.requests.remove(&entry.key) {
100                timed_out.push((entry.key, data));
101            }
102        }
103
104        timed_out
105    }
106
107    pub fn len(&self) -> usize {
108        self.requests.len()
109    }
110
111    pub fn is_empty(&self) -> bool {
112        self.requests.is_empty()
113    }
114
115    pub fn contains_key(&self, key: &K) -> bool {
116        self.requests.contains_key(key)
117    }
118
119    pub fn keys(&self) -> impl Iterator<Item = &K> {
120        self.requests.keys()
121    }
122
123    pub fn get(&self, key: &K) -> Option<&V> {
124        self.requests.get(key)
125    }
126
127    pub fn get_mut(&mut self, key: &K) -> Option<&mut V> {
128        self.requests.get_mut(key)
129    }
130
131    pub fn iter(&self) -> impl Iterator<Item = (&K, &V)> {
132        self.requests.iter()
133    }
134
135    pub fn iter_mut(&mut self) -> impl Iterator<Item = (&K, &mut V)> {
136        self.requests.iter_mut()
137    }
138}
139
140impl<'a, K, V> IntoIterator for &'a PendingRequestManager<K, V>
141where
142    K: Clone + Eq + Hash,
143{
144    type Item = (&'a K, &'a V);
145    type IntoIter = std::collections::hash_map::Iter<'a, K, V>;
146
147    fn into_iter(self) -> Self::IntoIter {
148        self.requests.iter()
149    }
150}
151
152impl<'a, K, V> IntoIterator for &'a mut PendingRequestManager<K, V>
153where
154    K: Clone + Eq + Hash,
155{
156    type Item = (&'a K, &'a mut V);
157    type IntoIter = std::collections::hash_map::IterMut<'a, K, V>;
158
159    fn into_iter(self) -> Self::IntoIter {
160        self.requests.iter_mut()
161    }
162}
163
164/// Tracks pending request-response exchanges by (PeerId, RequestId).
165///
166/// This enables correct response correlation even when peers change their
167/// embedding (drift), since we match by who responded, not where they claim to be.
168///
169/// The value type `V` can be an enum to support multiple request types:
170/// ```ignore
171/// enum RequestData {
172///     Knn { search_embedding: Embedding, k: usize },
173///     Kfn { search_embedding: Embedding, k: usize },
174///     Ping,
175/// }
176/// let tracker: RequestTracker<RequestData> = RequestTracker::new(timeout);
177/// ```
178#[derive(Debug, Clone)]
179pub struct RequestTracker<V> {
180    inner: PendingRequestManager<RequestKey, V>,
181    next_request_id: u64,
182}
183
184impl<V> RequestTracker<V>
185where
186    V: Clone,
187{
188    pub fn new(default_timeout: Duration) -> Self {
189        Self {
190            inner: PendingRequestManager::new(default_timeout),
191            next_request_id: 1, // Start at 1 like libp2p
192        }
193    }
194
195    /// Insert a new pending request, returning the assigned RequestId.
196    pub fn insert(&mut self, peer_id: PeerId, data: V) -> RequestId {
197        let request_id = RequestId::new(self.next_request_id);
198        self.next_request_id += 1;
199        let key = RequestKey::new(peer_id, request_id);
200        self.inner.insert(key, data);
201        request_id
202    }
203
204    /// Remove a pending request by peer_id and request_id.
205    pub fn remove(&mut self, peer_id: &PeerId, request_id: &RequestId) -> Option<V> {
206        let key = RequestKey { peer_id: peer_id.clone(), request_id: *request_id };
207        self.inner.remove(&key)
208    }
209
210    /// Check if a request is pending for this peer.
211    pub fn is_pending(&self, peer_id: &PeerId, request_id: &RequestId) -> bool {
212        let key = RequestKey { peer_id: peer_id.clone(), request_id: *request_id };
213        self.inner.contains_key(&key)
214    }
215
216    /// Process timeouts, returning timed-out (RequestKey, V) pairs.
217    pub fn process_timeouts(&mut self) -> Vec<(RequestKey, V)> {
218        self.inner.process_timeouts()
219    }
220
221    pub fn len(&self) -> usize {
222        self.inner.len()
223    }
224
225    pub fn is_empty(&self) -> bool {
226        self.inner.is_empty()
227    }
228
229    /// Iterate over all pending requests.
230    pub fn iter(&self) -> impl Iterator<Item = (&RequestKey, &V)> {
231        self.inner.iter()
232    }
233
234    /// Remove all pending requests for a specific peer.
235    /// Returns all removed (RequestKey, V) pairs.
236    pub fn remove_all_for_peer(&mut self, peer_id: &PeerId) -> Vec<(RequestKey, V)> {
237        self.inner
238            .requests
239            .extract_if(|key, _| &key.peer_id == peer_id)
240            .collect()
241    }
242}
243
244#[cfg(test)]
245mod tests {
246    use super::*;
247    use std::thread;
248
249    #[test]
250    fn test_new_manager() {
251        let manager: PendingRequestManager<u32, String> =
252            PendingRequestManager::new(Duration::from_millis(100));
253
254        assert!(manager.is_empty());
255        assert_eq!(manager.len(), 0);
256    }
257
258    #[test]
259    fn test_insert_and_contains() {
260        let mut manager = PendingRequestManager::new(Duration::from_millis(100));
261
262        let result = manager.insert(1, "first".to_string());
263        assert!(matches!(result, InsertResult::Inserted));
264        assert_eq!(manager.len(), 1);
265        assert!(manager.contains_key(&1));
266        assert!(!manager.contains_key(&2));
267    }
268
269    #[test]
270    fn test_insert_duplicate_key() {
271        let mut manager = PendingRequestManager::new(Duration::from_millis(100));
272
273        let result1 = manager.insert(1, "first".to_string());
274        assert!(matches!(result1, InsertResult::Inserted));
275
276        let result2 = manager.insert(1, "second".to_string());
277        match result2 {
278            InsertResult::Replaced(returned_value) => assert_eq!(returned_value, "first"),
279            _ => panic!("Expected Replaced variant"),
280        }
281
282        assert_eq!(manager.len(), 1);
283        assert!(manager.contains_key(&1));
284
285        let removed = manager.remove(&1);
286        assert_eq!(removed, Some("second".to_string()));
287    }
288
289    #[test]
290    fn test_remove() {
291        let mut manager = PendingRequestManager::new(Duration::from_millis(100));
292
293        manager.insert(1, "test".to_string());
294        assert_eq!(manager.len(), 1);
295
296        let removed = manager.remove(&1);
297        assert_eq!(removed, Some("test".to_string()));
298        assert_eq!(manager.len(), 0);
299        assert!(!manager.contains_key(&1));
300
301        let removed2 = manager.remove(&2);
302        assert_eq!(removed2, None);
303    }
304
305    #[test]
306    fn test_timeout_processing() {
307        let mut manager = PendingRequestManager::new(Duration::from_millis(50));
308
309        manager.insert(1, "first".to_string());
310        manager.insert(2, "second".to_string());
311
312        thread::sleep(Duration::from_millis(60));
313
314        let timed_out = manager.process_timeouts();
315        assert_eq!(timed_out.len(), 2);
316        assert_eq!(manager.len(), 0);
317
318        let mut keys: Vec<_> = timed_out.iter().map(|(k, _)| *k).collect();
319        keys.sort();
320        assert_eq!(keys, vec![1, 2]);
321    }
322
323    #[test]
324    fn test_no_timeouts() {
325        let mut manager = PendingRequestManager::new(Duration::from_millis(100));
326
327        manager.insert(1, "test".to_string());
328
329        let timed_out = manager.process_timeouts();
330        assert_eq!(timed_out.len(), 0);
331        assert_eq!(manager.len(), 1);
332    }
333
334    #[test]
335    fn test_manual_remove_before_timeout() {
336        let mut manager = PendingRequestManager::new(Duration::from_millis(50));
337
338        manager.insert(1, "test".to_string());
339
340        let removed = manager.remove(&1);
341        assert_eq!(removed, Some("test".to_string()));
342
343        thread::sleep(Duration::from_millis(60));
344
345        let timed_out = manager.process_timeouts();
346        assert_eq!(timed_out.len(), 0);
347        assert_eq!(manager.len(), 0);
348    }
349}