1use std::collections::HashMap;
2use std::sync::Arc;
3
4use futures_util::stream::Stream;
5use redis::{
6 AsyncConnectionConfig, Client, FromRedisValue, PushInfo, PushKind, RedisError, RedisResult,
7 Value, cmd,
8};
9use tokio::sync::{Mutex, mpsc};
10use tracing::error;
11use uuid::Uuid;
12
13use crate::types::{Event, SierraMessage};
14
15#[derive(Debug, Clone)]
17enum PartitionSelector {
18 Id(u16),
19 Key(Uuid),
20}
21
22#[derive(Clone)]
25pub struct SubscriptionManager {
26 inner: Arc<Mutex<SubscriptionManagerInner>>,
27 }
29
30struct SubscriptionManagerInner {
31 connection: redis::aio::MultiplexedConnection,
32 subscriptions: HashMap<Uuid, mpsc::UnboundedSender<SierraMessage>>,
33}
34
35impl SubscriptionManager {
36 pub async fn new(client: &Client) -> RedisResult<Self> {
41 let (tx, mut rx) = mpsc::unbounded_channel::<PushInfo>();
42
43 let connection = client
44 .get_multiplexed_async_connection_with_config(
45 &AsyncConnectionConfig::new().set_push_sender(tx),
46 )
47 .await?;
48
49 let inner = Arc::new(Mutex::new(SubscriptionManagerInner {
50 connection,
51 subscriptions: HashMap::new(),
52 }));
53
54 let inner_clone = inner.clone();
56 let _background_task = tokio::spawn(async move {
57 while let Some(push_info) = rx.recv().await {
58 let mut inner_guard = inner_clone.lock().await;
59 if let Err(err) = inner_guard.handle_push_message(push_info).await {
60 error!("error handling push message: {err}");
61 }
62 }
63 });
64
65 Ok(SubscriptionManager {
66 inner,
67 })
69 }
70
71 pub async fn subscribe_to_stream<S: redis::ToRedisArgs>(
73 &mut self,
74 stream_id: S,
75 ) -> RedisResult<EventSubscription> {
76 self.subscribe_to_stream_with_options(stream_id, None, None, None)
77 .await
78 }
79
80 pub async fn subscribe_to_stream_with_window<S: redis::ToRedisArgs>(
82 &mut self,
83 stream_id: S,
84 window_size: u32,
85 ) -> RedisResult<EventSubscription> {
86 self.subscribe_to_stream_with_options(stream_id, None, None, Some(window_size))
87 .await
88 }
89
90 pub async fn subscribe_to_stream_from_version<S: redis::ToRedisArgs>(
92 &mut self,
93 stream_id: S,
94 from_version: u64,
95 ) -> RedisResult<EventSubscription> {
96 self.subscribe_to_stream_with_options(stream_id, None, Some(from_version), None)
97 .await
98 }
99
100 pub async fn subscribe_to_stream_from_version_with_window<S: redis::ToRedisArgs>(
103 &mut self,
104 stream_id: S,
105 from_version: u64,
106 window_size: u32,
107 ) -> RedisResult<EventSubscription> {
108 self.subscribe_to_stream_with_options(
109 stream_id,
110 None,
111 Some(from_version),
112 Some(window_size),
113 )
114 .await
115 }
116
117 async fn subscribe_to_stream_with_options<S: redis::ToRedisArgs>(
119 &mut self,
120 stream_id: S,
121 partition_key: Option<Uuid>,
122 from_version: Option<u64>,
123 window_size: Option<u32>,
124 ) -> RedisResult<EventSubscription> {
125 let mut inner = self.inner.lock().await;
126
127 let mut cmd = cmd("ESUB");
128 cmd.arg(stream_id);
129
130 if let Some(key) = partition_key {
131 cmd.arg("PARTITION_KEY").arg(key.to_string());
132 }
133
134 if let Some(version) = from_version {
135 cmd.arg("FROM").arg(version);
136 }
137
138 if let Some(size) = window_size {
139 cmd.arg("WINDOW").arg(size);
140 }
141
142 let response: Value = cmd.query_async(&mut inner.connection).await?;
143
144 let subscription_id = match response {
145 Value::SimpleString(id_str) => Uuid::parse_str(&id_str).map_err(|_| {
146 RedisError::from((redis::ErrorKind::TypeError, "Invalid UUID in response"))
147 })?,
148 _ => {
149 return Err(RedisError::from((
150 redis::ErrorKind::TypeError,
151 "Expected subscription ID",
152 )));
153 }
154 };
155
156 let (sender, receiver) = mpsc::unbounded_channel();
157 inner.subscriptions.insert(subscription_id, sender);
158
159 Ok(EventSubscription {
160 subscription_id,
161 receiver,
162 manager: self.inner.clone(),
163 })
164 }
165
166 pub async fn subscribe_to_partition(
168 &mut self,
169 partition: u16,
170 ) -> RedisResult<EventSubscription> {
171 self.subscribe_to_partition_with_options(PartitionSelector::Id(partition), None, None)
172 .await
173 }
174
175 pub async fn subscribe_to_partition_with_window(
177 &mut self,
178 partition: u16,
179 window_size: u32,
180 ) -> RedisResult<EventSubscription> {
181 self.subscribe_to_partition_with_options(
182 PartitionSelector::Id(partition),
183 None,
184 Some(window_size),
185 )
186 .await
187 }
188
189 pub async fn subscribe_to_partition_key(
191 &mut self,
192 key: Uuid,
193 ) -> RedisResult<EventSubscription> {
194 self.subscribe_to_partition_with_options(PartitionSelector::Key(key), None, None)
195 .await
196 }
197
198 pub async fn subscribe_to_partition_key_with_window(
201 &mut self,
202 key: Uuid,
203 window_size: u32,
204 ) -> RedisResult<EventSubscription> {
205 self.subscribe_to_partition_with_options(
206 PartitionSelector::Key(key),
207 None,
208 Some(window_size),
209 )
210 .await
211 }
212
213 pub async fn subscribe_to_stream_with_partition_key<S: redis::ToRedisArgs>(
215 &mut self,
216 stream_id: S,
217 partition_key: Uuid,
218 ) -> RedisResult<EventSubscription> {
219 self.subscribe_to_stream_with_options(stream_id, Some(partition_key), None, None)
220 .await
221 }
222
223 pub async fn subscribe_to_stream_with_partition_key_and_window<S: redis::ToRedisArgs>(
225 &mut self,
226 stream_id: S,
227 partition_key: Uuid,
228 window_size: u32,
229 ) -> RedisResult<EventSubscription> {
230 self.subscribe_to_stream_with_options(
231 stream_id,
232 Some(partition_key),
233 None,
234 Some(window_size),
235 )
236 .await
237 }
238
239 pub async fn subscribe_to_stream_with_partition_and_version<S: redis::ToRedisArgs>(
242 &mut self,
243 stream_id: S,
244 partition_key: Uuid,
245 from_version: u64,
246 ) -> RedisResult<EventSubscription> {
247 self.subscribe_to_stream_with_options(
248 stream_id,
249 Some(partition_key),
250 Some(from_version),
251 None,
252 )
253 .await
254 }
255
256 pub async fn subscribe_to_stream_with_partition_and_version_and_window<
259 S: redis::ToRedisArgs,
260 >(
261 &mut self,
262 stream_id: S,
263 partition_key: Uuid,
264 from_version: u64,
265 window_size: u32,
266 ) -> RedisResult<EventSubscription> {
267 self.subscribe_to_stream_with_options(
268 stream_id,
269 Some(partition_key),
270 Some(from_version),
271 Some(window_size),
272 )
273 .await
274 }
275
276 pub async fn subscribe_to_partition_from_sequence(
278 &mut self,
279 partition: u16,
280 from_sequence: u64,
281 ) -> RedisResult<EventSubscription> {
282 self.subscribe_to_partition_with_options(
283 PartitionSelector::Id(partition),
284 Some(from_sequence),
285 None,
286 )
287 .await
288 }
289
290 pub async fn subscribe_to_partition_from_sequence_with_window(
293 &mut self,
294 partition: u16,
295 from_sequence: u64,
296 window_size: u32,
297 ) -> RedisResult<EventSubscription> {
298 self.subscribe_to_partition_with_options(
299 PartitionSelector::Id(partition),
300 Some(from_sequence),
301 Some(window_size),
302 )
303 .await
304 }
305
306 pub async fn subscribe_to_partition_key_from_sequence(
309 &mut self,
310 key: Uuid,
311 from_sequence: u64,
312 ) -> RedisResult<EventSubscription> {
313 self.subscribe_to_partition_with_options(
314 PartitionSelector::Key(key),
315 Some(from_sequence),
316 None,
317 )
318 .await
319 }
320
321 pub async fn subscribe_to_partition_key_from_sequence_with_window(
324 &mut self,
325 key: Uuid,
326 from_sequence: u64,
327 window_size: u32,
328 ) -> RedisResult<EventSubscription> {
329 self.subscribe_to_partition_with_options(
330 PartitionSelector::Key(key),
331 Some(from_sequence),
332 Some(window_size),
333 )
334 .await
335 }
336
337 async fn subscribe_to_partition_with_options(
339 &mut self,
340 partition: PartitionSelector,
341 from_sequence: Option<u64>,
342 window_size: Option<u32>,
343 ) -> RedisResult<EventSubscription> {
344 let mut inner = self.inner.lock().await;
345
346 let mut cmd = cmd("EPSUB");
347
348 match partition {
349 PartitionSelector::Id(id) => cmd.arg(id),
350 PartitionSelector::Key(key) => cmd.arg(key.to_string()),
351 };
352
353 if let Some(sequence) = from_sequence {
354 cmd.arg("FROM").arg(sequence);
355 }
356
357 if let Some(size) = window_size {
358 cmd.arg("WINDOW").arg(size);
359 }
360
361 let response: Value = cmd.query_async(&mut inner.connection).await?;
362
363 let subscription_id = match response {
364 Value::SimpleString(id_str) => Uuid::parse_str(&id_str).map_err(|_| {
365 RedisError::from((redis::ErrorKind::TypeError, "Invalid UUID in response"))
366 })?,
367 _ => {
368 return Err(RedisError::from((
369 redis::ErrorKind::TypeError,
370 "Expected subscription ID",
371 )));
372 }
373 };
374
375 let (sender, receiver) = mpsc::unbounded_channel();
376 inner.subscriptions.insert(subscription_id, sender);
377
378 Ok(EventSubscription {
379 subscription_id,
380 receiver,
381 manager: self.inner.clone(),
382 })
383 }
384
385 pub async fn acknowledge_up_to_cursor(
390 &mut self,
391 subscription_id: Uuid,
392 cursor: u64,
393 ) -> RedisResult<()> {
394 let mut inner = self.inner.lock().await;
395
396 let _response: Value = cmd("EACK")
397 .arg(subscription_id.to_string())
398 .arg(cursor)
399 .query_async(&mut inner.connection)
400 .await?;
401
402 Ok(())
403 }
404
405 pub async fn subscribe_to_partitions(
421 &mut self,
422 partition_range: &str,
423 from_sequence: u64,
424 window_size: Option<u32>,
425 ) -> RedisResult<EventSubscription> {
426 let mut inner = self.inner.lock().await;
427
428 let mut cmd = cmd("EPSUB");
429 cmd.arg(partition_range);
430 cmd.arg("FROM").arg(from_sequence);
431
432 if let Some(size) = window_size {
433 cmd.arg("WINDOW").arg(size);
434 }
435
436 let response: Value = cmd.query_async(&mut inner.connection).await?;
437
438 let subscription_id = match response {
439 Value::SimpleString(id_str) => Uuid::parse_str(&id_str).map_err(|_| {
440 RedisError::from((redis::ErrorKind::TypeError, "Invalid UUID in response"))
441 })?,
442 _ => {
443 return Err(RedisError::from((
444 redis::ErrorKind::TypeError,
445 "Expected subscription ID",
446 )));
447 }
448 };
449
450 let (sender, receiver) = mpsc::unbounded_channel();
451 inner.subscriptions.insert(subscription_id, sender);
452
453 Ok(EventSubscription {
454 subscription_id,
455 receiver,
456 manager: self.inner.clone(),
457 })
458 }
459
460 pub async fn subscribe_to_partitions_with_sequences(
476 &mut self,
477 partition_sequences: HashMap<u16, u64>,
478 window_size: Option<u32>,
479 ) -> RedisResult<EventSubscription> {
480 if partition_sequences.is_empty() {
481 return Err(RedisError::from((
482 redis::ErrorKind::InvalidClientConfig,
483 "At least one partition must be specified",
484 )));
485 }
486
487 let mut inner = self.inner.lock().await;
488
489 let mut cmd = cmd("EPSUB");
490
491 let partition_list: Vec<String> =
493 partition_sequences.keys().map(|&p| p.to_string()).collect();
494 cmd.arg(partition_list.join(","));
495
496 cmd.arg("FROM");
497 cmd.arg("MAP");
498
499 for (partition_id, sequence) in &partition_sequences {
501 cmd.arg(format!("{partition_id}={sequence}"));
502 }
503
504 if let Some(size) = window_size {
505 cmd.arg("WINDOW").arg(size);
506 }
507
508 let response: Value = cmd.query_async(&mut inner.connection).await?;
509
510 let subscription_id = match response {
511 Value::SimpleString(id_str) => Uuid::parse_str(&id_str).map_err(|_| {
512 RedisError::from((redis::ErrorKind::TypeError, "Invalid UUID in response"))
513 })?,
514 _ => {
515 return Err(RedisError::from((
516 redis::ErrorKind::TypeError,
517 "Expected subscription ID",
518 )));
519 }
520 };
521
522 let (sender, receiver) = mpsc::unbounded_channel();
523 inner.subscriptions.insert(subscription_id, sender);
524
525 Ok(EventSubscription {
526 subscription_id,
527 receiver,
528 manager: self.inner.clone(),
529 })
530 }
531
532 pub async fn subscribe_to_all_partitions(
537 &mut self,
538 from_sequence: u64,
539 window_size: Option<u32>,
540 ) -> RedisResult<EventSubscription> {
541 self.subscribe_to_partitions("*", from_sequence, window_size)
542 .await
543 }
544
545 pub async fn subscribe_to_partition_range(
559 &mut self,
560 start_partition: u16,
561 end_partition: u16,
562 from_sequence: u64,
563 window_size: Option<u32>,
564 ) -> RedisResult<EventSubscription> {
565 let range = format!("{start_partition}-{end_partition}");
566 self.subscribe_to_partitions(&range, from_sequence, window_size)
567 .await
568 }
569
570 pub async fn subscribe_to_stream_from_latest<S: redis::ToRedisArgs>(
572 &mut self,
573 stream_id: S,
574 ) -> RedisResult<EventSubscription> {
575 let mut inner = self.inner.lock().await;
576
577 let mut cmd = cmd("ESUB");
578 cmd.arg(stream_id);
579 cmd.arg("FROM");
580 cmd.arg("LATEST");
581
582 let response: Value = cmd.query_async(&mut inner.connection).await?;
583
584 let subscription_id = match response {
585 Value::SimpleString(id_str) => Uuid::parse_str(&id_str).map_err(|_| {
586 RedisError::from((redis::ErrorKind::TypeError, "Invalid UUID in response"))
587 })?,
588 _ => {
589 return Err(RedisError::from((
590 redis::ErrorKind::TypeError,
591 "Expected subscription ID",
592 )));
593 }
594 };
595
596 let (sender, receiver) = mpsc::unbounded_channel();
597 inner.subscriptions.insert(subscription_id, sender);
598
599 Ok(EventSubscription {
600 subscription_id,
601 receiver,
602 manager: self.inner.clone(),
603 })
604 }
605
606 pub async fn subscribe_to_all_partitions_from_latest(
608 &mut self,
609 ) -> RedisResult<EventSubscription> {
610 let mut inner = self.inner.lock().await;
611
612 let mut cmd = cmd("EPSUB");
613 cmd.arg("*");
614 cmd.arg("FROM");
615 cmd.arg("LATEST");
616
617 let response: Value = cmd.query_async(&mut inner.connection).await?;
618
619 let subscription_id = match response {
620 Value::SimpleString(id_str) => Uuid::parse_str(&id_str).map_err(|_| {
621 RedisError::from((redis::ErrorKind::TypeError, "Invalid UUID in response"))
622 })?,
623 _ => {
624 return Err(RedisError::from((
625 redis::ErrorKind::TypeError,
626 "Expected subscription ID",
627 )));
628 }
629 };
630
631 let (sender, receiver) = mpsc::unbounded_channel();
632 inner.subscriptions.insert(subscription_id, sender);
633
634 Ok(EventSubscription {
635 subscription_id,
636 receiver,
637 manager: self.inner.clone(),
638 })
639 }
640
641 pub async fn subscribe_to_all_partitions_with_fallback(
660 &mut self,
661 partition_sequences: HashMap<u16, u64>,
662 fallback_sequence: u64,
663 window_size: Option<u32>,
664 ) -> RedisResult<EventSubscription> {
665 self.subscribe_to_all_partitions_flexible(
666 partition_sequences,
667 Some(fallback_sequence),
668 window_size,
669 )
670 .await
671 }
672
673 pub async fn subscribe_to_all_partitions_flexible(
698 &mut self,
699 from_map: HashMap<u16, u64>,
700 fallback_sequence: Option<u64>,
701 window_size: Option<u32>,
702 ) -> RedisResult<EventSubscription> {
703 let mut inner = self.inner.lock().await;
704
705 let mut cmd = cmd("EPSUB");
706 cmd.arg("*");
707
708 if from_map.is_empty() {
710 match fallback_sequence {
711 None => {
712 cmd.arg("FROM");
714 cmd.arg("LATEST");
715 }
716 Some(fallback) => {
717 cmd.arg("FROM");
719 cmd.arg(fallback);
720 }
721 }
722 } else {
723 cmd.arg("FROM");
725 cmd.arg("MAP");
726
727 for (partition_id, sequence) in &from_map {
729 cmd.arg(format!("{partition_id}={sequence}"));
730 }
731
732 if let Some(fallback) = fallback_sequence {
734 cmd.arg("DEFAULT");
735 cmd.arg(fallback);
736 }
737 }
738
739 if let Some(size) = window_size {
740 cmd.arg("WINDOW").arg(size);
741 }
742
743 let response: Value = cmd.query_async(&mut inner.connection).await?;
744
745 let subscription_id = match response {
746 Value::SimpleString(id_str) => Uuid::parse_str(&id_str).map_err(|_| {
747 RedisError::from((redis::ErrorKind::TypeError, "Invalid UUID in response"))
748 })?,
749 _ => {
750 return Err(RedisError::from((
751 redis::ErrorKind::TypeError,
752 "Expected subscription ID",
753 )));
754 }
755 };
756
757 let (sender, receiver) = mpsc::unbounded_channel();
758 inner.subscriptions.insert(subscription_id, sender);
759
760 Ok(EventSubscription {
761 subscription_id,
762 receiver,
763 manager: self.inner.clone(),
764 })
765 }
766}
767
768impl SubscriptionManagerInner {
769 async fn handle_push_message(&mut self, push_info: PushInfo) -> Result<(), RedisError> {
770 let PushInfo { kind, data } = push_info;
771
772 match kind {
773 PushKind::Message => {
774 match data.as_slice() {
776 [
777 Value::SimpleString(subscription_id_str),
778 Value::Int(cursor),
779 event_value,
780 ] => {
781 let subscription_id =
782 Uuid::parse_str(subscription_id_str).map_err(|_| {
783 RedisError::from((
784 redis::ErrorKind::TypeError,
785 "Invalid subscription ID",
786 ))
787 })?;
788
789 if let Some(sender) = self.subscriptions.get(&subscription_id) {
790 let event = Event::from_redis_value(event_value)?;
791 let cursor = *cursor as u64;
792 let message = SierraMessage::Event { event, cursor };
793
794 if sender.send(message).is_err() {
795 self.subscriptions.remove(&subscription_id);
797 }
798 }
799 }
800 _ => {
801 return Err(RedisError::from((
802 redis::ErrorKind::TypeError,
803 "Unexpected message format",
804 )));
805 }
806 }
807 }
808 PushKind::Subscribe => {
809 match data.as_slice() {
811 [Value::SimpleString(subscription_id_str), Value::Int(count)] => {
812 let subscription_id =
813 Uuid::parse_str(subscription_id_str).map_err(|_| {
814 RedisError::from((
815 redis::ErrorKind::TypeError,
816 "Invalid subscription ID",
817 ))
818 })?;
819
820 if let Some(sender) = self.subscriptions.get(&subscription_id) {
821 let message = SierraMessage::SubscriptionConfirmed {
822 subscription_count: *count,
823 };
824
825 if sender.send(message).is_err() {
826 self.subscriptions.remove(&subscription_id);
828 }
829 }
830 }
831 _ => {
832 return Err(RedisError::from((
833 redis::ErrorKind::TypeError,
834 "Unexpected subscribe format",
835 )));
836 }
837 }
838 }
839 PushKind::Disconnection => {}
840 kind => {
841 return Err(RedisError::from((
842 redis::ErrorKind::TypeError,
843 "Unknown push kind",
844 kind.to_string(),
845 )));
846 }
847 }
848
849 Ok(())
850 }
851}
852
853pub struct EventSubscription {
857 subscription_id: Uuid,
858 receiver: mpsc::UnboundedReceiver<SierraMessage>,
859 manager: Arc<Mutex<SubscriptionManagerInner>>,
860}
861
862impl EventSubscription {
863 pub fn subscription_id(&self) -> Uuid {
865 self.subscription_id
866 }
867
868 pub async fn next_message(&mut self) -> Option<SierraMessage> {
872 self.receiver.recv().await
873 }
874
875 pub fn into_stream(self) -> impl Stream<Item = SierraMessage> {
881 futures_util::stream::unfold(self, |mut subscription| async move {
882 subscription
883 .next_message()
884 .await
885 .map(|msg| (msg, subscription))
886 })
887 }
888
889 pub async fn acknowledge_up_to_cursor(&self, cursor: u64) -> RedisResult<()> {
894 let mut manager = self.manager.lock().await;
895
896 let _response: Value = cmd("EACK")
897 .arg(self.subscription_id.to_string())
898 .arg(cursor)
899 .query_async(&mut manager.connection)
900 .await?;
901
902 Ok(())
903 }
904
905 pub async fn unsubscribe(mut self) -> RedisResult<()> {
909 self.receiver.close();
911
912 let mut manager = self.manager.lock().await;
914 manager.subscriptions.remove(&self.subscription_id);
915
916 Ok(())
918 }
919}
920
921impl Drop for EventSubscription {
922 fn drop(&mut self) {
923 let manager = self.manager.clone();
925 let subscription_id = self.subscription_id;
926
927 tokio::spawn(async move {
928 let mut inner = manager.lock().await;
929 inner.subscriptions.remove(&subscription_id);
930 });
931 }
932}