atrium_oauth/
oauth_client.rs

1use crate::{
2    constants::FALLBACK_ALG,
3    error::{Error, Result},
4    keyset::Keyset,
5    oauth_session::OAuthSession,
6    resolver::{OAuthResolver, OAuthResolverConfig},
7    server_agent::{OAuthRequest, OAuthServerAgent, OAuthServerFactory},
8    store::{
9        session::{Session, SessionStore},
10        session_registry::SessionRegistry,
11        state::{InternalStateData, StateStore},
12    },
13    types::{
14        AuthorizationCodeChallengeMethod, AuthorizationResponseType, AuthorizeOptions,
15        CallbackParams, OAuthAuthorizationServerMetadata, OAuthClientMetadata,
16        OAuthPusehedAuthorizationRequestResponse, PushedAuthorizationRequestParameters,
17        TryIntoOAuthClientMetadata,
18    },
19    utils::{compare_algos, generate_key, generate_nonce},
20};
21use atrium_api::{
22    did_doc::DidDocument,
23    types::string::{Did, Handle},
24};
25use atrium_common::resolver::Resolver;
26use atrium_xrpc::HttpClient;
27use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
28use jose_jwk::{Jwk, JwkSet, Key};
29use serde::Serialize;
30use sha2::{Digest, Sha256};
31use std::sync::Arc;
32
33#[cfg(feature = "default-client")]
34pub struct OAuthClientConfig<S0, S1, M, D, H>
35where
36    M: TryIntoOAuthClientMetadata,
37{
38    // Config
39    pub client_metadata: M,
40    pub keys: Option<Vec<Jwk>>,
41    // Stores
42    pub state_store: S0,
43    pub session_store: S1,
44    // Services
45    pub resolver: OAuthResolverConfig<D, H>,
46}
47
48#[cfg(not(feature = "default-client"))]
49pub struct OAuthClientConfig<S0, S1, T, M, D, H>
50where
51    M: TryIntoOAuthClientMetadata,
52{
53    // Config
54    pub client_metadata: M,
55    pub keys: Option<Vec<Jwk>>,
56    // Stores
57    pub state_store: S0,
58    pub session_store: S1,
59    // Services
60    pub resolver: OAuthResolverConfig<D, H>,
61    // Others
62    pub http_client: T,
63}
64
65/// An OAuth client for AT Protocol.
66///
67/// This client is used to process OAuth flows with AT Protocol.
68#[cfg(feature = "default-client")]
69pub struct OAuthClient<S0, S1, D, H, T = crate::http_client::default::DefaultHttpClient>
70where
71    T: HttpClient + Send + Sync + 'static,
72    S1: SessionStore + Send + Sync + 'static,
73    S1::Error: std::error::Error + Send + Sync + 'static,
74{
75    pub client_metadata: OAuthClientMetadata,
76    keyset: Option<Keyset>,
77    resolver: Arc<OAuthResolver<T, D, H>>,
78    server_factory: Arc<OAuthServerFactory<T, D, H>>,
79    state_store: S0,
80    session_registry: Arc<SessionRegistry<S1, T, D, H>>,
81    http_client: Arc<T>,
82}
83
84#[cfg(not(feature = "default-client"))]
85pub struct OAuthClient<S0, S1, D, H, T>
86where
87    T: HttpClient + Send + Sync + 'static,
88    S1: SessionStore + Send + Sync + 'static,
89    S1::Error: std::error::Error + Send + Sync + 'static,
90{
91    pub client_metadata: OAuthClientMetadata,
92    keyset: Option<Keyset>,
93    resolver: Arc<OAuthResolver<T, D, H>>,
94    server_factory: Arc<OAuthServerFactory<T, D, H>>,
95    state_store: S0,
96    session_registry: Arc<SessionRegistry<S1, T, D, H>>,
97    http_client: Arc<T>,
98}
99
100#[cfg(feature = "default-client")]
101impl<S0, S1, D, H> OAuthClient<S0, S1, D, H, crate::http_client::default::DefaultHttpClient>
102where
103    S1: SessionStore + Send + Sync + 'static,
104    S1::Error: std::error::Error + Send + Sync + 'static,
105{
106    /// Create a new OAuth client.
107    pub fn new<M>(config: OAuthClientConfig<S0, S1, M, D, H>) -> Result<Self>
108    where
109        M: TryIntoOAuthClientMetadata<Error = crate::atproto::Error>,
110    {
111        let keyset = if let Some(keys) = config.keys { Some(keys.try_into()?) } else { None };
112        let client_metadata = config.client_metadata.try_into_client_metadata(&keyset)?;
113        let http_client = Arc::new(crate::http_client::default::DefaultHttpClient::default());
114        let resolver = Arc::new(OAuthResolver::new(config.resolver, Arc::clone(&http_client)));
115        let server_factory = Arc::new(OAuthServerFactory::new(
116            client_metadata.clone(),
117            Arc::clone(&resolver),
118            Arc::clone(&http_client),
119            keyset.clone(),
120        ));
121        let session_registry =
122            Arc::new(SessionRegistry::new(config.session_store, Arc::clone(&server_factory)));
123        Ok(Self {
124            client_metadata,
125            keyset,
126            resolver,
127            server_factory,
128            state_store: config.state_store,
129            session_registry,
130            http_client,
131        })
132    }
133}
134
135#[cfg(not(feature = "default-client"))]
136impl<S0, S1, D, H, T> OAuthClient<S0, S1, D, H, T>
137where
138    T: HttpClient + Send + Sync + 'static,
139    S1: SessionStore + Send + Sync + 'static,
140    S1::Error: std::error::Error + Send + Sync + 'static,
141{
142    pub fn new<M>(config: OAuthClientConfig<S0, S1, T, M, D, H>) -> Result<Self>
143    where
144        M: TryIntoOAuthClientMetadata<Error = crate::atproto::Error>,
145    {
146        let keyset = if let Some(keys) = config.keys { Some(keys.try_into()?) } else { None };
147        let client_metadata = config.client_metadata.try_into_client_metadata(&keyset)?;
148        let http_client = Arc::new(config.http_client);
149        let resolver = Arc::new(OAuthResolver::new(config.resolver, Arc::clone(&http_client)));
150        let server_factory = Arc::new(OAuthServerFactory::new(
151            client_metadata.clone(),
152            Arc::clone(&resolver),
153            Arc::clone(&http_client),
154            keyset.clone(),
155        ));
156        let session_registry =
157            Arc::new(SessionRegistry::new(config.session_store, Arc::clone(&server_factory)));
158        Ok(Self {
159            client_metadata,
160            keyset,
161            resolver,
162            server_factory,
163            state_store: config.state_store,
164            session_registry,
165            http_client,
166        })
167    }
168}
169
170impl<S0, S1, D, H, T> OAuthClient<S0, S1, D, H, T>
171where
172    S0: StateStore + Send + Sync + 'static,
173    S1: SessionStore + Send + Sync + 'static,
174    D: Resolver<Input = Did, Output = DidDocument, Error = atrium_identity::Error> + Send + Sync,
175    H: Resolver<Input = Handle, Output = Did, Error = atrium_identity::Error> + Send + Sync,
176    T: HttpClient + Send + Sync + 'static,
177    S0::Error: std::error::Error + Send + Sync + 'static,
178    S1::Error: std::error::Error + Send + Sync + 'static,
179{
180    /// Get the jwks of the client.
181    pub fn jwks(&self) -> JwkSet {
182        self.keyset.as_ref().map(|keyset| keyset.public_jwks()).unwrap_or_default()
183    }
184    /// Start the authorization process.
185    ///
186    /// This method will return a URL that the user should be redirected to.
187    pub async fn authorize(
188        &self,
189        input: impl AsRef<str>,
190        options: AuthorizeOptions,
191    ) -> Result<String> {
192        let redirect_uri = if let Some(uri) = options.redirect_uri {
193            if !self.client_metadata.redirect_uris.contains(&uri) {
194                return Err(Error::Authorize("invalid redirect_uri".into()));
195            }
196            uri
197        } else {
198            self.client_metadata.redirect_uris[0].clone()
199        };
200        let (metadata, identity) = self.resolver.resolve(input.as_ref()).await?;
201        let Some(dpop_key) = Self::generate_dpop_key(&metadata) else {
202            return Err(Error::Authorize("none of the algorithms worked".into()));
203        };
204        let (code_challenge, verifier) = Self::generate_pkce();
205        let state = generate_nonce();
206        let state_data = InternalStateData {
207            iss: metadata.issuer.clone(),
208            dpop_key: dpop_key.clone(),
209            verifier,
210            app_state: options.state,
211        };
212        self.state_store
213            .set(state.clone(), state_data)
214            .await
215            .map_err(|e| Error::StateStore(Box::new(e)))?;
216        let login_hint = if identity.is_some() { Some(input.as_ref().into()) } else { None };
217        let parameters = PushedAuthorizationRequestParameters {
218            response_type: AuthorizationResponseType::Code,
219            redirect_uri,
220            state,
221            scope: Some(options.scopes.iter().map(AsRef::as_ref).collect::<Vec<_>>().join(" ")),
222            response_mode: None,
223            code_challenge,
224            code_challenge_method: AuthorizationCodeChallengeMethod::S256,
225            login_hint,
226            prompt: options.prompt.map(String::from),
227        };
228        if metadata.pushed_authorization_request_endpoint.is_some() {
229            let server = self.server_factory.build_from_metadata(dpop_key, metadata.clone())?;
230            let par_response = server
231                .request::<OAuthPusehedAuthorizationRequestResponse>(
232                    OAuthRequest::PushedAuthorizationRequest(parameters),
233                )
234                .await?;
235
236            #[derive(Serialize)]
237            struct Parameters {
238                client_id: String,
239                request_uri: String,
240            }
241            Ok(metadata.authorization_endpoint
242                + "?"
243                + &serde_html_form::to_string(Parameters {
244                    client_id: self.client_metadata.client_id.clone(),
245                    request_uri: par_response.request_uri,
246                })
247                .unwrap())
248        } else if metadata.require_pushed_authorization_requests == Some(true) {
249            Err(Error::Authorize("server requires PAR but no endpoint is available".into()))
250        } else {
251            // now "the use of PAR is *mandatory* for all clients"
252            // https://github.com/bluesky-social/proposals/tree/main/0004-oauth#framework
253            todo!()
254        }
255    }
256    /// Handle the callback from the authorization server.
257    ///
258    /// This method will exchange the authorization code for an access token and store the session,
259    /// and return the [`OAuthSession`] and the application state.
260    pub async fn callback(
261        &self,
262        params: CallbackParams,
263    ) -> Result<(OAuthSession<T, D, H, S1>, Option<String>)> {
264        let Some(state_key) = params.state else {
265            return Err(Error::Callback("missing `state` parameter".into()));
266        };
267
268        let Some(state) =
269            self.state_store.get(&state_key).await.map_err(|e| Error::StateStore(Box::new(e)))?
270        else {
271            return Err(Error::Callback(format!("unknown authorization state: {state_key}")));
272        };
273        // Prevent any kind of replay
274        self.state_store.del(&state_key).await.map_err(|e| Error::StateStore(Box::new(e)))?;
275
276        let metadata = self.resolver.get_authorization_server_metadata(&state.iss).await?;
277        // https://datatracker.ietf.org/doc/html/rfc9207#section-2.4
278        if let Some(iss) = params.iss {
279            if iss != metadata.issuer {
280                return Err(Error::Callback(format!(
281                    "issuer mismatch: expected {}, got {iss}",
282                    metadata.issuer
283                )));
284            }
285        } else if metadata.authorization_response_iss_parameter_supported == Some(true) {
286            return Err(Error::Callback("missing `iss` parameter".into()));
287        }
288        let server =
289            self.server_factory.build_from_metadata(state.dpop_key.clone(), metadata.clone())?;
290        match server.exchange_code(&params.code, &state.verifier).await {
291            Ok(token_set) => {
292                let sub = token_set.sub.clone();
293                self.session_registry
294                    .set(sub.clone(), Session { dpop_key: state.dpop_key.clone(), token_set })
295                    .await
296                    .map_err(|e| Error::SessionStore(Box::new(e)))?;
297                Ok((self.create_session(server, &sub).await?, state.app_state))
298            }
299            Err(_) => {
300                todo!()
301            }
302        }
303    }
304    /// Load a stored session by giving the subject DID.
305    ///
306    /// This method will return the [`OAuthSession`] if it exists.
307    pub async fn restore(&self, sub: &Did) -> Result<OAuthSession<T, D, H, S1>> {
308        // let session_handle = self.session_registry.get(sub).await?;
309        // let session = session_handle.read().await.session();
310        let session = self.session_registry.get(sub, false).await?;
311        self.create_session(
312            self.server_factory.build_from_issuer(session.dpop_key, &session.token_set.iss).await?,
313            sub,
314        )
315        .await
316    }
317    /// Revoke a session by giving the subject DID.
318    pub async fn revoke(&self, sub: &Did) -> Result<()> {
319        let session = self.session_registry.get(sub, false).await?;
320        let server_agent =
321            self.server_factory.build_from_issuer(session.dpop_key, &session.token_set.iss).await?;
322        server_agent.revoke(&session.token_set.access_token).await?;
323        self.session_registry.del(sub).await.map_err(|e| Error::SessionStore(Box::new(e)))
324    }
325    async fn create_session(
326        &self,
327        server: OAuthServerAgent<T, D, H>,
328        sub: &Did,
329    ) -> Result<OAuthSession<T, D, H, S1>> {
330        Ok(OAuthSession::new(
331            server.server_metadata.clone(),
332            sub.clone(),
333            Arc::clone(&self.http_client),
334            Arc::clone(&self.session_registry),
335        )
336        .await?)
337    }
338    fn generate_dpop_key(metadata: &OAuthAuthorizationServerMetadata) -> Option<Key> {
339        let mut algs =
340            metadata.dpop_signing_alg_values_supported.clone().unwrap_or(vec![FALLBACK_ALG.into()]);
341        algs.sort_by(compare_algos);
342        generate_key(&algs)
343    }
344    fn generate_pkce() -> (String, String) {
345        // https://datatracker.ietf.org/doc/html/rfc7636#section-4.1
346        let verifier = [generate_nonce(), generate_nonce()].join("");
347        (URL_SAFE_NO_PAD.encode(Sha256::digest(&verifier)), verifier)
348    }
349}