1use std::{
2 collections::HashMap,
3 hash::Hash,
4 marker::PhantomData,
5 sync::Arc,
6 time::{Duration, SystemTime},
7};
8
9use anyhow::bail;
10use async_trait::async_trait;
11use futures::future::join_all;
12use ordered_float::OrderedFloat;
13use parking_lot::{Mutex, RwLock};
14use priority_queue::PriorityQueue;
15use std::fmt::Debug;
16use tokio::time::timeout;
17
18use crate::{backoff::BackoffTracker, clone_to_async, error::LogError, short_name, time::Sleeper};
19
20use super::service::MutJob;
21
22pub type Queue<T> = Arc<RwLock<PriorityQueue<T, OrderedFloat<f64>>>>;
24
25pub trait Receiver<T: Hash + Eq> {
26 fn receiver(&self, config: QueueReceiveConfig, sleeper: Arc<dyn Sleeper>) -> QueueReceiver<T>;
27}
28
29impl<T: Hash + Eq + Clone> Receiver<T> for Queue<T> {
30 fn receiver(&self, config: QueueReceiveConfig, sleeper: Arc<dyn Sleeper>) -> QueueReceiver<T> {
31 QueueReceiver::new(self.clone(), config, sleeper)
32 }
33}
34
35pub struct QueueReceiver<T: Hash + Eq> {
36 queue: Queue<T>,
37 last_appearances: HashMap<T, SystemTime>,
39 config: QueueReceiveConfig,
40 sleeper: Arc<dyn Sleeper>,
41}
42
43impl<T: Hash + Eq + Clone> QueueReceiver<T> {
45 pub fn new(queue: Queue<T>, config: QueueReceiveConfig, sleeper: Arc<dyn Sleeper>) -> Self {
46 Self {
47 queue,
48 config,
49 last_appearances: HashMap::new(),
50 sleeper,
51 }
52 }
53
54 pub fn len(&self) -> usize {
55 self.queue.read().len()
56 }
57
58 pub fn is_empty(&self) -> bool {
59 self.queue.read().is_empty()
60 }
61
62 pub async fn recv(&mut self) -> (Vec<T>, usize) {
63 self.recv_with(&[]).await
64 }
65
66 pub fn try_recv(&mut self) -> (Vec<T>, usize) {
67 self.try_recv_with(&[])
68 }
69
70 pub async fn recv_with(&mut self, overrides: &[Override]) -> (Vec<T>, usize) {
74 let config = self.config.with(overrides);
75 let start = SystemTime::now();
76 loop {
77 if self.queue.read().len() < config.min_chunk_size {
78 self.sleeper.sleep(config.poll_interval).await;
79 } else {
80 let x = self.try_recv_with(overrides);
81 if !x.0.is_empty() {
82 return x;
83 }
84 }
85 if let Some(max_wait) = config.max_wait {
86 if SystemTime::now().duration_since(start).unwrap() > max_wait {
87 return (vec![], self.queue.read().len());
88 }
89 }
90 }
91 }
92
93 pub fn try_recv_with(&mut self, overrides: &[Override]) -> (Vec<T>, usize) {
96 let mut reader = self.queue.write();
97 let max_return = self
98 .config
99 .with(overrides)
100 .max_chunk_size
101 .unwrap_or(usize::MAX);
102 let actual_return = std::cmp::min(max_return, reader.len());
103 let mut to_process = Vec::with_capacity(actual_return);
104
105 while to_process.len() < max_return {
106 match reader.pop() {
107 Some((item, _)) => {
108 if let Some(cooldown) = self.config.with(overrides).cooldown {
109 let now = SystemTime::now();
110 let last = self
111 .last_appearances
112 .get(&item)
113 .unwrap_or(&SystemTime::UNIX_EPOCH);
114 if now.duration_since(*last).unwrap() < cooldown {
115 continue;
116 };
117 self.last_appearances.insert(item.clone(), now);
118 }
119 to_process.push(item);
120 }
121 None => break,
122 }
123 }
124 (to_process, reader.len())
125 }
126}
127
128#[derive(Clone, Copy, Debug)]
129pub struct QueueReceiveConfig {
130 pub poll_interval: Duration,
132 pub max_wait: Option<Duration>,
134 pub max_chunk_size: Option<usize>,
136 pub min_chunk_size: usize,
138 pub cooldown: Option<Duration>,
141 pub batch_strategy: BatchStrategy,
144}
145
146impl Default for QueueReceiveConfig {
147 fn default() -> Self {
148 Self {
149 poll_interval: Duration::from_secs(1),
150 max_wait: None, max_chunk_size: None,
152 min_chunk_size: 1,
153 cooldown: None,
154 batch_strategy: BatchStrategy::Responsive,
155 }
156 }
157}
158
159#[derive(Clone, Copy, Debug)]
160pub enum BatchStrategy {
161 Responsive,
163 Efficient,
165}
166
167#[derive(Clone, Copy, Debug)]
168pub enum Override {
169 PollInterval(Duration),
170 MaxWait(Option<Duration>),
171 MaxChunkSize(Option<usize>),
172 MinChunkSize(usize),
173 Cooldown(Option<Duration>),
174 BatchStrategy(BatchStrategy),
175}
176
177impl QueueReceiveConfig {
178 fn with(mut self, settings: &[Override]) -> QueueReceiveConfig {
179 for setting in settings {
180 match setting {
181 Override::PollInterval(x) => self.poll_interval = *x,
182 Override::MaxWait(x) => self.max_wait = *x,
183 Override::MaxChunkSize(x) => self.max_chunk_size = *x,
184 Override::Cooldown(x) => self.cooldown = *x,
185 Override::BatchStrategy(x) => self.batch_strategy = *x,
186 Override::MinChunkSize(x) => self.min_chunk_size = *x,
187 }
188 }
189 self
190 }
191}
192
193impl From<QueueReceiveConfig> for Vec<Override> {
194 fn from(value: QueueReceiveConfig) -> Self {
195 vec![
196 Override::PollInterval(value.poll_interval),
197 Override::MaxWait(value.max_wait),
198 Override::MaxChunkSize(value.max_chunk_size),
199 Override::MinChunkSize(value.min_chunk_size),
200 Override::Cooldown(value.cooldown),
201 Override::BatchStrategy(value.batch_strategy),
202 ]
203 }
204}
205
206#[async_trait]
207pub trait BatchProcessor<In, Out = ()> {
208 type Intermediate;
209 async fn prepare_item(&self, input: In) -> anyhow::Result<Self::Intermediate>;
212
213 async fn process_batch(&self, mid: Vec<Self::Intermediate>) -> anyhow::Result<Out>;
216}
217
218pub struct Dispatch<P, In>
219where
220 P: BatchProcessor<In>,
221 In: Hash + Eq + Clone,
222{
223 provider: Mutex<QueueReceiver<In>>,
224 processor: P,
225 timeout: u64,
226 backoff: Option<BackoffTracker<In>>,
227 _phantom: PhantomData<In>,
228}
229
230impl<P: BatchProcessor<In>, In: Hash + Eq + Clone> Dispatch<P, In> {
231 pub fn new(
232 provider: QueueReceiver<In>,
233 processor: P,
234 timeout: u64,
235 backoff: Option<BackoffTracker<In>>,
236 ) -> Self {
237 Self {
238 provider: Mutex::new(provider),
239 processor,
240 timeout,
241 backoff,
242 _phantom: PhantomData,
243 }
244 }
245
246 pub fn is_empty(&self) -> bool {
247 self.provider.lock().is_empty()
248 }
249}
250
251#[async_trait]
252impl<P, In> MutJob for Dispatch<P, In>
253where
254 P: BatchProcessor<In> + Send + Sync,
255 In: Hash + Eq + Clone + Send + Sync + Debug,
256 P::Intermediate: Send + Sync,
257{
258 async fn run_once_mut(&mut self) -> anyhow::Result<()> {
259 process_batch(
260 &self.provider,
261 self.timeout,
262 &self.processor,
263 self.backoff.as_mut(),
264 )
265 .await
266 }
267}
268
269pub async fn process_batch<T: Hash + Eq + Clone, Ret, Processor: BatchProcessor<T, Ret>>(
278 receiver: &Mutex<QueueReceiver<T>>,
279 timeout_secs: u64,
280 processor: &Processor,
281 mut backoff: Option<&mut BackoffTracker<T>>,
282) -> anyhow::Result<Ret>
283where
284 T: std::fmt::Debug,
285{
286 let input_name = &short_name::<T>();
288 let proc_name = &short_name::<Processor>();
289 match timeout(Duration::from_secs(timeout_secs), async {
290 let (mut items, remaining) = receiver.lock().recv_with(&[]).await;
292 if let Some(backoff) = backoff.as_ref() {
293 items.retain(|x| backoff.is_ready(x));
294 }
295 if !items.is_empty() {
296 tracing::info!(
297 "Processing {} '{input_name}'s with '{proc_name}'. {remaining} remain.",
298 items.len()
299 );
300 } else if remaining > 0 {
301 tracing::debug!(
302 "Processing {} '{input_name}'s with '{proc_name}'. {remaining} remain.",
303 items.len()
304 );
305 }
306
307 let (oks, errs) = join_all(items.into_iter().map(clone_to_async! { (processor) |item|
309 (item.clone(), processor
310 .prepare_item(item.clone())
311 .await
312 .log_with_context_passthrough(|| format!("{proc_name} - preparing {item:?}")))
313 }))
314 .await
315 .into_iter()
316 .partition::<Vec<_>, _>(|x| x.1.is_ok());
317
318 let (ok_inputs, ok_intermediates) = oks
320 .into_iter()
321 .map(|(t, r)| (t, r.unwrap()))
322 .unzip::<_, _, Vec<_>, Vec<_>>();
323 let (err_inputs, _errors) = errs
324 .into_iter()
325 .map(|(t, r)| (t, r.err().unwrap()))
326 .unzip::<_, _, Vec<_>, Vec<_>>();
327
328 let batch_result = processor
330 .process_batch(ok_intermediates)
331 .await
332 .log_with_context_passthrough(|| format!("'{proc_name}' batch processor"));
333
334 if batch_result.is_err() {
336 let all_inputs = ok_inputs
339 .into_iter()
340 .chain(err_inputs.into_iter())
341 .collect::<Vec<_>>();
342 for input in all_inputs.clone() {
343 if let Some(b) = backoff.as_mut() { b.event(input) }
344 }
345 bail!("'{proc_name}' failed to process batch inputs, see logs. failed inputs: {all_inputs:?}");
346 } else {
347 for successful_input in ok_inputs {
348 if let Some(b) = backoff.as_mut() { b.clear(&successful_input) }
349 }
350 for failed_input in err_inputs.clone() {
351 if let Some(b) = backoff.as_mut() { b.event(failed_input) }
352 }
353 if !err_inputs.is_empty() {
354 bail!("'{proc_name}' failed to process some inputs, see logs. failed inputs: {err_inputs:?}");
355 }
356 batch_result
357 }
358 })
359 .await
360 {
361 Ok(Ok(x)) => Ok(x),
362 Ok(Err(e)) => Err(e),
363 Err(e) => {
364 Err(anyhow::anyhow!(e)).log_context_passthrough(&format!("'{proc_name}' timed out"))
365 }
366 }
367}