atrium_oauth/
oauth_session.rs

1mod inner;
2mod store;
3
4use self::store::MemorySessionStore;
5use crate::{
6    http_client::dpop::DpopClient,
7    store::{session::SessionStore, session_registry::SessionRegistry},
8    types::OAuthAuthorizationServerMetadata,
9};
10use atrium_api::{
11    agent::{utils::SessionWithEndpointStore, CloneWithProxy, Configure, SessionManager},
12    did_doc::DidDocument,
13    types::string::{Did, Handle},
14};
15use atrium_common::resolver::Resolver;
16use atrium_xrpc::{
17    http::{Request, Response},
18    HttpClient, OutputDataOrBytes, XrpcClient, XrpcRequest,
19};
20use serde::{de::DeserializeOwned, Serialize};
21use std::{fmt::Debug, sync::Arc};
22use thiserror::Error;
23
24#[derive(Error, Debug)]
25pub enum Error {
26    #[error(transparent)]
27    Dpop(#[from] crate::http_client::dpop::Error),
28    #[error(transparent)]
29    SessionRegistry(#[from] crate::store::session_registry::Error),
30    #[error(transparent)]
31    Store(#[from] atrium_common::store::memory::Error),
32}
33
34pub struct OAuthSession<T, D, H, S>
35where
36    T: HttpClient + Send + Sync + 'static,
37    S: SessionStore + Send + Sync + 'static,
38{
39    store: Arc<SessionWithEndpointStore<store::MemorySessionStore, String>>,
40    inner: inner::Client<S, T, D, H>,
41    sub: Did,
42    session_registry: Arc<SessionRegistry<S, T, D, H>>,
43}
44
45impl<T, D, H, S> OAuthSession<T, D, H, S>
46where
47    T: HttpClient + Send + Sync,
48    D: Resolver<Input = Did, Output = DidDocument, Error = atrium_identity::Error> + Send + Sync,
49    H: Resolver<Input = Handle, Output = Did, Error = atrium_identity::Error> + Send + Sync,
50    S: SessionStore + Send + Sync + 'static,
51{
52    pub(crate) async fn new(
53        server_metadata: OAuthAuthorizationServerMetadata,
54        sub: Did,
55        http_client: Arc<T>,
56        session_registry: Arc<SessionRegistry<S, T, D, H>>,
57    ) -> Result<Self, Error> {
58        // initialize SessionWithEndpointStore
59        let (dpop_key, token_set) = {
60            let s = session_registry.get(&sub, false).await?;
61            (s.dpop_key.clone(), s.token_set.clone())
62        };
63        let store = Arc::new(SessionWithEndpointStore::new(
64            MemorySessionStore::default(),
65            token_set.aud.clone(),
66        ));
67        store.set(token_set.access_token.clone()).await?;
68        // initialize inner client
69        let inner = inner::Client::new(
70            Arc::clone(&store),
71            DpopClient::new(
72                dpop_key,
73                http_client,
74                false,
75                &server_metadata.token_endpoint_auth_signing_alg_values_supported,
76            )?,
77            sub.clone(),
78            Arc::clone(&session_registry),
79        );
80        Ok(Self { store, inner, sub, session_registry })
81    }
82}
83
84impl<T, D, H, S> HttpClient for OAuthSession<T, D, H, S>
85where
86    T: HttpClient + Send + Sync + 'static,
87    D: Send + Sync,
88    H: Send + Sync,
89    S: SessionStore + Send + Sync,
90{
91    async fn send_http(
92        &self,
93        request: Request<Vec<u8>>,
94    ) -> Result<Response<Vec<u8>>, Box<dyn std::error::Error + Send + Sync + 'static>> {
95        self.inner.send_http(request).await
96    }
97}
98
99impl<T, D, H, S> XrpcClient for OAuthSession<T, D, H, S>
100where
101    T: HttpClient + Send + Sync + 'static,
102    D: Resolver<Input = Did, Output = DidDocument, Error = atrium_identity::Error> + Send + Sync,
103    H: Resolver<Input = Handle, Output = Did, Error = atrium_identity::Error> + Send + Sync,
104    S: SessionStore + Send + Sync + 'static,
105{
106    fn base_uri(&self) -> String {
107        self.inner.base_uri()
108    }
109    async fn send_xrpc<P, I, O, E>(
110        &self,
111        request: &XrpcRequest<P, I>,
112    ) -> Result<OutputDataOrBytes<O>, atrium_xrpc::Error<E>>
113    where
114        P: Serialize + Send + Sync,
115        I: Serialize + Send + Sync,
116        O: DeserializeOwned + Send + Sync,
117        E: DeserializeOwned + Send + Sync + Debug,
118    {
119        self.inner.send_xrpc(request).await
120    }
121}
122
123impl<T, D, H, S> SessionManager for OAuthSession<T, D, H, S>
124where
125    T: HttpClient + Send + Sync + 'static,
126    D: Resolver<Input = Did, Output = DidDocument, Error = atrium_identity::Error> + Send + Sync,
127    H: Resolver<Input = Handle, Output = Did, Error = atrium_identity::Error> + Send + Sync,
128    S: SessionStore + Send + Sync + 'static,
129{
130    async fn did(&self) -> Option<Did> {
131        Some(self.sub.clone())
132    }
133}
134
135impl<T, D, H, S> Configure for OAuthSession<T, D, H, S>
136where
137    T: HttpClient + Send + Sync,
138    S: SessionStore + Send + Sync,
139{
140    fn configure_endpoint(&self, endpoint: String) {
141        self.inner.configure_endpoint(endpoint);
142    }
143    fn configure_labelers_header(&self, labeler_dids: Option<Vec<(Did, bool)>>) {
144        self.inner.configure_labelers_header(labeler_dids);
145    }
146    fn configure_proxy_header(&self, did: Did, service_type: impl AsRef<str>) {
147        self.inner.configure_proxy_header(did, service_type);
148    }
149}
150
151impl<T, D, H, S> CloneWithProxy for OAuthSession<T, D, H, S>
152where
153    T: HttpClient + Send + Sync,
154    S: SessionStore + Send + Sync,
155{
156    fn clone_with_proxy(&self, did: Did, service_type: impl AsRef<str>) -> Self {
157        Self {
158            store: self.store.clone(),
159            inner: self.inner.clone_with_proxy(did, service_type),
160            sub: self.sub.clone(),
161            session_registry: Arc::clone(&self.session_registry),
162        }
163    }
164}
165
166#[cfg(test)]
167mod tests {
168    use super::*;
169    use crate::server_agent::OAuthServerFactory;
170    use crate::tests::{
171        client_metadata, dpop_key, oauth_resolver, protected_resource_metadata, server_metadata,
172        MockDidResolver, NoopHandleResolver,
173    };
174    use crate::{
175        jose::jwt::Claims,
176        store::session::Session,
177        types::{OAuthTokenResponse, OAuthTokenType, RefreshRequestParameters, TokenSet},
178    };
179    use atrium_api::{
180        agent::{Agent, AtprotoServiceType},
181        client::Service,
182        xrpc::http::{header::CONTENT_TYPE, HeaderMap, HeaderName, HeaderValue, StatusCode},
183    };
184    use atrium_common::store::Store;
185    use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
186    use std::{collections::HashMap, time::Duration};
187    use tokio::sync::Mutex;
188
189    #[derive(Default)]
190    struct RecordData {
191        host: Option<String>,
192        headers: HeaderMap<HeaderValue>,
193    }
194
195    struct MockHttpClient {
196        data: Arc<Mutex<Option<RecordData>>>,
197        next_token: Arc<Mutex<Option<OAuthTokenResponse>>>,
198    }
199
200    impl MockHttpClient {
201        fn new(data: Arc<Mutex<Option<RecordData>>>) -> Self {
202            Self {
203                data,
204                next_token: Arc::new(Mutex::new(Some(OAuthTokenResponse {
205                    access_token: String::from("new_accesstoken"),
206                    token_type: OAuthTokenType::DPoP,
207                    expires_in: Some(10),
208                    refresh_token: Some(String::from("new_refreshtoken")),
209                    scope: None,
210                    sub: None,
211                }))),
212            }
213        }
214    }
215
216    impl HttpClient for MockHttpClient {
217        async fn send_http(
218            &self,
219            request: Request<Vec<u8>>,
220        ) -> Result<Response<Vec<u8>>, Box<dyn std::error::Error + Send + Sync + 'static>> {
221            // tick tokio time
222            tokio::time::sleep(std::time::Duration::from_micros(0)).await;
223
224            match (request.uri().host(), request.uri().path()) {
225                (Some("iss.example.com"), "/.well-known/oauth-authorization-server") => {
226                    return Response::builder()
227                        .header(CONTENT_TYPE, "application/json")
228                        .body(serde_json::to_vec(&server_metadata())?)
229                        .map_err(|e| e.into());
230                }
231                (Some("aud.example.com"), "/.well-known/oauth-protected-resource") => {
232                    return Response::builder()
233                        .header(CONTENT_TYPE, "application/json")
234                        .body(serde_json::to_vec(&protected_resource_metadata())?)
235                        .map_err(|e| e.into());
236                }
237                _ => {}
238            }
239
240            let mut headers = request.headers().clone();
241            let Some(authorization) = headers
242                .remove("authorization")
243                .and_then(|value| value.to_str().map(String::from).ok())
244            else {
245                let response = if request.uri().path() == "/token" {
246                    let parameters =
247                        serde_html_form::from_bytes::<RefreshRequestParameters>(request.body())?;
248                    let token_response = if parameters.refresh_token == "refreshtoken" {
249                        self.next_token.lock().await.take()
250                    } else {
251                        None
252                    };
253                    if let Some(token_response) = token_response {
254                        Response::builder()
255                            .status(StatusCode::OK)
256                            .header(CONTENT_TYPE, "application/json")
257                            .body(serde_json::to_vec(&token_response)?)?
258                    } else {
259                        Response::builder()
260                            .status(StatusCode::UNAUTHORIZED)
261                            .header("WWW-Authenticate", "DPoP error=\"invalid_token\"")
262                            .body(Vec::new())?
263                    }
264                } else {
265                    Response::builder().status(StatusCode::UNAUTHORIZED).body(Vec::new())?
266                };
267                return Ok(response);
268            };
269            let Some(token) = authorization.strip_prefix("DPoP ") else {
270                panic!("authorization header should start with DPoP");
271            };
272            if token == "expired" {
273                return Ok(Response::builder()
274                    .status(StatusCode::UNAUTHORIZED)
275                    .header("WWW-Authenticate", "DPoP error=\"invalid_token\"")
276                    .body(Vec::new())?);
277            }
278            let dpop_jwt = headers.remove("dpop").expect("dpop header should be present");
279            let payload = dpop_jwt
280                .to_str()
281                .expect("dpop header should be valid")
282                .split('.')
283                .nth(1)
284                .expect("dpop header should have 2 parts");
285            let claims = URL_SAFE_NO_PAD
286                .decode(payload)
287                .ok()
288                .and_then(|value| serde_json::from_slice::<Claims>(&value).ok())
289                .expect("dpop payload should be valid");
290            assert!(claims.registered.iat.is_some());
291            assert!(claims.registered.jti.is_some());
292            assert_eq!(claims.public.htm, Some(request.method().to_string()));
293            assert_eq!(claims.public.htu, Some(request.uri().to_string()));
294
295            self.data
296                .lock()
297                .await
298                .replace(RecordData { host: request.uri().host().map(String::from), headers });
299            let output = atrium_api::com::atproto::server::get_service_auth::OutputData {
300                token: String::from("fake_token"),
301            };
302            Response::builder()
303                .header(CONTENT_TYPE, "application/json")
304                .body(serde_json::to_vec(&output)?)
305                .map_err(|e| e.into())
306        }
307    }
308
309    struct MockSessionStore {
310        data: Arc<Mutex<HashMap<Did, Session>>>,
311    }
312
313    impl Store<Did, Session> for MockSessionStore {
314        type Error = Error;
315
316        async fn get(&self, key: &Did) -> Result<Option<Session>, Self::Error> {
317            tokio::time::sleep(Duration::from_micros(10)).await;
318            Ok(self.data.lock().await.get(key).cloned())
319        }
320        async fn set(&self, key: Did, value: Session) -> Result<(), Self::Error> {
321            tokio::time::sleep(Duration::from_micros(10)).await;
322            self.data.lock().await.insert(key, value);
323            Ok(())
324        }
325        async fn del(&self, _: &Did) -> Result<(), Self::Error> {
326            unimplemented!()
327        }
328        async fn clear(&self) -> Result<(), Self::Error> {
329            unimplemented!()
330        }
331    }
332
333    impl SessionStore for MockSessionStore {}
334
335    fn did() -> Did {
336        Did::new(String::from("did:fake:sub.test")).expect("did should be valid")
337    }
338
339    fn default_store() -> Arc<Mutex<HashMap<Did, Session>>> {
340        let did = did();
341        let token_set = TokenSet {
342            iss: String::from("https://iss.example.com"),
343            sub: did.clone(),
344            aud: String::from("https://aud.example.com"),
345            scope: None,
346            refresh_token: Some(String::from("refreshtoken")),
347            access_token: String::from("accesstoken"),
348            token_type: OAuthTokenType::DPoP,
349            expires_at: None,
350        };
351        let dpop_key = dpop_key();
352        let session = Session { token_set, dpop_key };
353        Arc::new(Mutex::new(HashMap::from_iter([(did, session)])))
354    }
355
356    async fn oauth_session(
357        data: Arc<Mutex<Option<RecordData>>>,
358        store: Arc<Mutex<HashMap<Did, Session>>>,
359    ) -> OAuthSession<MockHttpClient, MockDidResolver, NoopHandleResolver, MockSessionStore> {
360        let http_client = Arc::new(MockHttpClient::new(data));
361        let resolver = Arc::new(oauth_resolver(Arc::clone(&http_client)));
362        let server_factory = Arc::new(OAuthServerFactory::new(
363            client_metadata(),
364            resolver,
365            Arc::clone(&http_client),
366            None,
367        ));
368        let session_registory = Arc::new(SessionRegistry::new(
369            MockSessionStore { data: Arc::clone(&store) },
370            server_factory,
371        ));
372        OAuthSession::new(server_metadata(), did(), http_client, session_registory)
373            .await
374            .expect("failed to create oauth session")
375    }
376
377    async fn oauth_agent(
378        data: Arc<Mutex<Option<RecordData>>>,
379    ) -> Agent<impl SessionManager + Configure + CloneWithProxy> {
380        Agent::new(oauth_session(data, default_store()).await)
381    }
382
383    async fn call_service(
384        service: &Service<impl SessionManager + Sync>,
385    ) -> Result<(), atrium_xrpc::Error<atrium_api::com::atproto::server::get_service_auth::Error>>
386    {
387        let output = service
388            .com
389            .atproto
390            .server
391            .get_service_auth(
392                atrium_api::com::atproto::server::get_service_auth::ParametersData {
393                    aud: Did::new(String::from("did:fake:handle.test"))
394                        .expect("did should be valid"),
395                    exp: None,
396                    lxm: None,
397                }
398                .into(),
399            )
400            .await?;
401        assert_eq!(output.token, "fake_token");
402        Ok(())
403    }
404
405    #[tokio::test]
406    async fn test_new() -> Result<(), Box<dyn std::error::Error>> {
407        let agent = oauth_agent(Default::default()).await;
408        assert_eq!(agent.did().await.as_deref(), Some("did:fake:sub.test"));
409        Ok(())
410    }
411
412    #[tokio::test]
413    async fn test_configure_endpoint() -> Result<(), Box<dyn std::error::Error>> {
414        let data = Default::default();
415        let agent = oauth_agent(Arc::clone(&data)).await;
416        call_service(&agent.api).await?;
417        assert_eq!(
418            data.lock().await.as_ref().expect("data should be recorded").host.as_deref(),
419            Some("aud.example.com")
420        );
421        agent.configure_endpoint(String::from("https://pds.example.com"));
422        call_service(&agent.api).await?;
423        assert_eq!(
424            data.lock().await.as_ref().expect("data should be recorded").host.as_deref(),
425            Some("pds.example.com")
426        );
427        Ok(())
428    }
429
430    #[tokio::test]
431    async fn test_configure_labelers_header() -> Result<(), Box<dyn std::error::Error>> {
432        let data = Default::default();
433        let agent = oauth_agent(Arc::clone(&data)).await;
434        // not configured
435        {
436            call_service(&agent.api).await?;
437            assert_eq!(
438                data.lock().await.as_ref().expect("data should be recorded").headers,
439                HeaderMap::new()
440            );
441        }
442        // configured 1
443        {
444            agent.configure_labelers_header(Some(vec![(
445                Did::new(String::from("did:fake:labeler.test"))?,
446                false,
447            )]));
448            call_service(&agent.api).await?;
449            assert_eq!(
450                data.lock().await.as_ref().expect("data should be recorded").headers,
451                HeaderMap::from_iter([(
452                    HeaderName::from_static("atproto-accept-labelers"),
453                    HeaderValue::from_static("did:fake:labeler.test"),
454                )])
455            );
456        }
457        // configured 2
458        {
459            agent.configure_labelers_header(Some(vec![
460                (Did::new(String::from("did:fake:labeler.test_redact"))?, true),
461                (Did::new(String::from("did:fake:labeler.test"))?, false),
462            ]));
463            call_service(&agent.api).await?;
464            assert_eq!(
465                data.lock().await.as_ref().expect("data should be recorded").headers,
466                HeaderMap::from_iter([(
467                    HeaderName::from_static("atproto-accept-labelers"),
468                    HeaderValue::from_static(
469                        "did:fake:labeler.test_redact;redact, did:fake:labeler.test"
470                    ),
471                )])
472            );
473        }
474        Ok(())
475    }
476
477    #[tokio::test]
478    async fn test_configure_proxy_header() -> Result<(), Box<dyn std::error::Error>> {
479        let data = Arc::new(Mutex::new(Default::default()));
480        let agent = oauth_agent(Arc::clone(&data)).await;
481        // not configured
482        {
483            call_service(&agent.api).await?;
484            assert_eq!(
485                data.lock().await.as_ref().expect("data should be recorded").headers,
486                HeaderMap::new()
487            );
488        }
489        // labeler service
490        {
491            agent.configure_proxy_header(
492                Did::new(String::from("did:fake:service.test"))?,
493                AtprotoServiceType::AtprotoLabeler,
494            );
495            call_service(&agent.api).await?;
496            assert_eq!(
497                data.lock().await.as_ref().expect("data should be recorded").headers,
498                HeaderMap::from_iter([(
499                    HeaderName::from_static("atproto-proxy"),
500                    HeaderValue::from_static("did:fake:service.test#atproto_labeler"),
501                )])
502            );
503        }
504        // custom service
505        {
506            agent.configure_proxy_header(
507                Did::new(String::from("did:fake:service.test"))?,
508                "custom_service",
509            );
510            call_service(&agent.api).await?;
511            assert_eq!(
512                data.lock().await.as_ref().expect("data should be recorded").headers,
513                HeaderMap::from_iter([(
514                    HeaderName::from_static("atproto-proxy"),
515                    HeaderValue::from_static("did:fake:service.test#custom_service"),
516                )])
517            );
518        }
519        // api_with_proxy
520        {
521            call_service(
522                &agent.api_with_proxy(
523                    Did::new(String::from("did:fake:service.test"))?,
524                    "temp_service",
525                ),
526            )
527            .await?;
528            assert_eq!(
529                data.lock().await.as_ref().expect("data should be recorded").headers,
530                HeaderMap::from_iter([(
531                    HeaderName::from_static("atproto-proxy"),
532                    HeaderValue::from_static("did:fake:service.test#temp_service"),
533                )])
534            );
535            call_service(&agent.api).await?;
536            assert_eq!(
537                data.lock().await.as_ref().expect("data should be recorded").headers,
538                HeaderMap::from_iter([(
539                    HeaderName::from_static("atproto-proxy"),
540                    HeaderValue::from_static("did:fake:service.test#custom_service"),
541                )])
542            );
543        }
544        Ok(())
545    }
546
547    #[tokio::test]
548    async fn test_xrpc_without_token() -> Result<(), Box<dyn std::error::Error>> {
549        let oauth_session = oauth_session(Default::default(), default_store()).await;
550        oauth_session.store.clear().await?;
551        let agent = Agent::new(oauth_session);
552        let result = agent
553            .api
554            .com
555            .atproto
556            .server
557            .get_service_auth(
558                atrium_api::com::atproto::server::get_service_auth::ParametersData {
559                    aud: Did::new(String::from("did:fake:handle.test"))
560                        .expect("did should be valid"),
561                    exp: None,
562                    lxm: None,
563                }
564                .into(),
565            )
566            .await;
567        match result.expect_err("should fail without token") {
568            atrium_xrpc::Error::XrpcResponse(err) => {
569                assert_eq!(err.status, StatusCode::UNAUTHORIZED);
570            }
571            _ => panic!("unexpected error"),
572        }
573        Ok(())
574    }
575
576    #[tokio::test]
577    async fn test_xrpc_with_refresh() -> Result<(), Box<dyn std::error::Error>> {
578        let session_data = default_store();
579        if let Some(session) = session_data.lock().await.get_mut(&did()) {
580            session.token_set.access_token = String::from("expired");
581        }
582        let oauth_session = oauth_session(Default::default(), Arc::clone(&session_data)).await;
583        let agent = Agent::new(oauth_session);
584        let result = agent
585            .api
586            .com
587            .atproto
588            .server
589            .get_service_auth(
590                atrium_api::com::atproto::server::get_service_auth::ParametersData {
591                    aud: Did::new(String::from("did:fake:handle.test"))
592                        .expect("did should be valid"),
593                    exp: None,
594                    lxm: None,
595                }
596                .into(),
597            )
598            .await;
599        match result {
600            Ok(output) => {
601                assert_eq!(output.token, "fake_token");
602            }
603            Err(err) => {
604                panic!("unexpected error: {err:?}");
605            }
606        }
607        // wait for async update
608        tokio::time::sleep(Duration::from_micros(0)).await;
609        {
610            let token_set = session_data
611                .lock()
612                .await
613                .get(&did())
614                .expect("session should be present")
615                .token_set
616                .clone();
617            assert_eq!(token_set.access_token, "new_accesstoken");
618            assert_eq!(token_set.refresh_token, Some(String::from("new_refreshtoken")));
619        }
620        Ok(())
621    }
622
623    #[tokio::test]
624    async fn test_xrpc_with_duplicated_refresh() -> Result<(), Box<dyn std::error::Error>> {
625        let session_data = default_store();
626        if let Some(session) = session_data.lock().await.get_mut(&did()) {
627            session.token_set.access_token = String::from("expired");
628        }
629        let oauth_session = oauth_session(Default::default(), session_data).await;
630        let agent = Arc::new(Agent::new(oauth_session));
631
632        let handles = (0..3).map(|_| {
633            let agent = Arc::clone(&agent);
634            tokio::spawn(async move {
635                agent
636                    .api
637                    .com
638                    .atproto
639                    .server
640                    .get_service_auth(
641                        atrium_api::com::atproto::server::get_service_auth::ParametersData {
642                            aud: Did::new(String::from("did:fake:handle.test"))
643                                .expect("did should be valid"),
644                            exp: None,
645                            lxm: None,
646                        }
647                        .into(),
648                    )
649                    .await
650            })
651        });
652        for result in futures::future::join_all(handles).await {
653            match result? {
654                Ok(output) => {
655                    assert_eq!(output.token, "fake_token");
656                }
657                Err(err) => {
658                    panic!("unexpected error: {err:?}");
659                }
660            }
661        }
662        Ok(())
663    }
664}