artemis/default_exchanges/
dedup.rs

1use crate::{
2    exchange::Client,
3    types::{ExchangeResult, Operation, OperationResult},
4    Exchange, ExchangeFactory, GraphQLQuery, OperationType, QueryError
5};
6use futures::channel::{oneshot, oneshot::Sender};
7use std::{
8    any::Any,
9    collections::HashMap,
10    error::Error,
11    fmt,
12    sync::{Arc, Mutex}
13};
14
15type InFlightCache = Arc<Mutex<HashMap<u64, Vec<Sender<Result<Box<dyn Any + Send>, QueryError>>>>>>;
16
17/// The default deduplication exchange.
18///
19/// This will keep track of in-flight queries and catch any identical queries before they execute,
20/// instead waiting for the result from the in-flight query
21pub struct DedupExchange;
22pub struct DedupExchangeImpl<TNext: Exchange> {
23    next: TNext,
24    in_flight_operations: InFlightCache
25}
26
27impl<TNext: Exchange> ExchangeFactory<TNext> for DedupExchange {
28    type Output = DedupExchangeImpl<TNext>;
29
30    fn build(self, next: TNext) -> Self::Output {
31        DedupExchangeImpl {
32            next,
33            in_flight_operations: InFlightCache::default()
34        }
35    }
36}
37
38#[derive(Debug, Clone)]
39pub struct DedupError;
40impl Error for DedupError {}
41impl fmt::Display for DedupError {
42    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
43        write!(f, "") //TODO: This isn't ideal
44    }
45}
46
47fn should_skip<Q: GraphQLQuery>(operation: &Operation<Q::Variables>) -> bool {
48    let op_type = &operation.meta.operation_type;
49    op_type != &OperationType::Query && op_type != &OperationType::Mutation
50}
51
52fn make_deduped_result<Q: GraphQLQuery>(
53    res: &ExchangeResult<Q::ResponseData>
54) -> Result<Box<dyn Any + Send>, QueryError> {
55    match res {
56        Ok(ref res) => {
57            let mut res = res.clone();
58            if let Some(ref mut debug_info) = res.response.debug_info {
59                debug_info.did_dedup = true;
60            }
61            Ok(Box::new(res))
62        }
63        Err(e) => Err(e.clone())
64    }
65}
66
67impl<TNext: Exchange> DedupExchangeImpl<TNext> {
68    fn notify_listeners<Q: GraphQLQuery>(&self, key: u64, res: &ExchangeResult<Q::ResponseData>) {
69        let mut cache = self.in_flight_operations.lock().unwrap();
70        let to_be_notified = cache.remove(&key).unwrap();
71        for sender in to_be_notified {
72            let res = make_deduped_result::<Q>(res);
73            sender.send(res).unwrap();
74        }
75    }
76}
77
78#[async_trait]
79impl<TNext: Exchange> Exchange for DedupExchangeImpl<TNext> {
80    async fn run<Q: GraphQLQuery, C: Client>(
81        &self,
82        operation: Operation<Q::Variables>,
83        _client: C
84    ) -> ExchangeResult<Q::ResponseData> {
85        if should_skip::<Q>(&operation) {
86            return self.next.run::<Q, _>(operation, _client).await;
87        }
88
89        let key = operation.key;
90        let rcv = {
91            let mut cache = self.in_flight_operations.lock().unwrap();
92            if let Some(listeners) = cache.get_mut(&key) {
93                let (sender, receiver) = oneshot::channel();
94                listeners.push(sender);
95                Some(receiver)
96            } else {
97                cache.insert(key, Vec::new());
98                None
99            }
100        };
101
102        if let Some(rcv) = rcv {
103            let res: Box<dyn Any> = rcv.await.unwrap()?;
104            let res: OperationResult<Q::ResponseData> = *res.downcast().unwrap();
105            Ok(res)
106        } else {
107            let res = self.next.run::<Q, _>(operation, _client).await;
108            self.notify_listeners::<Q>(key, &res);
109            res
110        }
111    }
112}
113
114#[cfg(all(test, not(target_arch = "wasm32")))]
115mod test {
116    use super::DedupExchangeImpl;
117    use crate::{
118        default_exchanges::DedupExchange,
119        exchange::Client,
120        types::{Operation, OperationOptions, OperationResult},
121        ClientBuilder, DebugInfo, Exchange, ExchangeFactory, ExchangeResult, FieldSelector,
122        GraphQLQuery, OperationMeta, OperationType, QueryBody, QueryInfo, RequestPolicy, Response,
123        ResultSource
124    };
125    use artemis_test::get_conference::{
126        get_conference::{ResponseData, Variables, OPERATION_NAME, QUERY},
127        GetConference
128    };
129    use lazy_static::lazy_static;
130    use std::time::Duration;
131    use tokio::time::delay_for;
132
133    lazy_static! {
134        static ref VARIABLES: Variables = Variables {
135            id: "1".to_string()
136        };
137        static ref EXCHANGE: DedupExchangeImpl<FakeFetchExchange> =
138            DedupExchange.build(FakeFetchExchange);
139    }
140
141    fn url() -> String {
142        "http://localhost:8080/graphql".to_string()
143    }
144
145    struct FakeFetchExchange;
146
147    impl<TNext: Exchange> ExchangeFactory<TNext> for FakeFetchExchange {
148        type Output = FakeFetchExchange;
149
150        fn build(self, _next: TNext) -> FakeFetchExchange {
151            Self
152        }
153    }
154
155    #[async_trait]
156    impl Exchange for FakeFetchExchange {
157        async fn run<Q: GraphQLQuery, C: Client>(
158            &self,
159            operation: Operation<Q::Variables>,
160            _client: C
161        ) -> ExchangeResult<Q::ResponseData> {
162            delay_for(Duration::from_millis(10)).await;
163            let res = OperationResult {
164                key: operation.key,
165                meta: operation.meta,
166                response: Response {
167                    debug_info: Some(DebugInfo {
168                        source: ResultSource::Network,
169                        did_dedup: false
170                    }),
171                    data: None,
172                    errors: None
173                }
174            };
175            Ok(res)
176        }
177    }
178
179    fn make_operation(query: QueryBody<Variables>, meta: OperationMeta) -> Operation<Variables> {
180        Operation {
181            key: meta.query_key as u64,
182            meta,
183            query,
184            options: OperationOptions {
185                request_policy: RequestPolicy::NetworkOnly,
186                extra_headers: None,
187                url: url(),
188                extensions: None
189            }
190        }
191    }
192
193    fn build_query(variables: Variables) -> (QueryBody<Variables>, OperationMeta) {
194        let meta = OperationMeta {
195            query_key: 13543040u32,
196            operation_type: OperationType::Query,
197            involved_types: vec!["Conference", "Person", "Talk"]
198        };
199        let body = QueryBody {
200            variables,
201            query: QUERY,
202            operation_name: OPERATION_NAME
203        };
204        (body, meta)
205    }
206
207    impl GraphQLQuery for GetConference {
208        type Variables = Variables;
209        type ResponseData = ResponseData;
210
211        fn build_query(_variables: Self::Variables) -> (QueryBody<Self::Variables>, OperationMeta) {
212            unimplemented!()
213        }
214    }
215
216    impl QueryInfo<Variables> for ResponseData {
217        fn selection(_variables: &Variables) -> Vec<FieldSelector> {
218            unimplemented!()
219        }
220    }
221
222    #[tokio::test]
223    async fn test_dedup() {
224        let (query, meta) = build_query(VARIABLES.clone());
225
226        let client = ClientBuilder::new("http://localhost:4000/graphql").build();
227
228        let fut1 = EXCHANGE.run::<GetConference, _>(
229            make_operation(query.clone(), meta.clone()),
230            client.0.clone()
231        );
232        let fut2 = EXCHANGE.run::<GetConference, _>(
233            make_operation(query.clone(), meta.clone()),
234            client.0.clone()
235        );
236        let join = tokio::spawn(async { fut1.await.unwrap() });
237        let res2 = fut2.await.unwrap();
238        let res1 = join.await.unwrap();
239
240        // The order can vary depending on the executor state, so XOR them
241        let did_1_dedup = res1.response.debug_info.unwrap().did_dedup;
242        let did_2_dedup = res2.response.debug_info.unwrap().did_dedup;
243        let did_one_dedup = did_1_dedup ^ did_2_dedup;
244
245        assert_eq!(did_one_dedup, true);
246    }
247}