pnwkit_core/
kit.rs

1use crate::{
2    data::QueryReturn,
3    query::{Query, QueryType},
4    request::{ContentType, Method, Request, Response},
5    resolve::Resolve,
6    variable::Variables,
7    Config, Data, Field, Paginator,
8};
9#[cfg(feature = "subscriptions")]
10use crate::{
11    data::SubscriptionAuthData,
12    subscription::{Subscription, SubscriptionEvent, SubscriptionModel},
13    to_query_string::ToQueryString,
14    Object, Value,
15};
16use serde_json::json;
17use std::{sync::Arc, time::Duration};
18
19type GetResult = Result<Data, String>;
20
21#[cfg(feature = "subscriptions")]
22type SubscriptionResult = Result<Arc<Subscription>, String>;
23
24#[derive(Clone, Debug)]
25pub struct Kit {
26    pub config: Arc<Config>,
27}
28
29impl Kit {
30    pub fn new(config: Config) -> Self {
31        Self {
32            config: Arc::new(config),
33        }
34    }
35
36    #[cfg(feature = "async")]
37    pub async fn get(&self, query: &Query) -> GetResult {
38        self.inner_get(query, None).await
39    }
40
41    #[cfg(feature = "sync")]
42    pub fn get_sync(&self, query: &Query) -> GetResult {
43        self.inner_get_sync(query, None)
44    }
45
46    #[cfg(feature = "async")]
47    pub async fn get_with_variables(&self, query: &Query, variables: &Variables) -> GetResult {
48        self.inner_get(query, Some(variables)).await
49    }
50
51    #[cfg(feature = "sync")]
52    pub fn get_with_variables_sync(&self, query: &Query, variables: &Variables) -> GetResult {
53        self.inner_get_sync(query, Some(variables))
54    }
55
56    fn parse_response(&self, response: Response) -> GetResult {
57        let result = serde_json::from_str::<QueryReturn>(&response.body);
58        match result {
59            Ok(json) => {
60                if json.errors.is_some() {
61                    let errors = json.errors.unwrap();
62                    let mut error_messages = Vec::with_capacity(errors.len());
63                    for error in errors {
64                        error_messages.push(error.message);
65                    }
66                    return Err(error_messages.join(", "));
67                }
68                match json.data {
69                    Some(d) => Ok(d),
70                    None => Err("No data".to_string()),
71                }
72            },
73            Err(err) => Err(err.to_string()),
74        }
75    }
76
77    // this should work fine without the drop, but it's here just in case
78    fn hit(&self) -> u64 {
79        let mut rate_limiter = self.config.rate_limiter.lock().unwrap();
80        let wait = rate_limiter.hit();
81        drop(rate_limiter);
82        wait
83    }
84
85    fn handle_429(&self, x_ratelimit_reset: Option<u64>) -> u64 {
86        let mut rate_limiter = self.config.rate_limiter.lock().unwrap();
87        let wait = rate_limiter.handle_429(x_ratelimit_reset);
88        drop(rate_limiter);
89        wait
90    }
91
92    #[cfg(feature = "async")]
93    async fn inner_get(&self, query: &Query, variables: Option<&Variables>) -> GetResult {
94        let request = self.build_request(query, variables);
95        if let Err(msg) = &request {
96            return Err(msg.clone());
97        }
98        let request = request.unwrap();
99        let mut err_msg = "Something went very wrong".to_string();
100        for _ in 1..5 {
101            loop {
102                let wait = self.hit();
103                if wait > 0 {
104                    (self.config.sleep)(Duration::from_secs(wait)).await;
105                } else {
106                    break;
107                }
108            }
109            let response = self.config.client.request(&request).await;
110            if let Err(err) = response {
111                err_msg = err.to_string();
112                continue;
113            }
114            let response = response.unwrap();
115            if response.status == 429 {
116                let wait = self.handle_429(response.x_ratelimit_reset);
117                (self.config.sleep)(Duration::from_secs(wait)).await;
118            }
119            return self.parse_response(response);
120        }
121        Err(format!("Max retries exceeded, returned error: {}", err_msg))
122    }
123
124    #[cfg(feature = "sync")]
125    fn inner_get_sync(&self, query: &Query, variables: Option<&Variables>) -> GetResult {
126        let request = self.build_request(query, variables);
127        if let Err(msg) = &request {
128            return Err(msg.clone());
129        }
130        let request = request.unwrap();
131        let mut err_msg = "Something went very wrong".to_string();
132        for _ in 1..5 {
133            loop {
134                let wait = self.hit();
135                if wait > 0 {
136                    (self.config.sleep_sync)(Duration::from_secs(wait));
137                } else {
138                    break;
139                }
140            }
141            let response = self.config.client.request_sync(&request);
142            if let Err(err) = response {
143                err_msg = err.to_string();
144                continue;
145            }
146            let response = response.unwrap();
147            if response.status == 429 {
148                let wait = self.handle_429(response.x_ratelimit_reset);
149                (self.config.sleep_sync)(Duration::from_secs(wait));
150            }
151            return self.parse_response(response);
152        }
153        Err(format!("Max retries exceeded, returned error: {}", err_msg))
154    }
155
156    pub fn build_request(
157        &self,
158        query: &Query,
159        variables: Option<&Variables>,
160    ) -> Result<Request, String> {
161        if let Err(msg) = query.valid() {
162            return Err(format!("Invalid query: {}", msg));
163        }
164        if let Some(v) = variables {
165            if let Err(msg) = v.valid(
166                query
167                    .get_variables()
168                    .iter()
169                    .map(|v| v.name.clone())
170                    .collect(),
171            ) {
172                return Err(format!("Invalid variables: {}", msg));
173            }
174        }
175        let body = match variables {
176            Some(vars) => {
177                vars.page_init();
178                json!({
179                    "query": query.resolve(),
180                    "variables": vars,
181                })
182            },
183            None => {
184                let vars = Variables::with_capacity(1);
185                vars.page_init();
186                json!({
187                    "query": query.resolve(),
188                    "variables": vars,
189                })
190            },
191        }
192        .to_string();
193        let method = Method::Post;
194        Ok(Request::new(
195            method,
196            self.config.api_url.clone(),
197            Some(body),
198            Some(self.config.headers.clone()),
199            Some(ContentType::Json),
200        ))
201    }
202
203    pub fn query(&self) -> Query {
204        Query::new(QueryType::Query)
205    }
206
207    pub fn mutation(&self) -> Query {
208        Query::new(QueryType::Mutation)
209    }
210
211    pub fn paginator(&self, field: Field) -> Paginator {
212        let query = Query::new(QueryType::Query).field(field);
213        Paginator::new(query)
214    }
215
216    pub fn paginator_with_capacity(&self, field: Field, capacity: u16) -> Paginator {
217        let query = Query::new(QueryType::Query).field(field);
218        Paginator::with_capacity(query, capacity)
219    }
220
221    pub fn paginator_with_variables(&self, field: Field, variables: Variables) -> Paginator {
222        let query = Query::new(QueryType::Query).field(field);
223        Paginator::with_variables(query, variables)
224    }
225
226    pub fn paginator_with_capacity_and_variables(
227        &self,
228        field: Field,
229        variables: Variables,
230        capacity: u16,
231    ) -> Paginator {
232        let query = Query::new(QueryType::Query).field(field);
233        Paginator::with_capacity_and_variables(query, capacity, variables)
234    }
235
236    #[cfg(feature = "subscriptions")]
237    pub async fn subscribe(
238        &self,
239        model: SubscriptionModel,
240        event: SubscriptionEvent,
241    ) -> SubscriptionResult {
242        self.subscribe_inner(model, event, Object::new()).await
243    }
244
245    #[cfg(feature = "subscriptions")]
246    pub async fn subscribe_with_filters(
247        &self,
248        model: SubscriptionModel,
249        event: SubscriptionEvent,
250        filters: Object,
251    ) -> SubscriptionResult {
252        self.subscribe_inner(model, event, filters).await
253    }
254
255    #[cfg(feature = "subscriptions")]
256    async fn subscribe_inner(
257        &self,
258        model: SubscriptionModel,
259        event: SubscriptionEvent,
260        filters: Object,
261    ) -> SubscriptionResult {
262        self.config.socket.init(self.clone()).await;
263        let channel = self
264            .request_subscription_channel(&model, &event, &filters)
265            .await?;
266
267        let subscription = Subscription::new(model, event, filters, channel);
268
269        self.subscribe_request(Arc::new(subscription)).await
270    }
271
272    #[cfg(feature = "subscriptions")]
273    pub async fn subscribe_request(&self, subscription: Arc<Subscription>) -> SubscriptionResult {
274        if !self.config.socket.get_connected().is_set().await {
275            self.config.socket.connect_ref().await?;
276            self.config.socket.start_ping_pong_task();
277        }
278
279        let mut channel = { subscription.channel.lock().await.clone() };
280        let auth = self.authorize_subscription(&channel).await;
281        if let Err(e) = &auth {
282            if e == "unauthorized" {
283                channel = self
284                    .request_subscription_channel(
285                        &subscription.model,
286                        &subscription.event,
287                        &subscription.filters,
288                    )
289                    .await?;
290                subscription.set_channel(channel.clone()).await;
291                let auth = self.authorize_subscription(&channel).await;
292                if let Err(e) = &auth {
293                    return Err(e.clone());
294                }
295            }
296        }
297        let auth = auth.unwrap();
298
299        self.config
300            .socket
301            .add_subscription(subscription.clone())
302            .await;
303
304        self.config
305            .socket
306            .send(
307                json!({
308                    "event": "pusher:subscribe",
309                    "data": {
310                        "channel": channel,
311                        "auth": auth.clone(),
312                    }
313                })
314                .to_string(),
315            )
316            .await?;
317
318        let timeout =
319            tokio::time::timeout(Duration::from_secs(60), subscription.succeeded.wait()).await;
320        if timeout.is_err() {
321            self.config
322                .socket
323                .remove_subscription(subscription.clone())
324                .await;
325            return Err("timed out waiting for subscription to succeed".to_string());
326        }
327
328        Ok(subscription.clone())
329    }
330
331    #[cfg(feature = "subscriptions")]
332    async fn request_subscription_channel(
333        &self,
334        model: &SubscriptionModel,
335        event: &SubscriptionEvent,
336        filters: &Object,
337    ) -> Result<String, String> {
338        let url = self
339            .config
340            .subscribe_url
341            .replace("{model}", &model.to_string())
342            .replace("{event}", &event.to_string());
343        let url = if !filters.is_empty() {
344            format!(
345                "{}?{}",
346                url,
347                serde_urlencoded::to_string(filters.to_query_string()).unwrap()
348            )
349        } else {
350            url
351        };
352        let request = Request::new(
353            Method::Get,
354            url,
355            None,
356            Some(self.config.headers.clone()),
357            Some(ContentType::Json),
358        );
359        let response = self.config.client.request(&request).await?;
360        let json = serde_json::from_str::<Value>(&response.body)
361            .unwrap()
362            .as_object()
363            .unwrap();
364        if let Some(err) = json.get("error") {
365            return Err(err.value().as_string().unwrap());
366        }
367        if let Some(channel) = json.get("channel") {
368            return Ok(channel.value().as_string().unwrap());
369        }
370        Err("malformed response".to_string())
371    }
372
373    #[cfg(feature = "subscriptions")]
374    async fn authorize_subscription(&self, channel: &String) -> Result<String, String> {
375        self.config.socket.get_established().wait().await;
376        let request = Request::new(
377            Method::Post,
378            self.config.subscription_auth_url.clone(),
379            Some(
380                serde_urlencoded::to_string([
381                    ("socket_id", &self.config.socket.get_socket_id().await),
382                    ("channel_name", channel),
383                ])
384                .unwrap(),
385            ),
386            None,
387            Some(ContentType::Form),
388        );
389        let response = self.config.client.request(&request).await?;
390        if response.status != 200 {
391            return Err("unauthorized".into());
392        }
393        let data = serde_json::from_str::<SubscriptionAuthData>(&response.body).unwrap();
394        Ok(data.auth)
395    }
396}