firebase_rs_sdk/data_connect/
query.rs1use std::collections::HashMap;
2use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
3use std::sync::{Arc, Mutex};
4use std::time::SystemTime;
5
6use serde_json::Value;
7
8use crate::data_connect::error::{DataConnectError, DataConnectResult};
9use crate::data_connect::reference::{
10 encode_query_key, string_to_system_time, DataSource, OpResult, QueryRef, QueryResult,
11 SerializedQuerySnapshot,
12};
13use crate::data_connect::transport::DataConnectTransport;
14use crate::platform::runtime;
15
16#[cfg(all(target_arch = "wasm32", feature = "wasm-web"))]
17type ValueCallback = dyn Fn() + 'static;
18#[cfg(not(all(target_arch = "wasm32", feature = "wasm-web")))]
19type ValueCallback = dyn Fn() + Send + Sync + 'static;
20
21#[cfg(all(target_arch = "wasm32", feature = "wasm-web"))]
22type DataCallback<T> = dyn Fn(&T) + 'static;
23#[cfg(not(all(target_arch = "wasm32", feature = "wasm-web")))]
24type DataCallback<T> = dyn Fn(&T) + Send + Sync + 'static;
25
26pub type QueryResultCallback = Arc<DataCallback<QueryResult>>;
27pub type QueryErrorCallback = Arc<DataCallback<DataConnectError>>;
28pub type QueryCompleteCallback = Arc<ValueCallback>;
29
30#[derive(Clone)]
32pub struct QuerySubscriptionHandlers {
33 pub on_next: QueryResultCallback,
34 pub on_error: Option<QueryErrorCallback>,
35 pub on_complete: Option<QueryCompleteCallback>,
36}
37
38impl QuerySubscriptionHandlers {
39 pub fn new(on_next: QueryResultCallback) -> Self {
40 Self {
41 on_next,
42 on_error: None,
43 on_complete: None,
44 }
45 }
46
47 pub fn with_error(mut self, callback: QueryErrorCallback) -> Self {
48 self.on_error = Some(callback);
49 self
50 }
51
52 pub fn with_complete(mut self, callback: QueryCompleteCallback) -> Self {
53 self.on_complete = Some(callback);
54 self
55 }
56}
57
58pub struct QuerySubscriptionHandle {
60 tracked: Arc<TrackedQuery>,
61 subscriber_id: u64,
62 closed: AtomicBool,
63}
64
65impl QuerySubscriptionHandle {
66 fn new(tracked: Arc<TrackedQuery>, subscriber_id: u64) -> Self {
67 Self {
68 tracked,
69 subscriber_id,
70 closed: AtomicBool::new(false),
71 }
72 }
73
74 pub fn unsubscribe(mut self) {
75 self.close();
76 }
77
78 fn close(&mut self) {
79 if !self.closed.swap(true, Ordering::SeqCst) {
80 self.tracked.remove_subscriber(self.subscriber_id);
81 }
82 }
83}
84
85impl Drop for QuerySubscriptionHandle {
86 fn drop(&mut self) {
87 self.close();
88 }
89}
90
91#[derive(Clone)]
93pub struct QueryManager {
94 inner: Arc<QueryManagerInner>,
95}
96
97impl QueryManager {
98 pub fn new(transport: Arc<dyn DataConnectTransport>) -> Self {
99 Self {
100 inner: Arc::new(QueryManagerInner::new(transport)),
101 }
102 }
103
104 pub async fn execute_query(&self, query_ref: QueryRef) -> DataConnectResult<QueryResult> {
105 self.inner.execute_query(query_ref).await
106 }
107
108 pub fn subscribe(
109 &self,
110 query_ref: QueryRef,
111 handlers: QuerySubscriptionHandlers,
112 initial_cache: Option<OpResult>,
113 ) -> DataConnectResult<QuerySubscriptionHandle> {
114 self.inner
115 .subscribe(self.clone(), query_ref, handlers, initial_cache)
116 }
117}
118
119struct QueryManagerInner {
120 transport: Arc<dyn DataConnectTransport>,
121 queries: Mutex<HashMap<String, Arc<TrackedQuery>>>,
122 next_id: AtomicU64,
123}
124
125impl QueryManagerInner {
126 fn new(transport: Arc<dyn DataConnectTransport>) -> Self {
127 Self {
128 transport,
129 queries: Mutex::new(HashMap::new()),
130 next_id: AtomicU64::new(1),
131 }
132 }
133
134 fn key_for(query_ref: &QueryRef) -> String {
135 encode_query_key(query_ref.operation_name(), query_ref.variables())
136 }
137
138 fn track(&self, query_ref: &QueryRef, initial_cache: Option<OpResult>) {
139 let tracked = self.tracked_entry(query_ref);
140 if let Some(cache) = initial_cache {
141 tracked.maybe_update_cache(cache);
142 }
143 }
144
145 fn tracked_entry(&self, query_ref: &QueryRef) -> Arc<TrackedQuery> {
146 let key = Self::key_for(query_ref);
147 let mut queries = self.queries.lock().unwrap();
148 queries
149 .entry(key.clone())
150 .or_insert_with(|| {
151 Arc::new(TrackedQuery::new(
152 key,
153 query_ref.operation_name().into(),
154 query_ref.variables().clone(),
155 ))
156 })
157 .clone()
158 }
159
160 async fn execute_query(&self, query_ref: QueryRef) -> DataConnectResult<QueryResult> {
161 let tracked = self.tracked_entry(&query_ref);
162 match self
163 .transport
164 .invoke_query(query_ref.operation_name(), query_ref.variables())
165 .await
166 {
167 Ok(data) => {
168 let fetch_time = SystemTime::now();
169 let result = QueryResult {
170 data: data.clone(),
171 source: DataSource::Server,
172 fetch_time,
173 query_ref: query_ref.clone(),
174 };
175 tracked.set_cache(OpResult {
176 data,
177 source: DataSource::Cache,
178 fetch_time,
179 });
180 tracked.clear_error();
181 tracked.notify_success(&result);
182 Ok(result)
183 }
184 Err(err) => {
185 tracked.record_error(err.clone());
186 Err(err)
187 }
188 }
189 }
190
191 fn subscribe(
192 &self,
193 manager: QueryManager,
194 query_ref: QueryRef,
195 handlers: QuerySubscriptionHandlers,
196 initial_cache: Option<OpResult>,
197 ) -> DataConnectResult<QuerySubscriptionHandle> {
198 self.track(&query_ref, initial_cache);
199 let tracked = self.tracked_entry(&query_ref);
200 let subscriber_id = self.next_id.fetch_add(1, Ordering::SeqCst);
201 tracked.add_subscriber(subscriber_id, handlers.clone());
202
203 if let Some(cache) = tracked.cache_snapshot() {
204 let snapshot = QueryResult {
205 data: cache.data,
206 source: cache.source,
207 fetch_time: cache.fetch_time,
208 query_ref: query_ref.clone(),
209 };
210 (handlers.on_next)(&snapshot);
211 } else {
212 let manager_clone = manager.clone();
213 let query_clone = query_ref.clone();
214 runtime::spawn_detached(async move {
215 let _ = manager_clone.execute_query(query_clone).await;
216 });
217 }
218
219 if let Some(last_error) = tracked.last_error() {
220 if let Some(on_error) = handlers.on_error {
221 on_error(&last_error);
222 }
223 }
224
225 Ok(QuerySubscriptionHandle::new(tracked, subscriber_id))
226 }
227}
228
229struct SubscriberEntry {
230 id: u64,
231 handlers: QuerySubscriptionHandlers,
232}
233
234struct TrackedState {
235 subscribers: Vec<SubscriberEntry>,
236 current_cache: Option<OpResult>,
237 last_error: Option<DataConnectError>,
238}
239
240struct TrackedQuery {
241 #[allow(unused)]
242 key: String,
243 #[allow(unused)]
244 name: Arc<str>,
245 #[allow(unused)]
246 variables: Value,
247 state: Mutex<TrackedState>,
248}
249
250impl TrackedQuery {
251 fn new(key: String, name: Arc<str>, variables: Value) -> Self {
252 Self {
253 key,
254 name,
255 variables,
256 state: Mutex::new(TrackedState {
257 subscribers: Vec::new(),
258 current_cache: None,
259 last_error: None,
260 }),
261 }
262 }
263
264 fn add_subscriber(&self, id: u64, handlers: QuerySubscriptionHandlers) {
265 self.state
266 .lock()
267 .unwrap()
268 .subscribers
269 .push(SubscriberEntry { id, handlers });
270 }
271
272 fn remove_subscriber(&self, id: u64) {
273 let mut state = self.state.lock().unwrap();
274 if let Some(pos) = state.subscribers.iter().position(|entry| entry.id == id) {
275 if let Some(callback) = state.subscribers[pos].handlers.on_complete.clone() {
276 callback();
277 }
278 state.subscribers.remove(pos);
279 }
280 }
281
282 fn maybe_update_cache(&self, cache: OpResult) {
283 let mut state = self.state.lock().unwrap();
284 match &state.current_cache {
285 Some(existing) if existing.fetch_time >= cache.fetch_time => {}
286 _ => state.current_cache = Some(cache),
287 }
288 }
289
290 fn set_cache(&self, cache: OpResult) {
291 self.state.lock().unwrap().current_cache = Some(cache);
292 }
293
294 fn cache_snapshot(&self) -> Option<OpResult> {
295 self.state.lock().unwrap().current_cache.clone()
296 }
297
298 fn record_error(&self, error: DataConnectError) {
299 let mut state = self.state.lock().unwrap();
300 state.last_error = Some(error.clone());
301 let subscribers = state
302 .subscribers
303 .iter()
304 .map(|entry| entry.handlers.on_error.clone())
305 .collect::<Vec<_>>();
306 drop(state);
307 for maybe_handler in subscribers {
308 if let Some(handler) = maybe_handler {
309 handler(&error);
310 }
311 }
312 }
313
314 fn last_error(&self) -> Option<DataConnectError> {
315 self.state.lock().unwrap().last_error.clone()
316 }
317
318 fn clear_error(&self) {
319 self.state.lock().unwrap().last_error = None;
320 }
321
322 fn notify_success(&self, result: &QueryResult) {
323 let handlers = self
324 .state
325 .lock()
326 .unwrap()
327 .subscribers
328 .iter()
329 .map(|entry| entry.handlers.on_next.clone())
330 .collect::<Vec<_>>();
331 for callback in handlers {
332 callback(result);
333 }
334 }
335}
336
337pub fn cache_from_serialized(snapshot: &SerializedQuerySnapshot) -> Option<OpResult> {
339 let fetch_time = string_to_system_time(&snapshot.fetch_time)?;
340 Some(OpResult {
341 data: snapshot.data.clone(),
342 source: snapshot.source,
343 fetch_time,
344 })
345}