mod oauth_authorization_server_resolver;
mod oauth_protected_resource_resolver;
use self::oauth_authorization_server_resolver::DefaultOAuthAuthorizationServerResolver;
use self::oauth_protected_resource_resolver::DefaultOAuthProtectedResourceResolver;
use crate::types::{OAuthAuthorizationServerMetadata, OAuthProtectedResourceMetadata};
use atrium_api::{
did_doc::DidDocument,
types::string::{Did, Handle},
};
use atrium_common::{
resolver::{CachedResolver, Resolver, ThrottledResolver},
types::{
cached::{
r#impl::{Cache, CacheImpl},
{CacheConfig, Cacheable},
},
throttled::Throttleable,
},
};
use atrium_identity::{
identity_resolver::{IdentityResolver, IdentityResolverConfig, ResolvedIdentity},
{Error, Result},
};
use atrium_xrpc::HttpClient;
use std::{marker::PhantomData, sync::Arc, time::Duration};
#[derive(Clone, Debug)]
pub struct OAuthAuthorizationServerMetadataResolverConfig {
pub cache: CacheConfig,
}
impl Default for OAuthAuthorizationServerMetadataResolverConfig {
fn default() -> Self {
Self {
cache: CacheConfig {
max_capacity: Some(100),
time_to_live: Some(Duration::from_secs(60)),
},
}
}
}
#[derive(Clone, Debug)]
pub struct OAuthProtectedResourceMetadataResolverConfig {
pub cache: CacheConfig,
}
impl Default for OAuthProtectedResourceMetadataResolverConfig {
fn default() -> Self {
Self {
cache: CacheConfig {
max_capacity: Some(100),
time_to_live: Some(Duration::from_secs(60)),
},
}
}
}
#[derive(Clone, Debug)]
pub struct OAuthResolverConfig<D, H> {
pub did_resolver: D,
pub handle_resolver: H,
pub authorization_server_metadata: OAuthAuthorizationServerMetadataResolverConfig,
pub protected_resource_metadata: OAuthProtectedResourceMetadataResolverConfig,
}
pub struct OAuthResolver<
T,
D,
H,
PR = DefaultOAuthProtectedResourceResolver<T>,
AS = DefaultOAuthAuthorizationServerResolver<T>,
> where
PR: Resolver<Input = String, Output = OAuthProtectedResourceMetadata> + Send + Sync + 'static,
AS: Resolver<Input = String, Output = OAuthAuthorizationServerMetadata> + Send + Sync + 'static,
{
identity_resolver: IdentityResolver<D, H>,
protected_resource_resolver: CachedResolver<ThrottledResolver<PR>>,
authorization_server_resolver: CachedResolver<ThrottledResolver<AS>>,
_phantom: PhantomData<T>,
}
impl<T, D, H> OAuthResolver<T, D, H>
where
T: HttpClient + Send + Sync + 'static,
{
pub fn new(config: OAuthResolverConfig<D, H>, http_client: Arc<T>) -> Self {
let protected_resource_resolver =
DefaultOAuthProtectedResourceResolver::new(http_client.clone())
.throttled()
.cached(CacheImpl::new(config.authorization_server_metadata.cache));
let authorization_server_resolver =
DefaultOAuthAuthorizationServerResolver::new(http_client.clone())
.throttled()
.cached(CacheImpl::new(config.protected_resource_metadata.cache));
Self {
identity_resolver: IdentityResolver::new(IdentityResolverConfig {
did_resolver: config.did_resolver,
handle_resolver: config.handle_resolver,
}),
protected_resource_resolver,
authorization_server_resolver,
_phantom: PhantomData,
}
}
}
impl<T, D, H> OAuthResolver<T, D, H>
where
T: HttpClient + Send + Sync + 'static,
D: Resolver<Input = Did, Output = DidDocument, Error = Error> + Send + Sync,
H: Resolver<Input = Handle, Output = Did, Error = Error> + Send + Sync,
{
pub async fn get_authorization_server_metadata(
&self,
issuer: impl AsRef<str>,
) -> Result<OAuthAuthorizationServerMetadata> {
let result =
self.authorization_server_resolver.resolve(&issuer.as_ref().to_string()).await?;
result.ok_or_else(|| Error::NotFound)
}
async fn resolve_from_service(&self, input: &str) -> Result<OAuthAuthorizationServerMetadata> {
if let Ok(metadata) = self.get_resource_server_metadata(input).await {
return Ok(metadata);
}
self.get_authorization_server_metadata(input).await
}
pub(crate) async fn resolve_from_identity(
&self,
input: &str,
) -> Result<(OAuthAuthorizationServerMetadata, ResolvedIdentity)> {
let identity = self.identity_resolver.resolve(input).await?;
let metadata = self.get_resource_server_metadata(&identity.pds).await?;
Ok((metadata, identity))
}
async fn get_resource_server_metadata(
&self,
pds: &str,
) -> Result<OAuthAuthorizationServerMetadata> {
let result = self.protected_resource_resolver.resolve(&pds.to_string()).await?;
let rs_metadata = result.ok_or_else(|| Error::NotFound)?;
let issuer = match &rs_metadata.authorization_servers {
Some(servers) if !servers.is_empty() => {
if servers.len() > 1 {
return Err(Error::ProtectedResourceMetadata(format!(
"unable to determine authorization server for PDS: {pds}"
)));
}
&servers[0]
}
_ => {
return Err(Error::ProtectedResourceMetadata(format!(
"no authorization server found for PDS: {pds}"
)));
}
};
let as_metadata = self.get_authorization_server_metadata(issuer).await?;
if let Some(protected_resources) = &as_metadata.protected_resources {
if !protected_resources.contains(&rs_metadata.resource) {
return Err(Error::AuthorizationServerMetadata(format!(
"pds {pds} does not protected by issuer: {issuer}",
)));
}
}
Ok(as_metadata)
}
}
impl<T, D, H> Resolver for OAuthResolver<T, D, H>
where
T: HttpClient + Send + Sync + 'static,
D: Resolver<Input = Did, Output = DidDocument, Error = Error> + Send + Sync,
H: Resolver<Input = Handle, Output = Did, Error = Error> + Send + Sync,
{
type Input = str;
type Output = (OAuthAuthorizationServerMetadata, Option<ResolvedIdentity>);
type Error = Error;
async fn resolve(&self, input: &Self::Input) -> Result<Self::Output> {
Ok(if input.starts_with("https://") {
(self.resolve_from_service(input.as_ref()).await?, None)
} else {
let (metadata, identity) = self.resolve_from_identity(input).await?;
(metadata, Some(identity))
})
}
}