atrium_api/
agent.rs

1//! Structs and traits for managing sessions and making the XRPC requests.
2
3pub mod atp_agent;
4#[cfg(feature = "bluesky")]
5pub mod bluesky;
6mod inner;
7mod session_manager;
8pub mod utils;
9
10pub use self::session_manager::SessionManager;
11use crate::{client::Service, types::string::Did};
12use atrium_xrpc::types::AuthorizationToken;
13use std::{future::Future, sync::Arc};
14
15/// A trait for providing authorization tokens.
16#[cfg_attr(not(target_arch = "wasm32"), trait_variant::make(Send))]
17pub trait AuthorizationProvider {
18    #[allow(unused_variables)]
19    fn authorization_token(
20        &self,
21        is_refresh: bool,
22    ) -> impl Future<Output = Option<AuthorizationToken>>;
23}
24
25/// A trait for configuring the endpoint and headers of a client.
26pub trait Configure {
27    /// Set the current endpoint.
28    fn configure_endpoint(&self, endpoint: String);
29    /// Configures the moderation services to be applied on requests.
30    fn configure_labelers_header(&self, labeler_dids: Option<Vec<(Did, bool)>>);
31    /// Configures the atproto-proxy header to be applied on requests.
32    fn configure_proxy_header(&self, did: Did, service_type: impl AsRef<str>);
33}
34
35/// A trait for cloning a client with a proxy header.
36pub trait CloneWithProxy {
37    fn clone_with_proxy(&self, did: Did, service_type: impl AsRef<str>) -> Self;
38}
39
40/// Supported proxy targets.
41#[cfg(feature = "bluesky")]
42pub type AtprotoServiceType = self::bluesky::AtprotoServiceType;
43
44#[cfg(not(feature = "bluesky"))]
45pub enum AtprotoServiceType {
46    AtprotoLabeler,
47}
48
49#[cfg(not(feature = "bluesky"))]
50impl AsRef<str> for AtprotoServiceType {
51    fn as_ref(&self) -> &str {
52        match self {
53            Self::AtprotoLabeler => "atproto_labeler",
54        }
55    }
56}
57
58/// A wrapper around a session manager and a client service.
59///
60/// An agent provides the following utilities:
61/// - AT Protocol labelers configuration utilities
62/// - AT Protocol proxy configuration utilities
63/// - Cloning utilities (if the session manager implements [`CloneWithProxy`])
64///
65/// # Example
66///
67/// ```
68/// use atrium_api::agent::atp_agent::{store::MemorySessionStore, CredentialSession};
69/// use atrium_api::agent::Agent;
70/// use atrium_xrpc_client::reqwest::ReqwestClient;
71///
72/// let session = CredentialSession::new(
73///     ReqwestClient::new("https://bsky.social"),
74///     MemorySessionStore::default(),
75/// );
76/// let agent = Agent::new(session);
77/// ``````
78pub struct Agent<M>
79where
80    M: SessionManager + Send + Sync,
81{
82    session_manager: Arc<inner::Wrapper<M>>,
83    pub api: Service<inner::Wrapper<M>>,
84}
85
86impl<M> Agent<M>
87where
88    M: SessionManager + Send + Sync,
89{
90    /// Creates a new agent with the given session manager.
91    pub fn new(session_manager: M) -> Self {
92        let session_manager = Arc::new(inner::Wrapper::new(session_manager));
93        let api = Service::new(session_manager.clone());
94        Self { session_manager, api }
95    }
96    /// Returns the DID of the current session.
97    pub async fn did(&self) -> Option<Did> {
98        self.session_manager.did().await
99    }
100}
101
102impl<M> Agent<M>
103where
104    M: CloneWithProxy + SessionManager + Send + Sync,
105{
106    /// Configures the atproto-proxy header to be applied on requests.
107    ///
108    /// Returns a new client service with the proxy header configured.
109    pub fn api_with_proxy(
110        &self,
111        did: Did,
112        service_type: impl AsRef<str>,
113    ) -> Service<inner::Wrapper<M>> {
114        Service::new(Arc::new(self.session_manager.clone_with_proxy(did, service_type)))
115    }
116}
117
118impl<M> Configure for Agent<M>
119where
120    M: Configure + SessionManager + Send + Sync,
121{
122    fn configure_endpoint(&self, endpoint: String) {
123        self.session_manager.configure_endpoint(endpoint);
124    }
125    fn configure_labelers_header(&self, labeler_dids: Option<Vec<(Did, bool)>>) {
126        self.session_manager.configure_labelers_header(labeler_dids);
127    }
128    fn configure_proxy_header(&self, did: Did, service_type: impl AsRef<str>) {
129        self.session_manager.configure_proxy_header(did, service_type);
130    }
131}
132
133#[cfg(test)]
134mod tests {
135    use super::inner::Wrapper;
136    use super::utils::{SessionClient, SessionWithEndpointStore};
137    use super::*;
138    use atrium_common::store::Store;
139    use atrium_xrpc::{Error, HttpClient, OutputDataOrBytes, XrpcClient, XrpcRequest};
140    use http::{header::CONTENT_TYPE, HeaderMap, HeaderName, HeaderValue, Request, Response};
141    use serde::{de::DeserializeOwned, Serialize};
142    use std::fmt::Debug;
143    use tokio::sync::Mutex;
144
145    #[derive(Default)]
146    struct RecordData {
147        host: Option<String>,
148        headers: HeaderMap<HeaderValue>,
149    }
150
151    struct MockClient {
152        data: Arc<Mutex<Option<RecordData>>>,
153    }
154
155    impl HttpClient for MockClient {
156        async fn send_http(
157            &self,
158            request: Request<Vec<u8>>,
159        ) -> Result<Response<Vec<u8>>, Box<dyn std::error::Error + Send + Sync + 'static>> {
160            self.data.lock().await.replace(RecordData {
161                host: request.uri().host().map(String::from),
162                headers: request.headers().clone(),
163            });
164            let output = crate::com::atproto::server::get_service_auth::OutputData {
165                token: String::from("fake_token"),
166            };
167            Response::builder()
168                .header(CONTENT_TYPE, "application/json")
169                .body(serde_json::to_vec(&output)?)
170                .map_err(|e| e.into())
171        }
172    }
173
174    impl XrpcClient for MockClient {
175        fn base_uri(&self) -> String {
176            unimplemented!()
177        }
178    }
179
180    #[derive(thiserror::Error, Debug)]
181    enum MockStoreError {}
182
183    struct MockStore;
184
185    impl Store<(), ()> for MockStore {
186        type Error = MockStoreError;
187
188        async fn get(&self, _key: &()) -> Result<Option<()>, Self::Error> {
189            unimplemented!()
190        }
191        async fn set(&self, _key: (), _value: ()) -> Result<(), Self::Error> {
192            unimplemented!()
193        }
194        async fn del(&self, _key: &()) -> Result<(), Self::Error> {
195            unimplemented!()
196        }
197        async fn clear(&self) -> Result<(), Self::Error> {
198            unimplemented!()
199        }
200    }
201
202    impl AuthorizationProvider for MockStore {
203        async fn authorization_token(&self, _: bool) -> Option<AuthorizationToken> {
204            None
205        }
206    }
207
208    struct MockSessionManager {
209        inner: SessionClient<MockStore, MockClient, ()>,
210    }
211
212    impl HttpClient for MockSessionManager {
213        async fn send_http(
214            &self,
215            request: Request<Vec<u8>>,
216        ) -> Result<Response<Vec<u8>>, Box<dyn std::error::Error + Send + Sync + 'static>> {
217            self.inner.send_http(request).await
218        }
219    }
220
221    impl XrpcClient for MockSessionManager {
222        fn base_uri(&self) -> String {
223            self.inner.base_uri()
224        }
225        async fn send_xrpc<P, I, O, E>(
226            &self,
227            request: &XrpcRequest<P, I>,
228        ) -> Result<OutputDataOrBytes<O>, Error<E>>
229        where
230            P: Serialize + Send + Sync,
231            I: Serialize + Send + Sync,
232            O: DeserializeOwned + Send + Sync,
233            E: DeserializeOwned + Send + Sync + Debug,
234        {
235            self.inner.send_xrpc(request).await
236        }
237    }
238
239    impl SessionManager for MockSessionManager {
240        async fn did(&self) -> Option<Did> {
241            Did::new(String::from("did:fake:handle.test")).ok()
242        }
243    }
244
245    impl Configure for MockSessionManager {
246        fn configure_endpoint(&self, endpoint: String) {
247            self.inner.configure_endpoint(endpoint);
248        }
249        fn configure_labelers_header(&self, labeler_dids: Option<Vec<(Did, bool)>>) {
250            self.inner.configure_labelers_header(labeler_dids);
251        }
252        fn configure_proxy_header(&self, did: Did, service_type: impl AsRef<str>) {
253            self.inner.configure_proxy_header(did, service_type);
254        }
255    }
256
257    impl CloneWithProxy for MockSessionManager {
258        fn clone_with_proxy(&self, did: Did, service_type: impl AsRef<str>) -> Self {
259            Self { inner: self.inner.clone_with_proxy(did, service_type) }
260        }
261    }
262
263    fn agent(data: Arc<Mutex<Option<RecordData>>>) -> Agent<MockSessionManager> {
264        let inner = SessionClient::new(
265            Arc::new(SessionWithEndpointStore::new(
266                MockStore {},
267                String::from("https://example.com"),
268            )),
269            MockClient { data },
270        );
271        Agent::new(MockSessionManager { inner })
272    }
273
274    async fn call_service(
275        service: &Service<Wrapper<MockSessionManager>>,
276    ) -> Result<(), Error<crate::com::atproto::server::get_service_auth::Error>> {
277        let output = service
278            .com
279            .atproto
280            .server
281            .get_service_auth(
282                crate::com::atproto::server::get_service_auth::ParametersData {
283                    aud: Did::new(String::from("did:fake:handle.test"))
284                        .expect("did should be valid"),
285                    exp: None,
286                    lxm: None,
287                }
288                .into(),
289            )
290            .await?;
291        assert_eq!(output.token, "fake_token");
292        Ok(())
293    }
294
295    #[tokio::test]
296    async fn test_new() -> Result<(), Box<dyn std::error::Error>> {
297        let agent = agent(Arc::new(Mutex::new(Default::default())));
298        assert_eq!(agent.did().await, Some(Did::new(String::from("did:fake:handle.test"))?));
299        Ok(())
300    }
301
302    #[tokio::test]
303    async fn test_configure_endpoint() -> Result<(), Box<dyn std::error::Error>> {
304        let data = Arc::new(Mutex::new(Default::default()));
305        let agent = agent(data.clone());
306        call_service(&agent.api).await?;
307        assert_eq!(
308            data.lock().await.as_ref().expect("data should be recorded").host.as_deref(),
309            Some("example.com")
310        );
311        agent.configure_endpoint(String::from("https://pds.example.com"));
312        call_service(&agent.api).await?;
313        assert_eq!(
314            data.lock().await.as_ref().expect("data should be recorded").host.as_deref(),
315            Some("pds.example.com")
316        );
317        Ok(())
318    }
319
320    #[tokio::test]
321    async fn test_configure_labelers_header() -> Result<(), Box<dyn std::error::Error>> {
322        let data = Arc::new(Mutex::new(Default::default()));
323        let agent = agent(data.clone());
324        // not configured
325        {
326            call_service(&agent.api).await?;
327            assert_eq!(
328                data.lock().await.as_ref().expect("data should be recorded").headers,
329                HeaderMap::new()
330            );
331        }
332        // configured 1
333        {
334            agent.configure_labelers_header(Some(vec![(
335                Did::new(String::from("did:fake:labeler.test"))?,
336                false,
337            )]));
338            call_service(&agent.api).await?;
339            assert_eq!(
340                data.lock().await.as_ref().expect("data should be recorded").headers,
341                HeaderMap::from_iter([(
342                    HeaderName::from_static("atproto-accept-labelers"),
343                    HeaderValue::from_static("did:fake:labeler.test"),
344                )])
345            );
346        }
347        // configured 2
348        {
349            agent.configure_labelers_header(Some(vec![
350                (Did::new(String::from("did:fake:labeler.test_redact"))?, true),
351                (Did::new(String::from("did:fake:labeler.test"))?, false),
352            ]));
353            call_service(&agent.api).await?;
354            assert_eq!(
355                data.lock().await.as_ref().expect("data should be recorded").headers,
356                HeaderMap::from_iter([(
357                    HeaderName::from_static("atproto-accept-labelers"),
358                    HeaderValue::from_static(
359                        "did:fake:labeler.test_redact;redact, did:fake:labeler.test"
360                    ),
361                )])
362            );
363        }
364        Ok(())
365    }
366
367    #[tokio::test]
368    async fn test_configure_proxy_header() -> Result<(), Box<dyn std::error::Error>> {
369        let data = Arc::new(Mutex::new(Default::default()));
370        let agent = agent(data.clone());
371        // not configured
372        {
373            call_service(&agent.api).await?;
374            assert_eq!(
375                data.lock().await.as_ref().expect("data should be recorded").headers,
376                HeaderMap::new()
377            );
378        }
379        // labeler service
380        {
381            agent.configure_proxy_header(
382                Did::new(String::from("did:fake:service.test"))?,
383                AtprotoServiceType::AtprotoLabeler,
384            );
385            call_service(&agent.api).await?;
386            assert_eq!(
387                data.lock().await.as_ref().expect("data should be recorded").headers,
388                HeaderMap::from_iter([(
389                    HeaderName::from_static("atproto-proxy"),
390                    HeaderValue::from_static("did:fake:service.test#atproto_labeler"),
391                )])
392            );
393        }
394        // custom service
395        {
396            agent.configure_proxy_header(
397                Did::new(String::from("did:fake:service.test"))?,
398                "custom_service",
399            );
400            call_service(&agent.api).await?;
401            assert_eq!(
402                data.lock().await.as_ref().expect("data should be recorded").headers,
403                HeaderMap::from_iter([(
404                    HeaderName::from_static("atproto-proxy"),
405                    HeaderValue::from_static("did:fake:service.test#custom_service"),
406                )])
407            );
408        }
409        // api_with_proxy
410        {
411            call_service(
412                &agent.api_with_proxy(
413                    Did::new(String::from("did:fake:service.test"))?,
414                    "temp_service",
415                ),
416            )
417            .await?;
418            assert_eq!(
419                data.lock().await.as_ref().expect("data should be recorded").headers,
420                HeaderMap::from_iter([(
421                    HeaderName::from_static("atproto-proxy"),
422                    HeaderValue::from_static("did:fake:service.test#temp_service"),
423                )])
424            );
425            call_service(&agent.api).await?;
426            assert_eq!(
427                data.lock().await.as_ref().expect("data should be recorded").headers,
428                HeaderMap::from_iter([(
429                    HeaderName::from_static("atproto-proxy"),
430                    HeaderValue::from_static("did:fake:service.test#custom_service"),
431                )])
432            );
433        }
434        Ok(())
435    }
436}