use std::time::Duration;
use super::{TokenCacheError, TokenSource};
pub struct ClientCredentialsSource {
http: reqwest::Client,
token_url: String,
client_id: String,
client_secret: String,
}
impl ClientCredentialsSource {
pub fn new(
token_url: impl Into<String>,
client_id: impl Into<String>,
client_secret: impl Into<String>,
) -> Self {
Self {
http: reqwest::Client::new(),
token_url: token_url.into(),
client_id: client_id.into(),
client_secret: client_secret.into(),
}
}
}
#[derive(serde::Deserialize)]
struct TokenResponse {
access_token: String,
expires_in: u64,
}
#[async_trait::async_trait]
impl TokenSource for ClientCredentialsSource {
async fn fetch_token(&self) -> Result<(String, Duration), TokenCacheError> {
let resp = self
.http
.post(&self.token_url)
.form(&[
("grant_type", "client_credentials"),
("client_id", self.client_id.as_str()),
("client_secret", self.client_secret.as_str()),
])
.send()
.await
.map_err(|e: reqwest::Error| TokenCacheError::Fetch(e.to_string()))?;
if !resp.status().is_success() {
let status = resp.status();
let body: String = resp.text().await.unwrap_or_default();
return Err(TokenCacheError::Fetch(format!("{status}: {body}")));
}
let body: TokenResponse = resp
.json::<TokenResponse>()
.await
.map_err(|e: reqwest::Error| TokenCacheError::Malformed(e.to_string()))?;
Ok((body.access_token, Duration::from_secs(body.expires_in)))
}
}