1use std::num::NonZeroU64;
4use std::pin::Pin;
5use std::sync::Arc;
6use std::task::{Context, Poll};
7
8use aion_core::{Event, WorkflowFilter, WorkflowId};
9use aion_proto::{
10 FilteredSubscription, FirehoseSubscription, PerWorkflowSubscription, ProtoWorkflowId,
11 SubscriptionRequest, subscription_request,
12};
13use futures::Stream;
14use futures::future::BoxFuture;
15use futures::stream::BoxStream;
16
17use crate::error::ClientError;
18use crate::transport::{SubscriptionAttempt, WorkflowTransport};
19
20pub type EventStream = Pin<Box<dyn Stream<Item = Result<Event, ClientError>> + Send>>;
22
23#[derive(Clone, Debug, PartialEq, Eq)]
25pub enum SubscribeTarget {
26 Workflow {
28 workflow_id: WorkflowId,
30 },
31 Filtered {
33 filter: WorkflowFilter,
35 },
36 Firehose,
38}
39
40impl SubscribeTarget {
41 pub(crate) fn request(&self, namespace: &str) -> SubscriptionRequest {
42 match self {
43 Self::Workflow { workflow_id } => SubscriptionRequest {
44 subscription: Some(subscription_request::Subscription::PerWorkflow(
45 PerWorkflowSubscription {
46 namespace: namespace.to_owned(),
47 workflow_id: Some(ProtoWorkflowId::from(workflow_id.clone())),
48 resume_from_seq: None,
49 },
50 )),
51 },
52 Self::Filtered { filter } => SubscriptionRequest {
53 subscription: Some(subscription_request::Subscription::Filtered(
54 FilteredSubscription {
55 namespace: namespace.to_owned(),
56 workflow_type: filter.workflow_type.clone(),
57 status: filter
58 .status
59 .map(|status| aion_proto::ProtoWorkflowStatus::from(status) as i32),
60 namespace_selector: None,
61 },
62 )),
63 },
64 Self::Firehose => SubscriptionRequest {
65 subscription: Some(subscription_request::Subscription::Firehose(
66 FirehoseSubscription {
67 namespace: namespace.to_owned(),
68 },
69 )),
70 },
71 }
72 }
73}
74
75pub struct ResumingEventStream {
86 transport: Arc<dyn WorkflowTransport>,
87 namespace: String,
88 target: SubscribeTarget,
89 last_seq: Option<u64>,
90 delivered_any: bool,
91 current: Option<BoxStream<'static, Result<Event, ClientError>>>,
92 pending_subscribe: Option<BoxFuture<'static, Result<SubscriptionAttempt, ClientError>>>,
93 terminal_error: Option<ClientError>,
94 finished: bool,
95}
96
97impl ResumingEventStream {
98 #[must_use]
100 pub fn new(
101 transport: Arc<dyn WorkflowTransport>,
102 namespace: impl Into<String>,
103 target: SubscribeTarget,
104 ) -> Self {
105 Self {
106 transport,
107 namespace: namespace.into(),
108 target,
109 last_seq: None,
110 delivered_any: false,
111 current: None,
112 pending_subscribe: None,
113 terminal_error: None,
114 finished: false,
115 }
116 }
117
118 #[must_use]
126 pub fn from_sequence(
127 transport: Arc<dyn WorkflowTransport>,
128 namespace: impl Into<String>,
129 workflow_id: WorkflowId,
130 resume_from: NonZeroU64,
131 ) -> Self {
132 let mut stream = Self::new(
133 transport,
134 namespace,
135 SubscribeTarget::Workflow { workflow_id },
136 );
137 stream.last_seq = Some(resume_from.get() - 1);
141 stream
142 }
143
144 fn is_per_workflow(&self) -> bool {
145 matches!(self.target, SubscribeTarget::Workflow { .. })
146 }
147
148 fn start_subscribe(&mut self) {
149 let transport = Arc::clone(&self.transport);
150 let request = self.target.request(&self.namespace);
151 let resume_from_sequence = if self.is_per_workflow() {
154 self.last_seq.map(|seq| seq.saturating_add(1))
155 } else {
156 None
157 };
158 self.pending_subscribe = Some(Box::pin(async move {
159 transport.subscribe(request, resume_from_sequence).await
160 }));
161 }
162}
163
164impl Stream for ResumingEventStream {
165 type Item = Result<Event, ClientError>;
166
167 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
168 let this = self.get_mut();
169 loop {
170 if this.finished {
171 return Poll::Ready(None);
172 }
173
174 if let Some(error) = this.terminal_error.take() {
175 this.finished = true;
176 return Poll::Ready(Some(Err(error)));
177 }
178
179 if this.current.is_none() && this.pending_subscribe.is_none() {
180 this.start_subscribe();
181 }
182
183 if let Some(pending) = this.pending_subscribe.as_mut() {
184 match pending.as_mut().poll(cx) {
185 Poll::Pending => return Poll::Pending,
186 Poll::Ready(Ok(attempt)) => {
187 this.pending_subscribe = None;
188 this.current = Some(attempt.events);
189 }
190 Poll::Ready(Err(error)) => {
191 this.pending_subscribe = None;
192 this.finished = true;
193 return Poll::Ready(Some(Err(error)));
194 }
195 }
196 }
197
198 let Some(current) = this.current.as_mut() else {
199 continue;
200 };
201 match current.as_mut().poll_next(cx) {
202 Poll::Pending => return Poll::Pending,
203 Poll::Ready(Some(Ok(event))) => {
204 if this.is_per_workflow() {
205 if this.last_seq.is_some_and(|seq| event.seq() <= seq) {
208 continue;
209 }
210 this.last_seq = Some(event.seq());
211 }
212 this.delivered_any = true;
213 return Poll::Ready(Some(Ok(event)));
214 }
215 Poll::Ready(Some(Err(error))) => {
216 this.current = None;
217 if is_retryable(&error) {
218 if this.is_per_workflow() {
219 continue;
220 }
221 if !this.delivered_any {
222 continue;
225 }
226 }
230 this.terminal_error = Some(error);
231 }
232 Poll::Ready(None) => {
233 this.current = None;
234 this.finished = true;
235 return Poll::Ready(None);
236 }
237 }
238 }
239 }
240}
241
242#[must_use]
244pub fn event_stream(
245 transport: Arc<dyn WorkflowTransport>,
246 namespace: impl Into<String>,
247 target: SubscribeTarget,
248) -> EventStream {
249 Box::pin(ResumingEventStream::new(transport, namespace, target))
250}
251
252#[must_use]
254pub fn event_stream_from(
255 transport: Arc<dyn WorkflowTransport>,
256 namespace: impl Into<String>,
257 workflow_id: WorkflowId,
258 resume_from: NonZeroU64,
259) -> EventStream {
260 Box::pin(ResumingEventStream::from_sequence(
261 transport,
262 namespace,
263 workflow_id,
264 resume_from,
265 ))
266}
267
268fn is_retryable(error: &ClientError) -> bool {
269 matches!(error, ClientError::Unavailable { .. })
270}
271
272#[cfg(test)]
273mod tests {
274 use std::collections::VecDeque;
275 use std::sync::Arc;
276
277 use aion_core::{ContentType, Event, EventEnvelope, Payload, WorkflowId};
278 use aion_proto::{
279 ProtoCancelResponse, ProtoDescribeWorkflowResponse, ProtoListWorkflowsResponse,
280 ProtoQueryResponse, ProtoSignalResponse, ProtoStartWorkflowResponse,
281 };
282 use async_trait::async_trait;
283 use chrono::Utc;
284 use futures::StreamExt;
285 use futures::stream;
286 use tokio::sync::Mutex;
287
288 use super::{ResumingEventStream, SubscribeTarget};
289 use crate::error::ClientError;
290 use crate::transport::{SubscriptionAttempt, WorkflowTransport};
291
292 #[derive(Default)]
293 struct SubscribeStub {
294 attempts: Mutex<VecDeque<SubscriptionAttempt>>,
295 resume_points: Mutex<Vec<Option<u64>>>,
296 }
297
298 #[async_trait]
299 impl WorkflowTransport for SubscribeStub {
300 async fn start_workflow(
301 &self,
302 _: aion_proto::ProtoStartWorkflowRequest,
303 ) -> Result<ProtoStartWorkflowResponse, ClientError> {
304 Err(ClientError::unavailable("stub transport"))
305 }
306
307 async fn signal(
308 &self,
309 _: aion_proto::ProtoSignalRequest,
310 ) -> Result<ProtoSignalResponse, ClientError> {
311 Err(ClientError::unavailable("stub transport"))
312 }
313
314 async fn query(
315 &self,
316 _: aion_proto::ProtoQueryRequest,
317 ) -> Result<ProtoQueryResponse, ClientError> {
318 Err(ClientError::unavailable("stub transport"))
319 }
320
321 async fn cancel(
322 &self,
323 _: aion_proto::ProtoCancelRequest,
324 ) -> Result<ProtoCancelResponse, ClientError> {
325 Err(ClientError::unavailable("stub transport"))
326 }
327
328 async fn list_workflows(
329 &self,
330 _: aion_proto::ProtoListWorkflowsRequest,
331 ) -> Result<ProtoListWorkflowsResponse, ClientError> {
332 Err(ClientError::unavailable("stub transport"))
333 }
334
335 async fn describe_workflow(
336 &self,
337 _: aion_proto::ProtoDescribeWorkflowRequest,
338 ) -> Result<ProtoDescribeWorkflowResponse, ClientError> {
339 Err(ClientError::unavailable("stub transport"))
340 }
341
342 async fn subscribe(
343 &self,
344 _: aion_proto::SubscriptionRequest,
345 resume_from_sequence: Option<u64>,
346 ) -> Result<SubscriptionAttempt, ClientError> {
347 self.resume_points.lock().await.push(resume_from_sequence);
348 self.attempts
349 .lock()
350 .await
351 .pop_front()
352 .ok_or_else(|| ClientError::server("missing subscribe attempt"))
353 }
354 }
355
356 fn event(seq: u64, workflow_id: &WorkflowId) -> Event {
357 Event::WorkflowStarted {
358 envelope: EventEnvelope {
359 seq,
360 recorded_at: Utc::now(),
361 workflow_id: workflow_id.clone(),
362 },
363 workflow_type: String::from("checkout"),
364 input: Payload::new(ContentType::Json, Vec::new()),
365 run_id: aion_core::RunId::new(uuid::Uuid::from_u128(1)),
366 parent_run_id: None,
367 }
368 }
369
370 #[tokio::test]
371 async fn resumes_after_transient_disconnect_without_gaps_or_duplicates() {
372 let workflow_id = WorkflowId::new_v4();
373 let stub = Arc::new(SubscribeStub::default());
374 stub.attempts
375 .lock()
376 .await
377 .push_back(SubscriptionAttempt::new(
378 stream::iter(vec![
379 Ok(event(1, &workflow_id)),
380 Ok(event(2, &workflow_id)),
381 Err(ClientError::unavailable("transient disconnect")),
382 ])
383 .boxed(),
384 ));
385 stub.attempts
386 .lock()
387 .await
388 .push_back(SubscriptionAttempt::new(
389 stream::iter(vec![
390 Ok(event(2, &workflow_id)),
391 Ok(event(3, &workflow_id)),
392 Ok(event(4, &workflow_id)),
393 ])
394 .boxed(),
395 ));
396 let mut events = ResumingEventStream::new(
397 stub.clone(),
398 "tenant-a",
399 SubscribeTarget::Workflow {
400 workflow_id: workflow_id.clone(),
401 },
402 );
403
404 let mut seqs = Vec::new();
405 while let Some(item) = events.next().await {
406 let event = item
407 .map_err(|e| format!("unexpected stream error: {e}"))
408 .ok();
409 if let Some(event) = event {
410 seqs.push(event.seq());
411 }
412 }
413
414 assert_eq!(seqs, vec![1, 2, 3, 4]);
415 assert_eq!(*stub.resume_points.lock().await, vec![None, Some(3)]);
416 }
417
418 #[tokio::test]
419 async fn terminal_failure_is_yielded_before_end() {
420 let workflow_id = WorkflowId::new_v4();
421 let stub = Arc::new(SubscribeStub::default());
422 stub.attempts
423 .lock()
424 .await
425 .push_back(SubscriptionAttempt::new(
426 stream::iter(vec![Err(ClientError::unauthenticated("bad token"))]).boxed(),
427 ));
428 let mut events =
429 ResumingEventStream::new(stub, "tenant-a", SubscribeTarget::Workflow { workflow_id });
430
431 assert_eq!(
432 events.next().await,
433 Some(Err(ClientError::unauthenticated("bad token")))
434 );
435 assert_eq!(events.next().await, None);
436 }
437
438 #[tokio::test]
439 async fn namespace_denied_is_terminal_and_never_retried() {
440 let workflow_id = WorkflowId::new_v4();
441 let stub = Arc::new(SubscribeStub::default());
442 let denied =
443 ClientError::namespace_denied("namespace tenant-b is not granted to this caller");
444 stub.attempts
445 .lock()
446 .await
447 .push_back(SubscriptionAttempt::new(
448 stream::iter(vec![Err(denied.clone())]).boxed(),
449 ));
450 let mut events = ResumingEventStream::new(
451 stub.clone(),
452 "tenant-b",
453 SubscribeTarget::Workflow { workflow_id },
454 );
455
456 assert_eq!(events.next().await, Some(Err(denied)));
457 assert_eq!(events.next().await, None);
458 assert_eq!(stub.resume_points.lock().await.len(), 1);
459 }
460
461 #[tokio::test]
462 async fn from_sequence_passes_the_cursor_on_the_initial_attach() {
463 let workflow_id = WorkflowId::new_v4();
464 let stub = Arc::new(SubscribeStub::default());
465 stub.attempts
466 .lock()
467 .await
468 .push_back(SubscriptionAttempt::new(
469 stream::iter(vec![Ok(event(1, &workflow_id)), Ok(event(2, &workflow_id))]).boxed(),
470 ));
471 let Some(resume_from) = std::num::NonZeroU64::new(1) else {
472 unreachable!("1 is non-zero");
473 };
474 let mut events = super::ResumingEventStream::from_sequence(
475 stub.clone(),
476 "tenant-a",
477 workflow_id,
478 resume_from,
479 );
480
481 let mut seqs = Vec::new();
482 while let Some(item) = events.next().await {
483 if let Ok(event) = item {
484 seqs.push(event.seq());
485 }
486 }
487
488 assert_eq!(seqs, vec![1, 2]);
489 assert_eq!(
490 *stub.resume_points.lock().await,
491 vec![Some(1)],
492 "the initial attach must carry the explicit cursor"
493 );
494 }
495
496 #[tokio::test]
497 async fn live_only_streams_reconnect_only_before_any_delivery() {
498 let workflow_id = WorkflowId::new_v4();
501 let stub = Arc::new(SubscribeStub::default());
502 stub.attempts
503 .lock()
504 .await
505 .push_back(SubscriptionAttempt::new(
506 stream::iter(vec![Err(ClientError::unavailable("transient disconnect"))]).boxed(),
507 ));
508 stub.attempts
509 .lock()
510 .await
511 .push_back(SubscriptionAttempt::new(
512 stream::iter(vec![Ok(event(1, &workflow_id))]).boxed(),
513 ));
514 let mut events = ResumingEventStream::new(
515 stub.clone(),
516 "tenant-a",
517 SubscribeTarget::Filtered {
518 filter: aion_core::WorkflowFilter::default(),
519 },
520 );
521
522 let mut seqs = Vec::new();
523 while let Some(item) = events.next().await {
524 if let Ok(event) = item {
525 seqs.push(event.seq());
526 }
527 }
528
529 assert_eq!(seqs, vec![1]);
530 assert_eq!(
531 *stub.resume_points.lock().await,
532 vec![None, None],
533 "live-only streams never carry a resume cursor"
534 );
535 }
536
537 #[tokio::test]
538 async fn live_only_disconnect_after_delivery_is_honest_unavailable() {
539 for target in [
543 SubscribeTarget::Filtered {
544 filter: aion_core::WorkflowFilter::default(),
545 },
546 SubscribeTarget::Firehose,
547 ] {
548 let workflow_id = WorkflowId::new_v4();
549 let stub = Arc::new(SubscribeStub::default());
550 stub.attempts
551 .lock()
552 .await
553 .push_back(SubscriptionAttempt::new(
554 stream::iter(vec![
555 Ok(event(1, &workflow_id)),
556 Err(ClientError::unavailable("transient disconnect")),
557 ])
558 .boxed(),
559 ));
560 let mut events = ResumingEventStream::new(stub.clone(), "tenant-a", target);
561
562 let first = events.next().await;
563 assert!(matches!(first, Some(Ok(_))), "got {first:?}");
564 assert_eq!(
565 events.next().await,
566 Some(Err(ClientError::unavailable("transient disconnect")))
567 );
568 assert_eq!(events.next().await, None);
569 assert_eq!(
570 stub.resume_points.lock().await.len(),
571 1,
572 "no reattach may follow a post-delivery live-only disconnect"
573 );
574 }
575 }
576
577 #[tokio::test]
578 async fn live_only_streams_do_not_dedupe_sequence_numbers_across_workflows() {
579 let first_workflow = WorkflowId::new_v4();
582 let second_workflow = WorkflowId::new_v4();
583 let stub = Arc::new(SubscribeStub::default());
584 stub.attempts
585 .lock()
586 .await
587 .push_back(SubscriptionAttempt::new(
588 stream::iter(vec![
589 Ok(event(1, &first_workflow)),
590 Ok(event(1, &second_workflow)),
591 ])
592 .boxed(),
593 ));
594 let mut events = ResumingEventStream::new(stub, "tenant-a", SubscribeTarget::Firehose);
595
596 let mut delivered = Vec::new();
597 while let Some(item) = events.next().await {
598 if let Ok(event) = item {
599 delivered.push(event.envelope().workflow_id.clone());
600 }
601 }
602
603 assert_eq!(delivered, vec![first_workflow, second_workflow]);
604 }
605
606 #[tokio::test]
607 async fn not_found_is_terminal_and_never_retried() {
608 let workflow_id = WorkflowId::new_v4();
612 let stub = Arc::new(SubscribeStub::default());
613 stub.attempts
614 .lock()
615 .await
616 .push_back(SubscriptionAttempt::new(
617 stream::iter(vec![Err(ClientError::not_found("workflow was not found"))]).boxed(),
618 ));
619 let mut events = ResumingEventStream::new(
620 stub.clone(),
621 "tenant-a",
622 SubscribeTarget::Workflow { workflow_id },
623 );
624
625 assert_eq!(
626 events.next().await,
627 Some(Err(ClientError::not_found("workflow was not found")))
628 );
629 assert_eq!(events.next().await, None);
630 assert_eq!(stub.resume_points.lock().await.len(), 1);
631 }
632}