radicle_protocol/fetcher/
state.rs1pub mod command;
9pub mod event;
10
11pub use command::Command;
12pub use event::Event;
13use serde::Serialize;
14
15use std::collections::{BTreeMap, VecDeque};
16use std::num::NonZeroUsize;
17use std::time;
18
19use radicle_core::{NodeId, RepoId};
20
21use crate::fetcher::RefsToFetch;
22
23pub const MAX_FETCH_QUEUE_SIZE: usize = 128;
25pub const MAX_CONCURRENCY: NonZeroUsize = NonZeroUsize::MIN;
27
28#[derive(Clone, Debug, PartialEq, Eq, Serialize)]
47pub struct FetcherState {
48 active: BTreeMap<RepoId, ActiveFetch>,
50 queues: BTreeMap<NodeId, Queue>,
52 config: Config,
54}
55
56impl Default for FetcherState {
57 fn default() -> Self {
58 Self::new(Config::default())
59 }
60}
61
62impl FetcherState {
63 pub fn new(config: Config) -> Self {
65 Self {
66 active: BTreeMap::new(),
67 queues: BTreeMap::new(),
68 config,
69 }
70 }
71}
72
73impl FetcherState {
74 pub fn handle(&mut self, command: Command) -> Event {
80 match command {
81 Command::Fetch(fetch) => self.fetch(fetch).into(),
82 Command::Fetched(fetched) => self.fetched(fetched).into(),
83 Command::Cancel(cancel) => self.cancel(cancel).into(),
84 }
85 }
86
87 pub fn fetch(
97 &mut self,
98 command::Fetch {
99 from,
100 rid,
101 refs,
102 timeout,
103 }: command::Fetch,
104 ) -> event::Fetch {
105 if let Some(active) = self.active.get(&rid) {
106 if active.refs == refs && active.from == from {
107 return event::Fetch::AlreadyFetching { rid, from };
108 } else {
109 return self.enqueue(rid, from, refs, timeout);
110 }
111 }
112
113 if self.is_at_node_capacity(&from) {
114 self.enqueue(rid, from, refs, timeout)
115 } else {
116 self.active.insert(
117 rid,
118 ActiveFetch {
119 from,
120 refs: refs.clone(),
121 },
122 );
123 event::Fetch::Started {
124 rid,
125 from,
126 refs,
127 timeout,
128 }
129 }
130 }
131
132 pub fn fetched(&mut self, command::Fetched { from, rid }: command::Fetched) -> event::Fetched {
139 match self.active.remove(&rid) {
140 None => event::Fetched::NotFound { from, rid },
141 Some(ActiveFetch { from, refs }) => event::Fetched::Completed { from, rid, refs },
142 }
143 }
144
145 pub fn dequeue(&mut self, from: &NodeId) -> Option<QueuedFetch> {
150 let is_at_capacity = self.is_at_node_capacity(from);
151 let queue = self.queues.get_mut(from)?;
152 let active = &self.active;
153 queue.try_dequeue(|QueuedFetch { rid, .. }| !is_at_capacity && !active.contains_key(rid))
154 }
155
156 pub fn cancel(&mut self, command::Cancel { from }: command::Cancel) -> event::Cancel {
161 let cancelled: Vec<_> = self
162 .active
163 .iter()
164 .filter_map(|(rid, f)| (f.from == from).then_some(*rid))
165 .collect();
166 let ongoing: BTreeMap<_, _> = cancelled
167 .iter()
168 .filter_map(|rid| self.active.remove(rid).map(|f| (*rid, f)))
169 .collect();
170 let ongoing = (!ongoing.is_empty()).then_some(ongoing);
171 let queued = self.queues.remove(&from).filter(|queue| !queue.is_empty());
172
173 match (ongoing, queued) {
174 (None, None) => event::Cancel::Unexpected { from },
175 (ongoing, queued) => event::Cancel::Canceled {
176 from,
177 active: ongoing.unwrap_or_default(),
178 queued: queued.map(|q| q.queue).unwrap_or_default(),
179 },
180 }
181 }
182
183 fn enqueue(
184 &mut self,
185 rid: RepoId,
186 from: NodeId,
187 refs: RefsToFetch,
188 timeout: time::Duration,
189 ) -> event::Fetch {
190 let queue = self
191 .queues
192 .entry(from)
193 .or_insert(Queue::new(self.config.maximum_queue_size));
194 match queue.enqueue(QueuedFetch { rid, refs, timeout }) {
195 Enqueue::CapacityReached(QueuedFetch { rid, refs, timeout }) => {
196 event::Fetch::QueueAtCapacity {
197 rid,
198 from,
199 refs,
200 timeout,
201 capacity: queue.len(),
202 }
203 }
204 Enqueue::Queued => event::Fetch::Queued { rid, from },
205 Enqueue::Merged => event::Fetch::Queued { rid, from },
206 }
207 }
208}
209
210impl FetcherState {
211 pub fn queued_fetches(&self) -> &BTreeMap<NodeId, Queue> {
213 &self.queues
214 }
215
216 pub fn active_fetches(&self) -> &BTreeMap<RepoId, ActiveFetch> {
218 &self.active
219 }
220
221 pub fn get_active_fetch(&self, rid: &RepoId) -> Option<&ActiveFetch> {
224 self.active.get(rid)
225 }
226
227 fn is_at_node_capacity(&self, node: &NodeId) -> bool {
233 let count = self.active.values().filter(|f| &f.from == node).count();
234 count >= self.config.maximum_concurrency.into()
235 }
236}
237
238#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize)]
240pub struct Config {
241 maximum_concurrency: NonZeroUsize,
243 maximum_queue_size: MaxQueueSize,
245}
246
247impl Config {
248 pub fn new() -> Self {
249 Self::default()
250 }
251
252 pub fn with_max_capacity(mut self, capacity: MaxQueueSize) -> Self {
254 self.maximum_queue_size = capacity;
255 self
256 }
257
258 pub fn with_max_concurrency(mut self, concurrency: NonZeroUsize) -> Self {
260 self.maximum_concurrency = concurrency;
261 self
262 }
263}
264
265impl Default for Config {
266 fn default() -> Self {
267 Self {
268 maximum_concurrency: MAX_CONCURRENCY,
269 maximum_queue_size: MaxQueueSize::default(),
270 }
271 }
272}
273
274#[derive(Clone, Debug, PartialEq, Eq, Serialize)]
276pub struct ActiveFetch {
277 pub from: NodeId,
278 pub refs: RefsToFetch,
279}
280
281impl ActiveFetch {
282 pub fn from(&self) -> &NodeId {
284 &self.from
285 }
286
287 pub fn refs(&self) -> &RefsToFetch {
289 &self.refs
290 }
291}
292
293#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize)]
295pub struct QueuedFetch {
296 pub rid: RepoId,
298 pub refs: RefsToFetch,
300 pub timeout: time::Duration,
302}
303
304#[derive(Clone, Debug, PartialEq, Eq, Serialize)]
309pub struct Queue {
310 queue: VecDeque<QueuedFetch>,
311 max_queue_size: MaxQueueSize,
312}
313
314#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize)]
316#[serde(transparent)]
317pub struct MaxQueueSize(usize);
318
319impl MaxQueueSize {
320 pub const MIN: Self = MaxQueueSize(1);
322
323 pub fn new(size: NonZeroUsize) -> Self {
325 Self(size.into())
326 }
327
328 pub fn as_usize(&self) -> usize {
329 self.0
330 }
331
332 fn is_exceeded_by(&self, n: usize) -> bool {
334 n >= self.0
335 }
336}
337
338impl Default for MaxQueueSize {
339 fn default() -> Self {
340 Self(MAX_FETCH_QUEUE_SIZE)
341 }
342}
343
344#[must_use]
346#[derive(Debug, PartialEq, Eq)]
347pub(super) enum Enqueue {
348 CapacityReached(QueuedFetch),
351 Queued,
353 Merged,
354}
355
356impl Queue {
357 pub(super) fn new(max_queue_size: MaxQueueSize) -> Self {
359 Self {
360 queue: VecDeque::with_capacity(max_queue_size.0),
361 max_queue_size,
362 }
363 }
364
365 pub(super) fn len(&self) -> usize {
367 self.queue.len()
368 }
369
370 pub(super) fn is_empty(&self) -> bool {
372 self.queue.is_empty()
373 }
374
375 pub(super) fn enqueue(&mut self, fetch: QueuedFetch) -> Enqueue {
378 if let Some(existing) = self.queue.iter_mut().find(|qf| qf.rid == fetch.rid) {
379 existing.refs = existing.refs.clone().merge(fetch.refs);
380 existing.timeout = existing.timeout.max(fetch.timeout);
382 return Enqueue::Merged;
383 }
384
385 if self.max_queue_size.is_exceeded_by(self.queue.len()) {
386 Enqueue::CapacityReached(fetch)
387 } else {
388 self.queue.push_back(fetch);
389 Enqueue::Queued
390 }
391 }
392
393 pub(super) fn try_dequeue<P>(&mut self, predicate: P) -> Option<QueuedFetch>
396 where
397 P: FnOnce(&QueuedFetch) -> bool,
398 {
399 let fetch = self.dequeue()?;
400 if predicate(&fetch) {
401 Some(fetch)
402 } else {
403 self.queue.push_front(fetch);
404 None
405 }
406 }
407
408 pub(super) fn dequeue(&mut self) -> Option<QueuedFetch> {
410 self.queue.pop_front()
411 }
412
413 pub fn iter<'a>(&'a self) -> QueueIter<'a> {
415 QueueIter {
416 inner: self.queue.iter(),
417 }
418 }
419}
420
421pub struct QueueIter<'a> {
423 inner: std::collections::vec_deque::Iter<'a, QueuedFetch>,
424}
425
426impl<'a> Iterator for QueueIter<'a> {
427 type Item = &'a QueuedFetch;
428
429 fn next(&mut self) -> Option<Self::Item> {
430 self.inner.next()
431 }
432}
433
434impl<'a> IntoIterator for &'a Queue {
435 type Item = &'a QueuedFetch;
436 type IntoIter = QueueIter<'a>;
437
438 fn into_iter(self) -> Self::IntoIter {
439 self.iter()
440 }
441}