1use async_trait::async_trait;
2use reqwest::header::HeaderMap;
3use reqwest::{Client, StatusCode, Url};
4use serde::de::DeserializeOwned;
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use std::fmt;
8use std::fmt::Debug;
9use std::sync::{OnceLock, RwLock};
10
11pub const API_URL: &str = "https://www.strava.com/api";
12
13fn http_client() -> &'static Client {
16 static CLIENT: OnceLock<Client> = OnceLock::new();
17 CLIENT.get_or_init(Client::new)
18}
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
30pub struct RateLimit {
31 pub short_term_usage: u32,
32 pub short_term_limit: u32,
33 pub long_term_usage: u32,
34 pub long_term_limit: u32,
35}
36
37impl RateLimit {
38 pub fn from_headers(headers: &HeaderMap) -> Option<Self> {
41 let limit = headers.get("x-ratelimit-limit")?.to_str().ok()?;
42 let usage = headers.get("x-ratelimit-usage")?.to_str().ok()?;
43 let (short_term_limit, long_term_limit) = parse_pair(limit)?;
44 let (short_term_usage, long_term_usage) = parse_pair(usage)?;
45 Some(Self {
46 short_term_usage,
47 short_term_limit,
48 long_term_usage,
49 long_term_limit,
50 })
51 }
52
53 pub fn short_term_remaining(&self) -> u32 {
55 self.short_term_limit.saturating_sub(self.short_term_usage)
56 }
57
58 pub fn long_term_remaining(&self) -> u32 {
60 self.long_term_limit.saturating_sub(self.long_term_usage)
61 }
62}
63
64fn parse_pair(s: &str) -> Option<(u32, u32)> {
65 let mut parts = s.split(',');
66 let a = parts.next()?.trim().parse().ok()?;
67 let b = parts.next()?.trim().parse().ok()?;
68 Some((a, b))
69}
70
71fn rate_limit_slot() -> &'static RwLock<Option<RateLimit>> {
72 static SLOT: OnceLock<RwLock<Option<RateLimit>>> = OnceLock::new();
73 SLOT.get_or_init(|| RwLock::new(None))
74}
75
76pub fn last_rate_limit() -> Option<RateLimit> {
79 rate_limit_slot().read().ok().and_then(|g| *g)
80}
81
82fn record_rate_limit(headers: &HeaderMap) -> Option<RateLimit> {
83 let rl = RateLimit::from_headers(headers)?;
84 if let Ok(mut slot) = rate_limit_slot().write() {
85 *slot = Some(rl);
86 }
87 Some(rl)
88}
89
90#[cfg(test)]
91pub(crate) fn clear_rate_limit_for_testing() {
92 if let Ok(mut slot) = rate_limit_slot().write() {
93 *slot = None;
94 }
95}
96
97#[derive(Debug)]
98#[non_exhaustive]
99pub enum ErrorWrapper {
100 Network(reqwest::Error),
101 #[non_exhaustive]
102 Parse {
103 error: serde_json::Error,
104 body: String,
105 },
106 #[non_exhaustive]
107 Api {
108 status: StatusCode,
109 response: ErrorResponse,
110 rate_limit: Option<RateLimit>,
114 },
115 Url(String),
116}
117
118impl fmt::Display for ErrorWrapper {
119 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
120 match self {
121 ErrorWrapper::Network(e) => write!(f, "network error: {}", e),
122 ErrorWrapper::Parse { error, .. } => {
123 write!(f, "failed to parse response: {}", error)
124 }
125 ErrorWrapper::Api {
126 status, response, ..
127 } => {
128 write!(f, "Strava API error {}: {}", status, response.message)
129 }
130 ErrorWrapper::Url(msg) => write!(f, "URL error: {}", msg),
131 }
132 }
133}
134
135impl std::error::Error for ErrorWrapper {
136 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
137 match self {
138 ErrorWrapper::Network(e) => Some(e),
139 ErrorWrapper::Parse { error, .. } => Some(error),
140 ErrorWrapper::Api { .. } | ErrorWrapper::Url(_) => None,
141 }
142 }
143}
144
145impl From<reqwest::Error> for ErrorWrapper {
146 fn from(e: reqwest::Error) -> Self {
147 ErrorWrapper::Network(e)
148 }
149}
150
151#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
152pub struct ErrorResponse {
153 pub errors: Vec<ErrorDetails>,
154 pub message: String,
155}
156
157#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
158pub struct ErrorDetails {
159 pub resource: String,
160 pub field: String,
161 pub code: String,
162}
163
164pub async fn get<T>(path: &str, token: &str) -> Result<T, ErrorWrapper>
165where
166 T: DeserializeOwned + Debug,
167{
168 let response = http_client()
169 .get(path)
170 .header("Authorization", format!("Bearer {}", token))
171 .send()
172 .await?;
173 handle_response::<T>(response).await
174}
175
176pub async fn get_raw(path: &str, token: &str) -> Result<String, ErrorWrapper> {
177 let response = http_client()
178 .get(path)
179 .header("Authorization", format!("Bearer {}", token))
180 .send()
181 .await?;
182 let status = response.status();
183 let rate_limit = record_rate_limit(response.headers());
184 let body = response.text().await?;
185 if status.is_success() {
186 Ok(body)
187 } else {
188 Err(ErrorWrapper::Api {
189 status,
190 response: parse_error_body(&body),
191 rate_limit,
192 })
193 }
194}
195
196pub async fn post<T, B>(path: &str, token: &str, body: B) -> Result<T, ErrorWrapper>
197where
198 T: DeserializeOwned + Debug,
199 B: Serialize + Debug,
200{
201 let response = http_client()
202 .post(path)
203 .header("Authorization", format!("Bearer {}", token))
204 .json(&body)
205 .send()
206 .await?;
207 handle_response::<T>(response).await
208}
209
210async fn handle_response<T>(response: reqwest::Response) -> Result<T, ErrorWrapper>
211where
212 T: DeserializeOwned + Debug,
213{
214 let status = response.status();
215 let rate_limit = record_rate_limit(response.headers());
216 let body = response.text().await?;
217 if status.is_success() {
218 serde_json::from_str::<T>(&body).map_err(|error| ErrorWrapper::Parse { error, body })
219 } else {
220 Err(ErrorWrapper::Api {
221 status,
222 response: parse_error_body(&body),
223 rate_limit,
224 })
225 }
226}
227
228pub(crate) fn parse_error_body(body: &str) -> ErrorResponse {
229 serde_json::from_str::<ErrorResponse>(body).unwrap_or_else(|_| ErrorResponse {
230 errors: Vec::new(),
231 message: body.to_string(),
232 })
233}
234
235#[async_trait]
236pub trait Sendable<U> {
237 async fn send(self) -> Result<U, ErrorWrapper>;
238}
239
240pub trait Query: Sized + Clone {
241 fn format_to_query_params(
242 url: &str,
243 params: Vec<(String, String)>,
244 ) -> Result<String, ErrorWrapper> {
245 Url::parse_with_params(url, params.iter())
246 .map(|u| u.to_string())
247 .map_err(|e| ErrorWrapper::Url(e.to_string()))
248 }
249
250 fn get_query_params(self) -> Vec<(String, String)>;
251}
252
253pub trait Endpoint: Sized + Clone {
254 fn new(url: impl Into<String>, token: impl Into<String>, path: impl Into<String>) -> Self
255 where
256 Self: Sized;
257
258 fn endpoint(&self) -> String;
259}
260
261pub trait PathQuery: Endpoint {
262 fn get_path_params(&self) -> HashMap<String, String>;
263}
264
265pub trait Page {
266 fn page(self, number: u32) -> Self;
267}
268pub trait PerPage {
269 fn per_page(self, number: u32) -> Self;
270}
271
272pub trait PageSize {
273 fn page_size(self, number: u32) -> Self;
274}
275
276pub trait Before {
277 fn before(self, before: u64) -> Self;
278}
279
280pub trait After {
281 fn after(self, after: u64) -> Self;
282}
283
284pub trait ID {
285 fn id(self, id: u64) -> Self;
286}
287
288pub trait GearID {
291 fn id(self, id: impl Into<String>) -> Self;
292}
293
294pub trait AfterCursor {
295 fn after_cursor(self, cursor: String) -> Self;
296}
297
298pub trait IncludeAllEfforts {
299 fn include_all_efforts(self, should_include: bool) -> Self;
300}
301
302fn format_path(template: &str, params: &HashMap<String, String>) -> String {
303 let mut path = template.to_string();
304 for (key, value) in params {
305 let placeholder = format!("{{{}}}", key);
306 path = path.replace(&placeholder, value);
307 }
308 path
309}
310
311pub async fn get_with_query_and_path<T, U>(inst: T, token: &str) -> Result<U, ErrorWrapper>
312where
313 T: Query + PathQuery + Endpoint,
314 U: DeserializeOwned + Debug,
315{
316 let url_with_path_params = format_path(&inst.endpoint(), &inst.get_path_params());
317 let url = T::format_to_query_params(&url_with_path_params, inst.get_query_params())?;
318 get(&url, token).await
319}
320
321pub async fn get_raw_with_query_and_path<T>(inst: T, token: &str) -> Result<String, ErrorWrapper>
322where
323 T: Query + PathQuery + Endpoint,
324{
325 let url_with_path_params = format_path(&inst.endpoint(), &inst.get_path_params());
326 let url = T::format_to_query_params(&url_with_path_params, inst.get_query_params())?;
327 get_raw(&url, token).await
328}
329
330pub async fn put_json<T, B>(path: &str, token: &str, body: &B) -> Result<T, ErrorWrapper>
331where
332 T: DeserializeOwned + Debug,
333 B: Serialize + ?Sized,
334{
335 let response = http_client()
336 .put(path)
337 .header("Authorization", format!("Bearer {}", token))
338 .json(body)
339 .send()
340 .await?;
341 handle_response::<T>(response).await
342}
343
344pub async fn put_form<T, B>(path: &str, token: &str, body: &B) -> Result<T, ErrorWrapper>
345where
346 T: DeserializeOwned + Debug,
347 B: Serialize + ?Sized,
348{
349 let response = http_client()
350 .put(path)
351 .header("Authorization", format!("Bearer {}", token))
352 .form(body)
353 .send()
354 .await?;
355 handle_response::<T>(response).await
356}
357
358pub async fn post_form<T, B>(path: &str, token: &str, body: &B) -> Result<T, ErrorWrapper>
359where
360 T: DeserializeOwned + Debug,
361 B: Serialize + ?Sized,
362{
363 let response = http_client()
364 .post(path)
365 .header("Authorization", format!("Bearer {}", token))
366 .form(body)
367 .send()
368 .await?;
369 handle_response::<T>(response).await
370}
371
372pub async fn post_multipart<T>(
373 path: &str,
374 token: &str,
375 form: reqwest::multipart::Form,
376) -> Result<T, ErrorWrapper>
377where
378 T: DeserializeOwned + Debug,
379{
380 let response = http_client()
381 .post(path)
382 .header("Authorization", format!("Bearer {}", token))
383 .multipart(form)
384 .send()
385 .await?;
386 handle_response::<T>(response).await
387}
388
389pub async fn put_json_with_path<T, U, B>(inst: T, token: &str, body: &B) -> Result<U, ErrorWrapper>
390where
391 T: Query + PathQuery + Endpoint,
392 U: DeserializeOwned + Debug,
393 B: Serialize + ?Sized,
394{
395 let url_with_path_params = format_path(&inst.endpoint(), &inst.get_path_params());
396 let url = T::format_to_query_params(&url_with_path_params, inst.get_query_params())?;
397 put_json(&url, token, body).await
398}
399
400pub async fn put_form_with_path<T, U, B>(inst: T, token: &str, body: &B) -> Result<U, ErrorWrapper>
401where
402 T: Query + PathQuery + Endpoint,
403 U: DeserializeOwned + Debug,
404 B: Serialize + ?Sized,
405{
406 let url_with_path_params = format_path(&inst.endpoint(), &inst.get_path_params());
407 let url = T::format_to_query_params(&url_with_path_params, inst.get_query_params())?;
408 put_form(&url, token, body).await
409}
410
411pub async fn post_form_with_path<T, U, B>(inst: T, token: &str, body: &B) -> Result<U, ErrorWrapper>
412where
413 T: Query + PathQuery + Endpoint,
414 U: DeserializeOwned + Debug,
415 B: Serialize + ?Sized,
416{
417 let url_with_path_params = format_path(&inst.endpoint(), &inst.get_path_params());
418 let url = T::format_to_query_params(&url_with_path_params, inst.get_query_params())?;
419 post_form(&url, token, body).await
420}
421
422pub async fn post_multipart_with_path<T, U>(
423 inst: T,
424 token: &str,
425 form: reqwest::multipart::Form,
426) -> Result<U, ErrorWrapper>
427where
428 T: Query + PathQuery + Endpoint,
429 U: DeserializeOwned + Debug,
430{
431 let url_with_path_params = format_path(&inst.endpoint(), &inst.get_path_params());
432 let url = T::format_to_query_params(&url_with_path_params, inst.get_query_params())?;
433 post_multipart(&url, token, form).await
434}