atrium_api/agent/
atp_agent.rs

1//! Implementation of [`AtpAgent`] and definitions of [`AtpSessionStore`] for it.
2
3mod inner;
4pub mod store;
5
6use self::store::AtpSessionStore;
7use super::{
8    inner::Wrapper, utils::SessionWithEndpointStore, Agent, CloneWithProxy, Configure,
9    SessionManager,
10};
11use crate::{
12    client::com::atproto::Service,
13    did_doc::DidDocument,
14    types::{string::Did, TryFromUnknown},
15};
16use atrium_xrpc::{Error, HttpClient, OutputDataOrBytes, XrpcClient, XrpcRequest};
17use http::{Request, Response};
18use serde::{de::DeserializeOwned, Serialize};
19use std::{convert, fmt::Debug, ops::Deref, sync::Arc};
20
21/// Type alias for the [com::atproto::server::create_session::Output](crate::com::atproto::server::create_session::Output)
22pub type AtpSession = crate::com::atproto::server::create_session::Output;
23
24pub struct CredentialSession<S, T>
25where
26    S: AtpSessionStore + Send + Sync,
27    T: XrpcClient + Send + Sync,
28    S::Error: std::error::Error + Send + Sync + 'static,
29{
30    store: Arc<SessionWithEndpointStore<S, AtpSession>>,
31    inner: Arc<inner::Client<S, T>>,
32    atproto_service: Service<inner::Client<S, T>>,
33}
34
35impl<S, T> CredentialSession<S, T>
36where
37    S: AtpSessionStore + Send + Sync,
38    T: XrpcClient + Send + Sync,
39    S::Error: std::error::Error + Send + Sync + 'static,
40{
41    pub fn new(xrpc: T, store: S) -> Self {
42        let store = Arc::new(SessionWithEndpointStore::new(store, xrpc.base_uri()));
43        let inner = Arc::new(inner::Client::new(Arc::clone(&store), xrpc));
44        let atproto_service = Service::new(Arc::clone(&inner));
45        Self { store, inner, atproto_service }
46    }
47    /// Start a new session with this agent.
48    pub async fn login(
49        &self,
50        identifier: impl AsRef<str>,
51        password: impl AsRef<str>,
52    ) -> Result<AtpSession, Error<crate::com::atproto::server::create_session::Error>> {
53        let result = self
54            .atproto_service
55            .server
56            .create_session(
57                crate::com::atproto::server::create_session::InputData {
58                    allow_takendown: None,
59                    auth_factor_token: None,
60                    identifier: identifier.as_ref().into(),
61                    password: password.as_ref().into(),
62                }
63                .into(),
64            )
65            .await?;
66        self.store.set(result.clone()).await.ok();
67        if let Some(did_doc) = result
68            .did_doc
69            .as_ref()
70            .and_then(|value| DidDocument::try_from_unknown(value.clone()).ok())
71        {
72            self.store.update_endpoint(&did_doc);
73        }
74        Ok(result)
75    }
76    /// Resume a pre-existing session with this agent.
77    pub async fn resume_session(
78        &self,
79        session: AtpSession,
80    ) -> Result<(), Error<crate::com::atproto::server::get_session::Error>> {
81        self.store.set(session.clone()).await.ok();
82        let result = self.atproto_service.server.get_session().await;
83        match result {
84            Ok(output) => {
85                assert_eq!(output.data.did, session.data.did);
86                if let Ok(Some(mut session)) = self.store.get().await {
87                    session.did_doc = output.data.did_doc.clone();
88                    session.email = output.data.email;
89                    session.email_confirmed = output.data.email_confirmed;
90                    session.handle = output.data.handle;
91                    self.store.set(session).await.ok();
92                }
93                if let Some(did_doc) = output
94                    .data
95                    .did_doc
96                    .as_ref()
97                    .and_then(|value| DidDocument::try_from_unknown(value.clone()).ok())
98                {
99                    self.store.update_endpoint(&did_doc);
100                }
101                Ok(())
102            }
103            Err(err) => {
104                self.store.clear().await.ok();
105                Err(err)
106            }
107        }
108    }
109    /// Get the current session.
110    pub async fn get_session(&self) -> Option<AtpSession> {
111        self.store.get().await.ok().and_then(convert::identity)
112    }
113    /// Get the current endpoint.
114    pub async fn get_endpoint(&self) -> String {
115        self.store.get_endpoint()
116    }
117    /// Get the current labelers header.
118    pub async fn get_labelers_header(&self) -> Option<Vec<String>> {
119        self.inner.get_labelers_header().await
120    }
121    /// Get the current proxy header.
122    pub async fn get_proxy_header(&self) -> Option<String> {
123        self.inner.get_proxy_header().await
124    }
125}
126
127impl<S, T> HttpClient for CredentialSession<S, T>
128where
129    S: AtpSessionStore + Send + Sync,
130    T: XrpcClient + Send + Sync,
131    S::Error: std::error::Error + Send + Sync + 'static,
132{
133    async fn send_http(
134        &self,
135        request: Request<Vec<u8>>,
136    ) -> Result<Response<Vec<u8>>, Box<dyn std::error::Error + Send + Sync + 'static>> {
137        self.inner.send_http(request).await
138    }
139}
140
141impl<S, T> XrpcClient for CredentialSession<S, T>
142where
143    S: AtpSessionStore + Send + Sync,
144    T: XrpcClient + Send + Sync,
145    S::Error: std::error::Error + Send + Sync + 'static,
146{
147    fn base_uri(&self) -> String {
148        self.inner.base_uri()
149    }
150    async fn send_xrpc<P, I, O, E>(
151        &self,
152        request: &XrpcRequest<P, I>,
153    ) -> Result<OutputDataOrBytes<O>, Error<E>>
154    where
155        P: Serialize + Send + Sync,
156        I: Serialize + Send + Sync,
157        O: DeserializeOwned + Send + Sync,
158        E: DeserializeOwned + Send + Sync + Debug,
159    {
160        self.inner.send_xrpc(request).await
161    }
162}
163
164impl<S, T> SessionManager for CredentialSession<S, T>
165where
166    S: AtpSessionStore + Send + Sync,
167    T: XrpcClient + Send + Sync,
168    S::Error: std::error::Error + Send + Sync + 'static,
169{
170    async fn did(&self) -> Option<Did> {
171        self.store.get().await.ok().and_then(|session| session.map(|session| session.data.did))
172    }
173}
174
175impl<S, T> Configure for CredentialSession<S, T>
176where
177    S: AtpSessionStore + Send + Sync,
178    T: XrpcClient + Send + Sync,
179    S::Error: std::error::Error + Send + Sync + 'static,
180{
181    fn configure_endpoint(&self, endpoint: String) {
182        self.inner.configure_endpoint(endpoint);
183    }
184    fn configure_labelers_header(&self, labeler_dids: Option<Vec<(Did, bool)>>) {
185        self.inner.configure_labelers_header(labeler_dids);
186    }
187    fn configure_proxy_header(&self, did: Did, service_type: impl AsRef<str>) {
188        self.inner.configure_proxy_header(did, service_type);
189    }
190}
191
192impl<S, T> CloneWithProxy for CredentialSession<S, T>
193where
194    S: AtpSessionStore + Send + Sync,
195    S::Error: std::error::Error + Send + Sync + 'static,
196    T: XrpcClient + Send + Sync,
197{
198    fn clone_with_proxy(&self, did: Did, service_type: impl AsRef<str>) -> Self {
199        let inner = Arc::new(self.inner.clone_with_proxy(did, service_type));
200        let atproto_service = Service::new(Arc::clone(&inner));
201        Self { store: Arc::clone(&self.store), inner, atproto_service }
202    }
203}
204
205/// An ATP "Agent".
206/// Manages session token lifecycles and provides convenience methods.
207///
208/// This will be deprecated in the near future. Use [`Agent`] directly
209/// with a [`CredentialSession`] instead:
210///
211/// # Example
212///
213/// ```
214/// use atrium_api::agent::atp_agent::{store::MemorySessionStore, CredentialSession};
215/// use atrium_api::agent::Agent;
216/// use atrium_xrpc_client::reqwest::ReqwestClient;
217///
218/// let session = CredentialSession::new(
219///     ReqwestClient::new("https://bsky.social"),
220///     MemorySessionStore::default(),
221/// );
222/// let agent = Agent::new(session);
223/// ```
224pub struct AtpAgent<S, T>
225where
226    S: AtpSessionStore + Send + Sync,
227    T: XrpcClient + Send + Sync,
228    S::Error: std::error::Error + Send + Sync + 'static,
229{
230    session_manager: Wrapper<CredentialSession<S, T>>,
231    inner: Agent<Wrapper<CredentialSession<S, T>>>,
232}
233
234impl<S, T> AtpAgent<S, T>
235where
236    S: AtpSessionStore + Send + Sync,
237    T: XrpcClient + Send + Sync,
238    S::Error: std::error::Error + Send + Sync + 'static,
239{
240    /// Create a new agent.
241    pub fn new(xrpc: T, store: S) -> Self {
242        let session_manager = Wrapper::new(CredentialSession::new(xrpc, store));
243        let inner = Agent::new(session_manager.clone());
244        Self { session_manager, inner }
245    }
246    /// Start a new session with this agent.
247    pub async fn login(
248        &self,
249        identifier: impl AsRef<str>,
250        password: impl AsRef<str>,
251    ) -> Result<AtpSession, Error<crate::com::atproto::server::create_session::Error>> {
252        self.session_manager.login(identifier, password).await
253    }
254    // /// Resume a pre-existing session with this agent.
255    pub async fn resume_session(
256        &self,
257        session: AtpSession,
258    ) -> Result<(), Error<crate::com::atproto::server::get_session::Error>> {
259        self.session_manager.resume_session(session).await
260    }
261    /// Get the current session.
262    pub async fn get_session(&self) -> Option<AtpSession> {
263        self.session_manager.get_session().await
264    }
265    /// Get the current endpoint.
266    pub async fn get_endpoint(&self) -> String {
267        self.session_manager.get_endpoint().await
268    }
269    /// Get the current labelers header.
270    pub async fn get_labelers_header(&self) -> Option<Vec<String>> {
271        self.session_manager.get_labelers_header().await
272    }
273    /// Get the current proxy header.
274    pub async fn get_proxy_header(&self) -> Option<String> {
275        self.session_manager.get_proxy_header().await
276    }
277}
278
279impl<S, T> Deref for AtpAgent<S, T>
280where
281    S: AtpSessionStore + Send + Sync,
282    T: XrpcClient + Send + Sync,
283    S::Error: std::error::Error + Send + Sync + 'static,
284{
285    type Target = Agent<Wrapper<CredentialSession<S, T>>>;
286
287    fn deref(&self) -> &Self::Target {
288        &self.inner
289    }
290}
291
292#[cfg(test)]
293mod tests {
294    use super::store::MemorySessionStore;
295    use super::*;
296    use crate::{
297        agent::AtprotoServiceType,
298        com::atproto::server::create_session::OutputData,
299        did_doc::{DidDocument, Service, VerificationMethod},
300        types::TryIntoUnknown,
301    };
302    use atrium_xrpc::HttpClient;
303    use http::{HeaderMap, HeaderName, HeaderValue, Request, Response};
304    use std::collections::HashMap;
305    use tokio::sync::RwLock;
306    #[cfg(target_arch = "wasm32")]
307    use wasm_bindgen_test::wasm_bindgen_test;
308
309    #[derive(Default)]
310    struct MockResponses {
311        create_session: Option<crate::com::atproto::server::create_session::OutputData>,
312        get_session: Option<crate::com::atproto::server::get_session::OutputData>,
313    }
314
315    #[derive(Default)]
316    struct MockClient {
317        responses: MockResponses,
318        counts: Arc<RwLock<HashMap<String, usize>>>,
319        headers: Arc<RwLock<Vec<HeaderMap<HeaderValue>>>>,
320    }
321
322    impl HttpClient for MockClient {
323        async fn send_http(
324            &self,
325            request: Request<Vec<u8>>,
326        ) -> Result<Response<Vec<u8>>, Box<dyn std::error::Error + Send + Sync + 'static>> {
327            // tick tokio time
328            #[cfg(not(target_arch = "wasm32"))]
329            tokio::time::sleep(std::time::Duration::from_micros(10)).await;
330
331            self.headers.write().await.push(request.headers().clone());
332            let builder =
333                Response::builder().header(http::header::CONTENT_TYPE, "application/json");
334            let token = request
335                .headers()
336                .get(http::header::AUTHORIZATION)
337                .and_then(|value| value.to_str().ok())
338                .and_then(|value| value.split(' ').last());
339            if token == Some("expired") {
340                return Ok(builder.status(http::StatusCode::BAD_REQUEST).body(
341                    serde_json::to_vec(&atrium_xrpc::error::ErrorResponseBody {
342                        error: Some(String::from("ExpiredToken")),
343                        message: Some(String::from("Token has expired")),
344                    })?,
345                )?);
346            }
347            let mut body = Vec::new();
348            if let Some(nsid) = request.uri().path().strip_prefix("/xrpc/") {
349                *self.counts.write().await.entry(nsid.into()).or_default() += 1;
350                match nsid {
351                    crate::com::atproto::server::create_session::NSID => {
352                        if let Some(output) = &self.responses.create_session {
353                            body.extend(serde_json::to_vec(output)?);
354                        }
355                    }
356                    crate::com::atproto::server::get_session::NSID => {
357                        if token == Some("access") {
358                            if let Some(output) = &self.responses.get_session {
359                                body.extend(serde_json::to_vec(output)?);
360                            }
361                        }
362                    }
363                    crate::com::atproto::server::refresh_session::NSID => {
364                        if token == Some("refresh") {
365                            body.extend(serde_json::to_vec(
366                                &crate::com::atproto::server::refresh_session::OutputData {
367                                    access_jwt: String::from("access"),
368                                    active: None,
369                                    did: "did:web:example.com".parse().expect("valid"),
370                                    did_doc: None,
371                                    handle: "example.com".parse().expect("valid"),
372                                    refresh_jwt: String::from("refresh"),
373                                    status: None,
374                                },
375                            )?);
376                        }
377                    }
378                    crate::com::atproto::server::describe_server::NSID => {
379                        body.extend(serde_json::to_vec(
380                            &crate::com::atproto::server::describe_server::OutputData {
381                                available_user_domains: Vec::new(),
382                                contact: None,
383                                did: "did:web:example.com".parse().expect("valid"),
384                                invite_code_required: None,
385                                links: None,
386                                phone_verification_required: None,
387                            },
388                        )?);
389                    }
390                    _ => {}
391                }
392            }
393            if body.is_empty() {
394                Ok(builder.status(http::StatusCode::UNAUTHORIZED).body(serde_json::to_vec(
395                    &atrium_xrpc::error::ErrorResponseBody {
396                        error: Some(String::from("AuthenticationRequired")),
397                        message: Some(String::from("Invalid identifier or password")),
398                    },
399                )?)?)
400            } else {
401                Ok(builder.status(http::StatusCode::OK).body(body)?)
402            }
403        }
404    }
405
406    impl XrpcClient for MockClient {
407        fn base_uri(&self) -> String {
408            "http://localhost:8080".into()
409        }
410    }
411
412    fn session_data() -> OutputData {
413        OutputData {
414            access_jwt: String::from("access"),
415            active: None,
416            did: "did:web:example.com".parse().expect("valid"),
417            did_doc: None,
418            email: None,
419            email_auth_factor: None,
420            email_confirmed: None,
421            handle: "example.com".parse().expect("valid"),
422            refresh_jwt: String::from("refresh"),
423            status: None,
424        }
425    }
426
427    #[tokio::test]
428    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
429    async fn test_new() {
430        let agent = AtpAgent::new(MockClient::default(), MemorySessionStore::default());
431        assert_eq!(agent.get_session().await, None);
432    }
433
434    #[tokio::test]
435    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
436    async fn test_login() {
437        let session_data = session_data();
438        // success
439        {
440            let client = MockClient {
441                responses: MockResponses {
442                    create_session: Some(crate::com::atproto::server::create_session::OutputData {
443                        ..session_data.clone()
444                    }),
445                    ..Default::default()
446                },
447                ..Default::default()
448            };
449            let agent = AtpAgent::new(client, MemorySessionStore::default());
450            agent.login("test", "pass").await.expect("login should be succeeded");
451            assert_eq!(agent.get_session().await, Some(session_data.into()));
452        }
453        // failure with `createSession` error
454        {
455            let client = MockClient {
456                responses: MockResponses { ..Default::default() },
457                ..Default::default()
458            };
459            let agent = AtpAgent::new(client, MemorySessionStore::default());
460            agent.login("test", "bad").await.expect_err("login should be failed");
461            assert_eq!(agent.get_session().await, None);
462        }
463    }
464
465    #[tokio::test]
466    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
467    async fn test_xrpc_get_session() {
468        let session_data = session_data();
469        let client = MockClient {
470            responses: MockResponses {
471                get_session: Some(crate::com::atproto::server::get_session::OutputData {
472                    active: session_data.active,
473                    did: session_data.did.clone(),
474                    did_doc: session_data.did_doc.clone(),
475                    email: session_data.email.clone(),
476                    email_auth_factor: session_data.email_auth_factor,
477                    email_confirmed: session_data.email_confirmed,
478                    handle: session_data.handle.clone(),
479                    status: session_data.status.clone(),
480                }),
481                ..Default::default()
482            },
483            ..Default::default()
484        };
485        let agent = AtpAgent::new(client, MemorySessionStore::default());
486        agent
487            .session_manager
488            .store
489            .set(session_data.clone().into())
490            .await
491            .expect("set session should be succeeded");
492        let output = agent
493            .api
494            .com
495            .atproto
496            .server
497            .get_session()
498            .await
499            .expect("get session should be succeeded");
500        assert_eq!(output.did.as_str(), "did:web:example.com");
501    }
502
503    #[tokio::test]
504    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
505    async fn test_xrpc_get_session_with_refresh() {
506        let mut session_data = session_data();
507        session_data.access_jwt = String::from("expired");
508        let client = MockClient {
509            responses: MockResponses {
510                get_session: Some(crate::com::atproto::server::get_session::OutputData {
511                    active: session_data.active,
512                    did: session_data.did.clone(),
513                    did_doc: session_data.did_doc.clone(),
514                    email: session_data.email.clone(),
515                    email_auth_factor: session_data.email_auth_factor,
516                    email_confirmed: session_data.email_confirmed,
517                    handle: session_data.handle.clone(),
518                    status: session_data.status.clone(),
519                }),
520                ..Default::default()
521            },
522            ..Default::default()
523        };
524        let agent = AtpAgent::new(client, MemorySessionStore::default());
525        agent
526            .session_manager
527            .store
528            .set(session_data.clone().into())
529            .await
530            .expect("set session should be succeeded");
531        let output = agent
532            .api
533            .com
534            .atproto
535            .server
536            .get_session()
537            .await
538            .expect("get session should be succeeded");
539        assert_eq!(output.did.as_str(), "did:web:example.com");
540        assert_eq!(
541            agent
542                .session_manager
543                .store
544                .get()
545                .await
546                .expect("session should be stored")
547                .map(|session| session.data.access_jwt),
548            Some("access".into())
549        );
550    }
551
552    #[cfg(not(target_arch = "wasm32"))]
553    #[tokio::test]
554    async fn test_xrpc_get_session_with_duplicated_refresh() {
555        let mut session_data = session_data();
556        session_data.access_jwt = String::from("expired");
557        let client = MockClient {
558            responses: MockResponses {
559                get_session: Some(crate::com::atproto::server::get_session::OutputData {
560                    active: session_data.active,
561                    did: session_data.did.clone(),
562                    did_doc: session_data.did_doc.clone(),
563                    email: session_data.email.clone(),
564                    email_auth_factor: session_data.email_auth_factor,
565                    email_confirmed: session_data.email_confirmed,
566                    handle: session_data.handle.clone(),
567                    status: session_data.status.clone(),
568                }),
569                ..Default::default()
570            },
571            ..Default::default()
572        };
573        let counts = Arc::clone(&client.counts);
574        let agent = Arc::new(AtpAgent::new(client, MemorySessionStore::default()));
575        agent
576            .session_manager
577            .store
578            .set(session_data.clone().into())
579            .await
580            .expect("set session should be succeeded");
581        let handles = (0..3).map(|_| {
582            let agent = Arc::clone(&agent);
583            tokio::spawn(async move { agent.api.com.atproto.server.get_session().await })
584        });
585        let results = futures::future::join_all(handles).await;
586        for result in &results {
587            let output = result
588                .as_ref()
589                .expect("task should be successfully executed")
590                .as_ref()
591                .expect("get session should be succeeded");
592            assert_eq!(output.did.as_str(), "did:web:example.com");
593        }
594        assert_eq!(
595            agent
596                .session_manager
597                .store
598                .get()
599                .await
600                .expect("session should be stored")
601                .map(|session| session.data.access_jwt),
602            Some("access".into())
603        );
604        assert_eq!(
605            counts.read().await.clone(),
606            HashMap::from_iter([
607                ("com.atproto.server.refreshSession".into(), 1),
608                ("com.atproto.server.getSession".into(), 3)
609            ])
610        );
611    }
612
613    #[tokio::test]
614    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
615    async fn test_resume_session() {
616        let session_data = session_data();
617        // success
618        {
619            let client = MockClient {
620                responses: MockResponses {
621                    get_session: Some(crate::com::atproto::server::get_session::OutputData {
622                        active: session_data.active,
623                        did: session_data.did.clone(),
624                        did_doc: session_data.did_doc.clone(),
625                        email: session_data.email.clone(),
626                        email_auth_factor: session_data.email_auth_factor,
627                        email_confirmed: session_data.email_confirmed,
628                        handle: session_data.handle.clone(),
629                        status: session_data.status.clone(),
630                    }),
631                    ..Default::default()
632                },
633                ..Default::default()
634            };
635            let agent = AtpAgent::new(client, MemorySessionStore::default());
636            assert_eq!(agent.get_session().await, None);
637            agent
638                .resume_session(
639                    OutputData {
640                        email: Some(String::from("test@example.com")),
641                        ..session_data.clone()
642                    }
643                    .into(),
644                )
645                .await
646                .expect("resume_session should be succeeded");
647            assert_eq!(agent.get_session().await, Some(session_data.clone().into()));
648        }
649        // failure with `getSession` error
650        {
651            let client = MockClient {
652                responses: MockResponses { ..Default::default() },
653                ..Default::default()
654            };
655            let agent = AtpAgent::new(client, MemorySessionStore::default());
656            assert_eq!(agent.get_session().await, None);
657            agent
658                .resume_session(session_data.clone().into())
659                .await
660                .expect_err("resume_session should be failed");
661            assert_eq!(agent.get_session().await, None);
662        }
663    }
664
665    #[tokio::test]
666    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
667    async fn test_resume_session_with_refresh() {
668        let session_data = session_data();
669        let client = MockClient {
670            responses: MockResponses {
671                get_session: Some(crate::com::atproto::server::get_session::OutputData {
672                    active: session_data.active,
673                    did: session_data.did.clone(),
674                    did_doc: session_data.did_doc.clone(),
675                    email: session_data.email.clone(),
676                    email_auth_factor: session_data.email_auth_factor,
677                    email_confirmed: session_data.email_confirmed,
678                    handle: session_data.handle.clone(),
679                    status: session_data.status.clone(),
680                }),
681                ..Default::default()
682            },
683            ..Default::default()
684        };
685        let agent = AtpAgent::new(client, MemorySessionStore::default());
686        agent
687            .resume_session(
688                OutputData { access_jwt: "expired".into(), ..session_data.clone() }.into(),
689            )
690            .await
691            .expect("resume_session should be succeeded");
692        assert_eq!(agent.get_session().await, Some(session_data.clone().into()));
693    }
694
695    #[tokio::test]
696    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
697    async fn test_login_with_diddoc() {
698        let session_data = session_data();
699        let did_doc = DidDocument {
700            context: None,
701            id: "did:plc:ewvi7nxzyoun6zhxrhs64oiz".into(),
702            also_known_as: Some(vec!["at://atproto.com".into()]),
703            verification_method: Some(vec![VerificationMethod {
704                id: "did:plc:ewvi7nxzyoun6zhxrhs64oiz#atproto".into(),
705                r#type: "Multikey".into(),
706                controller: "did:plc:ewvi7nxzyoun6zhxrhs64oiz".into(),
707                public_key_multibase: Some(
708                    "zQ3shXjHeiBuRCKmM36cuYnm7YEMzhGnCmCyW92sRJ9pribSF".into(),
709                ),
710            }]),
711            service: Some(vec![Service {
712                id: "#atproto_pds".into(),
713                r#type: "AtprotoPersonalDataServer".into(),
714                service_endpoint: "https://bsky.social".into(),
715            }]),
716        };
717        // success
718        {
719            let client = MockClient {
720                responses: MockResponses {
721                    create_session: Some(crate::com::atproto::server::create_session::OutputData {
722                        did_doc: Some(
723                            did_doc
724                                .clone()
725                                .try_into_unknown()
726                                .expect("failed to convert to unknown"),
727                        ),
728                        ..session_data.clone()
729                    }),
730                    ..Default::default()
731                },
732                ..Default::default()
733            };
734            let agent = AtpAgent::new(client, MemorySessionStore::default());
735            agent.login("test", "pass").await.expect("login should be succeeded");
736            assert_eq!(agent.get_endpoint().await, "https://bsky.social");
737            assert_eq!(agent.api.com.atproto.server.xrpc.base_uri(), "https://bsky.social");
738        }
739        // invalid services
740        {
741            let client = MockClient {
742                responses: MockResponses {
743                    create_session: Some(crate::com::atproto::server::create_session::OutputData {
744                        did_doc: Some(
745                            DidDocument {
746                                service: Some(vec![
747                                    Service {
748                                        id: "#pds".into(), // not `#atproto_pds`
749                                        r#type: "AtprotoPersonalDataServer".into(),
750                                        service_endpoint: "https://bsky.social".into(),
751                                    },
752                                    Service {
753                                        id: "#atproto_pds".into(),
754                                        r#type: "AtprotoPersonalDataServer".into(),
755                                        service_endpoint: "htps://bsky.social".into(), // invalid url (not `https`)
756                                    },
757                                ]),
758                                ..did_doc.clone()
759                            }
760                            .try_into_unknown()
761                            .expect("failed to convert to unknown"),
762                        ),
763                        ..session_data.clone()
764                    }),
765                    ..Default::default()
766                },
767                ..Default::default()
768            };
769            let agent = AtpAgent::new(client, MemorySessionStore::default());
770            agent.login("test", "pass").await.expect("login should be succeeded");
771            // not updated
772            assert_eq!(agent.get_endpoint().await, "http://localhost:8080");
773            assert_eq!(agent.api.com.atproto.server.xrpc.base_uri(), "http://localhost:8080");
774        }
775    }
776
777    #[tokio::test]
778    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
779    async fn test_configure_labelers_header() {
780        let client = MockClient::default();
781        let headers = Arc::clone(&client.headers);
782        let agent = AtpAgent::new(client, MemorySessionStore::default());
783
784        agent
785            .api
786            .com
787            .atproto
788            .server
789            .describe_server()
790            .await
791            .expect("describe_server should be succeeded");
792        assert_eq!(headers.read().await.last(), Some(&HeaderMap::new()));
793
794        agent.configure_labelers_header(Some(vec![(
795            "did:plc:test1".parse().expect("did should be valid"),
796            false,
797        )]));
798        agent
799            .api
800            .com
801            .atproto
802            .server
803            .describe_server()
804            .await
805            .expect("describe_server should be succeeded");
806        assert_eq!(
807            headers.read().await.last(),
808            Some(&HeaderMap::from_iter([(
809                HeaderName::from_static("atproto-accept-labelers"),
810                HeaderValue::from_static("did:plc:test1"),
811            )]))
812        );
813
814        agent.configure_labelers_header(Some(vec![
815            ("did:plc:test1".parse().expect("did should be valid"), true),
816            ("did:plc:test2".parse().expect("did should be valid"), false),
817        ]));
818        agent
819            .api
820            .com
821            .atproto
822            .server
823            .describe_server()
824            .await
825            .expect("describe_server should be succeeded");
826        assert_eq!(
827            headers.read().await.last(),
828            Some(&HeaderMap::from_iter([(
829                HeaderName::from_static("atproto-accept-labelers"),
830                HeaderValue::from_static("did:plc:test1;redact, did:plc:test2"),
831            )]))
832        );
833
834        assert_eq!(
835            agent.get_labelers_header().await,
836            Some(vec![String::from("did:plc:test1;redact"), String::from("did:plc:test2")])
837        );
838    }
839
840    #[tokio::test]
841    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
842    async fn test_configure_proxy_header() {
843        let client = MockClient::default();
844        let headers = Arc::clone(&client.headers);
845        let agent = AtpAgent::new(client, MemorySessionStore::default());
846
847        agent
848            .api
849            .com
850            .atproto
851            .server
852            .describe_server()
853            .await
854            .expect("describe_server should be succeeded");
855        assert_eq!(headers.read().await.last(), Some(&HeaderMap::new()));
856
857        agent.configure_proxy_header(
858            "did:plc:test1".parse().expect("did should be valid"),
859            AtprotoServiceType::AtprotoLabeler,
860        );
861        agent
862            .api
863            .com
864            .atproto
865            .server
866            .describe_server()
867            .await
868            .expect("describe_server should be succeeded");
869        assert_eq!(
870            headers.read().await.last(),
871            Some(&HeaderMap::from_iter([(
872                HeaderName::from_static("atproto-proxy"),
873                HeaderValue::from_static("did:plc:test1#atproto_labeler"),
874            ),]))
875        );
876
877        agent.configure_proxy_header(
878            "did:plc:test1".parse().expect("did should be valid"),
879            "atproto_labeler",
880        );
881        agent
882            .api
883            .com
884            .atproto
885            .server
886            .describe_server()
887            .await
888            .expect("describe_server should be succeeded");
889        assert_eq!(
890            headers.read().await.last(),
891            Some(&HeaderMap::from_iter([(
892                HeaderName::from_static("atproto-proxy"),
893                HeaderValue::from_static("did:plc:test1#atproto_labeler"),
894            ),]))
895        );
896
897        agent
898            .api_with_proxy(
899                "did:plc:test2".parse().expect("did should be valid"),
900                "atproto_labeler",
901            )
902            .com
903            .atproto
904            .server
905            .describe_server()
906            .await
907            .expect("describe_server should be succeeded");
908        assert_eq!(
909            headers.read().await.last(),
910            Some(&HeaderMap::from_iter([(
911                HeaderName::from_static("atproto-proxy"),
912                HeaderValue::from_static("did:plc:test2#atproto_labeler"),
913            ),]))
914        );
915
916        agent
917            .api
918            .com
919            .atproto
920            .server
921            .describe_server()
922            .await
923            .expect("describe_server should be succeeded");
924        assert_eq!(
925            headers.read().await.last(),
926            Some(&HeaderMap::from_iter([(
927                HeaderName::from_static("atproto-proxy"),
928                HeaderValue::from_static("did:plc:test1#atproto_labeler"),
929            ),]))
930        );
931
932        assert_eq!(
933            agent.get_proxy_header().await,
934            Some(String::from("did:plc:test1#atproto_labeler"))
935        );
936    }
937
938    #[tokio::test]
939    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
940    async fn test_agent_did() {
941        let session_data = session_data();
942        let client = MockClient { responses: MockResponses::default(), ..Default::default() };
943        let agent = AtpAgent::new(client, MemorySessionStore::default());
944        assert_eq!(agent.did().await, None);
945        agent
946            .session_manager
947            .store
948            .set(session_data.clone().into())
949            .await
950            .expect("set session should be succeeded");
951        assert_eq!(agent.did().await, Some(session_data.did));
952    }
953}