Skip to main content

strava_wrapper/
query.rs

1use async_trait::async_trait;
2use reqwest::header::HeaderMap;
3use reqwest::{Client, StatusCode, Url};
4use serde::de::DeserializeOwned;
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use std::fmt;
8use std::fmt::Debug;
9use std::sync::{OnceLock, RwLock};
10
11pub const API_URL: &str = "https://www.strava.com/api";
12
13/// Process-wide reqwest client. Reused across all requests so the connection
14/// pool and TLS session cache don't get rebuilt for every call.
15fn http_client() -> &'static Client {
16    static CLIENT: OnceLock<Client> = OnceLock::new();
17    CLIENT.get_or_init(Client::new)
18}
19
20/// Counters from Strava's `X-RateLimit-Usage` and `X-RateLimit-Limit`
21/// response headers. Strava tracks two windows per application:
22///
23/// - **short term** — 15-minute rolling window
24/// - **long term**  — 24-hour rolling window
25///
26/// Strava's default quotas are 100 (short) / 1000 (long); elevated apps get
27/// higher limits. Both counters are per-application, not per-token, so a
28/// process-wide "last seen" snapshot is an accurate view of consumption.
29#[derive(Debug, Clone, Copy, PartialEq, Eq)]
30pub struct RateLimit {
31    pub short_term_usage: u32,
32    pub short_term_limit: u32,
33    pub long_term_usage: u32,
34    pub long_term_limit: u32,
35}
36
37impl RateLimit {
38    /// Parse both headers from a response. Returns `None` if either header
39    /// is missing or malformed (e.g. mock servers without the headers set).
40    pub fn from_headers(headers: &HeaderMap) -> Option<Self> {
41        let limit = headers.get("x-ratelimit-limit")?.to_str().ok()?;
42        let usage = headers.get("x-ratelimit-usage")?.to_str().ok()?;
43        let (short_term_limit, long_term_limit) = parse_pair(limit)?;
44        let (short_term_usage, long_term_usage) = parse_pair(usage)?;
45        Some(Self {
46            short_term_usage,
47            short_term_limit,
48            long_term_usage,
49            long_term_limit,
50        })
51    }
52
53    /// Remaining calls before the 15-minute quota trips (saturating).
54    pub fn short_term_remaining(&self) -> u32 {
55        self.short_term_limit.saturating_sub(self.short_term_usage)
56    }
57
58    /// Remaining calls before the daily quota trips (saturating).
59    pub fn long_term_remaining(&self) -> u32 {
60        self.long_term_limit.saturating_sub(self.long_term_usage)
61    }
62}
63
64fn parse_pair(s: &str) -> Option<(u32, u32)> {
65    let mut parts = s.split(',');
66    let a = parts.next()?.trim().parse().ok()?;
67    let b = parts.next()?.trim().parse().ok()?;
68    Some((a, b))
69}
70
71fn rate_limit_slot() -> &'static RwLock<Option<RateLimit>> {
72    static SLOT: OnceLock<RwLock<Option<RateLimit>>> = OnceLock::new();
73    SLOT.get_or_init(|| RwLock::new(None))
74}
75
76/// Most recent rate-limit snapshot observed from any request in this process.
77/// `None` until the first response with the expected headers lands.
78pub fn last_rate_limit() -> Option<RateLimit> {
79    rate_limit_slot().read().ok().and_then(|g| *g)
80}
81
82fn record_rate_limit(headers: &HeaderMap) -> Option<RateLimit> {
83    let rl = RateLimit::from_headers(headers)?;
84    if let Ok(mut slot) = rate_limit_slot().write() {
85        *slot = Some(rl);
86    }
87    Some(rl)
88}
89
90#[cfg(test)]
91pub(crate) fn clear_rate_limit_for_testing() {
92    if let Ok(mut slot) = rate_limit_slot().write() {
93        *slot = None;
94    }
95}
96
97#[derive(Debug)]
98#[non_exhaustive]
99pub enum ErrorWrapper {
100    Network(reqwest::Error),
101    #[non_exhaustive]
102    Parse {
103        error: serde_json::Error,
104        body: String,
105    },
106    #[non_exhaustive]
107    Api {
108        status: StatusCode,
109        response: ErrorResponse,
110        /// Rate-limit snapshot from the response that triggered this error.
111        /// `None` for mock servers that don't set `X-RateLimit-*` headers.
112        /// Especially useful on 429 responses to decide a back-off window.
113        rate_limit: Option<RateLimit>,
114    },
115    Url(String),
116}
117
118impl fmt::Display for ErrorWrapper {
119    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
120        match self {
121            ErrorWrapper::Network(e) => write!(f, "network error: {}", e),
122            ErrorWrapper::Parse { error, .. } => {
123                write!(f, "failed to parse response: {}", error)
124            }
125            ErrorWrapper::Api {
126                status, response, ..
127            } => {
128                write!(f, "Strava API error {}: {}", status, response.message)
129            }
130            ErrorWrapper::Url(msg) => write!(f, "URL error: {}", msg),
131        }
132    }
133}
134
135impl std::error::Error for ErrorWrapper {
136    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
137        match self {
138            ErrorWrapper::Network(e) => Some(e),
139            ErrorWrapper::Parse { error, .. } => Some(error),
140            ErrorWrapper::Api { .. } | ErrorWrapper::Url(_) => None,
141        }
142    }
143}
144
145impl From<reqwest::Error> for ErrorWrapper {
146    fn from(e: reqwest::Error) -> Self {
147        ErrorWrapper::Network(e)
148    }
149}
150
151#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
152pub struct ErrorResponse {
153    pub errors: Vec<ErrorDetails>,
154    pub message: String,
155}
156
157#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
158pub struct ErrorDetails {
159    pub resource: String,
160    pub field: String,
161    pub code: String,
162}
163
164pub async fn get<T>(path: &str, token: &str) -> Result<T, ErrorWrapper>
165where
166    T: DeserializeOwned + Debug,
167{
168    let response = http_client()
169        .get(path)
170        .header("Authorization", format!("Bearer {}", token))
171        .send()
172        .await?;
173    handle_response::<T>(response).await
174}
175
176pub async fn get_raw(path: &str, token: &str) -> Result<String, ErrorWrapper> {
177    let response = http_client()
178        .get(path)
179        .header("Authorization", format!("Bearer {}", token))
180        .send()
181        .await?;
182    let status = response.status();
183    let rate_limit = record_rate_limit(response.headers());
184    let body = response.text().await?;
185    if status.is_success() {
186        Ok(body)
187    } else {
188        Err(ErrorWrapper::Api {
189            status,
190            response: parse_error_body(&body),
191            rate_limit,
192        })
193    }
194}
195
196pub async fn post<T, B>(path: &str, token: &str, body: B) -> Result<T, ErrorWrapper>
197where
198    T: DeserializeOwned + Debug,
199    B: Serialize + Debug,
200{
201    let response = http_client()
202        .post(path)
203        .header("Authorization", format!("Bearer {}", token))
204        .json(&body)
205        .send()
206        .await?;
207    handle_response::<T>(response).await
208}
209
210async fn handle_response<T>(response: reqwest::Response) -> Result<T, ErrorWrapper>
211where
212    T: DeserializeOwned + Debug,
213{
214    let status = response.status();
215    let rate_limit = record_rate_limit(response.headers());
216    let body = response.text().await?;
217    if status.is_success() {
218        serde_json::from_str::<T>(&body).map_err(|error| ErrorWrapper::Parse { error, body })
219    } else {
220        Err(ErrorWrapper::Api {
221            status,
222            response: parse_error_body(&body),
223            rate_limit,
224        })
225    }
226}
227
228pub(crate) fn parse_error_body(body: &str) -> ErrorResponse {
229    serde_json::from_str::<ErrorResponse>(body).unwrap_or_else(|_| ErrorResponse {
230        errors: Vec::new(),
231        message: body.to_string(),
232    })
233}
234
235#[async_trait]
236pub trait Sendable<U> {
237    async fn send(self) -> Result<U, ErrorWrapper>;
238}
239
240pub trait Query: Sized + Clone {
241    fn format_to_query_params(
242        url: &str,
243        params: Vec<(String, String)>,
244    ) -> Result<String, ErrorWrapper> {
245        Url::parse_with_params(url, params.iter())
246            .map(|u| u.to_string())
247            .map_err(|e| ErrorWrapper::Url(e.to_string()))
248    }
249
250    fn get_query_params(self) -> Vec<(String, String)>;
251}
252
253pub trait Endpoint: Sized + Clone {
254    fn new(url: impl Into<String>, token: impl Into<String>, path: impl Into<String>) -> Self
255    where
256        Self: Sized;
257
258    fn endpoint(&self) -> String;
259}
260
261pub trait PathQuery: Endpoint {
262    fn get_path_params(&self) -> HashMap<String, String>;
263}
264
265pub trait Page {
266    fn page(self, number: u32) -> Self;
267}
268pub trait PerPage {
269    fn per_page(self, number: u32) -> Self;
270}
271
272pub trait PageSize {
273    fn page_size(self, number: u32) -> Self;
274}
275
276pub trait Before {
277    fn before(self, before: u64) -> Self;
278}
279
280pub trait After {
281    fn after(self, after: u64) -> Self;
282}
283
284pub trait ID {
285    fn id(self, id: u64) -> Self;
286}
287
288/// String-valued `{id}` path param. Use on resources where Strava assigns
289/// non-numeric identifiers (e.g. gear ids like `"b12345678"`).
290pub trait GearID {
291    fn id(self, id: impl Into<String>) -> Self;
292}
293
294pub trait AfterCursor {
295    fn after_cursor(self, cursor: String) -> Self;
296}
297
298pub trait IncludeAllEfforts {
299    fn include_all_efforts(self, should_include: bool) -> Self;
300}
301
302fn format_path(template: &str, params: &HashMap<String, String>) -> String {
303    let mut path = template.to_string();
304    for (key, value) in params {
305        let placeholder = format!("{{{}}}", key);
306        path = path.replace(&placeholder, value);
307    }
308    path
309}
310
311pub async fn get_with_query_and_path<T, U>(inst: T, token: &str) -> Result<U, ErrorWrapper>
312where
313    T: Query + PathQuery + Endpoint,
314    U: DeserializeOwned + Debug,
315{
316    let url_with_path_params = format_path(&inst.endpoint(), &inst.get_path_params());
317    let url = T::format_to_query_params(&url_with_path_params, inst.get_query_params())?;
318    get(&url, token).await
319}
320
321pub async fn get_raw_with_query_and_path<T>(inst: T, token: &str) -> Result<String, ErrorWrapper>
322where
323    T: Query + PathQuery + Endpoint,
324{
325    let url_with_path_params = format_path(&inst.endpoint(), &inst.get_path_params());
326    let url = T::format_to_query_params(&url_with_path_params, inst.get_query_params())?;
327    get_raw(&url, token).await
328}
329
330pub async fn put_json<T, B>(path: &str, token: &str, body: &B) -> Result<T, ErrorWrapper>
331where
332    T: DeserializeOwned + Debug,
333    B: Serialize + ?Sized,
334{
335    let response = http_client()
336        .put(path)
337        .header("Authorization", format!("Bearer {}", token))
338        .json(body)
339        .send()
340        .await?;
341    handle_response::<T>(response).await
342}
343
344pub async fn put_form<T, B>(path: &str, token: &str, body: &B) -> Result<T, ErrorWrapper>
345where
346    T: DeserializeOwned + Debug,
347    B: Serialize + ?Sized,
348{
349    let response = http_client()
350        .put(path)
351        .header("Authorization", format!("Bearer {}", token))
352        .form(body)
353        .send()
354        .await?;
355    handle_response::<T>(response).await
356}
357
358pub async fn post_form<T, B>(path: &str, token: &str, body: &B) -> Result<T, ErrorWrapper>
359where
360    T: DeserializeOwned + Debug,
361    B: Serialize + ?Sized,
362{
363    let response = http_client()
364        .post(path)
365        .header("Authorization", format!("Bearer {}", token))
366        .form(body)
367        .send()
368        .await?;
369    handle_response::<T>(response).await
370}
371
372pub async fn post_multipart<T>(
373    path: &str,
374    token: &str,
375    form: reqwest::multipart::Form,
376) -> Result<T, ErrorWrapper>
377where
378    T: DeserializeOwned + Debug,
379{
380    let response = http_client()
381        .post(path)
382        .header("Authorization", format!("Bearer {}", token))
383        .multipart(form)
384        .send()
385        .await?;
386    handle_response::<T>(response).await
387}
388
389pub async fn put_json_with_path<T, U, B>(inst: T, token: &str, body: &B) -> Result<U, ErrorWrapper>
390where
391    T: Query + PathQuery + Endpoint,
392    U: DeserializeOwned + Debug,
393    B: Serialize + ?Sized,
394{
395    let url_with_path_params = format_path(&inst.endpoint(), &inst.get_path_params());
396    let url = T::format_to_query_params(&url_with_path_params, inst.get_query_params())?;
397    put_json(&url, token, body).await
398}
399
400pub async fn put_form_with_path<T, U, B>(inst: T, token: &str, body: &B) -> Result<U, ErrorWrapper>
401where
402    T: Query + PathQuery + Endpoint,
403    U: DeserializeOwned + Debug,
404    B: Serialize + ?Sized,
405{
406    let url_with_path_params = format_path(&inst.endpoint(), &inst.get_path_params());
407    let url = T::format_to_query_params(&url_with_path_params, inst.get_query_params())?;
408    put_form(&url, token, body).await
409}
410
411pub async fn post_form_with_path<T, U, B>(inst: T, token: &str, body: &B) -> Result<U, ErrorWrapper>
412where
413    T: Query + PathQuery + Endpoint,
414    U: DeserializeOwned + Debug,
415    B: Serialize + ?Sized,
416{
417    let url_with_path_params = format_path(&inst.endpoint(), &inst.get_path_params());
418    let url = T::format_to_query_params(&url_with_path_params, inst.get_query_params())?;
419    post_form(&url, token, body).await
420}
421
422pub async fn post_multipart_with_path<T, U>(
423    inst: T,
424    token: &str,
425    form: reqwest::multipart::Form,
426) -> Result<U, ErrorWrapper>
427where
428    T: Query + PathQuery + Endpoint,
429    U: DeserializeOwned + Debug,
430{
431    let url_with_path_params = format_path(&inst.endpoint(), &inst.get_path_params());
432    let url = T::format_to_query_params(&url_with_path_params, inst.get_query_params())?;
433    post_multipart(&url, token, form).await
434}