1use crate::{
2 data::QueryReturn,
3 query::{Query, QueryType},
4 request::{ContentType, Method, Request, Response},
5 resolve::Resolve,
6 variable::Variables,
7 Config, Data, Field, Paginator,
8};
9#[cfg(feature = "subscriptions")]
10use crate::{
11 data::SubscriptionAuthData,
12 subscription::{Subscription, SubscriptionEvent, SubscriptionModel},
13 to_query_string::ToQueryString,
14 Object, Value,
15};
16use serde_json::json;
17use std::{sync::Arc, time::Duration};
18
19type GetResult = Result<Data, String>;
20
21#[cfg(feature = "subscriptions")]
22type SubscriptionResult = Result<Arc<Subscription>, String>;
23
24#[derive(Clone, Debug)]
25pub struct Kit {
26 pub config: Arc<Config>,
27}
28
29impl Kit {
30 pub fn new(config: Config) -> Self {
31 Self {
32 config: Arc::new(config),
33 }
34 }
35
36 #[cfg(feature = "async")]
37 pub async fn get(&self, query: &Query) -> GetResult {
38 self.inner_get(query, None).await
39 }
40
41 #[cfg(feature = "sync")]
42 pub fn get_sync(&self, query: &Query) -> GetResult {
43 self.inner_get_sync(query, None)
44 }
45
46 #[cfg(feature = "async")]
47 pub async fn get_with_variables(&self, query: &Query, variables: &Variables) -> GetResult {
48 self.inner_get(query, Some(variables)).await
49 }
50
51 #[cfg(feature = "sync")]
52 pub fn get_with_variables_sync(&self, query: &Query, variables: &Variables) -> GetResult {
53 self.inner_get_sync(query, Some(variables))
54 }
55
56 fn parse_response(&self, response: Response) -> GetResult {
57 let result = serde_json::from_str::<QueryReturn>(&response.body);
58 match result {
59 Ok(json) => {
60 if json.errors.is_some() {
61 let errors = json.errors.unwrap();
62 let mut error_messages = Vec::with_capacity(errors.len());
63 for error in errors {
64 error_messages.push(error.message);
65 }
66 return Err(error_messages.join(", "));
67 }
68 match json.data {
69 Some(d) => Ok(d),
70 None => Err("No data".to_string()),
71 }
72 },
73 Err(err) => Err(err.to_string()),
74 }
75 }
76
77 fn hit(&self) -> u64 {
79 let mut rate_limiter = self.config.rate_limiter.lock().unwrap();
80 let wait = rate_limiter.hit();
81 drop(rate_limiter);
82 wait
83 }
84
85 fn handle_429(&self, x_ratelimit_reset: Option<u64>) -> u64 {
86 let mut rate_limiter = self.config.rate_limiter.lock().unwrap();
87 let wait = rate_limiter.handle_429(x_ratelimit_reset);
88 drop(rate_limiter);
89 wait
90 }
91
92 #[cfg(feature = "async")]
93 async fn inner_get(&self, query: &Query, variables: Option<&Variables>) -> GetResult {
94 let request = self.build_request(query, variables);
95 if let Err(msg) = &request {
96 return Err(msg.clone());
97 }
98 let request = request.unwrap();
99 let mut err_msg = "Something went very wrong".to_string();
100 for _ in 1..5 {
101 loop {
102 let wait = self.hit();
103 if wait > 0 {
104 (self.config.sleep)(Duration::from_secs(wait)).await;
105 } else {
106 break;
107 }
108 }
109 let response = self.config.client.request(&request).await;
110 if let Err(err) = response {
111 err_msg = err.to_string();
112 continue;
113 }
114 let response = response.unwrap();
115 if response.status == 429 {
116 let wait = self.handle_429(response.x_ratelimit_reset);
117 (self.config.sleep)(Duration::from_secs(wait)).await;
118 }
119 return self.parse_response(response);
120 }
121 Err(format!("Max retries exceeded, returned error: {}", err_msg))
122 }
123
124 #[cfg(feature = "sync")]
125 fn inner_get_sync(&self, query: &Query, variables: Option<&Variables>) -> GetResult {
126 let request = self.build_request(query, variables);
127 if let Err(msg) = &request {
128 return Err(msg.clone());
129 }
130 let request = request.unwrap();
131 let mut err_msg = "Something went very wrong".to_string();
132 for _ in 1..5 {
133 loop {
134 let wait = self.hit();
135 if wait > 0 {
136 (self.config.sleep_sync)(Duration::from_secs(wait));
137 } else {
138 break;
139 }
140 }
141 let response = self.config.client.request_sync(&request);
142 if let Err(err) = response {
143 err_msg = err.to_string();
144 continue;
145 }
146 let response = response.unwrap();
147 if response.status == 429 {
148 let wait = self.handle_429(response.x_ratelimit_reset);
149 (self.config.sleep_sync)(Duration::from_secs(wait));
150 }
151 return self.parse_response(response);
152 }
153 Err(format!("Max retries exceeded, returned error: {}", err_msg))
154 }
155
156 pub fn build_request(
157 &self,
158 query: &Query,
159 variables: Option<&Variables>,
160 ) -> Result<Request, String> {
161 if let Err(msg) = query.valid() {
162 return Err(format!("Invalid query: {}", msg));
163 }
164 if let Some(v) = variables {
165 if let Err(msg) = v.valid(
166 query
167 .get_variables()
168 .iter()
169 .map(|v| v.name.clone())
170 .collect(),
171 ) {
172 return Err(format!("Invalid variables: {}", msg));
173 }
174 }
175 let body = match variables {
176 Some(vars) => {
177 vars.page_init();
178 json!({
179 "query": query.resolve(),
180 "variables": vars,
181 })
182 },
183 None => {
184 let vars = Variables::with_capacity(1);
185 vars.page_init();
186 json!({
187 "query": query.resolve(),
188 "variables": vars,
189 })
190 },
191 }
192 .to_string();
193 let method = Method::Post;
194 Ok(Request::new(
195 method,
196 self.config.api_url.clone(),
197 Some(body),
198 Some(self.config.headers.clone()),
199 Some(ContentType::Json),
200 ))
201 }
202
203 pub fn query(&self) -> Query {
204 Query::new(QueryType::Query)
205 }
206
207 pub fn mutation(&self) -> Query {
208 Query::new(QueryType::Mutation)
209 }
210
211 pub fn paginator(&self, field: Field) -> Paginator {
212 let query = Query::new(QueryType::Query).field(field);
213 Paginator::new(query)
214 }
215
216 pub fn paginator_with_capacity(&self, field: Field, capacity: u16) -> Paginator {
217 let query = Query::new(QueryType::Query).field(field);
218 Paginator::with_capacity(query, capacity)
219 }
220
221 pub fn paginator_with_variables(&self, field: Field, variables: Variables) -> Paginator {
222 let query = Query::new(QueryType::Query).field(field);
223 Paginator::with_variables(query, variables)
224 }
225
226 pub fn paginator_with_capacity_and_variables(
227 &self,
228 field: Field,
229 variables: Variables,
230 capacity: u16,
231 ) -> Paginator {
232 let query = Query::new(QueryType::Query).field(field);
233 Paginator::with_capacity_and_variables(query, capacity, variables)
234 }
235
236 #[cfg(feature = "subscriptions")]
237 pub async fn subscribe(
238 &self,
239 model: SubscriptionModel,
240 event: SubscriptionEvent,
241 ) -> SubscriptionResult {
242 self.subscribe_inner(model, event, Object::new()).await
243 }
244
245 #[cfg(feature = "subscriptions")]
246 pub async fn subscribe_with_filters(
247 &self,
248 model: SubscriptionModel,
249 event: SubscriptionEvent,
250 filters: Object,
251 ) -> SubscriptionResult {
252 self.subscribe_inner(model, event, filters).await
253 }
254
255 #[cfg(feature = "subscriptions")]
256 async fn subscribe_inner(
257 &self,
258 model: SubscriptionModel,
259 event: SubscriptionEvent,
260 filters: Object,
261 ) -> SubscriptionResult {
262 self.config.socket.init(self.clone()).await;
263 let channel = self
264 .request_subscription_channel(&model, &event, &filters)
265 .await?;
266
267 let subscription = Subscription::new(model, event, filters, channel);
268
269 self.subscribe_request(Arc::new(subscription)).await
270 }
271
272 #[cfg(feature = "subscriptions")]
273 pub async fn subscribe_request(&self, subscription: Arc<Subscription>) -> SubscriptionResult {
274 if !self.config.socket.get_connected().is_set().await {
275 self.config.socket.connect_ref().await?;
276 self.config.socket.start_ping_pong_task();
277 }
278
279 let mut channel = { subscription.channel.lock().await.clone() };
280 let auth = self.authorize_subscription(&channel).await;
281 if let Err(e) = &auth {
282 if e == "unauthorized" {
283 channel = self
284 .request_subscription_channel(
285 &subscription.model,
286 &subscription.event,
287 &subscription.filters,
288 )
289 .await?;
290 subscription.set_channel(channel.clone()).await;
291 let auth = self.authorize_subscription(&channel).await;
292 if let Err(e) = &auth {
293 return Err(e.clone());
294 }
295 }
296 }
297 let auth = auth.unwrap();
298
299 self.config
300 .socket
301 .add_subscription(subscription.clone())
302 .await;
303
304 self.config
305 .socket
306 .send(
307 json!({
308 "event": "pusher:subscribe",
309 "data": {
310 "channel": channel,
311 "auth": auth.clone(),
312 }
313 })
314 .to_string(),
315 )
316 .await?;
317
318 let timeout =
319 tokio::time::timeout(Duration::from_secs(60), subscription.succeeded.wait()).await;
320 if timeout.is_err() {
321 self.config
322 .socket
323 .remove_subscription(subscription.clone())
324 .await;
325 return Err("timed out waiting for subscription to succeed".to_string());
326 }
327
328 Ok(subscription.clone())
329 }
330
331 #[cfg(feature = "subscriptions")]
332 async fn request_subscription_channel(
333 &self,
334 model: &SubscriptionModel,
335 event: &SubscriptionEvent,
336 filters: &Object,
337 ) -> Result<String, String> {
338 let url = self
339 .config
340 .subscribe_url
341 .replace("{model}", &model.to_string())
342 .replace("{event}", &event.to_string());
343 let url = if !filters.is_empty() {
344 format!(
345 "{}?{}",
346 url,
347 serde_urlencoded::to_string(filters.to_query_string()).unwrap()
348 )
349 } else {
350 url
351 };
352 let request = Request::new(
353 Method::Get,
354 url,
355 None,
356 Some(self.config.headers.clone()),
357 Some(ContentType::Json),
358 );
359 let response = self.config.client.request(&request).await?;
360 let json = serde_json::from_str::<Value>(&response.body)
361 .unwrap()
362 .as_object()
363 .unwrap();
364 if let Some(err) = json.get("error") {
365 return Err(err.value().as_string().unwrap());
366 }
367 if let Some(channel) = json.get("channel") {
368 return Ok(channel.value().as_string().unwrap());
369 }
370 Err("malformed response".to_string())
371 }
372
373 #[cfg(feature = "subscriptions")]
374 async fn authorize_subscription(&self, channel: &String) -> Result<String, String> {
375 self.config.socket.get_established().wait().await;
376 let request = Request::new(
377 Method::Post,
378 self.config.subscription_auth_url.clone(),
379 Some(
380 serde_urlencoded::to_string([
381 ("socket_id", &self.config.socket.get_socket_id().await),
382 ("channel_name", channel),
383 ])
384 .unwrap(),
385 ),
386 None,
387 Some(ContentType::Form),
388 );
389 let response = self.config.client.request(&request).await?;
390 if response.status != 200 {
391 return Err("unauthorized".into());
392 }
393 let data = serde_json::from_str::<SubscriptionAuthData>(&response.body).unwrap();
394 Ok(data.auth)
395 }
396}