Skip to main content

pkrs_fork/
client.rs

1use std::fmt::Display;
2
3use reqwest::{Client, RequestBuilder, Response, StatusCode};
4use serde::{Deserialize, Serialize};
5use thiserror::Error;
6use time::format_description::well_known::Rfc3339;
7use time::OffsetDateTime;
8use uuid::Uuid;
9
10#[cfg(feature = "metrics")]
11use std::time::Instant;
12#[cfg(feature = "metrics")]
13use url::Url;
14
15use crate::model::*;
16
17const BASE_URL: &str = "https://api.pluralkit.me/v2/";
18
19pub struct PkClient {
20    pub client: Client,
21    pub token: String,
22    pub user_agent: String,
23}
24
25impl Default for PkClient {
26    fn default() -> Self {
27        PkClient {
28            client: Client::builder()
29                .build()
30                .expect("failed to create reqwest client"),
31            token: "".to_string(),
32            user_agent: "pk + rust project".to_string(), // todo: this is kind of bad
33        }
34    }
35}
36
37impl PkClient {
38    /// return a new PkClient based on the current one, but with the specified token
39    pub fn with_token(&self, token: String) -> Self {
40        Self {
41            token,
42            client: self.client.clone(),
43            user_agent: self.user_agent.clone(),
44        }
45    }
46
47    pub async fn get_system(&self, system_id: &PkId) -> Result<System, PluralKitError> {
48        self.get(format!("systems/{}", system_id)).await
49    }
50
51    pub async fn update_system(&self, system: &System) -> Result<System, PluralKitError> {
52        self.patch_get_json(format!("systems/{}", system.id), system)
53            .await
54    }
55
56    pub async fn get_system_settings(
57        &self,
58        system_id: &PkId,
59    ) -> Result<SystemSettings, PluralKitError> {
60        self.get(format!("systems/{}/settings", system_id)).await
61    }
62
63    pub async fn update_system_settings(
64        &self,
65        system_id: &PkId,
66        settings: &SystemSettings,
67    ) -> Result<SystemSettings, PluralKitError> {
68        self.patch_get_json(format!("systems/{}/settings", system_id), settings)
69            .await
70    }
71
72    pub async fn get_system_guild_settings(
73        &self,
74        system_id: &PkId,
75        guild_id: &str,
76    ) -> Result<SystemGuildSettings, PluralKitError> {
77        let req = format!("systems/{}/settings/guilds/{}", system_id, guild_id);
78        self.get(req).await
79    }
80
81    pub async fn update_system_guild_settings(
82        &self,
83        system_id: &PkId,
84        guild_id: &str,
85        settings: &SystemGuildSettings,
86    ) -> Result<SystemGuildSettings, PluralKitError> {
87        let req = format!("systems/{}/settings/guilds/{}", system_id, guild_id);
88        self.patch_get_json(req, settings).await
89    }
90
91    pub async fn get_system_autoproxy_settings(
92        &self,
93        system_id: &PkId,
94        guild_id: &str,
95    ) -> Result<AutoProxySettings, PluralKitError> {
96        let req = format!("systems/{}/autoproxy", system_id);
97        self.get_query_get_json(req, &[("guild_id", guild_id)])
98            .await
99    }
100
101    pub async fn update_system_autoproxy_settings(
102        &self,
103        system_id: &PkId,
104        guild_id: &str,
105        settings: &AutoProxySettings,
106    ) -> Result<AutoProxySettings, PluralKitError> {
107        let req = format!("systems/{}/autoproxy", system_id);
108        self.patch_query_get_json(req, settings, &[("guild_id", guild_id)])
109            .await
110    }
111
112    pub async fn get_system_members(
113        &self,
114        system_id: &PkId,
115    ) -> Result<Vec<Member>, PluralKitError> {
116        self.get(format!("systems/{}/members", system_id)).await
117    }
118
119    pub async fn create_member(&self, member: &Member) -> Result<Member, PluralKitError> {
120        self.post_get_json("members".to_string(), member).await
121    }
122
123    pub async fn get_member(&self, member_id: &PkId) -> Result<Member, PluralKitError> {
124        self.get(format!("members/{}", member_id)).await
125    }
126
127    pub async fn update_member(&self, member: &Member) -> Result<Member, PluralKitError> {
128        self.patch_get_json(format!("members/{}", member.id), member)
129            .await
130    }
131
132    pub async fn delete_member(&self, member_id: &PkId) -> Result<(), PluralKitError> {
133        expect_no_content(&self.delete(format!("members/{}", member_id)).await?)
134    }
135
136    pub async fn get_member_groups(&self, member_id: &PkId) -> Result<Vec<Group>, PluralKitError> {
137        self.get(format!("members/{}/groups", member_id)).await
138    }
139
140    async fn member_groups(
141        &self,
142        action: &str,
143        member_id: &PkId,
144        group_ids: &[&PkId],
145    ) -> Result<(), PluralKitError> {
146        expect_no_content(
147            &self
148                .request(
149                    self.client
150                        .post(format!(
151                            "{}members/{}/groups/{}",
152                            BASE_URL, member_id, action
153                        ))
154                        .json(group_ids),
155                )
156                .await?,
157        )
158    }
159
160    pub async fn add_member_groups(
161        &self,
162        member_id: &PkId,
163        group_ids: &[&PkId],
164    ) -> Result<(), PluralKitError> {
165        self.member_groups("add", member_id, group_ids).await
166    }
167
168    pub async fn remove_member_groups(
169        &self,
170        member_id: &PkId,
171        group_ids: &[&PkId],
172    ) -> Result<(), PluralKitError> {
173        self.member_groups("remove", member_id, group_ids).await
174    }
175
176    pub async fn overwrite_member_groups(
177        &self,
178        member_id: &PkId,
179        group_ids: &[&PkId],
180    ) -> Result<(), PluralKitError> {
181        self.member_groups("overwrite", member_id, group_ids).await
182    }
183
184    pub async fn get_member_guild_settings(
185        &self,
186        member_id: &PkId,
187        guild_id: &str,
188    ) -> Result<Vec<MemberGuildSettings>, PluralKitError> {
189        self.get(format!("members/{}/guilds/{}", member_id, guild_id))
190            .await
191    }
192
193    pub async fn update_member_guild_settings(
194        &self,
195        member_id: &PkId,
196        guild_id: &str,
197        settings: &MemberGuildSettings,
198    ) -> Result<MemberGuildSettings, PluralKitError> {
199        self.patch_get_json(
200            format!("members/{}/guilds/{}", member_id, guild_id),
201            settings,
202        )
203        .await
204    }
205
206    pub async fn get_system_groups(&self, system_id: &PkId) -> Result<Vec<Group>, PluralKitError> {
207        self.get(format!("systems/{}/groups", system_id)).await
208    }
209
210    pub async fn create_group(&self, group: &Group) -> Result<Group, PluralKitError> {
211        self.post_get_json("groups".to_string(), group).await
212    }
213
214    pub async fn get_group(&self, group_id: &PkId) -> Result<Group, PluralKitError> {
215        self.get(format!("groups/{}", group_id)).await
216    }
217
218    pub async fn update_group(&self, group: &Group) -> Result<Group, PluralKitError> {
219        self.patch_get_json(format!("groups/{}", group.id), group)
220            .await
221    }
222
223    pub async fn delete_group(&self, group_id: &PkId) -> Result<(), PluralKitError> {
224        expect_no_content(&self.delete(format!("groups/{}", group_id)).await?)
225    }
226
227    pub async fn get_group_members(&self, group_id: &PkId) -> Result<Vec<Member>, PluralKitError> {
228        self.get(format!("groups/{}/members", group_id)).await
229    }
230
231    async fn group_members(
232        &self,
233        action: &str,
234        group_id: &PkId,
235        member_ids: &[&PkId],
236    ) -> Result<(), PluralKitError> {
237        expect_no_content(
238            &self
239                .request(
240                    self.client
241                        .post(format!(
242                            "{}groups/{}/members/{}",
243                            BASE_URL, group_id, action
244                        ))
245                        .json(member_ids),
246                )
247                .await?,
248        )
249    }
250
251    pub async fn add_group_members(
252        &self,
253        group_id: &PkId,
254        member_ids: &[&PkId],
255    ) -> Result<(), PluralKitError> {
256        self.group_members("add", group_id, member_ids).await
257    }
258
259    pub async fn remove_group_members(
260        &self,
261        group_id: &PkId,
262        member_ids: &[&PkId],
263    ) -> Result<(), PluralKitError> {
264        self.group_members("remove", group_id, member_ids).await
265    }
266
267    pub async fn overwrite_group_members(
268        &self,
269        group_id: &PkId,
270        member_ids: &[&PkId],
271    ) -> Result<(), PluralKitError> {
272        self.group_members("overwrite", group_id, member_ids).await
273    }
274
275    pub async fn get_system_switches(
276        &self,
277        system_id: &PkId,
278        before: &OffsetDateTime,
279        limit: &i32,
280    ) -> Result<Vec<Switch>, PluralKitError> {
281        self.get_query_get_json(
282            format!("systems/{}/switches", system_id),
283            &[
284                ("before", before.format(&Rfc3339).unwrap().as_str()),
285                ("limit", limit.to_string().as_str()),
286            ],
287        )
288        .await
289    }
290
291    pub async fn get_system_fronters(
292        &self,
293        system_id: &PkId,
294    ) -> Result<Option<Switch>, PluralKitError> {
295        let resp = self
296            .request(
297                self.client
298                    .get(format!("{BASE_URL}systems/{system_id}/fronters")),
299            )
300            .await?;
301
302        if resp.status() == StatusCode::NO_CONTENT {
303            return Ok(None);
304        }
305
306        Ok(Some(resp.json().await?))
307    }
308
309    pub async fn create_switch(
310        &self,
311        system_id: &PkId,
312        member_ids: Vec<PkId>,
313        time: Option<OffsetDateTime>,
314    ) -> Result<Response, PluralKitError> {
315        #[derive(Serialize, Deserialize, Debug)]
316        struct SwitchCreate {
317            #[serde(with = "time::serde::rfc3339::option")]
318            timestamp: Option<OffsetDateTime>,
319            members: Vec<PkId>,
320        }
321        self.post(
322            format!("systems/{}/switches", system_id),
323            &SwitchCreate {
324                timestamp: time,
325                members: member_ids,
326            },
327        )
328        .await
329    }
330
331    pub async fn get_switch(
332        &self,
333        system_id: &PkId,
334        switch_id: &Uuid,
335    ) -> Result<Vec<Switch>, PluralKitError> {
336        self.get(format!("systems/{}/switches/{}", system_id, switch_id))
337            .await
338    }
339
340    pub async fn update_switch(
341        &self,
342        system_id: &PkId,
343        switch_id: &Uuid,
344        time: OffsetDateTime,
345    ) -> Result<Switch, PluralKitError> {
346        let req = format!("systems/{}/switches/{}", system_id, switch_id);
347        #[derive(Serialize, Deserialize, Debug)]
348        struct SwitchTimeUpdate {
349            #[serde(with = "time::serde::rfc3339")]
350            timestamp: OffsetDateTime,
351        }
352        self.get_response_json(
353            self.client
354                .patch(BASE_URL.to_string() + req.as_str())
355                .json(&SwitchTimeUpdate { timestamp: time }),
356        )
357        .await
358    }
359
360    pub async fn update_switch_members(
361        &self,
362        system_id: &PkId,
363        switch_id: &Uuid,
364        members: &[&PkId],
365    ) -> Result<Switch, PluralKitError> {
366        let req = format!("systems/{}/switches/{}/members", system_id, switch_id);
367        self.get_response_json(
368            self.client
369                .patch(BASE_URL.to_string() + req.as_str())
370                .json(members),
371        )
372        .await
373    }
374
375    pub async fn delete_switch(
376        &self,
377        system_id: &PkId,
378        switch_id: &Uuid,
379    ) -> Result<(), PluralKitError> {
380        expect_no_content(
381            &self
382                .delete(format!("systems/{}/switches/{}", system_id, switch_id))
383                .await?,
384        )
385    }
386
387    pub async fn get_message(&self, id: &str) -> Result<Message, PluralKitError> {
388        self.get(format!("messages/{}", id)).await
389    }
390
391    // all
392    async fn get<T: for<'a> Deserialize<'a>>(&self, endpoint: String) -> Result<T, PluralKitError> {
393        self.get_response_json(self.client.get(BASE_URL.to_string() + &*endpoint))
394            .await
395    }
396
397    // of this
398    async fn get_query_get_json<T: for<'a> Deserialize<'a>>(
399        &self,
400        endpoint: String,
401        query: &[(&str, &str)],
402    ) -> Result<T, PluralKitError> {
403        self.get_response_json(
404            self.client
405                .get(BASE_URL.to_string() + &*endpoint)
406                .query(query),
407        )
408        .await
409    }
410
411    // duplication
412    async fn patch_get_json<T>(&self, endpoint: String, body: &T) -> Result<T, PluralKitError>
413    where
414        T: Serialize + for<'a> Deserialize<'a>,
415    {
416        self.get_response_json(
417            self.client
418                .patch(BASE_URL.to_string() + &*endpoint)
419                .json(body),
420        )
421        .await
422    }
423
424    // feels
425    async fn patch_query_get_json<T>(
426        &self,
427        endpoint: String,
428        body: &T,
429        query: &[(&str, &str)],
430    ) -> Result<T, PluralKitError>
431    where
432        T: Serialize + for<'a> Deserialize<'a>,
433    {
434        self.get_response_json(
435            self.client
436                .patch(BASE_URL.to_string() + &*endpoint)
437                .query(query)
438                .json(body),
439        )
440        .await
441    }
442
443    // extremely
444    async fn post_get_json<T>(&self, endpoint: String, body: &T) -> Result<T, PluralKitError>
445    where
446        T: Serialize + for<'a> Deserialize<'a>,
447    {
448        self.get_response_json(
449            self.client
450                .post(BASE_URL.to_string() + &*endpoint)
451                .json(body),
452        )
453        .await
454    }
455
456    // unclean
457    async fn post<T: Serialize>(
458        &self,
459        endpoint: String,
460        body: &T,
461    ) -> Result<Response, PluralKitError> {
462        self.request(
463            self.client
464                .post(BASE_URL.to_string() + &*endpoint)
465                .json(body),
466        )
467        .await
468    }
469
470    async fn get_response_json<T: for<'a> Deserialize<'a>>(
471        &self,
472        builder: RequestBuilder,
473    ) -> Result<T, PluralKitError> {
474        Ok(self.request(builder).await?.json::<T>().await?)
475    }
476
477    async fn delete(&self, endpoint: String) -> Result<Response, PluralKitError> {
478        self.request(self.client.delete(BASE_URL.to_string() + &*endpoint))
479            .await
480    }
481
482    async fn request(&self, builder: RequestBuilder) -> Result<Response, PluralKitError> {
483        #[cfg(feature = "metrics")]
484        let now = Instant::now();
485
486        let req = builder
487            .header("User-Agent", &self.user_agent)
488            .header("Authorization", &self.token)
489            .build()?;
490
491        #[cfg(feature = "metrics")]
492        let method = req.method().to_string();
493        #[cfg(feature = "metrics")]
494        let url = req.url().clone();
495
496        let resp = self.client.execute(req).await;
497
498        #[cfg(feature = "metrics")]
499        match resp {
500            Ok(ref resp) => track_request(
501                method,
502                &url,
503                Some(resp.status()),
504                now.elapsed().as_secs_f64(),
505            ),
506            Err(_) => track_request(method, &url, None, now.elapsed().as_secs_f64()),
507        }
508
509        let resp = match resp {
510            Ok(resp) => resp,
511            Err(err) => return Err(PluralKitError::Reqwest(err)),
512        };
513
514        let err_for_status = match resp.error_for_status_ref() {
515            Ok(_) => return Ok(resp),
516            Err(err) => err,
517        };
518
519        match resp.json::<ErrorMessage>().await {
520            Ok(message) => Err(PluralKitError::Pk(
521                err_for_status
522                    .status()
523                    .expect("error_for_status() always populates status()"),
524                message,
525            )),
526            Err(_) => Err(PluralKitError::Reqwest(err_for_status)),
527        }
528    }
529}
530
531#[cfg(feature = "metrics")]
532fn track_request(method: String, url: &Url, status: Option<StatusCode>, latency: f64) {
533    // normalise the request path stripping ids/snowflakes/uuids
534    let normalized_path: String = match url.path_segments() {
535        None => "/".to_string(),
536        #[rustfmt::skip]
537        Some(segments) => {
538            // we skip the api version, v1 should just fallthrough
539            let segment_vec: Vec<&str> = segments.skip(1).collect();
540            match segment_vec[..] {
541                // /members
542                // /groups
543                [resource                                     ] => format!("/v2/{resource}"),
544                // /systems/{systemRef}
545                // /members/{memberRef}
546                // /groups/{groupRef}
547                // /messages/{message}
548                [resource , _                                 ] => format!("/v2/{resource}/!"),
549                // /members/{memberRef}/groups/add
550                // /members/{memberRef}/groups/remove
551                // /members/{memberRef}/groups/overwrite
552                ["members", _, "groups"   , action            ] => format!("/v2/members/!/groups/{action}"),
553                // /groups/{groupRef}/members/add
554                // /groups/{groupRef}/members/remove
555                // /groups/{groupRef}/members/overwrite
556                ["groups" , _, "members"  , action            ] => format!("/v2/groups/!/members/{action}"),
557                // /systems/{systemRef}/settings
558                // /systems/@me/autoproxy
559                // /systems/{systemRef}/members
560                // /members/{memberRef}/groups
561                // /systems/{systemRef}/groups
562                // /groups/{groupRef}/members
563                // /systems/{systemRef}/switches
564                // /systems/{systemRef}/fronters
565                [resource , _, subresource                    ] => format!("/v2/{resource}/!/{subresource}"),
566                // /systems/@me/guilds/{guild_id}
567                // /members/{memberRef}/guilds/{guild_id}
568                // /systems/{systemRef}/switches/{switchRef}
569                [resource , _, subresource, _                 ] => format!("/v2/{resource}/!/{subresource}/!"),
570                // /systems/{systemRef}/switches/{switchRef}/members
571                [resource , _, subresource, _, subsubresource ] => format!("/v2/{resource}/!/{subresource}/!/{subsubresource}"),
572                _ => url.path().to_string(),
573            }
574        }
575    };
576
577    metrics::histogram!(
578        "pkrs_request_latency",
579        "path" => normalized_path.clone(),
580        "method" => method.clone(),
581    )
582    .record(latency);
583    metrics::counter!(
584        "pkrs_request",
585        "path" => normalized_path,
586        "method" => method,
587        "status" => status.map(|s| s.to_string()).unwrap_or_else(|| "Unknown".to_string()),
588    )
589    .increment(1);
590}
591
592/// convenience function for handling endpoints that return `204 No Content`
593/// on success
594fn expect_no_content(response: &Response) -> Result<(), PluralKitError> {
595    if response.status() == StatusCode::NO_CONTENT {
596        Ok(())
597    } else {
598        Err(PluralKitError::Other(
599            format!(
600                "received status {} expected {}",
601                response.status(),
602                StatusCode::NO_CONTENT
603            )
604            .into(),
605        ))
606    }
607}
608
609#[derive(Error, Debug)]
610pub enum PluralKitError {
611    #[error("pluralkit api error, status: {0}, error: {1}")]
612    Pk(StatusCode, ErrorMessage),
613    #[error(transparent)]
614    Reqwest(#[from] reqwest::Error),
615    #[error(transparent)]
616    Other(#[from] Box<dyn std::error::Error + Send + Sync>),
617}
618
619#[derive(Deserialize, Serialize, Debug)]
620pub struct ErrorMessage {
621    pub code: u32,
622    pub message: String,
623    pub errors: Option<Vec<ErrorMessageObject>>,
624    pub retry_after: Option<u32>,
625}
626
627impl Display for ErrorMessage {
628    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
629        write!(f, "{}: {}", self.code, self.message)
630    }
631}
632
633#[derive(Deserialize, Serialize, Debug)]
634pub struct ErrorMessageObject {
635    pub message: String,
636    pub max_length: Option<u32>,
637    pub actual_length: Option<u32>,
638}