1use ankurah_proto::{self as proto, CollectionId};
3use anyhow::anyhow;
4use async_trait::async_trait;
5use std::collections::HashMap;
6use std::sync::{Arc, OnceLock};
7use tracing::{debug, warn};
8
9use crate::error::{RequestError, RetrievalError};
10use crate::node::ContextData;
11use crate::util::safeset::SafeSet;
12
13#[async_trait::async_trait]
16pub trait RemoteQuerySubscriber: Clone + Send + Sync + 'static {
17 async fn subscription_established(&self, version: u32);
21
22 fn set_last_error(&self, error: RetrievalError);
24}
25
26#[derive(Debug, Clone)]
27pub enum Status {
28 PendingRemote,
29 Requested(proto::EntityId, u32), Established(proto::EntityId, u32), PendingUpdate(proto::EntityId, u32), Failed,
34}
35
36#[derive(Debug)]
37pub struct Content<CD: ContextData> {
38 pub query_id: proto::QueryId,
39 pub collection_id: CollectionId,
40 pub selection: ankql::ast::Selection,
41 pub context_data: CD,
42 pub version: u32,
43}
44
45pub struct RemoteQueryState<CD: ContextData, Q: RemoteQuerySubscriber> {
46 pub content: Arc<Content<CD>>,
47 pub status: Status,
48 pub livequery: Q,
49}
50
51struct SubscriptionRelayInner<CD: ContextData, Q: RemoteQuerySubscriber> {
52 subscriptions: std::sync::Mutex<HashMap<proto::QueryId, RemoteQueryState<CD, Q>>>,
54 connected_peers: SafeSet<proto::EntityId>,
56 node: OnceLock<Arc<dyn TNode<CD>>>,
58 _shutdown_tx: tokio::sync::mpsc::Sender<()>,
60}
61
62#[derive(Clone)]
90pub struct SubscriptionRelay<CD: ContextData, Q: RemoteQuerySubscriber> {
91 inner: Arc<SubscriptionRelayInner<CD, Q>>,
92}
93
94impl<CD: ContextData, Q: RemoteQuerySubscriber> Default for SubscriptionRelay<CD, Q> {
95 fn default() -> Self { Self::new() }
96}
97
98impl<CD: ContextData, Q: RemoteQuerySubscriber> SubscriptionRelay<CD, Q> {
99 pub fn new() -> Self {
100 let (shutdown_tx, shutdown_rx) = tokio::sync::mpsc::channel(1);
101
102 let relay = Self {
103 inner: Arc::new(SubscriptionRelayInner {
104 subscriptions: std::sync::Mutex::new(HashMap::new()),
105 connected_peers: SafeSet::new(),
106 node: OnceLock::new(),
107 _shutdown_tx: shutdown_tx,
108 }),
109 };
110
111 relay.start_retry_task(shutdown_rx);
113
114 relay
115 }
116
117 pub fn set_node(&self, node: Arc<dyn TNode<CD>>) -> Result<(), ()> { self.inner.node.set(node).map_err(|_| ()) }
122
123 pub fn subscribe_query(
128 &self,
129 query_id: proto::QueryId,
130 collection_id: CollectionId,
131 selection: ankql::ast::Selection,
132 context_data: CD,
133 version: u32,
134 livequery: Q,
135 ) {
136 debug!("SubscriptionRelay.subscribe_predicate() - New predicate {} needs remote registration", query_id);
137 {
138 self.inner.subscriptions.lock().expect("poisoned lock").insert(
139 query_id,
140 RemoteQueryState {
141 content: Arc::new(Content { collection_id, selection, context_data, query_id, version }),
142 status: Status::PendingRemote,
143 livequery,
144 },
145 );
146 }
147
148 if !self.inner.connected_peers.is_empty() {
150 self.setup_remote_subscriptions()
151 }
152 }
153 pub fn update_query(&self, query_id: proto::QueryId, selection: ankql::ast::Selection, version: u32) -> Result<(), anyhow::Error> {
154 debug!("SubscriptionRelay.update_query() - New query {} needs remote registration", query_id);
155
156 let update = {
157 let mut subscriptions = self.inner.subscriptions.lock().expect("poisoned lock");
158 match subscriptions.get_mut(&query_id) {
159 Some(state) => {
160 let old_content = &state.content;
162 state.content = Arc::new(Content {
163 collection_id: old_content.collection_id.clone(),
164 selection: selection.clone(),
165 context_data: old_content.context_data.clone(),
166 query_id: old_content.query_id,
167 version,
168 });
169
170 match state.status {
171 Status::Established(peer_id, _old_version) => {
172 state.status = Status::Requested(peer_id, version);
174 Some((peer_id, state.content.collection_id.clone(), state.content.context_data.clone()))
175 }
177 _ => {
178 state.status = Status::PendingRemote;
180 None
181 }
182 }
183 }
184 None => return Err(anyhow!("Predicate {} not found", query_id)),
185 }
186 };
187
188 match update {
189 Some((peer_id, collection_id, context_data)) => {
190 self.update_query_on_peer(peer_id, query_id, collection_id, selection, version, context_data);
191 }
192 None => {
193 self.setup_remote_subscriptions();
195 }
196 };
197
198 Ok(())
199 }
200
201 fn update_query_on_peer(
202 &self,
203 peer_id: proto::EntityId,
204 query_id: proto::QueryId,
205 collection_id: CollectionId,
206 selection: ankql::ast::Selection,
207 version: u32,
208 context_data: CD,
209 ) {
210 let me = self.clone();
211 crate::task::spawn(async move {
212 if let Some(node) = me.inner.node.get() {
213 let livequery = { me.inner.subscriptions.lock().unwrap().get(&query_id).map(|state| state.livequery.clone()) };
215
216 match node.remote_subscribe(peer_id, query_id, collection_id, selection, &context_data, version).await {
218 Ok(()) => {
219 if let Some(lq) = livequery {
221 lq.subscription_established(version).await;
222 }
223
224 let mut subscriptions = me.inner.subscriptions.lock().unwrap();
226 if let Some(info) = subscriptions.get_mut(&query_id) {
227 info.status = Status::Established(peer_id, version);
228 }
229 debug!("Successfully updated predicate {} on peer {} subscription", query_id, peer_id);
230 }
231 Err(e) => {
232 me.handle_error(query_id, peer_id, e, livequery).await;
234 }
235 }
236 }
237 });
238 }
239
240 pub fn unsubscribe_predicate(&self, query_id: proto::QueryId) {
245 debug!("Unregistering predicate {}", query_id);
246
247 {
249 let mut subscriptions = self.inner.subscriptions.lock().unwrap();
250 if let Some(info) = subscriptions.remove(&query_id) {
251 if let Status::Established(peer_id, _version) = &info.status {
252 let node = self.inner.node.get();
253 if let Some(node) = node {
254 let node = node.clone();
255 let peer_id = *peer_id;
256 crate::task::spawn(async move {
257 if let Err(e) = node.peer_unsubscribe(peer_id, query_id).await {
258 warn!("Failed to send unsubscribe message for {}: {}", query_id, e);
259 } else {
260 debug!("Successfully sent unsubscribe message for {}", query_id);
261 }
262 });
263 }
264 }
265 }
266 }
267 }
268
269 pub fn notify_peer_disconnected(&self, peer_id: proto::EntityId) {
275 debug!("Peer {} disconnected, orphaning predicate registrations", peer_id);
276
277 self.inner.connected_peers.remove(&peer_id);
279
280 for info in self.inner.subscriptions.lock().expect("poisoned lock").values_mut() {
281 if let Status::Established(established_peer_id, _) | Status::Requested(established_peer_id, _) = &info.status {
282 if *established_peer_id == peer_id {
283 info.status = Status::PendingRemote;
285 warn!("Predicate {} orphaned due to peer {} disconnect", info.content.query_id, peer_id);
286 }
287 }
288 }
289
290 self.setup_remote_subscriptions();
292 }
293
294 pub fn notify_peer_connected(&self, peer_id: proto::EntityId) {
299 debug!("SubscriptionRelay.notify_peer_connected() - Peer {} connected, registering predicates on peer subscription", peer_id);
300
301 self.inner.connected_peers.insert(peer_id);
303
304 self.setup_remote_subscriptions();
306 }
307
308 pub fn get_status(&self, query_id: proto::QueryId) -> Option<Status> {
310 let subscriptions = self.inner.subscriptions.lock().unwrap();
311 subscriptions.get(&query_id).map(|info| info.status.clone())
312 }
313
314 pub fn get_contexts_for_peer(&self, peer_id: &proto::EntityId) -> std::collections::HashSet<CD> {
317 let subscriptions = self.inner.subscriptions.lock().unwrap();
318 let mut contexts = std::collections::HashSet::new();
319
320 for (_, state) in subscriptions.iter() {
321 match &state.status {
322 Status::Established(established_peer, _) | Status::Requested(established_peer, _) => {
323 if established_peer == peer_id {
324 contexts.insert(state.content.context_data.clone());
325 }
326 }
327 _ => {}
328 }
329 }
330
331 contexts
332 }
333
334 fn setup_remote_subscriptions(&self) {
336 let node = match self.inner.node.get() {
337 Some(node) => node,
338 None => {
339 warn!("No node configured for remote subscription setup");
340 return;
341 }
342 };
343
344 let connected_peers = self.inner.connected_peers.to_vec();
346 if connected_peers.is_empty() {
347 warn!("No durable peers available for remote subscription setup");
348 return;
349 }
350
351 let target_peer = connected_peers[0];
352
353 let pending: Vec<_> = {
355 self.inner
356 .subscriptions
357 .lock()
358 .expect("poisoned lock")
359 .values_mut()
360 .filter_map(|info| {
361 if let Status::PendingRemote = info.status {
362 info.status = Status::Requested(target_peer, info.content.version);
363 Some(info.content.clone())
364 } else {
365 None
366 }
367 })
368 .collect()
369 };
370
371 if pending.is_empty() {
372 return;
373 }
374
375 debug!("Registering {} predicates on {} peer subscriptions", pending.len(), self.inner.connected_peers.len());
376
377 for content in pending {
378 crate::task::spawn(self.clone().attempt_subscribe(node.clone(), target_peer, content));
379 }
380 }
381
382 async fn attempt_subscribe(self, node: Arc<dyn TNode<CD>>, target_peer: proto::EntityId, content: Arc<Content<CD>>) {
383 let query_id = content.query_id;
384 let predicate = content.selection.clone();
385 let context_data = content.context_data.clone();
386 let version = content.version;
387
388 let livequery = { self.inner.subscriptions.lock().unwrap().get(&query_id).map(|state| state.livequery.clone()) };
390
391 match node.remote_subscribe(target_peer, query_id, content.collection_id.clone(), predicate, &context_data, version).await {
393 Ok(()) => {
394 if let Some(lq) = livequery {
397 lq.subscription_established(version).await;
398 }
399
400 let mut subscriptions = self.inner.subscriptions.lock().unwrap();
402 if let Some(info) = subscriptions.get_mut(&query_id) {
403 info.status = Status::Established(target_peer, version);
404 }
405 debug!("Successfully registered predicate {} on peer {} subscription", query_id, target_peer);
406 }
407 Err(e) => {
408 self.handle_error(query_id, target_peer, e, livequery).await;
410 }
411 }
412 }
413
414 fn start_retry_task(&self, mut shutdown_rx: tokio::sync::mpsc::Receiver<()>) {
416 let me = self.clone();
417 crate::task::spawn(async move {
418 let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(5));
419 loop {
420 tokio::select! {
421 _ = interval.tick() => {
422 me.setup_remote_subscriptions();
424 }
425 _ = shutdown_rx.recv() => {
426 debug!("Retry task shutting down - SubscriptionRelay dropped");
427 break;
428 }
429 }
430 }
431 });
432 }
433
434 async fn handle_error(&self, query_id: proto::QueryId, target_peer: proto::EntityId, error: RetrievalError, livequery: Option<Q>) {
436 let error_msg = error.to_string();
437
438 let is_retryable = match &error {
440 RetrievalError::RequestError(req_err) => match req_err {
442 RequestError::PeerNotConnected => true,
443 RequestError::ConnectionLost => true,
444 RequestError::SendError(_) => true,
445 RequestError::InternalChannelClosed => true,
446 RequestError::ServerError(_) => false,
447 RequestError::UnexpectedResponse(_) => false,
448 },
449 _ => false,
451 };
452
453 let mut subscriptions = self.inner.subscriptions.lock().unwrap();
455 if let Some(info) = subscriptions.get_mut(&query_id) {
456 if is_retryable {
457 info.status = Status::PendingRemote;
459 warn!("Retryable failure for predicate {} with peer {}: {} - will retry", query_id, target_peer, error_msg);
460 } else {
461 info.status = Status::Failed;
463 tracing::error!("Permanent failure for predicate {} with peer {}: {} - no retry", query_id, target_peer, error_msg);
464
465 if let Some(lq) = livequery {
467 lq.set_last_error(error);
468 }
469 }
470 }
471 }
472}
473
474#[async_trait]
476pub trait TNode<CD: ContextData>: Send + Sync {
477 async fn remote_subscribe(
481 &self,
482 peer_id: proto::EntityId,
483 query_id: proto::QueryId,
484 collection_id: CollectionId,
485 selection: ankql::ast::Selection,
486 context_data: &CD,
487 version: u32,
488 ) -> Result<(), RetrievalError>;
489
490 async fn peer_unsubscribe(&self, peer_id: proto::EntityId, query_id: proto::QueryId) -> Result<(), anyhow::Error>;
493}
494
495#[async_trait]
497impl<SE, PA> TNode<PA::ContextData> for crate::node::WeakNode<SE, PA>
498where
499 SE: crate::storage::StorageEngine + Send + Sync + 'static,
500 PA: crate::policy::PolicyAgent + Send + Sync + 'static,
501{
502 async fn remote_subscribe(
503 &self,
504 peer_id: proto::EntityId,
505 query_id: proto::QueryId,
506 collection_id: CollectionId,
507 selection: ankql::ast::Selection,
508 context_data: &PA::ContextData,
509 version: u32,
510 ) -> Result<(), RetrievalError> {
511 let node = self.upgrade().ok_or_else(|| RetrievalError::Other("Node has been dropped".to_string()))?;
512
513 let known_matches: Vec<ankurah_proto::KnownEntity> = node
515 .fetch_entities_from_local(&collection_id, &selection)
516 .await?
517 .into_iter()
518 .map(|entity| ankurah_proto::KnownEntity { entity_id: entity.id(), head: entity.head() })
519 .collect();
520
521 let deltas = match node
523 .request(
524 peer_id,
525 context_data,
526 ankurah_proto::NodeRequestBody::SubscribeQuery {
527 query_id,
528 collection: collection_id.clone(),
529 selection: selection.clone(),
530 version,
531 known_matches,
532 },
533 )
534 .await
535 .map_err(|e| RetrievalError::RequestError(e))?
536 {
537 ankurah_proto::NodeResponseBody::QuerySubscribed { query_id: _response_query_id, deltas } => deltas,
538 ankurah_proto::NodeResponseBody::Error(e) => return Err(RetrievalError::RequestError(RequestError::ServerError(e))),
539 other => return Err(RetrievalError::RequestError(RequestError::UnexpectedResponse(other))),
540 };
541
542 let retriever = crate::retrieval::EphemeralNodeRetriever::new(collection_id, &node, context_data);
544 let apply_result = crate::node_applier::NodeApplier::apply_deltas(&node, &peer_id, deltas, &retriever).await;
545 let event_store_result = retriever.store_used_events().await;
546
547 apply_result?; event_store_result?;
549
550 Ok(())
551 }
552
553 async fn peer_unsubscribe(&self, peer_id: proto::EntityId, query_id: proto::QueryId) -> Result<(), anyhow::Error> {
554 let node = self.upgrade().ok_or_else(|| anyhow!("Node has been dropped"))?;
555
556 node.request_remote_unsubscribe(query_id, vec![peer_id]).await?;
558
559 Ok(())
560 }
561}
562
563#[cfg(test)]
564mod tests {
565 use super::*;
566 use ankql::ast::Predicate;
567 use ankurah_proto::EntityId;
568 use std::sync::{Arc, Mutex};
569
570 impl ContextData for CollectionId {}
578
579 #[derive(Debug)]
581 struct MockMessageSender<CD: ContextData> {
582 next_error: Arc<Mutex<Option<RequestError>>>,
583 sent_requests: Arc<Mutex<Vec<(EntityId, proto::QueryId, CollectionId, ankql::ast::Selection)>>>,
584 should_fail: Arc<Mutex<bool>>,
585 failure_message: Arc<Mutex<String>>,
586 _phantom: std::marker::PhantomData<CD>,
587 }
588
589 impl<CD: ContextData> MockMessageSender<CD> {
590 fn new() -> Self {
591 Self {
592 sent_requests: Arc::new(Mutex::new(Vec::new())),
593 next_error: Arc::new(Mutex::new(None)),
594 should_fail: Arc::new(Mutex::new(false)),
595 failure_message: Arc::new(Mutex::new(String::new())),
596 _phantom: std::marker::PhantomData,
597 }
598 }
599
600 fn set_fail_next(&self, error: RequestError) { *self.next_error.lock().unwrap() = Some(error); }
601
602 fn get_sent_requests(&self) -> Vec<(EntityId, proto::QueryId, CollectionId, ankql::ast::Selection)> {
603 self.sent_requests.lock().unwrap().clone()
604 }
605
606 fn clear_sent_requests(&self) { self.sent_requests.lock().unwrap().clear(); }
607 }
608
609 #[async_trait]
610 impl<CD: ContextData> TNode<CD> for MockMessageSender<CD> {
611 async fn remote_subscribe(
612 &self,
613 peer_id: EntityId,
614 query_id: proto::QueryId,
615 collection_id: CollectionId,
616 selection: ankql::ast::Selection,
617 _context_data: &CD,
618 _version: u32,
619 ) -> Result<(), RetrievalError> {
620 self.sent_requests.lock().unwrap().push((peer_id, query_id, collection_id.clone(), selection.clone()));
621
622 if let Some(error) = self.next_error.lock().unwrap().take() {
624 Err(RetrievalError::RequestError(error))
625 } else {
626 Ok(())
628 }
629 }
630
631 async fn peer_unsubscribe(&self, peer_id: EntityId, query_id: proto::QueryId) -> Result<(), anyhow::Error> {
632 self.sent_requests.lock().unwrap().push((
633 peer_id,
634 query_id,
635 CollectionId::from("unsubscribe"),
636 ankql::ast::Selection { predicate: ankql::ast::Predicate::True, order_by: None, limit: None },
637 ));
638
639 if let Some(error) = self.next_error.lock().unwrap().take() {
641 Err(anyhow!(error.to_string()))
642 } else {
643 Ok(())
644 }
645 }
646 }
647
648 #[derive(Clone)]
650 struct MockLiveQuery;
651
652 #[async_trait::async_trait]
653 impl RemoteQuerySubscriber for MockLiveQuery {
654 async fn subscription_established(&self, _version: u32) {
655 }
657
658 fn set_last_error(&self, _error: RetrievalError) {
659 }
661 }
662
663 fn create_test_selection() -> ankql::ast::Selection {
664 ankql::ast::Selection { predicate: ankql::ast::Predicate::True, order_by: None, limit: None }
666 }
667
668 fn create_test_collection_id() -> CollectionId { CollectionId::from("test_collection") }
669
670 #[tokio::test]
671 async fn test_new_subscription_setup() {
672 let relay = SubscriptionRelay::new();
673 let mock_sender = Arc::new(MockMessageSender::<CollectionId>::new());
674 relay.set_node(mock_sender.clone()).expect("Failed to set message sender");
675
676 let query_id = proto::QueryId::new();
677 let collection_id = create_test_collection_id();
678 let predicate = create_test_selection();
679 let peer_id = EntityId::new();
680
681 relay.notify_peer_connected(peer_id);
683
684 relay.subscribe_query(query_id, collection_id.clone(), predicate.clone(), collection_id.clone(), 0, MockLiveQuery);
686
687 assert!(matches!(relay.get_status(query_id), Some(Status::Requested(_, _))));
689
690 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
692
693 let sent_requests = mock_sender.get_sent_requests();
695 assert_eq!(sent_requests.len(), 1);
696 assert_eq!(sent_requests[0].0, peer_id);
697 assert_eq!(sent_requests[0].1, query_id);
698 assert_eq!(sent_requests[0].2, collection_id);
699
700 assert!(matches!(relay.get_status(query_id), Some(Status::Established(established_peer_id, _)) if established_peer_id == peer_id));
702 }
703
704 #[tokio::test]
705 async fn test_peer_disconnection_orphans_subscriptions() {
706 let relay = SubscriptionRelay::new();
707
708 let mock_sender = Arc::new(MockMessageSender::<CollectionId>::new());
709 relay.set_node(mock_sender.clone()).expect("Failed to set message sender");
710
711 let query_id = proto::QueryId::new();
712 let collection_id = create_test_collection_id();
713 let predicate = create_test_selection();
714 let peer_id = EntityId::new();
715
716 relay.notify_peer_connected(peer_id);
718
719 relay.subscribe_query(query_id, collection_id.clone(), predicate, collection_id.clone(), 0, MockLiveQuery);
721
722 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
724
725 assert!(matches!(relay.get_status(query_id), Some(Status::Established(established_peer_id, _)) if established_peer_id == peer_id));
726
727 relay.notify_peer_disconnected(peer_id);
729
730 assert!(matches!(relay.get_status(query_id), Some(Status::PendingRemote)));
732 }
733
734 #[tokio::test]
735 async fn test_peer_connection_triggers_setup() {
736 let relay = SubscriptionRelay::new();
737 let mock_sender = Arc::new(MockMessageSender::<CollectionId>::new());
738 relay.set_node(mock_sender.clone()).expect("Failed to set message sender");
739
740 let query_id = proto::QueryId::new();
741 let collection_id = create_test_collection_id();
742 let predicate = create_test_selection();
743 let peer_id = EntityId::new();
744
745 relay.subscribe_query(query_id, collection_id.clone(), predicate.clone(), collection_id.clone(), 0, MockLiveQuery);
747 assert!(matches!(relay.get_status(query_id), Some(Status::PendingRemote)));
748
749 mock_sender.clear_sent_requests();
751
752 relay.notify_peer_connected(peer_id);
754
755 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
757
758 let sent_requests = mock_sender.get_sent_requests();
760 assert_eq!(sent_requests.len(), 1);
761 assert_eq!(sent_requests[0].0, peer_id);
762 assert_eq!(sent_requests[0].1, query_id);
763
764 assert!(matches!(relay.get_status(query_id), Some(Status::Established(established_peer_id, _)) if established_peer_id == peer_id));
766 }
767
768 #[tokio::test]
769 async fn test_failed_subscription_retry() {
770 let relay = SubscriptionRelay::new();
771 let mock_sender = Arc::new(MockMessageSender::<CollectionId>::new());
772 relay.set_node(mock_sender.clone()).expect("Failed to set message sender");
773
774 let query_id = proto::QueryId::new();
775 let collection_id = create_test_collection_id();
776 let predicate = create_test_selection();
777 let peer_id = EntityId::new();
778
779 relay.notify_peer_connected(peer_id);
781 relay.subscribe_query(query_id, collection_id.clone(), predicate.clone(), collection_id.clone(), 0, MockLiveQuery);
782
783 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
785
786 assert!(matches!(relay.get_status(query_id), Some(Status::Established(established_peer_id, _)) if established_peer_id == peer_id));
788
789 relay.notify_peer_disconnected(peer_id);
792
793 assert!(matches!(relay.get_status(query_id), Some(Status::PendingRemote)));
795
796 mock_sender.clear_sent_requests();
798 mock_sender.set_fail_next(RequestError::ServerError("Invalid predicate".to_string()));
799
800 relay.notify_peer_connected(peer_id);
802
803 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
805
806 let sent_requests = mock_sender.get_sent_requests();
808 assert_eq!(sent_requests.len(), 1);
809
810 assert!(matches!(relay.get_status(query_id), Some(Status::Failed)));
812 }
813
814 #[tokio::test]
815 async fn test_retryable_vs_non_retryable_failures() {
816 let relay = SubscriptionRelay::new();
817 let mock_sender = Arc::new(MockMessageSender::<CollectionId>::new());
818 relay.set_node(mock_sender.clone()).expect("Failed to set message sender");
819
820 let retryable_query_id = proto::QueryId::new();
821 let non_retryable_query_id = proto::QueryId::new();
822 let collection_id = create_test_collection_id();
823 let predicate = create_test_selection();
824 let peer_id = EntityId::new();
825
826 relay.subscribe_query(retryable_query_id, collection_id.clone(), predicate.clone(), collection_id.clone(), 0, MockLiveQuery);
828 relay.subscribe_query(non_retryable_query_id, collection_id.clone(), predicate.clone(), collection_id.clone(), 0, MockLiveQuery);
829
830 {
832 let mut subscriptions = relay.inner.subscriptions.lock().unwrap();
833 if let Some(info) = subscriptions.get_mut(&retryable_query_id) {
834 info.status = Status::PendingRemote; }
836 if let Some(info) = subscriptions.get_mut(&non_retryable_query_id) {
837 info.status = Status::Failed; }
839 }
840
841 relay.notify_peer_connected(peer_id);
843
844 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
846
847 let sent_requests = mock_sender.get_sent_requests();
849 assert_eq!(sent_requests.len(), 1);
850 assert_eq!(sent_requests[0].1, retryable_query_id);
851
852 assert!(
854 matches!(relay.get_status(retryable_query_id), Some(Status::Established(established_peer_id, _)) if established_peer_id == peer_id)
855 );
856 assert!(matches!(relay.get_status(non_retryable_query_id), Some(Status::Failed)));
857 }
858
859 #[tokio::test]
860 async fn test_subscription_removal() {
861 let relay = SubscriptionRelay::new();
862 let mock_sender = Arc::new(MockMessageSender::<CollectionId>::new());
863 relay.set_node(mock_sender.clone()).expect("Failed to set message sender");
864
865 let query_id = proto::QueryId::new();
866 let collection_id = create_test_collection_id();
867 let predicate = create_test_selection();
868 let peer_id = EntityId::new();
869
870 relay.notify_peer_connected(peer_id);
872 relay.subscribe_query(query_id, collection_id.clone(), predicate, collection_id.clone(), 0, MockLiveQuery);
873
874 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
876
877 assert!(matches!(relay.get_status(query_id), Some(Status::Established(established_peer_id, _)) if established_peer_id == peer_id));
878
879 mock_sender.clear_sent_requests();
881
882 relay.unsubscribe_predicate(query_id);
884
885 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
887
888 let sent_requests = mock_sender.get_sent_requests();
890 assert_eq!(sent_requests.len(), 1);
891 assert_eq!(sent_requests[0].0, peer_id);
892 assert_eq!(sent_requests[0].1, query_id);
893
894 assert!(matches!(relay.get_status(query_id), None));
896 }
897
898 #[tokio::test]
899 async fn test_edge_cases() {
900 let relay = SubscriptionRelay::new();
901 let mock_sender = Arc::new(MockMessageSender::<CollectionId>::new());
902
903 let query_id = proto::QueryId::new();
904 let collection_id = create_test_collection_id();
905 let predicate = create_test_selection();
906 let peer_id = EntityId::new();
907
908 relay.subscribe_query(query_id, collection_id.clone(), predicate.clone(), collection_id.clone(), 0, MockLiveQuery);
910 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
911
912 assert!(matches!(relay.get_status(query_id), Some(Status::PendingRemote)));
914
915 relay.set_node(mock_sender.clone()).expect("Failed to set message sender");
917 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
918
919 assert!(matches!(relay.get_status(query_id), Some(Status::PendingRemote)));
921
922 assert_eq!(mock_sender.get_sent_requests().len(), 0);
924
925 relay.notify_peer_connected(peer_id);
927 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
928
929 assert!(matches!(relay.get_status(query_id), Some(Status::Established(established_peer_id, _)) if established_peer_id == peer_id));
931 assert_eq!(mock_sender.get_sent_requests().len(), 1);
932 }
933
934 #[tokio::test]
935 async fn test_notify_unsubscribe_with_no_established_subscription() {
936 let relay = SubscriptionRelay::new();
937 let mock_sender = Arc::new(MockMessageSender::<CollectionId>::new());
938 relay.set_node(mock_sender.clone()).expect("Failed to set message sender");
939
940 let query_id = proto::QueryId::new();
941 let collection_id = create_test_collection_id();
942 let predicate = create_test_selection();
943
944 relay.subscribe_query(query_id, collection_id.clone(), predicate, collection_id.clone(), 0, MockLiveQuery);
946 assert!(matches!(relay.get_status(query_id), Some(Status::PendingRemote)));
947
948 relay.unsubscribe_predicate(query_id);
950
951 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
953
954 let sent_requests = mock_sender.get_sent_requests();
956 assert_eq!(sent_requests.len(), 0);
957
958 assert!(matches!(relay.get_status(query_id), None));
960 }
961}