1use backon::{ExponentialBuilder, Retryable};
35use std::{
36 collections::HashMap, future::Future, marker::PhantomData, ops::Deref, pin::Pin, time::Duration,
37};
38use tokio::{
39 sync::{oneshot::Receiver, Mutex},
40 time::{interval_at, Instant},
41};
42use tracing::field::Empty;
43use ulid::Ulid;
44
45use crate::{context, cursor::Args, Aggregator, AggregatorEvent, Executor, ReadAggregator};
46
47#[derive(Clone)]
52pub enum RoutingKey {
53 All,
55 Value(Option<String>),
57}
58
59#[derive(Clone)]
83pub struct Context<'a, E: Executor> {
84 context: context::RwContext,
85 pub executor: &'a E,
87}
88
89impl<'a, E: Executor> Deref for Context<'a, E> {
90 type Target = context::RwContext;
91
92 fn deref(&self) -> &Self::Target {
93 &self.context
94 }
95}
96
97pub trait Handler<E: Executor>: Sync + Send {
105 fn handle<'a>(
110 &'a self,
111 context: &'a Context<'a, E>,
112 event: &'a crate::Event,
113 ) -> Pin<Box<dyn Future<Output = anyhow::Result<()>> + Send + 'a>>;
114
115 fn aggregator_type(&self) -> &'static str;
117 fn event_name(&self) -> &'static str;
119}
120
121pub struct SubscriptionBuilder<E: Executor> {
143 key: String,
144 handlers: HashMap<String, Box<dyn Handler<E>>>,
145 context: context::RwContext,
146 routing_key: RoutingKey,
147 delay: Option<Duration>,
148 chunk_size: u16,
149 is_accept_failure: bool,
150 retry: Option<u8>,
151 aggregators: HashMap<String, String>,
152 safety_disabled: bool,
153 shutdown_rx: Option<Mutex<Receiver<()>>>,
154}
155
156impl<E: Executor + 'static> SubscriptionBuilder<E> {
157 pub fn new(key: impl Into<String>) -> Self {
161 Self {
162 key: key.into(),
163 handlers: HashMap::new(),
164 safety_disabled: true,
165 context: Default::default(),
166 delay: None,
167 retry: Some(30),
168 chunk_size: 300,
169 is_accept_failure: false,
170 routing_key: RoutingKey::Value(None),
171 aggregators: Default::default(),
172 shutdown_rx: None,
173 }
174 }
175
176 pub fn safety_check(mut self) -> Self {
180 self.safety_disabled = false;
181
182 self
183 }
184
185 pub fn handler<H: Handler<E> + 'static>(mut self, h: H) -> Self {
191 let key = format!("{}_{}", h.aggregator_type(), h.event_name());
192 if self.handlers.insert(key.to_owned(), Box::new(h)).is_some() {
193 panic!("Cannot register event handler: key {} already exists", key);
194 }
195 self
196 }
197
198 pub fn skip<EV: AggregatorEvent + Send + Sync + 'static>(self) -> Self {
206 self.handler(SkipHandler::<EV>(PhantomData))
207 }
208
209 pub fn data<D: Send + Sync + 'static>(self, v: D) -> Self {
213 self.context.insert(v);
214
215 self
216 }
217
218 pub fn accept_failure(mut self) -> Self {
223 self.is_accept_failure = true;
224
225 self
226 }
227
228 pub fn chunk_size(mut self, v: u16) -> Self {
232 self.chunk_size = v;
233
234 self
235 }
236
237 pub fn delay(mut self, v: Duration) -> Self {
241 self.delay = Some(v);
242
243 self
244 }
245
246 pub fn routing_key(mut self, v: impl Into<String>) -> Self {
250 self.routing_key = RoutingKey::Value(Some(v.into()));
251
252 self
253 }
254
255 pub fn retry(mut self, v: u8) -> Self {
259 self.retry = Some(v);
260
261 self
262 }
263
264 pub fn all(mut self) -> Self {
266 self.routing_key = RoutingKey::All;
267
268 self
269 }
270
271 pub fn aggregator<A: Aggregator>(mut self, id: impl Into<String>) -> Self {
273 self.aggregators
274 .insert(A::aggregator_type().to_owned(), id.into());
275
276 self
277 }
278
279 fn read_aggregators(&self) -> Vec<ReadAggregator> {
280 self.handlers
281 .values()
282 .map(|h| match self.aggregators.get(h.aggregator_type()) {
283 Some(id) => ReadAggregator {
284 aggregator_type: h.aggregator_type().to_owned(),
285 aggregator_id: Some(id.to_owned()),
286 name: if self.safety_disabled {
287 Some(h.event_name().to_owned())
288 } else {
289 None
290 },
291 },
292 _ => {
293 if self.safety_disabled {
294 ReadAggregator::event(h.aggregator_type(), h.event_name())
295 } else {
296 ReadAggregator::aggregator(h.aggregator_type())
297 }
298 }
299 })
300 .collect()
301 }
302
303 fn key(&self) -> String {
304 if let RoutingKey::Value(Some(ref key)) = self.routing_key {
305 return format!("{key}.{}", self.key);
306 }
307
308 self.key.to_owned()
309 }
310
311 #[tracing::instrument(
312 skip_all,
313 fields(
314 subscription = Empty,
315 aggregator_type = Empty,
316 aggregator_id = Empty,
317 event = Empty,
318 )
319 )]
320 async fn process(
321 &self,
322 executor: &E,
323 id: &Ulid,
324 aggregators: &[ReadAggregator],
325 ) -> anyhow::Result<bool> {
326 let mut interval = interval_at(
327 Instant::now() - Duration::from_millis(400),
328 Duration::from_millis(300),
329 );
330
331 tracing::Span::current().record("subscription", self.key());
332
333 loop {
334 interval.tick().await;
335
336 if !executor.is_subscriber_running(self.key(), *id).await? {
337 return Ok(false);
338 }
339
340 let cursor = executor.get_subscriber_cursor(self.key()).await?;
341
342 let timestamp = executor
343 .read(
344 Some(aggregators.to_vec()),
345 Some(self.routing_key.to_owned()),
346 Args::backward(1, None),
347 )
348 .await?
349 .edges
350 .last()
351 .map(|e| e.node.timestamp)
352 .unwrap_or_default();
353
354 let res = executor
355 .read(
356 Some(aggregators.to_vec()),
357 Some(self.routing_key.to_owned()),
358 Args::forward(self.chunk_size, cursor),
359 )
360 .await?;
361
362 if res.edges.is_empty() {
363 return Ok(false);
364 }
365
366 let context = Context {
367 context: self.context.clone(),
368 executor,
369 };
370
371 for event in res.edges {
372 if let Some(ref rx) = self.shutdown_rx {
373 let mut rx = rx.lock().await;
374 if rx.try_recv().is_ok() {
375 tracing::info!(
376 key = self.key(),
377 "Subscription received shutdown signal, stopping gracefull"
378 );
379
380 return Ok(true);
381 }
382 drop(rx);
383 }
384
385 tracing::Span::current().record("aggregator_type", &event.node.aggregator_type);
386 tracing::Span::current().record("aggregator_id", &event.node.aggregator_id);
387 tracing::Span::current().record("event", &event.node.name);
388
389 let all_key = format!("{}_all", event.node.aggregator_type);
390 let key = format!("{}_{}", event.node.aggregator_type, event.node.name);
391 let Some(handler) = self.handlers.get(&all_key).or(self.handlers.get(&key)) else {
392 if !self.safety_disabled {
393 anyhow::bail!("no handler s={} k={key}", self.key());
394 }
395
396 continue;
397 };
398
399 if let Err(err) = handler.handle(&context, &event.node).await {
400 tracing::error!("failed");
401
402 return Err(err);
403 }
404
405 tracing::debug!("completed");
406
407 executor
408 .acknowledge(
409 self.key(),
410 event.cursor.to_owned(),
411 timestamp - event.node.timestamp,
412 )
413 .await?;
414 }
415 }
416 }
417
418 pub async fn unretry_start(mut self, executor: &E) -> anyhow::Result<Subscription>
422 where
423 E: Clone,
424 {
425 self.retry = None;
426 self.start(executor).await
427 }
428
429 #[tracing::instrument(skip_all, fields(
434 subscription = self.key(),
435 aggregator_type = tracing::field::Empty,
436 aggregator_id = tracing::field::Empty,
437 event = tracing::field::Empty,
438 ))]
439 pub async fn start(mut self, executor: &E) -> anyhow::Result<Subscription>
440 where
441 E: Clone,
442 {
443 let executor = executor.clone();
444 let id = Ulid::new();
445 let subscription_id = id;
446 let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel();
447 self.shutdown_rx = Some(Mutex::new(shutdown_rx));
448
449 executor
450 .upsert_subscriber(self.key(), id.to_owned())
451 .await?;
452
453 let task_handle = tokio::spawn(async move {
454 let read_aggregators = self.read_aggregators();
455 let start = self
456 .delay
457 .map(|d| Instant::now() + d)
458 .unwrap_or_else(Instant::now);
459
460 let mut interval = interval_at(
461 start - Duration::from_millis(1200),
462 Duration::from_millis(1000),
463 );
464
465 loop {
466 interval.tick().await;
467
468 if let Some(ref rx) = self.shutdown_rx {
469 let mut rx = rx.lock().await;
470 if rx.try_recv().is_ok() {
471 tracing::info!(
472 key = self.key(),
473 "Subscription received shutdown signal, stopping gracefull"
474 );
475
476 break;
477 }
478 drop(rx);
479 }
480
481 let result = match self.retry {
482 Some(retry) => {
483 (|| async { self.process(&executor, &id, &read_aggregators).await })
484 .retry(ExponentialBuilder::default().with_max_times(retry.into()))
485 .sleep(tokio::time::sleep)
486 .notify(|err, dur| {
487 tracing::error!(
488 error = %err,
489 duration = ?dur,
490 "Failed to process event"
491 );
492 })
493 .await
494 }
495 _ => self.process(&executor, &id, &read_aggregators).await,
496 };
497
498 match result {
499 Ok(shutdown) => {
500 if shutdown {
501 break;
502 }
503 }
504 Err(err) => {
505 tracing::error!(error = %err, "Failed to process event");
506
507 if !self.is_accept_failure {
508 break;
509 }
510 }
511 };
512 }
513 });
514
515 Ok(Subscription {
516 id: subscription_id,
517 task_handle,
518 shutdown_tx,
519 })
520 }
521
522 pub async fn unretry_execute(mut self, executor: &E) -> anyhow::Result<()> {
526 self.retry = None;
527 self.execute(executor).await
528 }
529
530 #[tracing::instrument(skip_all, fields(
535 subscription = self.key(),
536 aggregator_type = tracing::field::Empty,
537 aggregator_id = tracing::field::Empty,
538 event = tracing::field::Empty,
539 ))]
540 pub async fn execute(&self, executor: &E) -> anyhow::Result<()> {
541 let id = Ulid::new();
542
543 executor
544 .upsert_subscriber(self.key(), id.to_owned())
545 .await?;
546
547 let read_aggregators = self.read_aggregators();
548
549 match self.retry {
550 Some(retry) => {
551 (|| async { self.process(executor, &id, &read_aggregators).await })
552 .retry(ExponentialBuilder::default().with_max_times(retry.into()))
553 .sleep(tokio::time::sleep)
554 .notify(|err, dur| {
555 tracing::error!(
556 error = %err,
557 duration = ?dur,
558 "Failed to process event"
559 );
560 })
561 .await
562 }
563 _ => self.process(executor, &id, &read_aggregators).await,
564 }?;
565
566 Ok(())
567 }
568}
569
570#[derive(Debug)]
589pub struct Subscription {
590 pub id: Ulid,
592 task_handle: tokio::task::JoinHandle<()>,
593 shutdown_tx: tokio::sync::oneshot::Sender<()>,
594}
595
596impl Subscription {
597 pub async fn shutdown(self) -> Result<(), tokio::task::JoinError> {
602 let _ = self.shutdown_tx.send(());
603
604 self.task_handle.await
605 }
606}
607
608struct SkipHandler<E: AggregatorEvent>(PhantomData<E>);
609
610impl<E: Executor, EV: AggregatorEvent + Send + Sync> Handler<E> for SkipHandler<EV> {
611 fn handle<'a>(
612 &'a self,
613 _context: &'a Context<'a, E>,
614 _event: &'a crate::Event,
615 ) -> Pin<Box<dyn Future<Output = anyhow::Result<()>> + Send + 'a>> {
616 Box::pin(async { Ok(()) })
617 }
618
619 fn aggregator_type(&self) -> &'static str {
620 EV::aggregator_type()
621 }
622
623 fn event_name(&self) -> &'static str {
624 EV::event_name()
625 }
626}