1mod oauth_authorization_server_resolver;
2mod oauth_protected_resource_resolver;
3
4use self::oauth_authorization_server_resolver::DefaultOAuthAuthorizationServerResolver;
5use self::oauth_protected_resource_resolver::DefaultOAuthProtectedResourceResolver;
6use crate::types::{OAuthAuthorizationServerMetadata, OAuthProtectedResourceMetadata};
7use atrium_api::{
8 did_doc::DidDocument,
9 types::string::{Did, Handle},
10};
11use atrium_common::{
12 resolver::{CachedResolver, Resolver, ThrottledResolver},
13 types::{
14 cached::{
15 r#impl::{Cache, CacheImpl},
16 {CacheConfig, Cacheable},
17 },
18 throttled::Throttleable,
19 },
20};
21use atrium_identity::{
22 identity_resolver::{IdentityResolver, IdentityResolverConfig, ResolvedIdentity},
23 {Error, Result},
24};
25use atrium_xrpc::HttpClient;
26use std::{marker::PhantomData, sync::Arc, time::Duration};
27
28#[derive(Clone, Debug)]
29pub struct OAuthAuthorizationServerMetadataResolverConfig {
30 pub cache: CacheConfig,
31}
32
33impl Default for OAuthAuthorizationServerMetadataResolverConfig {
34 fn default() -> Self {
35 Self {
36 cache: CacheConfig {
37 max_capacity: Some(100),
38 time_to_live: Some(Duration::from_secs(60)),
39 },
40 }
41 }
42}
43
44#[derive(Clone, Debug)]
45pub struct OAuthProtectedResourceMetadataResolverConfig {
46 pub cache: CacheConfig,
47}
48
49impl Default for OAuthProtectedResourceMetadataResolverConfig {
50 fn default() -> Self {
51 Self {
52 cache: CacheConfig {
53 max_capacity: Some(100),
54 time_to_live: Some(Duration::from_secs(60)),
55 },
56 }
57 }
58}
59
60#[derive(Clone, Debug)]
61pub struct OAuthResolverConfig<D, H> {
62 pub did_resolver: D,
63 pub handle_resolver: H,
64 pub authorization_server_metadata: OAuthAuthorizationServerMetadataResolverConfig,
65 pub protected_resource_metadata: OAuthProtectedResourceMetadataResolverConfig,
66}
67
68pub struct OAuthResolver<
69 T,
70 D,
71 H,
72 PR = DefaultOAuthProtectedResourceResolver<T>,
73 AS = DefaultOAuthAuthorizationServerResolver<T>,
74> where
75 PR: Resolver<Input = String, Output = OAuthProtectedResourceMetadata> + Send + Sync + 'static,
76 AS: Resolver<Input = String, Output = OAuthAuthorizationServerMetadata> + Send + Sync + 'static,
77{
78 identity_resolver: IdentityResolver<D, H>,
79 protected_resource_resolver: CachedResolver<ThrottledResolver<PR>>,
80 authorization_server_resolver: CachedResolver<ThrottledResolver<AS>>,
81 _phantom: PhantomData<T>,
82}
83
84impl<T, D, H> OAuthResolver<T, D, H>
85where
86 T: HttpClient + Send + Sync + 'static,
87{
88 pub fn new(config: OAuthResolverConfig<D, H>, http_client: Arc<T>) -> Self {
89 let protected_resource_resolver =
90 DefaultOAuthProtectedResourceResolver::new(http_client.clone())
91 .throttled()
92 .cached(CacheImpl::new(config.authorization_server_metadata.cache));
93 let authorization_server_resolver =
94 DefaultOAuthAuthorizationServerResolver::new(http_client.clone())
95 .throttled()
96 .cached(CacheImpl::new(config.protected_resource_metadata.cache));
97 Self {
98 identity_resolver: IdentityResolver::new(IdentityResolverConfig {
99 did_resolver: config.did_resolver,
100 handle_resolver: config.handle_resolver,
101 }),
102 protected_resource_resolver,
103 authorization_server_resolver,
104 _phantom: PhantomData,
105 }
106 }
107}
108
109impl<T, D, H> OAuthResolver<T, D, H>
110where
111 T: HttpClient + Send + Sync + 'static,
112 D: Resolver<Input = Did, Output = DidDocument, Error = Error> + Send + Sync,
113 H: Resolver<Input = Handle, Output = Did, Error = Error> + Send + Sync,
114{
115 pub async fn get_authorization_server_metadata(
116 &self,
117 issuer: impl AsRef<str>,
118 ) -> Result<OAuthAuthorizationServerMetadata> {
119 let result =
120 self.authorization_server_resolver.resolve(&issuer.as_ref().to_string()).await?;
121 result.ok_or_else(|| Error::NotFound)
122 }
123 async fn resolve_from_service(&self, input: &str) -> Result<OAuthAuthorizationServerMetadata> {
124 if let Ok(metadata) = self.get_resource_server_metadata(input).await {
126 return Ok(metadata);
127 }
128 self.get_authorization_server_metadata(input).await
130 }
131 pub(crate) async fn resolve_from_identity(
132 &self,
133 input: &str,
134 ) -> Result<(OAuthAuthorizationServerMetadata, ResolvedIdentity)> {
135 let identity = self.identity_resolver.resolve(input).await?;
136 let metadata = self.get_resource_server_metadata(&identity.pds).await?;
137 Ok((metadata, identity))
138 }
139 async fn get_resource_server_metadata(
140 &self,
141 pds: &str,
142 ) -> Result<OAuthAuthorizationServerMetadata> {
143 let result = self.protected_resource_resolver.resolve(&pds.to_string()).await?;
144 let rs_metadata = result.ok_or_else(|| Error::NotFound)?;
145 let issuer = match &rs_metadata.authorization_servers {
149 Some(servers) if !servers.is_empty() => {
150 if servers.len() > 1 {
151 return Err(Error::ProtectedResourceMetadata(format!(
152 "unable to determine authorization server for PDS: {pds}"
153 )));
154 }
155 &servers[0]
156 }
157 _ => {
158 return Err(Error::ProtectedResourceMetadata(format!(
159 "no authorization server found for PDS: {pds}"
160 )))
161 }
162 };
163 let as_metadata = self.get_authorization_server_metadata(issuer).await?;
164 if let Some(protected_resources) = &as_metadata.protected_resources {
166 if !protected_resources.contains(&rs_metadata.resource) {
167 return Err(Error::AuthorizationServerMetadata(format!(
168 "pds {pds} does not protected by issuer: {issuer}",
169 )));
170 }
171 }
172
173 Ok(as_metadata)
185 }
186}
187
188impl<T, D, H> Resolver for OAuthResolver<T, D, H>
189where
190 T: HttpClient + Send + Sync + 'static,
191 D: Resolver<Input = Did, Output = DidDocument, Error = Error> + Send + Sync,
192 H: Resolver<Input = Handle, Output = Did, Error = Error> + Send + Sync,
193{
194 type Input = str;
195 type Output = (OAuthAuthorizationServerMetadata, Option<ResolvedIdentity>);
196 type Error = Error;
197
198 async fn resolve(&self, input: &Self::Input) -> Result<Self::Output> {
199 Ok(if input.starts_with("https://") {
203 (self.resolve_from_service(input.as_ref()).await?, None)
204 } else {
205 let (metadata, identity) = self.resolve_from_identity(input).await?;
206 (metadata, Some(identity))
207 })
208 }
209}