1use ankurah_proto::{self as proto, CollectionId};
3use anyhow::anyhow;
4use async_trait::async_trait;
5use std::collections::HashMap;
6use std::sync::{Arc, OnceLock};
7use tracing::{debug, warn};
8
9use crate::error::{RequestError, RetrievalError};
10use crate::node::ContextData;
11use crate::util::safeset::SafeSet;
12
13#[async_trait::async_trait]
16pub trait RemoteQuerySubscriber: Clone + Send + Sync + 'static {
17 async fn subscription_established(&self, version: u32);
21
22 fn set_last_error(&self, error: RetrievalError);
24}
25
26#[derive(Debug, Clone)]
27pub enum Status {
28 PendingRemote,
29 Requested(proto::EntityId, u32), Established(proto::EntityId, u32), PendingUpdate(proto::EntityId, u32), Failed,
34}
35
36#[derive(Debug)]
37pub struct Content<CD: ContextData> {
38 pub query_id: proto::QueryId,
39 pub collection_id: CollectionId,
40 pub selection: ankql::ast::Selection,
41 pub context_data: CD,
42 pub version: u32,
43}
44
45pub struct RemoteQueryState<CD: ContextData, Q: RemoteQuerySubscriber> {
46 pub content: Arc<Content<CD>>,
47 pub status: Status,
48 pub livequery: Q,
49}
50
51struct SubscriptionRelayInner<CD: ContextData, Q: RemoteQuerySubscriber> {
52 subscriptions: std::sync::Mutex<HashMap<proto::QueryId, RemoteQueryState<CD, Q>>>,
54 connected_peers: SafeSet<proto::EntityId>,
56 node: OnceLock<Arc<dyn TNode<CD>>>,
58 _shutdown_tx: tokio::sync::mpsc::Sender<()>,
60}
61
62#[derive(Clone)]
90pub struct SubscriptionRelay<CD: ContextData, Q: RemoteQuerySubscriber> {
91 inner: Arc<SubscriptionRelayInner<CD, Q>>,
92}
93
94impl<CD: ContextData, Q: RemoteQuerySubscriber> Default for SubscriptionRelay<CD, Q> {
95 fn default() -> Self { Self::new() }
96}
97
98impl<CD: ContextData, Q: RemoteQuerySubscriber> SubscriptionRelay<CD, Q> {
99 pub fn new() -> Self {
100 let (shutdown_tx, shutdown_rx) = tokio::sync::mpsc::channel(1);
101
102 let relay = Self {
103 inner: Arc::new(SubscriptionRelayInner {
104 subscriptions: std::sync::Mutex::new(HashMap::new()),
105 connected_peers: SafeSet::new(),
106 node: OnceLock::new(),
107 _shutdown_tx: shutdown_tx,
108 }),
109 };
110
111 relay.start_retry_task(shutdown_rx);
113
114 relay
115 }
116
117 pub fn set_node(&self, node: Arc<dyn TNode<CD>>) -> Result<(), ()> { self.inner.node.set(node).map_err(|_| ()) }
122
123 pub fn subscribe_query(
128 &self,
129 query_id: proto::QueryId,
130 collection_id: CollectionId,
131 selection: ankql::ast::Selection,
132 context_data: CD,
133 version: u32,
134 livequery: Q,
135 ) {
136 debug!("SubscriptionRelay.subscribe_predicate() - New predicate {} needs remote registration", query_id);
137 {
138 self.inner.subscriptions.lock().expect("poisoned lock").insert(
139 query_id,
140 RemoteQueryState {
141 content: Arc::new(Content { collection_id, selection, context_data, query_id, version }),
142 status: Status::PendingRemote,
143 livequery,
144 },
145 );
146 }
147
148 if !self.inner.connected_peers.is_empty() {
150 self.setup_remote_subscriptions()
151 }
152 }
153 pub fn update_query(&self, query_id: proto::QueryId, selection: ankql::ast::Selection, version: u32) -> Result<(), anyhow::Error> {
154 debug!("SubscriptionRelay.update_query() - New query {} needs remote registration", query_id);
155
156 let update = {
157 let mut subscriptions = self.inner.subscriptions.lock().expect("poisoned lock");
158 match subscriptions.get_mut(&query_id) {
159 Some(state) => {
160 let old_content = &state.content;
162 state.content = Arc::new(Content {
163 collection_id: old_content.collection_id.clone(),
164 selection: selection.clone(),
165 context_data: old_content.context_data.clone(),
166 query_id: old_content.query_id,
167 version,
168 });
169
170 match state.status {
171 Status::Established(peer_id, _old_version) => {
172 state.status = Status::Requested(peer_id, version);
174 Some((peer_id, state.content.collection_id.clone(), state.content.context_data.clone()))
175 }
177 _ => {
178 state.status = Status::PendingRemote;
180 None
181 }
182 }
183 }
184 None => return Err(anyhow!("Predicate {} not found", query_id)),
185 }
186 };
187
188 match update {
189 Some((peer_id, collection_id, context_data)) => {
190 self.update_query_on_peer(peer_id, query_id, collection_id, selection, version, context_data);
191 }
192 None => {
193 self.setup_remote_subscriptions();
195 }
196 };
197
198 Ok(())
199 }
200
201 fn update_query_on_peer(
202 &self,
203 peer_id: proto::EntityId,
204 query_id: proto::QueryId,
205 collection_id: CollectionId,
206 selection: ankql::ast::Selection,
207 version: u32,
208 context_data: CD,
209 ) {
210 let me = self.clone();
211 crate::task::spawn(async move {
212 if let Some(node) = me.inner.node.get() {
213 let livequery = {
215 me.inner.subscriptions.lock().unwrap_or_else(|e| e.into_inner()).get(&query_id).map(|state| state.livequery.clone())
216 };
217
218 match node.remote_subscribe(peer_id, query_id, collection_id, selection, &context_data, version).await {
220 Ok(()) => {
221 if let Some(lq) = livequery {
223 lq.subscription_established(version).await;
224 }
225
226 let mut subscriptions = me.inner.subscriptions.lock().unwrap_or_else(|e| e.into_inner());
228 if let Some(info) = subscriptions.get_mut(&query_id) {
229 info.status = Status::Established(peer_id, version);
230 }
231 debug!("Successfully updated predicate {} on peer {} subscription", query_id, peer_id);
232 }
233 Err(e) => {
234 me.handle_error(query_id, peer_id, e, livequery).await;
236 }
237 }
238 }
239 });
240 }
241
242 pub fn unsubscribe_predicate(&self, query_id: proto::QueryId) {
247 debug!("Unregistering predicate {}", query_id);
248
249 {
251 let mut subscriptions = self.inner.subscriptions.lock().unwrap_or_else(|e| e.into_inner());
252 if let Some(info) = subscriptions.remove(&query_id) {
253 if let Status::Established(peer_id, _version) = &info.status {
254 let node = self.inner.node.get();
255 if let Some(node) = node {
256 let node = node.clone();
257 let peer_id = *peer_id;
258 crate::task::spawn(async move {
259 if let Err(e) = node.peer_unsubscribe(peer_id, query_id).await {
260 warn!("Failed to send unsubscribe message for {}: {}", query_id, e);
261 } else {
262 debug!("Successfully sent unsubscribe message for {}", query_id);
263 }
264 });
265 }
266 }
267 }
268 }
269 }
270
271 pub fn notify_peer_disconnected(&self, peer_id: proto::EntityId) {
277 debug!("Peer {} disconnected, orphaning predicate registrations", peer_id);
278
279 self.inner.connected_peers.remove(&peer_id);
281
282 for info in self.inner.subscriptions.lock().expect("poisoned lock").values_mut() {
283 if let Status::Established(established_peer_id, _) | Status::Requested(established_peer_id, _) = &info.status {
284 if *established_peer_id == peer_id {
285 info.status = Status::PendingRemote;
287 warn!("Predicate {} orphaned due to peer {} disconnect", info.content.query_id, peer_id);
288 }
289 }
290 }
291
292 self.setup_remote_subscriptions();
294 }
295
296 pub fn notify_peer_connected(&self, peer_id: proto::EntityId) {
301 debug!("SubscriptionRelay.notify_peer_connected() - Peer {} connected, registering predicates on peer subscription", peer_id);
302
303 self.inner.connected_peers.insert(peer_id);
305
306 self.setup_remote_subscriptions();
308 }
309
310 pub fn get_status(&self, query_id: proto::QueryId) -> Option<Status> {
312 let subscriptions = self.inner.subscriptions.lock().unwrap_or_else(|e| e.into_inner());
313 subscriptions.get(&query_id).map(|info| info.status.clone())
314 }
315
316 pub fn get_contexts_for_peer(&self, peer_id: &proto::EntityId) -> std::collections::HashSet<CD> {
319 let subscriptions = self.inner.subscriptions.lock().unwrap_or_else(|e| e.into_inner());
320 let mut contexts = std::collections::HashSet::new();
321
322 for (_, state) in subscriptions.iter() {
323 match &state.status {
324 Status::Established(established_peer, _) | Status::Requested(established_peer, _) => {
325 if established_peer == peer_id {
326 contexts.insert(state.content.context_data.clone());
327 }
328 }
329 _ => {}
330 }
331 }
332
333 contexts
334 }
335
336 fn setup_remote_subscriptions(&self) {
338 let node = match self.inner.node.get() {
339 Some(node) => node,
340 None => {
341 warn!("No node configured for remote subscription setup");
342 return;
343 }
344 };
345
346 let connected_peers = self.inner.connected_peers.to_vec();
348 if connected_peers.is_empty() {
349 warn!("No durable peers available for remote subscription setup");
350 return;
351 }
352
353 let target_peer = connected_peers[0];
354
355 let pending: Vec<_> = {
357 self.inner
358 .subscriptions
359 .lock()
360 .expect("poisoned lock")
361 .values_mut()
362 .filter_map(|info| {
363 if let Status::PendingRemote = info.status {
364 info.status = Status::Requested(target_peer, info.content.version);
365 Some(info.content.clone())
366 } else {
367 None
368 }
369 })
370 .collect()
371 };
372
373 if pending.is_empty() {
374 return;
375 }
376
377 debug!("Registering {} predicates on {} peer subscriptions", pending.len(), self.inner.connected_peers.len());
378
379 for content in pending {
380 crate::task::spawn(self.clone().attempt_subscribe(node.clone(), target_peer, content));
381 }
382 }
383
384 async fn attempt_subscribe(self, node: Arc<dyn TNode<CD>>, target_peer: proto::EntityId, content: Arc<Content<CD>>) {
385 let query_id = content.query_id;
386 let predicate = content.selection.clone();
387 let context_data = content.context_data.clone();
388 let version = content.version;
389
390 let livequery =
392 { self.inner.subscriptions.lock().unwrap_or_else(|e| e.into_inner()).get(&query_id).map(|state| state.livequery.clone()) };
393
394 match node.remote_subscribe(target_peer, query_id, content.collection_id.clone(), predicate, &context_data, version).await {
396 Ok(()) => {
397 if let Some(lq) = livequery {
400 lq.subscription_established(version).await;
401 }
402
403 let mut subscriptions = self.inner.subscriptions.lock().unwrap_or_else(|e| e.into_inner());
405 if let Some(info) = subscriptions.get_mut(&query_id) {
406 info.status = Status::Established(target_peer, version);
407 }
408 debug!("Successfully registered predicate {} on peer {} subscription", query_id, target_peer);
409 }
410 Err(e) => {
411 self.handle_error(query_id, target_peer, e, livequery).await;
413 }
414 }
415 }
416
417 fn start_retry_task(&self, mut shutdown_rx: tokio::sync::mpsc::Receiver<()>) {
419 let me = self.clone();
420 crate::task::spawn(async move {
421 loop {
422 let delay = futures_timer::Delay::new(std::time::Duration::from_secs(5));
423 tokio::select! {
424 _ = delay => {
425 me.setup_remote_subscriptions();
427 }
428 _ = shutdown_rx.recv() => {
429 debug!("Retry task shutting down - SubscriptionRelay dropped");
430 break;
431 }
432 }
433 }
434 });
435 }
436
437 async fn handle_error(&self, query_id: proto::QueryId, target_peer: proto::EntityId, error: RetrievalError, livequery: Option<Q>) {
439 let error_msg = error.to_string();
440
441 let is_retryable = match &error {
443 RetrievalError::RequestError(req_err) => match req_err {
445 RequestError::PeerNotConnected => true,
446 RequestError::ConnectionLost => true,
447 RequestError::SendError(_) => true,
448 RequestError::InternalChannelClosed => true,
449 RequestError::ServerError(_) => false,
450 RequestError::UnexpectedResponse(_) => false,
451 RequestError::AccessDenied(_) => false,
452 },
453 _ => false,
455 };
456
457 let mut subscriptions = self.inner.subscriptions.lock().unwrap_or_else(|e| e.into_inner());
459 if let Some(info) = subscriptions.get_mut(&query_id) {
460 if is_retryable {
461 info.status = Status::PendingRemote;
463 warn!("Retryable failure for predicate {} with peer {}: {} - will retry", query_id, target_peer, error_msg);
464 } else {
465 info.status = Status::Failed;
467 tracing::error!("Permanent failure for predicate {} with peer {}: {} - no retry", query_id, target_peer, error_msg);
468
469 if let Some(lq) = livequery {
471 lq.set_last_error(error);
472 }
473 }
474 }
475 }
476}
477
478#[async_trait]
480pub trait TNode<CD: ContextData>: Send + Sync {
481 async fn remote_subscribe(
485 &self,
486 peer_id: proto::EntityId,
487 query_id: proto::QueryId,
488 collection_id: CollectionId,
489 selection: ankql::ast::Selection,
490 context_data: &CD,
491 version: u32,
492 ) -> Result<(), RetrievalError>;
493
494 async fn peer_unsubscribe(&self, peer_id: proto::EntityId, query_id: proto::QueryId) -> Result<(), anyhow::Error>;
497}
498
499#[async_trait]
501impl<SE, PA> TNode<PA::ContextData> for crate::node::WeakNode<SE, PA>
502where
503 SE: crate::storage::StorageEngine + Send + Sync + 'static,
504 PA: crate::policy::PolicyAgent + Send + Sync + 'static,
505{
506 async fn remote_subscribe(
507 &self,
508 peer_id: proto::EntityId,
509 query_id: proto::QueryId,
510 collection_id: CollectionId,
511 selection: ankql::ast::Selection,
512 context_data: &PA::ContextData,
513 version: u32,
514 ) -> Result<(), RetrievalError> {
515 let node = self.upgrade().ok_or_else(|| RetrievalError::Other("Node has been dropped".to_string()))?;
516
517 let known_matches: Vec<ankurah_proto::KnownEntity> = node
519 .fetch_entities_from_local(&collection_id, &selection)
520 .await?
521 .into_iter()
522 .map(|entity| ankurah_proto::KnownEntity { entity_id: entity.id(), head: entity.head() })
523 .collect();
524
525 let deltas = match node
527 .request(
528 peer_id,
529 context_data,
530 ankurah_proto::NodeRequestBody::SubscribeQuery {
531 query_id,
532 collection: collection_id.clone(),
533 selection: selection.clone(),
534 version,
535 known_matches,
536 },
537 )
538 .await
539 .map_err(|e| RetrievalError::RequestError(e))?
540 {
541 ankurah_proto::NodeResponseBody::QuerySubscribed { query_id: _response_query_id, deltas } => deltas,
542 ankurah_proto::NodeResponseBody::Error(e) => return Err(RetrievalError::RequestError(RequestError::ServerError(e))),
543 other => return Err(RetrievalError::RequestError(RequestError::UnexpectedResponse(other))),
544 };
545
546 tracing::debug!(
547 "Node.remote_subscribe: query_id: {}, collection_id: {}, received deltas: {}",
548 query_id,
549 collection_id,
550 deltas.len()
551 );
552 let retriever = crate::retrieval::EphemeralNodeRetriever::new(collection_id, &node, context_data);
554 let apply_result = crate::node_applier::NodeApplier::apply_deltas(&node, &peer_id, deltas, &retriever).await;
555 let event_store_result = retriever.store_used_events().await;
556
557 apply_result?; event_store_result?;
559
560 Ok(())
561 }
562
563 async fn peer_unsubscribe(&self, peer_id: proto::EntityId, query_id: proto::QueryId) -> Result<(), anyhow::Error> {
564 let node = self.upgrade().ok_or_else(|| anyhow!("Node has been dropped"))?;
565
566 node.request_remote_unsubscribe(query_id, vec![peer_id]).await?;
568
569 Ok(())
570 }
571}
572
573#[cfg(test)]
574mod tests {
575 use super::*;
576 use ankql::ast::Predicate;
577 use ankurah_proto::EntityId;
578 use std::sync::{Arc, Mutex};
579
580 impl ContextData for CollectionId {}
588
589 #[derive(Debug)]
591 struct MockMessageSender<CD: ContextData> {
592 next_error: Arc<Mutex<Option<RequestError>>>,
593 sent_requests: Arc<Mutex<Vec<(EntityId, proto::QueryId, CollectionId, ankql::ast::Selection)>>>,
594 should_fail: Arc<Mutex<bool>>,
595 failure_message: Arc<Mutex<String>>,
596 _phantom: std::marker::PhantomData<CD>,
597 }
598
599 impl<CD: ContextData> MockMessageSender<CD> {
600 fn new() -> Self {
601 Self {
602 sent_requests: Arc::new(Mutex::new(Vec::new())),
603 next_error: Arc::new(Mutex::new(None)),
604 should_fail: Arc::new(Mutex::new(false)),
605 failure_message: Arc::new(Mutex::new(String::new())),
606 _phantom: std::marker::PhantomData,
607 }
608 }
609
610 fn set_fail_next(&self, error: RequestError) { *self.next_error.lock().unwrap() = Some(error); }
611
612 fn get_sent_requests(&self) -> Vec<(EntityId, proto::QueryId, CollectionId, ankql::ast::Selection)> {
613 self.sent_requests.lock().unwrap().clone()
614 }
615
616 fn clear_sent_requests(&self) { self.sent_requests.lock().unwrap().clear(); }
617 }
618
619 #[async_trait]
620 impl<CD: ContextData> TNode<CD> for MockMessageSender<CD> {
621 async fn remote_subscribe(
622 &self,
623 peer_id: EntityId,
624 query_id: proto::QueryId,
625 collection_id: CollectionId,
626 selection: ankql::ast::Selection,
627 _context_data: &CD,
628 _version: u32,
629 ) -> Result<(), RetrievalError> {
630 self.sent_requests.lock().unwrap().push((peer_id, query_id, collection_id.clone(), selection.clone()));
631
632 if let Some(error) = self.next_error.lock().unwrap().take() {
634 Err(RetrievalError::RequestError(error))
635 } else {
636 Ok(())
638 }
639 }
640
641 async fn peer_unsubscribe(&self, peer_id: EntityId, query_id: proto::QueryId) -> Result<(), anyhow::Error> {
642 self.sent_requests.lock().unwrap().push((
643 peer_id,
644 query_id,
645 CollectionId::from("unsubscribe"),
646 ankql::ast::Selection { predicate: ankql::ast::Predicate::True, order_by: None, limit: None },
647 ));
648
649 if let Some(error) = self.next_error.lock().unwrap().take() {
651 Err(anyhow!(error.to_string()))
652 } else {
653 Ok(())
654 }
655 }
656 }
657
658 #[derive(Clone)]
660 struct MockLiveQuery;
661
662 #[async_trait::async_trait]
663 impl RemoteQuerySubscriber for MockLiveQuery {
664 async fn subscription_established(&self, _version: u32) {
665 }
667
668 fn set_last_error(&self, _error: RetrievalError) {
669 }
671 }
672
673 fn create_test_selection() -> ankql::ast::Selection {
674 ankql::ast::Selection { predicate: ankql::ast::Predicate::True, order_by: None, limit: None }
676 }
677
678 fn create_test_collection_id() -> CollectionId { CollectionId::from("test_collection") }
679
680 #[tokio::test]
681 async fn test_new_subscription_setup() {
682 let relay = SubscriptionRelay::new();
683 let mock_sender = Arc::new(MockMessageSender::<CollectionId>::new());
684 relay.set_node(mock_sender.clone()).expect("Failed to set message sender");
685
686 let query_id = proto::QueryId::new();
687 let collection_id = create_test_collection_id();
688 let predicate = create_test_selection();
689 let peer_id = EntityId::new();
690
691 relay.notify_peer_connected(peer_id);
693
694 relay.subscribe_query(query_id, collection_id.clone(), predicate.clone(), collection_id.clone(), 0, MockLiveQuery);
696
697 assert!(matches!(relay.get_status(query_id), Some(Status::Requested(_, _))));
699
700 futures_timer::Delay::new(std::time::Duration::from_millis(10)).await;
702
703 let sent_requests = mock_sender.get_sent_requests();
705 assert_eq!(sent_requests.len(), 1);
706 assert_eq!(sent_requests[0].0, peer_id);
707 assert_eq!(sent_requests[0].1, query_id);
708 assert_eq!(sent_requests[0].2, collection_id);
709
710 assert!(matches!(relay.get_status(query_id), Some(Status::Established(established_peer_id, _)) if established_peer_id == peer_id));
712 }
713
714 #[tokio::test]
715 async fn test_peer_disconnection_orphans_subscriptions() {
716 let relay = SubscriptionRelay::new();
717
718 let mock_sender = Arc::new(MockMessageSender::<CollectionId>::new());
719 relay.set_node(mock_sender.clone()).expect("Failed to set message sender");
720
721 let query_id = proto::QueryId::new();
722 let collection_id = create_test_collection_id();
723 let predicate = create_test_selection();
724 let peer_id = EntityId::new();
725
726 relay.notify_peer_connected(peer_id);
728
729 relay.subscribe_query(query_id, collection_id.clone(), predicate, collection_id.clone(), 0, MockLiveQuery);
731
732 futures_timer::Delay::new(std::time::Duration::from_millis(10)).await;
734
735 assert!(matches!(relay.get_status(query_id), Some(Status::Established(established_peer_id, _)) if established_peer_id == peer_id));
736
737 relay.notify_peer_disconnected(peer_id);
739
740 assert!(matches!(relay.get_status(query_id), Some(Status::PendingRemote)));
742 }
743
744 #[tokio::test]
745 async fn test_peer_connection_triggers_setup() {
746 let relay = SubscriptionRelay::new();
747 let mock_sender = Arc::new(MockMessageSender::<CollectionId>::new());
748 relay.set_node(mock_sender.clone()).expect("Failed to set message sender");
749
750 let query_id = proto::QueryId::new();
751 let collection_id = create_test_collection_id();
752 let predicate = create_test_selection();
753 let peer_id = EntityId::new();
754
755 relay.subscribe_query(query_id, collection_id.clone(), predicate.clone(), collection_id.clone(), 0, MockLiveQuery);
757 assert!(matches!(relay.get_status(query_id), Some(Status::PendingRemote)));
758
759 mock_sender.clear_sent_requests();
761
762 relay.notify_peer_connected(peer_id);
764
765 futures_timer::Delay::new(std::time::Duration::from_millis(10)).await;
767
768 let sent_requests = mock_sender.get_sent_requests();
770 assert_eq!(sent_requests.len(), 1);
771 assert_eq!(sent_requests[0].0, peer_id);
772 assert_eq!(sent_requests[0].1, query_id);
773
774 assert!(matches!(relay.get_status(query_id), Some(Status::Established(established_peer_id, _)) if established_peer_id == peer_id));
776 }
777
778 #[tokio::test]
779 async fn test_failed_subscription_retry() {
780 let relay = SubscriptionRelay::new();
781 let mock_sender = Arc::new(MockMessageSender::<CollectionId>::new());
782 relay.set_node(mock_sender.clone()).expect("Failed to set message sender");
783
784 let query_id = proto::QueryId::new();
785 let collection_id = create_test_collection_id();
786 let predicate = create_test_selection();
787 let peer_id = EntityId::new();
788
789 relay.notify_peer_connected(peer_id);
791 relay.subscribe_query(query_id, collection_id.clone(), predicate.clone(), collection_id.clone(), 0, MockLiveQuery);
792
793 futures_timer::Delay::new(std::time::Duration::from_millis(10)).await;
795
796 assert!(matches!(relay.get_status(query_id), Some(Status::Established(established_peer_id, _)) if established_peer_id == peer_id));
798
799 relay.notify_peer_disconnected(peer_id);
802
803 assert!(matches!(relay.get_status(query_id), Some(Status::PendingRemote)));
805
806 mock_sender.clear_sent_requests();
808 mock_sender.set_fail_next(RequestError::ServerError("Invalid predicate".to_string()));
809
810 relay.notify_peer_connected(peer_id);
812
813 futures_timer::Delay::new(std::time::Duration::from_millis(10)).await;
815
816 let sent_requests = mock_sender.get_sent_requests();
818 assert_eq!(sent_requests.len(), 1);
819
820 assert!(matches!(relay.get_status(query_id), Some(Status::Failed)));
822 }
823
824 #[tokio::test]
825 async fn test_retryable_vs_non_retryable_failures() {
826 let relay = SubscriptionRelay::new();
827 let mock_sender = Arc::new(MockMessageSender::<CollectionId>::new());
828 relay.set_node(mock_sender.clone()).expect("Failed to set message sender");
829
830 let retryable_query_id = proto::QueryId::new();
831 let non_retryable_query_id = proto::QueryId::new();
832 let collection_id = create_test_collection_id();
833 let predicate = create_test_selection();
834 let peer_id = EntityId::new();
835
836 relay.subscribe_query(retryable_query_id, collection_id.clone(), predicate.clone(), collection_id.clone(), 0, MockLiveQuery);
838 relay.subscribe_query(non_retryable_query_id, collection_id.clone(), predicate.clone(), collection_id.clone(), 0, MockLiveQuery);
839
840 {
842 let mut subscriptions = relay.inner.subscriptions.lock().unwrap_or_else(|e| e.into_inner());
843 if let Some(info) = subscriptions.get_mut(&retryable_query_id) {
844 info.status = Status::PendingRemote; }
846 if let Some(info) = subscriptions.get_mut(&non_retryable_query_id) {
847 info.status = Status::Failed; }
849 }
850
851 relay.notify_peer_connected(peer_id);
853
854 futures_timer::Delay::new(std::time::Duration::from_millis(10)).await;
856
857 let sent_requests = mock_sender.get_sent_requests();
859 assert_eq!(sent_requests.len(), 1);
860 assert_eq!(sent_requests[0].1, retryable_query_id);
861
862 assert!(
864 matches!(relay.get_status(retryable_query_id), Some(Status::Established(established_peer_id, _)) if established_peer_id == peer_id)
865 );
866 assert!(matches!(relay.get_status(non_retryable_query_id), Some(Status::Failed)));
867 }
868
869 #[tokio::test]
870 async fn test_subscription_removal() {
871 let relay = SubscriptionRelay::new();
872 let mock_sender = Arc::new(MockMessageSender::<CollectionId>::new());
873 relay.set_node(mock_sender.clone()).expect("Failed to set message sender");
874
875 let query_id = proto::QueryId::new();
876 let collection_id = create_test_collection_id();
877 let predicate = create_test_selection();
878 let peer_id = EntityId::new();
879
880 relay.notify_peer_connected(peer_id);
882 relay.subscribe_query(query_id, collection_id.clone(), predicate, collection_id.clone(), 0, MockLiveQuery);
883
884 futures_timer::Delay::new(std::time::Duration::from_millis(10)).await;
886
887 assert!(matches!(relay.get_status(query_id), Some(Status::Established(established_peer_id, _)) if established_peer_id == peer_id));
888
889 mock_sender.clear_sent_requests();
891
892 relay.unsubscribe_predicate(query_id);
894
895 futures_timer::Delay::new(std::time::Duration::from_millis(10)).await;
897
898 let sent_requests = mock_sender.get_sent_requests();
900 assert_eq!(sent_requests.len(), 1);
901 assert_eq!(sent_requests[0].0, peer_id);
902 assert_eq!(sent_requests[0].1, query_id);
903
904 assert!(matches!(relay.get_status(query_id), None));
906 }
907
908 #[tokio::test]
909 async fn test_edge_cases() {
910 let relay = SubscriptionRelay::new();
911 let mock_sender = Arc::new(MockMessageSender::<CollectionId>::new());
912
913 let query_id = proto::QueryId::new();
914 let collection_id = create_test_collection_id();
915 let predicate = create_test_selection();
916 let peer_id = EntityId::new();
917
918 relay.subscribe_query(query_id, collection_id.clone(), predicate.clone(), collection_id.clone(), 0, MockLiveQuery);
920 futures_timer::Delay::new(std::time::Duration::from_millis(10)).await;
921
922 assert!(matches!(relay.get_status(query_id), Some(Status::PendingRemote)));
924
925 relay.set_node(mock_sender.clone()).expect("Failed to set message sender");
927 futures_timer::Delay::new(std::time::Duration::from_millis(10)).await;
928
929 assert!(matches!(relay.get_status(query_id), Some(Status::PendingRemote)));
931
932 assert_eq!(mock_sender.get_sent_requests().len(), 0);
934
935 relay.notify_peer_connected(peer_id);
937 futures_timer::Delay::new(std::time::Duration::from_millis(10)).await;
938
939 assert!(matches!(relay.get_status(query_id), Some(Status::Established(established_peer_id, _)) if established_peer_id == peer_id));
941 assert_eq!(mock_sender.get_sent_requests().len(), 1);
942 }
943
944 #[tokio::test]
945 async fn test_notify_unsubscribe_with_no_established_subscription() {
946 let relay = SubscriptionRelay::new();
947 let mock_sender = Arc::new(MockMessageSender::<CollectionId>::new());
948 relay.set_node(mock_sender.clone()).expect("Failed to set message sender");
949
950 let query_id = proto::QueryId::new();
951 let collection_id = create_test_collection_id();
952 let predicate = create_test_selection();
953
954 relay.subscribe_query(query_id, collection_id.clone(), predicate, collection_id.clone(), 0, MockLiveQuery);
956 assert!(matches!(relay.get_status(query_id), Some(Status::PendingRemote)));
957
958 relay.unsubscribe_predicate(query_id);
960
961 futures_timer::Delay::new(std::time::Duration::from_millis(10)).await;
963
964 let sent_requests = mock_sender.get_sent_requests();
966 assert_eq!(sent_requests.len(), 0);
967
968 assert!(matches!(relay.get_status(query_id), None));
970 }
971}