1mod inner;
2mod store;
3
4use self::store::MemorySessionStore;
5use crate::{
6 http_client::dpop::DpopClient,
7 store::{session::SessionStore, session_registry::SessionRegistry},
8 types::OAuthAuthorizationServerMetadata,
9};
10use atrium_api::{
11 agent::{utils::SessionWithEndpointStore, CloneWithProxy, Configure, SessionManager},
12 did_doc::DidDocument,
13 types::string::{Did, Handle},
14};
15use atrium_common::resolver::Resolver;
16use atrium_xrpc::{
17 http::{Request, Response},
18 HttpClient, OutputDataOrBytes, XrpcClient, XrpcRequest,
19};
20use serde::{de::DeserializeOwned, Serialize};
21use std::{fmt::Debug, sync::Arc};
22use thiserror::Error;
23
24#[derive(Error, Debug)]
25pub enum Error {
26 #[error(transparent)]
27 Dpop(#[from] crate::http_client::dpop::Error),
28 #[error(transparent)]
29 SessionRegistry(#[from] crate::store::session_registry::Error),
30 #[error(transparent)]
31 Store(#[from] atrium_common::store::memory::Error),
32}
33
34pub struct OAuthSession<T, D, H, S>
35where
36 T: HttpClient + Send + Sync + 'static,
37 S: SessionStore + Send + Sync + 'static,
38{
39 store: Arc<SessionWithEndpointStore<store::MemorySessionStore, String>>,
40 inner: inner::Client<S, T, D, H>,
41 sub: Did,
42 session_registry: Arc<SessionRegistry<S, T, D, H>>,
43}
44
45impl<T, D, H, S> OAuthSession<T, D, H, S>
46where
47 T: HttpClient + Send + Sync,
48 D: Resolver<Input = Did, Output = DidDocument, Error = atrium_identity::Error> + Send + Sync,
49 H: Resolver<Input = Handle, Output = Did, Error = atrium_identity::Error> + Send + Sync,
50 S: SessionStore + Send + Sync + 'static,
51{
52 pub(crate) async fn new(
53 server_metadata: OAuthAuthorizationServerMetadata,
54 sub: Did,
55 http_client: Arc<T>,
56 session_registry: Arc<SessionRegistry<S, T, D, H>>,
57 ) -> Result<Self, Error> {
58 let (dpop_key, token_set) = {
60 let s = session_registry.get(&sub, false).await?;
61 (s.dpop_key.clone(), s.token_set.clone())
62 };
63 let store = Arc::new(SessionWithEndpointStore::new(
64 MemorySessionStore::default(),
65 token_set.aud.clone(),
66 ));
67 store.set(token_set.access_token.clone()).await?;
68 let inner = inner::Client::new(
70 Arc::clone(&store),
71 DpopClient::new(
72 dpop_key,
73 http_client,
74 false,
75 &server_metadata.token_endpoint_auth_signing_alg_values_supported,
76 )?,
77 sub.clone(),
78 Arc::clone(&session_registry),
79 );
80 Ok(Self { store, inner, sub, session_registry })
81 }
82}
83
84impl<T, D, H, S> HttpClient for OAuthSession<T, D, H, S>
85where
86 T: HttpClient + Send + Sync + 'static,
87 D: Send + Sync,
88 H: Send + Sync,
89 S: SessionStore + Send + Sync,
90{
91 async fn send_http(
92 &self,
93 request: Request<Vec<u8>>,
94 ) -> Result<Response<Vec<u8>>, Box<dyn std::error::Error + Send + Sync + 'static>> {
95 self.inner.send_http(request).await
96 }
97}
98
99impl<T, D, H, S> XrpcClient for OAuthSession<T, D, H, S>
100where
101 T: HttpClient + Send + Sync + 'static,
102 D: Resolver<Input = Did, Output = DidDocument, Error = atrium_identity::Error> + Send + Sync,
103 H: Resolver<Input = Handle, Output = Did, Error = atrium_identity::Error> + Send + Sync,
104 S: SessionStore + Send + Sync + 'static,
105{
106 fn base_uri(&self) -> String {
107 self.inner.base_uri()
108 }
109 async fn send_xrpc<P, I, O, E>(
110 &self,
111 request: &XrpcRequest<P, I>,
112 ) -> Result<OutputDataOrBytes<O>, atrium_xrpc::Error<E>>
113 where
114 P: Serialize + Send + Sync,
115 I: Serialize + Send + Sync,
116 O: DeserializeOwned + Send + Sync,
117 E: DeserializeOwned + Send + Sync + Debug,
118 {
119 self.inner.send_xrpc(request).await
120 }
121}
122
123impl<T, D, H, S> SessionManager for OAuthSession<T, D, H, S>
124where
125 T: HttpClient + Send + Sync + 'static,
126 D: Resolver<Input = Did, Output = DidDocument, Error = atrium_identity::Error> + Send + Sync,
127 H: Resolver<Input = Handle, Output = Did, Error = atrium_identity::Error> + Send + Sync,
128 S: SessionStore + Send + Sync + 'static,
129{
130 async fn did(&self) -> Option<Did> {
131 Some(self.sub.clone())
132 }
133}
134
135impl<T, D, H, S> Configure for OAuthSession<T, D, H, S>
136where
137 T: HttpClient + Send + Sync,
138 S: SessionStore + Send + Sync,
139{
140 fn configure_endpoint(&self, endpoint: String) {
141 self.inner.configure_endpoint(endpoint);
142 }
143 fn configure_labelers_header(&self, labeler_dids: Option<Vec<(Did, bool)>>) {
144 self.inner.configure_labelers_header(labeler_dids);
145 }
146 fn configure_proxy_header(&self, did: Did, service_type: impl AsRef<str>) {
147 self.inner.configure_proxy_header(did, service_type);
148 }
149}
150
151impl<T, D, H, S> CloneWithProxy for OAuthSession<T, D, H, S>
152where
153 T: HttpClient + Send + Sync,
154 S: SessionStore + Send + Sync,
155{
156 fn clone_with_proxy(&self, did: Did, service_type: impl AsRef<str>) -> Self {
157 Self {
158 store: self.store.clone(),
159 inner: self.inner.clone_with_proxy(did, service_type),
160 sub: self.sub.clone(),
161 session_registry: Arc::clone(&self.session_registry),
162 }
163 }
164}
165
166#[cfg(test)]
167mod tests {
168 use super::*;
169 use crate::server_agent::OAuthServerFactory;
170 use crate::tests::{
171 client_metadata, dpop_key, oauth_resolver, protected_resource_metadata, server_metadata,
172 MockDidResolver, NoopHandleResolver,
173 };
174 use crate::{
175 jose::jwt::Claims,
176 store::session::Session,
177 types::{OAuthTokenResponse, OAuthTokenType, RefreshRequestParameters, TokenSet},
178 };
179 use atrium_api::{
180 agent::{Agent, AtprotoServiceType},
181 client::Service,
182 xrpc::http::{header::CONTENT_TYPE, HeaderMap, HeaderName, HeaderValue, StatusCode},
183 };
184 use atrium_common::store::Store;
185 use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
186 use std::{collections::HashMap, time::Duration};
187 use tokio::sync::Mutex;
188
189 #[derive(Default)]
190 struct RecordData {
191 host: Option<String>,
192 headers: HeaderMap<HeaderValue>,
193 }
194
195 struct MockHttpClient {
196 data: Arc<Mutex<Option<RecordData>>>,
197 next_token: Arc<Mutex<Option<OAuthTokenResponse>>>,
198 }
199
200 impl MockHttpClient {
201 fn new(data: Arc<Mutex<Option<RecordData>>>) -> Self {
202 Self {
203 data,
204 next_token: Arc::new(Mutex::new(Some(OAuthTokenResponse {
205 access_token: String::from("new_accesstoken"),
206 token_type: OAuthTokenType::DPoP,
207 expires_in: Some(10),
208 refresh_token: Some(String::from("new_refreshtoken")),
209 scope: None,
210 sub: None,
211 }))),
212 }
213 }
214 }
215
216 impl HttpClient for MockHttpClient {
217 async fn send_http(
218 &self,
219 request: Request<Vec<u8>>,
220 ) -> Result<Response<Vec<u8>>, Box<dyn std::error::Error + Send + Sync + 'static>> {
221 tokio::time::sleep(std::time::Duration::from_micros(0)).await;
223
224 match (request.uri().host(), request.uri().path()) {
225 (Some("iss.example.com"), "/.well-known/oauth-authorization-server") => {
226 return Response::builder()
227 .header(CONTENT_TYPE, "application/json")
228 .body(serde_json::to_vec(&server_metadata())?)
229 .map_err(|e| e.into());
230 }
231 (Some("aud.example.com"), "/.well-known/oauth-protected-resource") => {
232 return Response::builder()
233 .header(CONTENT_TYPE, "application/json")
234 .body(serde_json::to_vec(&protected_resource_metadata())?)
235 .map_err(|e| e.into());
236 }
237 _ => {}
238 }
239
240 let mut headers = request.headers().clone();
241 let Some(authorization) = headers
242 .remove("authorization")
243 .and_then(|value| value.to_str().map(String::from).ok())
244 else {
245 let response = if request.uri().path() == "/token" {
246 let parameters =
247 serde_html_form::from_bytes::<RefreshRequestParameters>(request.body())?;
248 let token_response = if parameters.refresh_token == "refreshtoken" {
249 self.next_token.lock().await.take()
250 } else {
251 None
252 };
253 if let Some(token_response) = token_response {
254 Response::builder()
255 .status(StatusCode::OK)
256 .header(CONTENT_TYPE, "application/json")
257 .body(serde_json::to_vec(&token_response)?)?
258 } else {
259 Response::builder()
260 .status(StatusCode::UNAUTHORIZED)
261 .header("WWW-Authenticate", "DPoP error=\"invalid_token\"")
262 .body(Vec::new())?
263 }
264 } else {
265 Response::builder().status(StatusCode::UNAUTHORIZED).body(Vec::new())?
266 };
267 return Ok(response);
268 };
269 let Some(token) = authorization.strip_prefix("DPoP ") else {
270 panic!("authorization header should start with DPoP");
271 };
272 if token == "expired" {
273 return Ok(Response::builder()
274 .status(StatusCode::UNAUTHORIZED)
275 .header("WWW-Authenticate", "DPoP error=\"invalid_token\"")
276 .body(Vec::new())?);
277 }
278 let dpop_jwt = headers.remove("dpop").expect("dpop header should be present");
279 let payload = dpop_jwt
280 .to_str()
281 .expect("dpop header should be valid")
282 .split('.')
283 .nth(1)
284 .expect("dpop header should have 2 parts");
285 let claims = URL_SAFE_NO_PAD
286 .decode(payload)
287 .ok()
288 .and_then(|value| serde_json::from_slice::<Claims>(&value).ok())
289 .expect("dpop payload should be valid");
290 assert!(claims.registered.iat.is_some());
291 assert!(claims.registered.jti.is_some());
292 assert_eq!(claims.public.htm, Some(request.method().to_string()));
293 assert_eq!(claims.public.htu, Some(request.uri().to_string()));
294
295 self.data
296 .lock()
297 .await
298 .replace(RecordData { host: request.uri().host().map(String::from), headers });
299 let output = atrium_api::com::atproto::server::get_service_auth::OutputData {
300 token: String::from("fake_token"),
301 };
302 Response::builder()
303 .header(CONTENT_TYPE, "application/json")
304 .body(serde_json::to_vec(&output)?)
305 .map_err(|e| e.into())
306 }
307 }
308
309 struct MockSessionStore {
310 data: Arc<Mutex<HashMap<Did, Session>>>,
311 }
312
313 impl Store<Did, Session> for MockSessionStore {
314 type Error = Error;
315
316 async fn get(&self, key: &Did) -> Result<Option<Session>, Self::Error> {
317 tokio::time::sleep(Duration::from_micros(10)).await;
318 Ok(self.data.lock().await.get(key).cloned())
319 }
320 async fn set(&self, key: Did, value: Session) -> Result<(), Self::Error> {
321 tokio::time::sleep(Duration::from_micros(10)).await;
322 self.data.lock().await.insert(key, value);
323 Ok(())
324 }
325 async fn del(&self, _: &Did) -> Result<(), Self::Error> {
326 unimplemented!()
327 }
328 async fn clear(&self) -> Result<(), Self::Error> {
329 unimplemented!()
330 }
331 }
332
333 impl SessionStore for MockSessionStore {}
334
335 fn did() -> Did {
336 Did::new(String::from("did:fake:sub.test")).expect("did should be valid")
337 }
338
339 fn default_store() -> Arc<Mutex<HashMap<Did, Session>>> {
340 let did = did();
341 let token_set = TokenSet {
342 iss: String::from("https://iss.example.com"),
343 sub: did.clone(),
344 aud: String::from("https://aud.example.com"),
345 scope: None,
346 refresh_token: Some(String::from("refreshtoken")),
347 access_token: String::from("accesstoken"),
348 token_type: OAuthTokenType::DPoP,
349 expires_at: None,
350 };
351 let dpop_key = dpop_key();
352 let session = Session { token_set, dpop_key };
353 Arc::new(Mutex::new(HashMap::from_iter([(did, session)])))
354 }
355
356 async fn oauth_session(
357 data: Arc<Mutex<Option<RecordData>>>,
358 store: Arc<Mutex<HashMap<Did, Session>>>,
359 ) -> OAuthSession<MockHttpClient, MockDidResolver, NoopHandleResolver, MockSessionStore> {
360 let http_client = Arc::new(MockHttpClient::new(data));
361 let resolver = Arc::new(oauth_resolver(Arc::clone(&http_client)));
362 let server_factory = Arc::new(OAuthServerFactory::new(
363 client_metadata(),
364 resolver,
365 Arc::clone(&http_client),
366 None,
367 ));
368 let session_registory = Arc::new(SessionRegistry::new(
369 MockSessionStore { data: Arc::clone(&store) },
370 server_factory,
371 ));
372 OAuthSession::new(server_metadata(), did(), http_client, session_registory)
373 .await
374 .expect("failed to create oauth session")
375 }
376
377 async fn oauth_agent(
378 data: Arc<Mutex<Option<RecordData>>>,
379 ) -> Agent<impl SessionManager + Configure + CloneWithProxy> {
380 Agent::new(oauth_session(data, default_store()).await)
381 }
382
383 async fn call_service(
384 service: &Service<impl SessionManager + Sync>,
385 ) -> Result<(), atrium_xrpc::Error<atrium_api::com::atproto::server::get_service_auth::Error>>
386 {
387 let output = service
388 .com
389 .atproto
390 .server
391 .get_service_auth(
392 atrium_api::com::atproto::server::get_service_auth::ParametersData {
393 aud: Did::new(String::from("did:fake:handle.test"))
394 .expect("did should be valid"),
395 exp: None,
396 lxm: None,
397 }
398 .into(),
399 )
400 .await?;
401 assert_eq!(output.token, "fake_token");
402 Ok(())
403 }
404
405 #[tokio::test]
406 async fn test_new() -> Result<(), Box<dyn std::error::Error>> {
407 let agent = oauth_agent(Default::default()).await;
408 assert_eq!(agent.did().await.as_deref(), Some("did:fake:sub.test"));
409 Ok(())
410 }
411
412 #[tokio::test]
413 async fn test_configure_endpoint() -> Result<(), Box<dyn std::error::Error>> {
414 let data = Default::default();
415 let agent = oauth_agent(Arc::clone(&data)).await;
416 call_service(&agent.api).await?;
417 assert_eq!(
418 data.lock().await.as_ref().expect("data should be recorded").host.as_deref(),
419 Some("aud.example.com")
420 );
421 agent.configure_endpoint(String::from("https://pds.example.com"));
422 call_service(&agent.api).await?;
423 assert_eq!(
424 data.lock().await.as_ref().expect("data should be recorded").host.as_deref(),
425 Some("pds.example.com")
426 );
427 Ok(())
428 }
429
430 #[tokio::test]
431 async fn test_configure_labelers_header() -> Result<(), Box<dyn std::error::Error>> {
432 let data = Default::default();
433 let agent = oauth_agent(Arc::clone(&data)).await;
434 {
436 call_service(&agent.api).await?;
437 assert_eq!(
438 data.lock().await.as_ref().expect("data should be recorded").headers,
439 HeaderMap::new()
440 );
441 }
442 {
444 agent.configure_labelers_header(Some(vec![(
445 Did::new(String::from("did:fake:labeler.test"))?,
446 false,
447 )]));
448 call_service(&agent.api).await?;
449 assert_eq!(
450 data.lock().await.as_ref().expect("data should be recorded").headers,
451 HeaderMap::from_iter([(
452 HeaderName::from_static("atproto-accept-labelers"),
453 HeaderValue::from_static("did:fake:labeler.test"),
454 )])
455 );
456 }
457 {
459 agent.configure_labelers_header(Some(vec![
460 (Did::new(String::from("did:fake:labeler.test_redact"))?, true),
461 (Did::new(String::from("did:fake:labeler.test"))?, false),
462 ]));
463 call_service(&agent.api).await?;
464 assert_eq!(
465 data.lock().await.as_ref().expect("data should be recorded").headers,
466 HeaderMap::from_iter([(
467 HeaderName::from_static("atproto-accept-labelers"),
468 HeaderValue::from_static(
469 "did:fake:labeler.test_redact;redact, did:fake:labeler.test"
470 ),
471 )])
472 );
473 }
474 Ok(())
475 }
476
477 #[tokio::test]
478 async fn test_configure_proxy_header() -> Result<(), Box<dyn std::error::Error>> {
479 let data = Arc::new(Mutex::new(Default::default()));
480 let agent = oauth_agent(Arc::clone(&data)).await;
481 {
483 call_service(&agent.api).await?;
484 assert_eq!(
485 data.lock().await.as_ref().expect("data should be recorded").headers,
486 HeaderMap::new()
487 );
488 }
489 {
491 agent.configure_proxy_header(
492 Did::new(String::from("did:fake:service.test"))?,
493 AtprotoServiceType::AtprotoLabeler,
494 );
495 call_service(&agent.api).await?;
496 assert_eq!(
497 data.lock().await.as_ref().expect("data should be recorded").headers,
498 HeaderMap::from_iter([(
499 HeaderName::from_static("atproto-proxy"),
500 HeaderValue::from_static("did:fake:service.test#atproto_labeler"),
501 )])
502 );
503 }
504 {
506 agent.configure_proxy_header(
507 Did::new(String::from("did:fake:service.test"))?,
508 "custom_service",
509 );
510 call_service(&agent.api).await?;
511 assert_eq!(
512 data.lock().await.as_ref().expect("data should be recorded").headers,
513 HeaderMap::from_iter([(
514 HeaderName::from_static("atproto-proxy"),
515 HeaderValue::from_static("did:fake:service.test#custom_service"),
516 )])
517 );
518 }
519 {
521 call_service(
522 &agent.api_with_proxy(
523 Did::new(String::from("did:fake:service.test"))?,
524 "temp_service",
525 ),
526 )
527 .await?;
528 assert_eq!(
529 data.lock().await.as_ref().expect("data should be recorded").headers,
530 HeaderMap::from_iter([(
531 HeaderName::from_static("atproto-proxy"),
532 HeaderValue::from_static("did:fake:service.test#temp_service"),
533 )])
534 );
535 call_service(&agent.api).await?;
536 assert_eq!(
537 data.lock().await.as_ref().expect("data should be recorded").headers,
538 HeaderMap::from_iter([(
539 HeaderName::from_static("atproto-proxy"),
540 HeaderValue::from_static("did:fake:service.test#custom_service"),
541 )])
542 );
543 }
544 Ok(())
545 }
546
547 #[tokio::test]
548 async fn test_xrpc_without_token() -> Result<(), Box<dyn std::error::Error>> {
549 let oauth_session = oauth_session(Default::default(), default_store()).await;
550 oauth_session.store.clear().await?;
551 let agent = Agent::new(oauth_session);
552 let result = agent
553 .api
554 .com
555 .atproto
556 .server
557 .get_service_auth(
558 atrium_api::com::atproto::server::get_service_auth::ParametersData {
559 aud: Did::new(String::from("did:fake:handle.test"))
560 .expect("did should be valid"),
561 exp: None,
562 lxm: None,
563 }
564 .into(),
565 )
566 .await;
567 match result.expect_err("should fail without token") {
568 atrium_xrpc::Error::XrpcResponse(err) => {
569 assert_eq!(err.status, StatusCode::UNAUTHORIZED);
570 }
571 _ => panic!("unexpected error"),
572 }
573 Ok(())
574 }
575
576 #[tokio::test]
577 async fn test_xrpc_with_refresh() -> Result<(), Box<dyn std::error::Error>> {
578 let session_data = default_store();
579 if let Some(session) = session_data.lock().await.get_mut(&did()) {
580 session.token_set.access_token = String::from("expired");
581 }
582 let oauth_session = oauth_session(Default::default(), Arc::clone(&session_data)).await;
583 let agent = Agent::new(oauth_session);
584 let result = agent
585 .api
586 .com
587 .atproto
588 .server
589 .get_service_auth(
590 atrium_api::com::atproto::server::get_service_auth::ParametersData {
591 aud: Did::new(String::from("did:fake:handle.test"))
592 .expect("did should be valid"),
593 exp: None,
594 lxm: None,
595 }
596 .into(),
597 )
598 .await;
599 match result {
600 Ok(output) => {
601 assert_eq!(output.token, "fake_token");
602 }
603 Err(err) => {
604 panic!("unexpected error: {err:?}");
605 }
606 }
607 tokio::time::sleep(Duration::from_micros(0)).await;
609 {
610 let token_set = session_data
611 .lock()
612 .await
613 .get(&did())
614 .expect("session should be present")
615 .token_set
616 .clone();
617 assert_eq!(token_set.access_token, "new_accesstoken");
618 assert_eq!(token_set.refresh_token, Some(String::from("new_refreshtoken")));
619 }
620 Ok(())
621 }
622
623 #[tokio::test]
624 async fn test_xrpc_with_duplicated_refresh() -> Result<(), Box<dyn std::error::Error>> {
625 let session_data = default_store();
626 if let Some(session) = session_data.lock().await.get_mut(&did()) {
627 session.token_set.access_token = String::from("expired");
628 }
629 let oauth_session = oauth_session(Default::default(), session_data).await;
630 let agent = Arc::new(Agent::new(oauth_session));
631
632 let handles = (0..3).map(|_| {
633 let agent = Arc::clone(&agent);
634 tokio::spawn(async move {
635 agent
636 .api
637 .com
638 .atproto
639 .server
640 .get_service_auth(
641 atrium_api::com::atproto::server::get_service_auth::ParametersData {
642 aud: Did::new(String::from("did:fake:handle.test"))
643 .expect("did should be valid"),
644 exp: None,
645 lxm: None,
646 }
647 .into(),
648 )
649 .await
650 })
651 });
652 for result in futures::future::join_all(handles).await {
653 match result? {
654 Ok(output) => {
655 assert_eq!(output.token, "fake_token");
656 }
657 Err(err) => {
658 panic!("unexpected error: {err:?}");
659 }
660 }
661 }
662 Ok(())
663 }
664}