use rustauth_oauth::oauth2::{OAuthError, OAuthHttpClient};
use serde::Deserialize;
use std::collections::BTreeMap;
use std::sync::{Arc, Mutex};
use super::GenericOAuthConfig;
#[derive(Debug, Clone, Default, PartialEq, Eq, Deserialize)]
pub struct DiscoveryDocument {
pub issuer: Option<String>,
pub authorization_endpoint: Option<String>,
pub token_endpoint: Option<String>,
pub userinfo_endpoint: Option<String>,
}
pub(super) fn resolve_http_client(config: &GenericOAuthConfig) -> OAuthHttpClient {
config.http_client.clone().unwrap_or_else(|| {
OAuthHttpClient::default_client()
.unwrap_or_else(|_| OAuthHttpClient::new(reqwest::Client::new()))
})
}
#[derive(Debug, Clone, Default)]
pub struct DiscoveryCache {
documents: Arc<Mutex<BTreeMap<String, DiscoveryDocument>>>,
}
impl DiscoveryCache {
pub async fn fetch(
&self,
config: &GenericOAuthConfig,
http_client: &OAuthHttpClient,
) -> Result<Option<DiscoveryDocument>, OAuthError> {
let Some(url) = config.discovery_url.as_deref() else {
return Ok(None);
};
if let Some(document) = self.get(&config.provider_id)? {
return Ok(Some(document));
}
let document = fetch_url(config, url, http_client).await?;
self.insert(config.provider_id.clone(), document.clone())?;
Ok(Some(document))
}
fn get(&self, provider_id: &str) -> Result<Option<DiscoveryDocument>, OAuthError> {
let documents = self.documents.lock().map_err(|_| {
OAuthError::InvalidResponse("discovery cache lock was poisoned".to_owned())
})?;
Ok(documents.get(provider_id).cloned())
}
fn insert(&self, provider_id: String, document: DiscoveryDocument) -> Result<(), OAuthError> {
let mut documents = self.documents.lock().map_err(|_| {
OAuthError::InvalidResponse("discovery cache lock was poisoned".to_owned())
})?;
documents.insert(provider_id, document);
Ok(())
}
}
async fn fetch_url(
config: &GenericOAuthConfig,
url: &str,
http_client: &OAuthHttpClient,
) -> Result<DiscoveryDocument, OAuthError> {
let header_pairs = config
.discovery_headers
.iter()
.map(|(key, value)| (key.as_str(), value.as_str()))
.collect::<Vec<_>>();
let bytes = http_client
.get_bytes_with_headers(url, &header_pairs)
.await?;
serde_json::from_slice::<DiscoveryDocument>(&bytes)
.map_err(|error| OAuthError::InvalidResponse(error.to_string()))
}
pub fn headers(headers: &BTreeMap<String, String>) -> BTreeMap<String, String> {
headers
.iter()
.map(|(key, value)| (key.to_ascii_lowercase(), value.clone()))
.collect()
}