Skip to main content

ankurah_core/peer_subscription/
client_relay.rs

1// TODO: Rename this module from client_relay to remote_subscription for clarity
2use ankurah_proto::{self as proto, CollectionId};
3use anyhow::anyhow;
4use async_trait::async_trait;
5use std::collections::HashMap;
6use std::sync::{Arc, OnceLock};
7use tracing::{debug, warn};
8
9use crate::error::{RequestError, RetrievalError};
10use crate::node::ContextData;
11use crate::util::safeset::SafeSet;
12
13/// Trait for query initialization that can be driven by SubscriptionRelay
14/// Abstracts the relay's interaction with LiveQuery
15#[async_trait::async_trait]
16pub trait RemoteQuerySubscriber: Clone + Send + Sync + 'static {
17    /// Called after remote subscription deltas have been applied
18    /// Dispatches to initialize (version 1) or update_selection_init (version >1) internally
19    /// Handles marking initialization as complete and setting last_error on failure
20    async fn subscription_established(&self, version: u32);
21
22    /// Set the last error for this subscription
23    fn set_last_error(&self, error: RetrievalError);
24}
25
26#[derive(Debug, Clone)]
27pub enum Status {
28    PendingRemote,
29    Requested(proto::EntityId, u32),     // peer_id, version
30    Established(proto::EntityId, u32),   // peer_id, version
31    PendingUpdate(proto::EntityId, u32), // peer_id, version
32    /// Non-retryable
33    Failed,
34}
35
36#[derive(Debug)]
37pub struct Content<CD: ContextData> {
38    pub query_id: proto::QueryId,
39    pub collection_id: CollectionId,
40    pub selection: ankql::ast::Selection,
41    pub context_data: CD,
42    pub version: u32,
43}
44
45pub struct RemoteQueryState<CD: ContextData, Q: RemoteQuerySubscriber> {
46    pub content: Arc<Content<CD>>,
47    pub status: Status,
48    pub livequery: Q,
49}
50
51struct SubscriptionRelayInner<CD: ContextData, Q: RemoteQuerySubscriber> {
52    // All subscription information in one place
53    subscriptions: std::sync::Mutex<HashMap<proto::QueryId, RemoteQueryState<CD, Q>>>,
54    // Track connected durable peers
55    connected_peers: SafeSet<proto::EntityId>,
56    // Node for communicating with remote peers
57    node: OnceLock<Arc<dyn TNode<CD>>>,
58    // Shutdown signal for retry task - when dropped, the task will stop
59    _shutdown_tx: tokio::sync::mpsc::Sender<()>,
60}
61
62/// Manages predicate registration on remote peer reactor subscriptions.
63///
64/// The SubscriptionRelay provides a resilient, event-driven approach to managing which predicates
65/// are registered with remote durable peers. It automatically handles:
66/// - Registering predicates on peer reactor subscriptions when peers connect
67/// - Re-registering predicates when peers disconnect and reconnect
68/// - Retrying failed predicate registration attempts
69/// - Clean teardown when predicates are removed
70/// - Storing ContextData for each predicate to enable proper authorization
71///
72/// This design separates predicate management concerns from the main Node implementation,
73/// making it easier to test and reason about predicate lifecycle management.
74///
75/// # Public API (for Node integration)
76///
77/// - `subscribe_predicate()` - Call when local subscriptions are created (parallel to reactor.subscribe)
78/// - `unsubscribe_predicate()` - Call when local subscriptions are removed (parallel to reactor.unsubscribe)
79/// - `notify_peer_connected()` - Call when durable peers connect (triggers automatic predicate registration)
80/// - `notify_peer_disconnected()` - Call when durable peers disconnect (orphans predicate registrations)
81/// - `get_status()` - Query current state of a predicate registration
82///
83/// # Internal/Testing API
84///
85/// - `setup_remote_subscriptions()` - Internal method for triggering predicate registration with specific peers
86///   (called automatically by notify_peer_connected, but exposed for testing)
87///
88/// The relay will automatically handle predicate registration/teardown asynchronously.
89#[derive(Clone)]
90pub struct SubscriptionRelay<CD: ContextData, Q: RemoteQuerySubscriber> {
91    inner: Arc<SubscriptionRelayInner<CD, Q>>,
92}
93
94impl<CD: ContextData, Q: RemoteQuerySubscriber> Default for SubscriptionRelay<CD, Q> {
95    fn default() -> Self { Self::new() }
96}
97
98impl<CD: ContextData, Q: RemoteQuerySubscriber> SubscriptionRelay<CD, Q> {
99    pub fn new() -> Self {
100        let (shutdown_tx, shutdown_rx) = tokio::sync::mpsc::channel(1);
101
102        let relay = Self {
103            inner: Arc::new(SubscriptionRelayInner {
104                subscriptions: std::sync::Mutex::new(HashMap::new()),
105                connected_peers: SafeSet::new(),
106                node: OnceLock::new(),
107                _shutdown_tx: shutdown_tx,
108            }),
109        };
110
111        // Start background retry task
112        relay.start_retry_task(shutdown_rx);
113
114        relay
115    }
116
117    /// Inject the node (typically a WeakNode for production)
118    ///
119    /// This should be called once during initialization. Returns an error if
120    /// the node has already been set.
121    pub fn set_node(&self, node: Arc<dyn TNode<CD>>) -> Result<(), ()> { self.inner.node.set(node).map_err(|_| ()) }
122
123    /// Notify the relay that a new predicate needs to be registered on remote peer subscriptions
124    ///
125    /// This should be called whenever a local subscription is established. The relay will
126    /// track this predicate and automatically attempt to register it with available durable peers.
127    pub fn subscribe_query(
128        &self,
129        query_id: proto::QueryId,
130        collection_id: CollectionId,
131        selection: ankql::ast::Selection,
132        context_data: CD,
133        version: u32,
134        livequery: Q,
135    ) {
136        debug!("SubscriptionRelay.subscribe_predicate() - New predicate {} needs remote registration", query_id);
137        {
138            self.inner.subscriptions.lock().expect("poisoned lock").insert(
139                query_id,
140                RemoteQueryState {
141                    content: Arc::new(Content { collection_id, selection, context_data, query_id, version }),
142                    status: Status::PendingRemote,
143                    livequery,
144                },
145            );
146        }
147
148        // Immediately attempt setup with available peers
149        if !self.inner.connected_peers.is_empty() {
150            self.setup_remote_subscriptions()
151        }
152    }
153    pub fn update_query(&self, query_id: proto::QueryId, selection: ankql::ast::Selection, version: u32) -> Result<(), anyhow::Error> {
154        debug!("SubscriptionRelay.update_query() - New query {} needs remote registration", query_id);
155
156        let update = {
157            let mut subscriptions = self.inner.subscriptions.lock().expect("poisoned lock");
158            match subscriptions.get_mut(&query_id) {
159                Some(state) => {
160                    // Update the content with new predicate and version
161                    let old_content = &state.content;
162                    state.content = Arc::new(Content {
163                        collection_id: old_content.collection_id.clone(),
164                        selection: selection.clone(),
165                        context_data: old_content.context_data.clone(),
166                        query_id: old_content.query_id,
167                        version,
168                    });
169
170                    match state.status {
171                        Status::Established(peer_id, _old_version) => {
172                            // Update to new version, mark as requested for this peer
173                            state.status = Status::Requested(peer_id, version);
174                            Some((peer_id, state.content.collection_id.clone(), state.content.context_data.clone()))
175                            // Return the peer_id to send update to
176                        }
177                        _ => {
178                            // Not established yet, just update to PendingRemote and setup
179                            state.status = Status::PendingRemote;
180                            None
181                        }
182                    }
183                }
184                None => return Err(anyhow!("Predicate {} not found", query_id)),
185            }
186        };
187
188        match update {
189            Some((peer_id, collection_id, context_data)) => {
190                self.update_query_on_peer(peer_id, query_id, collection_id, selection, version, context_data);
191            }
192            None => {
193                // Not established yet - use setup_remote_subscriptions for initial setup
194                self.setup_remote_subscriptions();
195            }
196        };
197
198        Ok(())
199    }
200
201    fn update_query_on_peer(
202        &self,
203        peer_id: proto::EntityId,
204        query_id: proto::QueryId,
205        collection_id: CollectionId,
206        selection: ankql::ast::Selection,
207        version: u32,
208        context_data: CD,
209    ) {
210        let me = self.clone();
211        crate::task::spawn(async move {
212            if let Some(node) = me.inner.node.get() {
213                // Get the livequery for error handling
214                let livequery = {
215                    me.inner.subscriptions.lock().unwrap_or_else(|e| e.into_inner()).get(&query_id).map(|state| state.livequery.clone())
216                };
217
218                // Send the updated predicate to the peer
219                match node.remote_subscribe(peer_id, query_id, collection_id, selection, &context_data, version).await {
220                    Ok(()) => {
221                        // Deltas applied successfully, now activate the livequery
222                        if let Some(lq) = livequery {
223                            lq.subscription_established(version).await;
224                        }
225
226                        // Mark as established - subscription succeeded even if livequery activation had issues
227                        let mut subscriptions = me.inner.subscriptions.lock().unwrap_or_else(|e| e.into_inner());
228                        if let Some(info) = subscriptions.get_mut(&query_id) {
229                            info.status = Status::Established(peer_id, version);
230                        }
231                        debug!("Successfully updated predicate {} on peer {} subscription", query_id, peer_id);
232                    }
233                    Err(e) => {
234                        // Handle error with retry logic
235                        me.handle_error(query_id, peer_id, e, livequery).await;
236                    }
237                }
238            }
239        });
240    }
241
242    /// Notify the relay that a predicate should be removed from remote peer subscriptions
243    ///
244    /// This will clean up all tracking state and send unsubscribe requests to any
245    /// remote peers that have this predicate registered.
246    pub fn unsubscribe_predicate(&self, query_id: proto::QueryId) {
247        debug!("Unregistering predicate {}", query_id);
248
249        // If subscription was established with a peer, send unsubscribe request
250        {
251            let mut subscriptions = self.inner.subscriptions.lock().unwrap_or_else(|e| e.into_inner());
252            if let Some(info) = subscriptions.remove(&query_id) {
253                if let Status::Established(peer_id, _version) = &info.status {
254                    let node = self.inner.node.get();
255                    if let Some(node) = node {
256                        let node = node.clone();
257                        let peer_id = *peer_id;
258                        crate::task::spawn(async move {
259                            if let Err(e) = node.peer_unsubscribe(peer_id, query_id).await {
260                                warn!("Failed to send unsubscribe message for {}: {}", query_id, e);
261                            } else {
262                                debug!("Successfully sent unsubscribe message for {}", query_id);
263                            }
264                        });
265                    }
266                }
267            }
268        }
269    }
270
271    /// Handle peer disconnection - mark all predicates for that peer as needing re-registration
272    ///
273    /// This should be called when a durable peer disconnects. All predicates registered
274    /// with that peer will be marked as pending and will be automatically re-registered
275    /// when the peer reconnects or another suitable peer becomes available.
276    pub fn notify_peer_disconnected(&self, peer_id: proto::EntityId) {
277        debug!("Peer {} disconnected, orphaning predicate registrations", peer_id);
278
279        // Remove from connected peers
280        self.inner.connected_peers.remove(&peer_id);
281
282        for info in self.inner.subscriptions.lock().expect("poisoned lock").values_mut() {
283            if let Status::Established(established_peer_id, _) | Status::Requested(established_peer_id, _) = &info.status {
284                if *established_peer_id == peer_id {
285                    // Update state to pending
286                    info.status = Status::PendingRemote;
287                    warn!("Predicate {} orphaned due to peer {} disconnect", info.content.query_id, peer_id);
288                }
289            }
290        }
291
292        // Resubscribe any orphaned subscriptions
293        self.setup_remote_subscriptions();
294    }
295
296    /// Handle peer connection - trigger predicate registration on the new peer subscription
297    ///
298    /// This should be called when a new durable peer connects. The relay will automatically
299    /// attempt to register any pending predicates on the newly connected peer's subscription.
300    pub fn notify_peer_connected(&self, peer_id: proto::EntityId) {
301        debug!("SubscriptionRelay.notify_peer_connected() - Peer {} connected, registering predicates on peer subscription", peer_id);
302
303        // Add to connected peers
304        self.inner.connected_peers.insert(peer_id);
305
306        // Trigger setup with all connected peers
307        self.setup_remote_subscriptions();
308    }
309
310    /// Get the current state of a predicate registration
311    pub fn get_status(&self, query_id: proto::QueryId) -> Option<Status> {
312        let subscriptions = self.inner.subscriptions.lock().unwrap_or_else(|e| e.into_inner());
313        subscriptions.get(&query_id).map(|info| info.status.clone())
314    }
315
316    /// Get all unique contexts for predicates established or requested with a specific peer
317    /// TODO: update the data structure to do this via a direct lookup rather than having to scan the entire map
318    pub fn get_contexts_for_peer(&self, peer_id: &proto::EntityId) -> std::collections::HashSet<CD> {
319        let subscriptions = self.inner.subscriptions.lock().unwrap_or_else(|e| e.into_inner());
320        let mut contexts = std::collections::HashSet::new();
321
322        for (_, state) in subscriptions.iter() {
323            match &state.status {
324                Status::Established(established_peer, _) | Status::Requested(established_peer, _) => {
325                    if established_peer == peer_id {
326                        contexts.insert(state.content.context_data.clone());
327                    }
328                }
329                _ => {}
330            }
331        }
332
333        contexts
334    }
335
336    /// Register predicates on available durable peer subscriptions
337    fn setup_remote_subscriptions(&self) {
338        let node = match self.inner.node.get() {
339            Some(node) => node,
340            None => {
341                warn!("No node configured for remote subscription setup");
342                return;
343            }
344        };
345
346        // For now, use the first available peer (could be made smarter)
347        let connected_peers = self.inner.connected_peers.to_vec();
348        if connected_peers.is_empty() {
349            warn!("No durable peers available for remote subscription setup");
350            return;
351        }
352
353        let target_peer = connected_peers[0];
354
355        // Atomically get pending subscriptions and mark them as requested
356        let pending: Vec<_> = {
357            self.inner
358                .subscriptions
359                .lock()
360                .expect("poisoned lock")
361                .values_mut()
362                .filter_map(|info| {
363                    if let Status::PendingRemote = info.status {
364                        info.status = Status::Requested(target_peer, info.content.version);
365                        Some(info.content.clone())
366                    } else {
367                        None
368                    }
369                })
370                .collect()
371        };
372
373        if pending.is_empty() {
374            return;
375        }
376
377        debug!("Registering {} predicates on {} peer subscriptions", pending.len(), self.inner.connected_peers.len());
378
379        for content in pending {
380            crate::task::spawn(self.clone().attempt_subscribe(node.clone(), target_peer, content));
381        }
382    }
383
384    async fn attempt_subscribe(self, node: Arc<dyn TNode<CD>>, target_peer: proto::EntityId, content: Arc<Content<CD>>) {
385        let query_id = content.query_id;
386        let predicate = content.selection.clone();
387        let context_data = content.context_data.clone();
388        let version = content.version;
389
390        // Get the livequery for error handling
391        let livequery =
392            { self.inner.subscriptions.lock().unwrap_or_else(|e| e.into_inner()).get(&query_id).map(|state| state.livequery.clone()) };
393
394        // Call remote_subscribe which fetches known matches, subscribes, applies deltas, and stores events
395        match node.remote_subscribe(target_peer, query_id, content.collection_id.clone(), predicate, &context_data, version).await {
396            Ok(()) => {
397                // Deltas applied successfully, now activate the livequery
398                // The livequery handles its own errors internally
399                if let Some(lq) = livequery {
400                    lq.subscription_established(version).await;
401                }
402
403                // Mark as established - subscription succeeded even if livequery activation had issues
404                let mut subscriptions = self.inner.subscriptions.lock().unwrap_or_else(|e| e.into_inner());
405                if let Some(info) = subscriptions.get_mut(&query_id) {
406                    info.status = Status::Established(target_peer, version);
407                }
408                debug!("Successfully registered predicate {} on peer {} subscription", query_id, target_peer);
409            }
410            Err(e) => {
411                // Handle error with retry logic
412                self.handle_error(query_id, target_peer, e, livequery).await;
413            }
414        }
415    }
416
417    /// Start background task that periodically retries pending subscriptions
418    fn start_retry_task(&self, mut shutdown_rx: tokio::sync::mpsc::Receiver<()>) {
419        let me = self.clone();
420        crate::task::spawn(async move {
421            loop {
422                let delay = futures_timer::Delay::new(std::time::Duration::from_secs(5));
423                tokio::select! {
424                    _ = delay => {
425                        // Attempt to setup any pending subscriptions
426                        me.setup_remote_subscriptions();
427                    }
428                    _ = shutdown_rx.recv() => {
429                        debug!("Retry task shutting down - SubscriptionRelay dropped");
430                        break;
431                    }
432                }
433            }
434        });
435    }
436
437    /// Handle errors with retry logic
438    async fn handle_error(&self, query_id: proto::QueryId, target_peer: proto::EntityId, error: RetrievalError, livequery: Option<Q>) {
439        let error_msg = error.to_string();
440
441        // Evaluate retriability at failure time
442        let is_retryable = match &error {
443            // Retrieval errors from fetching are generally not retryable
444            RetrievalError::RequestError(req_err) => match req_err {
445                RequestError::PeerNotConnected => true,
446                RequestError::ConnectionLost => true,
447                RequestError::SendError(_) => true,
448                RequestError::InternalChannelClosed => true,
449                RequestError::ServerError(_) => false,
450                RequestError::UnexpectedResponse(_) => false,
451                RequestError::AccessDenied(_) => false,
452            },
453            // Other retrieval errors are not retryable
454            _ => false,
455        };
456
457        // Update state based on retriability
458        let mut subscriptions = self.inner.subscriptions.lock().unwrap_or_else(|e| e.into_inner());
459        if let Some(info) = subscriptions.get_mut(&query_id) {
460            if is_retryable {
461                // Retryable errors go back to pending for retry by background task
462                info.status = Status::PendingRemote;
463                warn!("Retryable failure for predicate {} with peer {}: {} - will retry", query_id, target_peer, error_msg);
464            } else {
465                // Non-retryable errors are permanently failed
466                info.status = Status::Failed;
467                tracing::error!("Permanent failure for predicate {} with peer {}: {} - no retry", query_id, target_peer, error_msg);
468
469                // Set error on livequery
470                if let Some(lq) = livequery {
471                    lq.set_last_error(error);
472                }
473            }
474        }
475    }
476}
477
478/// Trait for communicating with remote peers (abstraction over WeakNode for testing)
479#[async_trait]
480pub trait TNode<CD: ContextData>: Send + Sync {
481    /// Send a predicate registration request to a remote peer, fetch known matches,
482    /// apply received deltas, and store used events.
483    /// Returns Ok(()) if subscription was established and deltas applied successfully.
484    async fn remote_subscribe(
485        &self,
486        peer_id: proto::EntityId,
487        query_id: proto::QueryId,
488        collection_id: CollectionId,
489        selection: ankql::ast::Selection,
490        context_data: &CD,
491        version: u32,
492    ) -> Result<(), RetrievalError>;
493
494    /// Send a predicate unregistration message to a remote peer
495    /// This is a one-way message, no response expected
496    async fn peer_unsubscribe(&self, peer_id: proto::EntityId, query_id: proto::QueryId) -> Result<(), anyhow::Error>;
497}
498
499/// Implementation of TNode for WeakNode
500#[async_trait]
501impl<SE, PA> TNode<PA::ContextData> for crate::node::WeakNode<SE, PA>
502where
503    SE: crate::storage::StorageEngine + Send + Sync + 'static,
504    PA: crate::policy::PolicyAgent + Send + Sync + 'static,
505{
506    async fn remote_subscribe(
507        &self,
508        peer_id: proto::EntityId,
509        query_id: proto::QueryId,
510        collection_id: CollectionId,
511        selection: ankql::ast::Selection,
512        context_data: &PA::ContextData,
513        version: u32,
514    ) -> Result<(), RetrievalError> {
515        let node = self.upgrade().ok_or_else(|| RetrievalError::Other("Node has been dropped".to_string()))?;
516
517        // 1. Pre-fetch known_matches from local storage
518        let known_matches: Vec<ankurah_proto::KnownEntity> = node
519            .fetch_entities_from_local(&collection_id, &selection)
520            .await?
521            .into_iter()
522            .map(|entity| ankurah_proto::KnownEntity { entity_id: entity.id(), head: entity.head() })
523            .collect();
524
525        // 2. Send subscribe request with known_matches
526        let deltas = match node
527            .request(
528                peer_id,
529                context_data,
530                ankurah_proto::NodeRequestBody::SubscribeQuery {
531                    query_id,
532                    collection: collection_id.clone(),
533                    selection: selection.clone(),
534                    version,
535                    known_matches,
536                },
537            )
538            .await
539            .map_err(|e| RetrievalError::RequestError(e))?
540        {
541            ankurah_proto::NodeResponseBody::QuerySubscribed { query_id: _response_query_id, deltas } => deltas,
542            ankurah_proto::NodeResponseBody::Error(e) => return Err(RetrievalError::RequestError(RequestError::ServerError(e))),
543            other => return Err(RetrievalError::RequestError(RequestError::UnexpectedResponse(other))),
544        };
545
546        tracing::debug!(
547            "Node.remote_subscribe: query_id: {}, collection_id: {}, received deltas: {}",
548            query_id,
549            collection_id,
550            deltas.len()
551        );
552        // 3. Apply deltas to local node using NodeApplier
553        let retriever = crate::retrieval::EphemeralNodeRetriever::new(collection_id, &node, context_data);
554        let apply_result = crate::node_applier::NodeApplier::apply_deltas(&node, &peer_id, deltas, &retriever).await;
555        let event_store_result = retriever.store_used_events().await;
556
557        apply_result?; // apply result is more important than event store result
558        event_store_result?;
559
560        Ok(())
561    }
562
563    async fn peer_unsubscribe(&self, peer_id: proto::EntityId, query_id: proto::QueryId) -> Result<(), anyhow::Error> {
564        let node = self.upgrade().ok_or_else(|| anyhow!("Node has been dropped"))?;
565
566        // Use the existing request_remote_unsubscribe method
567        node.request_remote_unsubscribe(query_id, vec![peer_id]).await?;
568
569        Ok(())
570    }
571}
572
573#[cfg(test)]
574mod tests {
575    use super::*;
576    use ankql::ast::Predicate;
577    use ankurah_proto::EntityId;
578    use std::sync::{Arc, Mutex};
579
580    // Note: Some tests call setup_remote_subscriptions() directly to test the core
581    // subscription setup logic in isolation, while others use notify_peer_connected()
582    // to test the full event-driven flow. Both approaches are valuable:
583    // - Direct calls test the setup mechanism itself (error handling, state transitions)
584    // - Event-driven calls test the integration and user-facing API
585
586    // For testing, we'll use CollectionId as our ContextData
587    impl ContextData for CollectionId {}
588
589    /// Mock message sender for testing
590    #[derive(Debug)]
591    struct MockMessageSender<CD: ContextData> {
592        next_error: Arc<Mutex<Option<RequestError>>>,
593        sent_requests: Arc<Mutex<Vec<(EntityId, proto::QueryId, CollectionId, ankql::ast::Selection)>>>,
594        should_fail: Arc<Mutex<bool>>,
595        failure_message: Arc<Mutex<String>>,
596        _phantom: std::marker::PhantomData<CD>,
597    }
598
599    impl<CD: ContextData> MockMessageSender<CD> {
600        fn new() -> Self {
601            Self {
602                sent_requests: Arc::new(Mutex::new(Vec::new())),
603                next_error: Arc::new(Mutex::new(None)),
604                should_fail: Arc::new(Mutex::new(false)),
605                failure_message: Arc::new(Mutex::new(String::new())),
606                _phantom: std::marker::PhantomData,
607            }
608        }
609
610        fn set_fail_next(&self, error: RequestError) { *self.next_error.lock().unwrap() = Some(error); }
611
612        fn get_sent_requests(&self) -> Vec<(EntityId, proto::QueryId, CollectionId, ankql::ast::Selection)> {
613            self.sent_requests.lock().unwrap().clone()
614        }
615
616        fn clear_sent_requests(&self) { self.sent_requests.lock().unwrap().clear(); }
617    }
618
619    #[async_trait]
620    impl<CD: ContextData> TNode<CD> for MockMessageSender<CD> {
621        async fn remote_subscribe(
622            &self,
623            peer_id: EntityId,
624            query_id: proto::QueryId,
625            collection_id: CollectionId,
626            selection: ankql::ast::Selection,
627            _context_data: &CD,
628            _version: u32,
629        ) -> Result<(), RetrievalError> {
630            self.sent_requests.lock().unwrap().push((peer_id, query_id, collection_id.clone(), selection.clone()));
631
632            // Check if there's an error to fail with
633            if let Some(error) = self.next_error.lock().unwrap().take() {
634                Err(RetrievalError::RequestError(error))
635            } else {
636                // Mock successful subscription (fetch, subscribe, apply, store all succeeded)
637                Ok(())
638            }
639        }
640
641        async fn peer_unsubscribe(&self, peer_id: EntityId, query_id: proto::QueryId) -> Result<(), anyhow::Error> {
642            self.sent_requests.lock().unwrap().push((
643                peer_id,
644                query_id,
645                CollectionId::from("unsubscribe"),
646                ankql::ast::Selection { predicate: ankql::ast::Predicate::True, order_by: None, limit: None },
647            ));
648
649            // Check if there's an error to fail with
650            if let Some(error) = self.next_error.lock().unwrap().take() {
651                Err(anyhow!(error.to_string()))
652            } else {
653                Ok(())
654            }
655        }
656    }
657
658    // Mock implementation of RemoteQuerySubscriber for tests
659    #[derive(Clone)]
660    struct MockLiveQuery;
661
662    #[async_trait::async_trait]
663    impl RemoteQuerySubscriber for MockLiveQuery {
664        async fn subscription_established(&self, _version: u32) {
665            // Mock - no-op
666        }
667
668        fn set_last_error(&self, _error: RetrievalError) {
669            // For tests, we don't track errors
670        }
671    }
672
673    fn create_test_selection() -> ankql::ast::Selection {
674        // Create a simple test predicate
675        ankql::ast::Selection { predicate: ankql::ast::Predicate::True, order_by: None, limit: None }
676    }
677
678    fn create_test_collection_id() -> CollectionId { CollectionId::from("test_collection") }
679
680    #[tokio::test]
681    async fn test_new_subscription_setup() {
682        let relay = SubscriptionRelay::new();
683        let mock_sender = Arc::new(MockMessageSender::<CollectionId>::new());
684        relay.set_node(mock_sender.clone()).expect("Failed to set message sender");
685
686        let query_id = proto::QueryId::new();
687        let collection_id = create_test_collection_id();
688        let predicate = create_test_selection();
689        let peer_id = EntityId::new();
690
691        // Connect the peer first
692        relay.notify_peer_connected(peer_id);
693
694        // Notify of new subscription
695        relay.subscribe_query(query_id, collection_id.clone(), predicate.clone(), collection_id.clone(), 0, MockLiveQuery);
696
697        // Check initial state - subscription should immediately go to Requested state since peer is connected
698        assert!(matches!(relay.get_status(query_id), Some(Status::Requested(_, _))));
699
700        // Give async task time to complete (setup should happen automatically)
701        futures_timer::Delay::new(std::time::Duration::from_millis(10)).await;
702
703        // Verify request was sent
704        let sent_requests = mock_sender.get_sent_requests();
705        assert_eq!(sent_requests.len(), 1);
706        assert_eq!(sent_requests[0].0, peer_id);
707        assert_eq!(sent_requests[0].1, query_id);
708        assert_eq!(sent_requests[0].2, collection_id);
709
710        // Verify subscription is marked as established
711        assert!(matches!(relay.get_status(query_id), Some(Status::Established(established_peer_id, _)) if established_peer_id == peer_id));
712    }
713
714    #[tokio::test]
715    async fn test_peer_disconnection_orphans_subscriptions() {
716        let relay = SubscriptionRelay::new();
717
718        let mock_sender = Arc::new(MockMessageSender::<CollectionId>::new());
719        relay.set_node(mock_sender.clone()).expect("Failed to set message sender");
720
721        let query_id = proto::QueryId::new();
722        let collection_id = create_test_collection_id();
723        let predicate = create_test_selection();
724        let peer_id = EntityId::new();
725
726        // Connect the peer first
727        relay.notify_peer_connected(peer_id);
728
729        // Setup established subscription by going through the full flow
730        relay.subscribe_query(query_id, collection_id.clone(), predicate, collection_id.clone(), 0, MockLiveQuery);
731
732        // Give async task time to complete
733        futures_timer::Delay::new(std::time::Duration::from_millis(10)).await;
734
735        assert!(matches!(relay.get_status(query_id), Some(Status::Established(established_peer_id, _)) if established_peer_id == peer_id));
736
737        // Simulate peer disconnection
738        relay.notify_peer_disconnected(peer_id);
739
740        // Verify subscription is marked as pending again
741        assert!(matches!(relay.get_status(query_id), Some(Status::PendingRemote)));
742    }
743
744    #[tokio::test]
745    async fn test_peer_connection_triggers_setup() {
746        let relay = SubscriptionRelay::new();
747        let mock_sender = Arc::new(MockMessageSender::<CollectionId>::new());
748        relay.set_node(mock_sender.clone()).expect("Failed to set message sender");
749
750        let query_id = proto::QueryId::new();
751        let collection_id = create_test_collection_id();
752        let predicate = create_test_selection();
753        let peer_id = EntityId::new();
754
755        // Add pending subscription (no peers connected yet)
756        relay.subscribe_query(query_id, collection_id.clone(), predicate.clone(), collection_id.clone(), 0, MockLiveQuery);
757        assert!(matches!(relay.get_status(query_id), Some(Status::PendingRemote)));
758
759        // Clear any previous requests
760        mock_sender.clear_sent_requests();
761
762        // Simulate peer connection (should trigger automatic setup)
763        relay.notify_peer_connected(peer_id);
764
765        // Give async task time to complete
766        futures_timer::Delay::new(std::time::Duration::from_millis(10)).await;
767
768        // Verify request was sent
769        let sent_requests = mock_sender.get_sent_requests();
770        assert_eq!(sent_requests.len(), 1);
771        assert_eq!(sent_requests[0].0, peer_id);
772        assert_eq!(sent_requests[0].1, query_id);
773
774        // Verify subscription is established
775        assert!(matches!(relay.get_status(query_id), Some(Status::Established(established_peer_id, _)) if established_peer_id == peer_id));
776    }
777
778    #[tokio::test]
779    async fn test_failed_subscription_retry() {
780        let relay = SubscriptionRelay::new();
781        let mock_sender = Arc::new(MockMessageSender::<CollectionId>::new());
782        relay.set_node(mock_sender.clone()).expect("Failed to set message sender");
783
784        let query_id = proto::QueryId::new();
785        let collection_id = create_test_collection_id();
786        let predicate = create_test_selection();
787        let peer_id = EntityId::new();
788
789        // Connect peer and add subscription (should succeed initially)
790        relay.notify_peer_connected(peer_id);
791        relay.subscribe_query(query_id, collection_id.clone(), predicate.clone(), collection_id.clone(), 0, MockLiveQuery);
792
793        // Give async task time to complete
794        futures_timer::Delay::new(std::time::Duration::from_millis(10)).await;
795
796        // Verify subscription is marked as established (since no error was set)
797        assert!(matches!(relay.get_status(query_id), Some(Status::Established(established_peer_id, _)) if established_peer_id == peer_id));
798
799        // Now test the retry behavior by disconnecting the peer (puts subscription back to PendingRemote)
800        // then setting up the mock to fail, and reconnecting to trigger the retry
801        relay.notify_peer_disconnected(peer_id);
802
803        // Verify subscription is now in pending state
804        assert!(matches!(relay.get_status(query_id), Some(Status::PendingRemote)));
805
806        // Clear requests and set up mock to fail on the next call
807        mock_sender.clear_sent_requests();
808        mock_sender.set_fail_next(RequestError::ServerError("Invalid predicate".to_string()));
809
810        // Reconnect peer to trigger retry attempt
811        relay.notify_peer_connected(peer_id);
812
813        // Give async task time to complete
814        futures_timer::Delay::new(std::time::Duration::from_millis(10)).await;
815
816        // Verify retry was attempted (the error gets consumed)
817        let sent_requests = mock_sender.get_sent_requests();
818        assert_eq!(sent_requests.len(), 1);
819
820        // Verify subscription remains in failed state (non-retryable error)
821        assert!(matches!(relay.get_status(query_id), Some(Status::Failed)));
822    }
823
824    #[tokio::test]
825    async fn test_retryable_vs_non_retryable_failures() {
826        let relay = SubscriptionRelay::new();
827        let mock_sender = Arc::new(MockMessageSender::<CollectionId>::new());
828        relay.set_node(mock_sender.clone()).expect("Failed to set message sender");
829
830        let retryable_query_id = proto::QueryId::new();
831        let non_retryable_query_id = proto::QueryId::new();
832        let collection_id = create_test_collection_id();
833        let predicate = create_test_selection();
834        let peer_id = EntityId::new();
835
836        // Add subscriptions
837        relay.subscribe_query(retryable_query_id, collection_id.clone(), predicate.clone(), collection_id.clone(), 0, MockLiveQuery);
838        relay.subscribe_query(non_retryable_query_id, collection_id.clone(), predicate.clone(), collection_id.clone(), 0, MockLiveQuery);
839
840        // Manually set different failure types - retryable goes back to pending, non-retryable stays failed
841        {
842            let mut subscriptions = relay.inner.subscriptions.lock().unwrap_or_else(|e| e.into_inner());
843            if let Some(info) = subscriptions.get_mut(&retryable_query_id) {
844                info.status = Status::PendingRemote; // Retryable errors go back to pending
845            }
846            if let Some(info) = subscriptions.get_mut(&non_retryable_query_id) {
847                info.status = Status::Failed; // Non-retryable errors stay failed
848            }
849        }
850
851        // Connect peer and trigger retry
852        relay.notify_peer_connected(peer_id);
853
854        // Give async task time to complete
855        futures_timer::Delay::new(std::time::Duration::from_millis(10)).await;
856
857        // Verify only the retryable subscription was attempted
858        let sent_requests = mock_sender.get_sent_requests();
859        assert_eq!(sent_requests.len(), 1);
860        assert_eq!(sent_requests[0].1, retryable_query_id);
861
862        // Verify states
863        assert!(
864            matches!(relay.get_status(retryable_query_id), Some(Status::Established(established_peer_id, _)) if established_peer_id == peer_id)
865        );
866        assert!(matches!(relay.get_status(non_retryable_query_id), Some(Status::Failed)));
867    }
868
869    #[tokio::test]
870    async fn test_subscription_removal() {
871        let relay = SubscriptionRelay::new();
872        let mock_sender = Arc::new(MockMessageSender::<CollectionId>::new());
873        relay.set_node(mock_sender.clone()).expect("Failed to set message sender");
874
875        let query_id = proto::QueryId::new();
876        let collection_id = create_test_collection_id();
877        let predicate = create_test_selection();
878        let peer_id = EntityId::new();
879
880        // Connect peer and setup established subscription
881        relay.notify_peer_connected(peer_id);
882        relay.subscribe_query(query_id, collection_id.clone(), predicate, collection_id.clone(), 0, MockLiveQuery);
883
884        // Give async task time to complete
885        futures_timer::Delay::new(std::time::Duration::from_millis(10)).await;
886
887        assert!(matches!(relay.get_status(query_id), Some(Status::Established(established_peer_id, _)) if established_peer_id == peer_id));
888
889        // Clear previous requests to focus on unsubscribe
890        mock_sender.clear_sent_requests();
891
892        // Remove subscription
893        relay.unsubscribe_predicate(query_id);
894
895        // Give async task time to complete
896        futures_timer::Delay::new(std::time::Duration::from_millis(10)).await;
897
898        // Verify unsubscribe message was sent
899        let sent_requests = mock_sender.get_sent_requests();
900        assert_eq!(sent_requests.len(), 1);
901        assert_eq!(sent_requests[0].0, peer_id);
902        assert_eq!(sent_requests[0].1, query_id);
903
904        // Verify subscription is gone
905        assert!(matches!(relay.get_status(query_id), None));
906    }
907
908    #[tokio::test]
909    async fn test_edge_cases() {
910        let relay = SubscriptionRelay::new();
911        let mock_sender = Arc::new(MockMessageSender::<CollectionId>::new());
912
913        let query_id = proto::QueryId::new();
914        let collection_id = create_test_collection_id();
915        let predicate = create_test_selection();
916        let peer_id = EntityId::new();
917
918        // Test setup without message sender - should not crash
919        relay.subscribe_query(query_id, collection_id.clone(), predicate.clone(), collection_id.clone(), 0, MockLiveQuery);
920        futures_timer::Delay::new(std::time::Duration::from_millis(10)).await;
921
922        // Should still be pending since no sender
923        assert!(matches!(relay.get_status(query_id), Some(Status::PendingRemote)));
924
925        // Now set sender and test with no connected peers
926        relay.set_node(mock_sender.clone()).expect("Failed to set message sender");
927        futures_timer::Delay::new(std::time::Duration::from_millis(10)).await;
928
929        // Should still be pending since no peers available
930        assert!(matches!(relay.get_status(query_id), Some(Status::PendingRemote)));
931
932        // Verify no requests were sent
933        assert_eq!(mock_sender.get_sent_requests().len(), 0);
934
935        // Now connect a peer (should trigger automatic setup)
936        relay.notify_peer_connected(peer_id);
937        futures_timer::Delay::new(std::time::Duration::from_millis(10)).await;
938
939        // Should now be established
940        assert!(matches!(relay.get_status(query_id), Some(Status::Established(established_peer_id, _)) if established_peer_id == peer_id));
941        assert_eq!(mock_sender.get_sent_requests().len(), 1);
942    }
943
944    #[tokio::test]
945    async fn test_notify_unsubscribe_with_no_established_subscription() {
946        let relay = SubscriptionRelay::new();
947        let mock_sender = Arc::new(MockMessageSender::<CollectionId>::new());
948        relay.set_node(mock_sender.clone()).expect("Failed to set message sender");
949
950        let query_id = proto::QueryId::new();
951        let collection_id = create_test_collection_id();
952        let predicate = create_test_selection();
953
954        // Add subscription but don't establish it
955        relay.subscribe_query(query_id, collection_id.clone(), predicate, collection_id.clone(), 0, MockLiveQuery);
956        assert!(matches!(relay.get_status(query_id), Some(Status::PendingRemote)));
957
958        // Unsubscribe from pending subscription
959        relay.unsubscribe_predicate(query_id);
960
961        // Give async task time to complete (though no request should be sent)
962        futures_timer::Delay::new(std::time::Duration::from_millis(10)).await;
963
964        // Verify no unsubscribe message was sent (since it wasn't established)
965        let sent_requests = mock_sender.get_sent_requests();
966        assert_eq!(sent_requests.len(), 0);
967
968        // Verify subscription is gone
969        assert!(matches!(relay.get_status(query_id), None));
970    }
971}