1use std::sync::{Arc, Mutex};
7use tokio::sync::{Mutex as AsyncMutex, RwLock};
8
9use proto_blue_xrpc::{
10 CallOptions, HeadersMap, QueryParams, QueryValue, ResponseType, XrpcBody, XrpcClient,
11};
12
13use crate::rich_text::RichText;
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum AtpSessionEvent {
23 Create,
25 CreateFailed,
27 Update,
29 Expired,
31 NetworkError,
33}
34
35pub type SessionEventCallback = Arc<dyn Fn(AtpSessionEvent, Option<&Session>) + Send + Sync>;
42
43#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
45#[serde(rename_all = "camelCase")]
46pub struct Session {
47 pub did: String,
48 pub handle: String,
49 pub access_jwt: String,
50 pub refresh_jwt: String,
51 #[serde(skip_serializing_if = "Option::is_none")]
52 pub email: Option<String>,
53 #[serde(skip_serializing_if = "Option::is_none")]
54 pub email_confirmed: Option<bool>,
55}
56
57#[derive(Debug, thiserror::Error)]
59pub enum AgentError {
60 #[error("XRPC error: {0}")]
61 Xrpc(#[from] proto_blue_xrpc::Error),
62 #[error("Not authenticated")]
63 NotAuthenticated,
64 #[error("JSON error: {0}")]
65 Json(#[from] serde_json::Error),
66 #[error("{0}")]
67 Other(String),
68}
69
70pub struct Agent {
88 client: XrpcClient,
89 session: Arc<RwLock<Option<Session>>>,
90 listeners: Arc<Mutex<Vec<SessionEventCallback>>>,
93 refresh_lock: Arc<AsyncMutex<()>>,
98 proxy: Arc<RwLock<Option<String>>>,
101 labelers: Arc<RwLock<Vec<LabelerOpts>>>,
103}
104
105#[derive(Debug, Clone, PartialEq, Eq)]
107pub struct LabelerOpts {
108 pub did: String,
110 pub redirect: bool,
113}
114
115impl LabelerOpts {
116 fn header_value(&self) -> String {
118 if self.redirect {
119 format!("{};redirect", self.did)
120 } else {
121 self.did.clone()
122 }
123 }
124}
125
126impl Agent {
127 #[cfg(any(
133 all(feature = "fetch-reqwest", not(target_arch = "wasm32")),
134 target_arch = "wasm32",
135 ))]
136 pub fn new(service: impl AsRef<str>) -> Result<Self, AgentError> {
137 let client = XrpcClient::new(service)?;
138 Ok(Self {
139 client,
140 session: Arc::new(RwLock::new(None)),
141 listeners: Arc::new(Mutex::new(Vec::new())),
142 refresh_lock: Arc::new(AsyncMutex::new(())),
143 proxy: Arc::new(RwLock::new(None)),
144 labelers: Arc::new(RwLock::new(Vec::new())),
145 })
146 }
147
148 pub fn on_session<F>(&self, callback: F)
155 where
156 F: Fn(AtpSessionEvent, Option<&Session>) + Send + Sync + 'static,
157 {
158 self.listeners.lock().unwrap().push(Arc::new(callback));
159 }
160
161 fn emit(&self, event: AtpSessionEvent, session: Option<&Session>) {
163 let listeners = self.listeners.lock().unwrap().clone();
168 for cb in listeners {
169 cb(event, session);
170 }
171 }
172
173 #[must_use]
175 pub fn service(&self) -> String {
176 self.client.service_url().to_string()
177 }
178
179 pub async fn did(&self) -> Option<String> {
181 self.session.read().await.as_ref().map(|s| s.did.clone())
182 }
183
184 pub async fn session(&self) -> Option<Session> {
186 self.session.read().await.clone()
187 }
188
189 async fn auth_call_options(&self) -> Option<CallOptions> {
196 let guard = self.session.read().await;
197 let session = guard.as_ref()?;
198 let mut headers = HeadersMap::new();
199 headers.insert(
200 "Authorization".into(),
201 format!("Bearer {}", session.access_jwt),
202 );
203 self.inject_proxy_and_labelers(&mut headers).await;
204 Some(CallOptions {
205 encoding: None,
206 headers: Some(headers),
207 ..Default::default()
208 })
209 }
210
211 pub async fn anon_call_options(&self) -> Option<CallOptions> {
217 let mut headers = HeadersMap::new();
218 self.inject_proxy_and_labelers(&mut headers).await;
219 if headers.is_empty() {
220 None
221 } else {
222 Some(CallOptions {
223 encoding: None,
224 headers: Some(headers),
225 ..Default::default()
226 })
227 }
228 }
229
230 async fn inject_proxy_and_labelers(&self, headers: &mut HeadersMap) {
231 if let Some(proxy) = self.proxy.read().await.as_ref() {
232 headers.insert("atproto-proxy".into(), proxy.clone());
233 }
234 let labelers = self.labelers.read().await;
235 if !labelers.is_empty() {
236 let v = labelers
237 .iter()
238 .map(LabelerOpts::header_value)
239 .collect::<Vec<_>>()
240 .join(", ");
241 headers.insert("atproto-accept-labelers".into(), v);
242 }
243 }
244
245 pub async fn configure_proxy(&self, target: Option<&str>) {
251 *self.proxy.write().await = target.map(String::from);
252 }
253
254 pub async fn with_proxy(&self, target: &str) -> Self {
257 let cloned = self.shallow_clone();
258 cloned.configure_proxy(Some(target)).await;
259 cloned
260 }
261
262 pub async fn configure_labelers(&self, labelers: &[LabelerOpts]) {
265 *self.labelers.write().await = labelers.to_vec();
266 }
267
268 fn shallow_clone(&self) -> Self {
272 Self {
273 client: self.client.clone(),
274 session: self.session.clone(),
275 listeners: self.listeners.clone(),
276 refresh_lock: self.refresh_lock.clone(),
277 proxy: Arc::new(RwLock::new(None)),
278 labelers: self.labelers.clone(),
279 }
280 }
281
282 pub async fn login(&self, identifier: &str, password: &str) -> Result<Session, AgentError> {
288 let body = serde_json::json!({
289 "identifier": identifier,
290 "password": password,
291 });
292
293 let response = match self
294 .client
295 .procedure(
296 "com.atproto.server.createSession",
297 None,
298 Some(XrpcBody::Json(body)),
299 None,
300 )
301 .await
302 {
303 Ok(r) => r,
304 Err(e) => {
305 self.emit(AtpSessionEvent::CreateFailed, None);
306 return Err(AgentError::Xrpc(e));
307 }
308 };
309
310 let session: Session = serde_json::from_value(response.data)?;
311
312 *self.session.write().await = Some(session.clone());
314 self.emit(AtpSessionEvent::Create, Some(&session));
315 Ok(session)
316 }
317
318 pub async fn resume_session(&self, session: Session) -> Result<(), AgentError> {
323 let mut headers = HeadersMap::new();
326 headers.insert(
327 "Authorization".into(),
328 format!("Bearer {}", session.access_jwt),
329 );
330 let opts = CallOptions {
331 encoding: None,
332 headers: Some(headers),
333 ..Default::default()
334 };
335 let response = self
336 .client
337 .query("com.atproto.server.getSession", None, Some(&opts))
338 .await?;
339 let verified_did = response
340 .data
341 .get("did")
342 .and_then(|v| v.as_str())
343 .map(std::string::ToString::to_string);
344
345 let mut committed = session;
347 if let Some(did) = verified_did {
348 committed.did = did;
349 }
350 *self.session.write().await = Some(committed.clone());
351 self.emit(AtpSessionEvent::Create, Some(&committed));
352
353 Ok(())
354 }
355
356 pub async fn refresh_session(&self) -> Result<Session, AgentError> {
364 let refresh_jwt = {
365 let sess = self.session.read().await;
366 let sess = sess.as_ref().ok_or(AgentError::NotAuthenticated)?;
367 sess.refresh_jwt.clone()
368 };
369
370 let mut headers = HeadersMap::new();
372 headers.insert("Authorization".into(), format!("Bearer {refresh_jwt}"));
373 let opts = CallOptions {
374 encoding: None,
375 headers: Some(headers),
376 ..Default::default()
377 };
378
379 let response = match self
380 .client
381 .procedure("com.atproto.server.refreshSession", None, None, Some(&opts))
382 .await
383 {
384 Ok(r) => r,
385 Err(e) => {
386 if is_refresh_rejected(&e) {
392 *self.session.write().await = None;
393 self.emit(AtpSessionEvent::Expired, None);
394 } else {
395 self.emit(AtpSessionEvent::NetworkError, None);
396 }
397 return Err(AgentError::Xrpc(e));
398 }
399 };
400
401 let session: Session = serde_json::from_value(response.data)?;
402
403 *self.session.write().await = Some(session.clone());
405 self.emit(AtpSessionEvent::Update, Some(&session));
406 Ok(session)
407 }
408
409 async fn assert_did(&self) -> Result<String, AgentError> {
413 self.did().await.ok_or(AgentError::NotAuthenticated)
414 }
415
416 async fn xrpc_query(
423 &self,
424 nsid: &str,
425 params: Option<&QueryParams>,
426 ) -> Result<serde_json::Value, AgentError> {
427 let opts = self.auth_call_options().await;
428 let first = self.client.query(nsid, params, opts.as_ref()).await;
429 match first {
430 Ok(r) => Ok(r.data),
431 Err(e) if is_auth_expired(&e) => {
432 self.refresh_and_retry(|opts| {
433 let c = self.client.clone();
434 let nsid = nsid.to_string();
435 let params = params.cloned();
436 async move { c.query(&nsid, params.as_ref(), opts.as_ref()).await }
437 })
438 .await
439 }
440 Err(e) => Err(AgentError::Xrpc(e)),
441 }
442 }
443
444 async fn xrpc_procedure(
446 &self,
447 nsid: &str,
448 body: serde_json::Value,
449 ) -> Result<serde_json::Value, AgentError> {
450 let opts = self.auth_call_options().await;
451 let first = self
452 .client
453 .procedure(
454 nsid,
455 None,
456 Some(XrpcBody::Json(body.clone())),
457 opts.as_ref(),
458 )
459 .await;
460 match first {
461 Ok(r) => Ok(r.data),
462 Err(e) if is_auth_expired(&e) => {
463 self.refresh_and_retry(|opts| {
464 let c = self.client.clone();
465 let nsid = nsid.to_string();
466 let body = body.clone();
467 async move {
468 c.procedure(&nsid, None, Some(XrpcBody::Json(body)), opts.as_ref())
469 .await
470 }
471 })
472 .await
473 }
474 Err(e) => Err(AgentError::Xrpc(e)),
475 }
476 }
477
478 async fn refresh_and_retry<F, Fut>(&self, replay: F) -> Result<serde_json::Value, AgentError>
488 where
489 F: FnOnce(Option<CallOptions>) -> Fut,
490 Fut: std::future::Future<
491 Output = Result<proto_blue_xrpc::XrpcResponse, proto_blue_xrpc::Error>,
492 >,
493 {
494 let pre_refresh_jwt = self
498 .session
499 .read()
500 .await
501 .as_ref()
502 .map(|s| s.access_jwt.clone());
503 let _guard = self.refresh_lock.lock().await;
504 let current_jwt = self
505 .session
506 .read()
507 .await
508 .as_ref()
509 .map(|s| s.access_jwt.clone());
510 if pre_refresh_jwt == current_jwt {
511 self.refresh_session().await?;
513 }
514 drop(_guard);
515
516 let opts = self.auth_call_options().await;
517 let response = replay(opts).await?;
518 Ok(response.data)
519 }
520
521 async fn create_record(
523 &self,
524 collection: &str,
525 record: serde_json::Value,
526 ) -> Result<serde_json::Value, AgentError> {
527 let did = self.assert_did().await?;
528 let body = serde_json::json!({
529 "repo": did,
530 "collection": collection,
531 "record": record,
532 });
533 self.xrpc_procedure("com.atproto.repo.createRecord", body)
534 .await
535 }
536
537 async fn delete_record(&self, collection: &str, uri: &str) -> Result<(), AgentError> {
539 let did = self.assert_did().await?;
540 let rkey = uri
541 .rsplit('/')
542 .next()
543 .ok_or_else(|| AgentError::Other("Invalid AT-URI".into()))?;
544
545 let body = serde_json::json!({
546 "repo": did,
547 "collection": collection,
548 "rkey": rkey,
549 });
550 self.xrpc_procedure("com.atproto.repo.deleteRecord", body)
551 .await?;
552 Ok(())
553 }
554
555 fn now_iso() -> String {
557 chrono::Utc::now().to_rfc3339_opts(chrono::SecondsFormat::Millis, true)
558 }
559
560 fn resolve_timestamp(created_at: Option<&str>) -> String {
562 created_at.map_or_else(Self::now_iso, String::from)
563 }
564
565 pub async fn post(
571 &self,
572 text: &str,
573 facets: Option<Vec<crate::rich_text::Facet>>,
574 created_at: Option<&str>,
575 ) -> Result<serde_json::Value, AgentError> {
576 let mut record = serde_json::json!({
577 "$type": "app.bsky.feed.post",
578 "text": text,
579 "createdAt": Self::resolve_timestamp(created_at),
580 });
581
582 if let Some(facets) = facets {
583 record["facets"] = serde_json::to_value(&facets)?;
584 }
585
586 self.create_record("app.bsky.feed.post", record).await
587 }
588
589 pub async fn post_rich(
591 &self,
592 rt: &RichText,
593 created_at: Option<&str>,
594 ) -> Result<serde_json::Value, AgentError> {
595 let facets = if rt.facets().is_empty() {
596 None
597 } else {
598 Some(rt.facets().to_vec())
599 };
600 self.post(rt.text(), facets, created_at).await
601 }
602
603 pub async fn delete_post(&self, uri: &str) -> Result<(), AgentError> {
605 self.delete_record("app.bsky.feed.post", uri).await
606 }
607
608 pub async fn like(
614 &self,
615 uri: &str,
616 cid: &str,
617 created_at: Option<&str>,
618 ) -> Result<serde_json::Value, AgentError> {
619 let record = serde_json::json!({
620 "$type": "app.bsky.feed.like",
621 "subject": { "uri": uri, "cid": cid },
622 "createdAt": Self::resolve_timestamp(created_at),
623 });
624 self.create_record("app.bsky.feed.like", record).await
625 }
626
627 pub async fn delete_like(&self, like_uri: &str) -> Result<(), AgentError> {
629 self.delete_record("app.bsky.feed.like", like_uri).await
630 }
631
632 pub async fn repost(
636 &self,
637 uri: &str,
638 cid: &str,
639 created_at: Option<&str>,
640 ) -> Result<serde_json::Value, AgentError> {
641 let record = serde_json::json!({
642 "$type": "app.bsky.feed.repost",
643 "subject": { "uri": uri, "cid": cid },
644 "createdAt": Self::resolve_timestamp(created_at),
645 });
646 self.create_record("app.bsky.feed.repost", record).await
647 }
648
649 pub async fn delete_repost(&self, repost_uri: &str) -> Result<(), AgentError> {
651 self.delete_record("app.bsky.feed.repost", repost_uri).await
652 }
653
654 pub async fn follow(
660 &self,
661 subject_did: &str,
662 created_at: Option<&str>,
663 ) -> Result<serde_json::Value, AgentError> {
664 let record = serde_json::json!({
665 "$type": "app.bsky.graph.follow",
666 "subject": subject_did,
667 "createdAt": Self::resolve_timestamp(created_at),
668 });
669 self.create_record("app.bsky.graph.follow", record).await
670 }
671
672 pub async fn delete_follow(&self, follow_uri: &str) -> Result<(), AgentError> {
674 self.delete_record("app.bsky.graph.follow", follow_uri)
675 .await
676 }
677
678 pub async fn get_profile(&self, actor: &str) -> Result<serde_json::Value, AgentError> {
682 let mut params = QueryParams::new();
683 params.insert("actor".into(), QueryValue::String(actor.into()));
684 self.xrpc_query("app.bsky.actor.getProfile", Some(¶ms))
685 .await
686 }
687
688 pub async fn get_timeline(
690 &self,
691 limit: Option<i64>,
692 cursor: Option<&str>,
693 ) -> Result<serde_json::Value, AgentError> {
694 let mut params = QueryParams::new();
695 if let Some(limit) = limit {
696 params.insert("limit".into(), QueryValue::Integer(limit));
697 }
698 if let Some(cursor) = cursor {
699 params.insert("cursor".into(), QueryValue::String(cursor.into()));
700 }
701 self.xrpc_query("app.bsky.feed.getTimeline", Some(¶ms))
702 .await
703 }
704
705 pub async fn get_post_thread(
707 &self,
708 uri: &str,
709 depth: Option<i64>,
710 ) -> Result<serde_json::Value, AgentError> {
711 let mut params = QueryParams::new();
712 params.insert("uri".into(), QueryValue::String(uri.into()));
713 if let Some(depth) = depth {
714 params.insert("depth".into(), QueryValue::Integer(depth));
715 }
716 self.xrpc_query("app.bsky.feed.getPostThread", Some(¶ms))
717 .await
718 }
719
720 pub async fn search_actors(
722 &self,
723 query: &str,
724 limit: Option<i64>,
725 ) -> Result<serde_json::Value, AgentError> {
726 let mut params = QueryParams::new();
727 params.insert("q".into(), QueryValue::String(query.into()));
728 if let Some(limit) = limit {
729 params.insert("limit".into(), QueryValue::Integer(limit));
730 }
731 self.xrpc_query("app.bsky.actor.searchActors", Some(¶ms))
732 .await
733 }
734
735 pub async fn resolve_handle(&self, handle: &str) -> Result<String, AgentError> {
737 let mut params = QueryParams::new();
738 params.insert("handle".into(), QueryValue::String(handle.into()));
739 let data = self
740 .xrpc_query("com.atproto.identity.resolveHandle", Some(¶ms))
741 .await?;
742 data.get("did")
743 .and_then(|v| v.as_str())
744 .map(std::string::ToString::to_string)
745 .ok_or_else(|| AgentError::Other("Missing DID in response".into()))
746 }
747
748 pub async fn list_notifications(
750 &self,
751 limit: Option<i64>,
752 cursor: Option<&str>,
753 ) -> Result<serde_json::Value, AgentError> {
754 let mut params = QueryParams::new();
755 if let Some(limit) = limit {
756 params.insert("limit".into(), QueryValue::Integer(limit));
757 }
758 if let Some(cursor) = cursor {
759 params.insert("cursor".into(), QueryValue::String(cursor.into()));
760 }
761 self.xrpc_query("app.bsky.notification.listNotifications", Some(¶ms))
762 .await
763 }
764
765 pub async fn upload_blob(
767 &self,
768 data: Vec<u8>,
769 content_type: &str,
770 ) -> Result<serde_json::Value, AgentError> {
771 let mut headers = HeadersMap::new();
772 headers.insert("Content-Type".into(), content_type.into());
773
774 if let Some(sess) = self.session.read().await.as_ref() {
776 headers.insert(
777 "Authorization".into(),
778 format!("Bearer {}", sess.access_jwt),
779 );
780 }
781
782 let opts = CallOptions {
783 encoding: Some(content_type.to_string()),
784 headers: Some(headers),
785 ..Default::default()
786 };
787
788 let response = self
789 .client
790 .procedure(
791 "com.atproto.repo.uploadBlob",
792 None,
793 Some(XrpcBody::Bytes(data)),
794 Some(&opts),
795 )
796 .await?;
797
798 Ok(response.data)
799 }
800
801 pub async fn describe_server(&self) -> Result<serde_json::Value, AgentError> {
803 self.xrpc_query("com.atproto.server.describeServer", None)
804 .await
805 }
806
807 pub async fn logout(&self) -> Result<(), AgentError> {
817 let refresh_jwt = {
818 let guard = self.session.read().await;
819 guard.as_ref().map(|s| s.refresh_jwt.clone())
820 };
821
822 let server_result = if let Some(refresh_jwt) = refresh_jwt {
823 let mut headers = HeadersMap::new();
824 headers.insert("Authorization".into(), format!("Bearer {refresh_jwt}"));
825 let opts = CallOptions {
826 encoding: None,
827 headers: Some(headers),
828 ..Default::default()
829 };
830 self.client
831 .procedure("com.atproto.server.deleteSession", None, None, Some(&opts))
832 .await
833 .map(|_| ())
834 } else {
835 Ok(())
836 };
837
838 *self.session.write().await = None;
840 self.emit(AtpSessionEvent::Expired, None);
841
842 server_result.map_err(AgentError::Xrpc)
843 }
844
845 pub async fn create_account(
854 &self,
855 handle: &str,
856 password: &str,
857 email: Option<&str>,
858 extra: Option<serde_json::Value>,
859 ) -> Result<Session, AgentError> {
860 let mut body = serde_json::json!({
861 "handle": handle,
862 "password": password,
863 });
864 if let Some(email) = email {
865 body["email"] = serde_json::Value::String(email.to_string());
866 }
867 if let Some(extra) = extra
868 && let Some(extra_map) = extra.as_object()
869 && let Some(body_map) = body.as_object_mut()
870 {
871 for (k, v) in extra_map {
872 body_map.insert(k.clone(), v.clone());
873 }
874 }
875
876 let response = match self
877 .client
878 .procedure(
879 "com.atproto.server.createAccount",
880 None,
881 Some(XrpcBody::Json(body)),
882 None,
883 )
884 .await
885 {
886 Ok(r) => r,
887 Err(e) => {
888 self.emit(AtpSessionEvent::CreateFailed, None);
889 return Err(AgentError::Xrpc(e));
890 }
891 };
892
893 let session: Session = serde_json::from_value(response.data)?;
894 *self.session.write().await = Some(session.clone());
895 self.emit(AtpSessionEvent::Create, Some(&session));
896 Ok(session)
897 }
898
899 pub async fn upsert_profile<F>(&self, mutate: F) -> Result<serde_json::Value, AgentError>
910 where
911 F: Fn(serde_json::Value) -> serde_json::Value,
912 {
913 let did = self.assert_did().await?;
914 const MAX_RETRIES: u32 = 5;
915
916 for _ in 0..MAX_RETRIES {
917 let existing_result = self
919 .xrpc_query(
920 "com.atproto.repo.getRecord",
921 Some(&{
922 let mut p = QueryParams::new();
923 p.insert("repo".into(), QueryValue::String(did.clone()));
924 p.insert(
925 "collection".into(),
926 QueryValue::String("app.bsky.actor.profile".into()),
927 );
928 p.insert("rkey".into(), QueryValue::String("self".into()));
929 p
930 }),
931 )
932 .await;
933
934 let (existing_record, swap_cid) = match existing_result {
935 Ok(r) => {
936 let record = r.get("value").cloned().unwrap_or(serde_json::Value::Null);
937 let cid = r.get("cid").and_then(|v| v.as_str()).map(String::from);
938 (record, cid)
939 }
940 Err(AgentError::Xrpc(ref e)) if is_not_found(e) => (serde_json::Value::Null, None),
941 Err(e) => return Err(e),
942 };
943
944 let updated = mutate(existing_record);
945 let mut body = serde_json::json!({
946 "repo": did,
947 "collection": "app.bsky.actor.profile",
948 "rkey": "self",
949 "record": updated,
950 });
951 if let Some(cid) = swap_cid {
952 body["swapRecord"] = serde_json::Value::String(cid);
953 }
954
955 match self
956 .xrpc_procedure("com.atproto.repo.putRecord", body)
957 .await
958 {
959 Ok(r) => return Ok(r),
960 Err(AgentError::Xrpc(ref e)) if is_invalid_swap(e) => {
961 continue;
964 }
965 Err(e) => return Err(e),
966 }
967 }
968
969 Err(AgentError::Other(
970 "upsert_profile: exceeded maximum retries due to concurrent writes".into(),
971 ))
972 }
973}
974
975fn is_not_found(err: &proto_blue_xrpc::Error) -> bool {
978 match err {
979 proto_blue_xrpc::Error::Xrpc(x) => x.is_error("RecordNotFound"),
980 _ => false,
981 }
982}
983
984fn is_invalid_swap(err: &proto_blue_xrpc::Error) -> bool {
987 match err {
988 proto_blue_xrpc::Error::Xrpc(x) => x.is_error("InvalidSwap"),
989 _ => false,
990 }
991}
992
993fn is_auth_expired(err: &proto_blue_xrpc::Error) -> bool {
1000 match err {
1001 proto_blue_xrpc::Error::Xrpc(x) => {
1002 matches!(x.status, ResponseType::AuthenticationRequired) && x.is_error("ExpiredToken")
1003 }
1004 _ => false,
1005 }
1006}
1007
1008const fn is_refresh_rejected(err: &proto_blue_xrpc::Error) -> bool {
1013 match err {
1014 proto_blue_xrpc::Error::Xrpc(x) => {
1015 matches!(x.status, ResponseType::AuthenticationRequired)
1016 }
1017 _ => false,
1018 }
1019}
1020
1021#[cfg(test)]
1022mod tests {
1023 use super::*;
1024
1025 #[test]
1026 fn agent_creation() {
1027 let _agent = Agent::new("https://bsky.social").unwrap();
1028 }
1029
1030 #[test]
1031 fn session_serde_roundtrip() {
1032 let session = Session {
1033 did: "did:plc:abc123".to_string(),
1034 handle: "alice.bsky.social".to_string(),
1035 access_jwt: "eyJ...".to_string(),
1036 refresh_jwt: "eyJ...".to_string(),
1037 email: Some("alice@example.com".to_string()),
1038 email_confirmed: Some(true),
1039 };
1040
1041 let json = serde_json::to_string(&session).unwrap();
1042 let parsed: Session = serde_json::from_str(&json).unwrap();
1043 assert_eq!(parsed.did, "did:plc:abc123");
1044 assert_eq!(parsed.handle, "alice.bsky.social");
1045 assert_eq!(parsed.email, Some("alice@example.com".to_string()));
1046 }
1047
1048 #[test]
1049 fn agent_error_display() {
1050 let err = AgentError::NotAuthenticated;
1051 assert_eq!(err.to_string(), "Not authenticated");
1052
1053 let err = AgentError::Other("test error".into());
1054 assert_eq!(err.to_string(), "test error");
1055 }
1056
1057 #[tokio::test]
1058 async fn agent_no_session_by_default() {
1059 let agent = Agent::new("https://bsky.social").unwrap();
1060 assert!(agent.did().await.is_none());
1061 assert!(agent.session().await.is_none());
1062 }
1063
1064 #[tokio::test]
1065 async fn agent_assert_did_fails_when_not_logged_in() {
1066 let agent = Agent::new("https://bsky.social").unwrap();
1067 let err = agent.assert_did().await.unwrap_err();
1068 assert!(matches!(err, AgentError::NotAuthenticated));
1069 }
1070
1071 #[test]
1072 fn now_iso_format() {
1073 let ts = Agent::now_iso();
1074 assert!(ts.ends_with('Z'));
1075 assert!(ts.contains('T'));
1076 }
1077
1078 #[test]
1079 fn resolve_timestamp_with_provided() {
1080 let ts = Agent::resolve_timestamp(Some("2024-01-15T12:00:00.000Z"));
1081 assert_eq!(ts, "2024-01-15T12:00:00.000Z");
1082 }
1083
1084 #[test]
1085 fn resolve_timestamp_without_provided() {
1086 let ts = Agent::resolve_timestamp(None);
1087 assert!(ts.ends_with('Z'));
1088 assert!(ts.contains('T'));
1089 }
1090
1091 #[test]
1092 fn service_url_accessible_without_async() {
1093 let agent = Agent::new("https://bsky.social").unwrap();
1094 assert_eq!(agent.service(), "https://bsky.social/");
1095 }
1096
1097 #[tokio::test]
1098 async fn auth_call_options_none_when_not_authenticated() {
1099 let agent = Agent::new("https://bsky.social").unwrap();
1100 assert!(agent.auth_call_options().await.is_none());
1101 }
1102
1103 use async_trait::async_trait;
1106 use proto_blue_common::fetch::{FetchError, FetchHandler, HttpRequest, HttpResponse};
1107
1108 struct ScriptedFetcher {
1112 createsession_body: Vec<u8>,
1113 scripts: std::sync::Mutex<std::collections::HashMap<String, Vec<ScriptedResponse>>>,
1115 call_counts: std::sync::Mutex<std::collections::HashMap<String, usize>>,
1116 }
1117
1118 #[derive(Clone)]
1119 struct ScriptedResponse {
1120 status: u16,
1121 body: Vec<u8>,
1122 }
1123
1124 impl ScriptedFetcher {
1125 fn new(createsession_body: Vec<u8>) -> Self {
1126 Self {
1127 createsession_body,
1128 scripts: Default::default(),
1129 call_counts: Default::default(),
1130 }
1131 }
1132 fn script(&self, path: &str, responses: Vec<ScriptedResponse>) {
1133 self.scripts
1134 .lock()
1135 .unwrap()
1136 .insert(path.to_string(), responses);
1137 }
1138 fn call_count(&self, path: &str) -> usize {
1139 *self.call_counts.lock().unwrap().get(path).unwrap_or(&0)
1140 }
1141 }
1142
1143 #[async_trait]
1144 impl FetchHandler for ScriptedFetcher {
1145 async fn fetch(&self, req: HttpRequest) -> Result<HttpResponse, FetchError> {
1146 let path = req.url.clone();
1147 let key = path
1148 .split("/xrpc/")
1149 .nth(1)
1150 .unwrap_or(&path)
1151 .split('?')
1152 .next()
1153 .unwrap_or("")
1154 .to_string();
1155 *self
1156 .call_counts
1157 .lock()
1158 .unwrap()
1159 .entry(key.clone())
1160 .or_insert(0) += 1;
1161
1162 {
1166 let mut scripts = self.scripts.lock().unwrap();
1167 if let Some(list) = scripts.get_mut(&key) {
1168 let resp = if list.len() == 1 {
1169 list[0].clone()
1170 } else {
1171 list.remove(0)
1172 };
1173 let mut headers = proto_blue_common::fetch::HttpHeaders::new();
1174 headers.insert("content-type".into(), "application/json".into());
1175 return Ok(HttpResponse {
1176 status: resp.status,
1177 headers,
1178 body: resp.body,
1179 });
1180 }
1181 }
1182
1183 if key == "com.atproto.server.createSession" {
1185 let mut headers = proto_blue_common::fetch::HttpHeaders::new();
1186 headers.insert("content-type".into(), "application/json".into());
1187 return Ok(HttpResponse {
1188 status: 200,
1189 headers,
1190 body: self.createsession_body.clone(),
1191 });
1192 }
1193
1194 Err(FetchError::Other(format!("no script for {key}")))
1195 }
1196 }
1197
1198 fn login_body() -> Vec<u8> {
1199 br#"{"did":"did:plc:u","handle":"alice","accessJwt":"a1","refreshJwt":"r1"}"#.to_vec()
1200 }
1201
1202 fn agent_with_fetcher(fetcher: Arc<ScriptedFetcher>) -> Agent {
1203 let client = XrpcClient::with_fetch_handler("https://example.com", fetcher).unwrap();
1204 Agent {
1205 client,
1206 session: Arc::new(RwLock::new(None)),
1207 listeners: Arc::new(Mutex::new(Vec::new())),
1208 refresh_lock: Arc::new(AsyncMutex::new(())),
1209 proxy: Arc::new(RwLock::new(None)),
1210 labelers: Arc::new(RwLock::new(Vec::new())),
1211 }
1212 }
1213
1214 #[tokio::test]
1215 async fn emits_create_on_successful_login() {
1216 let fetcher = Arc::new(ScriptedFetcher::new(login_body()));
1217 let agent = agent_with_fetcher(fetcher);
1218
1219 let events: Arc<Mutex<Vec<AtpSessionEvent>>> = Arc::new(Mutex::new(Vec::new()));
1220 let ev_clone = events.clone();
1221 agent.on_session(move |e, _| ev_clone.lock().unwrap().push(e));
1222
1223 agent.login("alice", "secret").await.unwrap();
1224 let got = events.lock().unwrap().clone();
1225 assert_eq!(got, vec![AtpSessionEvent::Create]);
1226 }
1227
1228 #[tokio::test]
1229 async fn emits_create_failed_on_login_rejection() {
1230 let fetcher = Arc::new(ScriptedFetcher::new(vec![]));
1231 fetcher.script(
1233 "com.atproto.server.createSession",
1234 vec![ScriptedResponse {
1235 status: 401,
1236 body: br#"{"error":"AuthenticationRequired","message":"bad pwd"}"#.to_vec(),
1237 }],
1238 );
1239 let agent = agent_with_fetcher(fetcher);
1240
1241 let events: Arc<Mutex<Vec<AtpSessionEvent>>> = Arc::new(Mutex::new(Vec::new()));
1242 let ev_clone = events.clone();
1243 agent.on_session(move |e, _| ev_clone.lock().unwrap().push(e));
1244
1245 let _ = agent.login("alice", "bad").await.unwrap_err();
1249 let got = events.lock().unwrap().clone();
1250 assert_eq!(got, vec![AtpSessionEvent::CreateFailed]);
1251 }
1252
1253 #[tokio::test]
1254 async fn auto_refreshes_on_expired_access_token() {
1255 let fetcher = Arc::new(ScriptedFetcher::new(login_body()));
1256
1257 fetcher.script(
1260 "com.atproto.server.describeServer",
1261 vec![
1262 ScriptedResponse {
1263 status: 401,
1264 body: br#"{"error":"ExpiredToken","message":"expired"}"#.to_vec(),
1265 },
1266 ScriptedResponse {
1267 status: 200,
1268 body: br#"{"did":"did:plc:svr"}"#.to_vec(),
1269 },
1270 ],
1271 );
1272 fetcher.script(
1273 "com.atproto.server.refreshSession",
1274 vec![ScriptedResponse {
1275 status: 200,
1276 body: br#"{"did":"did:plc:u","handle":"alice","accessJwt":"a2","refreshJwt":"r2"}"#
1277 .to_vec(),
1278 }],
1279 );
1280
1281 let agent = agent_with_fetcher(fetcher.clone());
1282 agent.login("alice", "secret").await.unwrap();
1283
1284 let events: Arc<Mutex<Vec<AtpSessionEvent>>> = Arc::new(Mutex::new(Vec::new()));
1285 let ev_clone = events.clone();
1286 agent.on_session(move |e, _| ev_clone.lock().unwrap().push(e));
1287
1288 let result = agent.describe_server().await.unwrap();
1289 assert_eq!(result["did"], "did:plc:svr");
1290
1291 assert_eq!(fetcher.call_count("com.atproto.server.describeServer"), 2);
1294 assert_eq!(fetcher.call_count("com.atproto.server.refreshSession"), 1);
1295
1296 let got = events.lock().unwrap().clone();
1298 assert_eq!(got, vec![AtpSessionEvent::Update]);
1299 }
1300
1301 #[tokio::test]
1302 async fn concurrent_expired_token_refreshes_once() {
1303 let fetcher = Arc::new(ScriptedFetcher::new(login_body()));
1304
1305 fetcher.script(
1308 "com.atproto.server.describeServer",
1309 vec![
1310 ScriptedResponse {
1311 status: 401,
1312 body: br#"{"error":"ExpiredToken","message":"expired"}"#.to_vec(),
1313 },
1314 ScriptedResponse {
1315 status: 200,
1316 body: br#"{"did":"did:plc:svr"}"#.to_vec(),
1317 },
1318 ],
1319 );
1320 fetcher.script(
1321 "com.atproto.server.refreshSession",
1322 vec![ScriptedResponse {
1323 status: 200,
1324 body: br#"{"did":"did:plc:u","handle":"alice","accessJwt":"a2","refreshJwt":"r2"}"#
1325 .to_vec(),
1326 }],
1327 );
1328
1329 let agent = Arc::new(agent_with_fetcher(fetcher.clone()));
1330 agent.login("alice", "secret").await.unwrap();
1331
1332 let mut handles = Vec::new();
1336 for _ in 0..5 {
1337 let a = agent.clone();
1338 handles.push(tokio::spawn(async move {
1339 a.describe_server().await.unwrap();
1340 }));
1341 }
1342 for h in handles {
1343 h.await.unwrap();
1344 }
1345
1346 assert_eq!(
1347 fetcher.call_count("com.atproto.server.refreshSession"),
1348 1,
1349 "concurrent callers must share one refreshSession call",
1350 );
1351 }
1352
1353 #[tokio::test]
1354 async fn configure_proxy_sets_header_on_next_call() {
1355 let fetcher = Arc::new(ScriptedFetcher::new(login_body()));
1356 fetcher.script(
1357 "com.atproto.server.describeServer",
1358 vec![ScriptedResponse {
1359 status: 200,
1360 body: br#"{"did":"did:plc:svr"}"#.to_vec(),
1361 }],
1362 );
1363 let agent = agent_with_fetcher(fetcher.clone());
1364 agent
1365 .configure_proxy(Some("did:web:api.bsky.chat#bsky_chat"))
1366 .await;
1367
1368 agent.describe_server().await.unwrap();
1369
1370 let p = agent.proxy.read().await;
1373 assert_eq!(p.as_deref(), Some("did:web:api.bsky.chat#bsky_chat"));
1374 }
1375
1376 #[tokio::test]
1377 async fn configure_labelers_stores_list() {
1378 let fetcher = Arc::new(ScriptedFetcher::new(login_body()));
1379 let agent = agent_with_fetcher(fetcher);
1380 agent
1381 .configure_labelers(&[
1382 LabelerOpts {
1383 did: "did:plc:a".into(),
1384 redirect: false,
1385 },
1386 LabelerOpts {
1387 did: "did:plc:b".into(),
1388 redirect: true,
1389 },
1390 ])
1391 .await;
1392 let l = agent.labelers.read().await;
1393 assert_eq!(l.len(), 2);
1394 assert_eq!(l[0].header_value(), "did:plc:a");
1395 assert_eq!(l[1].header_value(), "did:plc:b;redirect");
1396 }
1397
1398 #[tokio::test]
1399 async fn logout_clears_session() {
1400 let fetcher = Arc::new(ScriptedFetcher::new(login_body()));
1401 fetcher.script(
1402 "com.atproto.server.deleteSession",
1403 vec![ScriptedResponse {
1404 status: 200,
1405 body: b"{}".to_vec(),
1406 }],
1407 );
1408 let agent = agent_with_fetcher(fetcher.clone());
1409 agent.login("alice", "secret").await.unwrap();
1410 assert!(agent.session().await.is_some());
1411 agent.logout().await.unwrap();
1412 assert!(agent.session().await.is_none());
1413 assert_eq!(fetcher.call_count("com.atproto.server.deleteSession"), 1,);
1414 }
1415
1416 #[tokio::test]
1417 async fn logout_clears_session_even_on_server_error() {
1418 let fetcher = Arc::new(ScriptedFetcher::new(login_body()));
1419 fetcher.script(
1420 "com.atproto.server.deleteSession",
1421 vec![ScriptedResponse {
1422 status: 500,
1423 body: br#"{"error":"InternalServerError"}"#.to_vec(),
1424 }],
1425 );
1426 let agent = agent_with_fetcher(fetcher);
1427 agent.login("alice", "secret").await.unwrap();
1428 let _ = agent.logout().await;
1430 assert!(agent.session().await.is_none());
1431 }
1432
1433 #[tokio::test]
1434 async fn create_account_emits_create_on_success() {
1435 let fetcher = Arc::new(ScriptedFetcher::new(login_body()));
1436 fetcher.script(
1437 "com.atproto.server.createAccount",
1438 vec![ScriptedResponse {
1439 status: 200,
1440 body:
1441 br#"{"did":"did:plc:new","handle":"newuser","accessJwt":"a","refreshJwt":"r"}"#
1442 .to_vec(),
1443 }],
1444 );
1445 let agent = agent_with_fetcher(fetcher);
1446
1447 let events: Arc<Mutex<Vec<AtpSessionEvent>>> = Arc::new(Mutex::new(Vec::new()));
1448 let ev = events.clone();
1449 agent.on_session(move |e, _| ev.lock().unwrap().push(e));
1450
1451 let session = agent
1452 .create_account("newuser", "pw", Some("new@example.com"), None)
1453 .await
1454 .unwrap();
1455 assert_eq!(session.did, "did:plc:new");
1456 assert_eq!(
1457 events.lock().unwrap().clone(),
1458 vec![AtpSessionEvent::Create]
1459 );
1460 }
1461
1462 #[tokio::test]
1463 async fn upsert_profile_creates_when_absent() {
1464 let fetcher = Arc::new(ScriptedFetcher::new(login_body()));
1465 fetcher.script(
1467 "com.atproto.repo.getRecord",
1468 vec![ScriptedResponse {
1469 status: 400,
1470 body: br#"{"error":"RecordNotFound","message":"no such record"}"#.to_vec(),
1471 }],
1472 );
1473 fetcher.script(
1474 "com.atproto.repo.putRecord",
1475 vec![ScriptedResponse {
1476 status: 200,
1477 body: br#"{"uri":"at://did:plc:u/app.bsky.actor.profile/self","cid":"bafy"}"#
1478 .to_vec(),
1479 }],
1480 );
1481 let agent = agent_with_fetcher(fetcher);
1482 agent.login("alice", "secret").await.unwrap();
1483
1484 let result = agent
1485 .upsert_profile(|prev| {
1486 assert!(prev.is_null(), "no existing profile");
1487 serde_json::json!({"$type": "app.bsky.actor.profile", "displayName": "Alice"})
1488 })
1489 .await
1490 .unwrap();
1491 assert_eq!(result["uri"], "at://did:plc:u/app.bsky.actor.profile/self");
1492 }
1493
1494 #[tokio::test]
1495 async fn emits_expired_when_refresh_itself_401s() {
1496 let fetcher = Arc::new(ScriptedFetcher::new(login_body()));
1497 fetcher.script(
1498 "com.atproto.server.refreshSession",
1499 vec![ScriptedResponse {
1500 status: 401,
1501 body: br#"{"error":"AuthenticationRequired","message":"refresh expired"}"#.to_vec(),
1502 }],
1503 );
1504 let agent = agent_with_fetcher(fetcher);
1505 agent.login("alice", "secret").await.unwrap();
1506
1507 let events: Arc<Mutex<Vec<AtpSessionEvent>>> = Arc::new(Mutex::new(Vec::new()));
1508 let ev_clone = events.clone();
1509 agent.on_session(move |e, _| ev_clone.lock().unwrap().push(e));
1510
1511 let _ = agent.refresh_session().await.unwrap_err();
1512 let got = events.lock().unwrap().clone();
1513 assert_eq!(got, vec![AtpSessionEvent::Expired]);
1514 assert!(
1515 agent.session().await.is_none(),
1516 "session cleared on expired refresh"
1517 );
1518 }
1519}