Skip to main content

strava_wrapper/
auth.rs

1use crate::query::{parse_error_body, ErrorWrapper};
2use reqwest::Client;
3use serde::{Deserialize, Serialize};
4use std::sync::OnceLock;
5
6pub const AUTH_URL: &str = "https://www.strava.com/oauth/token";
7pub const DEAUTH_URL: &str = "https://www.strava.com/oauth/deauthorize";
8
9fn auth_client() -> &'static Client {
10    static CLIENT: OnceLock<Client> = OnceLock::new();
11    CLIENT.get_or_init(Client::new)
12}
13
14/// Exchange an authorization code for an access token, hitting Strava's
15/// production OAuth endpoint.
16pub async fn get_token(
17    client_id: u32,
18    client_secret: &str,
19    code: &str,
20) -> Result<TokenResponse, ErrorWrapper> {
21    get_token_at(AUTH_URL, client_id, client_secret, code).await
22}
23
24/// Same as [`get_token`] but against a caller-supplied URL. Useful for
25/// integration tests (pointing at httpmock) or for staging/dev OAuth hosts.
26pub async fn get_token_at(
27    url: &str,
28    client_id: u32,
29    client_secret: &str,
30    code: &str,
31) -> Result<TokenResponse, ErrorWrapper> {
32    let response = auth_client()
33        .post(url)
34        .form(&TokenRequest {
35            client_id,
36            client_secret: client_secret.to_string(),
37            code: code.to_string(),
38            grant_type: "authorization_code".into(),
39        })
40        .send()
41        .await?;
42
43    handle_token_response(response).await
44}
45
46/// Swap an expiring access token for a fresh one using a stored refresh token.
47pub async fn refresh_token(
48    client_id: u32,
49    client_secret: &str,
50    refresh_token: &str,
51) -> Result<TokenResponse, ErrorWrapper> {
52    refresh_token_at(AUTH_URL, client_id, client_secret, refresh_token).await
53}
54
55/// Same as [`refresh_token`] but against a caller-supplied URL.
56pub async fn refresh_token_at(
57    url: &str,
58    client_id: u32,
59    client_secret: &str,
60    refresh_token: &str,
61) -> Result<TokenResponse, ErrorWrapper> {
62    let response = auth_client()
63        .post(url)
64        .form(&RefreshRequest {
65            client_id,
66            client_secret: client_secret.to_string(),
67            refresh_token: refresh_token.to_string(),
68            grant_type: "refresh_token".into(),
69        })
70        .send()
71        .await?;
72
73    handle_token_response(response).await
74}
75
76/// Revoke an access token.
77pub async fn deauthorize(access_token: &str) -> Result<(), ErrorWrapper> {
78    deauthorize_at(DEAUTH_URL, access_token).await
79}
80
81/// Same as [`deauthorize`] but against a caller-supplied URL.
82pub async fn deauthorize_at(url: &str, access_token: &str) -> Result<(), ErrorWrapper> {
83    let response = auth_client()
84        .post(url)
85        .header("Authorization", format!("Bearer {}", access_token))
86        .send()
87        .await?;
88
89    let status = response.status();
90    if status.is_success() {
91        Ok(())
92    } else {
93        let body = response.text().await?;
94        Err(ErrorWrapper::Api {
95            status,
96            response: parse_error_body(&body),
97            rate_limit: None,
98        })
99    }
100}
101
102async fn handle_token_response(response: reqwest::Response) -> Result<TokenResponse, ErrorWrapper> {
103    let status = response.status();
104    let body = response.text().await?;
105    if status.is_success() {
106        serde_json::from_str::<TokenResponse>(&body)
107            .map_err(|error| ErrorWrapper::Parse { error, body })
108    } else {
109        Err(ErrorWrapper::Api {
110            status,
111            response: parse_error_body(&body),
112            rate_limit: None,
113        })
114    }
115}
116
117#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
118struct TokenRequest {
119    client_id: u32,
120    client_secret: String,
121    code: String,
122    grant_type: String,
123}
124
125#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
126struct RefreshRequest {
127    client_id: u32,
128    client_secret: String,
129    refresh_token: String,
130    grant_type: String,
131}
132
133#[derive(Debug, Clone, Deserialize)]
134pub struct TokenResponse {
135    pub token_type: String,
136    pub access_token: String,
137    pub expires_at: u64,
138    pub expires_in: u64,
139    pub refresh_token: String,
140}