1use std::sync::{Arc, Mutex};
7use tokio::sync::{Mutex as AsyncMutex, RwLock};
8
9use proto_blue_lex_data::Cid;
10use proto_blue_syntax::{AtIdentifier, AtUri, Did, Handle};
11use proto_blue_xrpc::{
12 CallOptions, HeadersMap, QueryParams, QueryValue, ResponseType, XrpcBody, XrpcClient,
13};
14
15use crate::rich_text::RichText;
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24pub enum AtpSessionEvent {
25 Create,
27 CreateFailed,
29 Update,
31 Expired,
33 NetworkError,
35}
36
37pub type SessionEventCallback = Arc<dyn Fn(AtpSessionEvent, Option<&Session>) + Send + Sync>;
44
45#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
53#[serde(rename_all = "camelCase")]
54pub struct Session {
55 pub did: Did,
56 pub handle: Handle,
57 pub access_jwt: String,
58 pub refresh_jwt: String,
59 #[serde(skip_serializing_if = "Option::is_none")]
60 pub email: Option<String>,
61 #[serde(skip_serializing_if = "Option::is_none")]
62 pub email_confirmed: Option<bool>,
63}
64
65#[derive(Debug, thiserror::Error)]
67pub enum AgentError {
68 #[error("XRPC error: {0}")]
69 Xrpc(#[from] proto_blue_xrpc::Error),
70 #[error("Not authenticated")]
71 NotAuthenticated,
72 #[error("JSON error: {0}")]
73 Json(#[from] serde_json::Error),
74 #[error("{0}")]
75 Other(String),
76}
77
78pub struct Agent {
96 client: XrpcClient,
97 session: Arc<RwLock<Option<Session>>>,
98 listeners: Arc<Mutex<Vec<SessionEventCallback>>>,
101 refresh_lock: Arc<AsyncMutex<()>>,
106 proxy: Arc<RwLock<Option<String>>>,
109 labelers: Arc<RwLock<Vec<LabelerOpts>>>,
111}
112
113#[derive(Debug, Clone, PartialEq, Eq)]
115pub struct LabelerOpts {
116 pub did: Did,
118 pub redirect: bool,
121}
122
123impl LabelerOpts {
124 fn header_value(&self) -> String {
126 if self.redirect {
127 format!("{};redirect", self.did)
128 } else {
129 self.did.to_string()
130 }
131 }
132}
133
134impl Agent {
135 #[cfg(any(
141 all(feature = "fetch-reqwest", not(target_arch = "wasm32")),
142 target_arch = "wasm32",
143 ))]
144 pub fn new(service: impl AsRef<str>) -> Result<Self, AgentError> {
145 let client = XrpcClient::new(service)?;
146 Ok(Self {
147 client,
148 session: Arc::new(RwLock::new(None)),
149 listeners: Arc::new(Mutex::new(Vec::new())),
150 refresh_lock: Arc::new(AsyncMutex::new(())),
151 proxy: Arc::new(RwLock::new(None)),
152 labelers: Arc::new(RwLock::new(Vec::new())),
153 })
154 }
155
156 pub fn on_session<F>(&self, callback: F)
163 where
164 F: Fn(AtpSessionEvent, Option<&Session>) + Send + Sync + 'static,
165 {
166 self.listeners.lock().unwrap().push(Arc::new(callback));
167 }
168
169 fn emit(&self, event: AtpSessionEvent, session: Option<&Session>) {
171 let listeners = self.listeners.lock().unwrap().clone();
176 for cb in listeners {
177 cb(event, session);
178 }
179 }
180
181 #[must_use]
183 pub fn service(&self) -> String {
184 self.client.service_url().to_string()
185 }
186
187 pub async fn did(&self) -> Option<Did> {
189 self.session.read().await.as_ref().map(|s| s.did.clone())
190 }
191
192 pub async fn session(&self) -> Option<Session> {
194 self.session.read().await.clone()
195 }
196
197 async fn auth_call_options(&self) -> Option<CallOptions> {
204 let access_jwt = {
208 let guard = self.session.read().await;
209 guard.as_ref()?.access_jwt.clone()
210 };
211 let mut headers = HeadersMap::new();
212 headers.insert("Authorization".into(), format!("Bearer {access_jwt}"));
213 self.inject_proxy_and_labelers(&mut headers).await;
214 Some(CallOptions {
215 encoding: None,
216 headers: Some(headers),
217 ..Default::default()
218 })
219 }
220
221 pub async fn anon_call_options(&self) -> Option<CallOptions> {
227 let mut headers = HeadersMap::new();
228 self.inject_proxy_and_labelers(&mut headers).await;
229 if headers.is_empty() {
230 None
231 } else {
232 Some(CallOptions {
233 encoding: None,
234 headers: Some(headers),
235 ..Default::default()
236 })
237 }
238 }
239
240 async fn inject_proxy_and_labelers(&self, headers: &mut HeadersMap) {
241 if let Some(proxy) = self.proxy.read().await.as_ref() {
242 headers.insert("atproto-proxy".into(), proxy.clone());
243 }
244 let labelers = self.labelers.read().await;
245 if !labelers.is_empty() {
246 let v = labelers
247 .iter()
248 .map(LabelerOpts::header_value)
249 .collect::<Vec<_>>()
250 .join(", ");
251 drop(labelers);
252 headers.insert("atproto-accept-labelers".into(), v);
253 }
254 }
255
256 pub async fn configure_proxy(&self, target: Option<&str>) {
262 *self.proxy.write().await = target.map(String::from);
263 }
264
265 pub async fn with_proxy(&self, target: &str) -> Self {
268 let cloned = self.shallow_clone();
269 cloned.configure_proxy(Some(target)).await;
270 cloned
271 }
272
273 pub async fn configure_labelers(&self, labelers: &[LabelerOpts]) {
276 *self.labelers.write().await = labelers.to_vec();
277 }
278
279 fn shallow_clone(&self) -> Self {
283 Self {
284 client: self.client.clone(),
285 session: self.session.clone(),
286 listeners: self.listeners.clone(),
287 refresh_lock: self.refresh_lock.clone(),
288 proxy: Arc::new(RwLock::new(None)),
289 labelers: self.labelers.clone(),
290 }
291 }
292
293 pub async fn login(
299 &self,
300 identifier: &AtIdentifier,
301 password: &str,
302 ) -> Result<Session, AgentError> {
303 let body = serde_json::json!({
304 "identifier": identifier,
305 "password": password,
306 });
307
308 let response = match self
309 .client
310 .procedure(
311 "com.atproto.server.createSession",
312 None,
313 Some(XrpcBody::Json(body)),
314 None,
315 )
316 .await
317 {
318 Ok(r) => r,
319 Err(e) => {
320 self.emit(AtpSessionEvent::CreateFailed, None);
321 return Err(AgentError::Xrpc(e));
322 }
323 };
324
325 let session: Session = serde_json::from_value(response.data)?;
326
327 *self.session.write().await = Some(session.clone());
329 self.emit(AtpSessionEvent::Create, Some(&session));
330 Ok(session)
331 }
332
333 pub async fn resume_session(&self, session: Session) -> Result<(), AgentError> {
338 let mut headers = HeadersMap::new();
341 headers.insert(
342 "Authorization".into(),
343 format!("Bearer {}", session.access_jwt),
344 );
345 let opts = CallOptions {
346 encoding: None,
347 headers: Some(headers),
348 ..Default::default()
349 };
350 let response = self
351 .client
352 .query("com.atproto.server.getSession", None, Some(&opts))
353 .await?;
354 let verified_did = response
355 .data
356 .get("did")
357 .and_then(|v| v.as_str())
358 .map(Did::new)
359 .transpose()
360 .map_err(|e| AgentError::Other(format!("server returned invalid DID: {e}")))?;
361
362 let mut committed = session;
364 if let Some(did) = verified_did {
365 committed.did = did;
366 }
367 *self.session.write().await = Some(committed.clone());
368 self.emit(AtpSessionEvent::Create, Some(&committed));
369
370 Ok(())
371 }
372
373 #[allow(clippy::significant_drop_tightening)]
387 pub async fn refresh_session(&self) -> Result<Session, AgentError> {
388 let refresh_jwt = {
389 let sess = self.session.read().await;
390 let sess = sess.as_ref().ok_or(AgentError::NotAuthenticated)?;
391 sess.refresh_jwt.clone()
392 };
393
394 let mut headers = HeadersMap::new();
396 headers.insert("Authorization".into(), format!("Bearer {refresh_jwt}"));
397 let opts = CallOptions {
398 encoding: None,
399 headers: Some(headers),
400 ..Default::default()
401 };
402
403 let response = match self
404 .client
405 .procedure("com.atproto.server.refreshSession", None, None, Some(&opts))
406 .await
407 {
408 Ok(r) => r,
409 Err(e) => {
410 if is_refresh_rejected(&e) {
416 *self.session.write().await = None;
417 self.emit(AtpSessionEvent::Expired, None);
418 } else {
419 self.emit(AtpSessionEvent::NetworkError, None);
420 }
421 return Err(AgentError::Xrpc(e));
422 }
423 };
424
425 let session: Session = serde_json::from_value(response.data)?;
426
427 *self.session.write().await = Some(session.clone());
429 self.emit(AtpSessionEvent::Update, Some(&session));
430 Ok(session)
431 }
432
433 async fn assert_did(&self) -> Result<Did, AgentError> {
437 self.did().await.ok_or(AgentError::NotAuthenticated)
438 }
439
440 async fn xrpc_query(
447 &self,
448 nsid: &str,
449 params: Option<&QueryParams>,
450 ) -> Result<serde_json::Value, AgentError> {
451 let opts = self.auth_call_options().await;
452 let first = self.client.query(nsid, params, opts.as_ref()).await;
453 match first {
454 Ok(r) => Ok(r.data),
455 Err(e) if is_auth_expired(&e) => {
456 self.refresh_and_retry(|opts| {
457 let c = self.client.clone();
458 let nsid = nsid.to_string();
459 let params = params.cloned();
460 async move { c.query(&nsid, params.as_ref(), opts.as_ref()).await }
461 })
462 .await
463 }
464 Err(e) => Err(AgentError::Xrpc(e)),
465 }
466 }
467
468 async fn xrpc_procedure(
470 &self,
471 nsid: &str,
472 body: serde_json::Value,
473 ) -> Result<serde_json::Value, AgentError> {
474 let opts = self.auth_call_options().await;
475 let first = self
476 .client
477 .procedure(
478 nsid,
479 None,
480 Some(XrpcBody::Json(body.clone())),
481 opts.as_ref(),
482 )
483 .await;
484 match first {
485 Ok(r) => Ok(r.data),
486 Err(e) if is_auth_expired(&e) => {
487 self.refresh_and_retry(|opts| {
488 let c = self.client.clone();
489 let nsid = nsid.to_string();
490 let body = body.clone();
491 async move {
492 c.procedure(&nsid, None, Some(XrpcBody::Json(body)), opts.as_ref())
493 .await
494 }
495 })
496 .await
497 }
498 Err(e) => Err(AgentError::Xrpc(e)),
499 }
500 }
501
502 async fn refresh_and_retry<F, Fut>(&self, replay: F) -> Result<serde_json::Value, AgentError>
512 where
513 F: FnOnce(Option<CallOptions>) -> Fut,
514 Fut: std::future::Future<
515 Output = Result<proto_blue_xrpc::XrpcResponse, proto_blue_xrpc::Error>,
516 >,
517 {
518 let pre_refresh_jwt = self
522 .session
523 .read()
524 .await
525 .as_ref()
526 .map(|s| s.access_jwt.clone());
527 let guard = self.refresh_lock.lock().await;
528 let current_jwt = self
529 .session
530 .read()
531 .await
532 .as_ref()
533 .map(|s| s.access_jwt.clone());
534 if pre_refresh_jwt == current_jwt {
535 self.refresh_session().await?;
537 }
538 drop(guard);
539
540 let opts = self.auth_call_options().await;
541 let response = replay(opts).await?;
542 Ok(response.data)
543 }
544
545 async fn create_record(
547 &self,
548 collection: &str,
549 record: serde_json::Value,
550 ) -> Result<serde_json::Value, AgentError> {
551 let did = self.assert_did().await?;
552 let body = serde_json::json!({
553 "repo": did,
554 "collection": collection,
555 "record": record,
556 });
557 self.xrpc_procedure("com.atproto.repo.createRecord", body)
558 .await
559 }
560
561 async fn delete_record(&self, collection: &str, uri: &AtUri) -> Result<(), AgentError> {
563 let did = self.assert_did().await?;
564 let rkey = uri
565 .rkey()
566 .ok_or_else(|| AgentError::Other("AT-URI has no rkey segment".into()))?;
567
568 let body = serde_json::json!({
569 "repo": did,
570 "collection": collection,
571 "rkey": rkey,
572 });
573 self.xrpc_procedure("com.atproto.repo.deleteRecord", body)
574 .await?;
575 Ok(())
576 }
577
578 fn now_iso() -> String {
580 chrono::Utc::now().to_rfc3339_opts(chrono::SecondsFormat::Millis, true)
581 }
582
583 fn resolve_timestamp(created_at: Option<&str>) -> String {
585 created_at.map_or_else(Self::now_iso, String::from)
586 }
587
588 pub async fn post(
594 &self,
595 text: &str,
596 facets: Option<Vec<crate::rich_text::Facet>>,
597 created_at: Option<&str>,
598 ) -> Result<serde_json::Value, AgentError> {
599 let mut record = serde_json::json!({
600 "$type": "app.bsky.feed.post",
601 "text": text,
602 "createdAt": Self::resolve_timestamp(created_at),
603 });
604
605 if let Some(facets) = facets {
606 record["facets"] = serde_json::to_value(&facets)?;
607 }
608
609 self.create_record("app.bsky.feed.post", record).await
610 }
611
612 pub async fn post_rich(
614 &self,
615 rt: &RichText,
616 created_at: Option<&str>,
617 ) -> Result<serde_json::Value, AgentError> {
618 let facets = if rt.facets().is_empty() {
619 None
620 } else {
621 Some(rt.facets().to_vec())
622 };
623 self.post(rt.text(), facets, created_at).await
624 }
625
626 pub async fn delete_post(&self, uri: &AtUri) -> Result<(), AgentError> {
628 self.delete_record("app.bsky.feed.post", uri).await
629 }
630
631 pub async fn like(
637 &self,
638 uri: &AtUri,
639 cid: &Cid,
640 created_at: Option<&str>,
641 ) -> Result<serde_json::Value, AgentError> {
642 let record = serde_json::json!({
643 "$type": "app.bsky.feed.like",
644 "subject": { "uri": uri, "cid": cid },
645 "createdAt": Self::resolve_timestamp(created_at),
646 });
647 self.create_record("app.bsky.feed.like", record).await
648 }
649
650 pub async fn delete_like(&self, like_uri: &AtUri) -> Result<(), AgentError> {
652 self.delete_record("app.bsky.feed.like", like_uri).await
653 }
654
655 pub async fn repost(
659 &self,
660 uri: &AtUri,
661 cid: &Cid,
662 created_at: Option<&str>,
663 ) -> Result<serde_json::Value, AgentError> {
664 let record = serde_json::json!({
665 "$type": "app.bsky.feed.repost",
666 "subject": { "uri": uri, "cid": cid },
667 "createdAt": Self::resolve_timestamp(created_at),
668 });
669 self.create_record("app.bsky.feed.repost", record).await
670 }
671
672 pub async fn delete_repost(&self, repost_uri: &AtUri) -> Result<(), AgentError> {
674 self.delete_record("app.bsky.feed.repost", repost_uri).await
675 }
676
677 pub async fn follow(
683 &self,
684 subject_did: &Did,
685 created_at: Option<&str>,
686 ) -> Result<serde_json::Value, AgentError> {
687 let record = serde_json::json!({
688 "$type": "app.bsky.graph.follow",
689 "subject": subject_did,
690 "createdAt": Self::resolve_timestamp(created_at),
691 });
692 self.create_record("app.bsky.graph.follow", record).await
693 }
694
695 pub async fn delete_follow(&self, follow_uri: &AtUri) -> Result<(), AgentError> {
697 self.delete_record("app.bsky.graph.follow", follow_uri)
698 .await
699 }
700
701 pub async fn get_profile(&self, actor: &AtIdentifier) -> Result<serde_json::Value, AgentError> {
705 let mut params = QueryParams::new();
706 params.insert("actor".into(), QueryValue::String(actor.to_string()));
707 self.xrpc_query("app.bsky.actor.getProfile", Some(¶ms))
708 .await
709 }
710
711 pub async fn get_timeline(
713 &self,
714 limit: Option<i64>,
715 cursor: Option<&str>,
716 ) -> Result<serde_json::Value, AgentError> {
717 let mut params = QueryParams::new();
718 if let Some(limit) = limit {
719 params.insert("limit".into(), QueryValue::Integer(limit));
720 }
721 if let Some(cursor) = cursor {
722 params.insert("cursor".into(), QueryValue::String(cursor.into()));
723 }
724 self.xrpc_query("app.bsky.feed.getTimeline", Some(¶ms))
725 .await
726 }
727
728 pub async fn get_post_thread(
730 &self,
731 uri: &AtUri,
732 depth: Option<i64>,
733 ) -> Result<serde_json::Value, AgentError> {
734 let mut params = QueryParams::new();
735 params.insert("uri".into(), QueryValue::String(uri.to_string()));
736 if let Some(depth) = depth {
737 params.insert("depth".into(), QueryValue::Integer(depth));
738 }
739 self.xrpc_query("app.bsky.feed.getPostThread", Some(¶ms))
740 .await
741 }
742
743 pub async fn search_actors(
745 &self,
746 query: &str,
747 limit: Option<i64>,
748 ) -> Result<serde_json::Value, AgentError> {
749 let mut params = QueryParams::new();
750 params.insert("q".into(), QueryValue::String(query.into()));
751 if let Some(limit) = limit {
752 params.insert("limit".into(), QueryValue::Integer(limit));
753 }
754 self.xrpc_query("app.bsky.actor.searchActors", Some(¶ms))
755 .await
756 }
757
758 pub async fn resolve_handle(&self, handle: &Handle) -> Result<Did, AgentError> {
760 let mut params = QueryParams::new();
761 params.insert("handle".into(), QueryValue::String(handle.to_string()));
762 let data = self
763 .xrpc_query("com.atproto.identity.resolveHandle", Some(¶ms))
764 .await?;
765 let did_str = data
766 .get("did")
767 .and_then(|v| v.as_str())
768 .ok_or_else(|| AgentError::Other("Missing DID in response".into()))?;
769 Did::new(did_str)
770 .map_err(|e| AgentError::Other(format!("server returned invalid DID: {e}")))
771 }
772
773 pub async fn list_notifications(
775 &self,
776 limit: Option<i64>,
777 cursor: Option<&str>,
778 ) -> Result<serde_json::Value, AgentError> {
779 let mut params = QueryParams::new();
780 if let Some(limit) = limit {
781 params.insert("limit".into(), QueryValue::Integer(limit));
782 }
783 if let Some(cursor) = cursor {
784 params.insert("cursor".into(), QueryValue::String(cursor.into()));
785 }
786 self.xrpc_query("app.bsky.notification.listNotifications", Some(¶ms))
787 .await
788 }
789
790 pub async fn upload_blob(
792 &self,
793 data: Vec<u8>,
794 content_type: &str,
795 ) -> Result<serde_json::Value, AgentError> {
796 let mut headers = HeadersMap::new();
797 headers.insert("Content-Type".into(), content_type.into());
798
799 if let Some(sess) = self.session.read().await.as_ref() {
801 headers.insert(
802 "Authorization".into(),
803 format!("Bearer {}", sess.access_jwt),
804 );
805 }
806
807 let opts = CallOptions {
808 encoding: Some(content_type.to_string()),
809 headers: Some(headers),
810 ..Default::default()
811 };
812
813 let response = self
814 .client
815 .procedure(
816 "com.atproto.repo.uploadBlob",
817 None,
818 Some(XrpcBody::Bytes(data)),
819 Some(&opts),
820 )
821 .await?;
822
823 Ok(response.data)
824 }
825
826 pub async fn describe_server(&self) -> Result<serde_json::Value, AgentError> {
828 self.xrpc_query("com.atproto.server.describeServer", None)
829 .await
830 }
831
832 pub async fn logout(&self) -> Result<(), AgentError> {
842 let refresh_jwt = {
843 let guard = self.session.read().await;
844 guard.as_ref().map(|s| s.refresh_jwt.clone())
845 };
846
847 let server_result = if let Some(refresh_jwt) = refresh_jwt {
848 let mut headers = HeadersMap::new();
849 headers.insert("Authorization".into(), format!("Bearer {refresh_jwt}"));
850 let opts = CallOptions {
851 encoding: None,
852 headers: Some(headers),
853 ..Default::default()
854 };
855 self.client
856 .procedure("com.atproto.server.deleteSession", None, None, Some(&opts))
857 .await
858 .map(|_| ())
859 } else {
860 Ok(())
861 };
862
863 *self.session.write().await = None;
865 self.emit(AtpSessionEvent::Expired, None);
866
867 server_result.map_err(AgentError::Xrpc)
868 }
869
870 pub async fn create_account(
879 &self,
880 handle: &Handle,
881 password: &str,
882 email: Option<&str>,
883 extra: Option<serde_json::Value>,
884 ) -> Result<Session, AgentError> {
885 let mut body = serde_json::json!({
886 "handle": handle,
887 "password": password,
888 });
889 if let Some(email) = email {
890 body["email"] = serde_json::Value::String(email.to_string());
891 }
892 if let Some(extra) = extra
893 && let Some(extra_map) = extra.as_object()
894 && let Some(body_map) = body.as_object_mut()
895 {
896 for (k, v) in extra_map {
897 body_map.insert(k.clone(), v.clone());
898 }
899 }
900
901 let response = match self
902 .client
903 .procedure(
904 "com.atproto.server.createAccount",
905 None,
906 Some(XrpcBody::Json(body)),
907 None,
908 )
909 .await
910 {
911 Ok(r) => r,
912 Err(e) => {
913 self.emit(AtpSessionEvent::CreateFailed, None);
914 return Err(AgentError::Xrpc(e));
915 }
916 };
917
918 let session: Session = serde_json::from_value(response.data)?;
919 *self.session.write().await = Some(session.clone());
920 self.emit(AtpSessionEvent::Create, Some(&session));
921 Ok(session)
922 }
923
924 pub async fn upsert_profile<F>(&self, mutate: F) -> Result<serde_json::Value, AgentError>
935 where
936 F: Fn(serde_json::Value) -> serde_json::Value,
937 {
938 const MAX_RETRIES: u32 = 5;
939
940 let did = self.assert_did().await?;
941
942 for _ in 0..MAX_RETRIES {
943 let existing_result = self
945 .xrpc_query(
946 "com.atproto.repo.getRecord",
947 Some(&{
948 let mut p = QueryParams::new();
949 p.insert("repo".into(), QueryValue::String(did.to_string()));
950 p.insert(
951 "collection".into(),
952 QueryValue::String("app.bsky.actor.profile".into()),
953 );
954 p.insert("rkey".into(), QueryValue::String("self".into()));
955 p
956 }),
957 )
958 .await;
959
960 let (existing_record, swap_cid) = match existing_result {
961 Ok(r) => {
962 let record = r.get("value").cloned().unwrap_or(serde_json::Value::Null);
963 let cid = r.get("cid").and_then(|v| v.as_str()).map(String::from);
964 (record, cid)
965 }
966 Err(AgentError::Xrpc(ref e)) if is_not_found(e) => (serde_json::Value::Null, None),
967 Err(e) => return Err(e),
968 };
969
970 let updated = mutate(existing_record);
971 let mut body = serde_json::json!({
972 "repo": did,
973 "collection": "app.bsky.actor.profile",
974 "rkey": "self",
975 "record": updated,
976 });
977 if let Some(cid) = swap_cid {
978 body["swapRecord"] = serde_json::Value::String(cid);
979 }
980
981 match self
982 .xrpc_procedure("com.atproto.repo.putRecord", body)
983 .await
984 {
985 Ok(r) => return Ok(r),
986 Err(AgentError::Xrpc(ref e)) if is_invalid_swap(e) => {
987 }
991 Err(e) => return Err(e),
992 }
993 }
994
995 Err(AgentError::Other(
996 "upsert_profile: exceeded maximum retries due to concurrent writes".into(),
997 ))
998 }
999}
1000
1001fn is_not_found(err: &proto_blue_xrpc::Error) -> bool {
1004 match err {
1005 proto_blue_xrpc::Error::Xrpc(x) => x.is_error("RecordNotFound"),
1006 _ => false,
1007 }
1008}
1009
1010fn is_invalid_swap(err: &proto_blue_xrpc::Error) -> bool {
1013 match err {
1014 proto_blue_xrpc::Error::Xrpc(x) => x.is_error("InvalidSwap"),
1015 _ => false,
1016 }
1017}
1018
1019fn is_auth_expired(err: &proto_blue_xrpc::Error) -> bool {
1026 match err {
1027 proto_blue_xrpc::Error::Xrpc(x) => {
1028 matches!(x.status, ResponseType::AuthenticationRequired) && x.is_error("ExpiredToken")
1029 }
1030 _ => false,
1031 }
1032}
1033
1034const fn is_refresh_rejected(err: &proto_blue_xrpc::Error) -> bool {
1039 match err {
1040 proto_blue_xrpc::Error::Xrpc(x) => {
1041 matches!(x.status, ResponseType::AuthenticationRequired)
1042 }
1043 _ => false,
1044 }
1045}
1046
1047#[cfg(test)]
1048mod tests {
1049 use super::*;
1050
1051 #[test]
1052 fn agent_creation() {
1053 let _agent = Agent::new("https://bsky.social").unwrap();
1054 }
1055
1056 #[test]
1057 fn session_serde_roundtrip() {
1058 let session = Session {
1059 did: Did::new("did:plc:abc123").unwrap(),
1060 handle: Handle::new("alice.bsky.social").unwrap(),
1061 access_jwt: "eyJ...".to_string(),
1062 refresh_jwt: "eyJ...".to_string(),
1063 email: Some("alice@example.com".to_string()),
1064 email_confirmed: Some(true),
1065 };
1066
1067 let json = serde_json::to_string(&session).unwrap();
1068 let parsed: Session = serde_json::from_str(&json).unwrap();
1069 assert_eq!(parsed.did.as_str(), "did:plc:abc123");
1070 assert_eq!(parsed.handle.as_str(), "alice.bsky.social");
1071 assert_eq!(parsed.email, Some("alice@example.com".to_string()));
1072 }
1073
1074 #[test]
1075 fn agent_error_display() {
1076 let err = AgentError::NotAuthenticated;
1077 assert_eq!(err.to_string(), "Not authenticated");
1078
1079 let err = AgentError::Other("test error".into());
1080 assert_eq!(err.to_string(), "test error");
1081 }
1082
1083 #[tokio::test]
1084 async fn agent_no_session_by_default() {
1085 let agent = Agent::new("https://bsky.social").unwrap();
1086 assert!(agent.did().await.is_none());
1087 assert!(agent.session().await.is_none());
1088 }
1089
1090 #[tokio::test]
1091 async fn agent_assert_did_fails_when_not_logged_in() {
1092 let agent = Agent::new("https://bsky.social").unwrap();
1093 let err = agent.assert_did().await.unwrap_err();
1094 assert!(matches!(err, AgentError::NotAuthenticated));
1095 }
1096
1097 #[test]
1098 fn now_iso_format() {
1099 let ts = Agent::now_iso();
1100 assert!(ts.ends_with('Z'));
1101 assert!(ts.contains('T'));
1102 }
1103
1104 #[test]
1105 fn resolve_timestamp_with_provided() {
1106 let ts = Agent::resolve_timestamp(Some("2024-01-15T12:00:00.000Z"));
1107 assert_eq!(ts, "2024-01-15T12:00:00.000Z");
1108 }
1109
1110 #[test]
1111 fn resolve_timestamp_without_provided() {
1112 let ts = Agent::resolve_timestamp(None);
1113 assert!(ts.ends_with('Z'));
1114 assert!(ts.contains('T'));
1115 }
1116
1117 #[test]
1118 fn service_url_accessible_without_async() {
1119 let agent = Agent::new("https://bsky.social").unwrap();
1120 assert_eq!(agent.service(), "https://bsky.social/");
1121 }
1122
1123 #[tokio::test]
1124 async fn auth_call_options_none_when_not_authenticated() {
1125 let agent = Agent::new("https://bsky.social").unwrap();
1126 assert!(agent.auth_call_options().await.is_none());
1127 }
1128
1129 use async_trait::async_trait;
1132 use proto_blue_common::fetch::{FetchError, FetchHandler, HttpRequest, HttpResponse};
1133
1134 struct ScriptedFetcher {
1138 createsession_body: Vec<u8>,
1139 scripts: std::sync::Mutex<std::collections::HashMap<String, Vec<ScriptedResponse>>>,
1141 call_counts: std::sync::Mutex<std::collections::HashMap<String, usize>>,
1142 }
1143
1144 #[derive(Clone)]
1145 struct ScriptedResponse {
1146 status: u16,
1147 body: Vec<u8>,
1148 }
1149
1150 impl ScriptedFetcher {
1151 fn new(createsession_body: Vec<u8>) -> Self {
1152 Self {
1153 createsession_body,
1154 scripts: Default::default(),
1155 call_counts: Default::default(),
1156 }
1157 }
1158 fn script(&self, path: &str, responses: Vec<ScriptedResponse>) {
1159 self.scripts
1160 .lock()
1161 .unwrap()
1162 .insert(path.to_string(), responses);
1163 }
1164 fn call_count(&self, path: &str) -> usize {
1165 *self.call_counts.lock().unwrap().get(path).unwrap_or(&0)
1166 }
1167 }
1168
1169 #[async_trait]
1170 impl FetchHandler for ScriptedFetcher {
1171 async fn fetch(&self, req: HttpRequest) -> Result<HttpResponse, FetchError> {
1172 let path = req.url.clone();
1173 let key = path
1174 .split("/xrpc/")
1175 .nth(1)
1176 .unwrap_or(&path)
1177 .split('?')
1178 .next()
1179 .unwrap_or("")
1180 .to_string();
1181 *self
1182 .call_counts
1183 .lock()
1184 .unwrap()
1185 .entry(key.clone())
1186 .or_insert(0) += 1;
1187
1188 {
1192 let mut scripts = self.scripts.lock().unwrap();
1193 if let Some(list) = scripts.get_mut(&key) {
1194 let resp = if list.len() == 1 {
1195 list[0].clone()
1196 } else {
1197 list.remove(0)
1198 };
1199 let mut headers = proto_blue_common::fetch::HttpHeaders::new();
1200 headers.insert("content-type".into(), "application/json".into());
1201 return Ok(HttpResponse {
1202 status: resp.status,
1203 headers,
1204 body: resp.body,
1205 });
1206 }
1207 }
1208
1209 if key == "com.atproto.server.createSession" {
1211 let mut headers = proto_blue_common::fetch::HttpHeaders::new();
1212 headers.insert("content-type".into(), "application/json".into());
1213 return Ok(HttpResponse {
1214 status: 200,
1215 headers,
1216 body: self.createsession_body.clone(),
1217 });
1218 }
1219
1220 Err(FetchError::Other(format!("no script for {key}")))
1221 }
1222 }
1223
1224 fn login_body() -> Vec<u8> {
1225 br#"{"did":"did:plc:u","handle":"alice.test","accessJwt":"a1","refreshJwt":"r1"}"#.to_vec()
1226 }
1227
1228 fn agent_with_fetcher(fetcher: Arc<ScriptedFetcher>) -> Agent {
1229 let client = XrpcClient::with_fetch_handler("https://example.com", fetcher).unwrap();
1230 Agent {
1231 client,
1232 session: Arc::new(RwLock::new(None)),
1233 listeners: Arc::new(Mutex::new(Vec::new())),
1234 refresh_lock: Arc::new(AsyncMutex::new(())),
1235 proxy: Arc::new(RwLock::new(None)),
1236 labelers: Arc::new(RwLock::new(Vec::new())),
1237 }
1238 }
1239
1240 #[tokio::test]
1241 async fn emits_create_on_successful_login() {
1242 let fetcher = Arc::new(ScriptedFetcher::new(login_body()));
1243 let agent = agent_with_fetcher(fetcher);
1244
1245 let events: Arc<Mutex<Vec<AtpSessionEvent>>> = Arc::new(Mutex::new(Vec::new()));
1246 let ev_clone = events.clone();
1247 agent.on_session(move |e, _| ev_clone.lock().unwrap().push(e));
1248
1249 agent
1250 .login(&AtIdentifier::new("alice.test").unwrap(), "secret")
1251 .await
1252 .unwrap();
1253 let got = events.lock().unwrap().clone();
1254 assert_eq!(got, vec![AtpSessionEvent::Create]);
1255 }
1256
1257 #[tokio::test]
1258 async fn emits_create_failed_on_login_rejection() {
1259 let fetcher = Arc::new(ScriptedFetcher::new(vec![]));
1260 fetcher.script(
1262 "com.atproto.server.createSession",
1263 vec![ScriptedResponse {
1264 status: 401,
1265 body: br#"{"error":"AuthenticationRequired","message":"bad pwd"}"#.to_vec(),
1266 }],
1267 );
1268 let agent = agent_with_fetcher(fetcher);
1269
1270 let events: Arc<Mutex<Vec<AtpSessionEvent>>> = Arc::new(Mutex::new(Vec::new()));
1271 let ev_clone = events.clone();
1272 agent.on_session(move |e, _| ev_clone.lock().unwrap().push(e));
1273
1274 let _ = agent
1278 .login(&AtIdentifier::new("alice.test").unwrap(), "bad")
1279 .await
1280 .unwrap_err();
1281 let got = events.lock().unwrap().clone();
1282 assert_eq!(got, vec![AtpSessionEvent::CreateFailed]);
1283 }
1284
1285 #[tokio::test]
1286 async fn auto_refreshes_on_expired_access_token() {
1287 let fetcher = Arc::new(ScriptedFetcher::new(login_body()));
1288
1289 fetcher.script(
1292 "com.atproto.server.describeServer",
1293 vec![
1294 ScriptedResponse {
1295 status: 401,
1296 body: br#"{"error":"ExpiredToken","message":"expired"}"#.to_vec(),
1297 },
1298 ScriptedResponse {
1299 status: 200,
1300 body: br#"{"did":"did:plc:svr"}"#.to_vec(),
1301 },
1302 ],
1303 );
1304 fetcher.script(
1305 "com.atproto.server.refreshSession",
1306 vec![ScriptedResponse {
1307 status: 200,
1308 body: br#"{"did":"did:plc:u","handle":"alice.test","accessJwt":"a2","refreshJwt":"r2"}"#
1309 .to_vec(),
1310 }],
1311 );
1312
1313 let agent = agent_with_fetcher(fetcher.clone());
1314 agent
1315 .login(&AtIdentifier::new("alice.test").unwrap(), "secret")
1316 .await
1317 .unwrap();
1318
1319 let events: Arc<Mutex<Vec<AtpSessionEvent>>> = Arc::new(Mutex::new(Vec::new()));
1320 let ev_clone = events.clone();
1321 agent.on_session(move |e, _| ev_clone.lock().unwrap().push(e));
1322
1323 let result = agent.describe_server().await.unwrap();
1324 assert_eq!(result["did"], "did:plc:svr");
1325
1326 assert_eq!(fetcher.call_count("com.atproto.server.describeServer"), 2);
1329 assert_eq!(fetcher.call_count("com.atproto.server.refreshSession"), 1);
1330
1331 let got = events.lock().unwrap().clone();
1333 assert_eq!(got, vec![AtpSessionEvent::Update]);
1334 }
1335
1336 #[tokio::test]
1337 async fn concurrent_expired_token_refreshes_once() {
1338 let fetcher = Arc::new(ScriptedFetcher::new(login_body()));
1339
1340 fetcher.script(
1343 "com.atproto.server.describeServer",
1344 vec![
1345 ScriptedResponse {
1346 status: 401,
1347 body: br#"{"error":"ExpiredToken","message":"expired"}"#.to_vec(),
1348 },
1349 ScriptedResponse {
1350 status: 200,
1351 body: br#"{"did":"did:plc:svr"}"#.to_vec(),
1352 },
1353 ],
1354 );
1355 fetcher.script(
1356 "com.atproto.server.refreshSession",
1357 vec![ScriptedResponse {
1358 status: 200,
1359 body: br#"{"did":"did:plc:u","handle":"alice.test","accessJwt":"a2","refreshJwt":"r2"}"#
1360 .to_vec(),
1361 }],
1362 );
1363
1364 let agent = Arc::new(agent_with_fetcher(fetcher.clone()));
1365 agent
1366 .login(&AtIdentifier::new("alice.test").unwrap(), "secret")
1367 .await
1368 .unwrap();
1369
1370 let mut handles = Vec::new();
1374 for _ in 0..5 {
1375 let a = agent.clone();
1376 handles.push(tokio::spawn(async move {
1377 a.describe_server().await.unwrap();
1378 }));
1379 }
1380 for h in handles {
1381 h.await.unwrap();
1382 }
1383
1384 assert_eq!(
1385 fetcher.call_count("com.atproto.server.refreshSession"),
1386 1,
1387 "concurrent callers must share one refreshSession call",
1388 );
1389 }
1390
1391 #[tokio::test]
1392 async fn configure_proxy_sets_header_on_next_call() {
1393 let fetcher = Arc::new(ScriptedFetcher::new(login_body()));
1394 fetcher.script(
1395 "com.atproto.server.describeServer",
1396 vec![ScriptedResponse {
1397 status: 200,
1398 body: br#"{"did":"did:plc:svr"}"#.to_vec(),
1399 }],
1400 );
1401 let agent = agent_with_fetcher(fetcher.clone());
1402 agent
1403 .configure_proxy(Some("did:web:api.bsky.chat#bsky_chat"))
1404 .await;
1405
1406 agent.describe_server().await.unwrap();
1407
1408 let p = agent.proxy.read().await;
1411 assert_eq!(p.as_deref(), Some("did:web:api.bsky.chat#bsky_chat"));
1412 }
1413
1414 #[tokio::test]
1415 async fn configure_labelers_stores_list() {
1416 let fetcher = Arc::new(ScriptedFetcher::new(login_body()));
1417 let agent = agent_with_fetcher(fetcher);
1418 agent
1419 .configure_labelers(&[
1420 LabelerOpts {
1421 did: Did::new("did:plc:a").unwrap(),
1422 redirect: false,
1423 },
1424 LabelerOpts {
1425 did: Did::new("did:plc:b").unwrap(),
1426 redirect: true,
1427 },
1428 ])
1429 .await;
1430 let l = agent.labelers.read().await;
1431 assert_eq!(l.len(), 2);
1432 assert_eq!(l[0].header_value(), "did:plc:a");
1433 assert_eq!(l[1].header_value(), "did:plc:b;redirect");
1434 }
1435
1436 #[tokio::test]
1437 async fn logout_clears_session() {
1438 let fetcher = Arc::new(ScriptedFetcher::new(login_body()));
1439 fetcher.script(
1440 "com.atproto.server.deleteSession",
1441 vec![ScriptedResponse {
1442 status: 200,
1443 body: b"{}".to_vec(),
1444 }],
1445 );
1446 let agent = agent_with_fetcher(fetcher.clone());
1447 agent
1448 .login(&AtIdentifier::new("alice.test").unwrap(), "secret")
1449 .await
1450 .unwrap();
1451 assert!(agent.session().await.is_some());
1452 agent.logout().await.unwrap();
1453 assert!(agent.session().await.is_none());
1454 assert_eq!(fetcher.call_count("com.atproto.server.deleteSession"), 1,);
1455 }
1456
1457 #[tokio::test]
1458 async fn logout_clears_session_even_on_server_error() {
1459 let fetcher = Arc::new(ScriptedFetcher::new(login_body()));
1460 fetcher.script(
1461 "com.atproto.server.deleteSession",
1462 vec![ScriptedResponse {
1463 status: 500,
1464 body: br#"{"error":"InternalServerError"}"#.to_vec(),
1465 }],
1466 );
1467 let agent = agent_with_fetcher(fetcher);
1468 agent
1469 .login(&AtIdentifier::new("alice.test").unwrap(), "secret")
1470 .await
1471 .unwrap();
1472 let _ = agent.logout().await;
1474 assert!(agent.session().await.is_none());
1475 }
1476
1477 #[tokio::test]
1478 async fn create_account_emits_create_on_success() {
1479 let fetcher = Arc::new(ScriptedFetcher::new(login_body()));
1480 fetcher.script(
1481 "com.atproto.server.createAccount",
1482 vec![ScriptedResponse {
1483 status: 200,
1484 body:
1485 br#"{"did":"did:plc:new","handle":"newuser.test","accessJwt":"a","refreshJwt":"r"}"#
1486 .to_vec(),
1487 }],
1488 );
1489 let agent = agent_with_fetcher(fetcher);
1490
1491 let events: Arc<Mutex<Vec<AtpSessionEvent>>> = Arc::new(Mutex::new(Vec::new()));
1492 let ev = events.clone();
1493 agent.on_session(move |e, _| ev.lock().unwrap().push(e));
1494
1495 let session = agent
1496 .create_account(
1497 &Handle::new("newuser.test").unwrap(),
1498 "pw",
1499 Some("new@example.com"),
1500 None,
1501 )
1502 .await
1503 .unwrap();
1504 assert_eq!(session.did.as_str(), "did:plc:new");
1505 assert_eq!(
1506 events.lock().unwrap().clone(),
1507 vec![AtpSessionEvent::Create]
1508 );
1509 }
1510
1511 #[tokio::test]
1512 async fn upsert_profile_creates_when_absent() {
1513 let fetcher = Arc::new(ScriptedFetcher::new(login_body()));
1514 fetcher.script(
1516 "com.atproto.repo.getRecord",
1517 vec![ScriptedResponse {
1518 status: 400,
1519 body: br#"{"error":"RecordNotFound","message":"no such record"}"#.to_vec(),
1520 }],
1521 );
1522 fetcher.script(
1523 "com.atproto.repo.putRecord",
1524 vec![ScriptedResponse {
1525 status: 200,
1526 body: br#"{"uri":"at://did:plc:u/app.bsky.actor.profile/self","cid":"bafy"}"#
1527 .to_vec(),
1528 }],
1529 );
1530 let agent = agent_with_fetcher(fetcher);
1531 agent
1532 .login(&AtIdentifier::new("alice.test").unwrap(), "secret")
1533 .await
1534 .unwrap();
1535
1536 let result = agent
1537 .upsert_profile(|prev| {
1538 assert!(prev.is_null(), "no existing profile");
1539 serde_json::json!({"$type": "app.bsky.actor.profile", "displayName": "Alice"})
1540 })
1541 .await
1542 .unwrap();
1543 assert_eq!(result["uri"], "at://did:plc:u/app.bsky.actor.profile/self");
1544 }
1545
1546 #[tokio::test]
1547 async fn emits_expired_when_refresh_itself_401s() {
1548 let fetcher = Arc::new(ScriptedFetcher::new(login_body()));
1549 fetcher.script(
1550 "com.atproto.server.refreshSession",
1551 vec![ScriptedResponse {
1552 status: 401,
1553 body: br#"{"error":"AuthenticationRequired","message":"refresh expired"}"#.to_vec(),
1554 }],
1555 );
1556 let agent = agent_with_fetcher(fetcher);
1557 agent
1558 .login(&AtIdentifier::new("alice.test").unwrap(), "secret")
1559 .await
1560 .unwrap();
1561
1562 let events: Arc<Mutex<Vec<AtpSessionEvent>>> = Arc::new(Mutex::new(Vec::new()));
1563 let ev_clone = events.clone();
1564 agent.on_session(move |e, _| ev_clone.lock().unwrap().push(e));
1565
1566 let _ = agent.refresh_session().await.unwrap_err();
1567 let got = events.lock().unwrap().clone();
1568 assert_eq!(got, vec![AtpSessionEvent::Expired]);
1569 assert!(
1570 agent.session().await.is_none(),
1571 "session cleared on expired refresh"
1572 );
1573 }
1574}