firebase_rs_sdk/data_connect/
query.rs

1use std::collections::HashMap;
2use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
3use std::sync::{Arc, Mutex};
4use std::time::SystemTime;
5
6use serde_json::Value;
7
8use crate::data_connect::error::{DataConnectError, DataConnectResult};
9use crate::data_connect::reference::{
10    encode_query_key, string_to_system_time, DataSource, OpResult, QueryRef, QueryResult,
11    SerializedQuerySnapshot,
12};
13use crate::data_connect::transport::DataConnectTransport;
14use crate::platform::runtime;
15
16#[cfg(all(target_arch = "wasm32", feature = "wasm-web"))]
17type ValueCallback = dyn Fn() + 'static;
18#[cfg(not(all(target_arch = "wasm32", feature = "wasm-web")))]
19type ValueCallback = dyn Fn() + Send + Sync + 'static;
20
21#[cfg(all(target_arch = "wasm32", feature = "wasm-web"))]
22type DataCallback<T> = dyn Fn(&T) + 'static;
23#[cfg(not(all(target_arch = "wasm32", feature = "wasm-web")))]
24type DataCallback<T> = dyn Fn(&T) + Send + Sync + 'static;
25
26pub type QueryResultCallback = Arc<DataCallback<QueryResult>>;
27pub type QueryErrorCallback = Arc<DataCallback<DataConnectError>>;
28pub type QueryCompleteCallback = Arc<ValueCallback>;
29
30/// Observer-style subscription handlers.
31#[derive(Clone)]
32pub struct QuerySubscriptionHandlers {
33    pub on_next: QueryResultCallback,
34    pub on_error: Option<QueryErrorCallback>,
35    pub on_complete: Option<QueryCompleteCallback>,
36}
37
38impl QuerySubscriptionHandlers {
39    pub fn new(on_next: QueryResultCallback) -> Self {
40        Self {
41            on_next,
42            on_error: None,
43            on_complete: None,
44        }
45    }
46
47    pub fn with_error(mut self, callback: QueryErrorCallback) -> Self {
48        self.on_error = Some(callback);
49        self
50    }
51
52    pub fn with_complete(mut self, callback: QueryCompleteCallback) -> Self {
53        self.on_complete = Some(callback);
54        self
55    }
56}
57
58/// Guard returned when subscribing to a query.
59pub struct QuerySubscriptionHandle {
60    tracked: Arc<TrackedQuery>,
61    subscriber_id: u64,
62    closed: AtomicBool,
63}
64
65impl QuerySubscriptionHandle {
66    fn new(tracked: Arc<TrackedQuery>, subscriber_id: u64) -> Self {
67        Self {
68            tracked,
69            subscriber_id,
70            closed: AtomicBool::new(false),
71        }
72    }
73
74    pub fn unsubscribe(mut self) {
75        self.close();
76    }
77
78    fn close(&mut self) {
79        if !self.closed.swap(true, Ordering::SeqCst) {
80            self.tracked.remove_subscriber(self.subscriber_id);
81        }
82    }
83}
84
85impl Drop for QuerySubscriptionHandle {
86    fn drop(&mut self) {
87        self.close();
88    }
89}
90
91/// Tracks outstanding queries, cached payloads, and subscribers.
92#[derive(Clone)]
93pub struct QueryManager {
94    inner: Arc<QueryManagerInner>,
95}
96
97impl QueryManager {
98    pub fn new(transport: Arc<dyn DataConnectTransport>) -> Self {
99        Self {
100            inner: Arc::new(QueryManagerInner::new(transport)),
101        }
102    }
103
104    pub async fn execute_query(&self, query_ref: QueryRef) -> DataConnectResult<QueryResult> {
105        self.inner.execute_query(query_ref).await
106    }
107
108    pub fn subscribe(
109        &self,
110        query_ref: QueryRef,
111        handlers: QuerySubscriptionHandlers,
112        initial_cache: Option<OpResult>,
113    ) -> DataConnectResult<QuerySubscriptionHandle> {
114        self.inner
115            .subscribe(self.clone(), query_ref, handlers, initial_cache)
116    }
117}
118
119struct QueryManagerInner {
120    transport: Arc<dyn DataConnectTransport>,
121    queries: Mutex<HashMap<String, Arc<TrackedQuery>>>,
122    next_id: AtomicU64,
123}
124
125impl QueryManagerInner {
126    fn new(transport: Arc<dyn DataConnectTransport>) -> Self {
127        Self {
128            transport,
129            queries: Mutex::new(HashMap::new()),
130            next_id: AtomicU64::new(1),
131        }
132    }
133
134    fn key_for(query_ref: &QueryRef) -> String {
135        encode_query_key(query_ref.operation_name(), query_ref.variables())
136    }
137
138    fn track(&self, query_ref: &QueryRef, initial_cache: Option<OpResult>) {
139        let tracked = self.tracked_entry(query_ref);
140        if let Some(cache) = initial_cache {
141            tracked.maybe_update_cache(cache);
142        }
143    }
144
145    fn tracked_entry(&self, query_ref: &QueryRef) -> Arc<TrackedQuery> {
146        let key = Self::key_for(query_ref);
147        let mut queries = self.queries.lock().unwrap();
148        queries
149            .entry(key.clone())
150            .or_insert_with(|| {
151                Arc::new(TrackedQuery::new(
152                    key,
153                    query_ref.operation_name().into(),
154                    query_ref.variables().clone(),
155                ))
156            })
157            .clone()
158    }
159
160    async fn execute_query(&self, query_ref: QueryRef) -> DataConnectResult<QueryResult> {
161        let tracked = self.tracked_entry(&query_ref);
162        match self
163            .transport
164            .invoke_query(query_ref.operation_name(), query_ref.variables())
165            .await
166        {
167            Ok(data) => {
168                let fetch_time = SystemTime::now();
169                let result = QueryResult {
170                    data: data.clone(),
171                    source: DataSource::Server,
172                    fetch_time,
173                    query_ref: query_ref.clone(),
174                };
175                tracked.set_cache(OpResult {
176                    data,
177                    source: DataSource::Cache,
178                    fetch_time,
179                });
180                tracked.clear_error();
181                tracked.notify_success(&result);
182                Ok(result)
183            }
184            Err(err) => {
185                tracked.record_error(err.clone());
186                Err(err)
187            }
188        }
189    }
190
191    fn subscribe(
192        &self,
193        manager: QueryManager,
194        query_ref: QueryRef,
195        handlers: QuerySubscriptionHandlers,
196        initial_cache: Option<OpResult>,
197    ) -> DataConnectResult<QuerySubscriptionHandle> {
198        self.track(&query_ref, initial_cache);
199        let tracked = self.tracked_entry(&query_ref);
200        let subscriber_id = self.next_id.fetch_add(1, Ordering::SeqCst);
201        tracked.add_subscriber(subscriber_id, handlers.clone());
202
203        if let Some(cache) = tracked.cache_snapshot() {
204            let snapshot = QueryResult {
205                data: cache.data,
206                source: cache.source,
207                fetch_time: cache.fetch_time,
208                query_ref: query_ref.clone(),
209            };
210            (handlers.on_next)(&snapshot);
211        } else {
212            let manager_clone = manager.clone();
213            let query_clone = query_ref.clone();
214            runtime::spawn_detached(async move {
215                let _ = manager_clone.execute_query(query_clone).await;
216            });
217        }
218
219        if let Some(last_error) = tracked.last_error() {
220            if let Some(on_error) = handlers.on_error {
221                on_error(&last_error);
222            }
223        }
224
225        Ok(QuerySubscriptionHandle::new(tracked, subscriber_id))
226    }
227}
228
229struct SubscriberEntry {
230    id: u64,
231    handlers: QuerySubscriptionHandlers,
232}
233
234struct TrackedState {
235    subscribers: Vec<SubscriberEntry>,
236    current_cache: Option<OpResult>,
237    last_error: Option<DataConnectError>,
238}
239
240struct TrackedQuery {
241    #[allow(unused)]
242    key: String,
243    #[allow(unused)]
244    name: Arc<str>,
245    #[allow(unused)]
246    variables: Value,
247    state: Mutex<TrackedState>,
248}
249
250impl TrackedQuery {
251    fn new(key: String, name: Arc<str>, variables: Value) -> Self {
252        Self {
253            key,
254            name,
255            variables,
256            state: Mutex::new(TrackedState {
257                subscribers: Vec::new(),
258                current_cache: None,
259                last_error: None,
260            }),
261        }
262    }
263
264    fn add_subscriber(&self, id: u64, handlers: QuerySubscriptionHandlers) {
265        self.state
266            .lock()
267            .unwrap()
268            .subscribers
269            .push(SubscriberEntry { id, handlers });
270    }
271
272    fn remove_subscriber(&self, id: u64) {
273        let mut state = self.state.lock().unwrap();
274        if let Some(pos) = state.subscribers.iter().position(|entry| entry.id == id) {
275            if let Some(callback) = state.subscribers[pos].handlers.on_complete.clone() {
276                callback();
277            }
278            state.subscribers.remove(pos);
279        }
280    }
281
282    fn maybe_update_cache(&self, cache: OpResult) {
283        let mut state = self.state.lock().unwrap();
284        match &state.current_cache {
285            Some(existing) if existing.fetch_time >= cache.fetch_time => {}
286            _ => state.current_cache = Some(cache),
287        }
288    }
289
290    fn set_cache(&self, cache: OpResult) {
291        self.state.lock().unwrap().current_cache = Some(cache);
292    }
293
294    fn cache_snapshot(&self) -> Option<OpResult> {
295        self.state.lock().unwrap().current_cache.clone()
296    }
297
298    fn record_error(&self, error: DataConnectError) {
299        let mut state = self.state.lock().unwrap();
300        state.last_error = Some(error.clone());
301        let subscribers = state
302            .subscribers
303            .iter()
304            .map(|entry| entry.handlers.on_error.clone())
305            .collect::<Vec<_>>();
306        drop(state);
307        for maybe_handler in subscribers {
308            if let Some(handler) = maybe_handler {
309                handler(&error);
310            }
311        }
312    }
313
314    fn last_error(&self) -> Option<DataConnectError> {
315        self.state.lock().unwrap().last_error.clone()
316    }
317
318    fn clear_error(&self) {
319        self.state.lock().unwrap().last_error = None;
320    }
321
322    fn notify_success(&self, result: &QueryResult) {
323        let handlers = self
324            .state
325            .lock()
326            .unwrap()
327            .subscribers
328            .iter()
329            .map(|entry| entry.handlers.on_next.clone())
330            .collect::<Vec<_>>();
331        for callback in handlers {
332            callback(result);
333        }
334    }
335}
336
337/// Converts a serialized query snapshot (e.g. produced on the server) into an initial cache entry.
338pub fn cache_from_serialized(snapshot: &SerializedQuerySnapshot) -> Option<OpResult> {
339    let fetch_time = string_to_system_time(&snapshot.fetch_time)?;
340    Some(OpResult {
341        data: snapshot.data.clone(),
342        source: snapshot.source,
343        fetch_time,
344    })
345}