atrium_oauth/
resolver.rs

1mod oauth_authorization_server_resolver;
2mod oauth_protected_resource_resolver;
3
4use self::oauth_authorization_server_resolver::DefaultOAuthAuthorizationServerResolver;
5use self::oauth_protected_resource_resolver::DefaultOAuthProtectedResourceResolver;
6use crate::types::{OAuthAuthorizationServerMetadata, OAuthProtectedResourceMetadata};
7use atrium_api::{
8    did_doc::DidDocument,
9    types::string::{Did, Handle},
10};
11use atrium_common::{
12    resolver::{CachedResolver, Resolver, ThrottledResolver},
13    types::{
14        cached::{
15            r#impl::{Cache, CacheImpl},
16            {CacheConfig, Cacheable},
17        },
18        throttled::Throttleable,
19    },
20};
21use atrium_identity::{
22    identity_resolver::{IdentityResolver, IdentityResolverConfig, ResolvedIdentity},
23    {Error, Result},
24};
25use atrium_xrpc::HttpClient;
26use std::{marker::PhantomData, sync::Arc, time::Duration};
27
28#[derive(Clone, Debug)]
29pub struct OAuthAuthorizationServerMetadataResolverConfig {
30    pub cache: CacheConfig,
31}
32
33impl Default for OAuthAuthorizationServerMetadataResolverConfig {
34    fn default() -> Self {
35        Self {
36            cache: CacheConfig {
37                max_capacity: Some(100),
38                time_to_live: Some(Duration::from_secs(60)),
39            },
40        }
41    }
42}
43
44#[derive(Clone, Debug)]
45pub struct OAuthProtectedResourceMetadataResolverConfig {
46    pub cache: CacheConfig,
47}
48
49impl Default for OAuthProtectedResourceMetadataResolverConfig {
50    fn default() -> Self {
51        Self {
52            cache: CacheConfig {
53                max_capacity: Some(100),
54                time_to_live: Some(Duration::from_secs(60)),
55            },
56        }
57    }
58}
59
60#[derive(Clone, Debug)]
61pub struct OAuthResolverConfig<D, H> {
62    pub did_resolver: D,
63    pub handle_resolver: H,
64    pub authorization_server_metadata: OAuthAuthorizationServerMetadataResolverConfig,
65    pub protected_resource_metadata: OAuthProtectedResourceMetadataResolverConfig,
66}
67
68pub struct OAuthResolver<
69    T,
70    D,
71    H,
72    PR = DefaultOAuthProtectedResourceResolver<T>,
73    AS = DefaultOAuthAuthorizationServerResolver<T>,
74> where
75    PR: Resolver<Input = String, Output = OAuthProtectedResourceMetadata> + Send + Sync + 'static,
76    AS: Resolver<Input = String, Output = OAuthAuthorizationServerMetadata> + Send + Sync + 'static,
77{
78    identity_resolver: IdentityResolver<D, H>,
79    protected_resource_resolver: CachedResolver<ThrottledResolver<PR>>,
80    authorization_server_resolver: CachedResolver<ThrottledResolver<AS>>,
81    _phantom: PhantomData<T>,
82}
83
84impl<T, D, H> OAuthResolver<T, D, H>
85where
86    T: HttpClient + Send + Sync + 'static,
87{
88    pub fn new(config: OAuthResolverConfig<D, H>, http_client: Arc<T>) -> Self {
89        let protected_resource_resolver =
90            DefaultOAuthProtectedResourceResolver::new(http_client.clone())
91                .throttled()
92                .cached(CacheImpl::new(config.authorization_server_metadata.cache));
93        let authorization_server_resolver =
94            DefaultOAuthAuthorizationServerResolver::new(http_client.clone())
95                .throttled()
96                .cached(CacheImpl::new(config.protected_resource_metadata.cache));
97        Self {
98            identity_resolver: IdentityResolver::new(IdentityResolverConfig {
99                did_resolver: config.did_resolver,
100                handle_resolver: config.handle_resolver,
101            }),
102            protected_resource_resolver,
103            authorization_server_resolver,
104            _phantom: PhantomData,
105        }
106    }
107}
108
109impl<T, D, H> OAuthResolver<T, D, H>
110where
111    T: HttpClient + Send + Sync + 'static,
112    D: Resolver<Input = Did, Output = DidDocument, Error = Error> + Send + Sync,
113    H: Resolver<Input = Handle, Output = Did, Error = Error> + Send + Sync,
114{
115    pub async fn get_authorization_server_metadata(
116        &self,
117        issuer: impl AsRef<str>,
118    ) -> Result<OAuthAuthorizationServerMetadata> {
119        let result =
120            self.authorization_server_resolver.resolve(&issuer.as_ref().to_string()).await?;
121        result.ok_or_else(|| Error::NotFound)
122    }
123    async fn resolve_from_service(&self, input: &str) -> Result<OAuthAuthorizationServerMetadata> {
124        // Assume first that input is a PDS URL (as required by ATPROTO)
125        if let Ok(metadata) = self.get_resource_server_metadata(input).await {
126            return Ok(metadata);
127        }
128        // Fallback to trying to fetch as an issuer (Entryway)
129        self.get_authorization_server_metadata(input).await
130    }
131    pub(crate) async fn resolve_from_identity(
132        &self,
133        input: &str,
134    ) -> Result<(OAuthAuthorizationServerMetadata, ResolvedIdentity)> {
135        let identity = self.identity_resolver.resolve(input).await?;
136        let metadata = self.get_resource_server_metadata(&identity.pds).await?;
137        Ok((metadata, identity))
138    }
139    async fn get_resource_server_metadata(
140        &self,
141        pds: &str,
142    ) -> Result<OAuthAuthorizationServerMetadata> {
143        let result = self.protected_resource_resolver.resolve(&pds.to_string()).await?;
144        let rs_metadata = result.ok_or_else(|| Error::NotFound)?;
145        // ATPROTO requires one, and only one, authorization server entry
146        // > That document MUST contain a single item in the authorization_servers array.
147        // https://github.com/bluesky-social/proposals/tree/main/0004-oauth#server-metadata
148        let issuer = match &rs_metadata.authorization_servers {
149            Some(servers) if !servers.is_empty() => {
150                if servers.len() > 1 {
151                    return Err(Error::ProtectedResourceMetadata(format!(
152                        "unable to determine authorization server for PDS: {pds}"
153                    )));
154                }
155                &servers[0]
156            }
157            _ => {
158                return Err(Error::ProtectedResourceMetadata(format!(
159                    "no authorization server found for PDS: {pds}"
160                )))
161            }
162        };
163        let as_metadata = self.get_authorization_server_metadata(issuer).await?;
164        // https://datatracker.ietf.org/doc/html/draft-ietf-oauth-resource-metadata-08#name-authorization-server-metada
165        if let Some(protected_resources) = &as_metadata.protected_resources {
166            if !protected_resources.contains(&rs_metadata.resource) {
167                return Err(Error::AuthorizationServerMetadata(format!(
168                    "pds {pds} does not protected by issuer: {issuer}",
169                )));
170            }
171        }
172
173        // TODO: atproot specific validation?
174        // https://github.com/bluesky-social/proposals/tree/main/0004-oauth#server-metadata
175        //
176        // eg.
177        // https://drafts.aaronpk.com/draft-parecki-oauth-client-id-metadata-document/draft-parecki-oauth-client-id-metadata-document.html
178        // if as_metadata.client_id_metadata_document_supported != Some(true) {
179        //     return Err(Error::AuthorizationServerMetadata(format!(
180        //         "authorization server does not support client_id_metadata_document: {issuer}"
181        //     )));
182        // }
183
184        Ok(as_metadata)
185    }
186}
187
188impl<T, D, H> Resolver for OAuthResolver<T, D, H>
189where
190    T: HttpClient + Send + Sync + 'static,
191    D: Resolver<Input = Did, Output = DidDocument, Error = Error> + Send + Sync,
192    H: Resolver<Input = Handle, Output = Did, Error = Error> + Send + Sync,
193{
194    type Input = str;
195    type Output = (OAuthAuthorizationServerMetadata, Option<ResolvedIdentity>);
196    type Error = Error;
197
198    async fn resolve(&self, input: &Self::Input) -> Result<Self::Output> {
199        // Allow using an entryway, or PDS url, directly as login input (e.g.
200        // when the user forgot their handle, or when the handle does not
201        // resolve to a DID)
202        Ok(if input.starts_with("https://") {
203            (self.resolve_from_service(input.as_ref()).await?, None)
204        } else {
205            let (metadata, identity) = self.resolve_from_identity(input).await?;
206            (metadata, Some(identity))
207        })
208    }
209}