1use crate::{query_observer::ListenerKey, *};
2use leptos::*;
3use std::{borrow::Borrow, cell::Cell, collections::HashMap, future::Future, rc::Rc};
4
5use self::{
6    cache_observer::CacheObserver, query::Query, query_cache::QueryCache,
7    query_observer::QueryObserver, query_persister::QueryPersister,
8};
9
10pub fn provide_query_client() {
12    provide_query_client_with_options(DefaultQueryOptions::default());
13}
14
15pub fn provide_query_client_with_options(options: DefaultQueryOptions) {
17    let owner = Owner::current().expect("Owner to be present");
18
19    provide_context(QueryClient::new(owner, options));
20}
21
22pub fn provide_query_client_with_options_and_persister(
24    options: DefaultQueryOptions,
25    persister: impl QueryPersister + Clone + 'static,
26) {
27    let owner = Owner::current().expect("Owner to be present");
28
29    let client = QueryClient::new(owner, options);
30
31    client.add_persister(persister);
32
33    provide_context(client);
34}
35
36pub fn use_query_client() -> QueryClient {
38    use_context::<QueryClient>().expect("Query Client Missing.")
39}
40
41#[derive(Clone)]
54pub struct QueryClient {
55    pub(crate) cache: QueryCache,
56    pub(crate) default_options: DefaultQueryOptions,
57}
58
59impl QueryClient {
60    pub fn new(owner: Owner, default_options: DefaultQueryOptions) -> Self {
62        Self {
63            cache: QueryCache::new(owner),
64            default_options,
65        }
66    }
67
68    pub async fn fetch_query<K, V, Fu>(
73        &self,
74        key: K,
75        fetcher: impl Fn(K) -> Fu + 'static,
76    ) -> QueryState<V>
77    where
78        K: QueryKey + 'static,
79        V: QueryValue + 'static,
80        Fu: Future<Output = V> + 'static,
81    {
82        #[cfg(any(feature = "hydrate", feature = "csr"))]
83        {
84            let query = self.cache.get_or_create_query::<K, V>(key);
85
86            query::execute_query(query.clone(), fetcher).await;
87
88            query.get_state()
89        }
90        #[cfg(not(any(feature = "hydrate", feature = "csr")))]
91        {
92            let _ = key;
93            let _ = fetcher;
94            QueryState::Created
95        }
96    }
97
98    pub async fn prefetch_query<K, V, Fu>(&self, key: K, fetcher: impl Fn(K) -> Fu + 'static)
103    where
104        K: QueryKey + 'static,
105        V: QueryValue + 'static,
106        Fu: Future<Output = V> + 'static,
107    {
108        #[cfg(any(feature = "hydrate", feature = "csr"))]
109        {
110            let query = self.cache.get_or_create_query::<K, V>(key);
111
112            query::execute_query(query.clone(), fetcher).await;
113        }
114        #[cfg(not(any(feature = "hydrate", feature = "csr")))]
115        {
116            let _ = key;
117            let _ = fetcher;
118        }
119    }
120
121    pub fn get_query_state<K, V>(
124        &self,
125        key: impl Fn() -> K + 'static,
126    ) -> Signal<Option<QueryState<V>>>
127    where
128        K: QueryKey + 'static,
129        V: QueryValue + 'static,
130    {
131        let cache = self.cache.clone();
132        let size = self.size();
133
134        let maybe_query = create_memo(move |_| {
136            let key = key();
137            size.track();
139            cache.get_query::<K, V>(&key)
140        });
141
142        let observer = Rc::new(QueryObserver::no_fetcher(
143            QueryOptions::default(),
144            maybe_query.get_untracked(),
145        ));
146
147        let state_signal = RwSignal::new(maybe_query.get_untracked().map(|q| q.get_state()));
148
149        let listener = Rc::new(Cell::new(None::<ListenerKey>));
150
151        create_isomorphic_effect({
152            move |_| {
153                if listener.get().is_none() {
155                    let listener_id = observer.add_listener(move |state| {
156                        state_signal.set(Some(state.clone()));
157                    });
158                    listener.set(Some(listener_id));
159                }
160
161                let query = maybe_query.get();
163                let current_state = query.as_ref().map(|q| q.get_state());
164                observer.update_query(query);
165                state_signal.set(current_state);
166            }
167        });
168
169        state_signal.into()
170    }
171
172    pub fn peek_query_state<K, V>(&self, key: &K) -> Option<QueryState<V>>
176    where
177        K: QueryKey + 'static,
178        V: QueryValue + 'static,
179    {
180        self.cache.get_query::<K, V>(key).map(|q| q.get_state())
181    }
182
183    pub fn invalidate_query<K, V>(&self, key: impl Borrow<K>) -> bool
199    where
200        K: QueryKey + 'static,
201        V: QueryValue + 'static,
202    {
203        self.cache
204            .use_cache_option(|cache: &HashMap<K, Query<K, V>>| {
205                cache
206                    .get(Borrow::borrow(&key))
207                    .map(|state| state.mark_invalid())
208            })
209            .unwrap_or(false)
210    }
211
212    pub fn invalidate_queries<K, V, Q>(&self, keys: impl IntoIterator<Item = Q>) -> Option<Vec<Q>>
228    where
229        K: crate::QueryKey + 'static,
230
231        V: crate::QueryValue + 'static,
232        Q: Borrow<K> + 'static,
233    {
234        self.cache
235            .use_cache_option(|cache: &HashMap<K, Query<K, V>>| {
236                let result = keys
237                    .into_iter()
238                    .filter(|key| {
239                        cache
240                            .get(Borrow::borrow(key))
241                            .map(|query| query.mark_invalid())
242                            .unwrap_or(false)
243                    })
244                    .collect::<Vec<_>>();
245                Some(result)
246            })
247    }
248
249    pub fn invalidate_query_type<K, V>(&self)
271    where
272        K: QueryKey + 'static,
273        V: QueryValue + 'static,
274    {
275        self.cache
276            .use_cache_option(|cache: &HashMap<K, Query<K, V>>| {
277                for q in cache.values() {
278                    q.mark_invalid();
279                }
280                Some(())
281            });
282    }
283
284    pub fn invalidate_all_queries(&self) {
301        self.cache.invalidate_all_queries()
302    }
303
304    pub fn size(&self) -> Signal<usize> {
318        self.cache.size()
319    }
320
321    pub fn update_query_data<K, V>(
361        &self,
362        key: K,
363        updater: impl FnOnce(Option<&V>) -> Option<V> + 'static,
364    ) where
365        K: QueryKey + 'static,
366        V: QueryValue + 'static,
367    {
368        self.cache
369            .use_cache_entry(key.clone(), move |(owner, entry)| match entry {
370                Some(query) => {
371                    query.maybe_map_state(|state| match state {
372                        QueryState::Created | QueryState::Loading => {
373                            if let Some(result) = updater(None) {
374                                Ok(QueryState::Loaded(QueryData::now(result)))
375                            } else {
376                                Err(state)
377                            }
378                        }
379                        QueryState::Fetching(ref data) => {
380                            if let Some(result) = updater(Some(&data.data)) {
381                                Ok(QueryState::Fetching(QueryData::now(result)))
382                            } else {
383                                Err(state)
384                            }
385                        }
386                        QueryState::Loaded(ref data) => {
387                            if let Some(result) = updater(Some(&data.data)) {
388                                Ok(QueryState::Loaded(QueryData::now(result)))
389                            } else {
390                                Err(state)
391                            }
392                        }
393                        QueryState::Invalid(ref data) => {
394                            if let Some(result) = updater(Some(&data.data)) {
395                                Ok(QueryState::Loaded(QueryData::now(result)))
396                            } else {
397                                Err(state)
398                            }
399                        }
400                    });
401                    None
402                }
403                None => {
404                    if let Some(result) = updater(None) {
405                        let query = with_owner(owner, || Query::new(key));
406                        query.set_state(QueryState::Loaded(QueryData::now(result)));
407                        Some(query)
408                    } else {
409                        None
410                    }
411                }
412            });
413    }
414
415    pub fn set_query_data<K, V>(&self, key: K, data: V)
418    where
419        K: QueryKey + 'static,
420        V: QueryValue + 'static,
421    {
422        self.update_query_data(key, |_| Some(data));
423    }
424
425    pub fn update_query_data_mut<K, V>(
428        &self,
429        key: impl Borrow<K>,
430        updater: impl FnOnce(&mut V),
431    ) -> bool
432    where
433        K: QueryKey + 'static,
434        V: QueryValue + 'static,
435    {
436        self.cache.use_cache::<K, V, bool>(move |cache| {
437            let mut updated = false;
438            if let Some(query) = cache.get(key.borrow()) {
439                query.update_state(|state| {
440                    if let Some(data) = state.data_mut() {
441                        updater(data);
442                        updated = true;
443                    }
444                });
445            }
446            updated
447        })
448    }
449
450    pub fn cancel_query<K, V>(&self, key: K) -> bool
453    where
454        K: QueryKey + 'static,
455        V: QueryValue + 'static,
456    {
457        self.cache.use_cache::<K, V, bool>(move |cache| {
458            if let Some(query) = cache.get(&key) {
459                query.cancel()
460            } else {
461                false
462            }
463        })
464    }
465
466    pub fn register_cache_observer(&self, observer: impl CacheObserver + 'static) {
468        let key = self.cache.register_observer(observer);
469        let cache = self.cache.clone();
470
471        on_cleanup(move || {
472            cache.unregister_observer(key);
473        })
474    }
475
476    pub fn add_persister(&self, persister: impl QueryPersister + Clone + 'static) {
478        self.register_cache_observer(persister.clone());
479        self.cache.add_persister(persister);
480    }
481
482    pub fn remove_persister(&self) -> bool {
484        self.cache.remove_persister().is_some()
485    }
486
487    pub fn clear(&self) {
489        self.cache.clear_all_queries()
490    }
491}
492
493#[cfg(all(test, not(any(feature = "csr", feature = "hydrate"))))]
494mod tests {
495    use super::*;
496
497    #[test]
498    fn update_query_data() {
499        let _ = create_runtime();
500
501        provide_query_client();
502        let client = use_query_client();
503
504        let state = || {
505            use_query_client()
506                .cache
507                .get_query::<u32, String>(&0)
508                .map(|q| q.get_state())
509        };
510
511        assert_eq!(None, state());
512        assert_eq!(0, client.size().get_untracked());
513
514        client.update_query_data::<u32, String>(0, |_| None);
515
516        assert_eq!(None, state());
517        assert_eq!(0, client.size().get_untracked());
518
519        client.update_query_data::<u32, String>(0, |_| Some("0".to_string()));
520
521        assert_eq!(1, client.size().get_untracked());
522
523        assert_eq!(
524            Some("0".to_string()),
525            state().and_then(|q| q.data().cloned())
526        );
527
528        assert!(matches!(state(), Some(QueryState::Loaded { .. })));
529
530        client.update_query_data::<u32, String>(0, |_| Some("1".to_string()));
531
532        assert_eq!(
533            Some("1".to_string()),
534            state().and_then(|q| q.data().cloned())
535        );
536    }
537
538    #[test]
539    fn set_query_data_new_query() {
540        let _ = create_runtime();
541
542        provide_query_client();
543        let client = use_query_client();
544
545        let state = || {
547            use_query_client()
548                .cache
549                .get_query::<u32, String>(&0)
550                .map(|q| q.get_state())
551                .and_then(|s| s.data().cloned())
552        };
553
554        assert_eq!(None, state());
556
557        client.set_query_data::<u32, String>(0, "New Data".to_string());
559
560        assert_eq!(Some("New Data".to_string()), state());
562    }
563
564    #[test]
565    fn set_query_data_existing_query() {
566        let _ = create_runtime();
567
568        provide_query_client();
569        let client = use_query_client();
570
571        let state = |key: u32| {
573            use_query_client()
574                .cache
575                .get_query::<u32, String>(&key)
576                .map(|q| q.get_state())
577                .and_then(|s| s.data().cloned())
578        };
579
580        client.set_query_data::<u32, String>(1, "Initial Data".to_string());
582
583        assert_eq!(Some("Initial Data".to_string()), state(1));
585
586        client.set_query_data::<u32, String>(1, "Updated Data".to_string());
588
589        assert_eq!(Some("Updated Data".to_string()), state(1));
591    }
592
593    #[test]
594    fn can_use_same_key_with_different_value_types() {
595        let _ = create_runtime();
596
597        provide_query_client();
598        let client = use_query_client();
599
600        client.update_query_data::<u32, String>(0, |_| Some("0".to_string()));
601
602        client.update_query_data::<u32, u32>(0, |_| Some(1234));
603
604        assert_eq!(2, client.size().get_untracked());
605    }
606
607    #[test]
608    fn can_invalidate_while_subscribed() {
609        let _ = create_runtime();
610
611        provide_query_client();
612        let client = use_query_client();
613
614        let subscription = client.get_query_state::<u32, u32>(|| 0_u32);
615
616        create_isomorphic_effect(move |_| {
617            subscription.track();
618        });
619
620        client.update_query_data::<u32, u32>(0_u32, |_| Some(1234));
621
622        assert!(client.invalidate_query::<u32, u32>(0));
623        let state = subscription.get_untracked();
624
625        assert!(
626            matches!(state, Some(QueryState::Invalid { .. })),
627            "Query should be invalid"
628        );
629    }
630
631    #[test]
632    fn can_invalidate_multiple() {
633        let _ = create_runtime();
634
635        provide_query_client();
636        let client = use_query_client();
637
638        client.update_query_data::<u32, u32>(0, |_| Some(1234));
639        client.update_query_data::<u32, u32>(1, |_| Some(1234));
640        let keys: Vec<u32> = vec![0, 1];
641        let invalidated = client
642            .invalidate_queries::<u32, u32, _>(keys.clone())
643            .unwrap_or_default();
644
645        assert_eq!(keys, invalidated)
646    }
647
648    #[test]
649    fn can_invalidate_multiple_strings() {
650        let _ = create_runtime();
651
652        provide_query_client();
653        let client = use_query_client();
654
655        let zero = "0".to_string();
656        let one = "1".to_string();
657
658        client.update_query_data::<String, String>(zero.clone(), |_| Some("1234".into()));
659        client.update_query_data::<String, String>(one.clone(), |_| Some("5678".into()));
660
661        let keys = vec![zero, one];
662        let invalidated = client
663            .invalidate_queries::<String, String, _>(keys.clone())
664            .unwrap_or_default();
665
666        assert_eq!(keys, invalidated)
667    }
668
669    #[test]
670    fn invalidate_all() {
671        let _ = create_runtime();
672
673        provide_query_client();
674        let client = use_query_client();
675
676        let zero = "0".to_string();
677        let one = "1".to_string();
678
679        client.update_query_data::<String, String>(zero.clone(), |_| Some("1234".into()));
680        client.update_query_data::<String, String>(one.clone(), |_| Some("5678".into()));
681        client.update_query_data::<u32, u32>(0, |_| Some(1234));
682        client.update_query_data::<u32, u32>(1, |_| Some(5678));
683
684        let state0_string = client.get_query_state::<String, String>(move || zero.clone());
685
686        let state1_string = client.get_query_state::<String, String>(move || one.clone());
687
688        let state0 = client.get_query_state::<u32, u32>(|| 0);
689        let state1 = client.get_query_state::<u32, u32>(|| 1);
690
691        client.invalidate_all_queries();
692
693        assert!(matches!(
694            state0.get_untracked(),
695            Some(QueryState::Invalid { .. })
696        ));
697        assert!(matches!(
698            state1.get_untracked(),
699            Some(QueryState::Invalid { .. })
700        ));
701        assert!(matches!(
702            state0_string.get_untracked(),
703            Some(QueryState::Invalid { .. })
704        ));
705        assert!(matches!(
706            state1_string.get_untracked(),
707            Some(QueryState::Invalid { .. })
708        ));
709    }
710
711    #[test]
712    fn can_invalidate_subset() {
713        let _ = create_runtime();
714
715        provide_query_client();
716        let client = use_query_client();
717
718        client.update_query_data::<u32, u32>(0, |_| Some(1234));
719        client.update_query_data::<u32, u32>(1, |_| Some(1234));
720
721        let state0 = client.get_query_state::<u32, u32>(|| 0);
722        let state1 = client.get_query_state::<u32, u32>(|| 1);
723
724        client.invalidate_query_type::<u32, u32>();
725
726        assert!(matches!(
727            state0.get_untracked(),
728            Some(QueryState::Invalid { .. })
729        ));
730        assert!(matches!(
731            state1.get_untracked(),
732            Some(QueryState::Invalid { .. })
733        ));
734    }
735
736    #[test]
737    fn update_query_data_mut() {
738        let _ = create_runtime();
739
740        provide_query_client();
741        let client = use_query_client();
742
743        let state = |key: u32| {
744            use_query_client()
745                .cache
746                .get_query::<u32, u32>(&key)
747                .map(|q| q.get_state())
748                .and_then(|s| s.data().cloned())
749        };
750
751        let initial_value = 100_u32;
753        client.update_query_data::<u32, u32>(0, move |_| Some(initial_value));
754
755        assert_eq!(state(0), Some(100));
756
757        let update_result = client.update_query_data_mut::<u32, u32>(0, |data| *data += 50);
759
760        assert!(update_result, "Expected data to be updated");
762        assert_eq!(
763            state(0),
764            Some(initial_value + 50),
765            "Data was not updated correctly"
766        );
767
768        let non_existent_update_result =
770            client.update_query_data_mut::<u32, u32>(1, |data| *data += 50);
771
772        assert!(
773            !non_existent_update_result,
774            "Expected no data to be updated for a non-existent query"
775        );
776
777        assert_eq!(state(1), None, "Data was updated for a non-existent query")
778    }
779}