1mod inner;
4pub mod store;
5
6use self::store::AtpSessionStore;
7use super::{
8 inner::Wrapper, utils::SessionWithEndpointStore, Agent, CloneWithProxy, Configure,
9 SessionManager,
10};
11use crate::{
12 client::com::atproto::Service,
13 did_doc::DidDocument,
14 types::{string::Did, TryFromUnknown},
15};
16use atrium_xrpc::{Error, HttpClient, OutputDataOrBytes, XrpcClient, XrpcRequest};
17use http::{Request, Response};
18use serde::{de::DeserializeOwned, Serialize};
19use std::{convert, fmt::Debug, ops::Deref, sync::Arc};
20
21pub type AtpSession = crate::com::atproto::server::create_session::Output;
23
24pub struct CredentialSession<S, T>
25where
26 S: AtpSessionStore + Send + Sync,
27 T: XrpcClient + Send + Sync,
28 S::Error: std::error::Error + Send + Sync + 'static,
29{
30 store: Arc<SessionWithEndpointStore<S, AtpSession>>,
31 inner: Arc<inner::Client<S, T>>,
32 atproto_service: Service<inner::Client<S, T>>,
33}
34
35impl<S, T> CredentialSession<S, T>
36where
37 S: AtpSessionStore + Send + Sync,
38 T: XrpcClient + Send + Sync,
39 S::Error: std::error::Error + Send + Sync + 'static,
40{
41 pub fn new(xrpc: T, store: S) -> Self {
42 let store = Arc::new(SessionWithEndpointStore::new(store, xrpc.base_uri()));
43 let inner = Arc::new(inner::Client::new(Arc::clone(&store), xrpc));
44 let atproto_service = Service::new(Arc::clone(&inner));
45 Self { store, inner, atproto_service }
46 }
47 pub async fn login(
49 &self,
50 identifier: impl AsRef<str>,
51 password: impl AsRef<str>,
52 ) -> Result<AtpSession, Error<crate::com::atproto::server::create_session::Error>> {
53 let result = self
54 .atproto_service
55 .server
56 .create_session(
57 crate::com::atproto::server::create_session::InputData {
58 allow_takendown: None,
59 auth_factor_token: None,
60 identifier: identifier.as_ref().into(),
61 password: password.as_ref().into(),
62 }
63 .into(),
64 )
65 .await?;
66 self.store.set(result.clone()).await.ok();
67 if let Some(did_doc) = result
68 .did_doc
69 .as_ref()
70 .and_then(|value| DidDocument::try_from_unknown(value.clone()).ok())
71 {
72 self.store.update_endpoint(&did_doc);
73 }
74 Ok(result)
75 }
76 pub async fn resume_session(
78 &self,
79 session: AtpSession,
80 ) -> Result<(), Error<crate::com::atproto::server::get_session::Error>> {
81 self.store.set(session.clone()).await.ok();
82 let result = self.atproto_service.server.get_session().await;
83 match result {
84 Ok(output) => {
85 assert_eq!(output.data.did, session.data.did);
86 if let Ok(Some(mut session)) = self.store.get().await {
87 session.did_doc = output.data.did_doc.clone();
88 session.email = output.data.email;
89 session.email_confirmed = output.data.email_confirmed;
90 session.handle = output.data.handle;
91 self.store.set(session).await.ok();
92 }
93 if let Some(did_doc) = output
94 .data
95 .did_doc
96 .as_ref()
97 .and_then(|value| DidDocument::try_from_unknown(value.clone()).ok())
98 {
99 self.store.update_endpoint(&did_doc);
100 }
101 Ok(())
102 }
103 Err(err) => {
104 self.store.clear().await.ok();
105 Err(err)
106 }
107 }
108 }
109 pub async fn get_session(&self) -> Option<AtpSession> {
111 self.store.get().await.ok().and_then(convert::identity)
112 }
113 pub async fn get_endpoint(&self) -> String {
115 self.store.get_endpoint()
116 }
117 pub async fn get_labelers_header(&self) -> Option<Vec<String>> {
119 self.inner.get_labelers_header().await
120 }
121 pub async fn get_proxy_header(&self) -> Option<String> {
123 self.inner.get_proxy_header().await
124 }
125}
126
127impl<S, T> HttpClient for CredentialSession<S, T>
128where
129 S: AtpSessionStore + Send + Sync,
130 T: XrpcClient + Send + Sync,
131 S::Error: std::error::Error + Send + Sync + 'static,
132{
133 async fn send_http(
134 &self,
135 request: Request<Vec<u8>>,
136 ) -> Result<Response<Vec<u8>>, Box<dyn std::error::Error + Send + Sync + 'static>> {
137 self.inner.send_http(request).await
138 }
139}
140
141impl<S, T> XrpcClient for CredentialSession<S, T>
142where
143 S: AtpSessionStore + Send + Sync,
144 T: XrpcClient + Send + Sync,
145 S::Error: std::error::Error + Send + Sync + 'static,
146{
147 fn base_uri(&self) -> String {
148 self.inner.base_uri()
149 }
150 async fn send_xrpc<P, I, O, E>(
151 &self,
152 request: &XrpcRequest<P, I>,
153 ) -> Result<OutputDataOrBytes<O>, Error<E>>
154 where
155 P: Serialize + Send + Sync,
156 I: Serialize + Send + Sync,
157 O: DeserializeOwned + Send + Sync,
158 E: DeserializeOwned + Send + Sync + Debug,
159 {
160 self.inner.send_xrpc(request).await
161 }
162}
163
164impl<S, T> SessionManager for CredentialSession<S, T>
165where
166 S: AtpSessionStore + Send + Sync,
167 T: XrpcClient + Send + Sync,
168 S::Error: std::error::Error + Send + Sync + 'static,
169{
170 async fn did(&self) -> Option<Did> {
171 self.store.get().await.ok().and_then(|session| session.map(|session| session.data.did))
172 }
173}
174
175impl<S, T> Configure for CredentialSession<S, T>
176where
177 S: AtpSessionStore + Send + Sync,
178 T: XrpcClient + Send + Sync,
179 S::Error: std::error::Error + Send + Sync + 'static,
180{
181 fn configure_endpoint(&self, endpoint: String) {
182 self.inner.configure_endpoint(endpoint);
183 }
184 fn configure_labelers_header(&self, labeler_dids: Option<Vec<(Did, bool)>>) {
185 self.inner.configure_labelers_header(labeler_dids);
186 }
187 fn configure_proxy_header(&self, did: Did, service_type: impl AsRef<str>) {
188 self.inner.configure_proxy_header(did, service_type);
189 }
190}
191
192impl<S, T> CloneWithProxy for CredentialSession<S, T>
193where
194 S: AtpSessionStore + Send + Sync,
195 S::Error: std::error::Error + Send + Sync + 'static,
196 T: XrpcClient + Send + Sync,
197{
198 fn clone_with_proxy(&self, did: Did, service_type: impl AsRef<str>) -> Self {
199 let inner = Arc::new(self.inner.clone_with_proxy(did, service_type));
200 let atproto_service = Service::new(Arc::clone(&inner));
201 Self { store: Arc::clone(&self.store), inner, atproto_service }
202 }
203}
204
205pub struct AtpAgent<S, T>
225where
226 S: AtpSessionStore + Send + Sync,
227 T: XrpcClient + Send + Sync,
228 S::Error: std::error::Error + Send + Sync + 'static,
229{
230 session_manager: Wrapper<CredentialSession<S, T>>,
231 inner: Agent<Wrapper<CredentialSession<S, T>>>,
232}
233
234impl<S, T> AtpAgent<S, T>
235where
236 S: AtpSessionStore + Send + Sync,
237 T: XrpcClient + Send + Sync,
238 S::Error: std::error::Error + Send + Sync + 'static,
239{
240 pub fn new(xrpc: T, store: S) -> Self {
242 let session_manager = Wrapper::new(CredentialSession::new(xrpc, store));
243 let inner = Agent::new(session_manager.clone());
244 Self { session_manager, inner }
245 }
246 pub async fn login(
248 &self,
249 identifier: impl AsRef<str>,
250 password: impl AsRef<str>,
251 ) -> Result<AtpSession, Error<crate::com::atproto::server::create_session::Error>> {
252 self.session_manager.login(identifier, password).await
253 }
254 pub async fn resume_session(
256 &self,
257 session: AtpSession,
258 ) -> Result<(), Error<crate::com::atproto::server::get_session::Error>> {
259 self.session_manager.resume_session(session).await
260 }
261 pub async fn get_session(&self) -> Option<AtpSession> {
263 self.session_manager.get_session().await
264 }
265 pub async fn get_endpoint(&self) -> String {
267 self.session_manager.get_endpoint().await
268 }
269 pub async fn get_labelers_header(&self) -> Option<Vec<String>> {
271 self.session_manager.get_labelers_header().await
272 }
273 pub async fn get_proxy_header(&self) -> Option<String> {
275 self.session_manager.get_proxy_header().await
276 }
277}
278
279impl<S, T> Deref for AtpAgent<S, T>
280where
281 S: AtpSessionStore + Send + Sync,
282 T: XrpcClient + Send + Sync,
283 S::Error: std::error::Error + Send + Sync + 'static,
284{
285 type Target = Agent<Wrapper<CredentialSession<S, T>>>;
286
287 fn deref(&self) -> &Self::Target {
288 &self.inner
289 }
290}
291
292#[cfg(test)]
293mod tests {
294 use super::store::MemorySessionStore;
295 use super::*;
296 use crate::{
297 agent::AtprotoServiceType,
298 com::atproto::server::create_session::OutputData,
299 did_doc::{DidDocument, Service, VerificationMethod},
300 types::TryIntoUnknown,
301 };
302 use atrium_xrpc::HttpClient;
303 use http::{HeaderMap, HeaderName, HeaderValue, Request, Response};
304 use std::collections::HashMap;
305 use tokio::sync::RwLock;
306 #[cfg(target_arch = "wasm32")]
307 use wasm_bindgen_test::wasm_bindgen_test;
308
309 #[derive(Default)]
310 struct MockResponses {
311 create_session: Option<crate::com::atproto::server::create_session::OutputData>,
312 get_session: Option<crate::com::atproto::server::get_session::OutputData>,
313 }
314
315 #[derive(Default)]
316 struct MockClient {
317 responses: MockResponses,
318 counts: Arc<RwLock<HashMap<String, usize>>>,
319 headers: Arc<RwLock<Vec<HeaderMap<HeaderValue>>>>,
320 }
321
322 impl HttpClient for MockClient {
323 async fn send_http(
324 &self,
325 request: Request<Vec<u8>>,
326 ) -> Result<Response<Vec<u8>>, Box<dyn std::error::Error + Send + Sync + 'static>> {
327 #[cfg(not(target_arch = "wasm32"))]
329 tokio::time::sleep(std::time::Duration::from_micros(10)).await;
330
331 self.headers.write().await.push(request.headers().clone());
332 let builder =
333 Response::builder().header(http::header::CONTENT_TYPE, "application/json");
334 let token = request
335 .headers()
336 .get(http::header::AUTHORIZATION)
337 .and_then(|value| value.to_str().ok())
338 .and_then(|value| value.split(' ').last());
339 if token == Some("expired") {
340 return Ok(builder.status(http::StatusCode::BAD_REQUEST).body(
341 serde_json::to_vec(&atrium_xrpc::error::ErrorResponseBody {
342 error: Some(String::from("ExpiredToken")),
343 message: Some(String::from("Token has expired")),
344 })?,
345 )?);
346 }
347 let mut body = Vec::new();
348 if let Some(nsid) = request.uri().path().strip_prefix("/xrpc/") {
349 *self.counts.write().await.entry(nsid.into()).or_default() += 1;
350 match nsid {
351 crate::com::atproto::server::create_session::NSID => {
352 if let Some(output) = &self.responses.create_session {
353 body.extend(serde_json::to_vec(output)?);
354 }
355 }
356 crate::com::atproto::server::get_session::NSID => {
357 if token == Some("access") {
358 if let Some(output) = &self.responses.get_session {
359 body.extend(serde_json::to_vec(output)?);
360 }
361 }
362 }
363 crate::com::atproto::server::refresh_session::NSID => {
364 if token == Some("refresh") {
365 body.extend(serde_json::to_vec(
366 &crate::com::atproto::server::refresh_session::OutputData {
367 access_jwt: String::from("access"),
368 active: None,
369 did: "did:web:example.com".parse().expect("valid"),
370 did_doc: None,
371 handle: "example.com".parse().expect("valid"),
372 refresh_jwt: String::from("refresh"),
373 status: None,
374 },
375 )?);
376 }
377 }
378 crate::com::atproto::server::describe_server::NSID => {
379 body.extend(serde_json::to_vec(
380 &crate::com::atproto::server::describe_server::OutputData {
381 available_user_domains: Vec::new(),
382 contact: None,
383 did: "did:web:example.com".parse().expect("valid"),
384 invite_code_required: None,
385 links: None,
386 phone_verification_required: None,
387 },
388 )?);
389 }
390 _ => {}
391 }
392 }
393 if body.is_empty() {
394 Ok(builder.status(http::StatusCode::UNAUTHORIZED).body(serde_json::to_vec(
395 &atrium_xrpc::error::ErrorResponseBody {
396 error: Some(String::from("AuthenticationRequired")),
397 message: Some(String::from("Invalid identifier or password")),
398 },
399 )?)?)
400 } else {
401 Ok(builder.status(http::StatusCode::OK).body(body)?)
402 }
403 }
404 }
405
406 impl XrpcClient for MockClient {
407 fn base_uri(&self) -> String {
408 "http://localhost:8080".into()
409 }
410 }
411
412 fn session_data() -> OutputData {
413 OutputData {
414 access_jwt: String::from("access"),
415 active: None,
416 did: "did:web:example.com".parse().expect("valid"),
417 did_doc: None,
418 email: None,
419 email_auth_factor: None,
420 email_confirmed: None,
421 handle: "example.com".parse().expect("valid"),
422 refresh_jwt: String::from("refresh"),
423 status: None,
424 }
425 }
426
427 #[tokio::test]
428 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
429 async fn test_new() {
430 let agent = AtpAgent::new(MockClient::default(), MemorySessionStore::default());
431 assert_eq!(agent.get_session().await, None);
432 }
433
434 #[tokio::test]
435 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
436 async fn test_login() {
437 let session_data = session_data();
438 {
440 let client = MockClient {
441 responses: MockResponses {
442 create_session: Some(crate::com::atproto::server::create_session::OutputData {
443 ..session_data.clone()
444 }),
445 ..Default::default()
446 },
447 ..Default::default()
448 };
449 let agent = AtpAgent::new(client, MemorySessionStore::default());
450 agent.login("test", "pass").await.expect("login should be succeeded");
451 assert_eq!(agent.get_session().await, Some(session_data.into()));
452 }
453 {
455 let client = MockClient {
456 responses: MockResponses { ..Default::default() },
457 ..Default::default()
458 };
459 let agent = AtpAgent::new(client, MemorySessionStore::default());
460 agent.login("test", "bad").await.expect_err("login should be failed");
461 assert_eq!(agent.get_session().await, None);
462 }
463 }
464
465 #[tokio::test]
466 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
467 async fn test_xrpc_get_session() {
468 let session_data = session_data();
469 let client = MockClient {
470 responses: MockResponses {
471 get_session: Some(crate::com::atproto::server::get_session::OutputData {
472 active: session_data.active,
473 did: session_data.did.clone(),
474 did_doc: session_data.did_doc.clone(),
475 email: session_data.email.clone(),
476 email_auth_factor: session_data.email_auth_factor,
477 email_confirmed: session_data.email_confirmed,
478 handle: session_data.handle.clone(),
479 status: session_data.status.clone(),
480 }),
481 ..Default::default()
482 },
483 ..Default::default()
484 };
485 let agent = AtpAgent::new(client, MemorySessionStore::default());
486 agent
487 .session_manager
488 .store
489 .set(session_data.clone().into())
490 .await
491 .expect("set session should be succeeded");
492 let output = agent
493 .api
494 .com
495 .atproto
496 .server
497 .get_session()
498 .await
499 .expect("get session should be succeeded");
500 assert_eq!(output.did.as_str(), "did:web:example.com");
501 }
502
503 #[tokio::test]
504 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
505 async fn test_xrpc_get_session_with_refresh() {
506 let mut session_data = session_data();
507 session_data.access_jwt = String::from("expired");
508 let client = MockClient {
509 responses: MockResponses {
510 get_session: Some(crate::com::atproto::server::get_session::OutputData {
511 active: session_data.active,
512 did: session_data.did.clone(),
513 did_doc: session_data.did_doc.clone(),
514 email: session_data.email.clone(),
515 email_auth_factor: session_data.email_auth_factor,
516 email_confirmed: session_data.email_confirmed,
517 handle: session_data.handle.clone(),
518 status: session_data.status.clone(),
519 }),
520 ..Default::default()
521 },
522 ..Default::default()
523 };
524 let agent = AtpAgent::new(client, MemorySessionStore::default());
525 agent
526 .session_manager
527 .store
528 .set(session_data.clone().into())
529 .await
530 .expect("set session should be succeeded");
531 let output = agent
532 .api
533 .com
534 .atproto
535 .server
536 .get_session()
537 .await
538 .expect("get session should be succeeded");
539 assert_eq!(output.did.as_str(), "did:web:example.com");
540 assert_eq!(
541 agent
542 .session_manager
543 .store
544 .get()
545 .await
546 .expect("session should be stored")
547 .map(|session| session.data.access_jwt),
548 Some("access".into())
549 );
550 }
551
552 #[cfg(not(target_arch = "wasm32"))]
553 #[tokio::test]
554 async fn test_xrpc_get_session_with_duplicated_refresh() {
555 let mut session_data = session_data();
556 session_data.access_jwt = String::from("expired");
557 let client = MockClient {
558 responses: MockResponses {
559 get_session: Some(crate::com::atproto::server::get_session::OutputData {
560 active: session_data.active,
561 did: session_data.did.clone(),
562 did_doc: session_data.did_doc.clone(),
563 email: session_data.email.clone(),
564 email_auth_factor: session_data.email_auth_factor,
565 email_confirmed: session_data.email_confirmed,
566 handle: session_data.handle.clone(),
567 status: session_data.status.clone(),
568 }),
569 ..Default::default()
570 },
571 ..Default::default()
572 };
573 let counts = Arc::clone(&client.counts);
574 let agent = Arc::new(AtpAgent::new(client, MemorySessionStore::default()));
575 agent
576 .session_manager
577 .store
578 .set(session_data.clone().into())
579 .await
580 .expect("set session should be succeeded");
581 let handles = (0..3).map(|_| {
582 let agent = Arc::clone(&agent);
583 tokio::spawn(async move { agent.api.com.atproto.server.get_session().await })
584 });
585 let results = futures::future::join_all(handles).await;
586 for result in &results {
587 let output = result
588 .as_ref()
589 .expect("task should be successfully executed")
590 .as_ref()
591 .expect("get session should be succeeded");
592 assert_eq!(output.did.as_str(), "did:web:example.com");
593 }
594 assert_eq!(
595 agent
596 .session_manager
597 .store
598 .get()
599 .await
600 .expect("session should be stored")
601 .map(|session| session.data.access_jwt),
602 Some("access".into())
603 );
604 assert_eq!(
605 counts.read().await.clone(),
606 HashMap::from_iter([
607 ("com.atproto.server.refreshSession".into(), 1),
608 ("com.atproto.server.getSession".into(), 3)
609 ])
610 );
611 }
612
613 #[tokio::test]
614 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
615 async fn test_resume_session() {
616 let session_data = session_data();
617 {
619 let client = MockClient {
620 responses: MockResponses {
621 get_session: Some(crate::com::atproto::server::get_session::OutputData {
622 active: session_data.active,
623 did: session_data.did.clone(),
624 did_doc: session_data.did_doc.clone(),
625 email: session_data.email.clone(),
626 email_auth_factor: session_data.email_auth_factor,
627 email_confirmed: session_data.email_confirmed,
628 handle: session_data.handle.clone(),
629 status: session_data.status.clone(),
630 }),
631 ..Default::default()
632 },
633 ..Default::default()
634 };
635 let agent = AtpAgent::new(client, MemorySessionStore::default());
636 assert_eq!(agent.get_session().await, None);
637 agent
638 .resume_session(
639 OutputData {
640 email: Some(String::from("test@example.com")),
641 ..session_data.clone()
642 }
643 .into(),
644 )
645 .await
646 .expect("resume_session should be succeeded");
647 assert_eq!(agent.get_session().await, Some(session_data.clone().into()));
648 }
649 {
651 let client = MockClient {
652 responses: MockResponses { ..Default::default() },
653 ..Default::default()
654 };
655 let agent = AtpAgent::new(client, MemorySessionStore::default());
656 assert_eq!(agent.get_session().await, None);
657 agent
658 .resume_session(session_data.clone().into())
659 .await
660 .expect_err("resume_session should be failed");
661 assert_eq!(agent.get_session().await, None);
662 }
663 }
664
665 #[tokio::test]
666 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
667 async fn test_resume_session_with_refresh() {
668 let session_data = session_data();
669 let client = MockClient {
670 responses: MockResponses {
671 get_session: Some(crate::com::atproto::server::get_session::OutputData {
672 active: session_data.active,
673 did: session_data.did.clone(),
674 did_doc: session_data.did_doc.clone(),
675 email: session_data.email.clone(),
676 email_auth_factor: session_data.email_auth_factor,
677 email_confirmed: session_data.email_confirmed,
678 handle: session_data.handle.clone(),
679 status: session_data.status.clone(),
680 }),
681 ..Default::default()
682 },
683 ..Default::default()
684 };
685 let agent = AtpAgent::new(client, MemorySessionStore::default());
686 agent
687 .resume_session(
688 OutputData { access_jwt: "expired".into(), ..session_data.clone() }.into(),
689 )
690 .await
691 .expect("resume_session should be succeeded");
692 assert_eq!(agent.get_session().await, Some(session_data.clone().into()));
693 }
694
695 #[tokio::test]
696 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
697 async fn test_login_with_diddoc() {
698 let session_data = session_data();
699 let did_doc = DidDocument {
700 context: None,
701 id: "did:plc:ewvi7nxzyoun6zhxrhs64oiz".into(),
702 also_known_as: Some(vec!["at://atproto.com".into()]),
703 verification_method: Some(vec![VerificationMethod {
704 id: "did:plc:ewvi7nxzyoun6zhxrhs64oiz#atproto".into(),
705 r#type: "Multikey".into(),
706 controller: "did:plc:ewvi7nxzyoun6zhxrhs64oiz".into(),
707 public_key_multibase: Some(
708 "zQ3shXjHeiBuRCKmM36cuYnm7YEMzhGnCmCyW92sRJ9pribSF".into(),
709 ),
710 }]),
711 service: Some(vec![Service {
712 id: "#atproto_pds".into(),
713 r#type: "AtprotoPersonalDataServer".into(),
714 service_endpoint: "https://bsky.social".into(),
715 }]),
716 };
717 {
719 let client = MockClient {
720 responses: MockResponses {
721 create_session: Some(crate::com::atproto::server::create_session::OutputData {
722 did_doc: Some(
723 did_doc
724 .clone()
725 .try_into_unknown()
726 .expect("failed to convert to unknown"),
727 ),
728 ..session_data.clone()
729 }),
730 ..Default::default()
731 },
732 ..Default::default()
733 };
734 let agent = AtpAgent::new(client, MemorySessionStore::default());
735 agent.login("test", "pass").await.expect("login should be succeeded");
736 assert_eq!(agent.get_endpoint().await, "https://bsky.social");
737 assert_eq!(agent.api.com.atproto.server.xrpc.base_uri(), "https://bsky.social");
738 }
739 {
741 let client = MockClient {
742 responses: MockResponses {
743 create_session: Some(crate::com::atproto::server::create_session::OutputData {
744 did_doc: Some(
745 DidDocument {
746 service: Some(vec![
747 Service {
748 id: "#pds".into(), r#type: "AtprotoPersonalDataServer".into(),
750 service_endpoint: "https://bsky.social".into(),
751 },
752 Service {
753 id: "#atproto_pds".into(),
754 r#type: "AtprotoPersonalDataServer".into(),
755 service_endpoint: "htps://bsky.social".into(), },
757 ]),
758 ..did_doc.clone()
759 }
760 .try_into_unknown()
761 .expect("failed to convert to unknown"),
762 ),
763 ..session_data.clone()
764 }),
765 ..Default::default()
766 },
767 ..Default::default()
768 };
769 let agent = AtpAgent::new(client, MemorySessionStore::default());
770 agent.login("test", "pass").await.expect("login should be succeeded");
771 assert_eq!(agent.get_endpoint().await, "http://localhost:8080");
773 assert_eq!(agent.api.com.atproto.server.xrpc.base_uri(), "http://localhost:8080");
774 }
775 }
776
777 #[tokio::test]
778 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
779 async fn test_configure_labelers_header() {
780 let client = MockClient::default();
781 let headers = Arc::clone(&client.headers);
782 let agent = AtpAgent::new(client, MemorySessionStore::default());
783
784 agent
785 .api
786 .com
787 .atproto
788 .server
789 .describe_server()
790 .await
791 .expect("describe_server should be succeeded");
792 assert_eq!(headers.read().await.last(), Some(&HeaderMap::new()));
793
794 agent.configure_labelers_header(Some(vec![(
795 "did:plc:test1".parse().expect("did should be valid"),
796 false,
797 )]));
798 agent
799 .api
800 .com
801 .atproto
802 .server
803 .describe_server()
804 .await
805 .expect("describe_server should be succeeded");
806 assert_eq!(
807 headers.read().await.last(),
808 Some(&HeaderMap::from_iter([(
809 HeaderName::from_static("atproto-accept-labelers"),
810 HeaderValue::from_static("did:plc:test1"),
811 )]))
812 );
813
814 agent.configure_labelers_header(Some(vec![
815 ("did:plc:test1".parse().expect("did should be valid"), true),
816 ("did:plc:test2".parse().expect("did should be valid"), false),
817 ]));
818 agent
819 .api
820 .com
821 .atproto
822 .server
823 .describe_server()
824 .await
825 .expect("describe_server should be succeeded");
826 assert_eq!(
827 headers.read().await.last(),
828 Some(&HeaderMap::from_iter([(
829 HeaderName::from_static("atproto-accept-labelers"),
830 HeaderValue::from_static("did:plc:test1;redact, did:plc:test2"),
831 )]))
832 );
833
834 assert_eq!(
835 agent.get_labelers_header().await,
836 Some(vec![String::from("did:plc:test1;redact"), String::from("did:plc:test2")])
837 );
838 }
839
840 #[tokio::test]
841 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
842 async fn test_configure_proxy_header() {
843 let client = MockClient::default();
844 let headers = Arc::clone(&client.headers);
845 let agent = AtpAgent::new(client, MemorySessionStore::default());
846
847 agent
848 .api
849 .com
850 .atproto
851 .server
852 .describe_server()
853 .await
854 .expect("describe_server should be succeeded");
855 assert_eq!(headers.read().await.last(), Some(&HeaderMap::new()));
856
857 agent.configure_proxy_header(
858 "did:plc:test1".parse().expect("did should be valid"),
859 AtprotoServiceType::AtprotoLabeler,
860 );
861 agent
862 .api
863 .com
864 .atproto
865 .server
866 .describe_server()
867 .await
868 .expect("describe_server should be succeeded");
869 assert_eq!(
870 headers.read().await.last(),
871 Some(&HeaderMap::from_iter([(
872 HeaderName::from_static("atproto-proxy"),
873 HeaderValue::from_static("did:plc:test1#atproto_labeler"),
874 ),]))
875 );
876
877 agent.configure_proxy_header(
878 "did:plc:test1".parse().expect("did should be valid"),
879 "atproto_labeler",
880 );
881 agent
882 .api
883 .com
884 .atproto
885 .server
886 .describe_server()
887 .await
888 .expect("describe_server should be succeeded");
889 assert_eq!(
890 headers.read().await.last(),
891 Some(&HeaderMap::from_iter([(
892 HeaderName::from_static("atproto-proxy"),
893 HeaderValue::from_static("did:plc:test1#atproto_labeler"),
894 ),]))
895 );
896
897 agent
898 .api_with_proxy(
899 "did:plc:test2".parse().expect("did should be valid"),
900 "atproto_labeler",
901 )
902 .com
903 .atproto
904 .server
905 .describe_server()
906 .await
907 .expect("describe_server should be succeeded");
908 assert_eq!(
909 headers.read().await.last(),
910 Some(&HeaderMap::from_iter([(
911 HeaderName::from_static("atproto-proxy"),
912 HeaderValue::from_static("did:plc:test2#atproto_labeler"),
913 ),]))
914 );
915
916 agent
917 .api
918 .com
919 .atproto
920 .server
921 .describe_server()
922 .await
923 .expect("describe_server should be succeeded");
924 assert_eq!(
925 headers.read().await.last(),
926 Some(&HeaderMap::from_iter([(
927 HeaderName::from_static("atproto-proxy"),
928 HeaderValue::from_static("did:plc:test1#atproto_labeler"),
929 ),]))
930 );
931
932 assert_eq!(
933 agent.get_proxy_header().await,
934 Some(String::from("did:plc:test1#atproto_labeler"))
935 );
936 }
937
938 #[tokio::test]
939 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
940 async fn test_agent_did() {
941 let session_data = session_data();
942 let client = MockClient { responses: MockResponses::default(), ..Default::default() };
943 let agent = AtpAgent::new(client, MemorySessionStore::default());
944 assert_eq!(agent.did().await, None);
945 agent
946 .session_manager
947 .store
948 .set(session_data.clone().into())
949 .await
950 .expect("set session should be succeeded");
951 assert_eq!(agent.did().await, Some(session_data.did));
952 }
953}