1use crate::{
2 constants::FALLBACK_ALG,
3 error::{Error, Result},
4 keyset::Keyset,
5 oauth_session::OAuthSession,
6 resolver::{OAuthResolver, OAuthResolverConfig},
7 server_agent::{OAuthRequest, OAuthServerAgent, OAuthServerFactory},
8 store::{
9 session::{Session, SessionStore},
10 session_registry::SessionRegistry,
11 state::{InternalStateData, StateStore},
12 },
13 types::{
14 AuthorizationCodeChallengeMethod, AuthorizationResponseType, AuthorizeOptions,
15 CallbackParams, OAuthAuthorizationServerMetadata, OAuthClientMetadata,
16 OAuthPusehedAuthorizationRequestResponse, PushedAuthorizationRequestParameters,
17 TryIntoOAuthClientMetadata,
18 },
19 utils::{compare_algos, generate_key, generate_nonce},
20};
21use atrium_api::{
22 did_doc::DidDocument,
23 types::string::{Did, Handle},
24};
25use atrium_common::resolver::Resolver;
26use atrium_xrpc::HttpClient;
27use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
28use jose_jwk::{Jwk, JwkSet, Key};
29use serde::Serialize;
30use sha2::{Digest, Sha256};
31use std::sync::Arc;
32
33#[cfg(feature = "default-client")]
34pub struct OAuthClientConfig<S0, S1, M, D, H>
35where
36 M: TryIntoOAuthClientMetadata,
37{
38 pub client_metadata: M,
40 pub keys: Option<Vec<Jwk>>,
41 pub state_store: S0,
43 pub session_store: S1,
44 pub resolver: OAuthResolverConfig<D, H>,
46}
47
48#[cfg(not(feature = "default-client"))]
49pub struct OAuthClientConfig<S0, S1, T, M, D, H>
50where
51 M: TryIntoOAuthClientMetadata,
52{
53 pub client_metadata: M,
55 pub keys: Option<Vec<Jwk>>,
56 pub state_store: S0,
58 pub session_store: S1,
59 pub resolver: OAuthResolverConfig<D, H>,
61 pub http_client: T,
63}
64
65#[cfg(feature = "default-client")]
69pub struct OAuthClient<S0, S1, D, H, T = crate::http_client::default::DefaultHttpClient>
70where
71 T: HttpClient + Send + Sync + 'static,
72 S1: SessionStore + Send + Sync + 'static,
73 S1::Error: std::error::Error + Send + Sync + 'static,
74{
75 pub client_metadata: OAuthClientMetadata,
76 keyset: Option<Keyset>,
77 resolver: Arc<OAuthResolver<T, D, H>>,
78 server_factory: Arc<OAuthServerFactory<T, D, H>>,
79 state_store: S0,
80 session_registry: Arc<SessionRegistry<S1, T, D, H>>,
81 http_client: Arc<T>,
82}
83
84#[cfg(not(feature = "default-client"))]
85pub struct OAuthClient<S0, S1, D, H, T>
86where
87 T: HttpClient + Send + Sync + 'static,
88 S1: SessionStore + Send + Sync + 'static,
89 S1::Error: std::error::Error + Send + Sync + 'static,
90{
91 pub client_metadata: OAuthClientMetadata,
92 keyset: Option<Keyset>,
93 resolver: Arc<OAuthResolver<T, D, H>>,
94 server_factory: Arc<OAuthServerFactory<T, D, H>>,
95 state_store: S0,
96 session_registry: Arc<SessionRegistry<S1, T, D, H>>,
97 http_client: Arc<T>,
98}
99
100#[cfg(feature = "default-client")]
101impl<S0, S1, D, H> OAuthClient<S0, S1, D, H, crate::http_client::default::DefaultHttpClient>
102where
103 S1: SessionStore + Send + Sync + 'static,
104 S1::Error: std::error::Error + Send + Sync + 'static,
105{
106 pub fn new<M>(config: OAuthClientConfig<S0, S1, M, D, H>) -> Result<Self>
108 where
109 M: TryIntoOAuthClientMetadata<Error = crate::atproto::Error>,
110 {
111 let keyset = if let Some(keys) = config.keys { Some(keys.try_into()?) } else { None };
112 let client_metadata = config.client_metadata.try_into_client_metadata(&keyset)?;
113 let http_client = Arc::new(crate::http_client::default::DefaultHttpClient::default());
114 let resolver = Arc::new(OAuthResolver::new(config.resolver, Arc::clone(&http_client)));
115 let server_factory = Arc::new(OAuthServerFactory::new(
116 client_metadata.clone(),
117 Arc::clone(&resolver),
118 Arc::clone(&http_client),
119 keyset.clone(),
120 ));
121 let session_registry =
122 Arc::new(SessionRegistry::new(config.session_store, Arc::clone(&server_factory)));
123 Ok(Self {
124 client_metadata,
125 keyset,
126 resolver,
127 server_factory,
128 state_store: config.state_store,
129 session_registry,
130 http_client,
131 })
132 }
133}
134
135#[cfg(not(feature = "default-client"))]
136impl<S0, S1, D, H, T> OAuthClient<S0, S1, D, H, T>
137where
138 T: HttpClient + Send + Sync + 'static,
139 S1: SessionStore + Send + Sync + 'static,
140 S1::Error: std::error::Error + Send + Sync + 'static,
141{
142 pub fn new<M>(config: OAuthClientConfig<S0, S1, T, M, D, H>) -> Result<Self>
143 where
144 M: TryIntoOAuthClientMetadata<Error = crate::atproto::Error>,
145 {
146 let keyset = if let Some(keys) = config.keys { Some(keys.try_into()?) } else { None };
147 let client_metadata = config.client_metadata.try_into_client_metadata(&keyset)?;
148 let http_client = Arc::new(config.http_client);
149 let resolver = Arc::new(OAuthResolver::new(config.resolver, Arc::clone(&http_client)));
150 let server_factory = Arc::new(OAuthServerFactory::new(
151 client_metadata.clone(),
152 Arc::clone(&resolver),
153 Arc::clone(&http_client),
154 keyset.clone(),
155 ));
156 let session_registry =
157 Arc::new(SessionRegistry::new(config.session_store, Arc::clone(&server_factory)));
158 Ok(Self {
159 client_metadata,
160 keyset,
161 resolver,
162 server_factory,
163 state_store: config.state_store,
164 session_registry,
165 http_client,
166 })
167 }
168}
169
170impl<S0, S1, D, H, T> OAuthClient<S0, S1, D, H, T>
171where
172 S0: StateStore + Send + Sync + 'static,
173 S1: SessionStore + Send + Sync + 'static,
174 D: Resolver<Input = Did, Output = DidDocument, Error = atrium_identity::Error> + Send + Sync,
175 H: Resolver<Input = Handle, Output = Did, Error = atrium_identity::Error> + Send + Sync,
176 T: HttpClient + Send + Sync + 'static,
177 S0::Error: std::error::Error + Send + Sync + 'static,
178 S1::Error: std::error::Error + Send + Sync + 'static,
179{
180 pub fn jwks(&self) -> JwkSet {
182 self.keyset.as_ref().map(|keyset| keyset.public_jwks()).unwrap_or_default()
183 }
184 pub async fn authorize(
188 &self,
189 input: impl AsRef<str>,
190 options: AuthorizeOptions,
191 ) -> Result<String> {
192 let redirect_uri = if let Some(uri) = options.redirect_uri {
193 if !self.client_metadata.redirect_uris.contains(&uri) {
194 return Err(Error::Authorize("invalid redirect_uri".into()));
195 }
196 uri
197 } else {
198 self.client_metadata.redirect_uris[0].clone()
199 };
200 let (metadata, identity) = self.resolver.resolve(input.as_ref()).await?;
201 let Some(dpop_key) = Self::generate_dpop_key(&metadata) else {
202 return Err(Error::Authorize("none of the algorithms worked".into()));
203 };
204 let (code_challenge, verifier) = Self::generate_pkce();
205 let state = generate_nonce();
206 let state_data = InternalStateData {
207 iss: metadata.issuer.clone(),
208 dpop_key: dpop_key.clone(),
209 verifier,
210 app_state: options.state,
211 };
212 self.state_store
213 .set(state.clone(), state_data)
214 .await
215 .map_err(|e| Error::StateStore(Box::new(e)))?;
216 let login_hint = if identity.is_some() { Some(input.as_ref().into()) } else { None };
217 let parameters = PushedAuthorizationRequestParameters {
218 response_type: AuthorizationResponseType::Code,
219 redirect_uri,
220 state,
221 scope: Some(options.scopes.iter().map(AsRef::as_ref).collect::<Vec<_>>().join(" ")),
222 response_mode: None,
223 code_challenge,
224 code_challenge_method: AuthorizationCodeChallengeMethod::S256,
225 login_hint,
226 prompt: options.prompt.map(String::from),
227 };
228 if metadata.pushed_authorization_request_endpoint.is_some() {
229 let server = self.server_factory.build_from_metadata(dpop_key, metadata.clone())?;
230 let par_response = server
231 .request::<OAuthPusehedAuthorizationRequestResponse>(
232 OAuthRequest::PushedAuthorizationRequest(parameters),
233 )
234 .await?;
235
236 #[derive(Serialize)]
237 struct Parameters {
238 client_id: String,
239 request_uri: String,
240 }
241 Ok(metadata.authorization_endpoint
242 + "?"
243 + &serde_html_form::to_string(Parameters {
244 client_id: self.client_metadata.client_id.clone(),
245 request_uri: par_response.request_uri,
246 })
247 .unwrap())
248 } else if metadata.require_pushed_authorization_requests == Some(true) {
249 Err(Error::Authorize("server requires PAR but no endpoint is available".into()))
250 } else {
251 todo!()
254 }
255 }
256 pub async fn callback(
261 &self,
262 params: CallbackParams,
263 ) -> Result<(OAuthSession<T, D, H, S1>, Option<String>)> {
264 let Some(state_key) = params.state else {
265 return Err(Error::Callback("missing `state` parameter".into()));
266 };
267
268 let Some(state) =
269 self.state_store.get(&state_key).await.map_err(|e| Error::StateStore(Box::new(e)))?
270 else {
271 return Err(Error::Callback(format!("unknown authorization state: {state_key}")));
272 };
273 self.state_store.del(&state_key).await.map_err(|e| Error::StateStore(Box::new(e)))?;
275
276 let metadata = self.resolver.get_authorization_server_metadata(&state.iss).await?;
277 if let Some(iss) = params.iss {
279 if iss != metadata.issuer {
280 return Err(Error::Callback(format!(
281 "issuer mismatch: expected {}, got {iss}",
282 metadata.issuer
283 )));
284 }
285 } else if metadata.authorization_response_iss_parameter_supported == Some(true) {
286 return Err(Error::Callback("missing `iss` parameter".into()));
287 }
288 let server =
289 self.server_factory.build_from_metadata(state.dpop_key.clone(), metadata.clone())?;
290 match server.exchange_code(¶ms.code, &state.verifier).await {
291 Ok(token_set) => {
292 let sub = token_set.sub.clone();
293 self.session_registry
294 .set(sub.clone(), Session { dpop_key: state.dpop_key.clone(), token_set })
295 .await
296 .map_err(|e| Error::SessionStore(Box::new(e)))?;
297 Ok((self.create_session(server, &sub).await?, state.app_state))
298 }
299 Err(_) => {
300 todo!()
301 }
302 }
303 }
304 pub async fn restore(&self, sub: &Did) -> Result<OAuthSession<T, D, H, S1>> {
308 let session = self.session_registry.get(sub, false).await?;
311 self.create_session(
312 self.server_factory.build_from_issuer(session.dpop_key, &session.token_set.iss).await?,
313 sub,
314 )
315 .await
316 }
317 pub async fn revoke(&self, sub: &Did) -> Result<()> {
319 let session = self.session_registry.get(sub, false).await?;
320 let server_agent =
321 self.server_factory.build_from_issuer(session.dpop_key, &session.token_set.iss).await?;
322 server_agent.revoke(&session.token_set.access_token).await?;
323 self.session_registry.del(sub).await.map_err(|e| Error::SessionStore(Box::new(e)))
324 }
325 async fn create_session(
326 &self,
327 server: OAuthServerAgent<T, D, H>,
328 sub: &Did,
329 ) -> Result<OAuthSession<T, D, H, S1>> {
330 Ok(OAuthSession::new(
331 server.server_metadata.clone(),
332 sub.clone(),
333 Arc::clone(&self.http_client),
334 Arc::clone(&self.session_registry),
335 )
336 .await?)
337 }
338 fn generate_dpop_key(metadata: &OAuthAuthorizationServerMetadata) -> Option<Key> {
339 let mut algs =
340 metadata.dpop_signing_alg_values_supported.clone().unwrap_or(vec![FALLBACK_ALG.into()]);
341 algs.sort_by(compare_algos);
342 generate_key(&algs)
343 }
344 fn generate_pkce() -> (String, String) {
345 let verifier = [generate_nonce(), generate_nonce()].join("");
347 (URL_SAFE_NO_PAD.encode(Sha256::digest(&verifier)), verifier)
348 }
349}