ankurah_core/peer_subscription/
client_relay.rs

1// TODO: Rename this module from client_relay to remote_subscription for clarity
2use ankurah_proto::{self as proto, CollectionId};
3use anyhow::anyhow;
4use async_trait::async_trait;
5use std::collections::HashMap;
6use std::sync::{Arc, OnceLock};
7use tracing::{debug, warn};
8
9use crate::error::{RequestError, RetrievalError};
10use crate::node::ContextData;
11use crate::util::safeset::SafeSet;
12
13/// Trait for query initialization that can be driven by SubscriptionRelay
14/// Abstracts the relay's interaction with LiveQuery
15#[async_trait::async_trait]
16pub trait RemoteQuerySubscriber: Clone + Send + Sync + 'static {
17    /// Called after remote subscription deltas have been applied
18    /// Dispatches to initialize (version 1) or update_selection_init (version >1) internally
19    /// Handles marking initialization as complete and setting last_error on failure
20    async fn subscription_established(&self, version: u32);
21
22    /// Set the last error for this subscription
23    fn set_last_error(&self, error: RetrievalError);
24}
25
26#[derive(Debug, Clone)]
27pub enum Status {
28    PendingRemote,
29    Requested(proto::EntityId, u32),     // peer_id, version
30    Established(proto::EntityId, u32),   // peer_id, version
31    PendingUpdate(proto::EntityId, u32), // peer_id, version
32    /// Non-retryable
33    Failed,
34}
35
36#[derive(Debug)]
37pub struct Content<CD: ContextData> {
38    pub query_id: proto::QueryId,
39    pub collection_id: CollectionId,
40    pub selection: ankql::ast::Selection,
41    pub context_data: CD,
42    pub version: u32,
43}
44
45pub struct RemoteQueryState<CD: ContextData, Q: RemoteQuerySubscriber> {
46    pub content: Arc<Content<CD>>,
47    pub status: Status,
48    pub livequery: Q,
49}
50
51struct SubscriptionRelayInner<CD: ContextData, Q: RemoteQuerySubscriber> {
52    // All subscription information in one place
53    subscriptions: std::sync::Mutex<HashMap<proto::QueryId, RemoteQueryState<CD, Q>>>,
54    // Track connected durable peers
55    connected_peers: SafeSet<proto::EntityId>,
56    // Node for communicating with remote peers
57    node: OnceLock<Arc<dyn TNode<CD>>>,
58    // Shutdown signal for retry task - when dropped, the task will stop
59    _shutdown_tx: tokio::sync::mpsc::Sender<()>,
60}
61
62/// Manages predicate registration on remote peer reactor subscriptions.
63///
64/// The SubscriptionRelay provides a resilient, event-driven approach to managing which predicates
65/// are registered with remote durable peers. It automatically handles:
66/// - Registering predicates on peer reactor subscriptions when peers connect
67/// - Re-registering predicates when peers disconnect and reconnect
68/// - Retrying failed predicate registration attempts
69/// - Clean teardown when predicates are removed
70/// - Storing ContextData for each predicate to enable proper authorization
71///
72/// This design separates predicate management concerns from the main Node implementation,
73/// making it easier to test and reason about predicate lifecycle management.
74///
75/// # Public API (for Node integration)
76///
77/// - `subscribe_predicate()` - Call when local subscriptions are created (parallel to reactor.subscribe)
78/// - `unsubscribe_predicate()` - Call when local subscriptions are removed (parallel to reactor.unsubscribe)
79/// - `notify_peer_connected()` - Call when durable peers connect (triggers automatic predicate registration)
80/// - `notify_peer_disconnected()` - Call when durable peers disconnect (orphans predicate registrations)
81/// - `get_status()` - Query current state of a predicate registration
82///
83/// # Internal/Testing API
84///
85/// - `setup_remote_subscriptions()` - Internal method for triggering predicate registration with specific peers
86///   (called automatically by notify_peer_connected, but exposed for testing)
87///
88/// The relay will automatically handle predicate registration/teardown asynchronously.
89#[derive(Clone)]
90pub struct SubscriptionRelay<CD: ContextData, Q: RemoteQuerySubscriber> {
91    inner: Arc<SubscriptionRelayInner<CD, Q>>,
92}
93
94impl<CD: ContextData, Q: RemoteQuerySubscriber> Default for SubscriptionRelay<CD, Q> {
95    fn default() -> Self { Self::new() }
96}
97
98impl<CD: ContextData, Q: RemoteQuerySubscriber> SubscriptionRelay<CD, Q> {
99    pub fn new() -> Self {
100        let (shutdown_tx, shutdown_rx) = tokio::sync::mpsc::channel(1);
101
102        let relay = Self {
103            inner: Arc::new(SubscriptionRelayInner {
104                subscriptions: std::sync::Mutex::new(HashMap::new()),
105                connected_peers: SafeSet::new(),
106                node: OnceLock::new(),
107                _shutdown_tx: shutdown_tx,
108            }),
109        };
110
111        // Start background retry task
112        relay.start_retry_task(shutdown_rx);
113
114        relay
115    }
116
117    /// Inject the node (typically a WeakNode for production)
118    ///
119    /// This should be called once during initialization. Returns an error if
120    /// the node has already been set.
121    pub fn set_node(&self, node: Arc<dyn TNode<CD>>) -> Result<(), ()> { self.inner.node.set(node).map_err(|_| ()) }
122
123    /// Notify the relay that a new predicate needs to be registered on remote peer subscriptions
124    ///
125    /// This should be called whenever a local subscription is established. The relay will
126    /// track this predicate and automatically attempt to register it with available durable peers.
127    pub fn subscribe_query(
128        &self,
129        query_id: proto::QueryId,
130        collection_id: CollectionId,
131        selection: ankql::ast::Selection,
132        context_data: CD,
133        version: u32,
134        livequery: Q,
135    ) {
136        debug!("SubscriptionRelay.subscribe_predicate() - New predicate {} needs remote registration", query_id);
137        {
138            self.inner.subscriptions.lock().expect("poisoned lock").insert(
139                query_id,
140                RemoteQueryState {
141                    content: Arc::new(Content { collection_id, selection, context_data, query_id, version }),
142                    status: Status::PendingRemote,
143                    livequery,
144                },
145            );
146        }
147
148        // Immediately attempt setup with available peers
149        if !self.inner.connected_peers.is_empty() {
150            self.setup_remote_subscriptions()
151        }
152    }
153    pub fn update_query(&self, query_id: proto::QueryId, selection: ankql::ast::Selection, version: u32) -> Result<(), anyhow::Error> {
154        debug!("SubscriptionRelay.update_query() - New query {} needs remote registration", query_id);
155
156        let update = {
157            let mut subscriptions = self.inner.subscriptions.lock().expect("poisoned lock");
158            match subscriptions.get_mut(&query_id) {
159                Some(state) => {
160                    // Update the content with new predicate and version
161                    let old_content = &state.content;
162                    state.content = Arc::new(Content {
163                        collection_id: old_content.collection_id.clone(),
164                        selection: selection.clone(),
165                        context_data: old_content.context_data.clone(),
166                        query_id: old_content.query_id,
167                        version,
168                    });
169
170                    match state.status {
171                        Status::Established(peer_id, _old_version) => {
172                            // Update to new version, mark as requested for this peer
173                            state.status = Status::Requested(peer_id, version);
174                            Some((peer_id, state.content.collection_id.clone(), state.content.context_data.clone()))
175                            // Return the peer_id to send update to
176                        }
177                        _ => {
178                            // Not established yet, just update to PendingRemote and setup
179                            state.status = Status::PendingRemote;
180                            None
181                        }
182                    }
183                }
184                None => return Err(anyhow!("Predicate {} not found", query_id)),
185            }
186        };
187
188        match update {
189            Some((peer_id, collection_id, context_data)) => {
190                self.update_query_on_peer(peer_id, query_id, collection_id, selection, version, context_data);
191            }
192            None => {
193                // Not established yet - use setup_remote_subscriptions for initial setup
194                self.setup_remote_subscriptions();
195            }
196        };
197
198        Ok(())
199    }
200
201    fn update_query_on_peer(
202        &self,
203        peer_id: proto::EntityId,
204        query_id: proto::QueryId,
205        collection_id: CollectionId,
206        selection: ankql::ast::Selection,
207        version: u32,
208        context_data: CD,
209    ) {
210        let me = self.clone();
211        crate::task::spawn(async move {
212            if let Some(node) = me.inner.node.get() {
213                // Get the livequery for error handling
214                let livequery = { me.inner.subscriptions.lock().unwrap().get(&query_id).map(|state| state.livequery.clone()) };
215
216                // Send the updated predicate to the peer
217                match node.remote_subscribe(peer_id, query_id, collection_id, selection, &context_data, version).await {
218                    Ok(()) => {
219                        // Deltas applied successfully, now activate the livequery
220                        if let Some(lq) = livequery {
221                            lq.subscription_established(version).await;
222                        }
223
224                        // Mark as established - subscription succeeded even if livequery activation had issues
225                        let mut subscriptions = me.inner.subscriptions.lock().unwrap();
226                        if let Some(info) = subscriptions.get_mut(&query_id) {
227                            info.status = Status::Established(peer_id, version);
228                        }
229                        debug!("Successfully updated predicate {} on peer {} subscription", query_id, peer_id);
230                    }
231                    Err(e) => {
232                        // Handle error with retry logic
233                        me.handle_error(query_id, peer_id, e, livequery).await;
234                    }
235                }
236            }
237        });
238    }
239
240    /// Notify the relay that a predicate should be removed from remote peer subscriptions
241    ///
242    /// This will clean up all tracking state and send unsubscribe requests to any
243    /// remote peers that have this predicate registered.
244    pub fn unsubscribe_predicate(&self, query_id: proto::QueryId) {
245        debug!("Unregistering predicate {}", query_id);
246
247        // If subscription was established with a peer, send unsubscribe request
248        {
249            let mut subscriptions = self.inner.subscriptions.lock().unwrap();
250            if let Some(info) = subscriptions.remove(&query_id) {
251                if let Status::Established(peer_id, _version) = &info.status {
252                    let node = self.inner.node.get();
253                    if let Some(node) = node {
254                        let node = node.clone();
255                        let peer_id = *peer_id;
256                        crate::task::spawn(async move {
257                            if let Err(e) = node.peer_unsubscribe(peer_id, query_id).await {
258                                warn!("Failed to send unsubscribe message for {}: {}", query_id, e);
259                            } else {
260                                debug!("Successfully sent unsubscribe message for {}", query_id);
261                            }
262                        });
263                    }
264                }
265            }
266        }
267    }
268
269    /// Handle peer disconnection - mark all predicates for that peer as needing re-registration
270    ///
271    /// This should be called when a durable peer disconnects. All predicates registered
272    /// with that peer will be marked as pending and will be automatically re-registered
273    /// when the peer reconnects or another suitable peer becomes available.
274    pub fn notify_peer_disconnected(&self, peer_id: proto::EntityId) {
275        debug!("Peer {} disconnected, orphaning predicate registrations", peer_id);
276
277        // Remove from connected peers
278        self.inner.connected_peers.remove(&peer_id);
279
280        for info in self.inner.subscriptions.lock().expect("poisoned lock").values_mut() {
281            if let Status::Established(established_peer_id, _) | Status::Requested(established_peer_id, _) = &info.status {
282                if *established_peer_id == peer_id {
283                    // Update state to pending
284                    info.status = Status::PendingRemote;
285                    warn!("Predicate {} orphaned due to peer {} disconnect", info.content.query_id, peer_id);
286                }
287            }
288        }
289
290        // Resubscribe any orphaned subscriptions
291        self.setup_remote_subscriptions();
292    }
293
294    /// Handle peer connection - trigger predicate registration on the new peer subscription
295    ///
296    /// This should be called when a new durable peer connects. The relay will automatically
297    /// attempt to register any pending predicates on the newly connected peer's subscription.
298    pub fn notify_peer_connected(&self, peer_id: proto::EntityId) {
299        debug!("SubscriptionRelay.notify_peer_connected() - Peer {} connected, registering predicates on peer subscription", peer_id);
300
301        // Add to connected peers
302        self.inner.connected_peers.insert(peer_id);
303
304        // Trigger setup with all connected peers
305        self.setup_remote_subscriptions();
306    }
307
308    /// Get the current state of a predicate registration
309    pub fn get_status(&self, query_id: proto::QueryId) -> Option<Status> {
310        let subscriptions = self.inner.subscriptions.lock().unwrap();
311        subscriptions.get(&query_id).map(|info| info.status.clone())
312    }
313
314    /// Get all unique contexts for predicates established or requested with a specific peer
315    /// TODO: update the data structure to do this via a direct lookup rather than having to scan the entire map
316    pub fn get_contexts_for_peer(&self, peer_id: &proto::EntityId) -> std::collections::HashSet<CD> {
317        let subscriptions = self.inner.subscriptions.lock().unwrap();
318        let mut contexts = std::collections::HashSet::new();
319
320        for (_, state) in subscriptions.iter() {
321            match &state.status {
322                Status::Established(established_peer, _) | Status::Requested(established_peer, _) => {
323                    if established_peer == peer_id {
324                        contexts.insert(state.content.context_data.clone());
325                    }
326                }
327                _ => {}
328            }
329        }
330
331        contexts
332    }
333
334    /// Register predicates on available durable peer subscriptions
335    fn setup_remote_subscriptions(&self) {
336        let node = match self.inner.node.get() {
337            Some(node) => node,
338            None => {
339                warn!("No node configured for remote subscription setup");
340                return;
341            }
342        };
343
344        // For now, use the first available peer (could be made smarter)
345        let connected_peers = self.inner.connected_peers.to_vec();
346        if connected_peers.is_empty() {
347            warn!("No durable peers available for remote subscription setup");
348            return;
349        }
350
351        let target_peer = connected_peers[0];
352
353        // Atomically get pending subscriptions and mark them as requested
354        let pending: Vec<_> = {
355            self.inner
356                .subscriptions
357                .lock()
358                .expect("poisoned lock")
359                .values_mut()
360                .filter_map(|info| {
361                    if let Status::PendingRemote = info.status {
362                        info.status = Status::Requested(target_peer, info.content.version);
363                        Some(info.content.clone())
364                    } else {
365                        None
366                    }
367                })
368                .collect()
369        };
370
371        if pending.is_empty() {
372            return;
373        }
374
375        debug!("Registering {} predicates on {} peer subscriptions", pending.len(), self.inner.connected_peers.len());
376
377        for content in pending {
378            crate::task::spawn(self.clone().attempt_subscribe(node.clone(), target_peer, content));
379        }
380    }
381
382    async fn attempt_subscribe(self, node: Arc<dyn TNode<CD>>, target_peer: proto::EntityId, content: Arc<Content<CD>>) {
383        let query_id = content.query_id;
384        let predicate = content.selection.clone();
385        let context_data = content.context_data.clone();
386        let version = content.version;
387
388        // Get the livequery for error handling
389        let livequery = { self.inner.subscriptions.lock().unwrap().get(&query_id).map(|state| state.livequery.clone()) };
390
391        // Call remote_subscribe which fetches known matches, subscribes, applies deltas, and stores events
392        match node.remote_subscribe(target_peer, query_id, content.collection_id.clone(), predicate, &context_data, version).await {
393            Ok(()) => {
394                // Deltas applied successfully, now activate the livequery
395                // The livequery handles its own errors internally
396                if let Some(lq) = livequery {
397                    lq.subscription_established(version).await;
398                }
399
400                // Mark as established - subscription succeeded even if livequery activation had issues
401                let mut subscriptions = self.inner.subscriptions.lock().unwrap();
402                if let Some(info) = subscriptions.get_mut(&query_id) {
403                    info.status = Status::Established(target_peer, version);
404                }
405                debug!("Successfully registered predicate {} on peer {} subscription", query_id, target_peer);
406            }
407            Err(e) => {
408                // Handle error with retry logic
409                self.handle_error(query_id, target_peer, e, livequery).await;
410            }
411        }
412    }
413
414    /// Start background task that periodically retries pending subscriptions
415    fn start_retry_task(&self, mut shutdown_rx: tokio::sync::mpsc::Receiver<()>) {
416        let me = self.clone();
417        crate::task::spawn(async move {
418            let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(5));
419            loop {
420                tokio::select! {
421                    _ = interval.tick() => {
422                        // Attempt to setup any pending subscriptions
423                        me.setup_remote_subscriptions();
424                    }
425                    _ = shutdown_rx.recv() => {
426                        debug!("Retry task shutting down - SubscriptionRelay dropped");
427                        break;
428                    }
429                }
430            }
431        });
432    }
433
434    /// Handle errors with retry logic
435    async fn handle_error(&self, query_id: proto::QueryId, target_peer: proto::EntityId, error: RetrievalError, livequery: Option<Q>) {
436        let error_msg = error.to_string();
437
438        // Evaluate retriability at failure time
439        let is_retryable = match &error {
440            // Retrieval errors from fetching are generally not retryable
441            RetrievalError::RequestError(req_err) => match req_err {
442                RequestError::PeerNotConnected => true,
443                RequestError::ConnectionLost => true,
444                RequestError::SendError(_) => true,
445                RequestError::InternalChannelClosed => true,
446                RequestError::ServerError(_) => false,
447                RequestError::UnexpectedResponse(_) => false,
448            },
449            // Other retrieval errors are not retryable
450            _ => false,
451        };
452
453        // Update state based on retriability
454        let mut subscriptions = self.inner.subscriptions.lock().unwrap();
455        if let Some(info) = subscriptions.get_mut(&query_id) {
456            if is_retryable {
457                // Retryable errors go back to pending for retry by background task
458                info.status = Status::PendingRemote;
459                warn!("Retryable failure for predicate {} with peer {}: {} - will retry", query_id, target_peer, error_msg);
460            } else {
461                // Non-retryable errors are permanently failed
462                info.status = Status::Failed;
463                tracing::error!("Permanent failure for predicate {} with peer {}: {} - no retry", query_id, target_peer, error_msg);
464
465                // Set error on livequery
466                if let Some(lq) = livequery {
467                    lq.set_last_error(error);
468                }
469            }
470        }
471    }
472}
473
474/// Trait for communicating with remote peers (abstraction over WeakNode for testing)
475#[async_trait]
476pub trait TNode<CD: ContextData>: Send + Sync {
477    /// Send a predicate registration request to a remote peer, fetch known matches,
478    /// apply received deltas, and store used events.
479    /// Returns Ok(()) if subscription was established and deltas applied successfully.
480    async fn remote_subscribe(
481        &self,
482        peer_id: proto::EntityId,
483        query_id: proto::QueryId,
484        collection_id: CollectionId,
485        selection: ankql::ast::Selection,
486        context_data: &CD,
487        version: u32,
488    ) -> Result<(), RetrievalError>;
489
490    /// Send a predicate unregistration message to a remote peer
491    /// This is a one-way message, no response expected
492    async fn peer_unsubscribe(&self, peer_id: proto::EntityId, query_id: proto::QueryId) -> Result<(), anyhow::Error>;
493}
494
495/// Implementation of TNode for WeakNode
496#[async_trait]
497impl<SE, PA> TNode<PA::ContextData> for crate::node::WeakNode<SE, PA>
498where
499    SE: crate::storage::StorageEngine + Send + Sync + 'static,
500    PA: crate::policy::PolicyAgent + Send + Sync + 'static,
501{
502    async fn remote_subscribe(
503        &self,
504        peer_id: proto::EntityId,
505        query_id: proto::QueryId,
506        collection_id: CollectionId,
507        selection: ankql::ast::Selection,
508        context_data: &PA::ContextData,
509        version: u32,
510    ) -> Result<(), RetrievalError> {
511        let node = self.upgrade().ok_or_else(|| RetrievalError::Other("Node has been dropped".to_string()))?;
512
513        // 1. Pre-fetch known_matches from local storage
514        let known_matches: Vec<ankurah_proto::KnownEntity> = node
515            .fetch_entities_from_local(&collection_id, &selection)
516            .await?
517            .into_iter()
518            .map(|entity| ankurah_proto::KnownEntity { entity_id: entity.id(), head: entity.head() })
519            .collect();
520
521        // 2. Send subscribe request with known_matches
522        let deltas = match node
523            .request(
524                peer_id,
525                context_data,
526                ankurah_proto::NodeRequestBody::SubscribeQuery {
527                    query_id,
528                    collection: collection_id.clone(),
529                    selection: selection.clone(),
530                    version,
531                    known_matches,
532                },
533            )
534            .await
535            .map_err(|e| RetrievalError::RequestError(e))?
536        {
537            ankurah_proto::NodeResponseBody::QuerySubscribed { query_id: _response_query_id, deltas } => deltas,
538            ankurah_proto::NodeResponseBody::Error(e) => return Err(RetrievalError::RequestError(RequestError::ServerError(e))),
539            other => return Err(RetrievalError::RequestError(RequestError::UnexpectedResponse(other))),
540        };
541
542        // 3. Apply deltas to local node using NodeApplier
543        let retriever = crate::retrieval::EphemeralNodeRetriever::new(collection_id, &node, context_data);
544        let apply_result = crate::node_applier::NodeApplier::apply_deltas(&node, &peer_id, deltas, &retriever).await;
545        let event_store_result = retriever.store_used_events().await;
546
547        apply_result?; // apply result is more important than event store result
548        event_store_result?;
549
550        Ok(())
551    }
552
553    async fn peer_unsubscribe(&self, peer_id: proto::EntityId, query_id: proto::QueryId) -> Result<(), anyhow::Error> {
554        let node = self.upgrade().ok_or_else(|| anyhow!("Node has been dropped"))?;
555
556        // Use the existing request_remote_unsubscribe method
557        node.request_remote_unsubscribe(query_id, vec![peer_id]).await?;
558
559        Ok(())
560    }
561}
562
563#[cfg(test)]
564mod tests {
565    use super::*;
566    use ankql::ast::Predicate;
567    use ankurah_proto::EntityId;
568    use std::sync::{Arc, Mutex};
569
570    // Note: Some tests call setup_remote_subscriptions() directly to test the core
571    // subscription setup logic in isolation, while others use notify_peer_connected()
572    // to test the full event-driven flow. Both approaches are valuable:
573    // - Direct calls test the setup mechanism itself (error handling, state transitions)
574    // - Event-driven calls test the integration and user-facing API
575
576    // For testing, we'll use CollectionId as our ContextData
577    impl ContextData for CollectionId {}
578
579    /// Mock message sender for testing
580    #[derive(Debug)]
581    struct MockMessageSender<CD: ContextData> {
582        next_error: Arc<Mutex<Option<RequestError>>>,
583        sent_requests: Arc<Mutex<Vec<(EntityId, proto::QueryId, CollectionId, ankql::ast::Selection)>>>,
584        should_fail: Arc<Mutex<bool>>,
585        failure_message: Arc<Mutex<String>>,
586        _phantom: std::marker::PhantomData<CD>,
587    }
588
589    impl<CD: ContextData> MockMessageSender<CD> {
590        fn new() -> Self {
591            Self {
592                sent_requests: Arc::new(Mutex::new(Vec::new())),
593                next_error: Arc::new(Mutex::new(None)),
594                should_fail: Arc::new(Mutex::new(false)),
595                failure_message: Arc::new(Mutex::new(String::new())),
596                _phantom: std::marker::PhantomData,
597            }
598        }
599
600        fn set_fail_next(&self, error: RequestError) { *self.next_error.lock().unwrap() = Some(error); }
601
602        fn get_sent_requests(&self) -> Vec<(EntityId, proto::QueryId, CollectionId, ankql::ast::Selection)> {
603            self.sent_requests.lock().unwrap().clone()
604        }
605
606        fn clear_sent_requests(&self) { self.sent_requests.lock().unwrap().clear(); }
607    }
608
609    #[async_trait]
610    impl<CD: ContextData> TNode<CD> for MockMessageSender<CD> {
611        async fn remote_subscribe(
612            &self,
613            peer_id: EntityId,
614            query_id: proto::QueryId,
615            collection_id: CollectionId,
616            selection: ankql::ast::Selection,
617            _context_data: &CD,
618            _version: u32,
619        ) -> Result<(), RetrievalError> {
620            self.sent_requests.lock().unwrap().push((peer_id, query_id, collection_id.clone(), selection.clone()));
621
622            // Check if there's an error to fail with
623            if let Some(error) = self.next_error.lock().unwrap().take() {
624                Err(RetrievalError::RequestError(error))
625            } else {
626                // Mock successful subscription (fetch, subscribe, apply, store all succeeded)
627                Ok(())
628            }
629        }
630
631        async fn peer_unsubscribe(&self, peer_id: EntityId, query_id: proto::QueryId) -> Result<(), anyhow::Error> {
632            self.sent_requests.lock().unwrap().push((
633                peer_id,
634                query_id,
635                CollectionId::from("unsubscribe"),
636                ankql::ast::Selection { predicate: ankql::ast::Predicate::True, order_by: None, limit: None },
637            ));
638
639            // Check if there's an error to fail with
640            if let Some(error) = self.next_error.lock().unwrap().take() {
641                Err(anyhow!(error.to_string()))
642            } else {
643                Ok(())
644            }
645        }
646    }
647
648    // Mock implementation of RemoteQuerySubscriber for tests
649    #[derive(Clone)]
650    struct MockLiveQuery;
651
652    #[async_trait::async_trait]
653    impl RemoteQuerySubscriber for MockLiveQuery {
654        async fn subscription_established(&self, _version: u32) {
655            // Mock - no-op
656        }
657
658        fn set_last_error(&self, _error: RetrievalError) {
659            // For tests, we don't track errors
660        }
661    }
662
663    fn create_test_selection() -> ankql::ast::Selection {
664        // Create a simple test predicate
665        ankql::ast::Selection { predicate: ankql::ast::Predicate::True, order_by: None, limit: None }
666    }
667
668    fn create_test_collection_id() -> CollectionId { CollectionId::from("test_collection") }
669
670    #[tokio::test]
671    async fn test_new_subscription_setup() {
672        let relay = SubscriptionRelay::new();
673        let mock_sender = Arc::new(MockMessageSender::<CollectionId>::new());
674        relay.set_node(mock_sender.clone()).expect("Failed to set message sender");
675
676        let query_id = proto::QueryId::new();
677        let collection_id = create_test_collection_id();
678        let predicate = create_test_selection();
679        let peer_id = EntityId::new();
680
681        // Connect the peer first
682        relay.notify_peer_connected(peer_id);
683
684        // Notify of new subscription
685        relay.subscribe_query(query_id, collection_id.clone(), predicate.clone(), collection_id.clone(), 0, MockLiveQuery);
686
687        // Check initial state - subscription should immediately go to Requested state since peer is connected
688        assert!(matches!(relay.get_status(query_id), Some(Status::Requested(_, _))));
689
690        // Give async task time to complete (setup should happen automatically)
691        tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
692
693        // Verify request was sent
694        let sent_requests = mock_sender.get_sent_requests();
695        assert_eq!(sent_requests.len(), 1);
696        assert_eq!(sent_requests[0].0, peer_id);
697        assert_eq!(sent_requests[0].1, query_id);
698        assert_eq!(sent_requests[0].2, collection_id);
699
700        // Verify subscription is marked as established
701        assert!(matches!(relay.get_status(query_id), Some(Status::Established(established_peer_id, _)) if established_peer_id == peer_id));
702    }
703
704    #[tokio::test]
705    async fn test_peer_disconnection_orphans_subscriptions() {
706        let relay = SubscriptionRelay::new();
707
708        let mock_sender = Arc::new(MockMessageSender::<CollectionId>::new());
709        relay.set_node(mock_sender.clone()).expect("Failed to set message sender");
710
711        let query_id = proto::QueryId::new();
712        let collection_id = create_test_collection_id();
713        let predicate = create_test_selection();
714        let peer_id = EntityId::new();
715
716        // Connect the peer first
717        relay.notify_peer_connected(peer_id);
718
719        // Setup established subscription by going through the full flow
720        relay.subscribe_query(query_id, collection_id.clone(), predicate, collection_id.clone(), 0, MockLiveQuery);
721
722        // Give async task time to complete
723        tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
724
725        assert!(matches!(relay.get_status(query_id), Some(Status::Established(established_peer_id, _)) if established_peer_id == peer_id));
726
727        // Simulate peer disconnection
728        relay.notify_peer_disconnected(peer_id);
729
730        // Verify subscription is marked as pending again
731        assert!(matches!(relay.get_status(query_id), Some(Status::PendingRemote)));
732    }
733
734    #[tokio::test]
735    async fn test_peer_connection_triggers_setup() {
736        let relay = SubscriptionRelay::new();
737        let mock_sender = Arc::new(MockMessageSender::<CollectionId>::new());
738        relay.set_node(mock_sender.clone()).expect("Failed to set message sender");
739
740        let query_id = proto::QueryId::new();
741        let collection_id = create_test_collection_id();
742        let predicate = create_test_selection();
743        let peer_id = EntityId::new();
744
745        // Add pending subscription (no peers connected yet)
746        relay.subscribe_query(query_id, collection_id.clone(), predicate.clone(), collection_id.clone(), 0, MockLiveQuery);
747        assert!(matches!(relay.get_status(query_id), Some(Status::PendingRemote)));
748
749        // Clear any previous requests
750        mock_sender.clear_sent_requests();
751
752        // Simulate peer connection (should trigger automatic setup)
753        relay.notify_peer_connected(peer_id);
754
755        // Give async task time to complete
756        tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
757
758        // Verify request was sent
759        let sent_requests = mock_sender.get_sent_requests();
760        assert_eq!(sent_requests.len(), 1);
761        assert_eq!(sent_requests[0].0, peer_id);
762        assert_eq!(sent_requests[0].1, query_id);
763
764        // Verify subscription is established
765        assert!(matches!(relay.get_status(query_id), Some(Status::Established(established_peer_id, _)) if established_peer_id == peer_id));
766    }
767
768    #[tokio::test]
769    async fn test_failed_subscription_retry() {
770        let relay = SubscriptionRelay::new();
771        let mock_sender = Arc::new(MockMessageSender::<CollectionId>::new());
772        relay.set_node(mock_sender.clone()).expect("Failed to set message sender");
773
774        let query_id = proto::QueryId::new();
775        let collection_id = create_test_collection_id();
776        let predicate = create_test_selection();
777        let peer_id = EntityId::new();
778
779        // Connect peer and add subscription (should succeed initially)
780        relay.notify_peer_connected(peer_id);
781        relay.subscribe_query(query_id, collection_id.clone(), predicate.clone(), collection_id.clone(), 0, MockLiveQuery);
782
783        // Give async task time to complete
784        tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
785
786        // Verify subscription is marked as established (since no error was set)
787        assert!(matches!(relay.get_status(query_id), Some(Status::Established(established_peer_id, _)) if established_peer_id == peer_id));
788
789        // Now test the retry behavior by disconnecting the peer (puts subscription back to PendingRemote)
790        // then setting up the mock to fail, and reconnecting to trigger the retry
791        relay.notify_peer_disconnected(peer_id);
792
793        // Verify subscription is now in pending state
794        assert!(matches!(relay.get_status(query_id), Some(Status::PendingRemote)));
795
796        // Clear requests and set up mock to fail on the next call
797        mock_sender.clear_sent_requests();
798        mock_sender.set_fail_next(RequestError::ServerError("Invalid predicate".to_string()));
799
800        // Reconnect peer to trigger retry attempt
801        relay.notify_peer_connected(peer_id);
802
803        // Give async task time to complete
804        tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
805
806        // Verify retry was attempted (the error gets consumed)
807        let sent_requests = mock_sender.get_sent_requests();
808        assert_eq!(sent_requests.len(), 1);
809
810        // Verify subscription remains in failed state (non-retryable error)
811        assert!(matches!(relay.get_status(query_id), Some(Status::Failed)));
812    }
813
814    #[tokio::test]
815    async fn test_retryable_vs_non_retryable_failures() {
816        let relay = SubscriptionRelay::new();
817        let mock_sender = Arc::new(MockMessageSender::<CollectionId>::new());
818        relay.set_node(mock_sender.clone()).expect("Failed to set message sender");
819
820        let retryable_query_id = proto::QueryId::new();
821        let non_retryable_query_id = proto::QueryId::new();
822        let collection_id = create_test_collection_id();
823        let predicate = create_test_selection();
824        let peer_id = EntityId::new();
825
826        // Add subscriptions
827        relay.subscribe_query(retryable_query_id, collection_id.clone(), predicate.clone(), collection_id.clone(), 0, MockLiveQuery);
828        relay.subscribe_query(non_retryable_query_id, collection_id.clone(), predicate.clone(), collection_id.clone(), 0, MockLiveQuery);
829
830        // Manually set different failure types - retryable goes back to pending, non-retryable stays failed
831        {
832            let mut subscriptions = relay.inner.subscriptions.lock().unwrap();
833            if let Some(info) = subscriptions.get_mut(&retryable_query_id) {
834                info.status = Status::PendingRemote; // Retryable errors go back to pending
835            }
836            if let Some(info) = subscriptions.get_mut(&non_retryable_query_id) {
837                info.status = Status::Failed; // Non-retryable errors stay failed
838            }
839        }
840
841        // Connect peer and trigger retry
842        relay.notify_peer_connected(peer_id);
843
844        // Give async task time to complete
845        tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
846
847        // Verify only the retryable subscription was attempted
848        let sent_requests = mock_sender.get_sent_requests();
849        assert_eq!(sent_requests.len(), 1);
850        assert_eq!(sent_requests[0].1, retryable_query_id);
851
852        // Verify states
853        assert!(
854            matches!(relay.get_status(retryable_query_id), Some(Status::Established(established_peer_id, _)) if established_peer_id == peer_id)
855        );
856        assert!(matches!(relay.get_status(non_retryable_query_id), Some(Status::Failed)));
857    }
858
859    #[tokio::test]
860    async fn test_subscription_removal() {
861        let relay = SubscriptionRelay::new();
862        let mock_sender = Arc::new(MockMessageSender::<CollectionId>::new());
863        relay.set_node(mock_sender.clone()).expect("Failed to set message sender");
864
865        let query_id = proto::QueryId::new();
866        let collection_id = create_test_collection_id();
867        let predicate = create_test_selection();
868        let peer_id = EntityId::new();
869
870        // Connect peer and setup established subscription
871        relay.notify_peer_connected(peer_id);
872        relay.subscribe_query(query_id, collection_id.clone(), predicate, collection_id.clone(), 0, MockLiveQuery);
873
874        // Give async task time to complete
875        tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
876
877        assert!(matches!(relay.get_status(query_id), Some(Status::Established(established_peer_id, _)) if established_peer_id == peer_id));
878
879        // Clear previous requests to focus on unsubscribe
880        mock_sender.clear_sent_requests();
881
882        // Remove subscription
883        relay.unsubscribe_predicate(query_id);
884
885        // Give async task time to complete
886        tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
887
888        // Verify unsubscribe message was sent
889        let sent_requests = mock_sender.get_sent_requests();
890        assert_eq!(sent_requests.len(), 1);
891        assert_eq!(sent_requests[0].0, peer_id);
892        assert_eq!(sent_requests[0].1, query_id);
893
894        // Verify subscription is gone
895        assert!(matches!(relay.get_status(query_id), None));
896    }
897
898    #[tokio::test]
899    async fn test_edge_cases() {
900        let relay = SubscriptionRelay::new();
901        let mock_sender = Arc::new(MockMessageSender::<CollectionId>::new());
902
903        let query_id = proto::QueryId::new();
904        let collection_id = create_test_collection_id();
905        let predicate = create_test_selection();
906        let peer_id = EntityId::new();
907
908        // Test setup without message sender - should not crash
909        relay.subscribe_query(query_id, collection_id.clone(), predicate.clone(), collection_id.clone(), 0, MockLiveQuery);
910        tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
911
912        // Should still be pending since no sender
913        assert!(matches!(relay.get_status(query_id), Some(Status::PendingRemote)));
914
915        // Now set sender and test with no connected peers
916        relay.set_node(mock_sender.clone()).expect("Failed to set message sender");
917        tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
918
919        // Should still be pending since no peers available
920        assert!(matches!(relay.get_status(query_id), Some(Status::PendingRemote)));
921
922        // Verify no requests were sent
923        assert_eq!(mock_sender.get_sent_requests().len(), 0);
924
925        // Now connect a peer (should trigger automatic setup)
926        relay.notify_peer_connected(peer_id);
927        tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
928
929        // Should now be established
930        assert!(matches!(relay.get_status(query_id), Some(Status::Established(established_peer_id, _)) if established_peer_id == peer_id));
931        assert_eq!(mock_sender.get_sent_requests().len(), 1);
932    }
933
934    #[tokio::test]
935    async fn test_notify_unsubscribe_with_no_established_subscription() {
936        let relay = SubscriptionRelay::new();
937        let mock_sender = Arc::new(MockMessageSender::<CollectionId>::new());
938        relay.set_node(mock_sender.clone()).expect("Failed to set message sender");
939
940        let query_id = proto::QueryId::new();
941        let collection_id = create_test_collection_id();
942        let predicate = create_test_selection();
943
944        // Add subscription but don't establish it
945        relay.subscribe_query(query_id, collection_id.clone(), predicate, collection_id.clone(), 0, MockLiveQuery);
946        assert!(matches!(relay.get_status(query_id), Some(Status::PendingRemote)));
947
948        // Unsubscribe from pending subscription
949        relay.unsubscribe_predicate(query_id);
950
951        // Give async task time to complete (though no request should be sent)
952        tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
953
954        // Verify no unsubscribe message was sent (since it wasn't established)
955        let sent_requests = mock_sender.get_sent_requests();
956        assert_eq!(sent_requests.len(), 0);
957
958        // Verify subscription is gone
959        assert!(matches!(relay.get_status(query_id), None));
960    }
961}