1use std::sync::{Arc, Mutex};
7use tokio::sync::{Mutex as AsyncMutex, RwLock};
8
9use proto_blue_xrpc::{
10 CallOptions, HeadersMap, QueryParams, QueryValue, ResponseType, XrpcBody, XrpcClient,
11};
12
13use crate::rich_text::RichText;
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum AtpSessionEvent {
23 Create,
25 CreateFailed,
27 Update,
29 Expired,
31 NetworkError,
33}
34
35pub type SessionEventCallback =
42 Arc<dyn Fn(AtpSessionEvent, Option<&Session>) + Send + Sync>;
43
44#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
46#[serde(rename_all = "camelCase")]
47pub struct Session {
48 pub did: String,
49 pub handle: String,
50 pub access_jwt: String,
51 pub refresh_jwt: String,
52 #[serde(skip_serializing_if = "Option::is_none")]
53 pub email: Option<String>,
54 #[serde(skip_serializing_if = "Option::is_none")]
55 pub email_confirmed: Option<bool>,
56}
57
58#[derive(Debug, thiserror::Error)]
60pub enum AgentError {
61 #[error("XRPC error: {0}")]
62 Xrpc(#[from] proto_blue_xrpc::Error),
63 #[error("Not authenticated")]
64 NotAuthenticated,
65 #[error("JSON error: {0}")]
66 Json(#[from] serde_json::Error),
67 #[error("{0}")]
68 Other(String),
69}
70
71pub struct Agent {
89 client: XrpcClient,
90 session: Arc<RwLock<Option<Session>>>,
91 listeners: Arc<Mutex<Vec<SessionEventCallback>>>,
94 refresh_lock: Arc<AsyncMutex<()>>,
99 proxy: Arc<RwLock<Option<String>>>,
102 labelers: Arc<RwLock<Vec<LabelerOpts>>>,
104}
105
106#[derive(Debug, Clone, PartialEq, Eq)]
108pub struct LabelerOpts {
109 pub did: String,
111 pub redirect: bool,
114}
115
116impl LabelerOpts {
117 fn header_value(&self) -> String {
119 if self.redirect {
120 format!("{};redirect", self.did)
121 } else {
122 self.did.clone()
123 }
124 }
125}
126
127impl Agent {
128 pub fn new(service: impl AsRef<str>) -> Result<Self, AgentError> {
130 let client = XrpcClient::new(service)?;
131 Ok(Agent {
132 client,
133 session: Arc::new(RwLock::new(None)),
134 listeners: Arc::new(Mutex::new(Vec::new())),
135 refresh_lock: Arc::new(AsyncMutex::new(())),
136 proxy: Arc::new(RwLock::new(None)),
137 labelers: Arc::new(RwLock::new(Vec::new())),
138 })
139 }
140
141 pub fn on_session<F>(&self, callback: F)
148 where
149 F: Fn(AtpSessionEvent, Option<&Session>) + Send + Sync + 'static,
150 {
151 self.listeners.lock().unwrap().push(Arc::new(callback));
152 }
153
154 fn emit(&self, event: AtpSessionEvent, session: Option<&Session>) {
156 let listeners = self.listeners.lock().unwrap().clone();
161 for cb in listeners {
162 cb(event, session);
163 }
164 }
165
166 pub fn service(&self) -> String {
168 self.client.service_url().to_string()
169 }
170
171 pub async fn did(&self) -> Option<String> {
173 self.session.read().await.as_ref().map(|s| s.did.clone())
174 }
175
176 pub async fn session(&self) -> Option<Session> {
178 self.session.read().await.clone()
179 }
180
181 async fn auth_call_options(&self) -> Option<CallOptions> {
188 let guard = self.session.read().await;
189 let session = guard.as_ref()?;
190 let mut headers = HeadersMap::new();
191 headers.insert(
192 "Authorization".into(),
193 format!("Bearer {}", session.access_jwt),
194 );
195 self.inject_proxy_and_labelers(&mut headers).await;
196 Some(CallOptions {
197 encoding: None,
198 headers: Some(headers),
199 ..Default::default()
200 })
201 }
202
203 pub async fn anon_call_options(&self) -> Option<CallOptions> {
209 let mut headers = HeadersMap::new();
210 self.inject_proxy_and_labelers(&mut headers).await;
211 if headers.is_empty() {
212 None
213 } else {
214 Some(CallOptions {
215 encoding: None,
216 headers: Some(headers),
217 ..Default::default()
218 })
219 }
220 }
221
222 async fn inject_proxy_and_labelers(&self, headers: &mut HeadersMap) {
223 if let Some(proxy) = self.proxy.read().await.as_ref() {
224 headers.insert("atproto-proxy".into(), proxy.clone());
225 }
226 let labelers = self.labelers.read().await;
227 if !labelers.is_empty() {
228 let v = labelers
229 .iter()
230 .map(|l| l.header_value())
231 .collect::<Vec<_>>()
232 .join(", ");
233 headers.insert("atproto-accept-labelers".into(), v);
234 }
235 }
236
237 pub async fn configure_proxy(&self, target: Option<&str>) {
243 *self.proxy.write().await = target.map(String::from);
244 }
245
246 pub async fn with_proxy(&self, target: &str) -> Self {
249 let cloned = self.shallow_clone();
250 cloned.configure_proxy(Some(target)).await;
251 cloned
252 }
253
254 pub async fn configure_labelers(&self, labelers: &[LabelerOpts]) {
257 *self.labelers.write().await = labelers.to_vec();
258 }
259
260 fn shallow_clone(&self) -> Self {
264 Agent {
265 client: self.client.clone(),
266 session: self.session.clone(),
267 listeners: self.listeners.clone(),
268 refresh_lock: self.refresh_lock.clone(),
269 proxy: Arc::new(RwLock::new(None)),
270 labelers: self.labelers.clone(),
271 }
272 }
273
274 pub async fn login(&self, identifier: &str, password: &str) -> Result<Session, AgentError> {
280 let body = serde_json::json!({
281 "identifier": identifier,
282 "password": password,
283 });
284
285 let response = match self
286 .client
287 .procedure(
288 "com.atproto.server.createSession",
289 None,
290 Some(XrpcBody::Json(body)),
291 None,
292 )
293 .await
294 {
295 Ok(r) => r,
296 Err(e) => {
297 self.emit(AtpSessionEvent::CreateFailed, None);
298 return Err(AgentError::Xrpc(e));
299 }
300 };
301
302 let session: Session = serde_json::from_value(response.data)?;
303
304 *self.session.write().await = Some(session.clone());
306 self.emit(AtpSessionEvent::Create, Some(&session));
307 Ok(session)
308 }
309
310 pub async fn resume_session(&self, session: Session) -> Result<(), AgentError> {
315 let mut headers = HeadersMap::new();
318 headers.insert(
319 "Authorization".into(),
320 format!("Bearer {}", session.access_jwt),
321 );
322 let opts = CallOptions {
323 encoding: None,
324 headers: Some(headers),
325 ..Default::default()
326 };
327 let response = self
328 .client
329 .query("com.atproto.server.getSession", None, Some(&opts))
330 .await?;
331 let verified_did = response
332 .data
333 .get("did")
334 .and_then(|v| v.as_str())
335 .map(|s| s.to_string());
336
337 let mut committed = session;
339 if let Some(did) = verified_did {
340 committed.did = did;
341 }
342 *self.session.write().await = Some(committed.clone());
343 self.emit(AtpSessionEvent::Create, Some(&committed));
344
345 Ok(())
346 }
347
348 pub async fn refresh_session(&self) -> Result<Session, AgentError> {
356 let refresh_jwt = {
357 let sess = self.session.read().await;
358 let sess = sess.as_ref().ok_or(AgentError::NotAuthenticated)?;
359 sess.refresh_jwt.clone()
360 };
361
362 let mut headers = HeadersMap::new();
364 headers.insert("Authorization".into(), format!("Bearer {}", refresh_jwt));
365 let opts = CallOptions {
366 encoding: None,
367 headers: Some(headers),
368 ..Default::default()
369 };
370
371 let response = match self
372 .client
373 .procedure("com.atproto.server.refreshSession", None, None, Some(&opts))
374 .await
375 {
376 Ok(r) => r,
377 Err(e) => {
378 if is_refresh_rejected(&e) {
384 *self.session.write().await = None;
385 self.emit(AtpSessionEvent::Expired, None);
386 } else {
387 self.emit(AtpSessionEvent::NetworkError, None);
388 }
389 return Err(AgentError::Xrpc(e));
390 }
391 };
392
393 let session: Session = serde_json::from_value(response.data)?;
394
395 *self.session.write().await = Some(session.clone());
397 self.emit(AtpSessionEvent::Update, Some(&session));
398 Ok(session)
399 }
400
401 async fn assert_did(&self) -> Result<String, AgentError> {
405 self.did().await.ok_or(AgentError::NotAuthenticated)
406 }
407
408 async fn xrpc_query(
415 &self,
416 nsid: &str,
417 params: Option<&QueryParams>,
418 ) -> Result<serde_json::Value, AgentError> {
419 let opts = self.auth_call_options().await;
420 let first = self.client.query(nsid, params, opts.as_ref()).await;
421 match first {
422 Ok(r) => Ok(r.data),
423 Err(e) if is_auth_expired(&e) => {
424 self.refresh_and_retry(|opts| {
425 let c = self.client.clone();
426 let nsid = nsid.to_string();
427 let params = params.cloned();
428 async move {
429 c.query(&nsid, params.as_ref(), opts.as_ref()).await
430 }
431 })
432 .await
433 }
434 Err(e) => Err(AgentError::Xrpc(e)),
435 }
436 }
437
438 async fn xrpc_procedure(
440 &self,
441 nsid: &str,
442 body: serde_json::Value,
443 ) -> Result<serde_json::Value, AgentError> {
444 let opts = self.auth_call_options().await;
445 let first = self
446 .client
447 .procedure(nsid, None, Some(XrpcBody::Json(body.clone())), opts.as_ref())
448 .await;
449 match first {
450 Ok(r) => Ok(r.data),
451 Err(e) if is_auth_expired(&e) => {
452 self.refresh_and_retry(|opts| {
453 let c = self.client.clone();
454 let nsid = nsid.to_string();
455 let body = body.clone();
456 async move {
457 c.procedure(&nsid, None, Some(XrpcBody::Json(body)), opts.as_ref())
458 .await
459 }
460 })
461 .await
462 }
463 Err(e) => Err(AgentError::Xrpc(e)),
464 }
465 }
466
467 async fn refresh_and_retry<F, Fut>(
477 &self,
478 replay: F,
479 ) -> Result<serde_json::Value, AgentError>
480 where
481 F: FnOnce(Option<CallOptions>) -> Fut,
482 Fut: std::future::Future<
483 Output = Result<proto_blue_xrpc::XrpcResponse, proto_blue_xrpc::Error>,
484 >,
485 {
486 let pre_refresh_jwt = self
490 .session
491 .read()
492 .await
493 .as_ref()
494 .map(|s| s.access_jwt.clone());
495 let _guard = self.refresh_lock.lock().await;
496 let current_jwt = self
497 .session
498 .read()
499 .await
500 .as_ref()
501 .map(|s| s.access_jwt.clone());
502 if pre_refresh_jwt == current_jwt {
503 self.refresh_session().await?;
505 }
506 drop(_guard);
507
508 let opts = self.auth_call_options().await;
509 let response = replay(opts).await?;
510 Ok(response.data)
511 }
512
513 async fn create_record(
515 &self,
516 collection: &str,
517 record: serde_json::Value,
518 ) -> Result<serde_json::Value, AgentError> {
519 let did = self.assert_did().await?;
520 let body = serde_json::json!({
521 "repo": did,
522 "collection": collection,
523 "record": record,
524 });
525 self.xrpc_procedure("com.atproto.repo.createRecord", body)
526 .await
527 }
528
529 async fn delete_record(&self, collection: &str, uri: &str) -> Result<(), AgentError> {
531 let did = self.assert_did().await?;
532 let rkey = uri
533 .rsplit('/')
534 .next()
535 .ok_or_else(|| AgentError::Other("Invalid AT-URI".into()))?;
536
537 let body = serde_json::json!({
538 "repo": did,
539 "collection": collection,
540 "rkey": rkey,
541 });
542 self.xrpc_procedure("com.atproto.repo.deleteRecord", body)
543 .await?;
544 Ok(())
545 }
546
547 fn now_iso() -> String {
549 chrono::Utc::now().to_rfc3339_opts(chrono::SecondsFormat::Millis, true)
550 }
551
552 fn resolve_timestamp(created_at: Option<&str>) -> String {
554 created_at.map(String::from).unwrap_or_else(Self::now_iso)
555 }
556
557 pub async fn post(
563 &self,
564 text: &str,
565 facets: Option<Vec<crate::rich_text::Facet>>,
566 created_at: Option<&str>,
567 ) -> Result<serde_json::Value, AgentError> {
568 let mut record = serde_json::json!({
569 "$type": "app.bsky.feed.post",
570 "text": text,
571 "createdAt": Self::resolve_timestamp(created_at),
572 });
573
574 if let Some(facets) = facets {
575 record["facets"] = serde_json::to_value(&facets)?;
576 }
577
578 self.create_record("app.bsky.feed.post", record).await
579 }
580
581 pub async fn post_rich(
583 &self,
584 rt: &RichText,
585 created_at: Option<&str>,
586 ) -> Result<serde_json::Value, AgentError> {
587 let facets = if rt.facets().is_empty() {
588 None
589 } else {
590 Some(rt.facets().to_vec())
591 };
592 self.post(rt.text(), facets, created_at).await
593 }
594
595 pub async fn delete_post(&self, uri: &str) -> Result<(), AgentError> {
597 self.delete_record("app.bsky.feed.post", uri).await
598 }
599
600 pub async fn like(
606 &self,
607 uri: &str,
608 cid: &str,
609 created_at: Option<&str>,
610 ) -> Result<serde_json::Value, AgentError> {
611 let record = serde_json::json!({
612 "$type": "app.bsky.feed.like",
613 "subject": { "uri": uri, "cid": cid },
614 "createdAt": Self::resolve_timestamp(created_at),
615 });
616 self.create_record("app.bsky.feed.like", record).await
617 }
618
619 pub async fn delete_like(&self, like_uri: &str) -> Result<(), AgentError> {
621 self.delete_record("app.bsky.feed.like", like_uri).await
622 }
623
624 pub async fn repost(
628 &self,
629 uri: &str,
630 cid: &str,
631 created_at: Option<&str>,
632 ) -> Result<serde_json::Value, AgentError> {
633 let record = serde_json::json!({
634 "$type": "app.bsky.feed.repost",
635 "subject": { "uri": uri, "cid": cid },
636 "createdAt": Self::resolve_timestamp(created_at),
637 });
638 self.create_record("app.bsky.feed.repost", record).await
639 }
640
641 pub async fn delete_repost(&self, repost_uri: &str) -> Result<(), AgentError> {
643 self.delete_record("app.bsky.feed.repost", repost_uri).await
644 }
645
646 pub async fn follow(
652 &self,
653 subject_did: &str,
654 created_at: Option<&str>,
655 ) -> Result<serde_json::Value, AgentError> {
656 let record = serde_json::json!({
657 "$type": "app.bsky.graph.follow",
658 "subject": subject_did,
659 "createdAt": Self::resolve_timestamp(created_at),
660 });
661 self.create_record("app.bsky.graph.follow", record).await
662 }
663
664 pub async fn delete_follow(&self, follow_uri: &str) -> Result<(), AgentError> {
666 self.delete_record("app.bsky.graph.follow", follow_uri)
667 .await
668 }
669
670 pub async fn get_profile(&self, actor: &str) -> Result<serde_json::Value, AgentError> {
674 let mut params = QueryParams::new();
675 params.insert("actor".into(), QueryValue::String(actor.into()));
676 self.xrpc_query("app.bsky.actor.getProfile", Some(¶ms))
677 .await
678 }
679
680 pub async fn get_timeline(
682 &self,
683 limit: Option<i64>,
684 cursor: Option<&str>,
685 ) -> Result<serde_json::Value, AgentError> {
686 let mut params = QueryParams::new();
687 if let Some(limit) = limit {
688 params.insert("limit".into(), QueryValue::Integer(limit));
689 }
690 if let Some(cursor) = cursor {
691 params.insert("cursor".into(), QueryValue::String(cursor.into()));
692 }
693 self.xrpc_query("app.bsky.feed.getTimeline", Some(¶ms))
694 .await
695 }
696
697 pub async fn get_post_thread(
699 &self,
700 uri: &str,
701 depth: Option<i64>,
702 ) -> Result<serde_json::Value, AgentError> {
703 let mut params = QueryParams::new();
704 params.insert("uri".into(), QueryValue::String(uri.into()));
705 if let Some(depth) = depth {
706 params.insert("depth".into(), QueryValue::Integer(depth));
707 }
708 self.xrpc_query("app.bsky.feed.getPostThread", Some(¶ms))
709 .await
710 }
711
712 pub async fn search_actors(
714 &self,
715 query: &str,
716 limit: Option<i64>,
717 ) -> Result<serde_json::Value, AgentError> {
718 let mut params = QueryParams::new();
719 params.insert("q".into(), QueryValue::String(query.into()));
720 if let Some(limit) = limit {
721 params.insert("limit".into(), QueryValue::Integer(limit));
722 }
723 self.xrpc_query("app.bsky.actor.searchActors", Some(¶ms))
724 .await
725 }
726
727 pub async fn resolve_handle(&self, handle: &str) -> Result<String, AgentError> {
729 let mut params = QueryParams::new();
730 params.insert("handle".into(), QueryValue::String(handle.into()));
731 let data = self
732 .xrpc_query("com.atproto.identity.resolveHandle", Some(¶ms))
733 .await?;
734 data.get("did")
735 .and_then(|v| v.as_str())
736 .map(|s| s.to_string())
737 .ok_or_else(|| AgentError::Other("Missing DID in response".into()))
738 }
739
740 pub async fn list_notifications(
742 &self,
743 limit: Option<i64>,
744 cursor: Option<&str>,
745 ) -> Result<serde_json::Value, AgentError> {
746 let mut params = QueryParams::new();
747 if let Some(limit) = limit {
748 params.insert("limit".into(), QueryValue::Integer(limit));
749 }
750 if let Some(cursor) = cursor {
751 params.insert("cursor".into(), QueryValue::String(cursor.into()));
752 }
753 self.xrpc_query("app.bsky.notification.listNotifications", Some(¶ms))
754 .await
755 }
756
757 pub async fn upload_blob(
759 &self,
760 data: Vec<u8>,
761 content_type: &str,
762 ) -> Result<serde_json::Value, AgentError> {
763 let mut headers = HeadersMap::new();
764 headers.insert("Content-Type".into(), content_type.into());
765
766 if let Some(sess) = self.session.read().await.as_ref() {
768 headers.insert(
769 "Authorization".into(),
770 format!("Bearer {}", sess.access_jwt),
771 );
772 }
773
774 let opts = CallOptions {
775 encoding: Some(content_type.to_string()),
776 headers: Some(headers),
777 ..Default::default()
778 };
779
780 let response = self
781 .client
782 .procedure(
783 "com.atproto.repo.uploadBlob",
784 None,
785 Some(XrpcBody::Bytes(data)),
786 Some(&opts),
787 )
788 .await?;
789
790 Ok(response.data)
791 }
792
793 pub async fn describe_server(&self) -> Result<serde_json::Value, AgentError> {
795 self.xrpc_query("com.atproto.server.describeServer", None)
796 .await
797 }
798
799 pub async fn logout(&self) -> Result<(), AgentError> {
809 let refresh_jwt = {
810 let guard = self.session.read().await;
811 guard.as_ref().map(|s| s.refresh_jwt.clone())
812 };
813
814 let server_result = if let Some(refresh_jwt) = refresh_jwt {
815 let mut headers = HeadersMap::new();
816 headers.insert("Authorization".into(), format!("Bearer {}", refresh_jwt));
817 let opts = CallOptions {
818 encoding: None,
819 headers: Some(headers),
820 ..Default::default()
821 };
822 self.client
823 .procedure(
824 "com.atproto.server.deleteSession",
825 None,
826 None,
827 Some(&opts),
828 )
829 .await
830 .map(|_| ())
831 } else {
832 Ok(())
833 };
834
835 *self.session.write().await = None;
837 self.emit(AtpSessionEvent::Expired, None);
838
839 server_result.map_err(AgentError::Xrpc)
840 }
841
842 pub async fn create_account(
851 &self,
852 handle: &str,
853 password: &str,
854 email: Option<&str>,
855 extra: Option<serde_json::Value>,
856 ) -> Result<Session, AgentError> {
857 let mut body = serde_json::json!({
858 "handle": handle,
859 "password": password,
860 });
861 if let Some(email) = email {
862 body["email"] = serde_json::Value::String(email.to_string());
863 }
864 if let Some(extra) = extra
865 && let Some(extra_map) = extra.as_object()
866 && let Some(body_map) = body.as_object_mut()
867 {
868 for (k, v) in extra_map {
869 body_map.insert(k.clone(), v.clone());
870 }
871 }
872
873 let response = match self
874 .client
875 .procedure(
876 "com.atproto.server.createAccount",
877 None,
878 Some(XrpcBody::Json(body)),
879 None,
880 )
881 .await
882 {
883 Ok(r) => r,
884 Err(e) => {
885 self.emit(AtpSessionEvent::CreateFailed, None);
886 return Err(AgentError::Xrpc(e));
887 }
888 };
889
890 let session: Session = serde_json::from_value(response.data)?;
891 *self.session.write().await = Some(session.clone());
892 self.emit(AtpSessionEvent::Create, Some(&session));
893 Ok(session)
894 }
895
896 pub async fn upsert_profile<F>(&self, mutate: F) -> Result<serde_json::Value, AgentError>
907 where
908 F: Fn(serde_json::Value) -> serde_json::Value,
909 {
910 let did = self.assert_did().await?;
911 const MAX_RETRIES: u32 = 5;
912
913 for _ in 0..MAX_RETRIES {
914 let existing_result = self
916 .xrpc_query(
917 "com.atproto.repo.getRecord",
918 Some(&{
919 let mut p = QueryParams::new();
920 p.insert("repo".into(), QueryValue::String(did.clone()));
921 p.insert(
922 "collection".into(),
923 QueryValue::String("app.bsky.actor.profile".into()),
924 );
925 p.insert("rkey".into(), QueryValue::String("self".into()));
926 p
927 }),
928 )
929 .await;
930
931 let (existing_record, swap_cid) = match existing_result {
932 Ok(r) => {
933 let record = r.get("value").cloned().unwrap_or(serde_json::Value::Null);
934 let cid = r.get("cid").and_then(|v| v.as_str()).map(String::from);
935 (record, cid)
936 }
937 Err(AgentError::Xrpc(ref e)) if is_not_found(e) => {
938 (serde_json::Value::Null, None)
939 }
940 Err(e) => return Err(e),
941 };
942
943 let updated = mutate(existing_record);
944 let mut body = serde_json::json!({
945 "repo": did,
946 "collection": "app.bsky.actor.profile",
947 "rkey": "self",
948 "record": updated,
949 });
950 if let Some(cid) = swap_cid {
951 body["swapRecord"] = serde_json::Value::String(cid);
952 }
953
954 match self
955 .xrpc_procedure("com.atproto.repo.putRecord", body)
956 .await
957 {
958 Ok(r) => return Ok(r),
959 Err(AgentError::Xrpc(ref e)) if is_invalid_swap(e) => {
960 continue;
963 }
964 Err(e) => return Err(e),
965 }
966 }
967
968 Err(AgentError::Other(
969 "upsert_profile: exceeded maximum retries due to concurrent writes".into(),
970 ))
971 }
972}
973
974fn is_not_found(err: &proto_blue_xrpc::Error) -> bool {
977 match err {
978 proto_blue_xrpc::Error::Xrpc(x) => x.is_error("RecordNotFound"),
979 _ => false,
980 }
981}
982
983fn is_invalid_swap(err: &proto_blue_xrpc::Error) -> bool {
986 match err {
987 proto_blue_xrpc::Error::Xrpc(x) => x.is_error("InvalidSwap"),
988 _ => false,
989 }
990}
991
992fn is_auth_expired(err: &proto_blue_xrpc::Error) -> bool {
999 match err {
1000 proto_blue_xrpc::Error::Xrpc(x) => {
1001 matches!(x.status, ResponseType::AuthenticationRequired)
1002 && x.is_error("ExpiredToken")
1003 }
1004 _ => false,
1005 }
1006}
1007
1008fn is_refresh_rejected(err: &proto_blue_xrpc::Error) -> bool {
1013 match err {
1014 proto_blue_xrpc::Error::Xrpc(x) => {
1015 matches!(x.status, ResponseType::AuthenticationRequired)
1016 }
1017 _ => false,
1018 }
1019}
1020
1021#[cfg(test)]
1022mod tests {
1023 use super::*;
1024
1025 #[test]
1026 fn agent_creation() {
1027 let _agent = Agent::new("https://bsky.social").unwrap();
1028 }
1029
1030 #[test]
1031 fn session_serde_roundtrip() {
1032 let session = Session {
1033 did: "did:plc:abc123".to_string(),
1034 handle: "alice.bsky.social".to_string(),
1035 access_jwt: "eyJ...".to_string(),
1036 refresh_jwt: "eyJ...".to_string(),
1037 email: Some("alice@example.com".to_string()),
1038 email_confirmed: Some(true),
1039 };
1040
1041 let json = serde_json::to_string(&session).unwrap();
1042 let parsed: Session = serde_json::from_str(&json).unwrap();
1043 assert_eq!(parsed.did, "did:plc:abc123");
1044 assert_eq!(parsed.handle, "alice.bsky.social");
1045 assert_eq!(parsed.email, Some("alice@example.com".to_string()));
1046 }
1047
1048 #[test]
1049 fn agent_error_display() {
1050 let err = AgentError::NotAuthenticated;
1051 assert_eq!(err.to_string(), "Not authenticated");
1052
1053 let err = AgentError::Other("test error".into());
1054 assert_eq!(err.to_string(), "test error");
1055 }
1056
1057 #[tokio::test]
1058 async fn agent_no_session_by_default() {
1059 let agent = Agent::new("https://bsky.social").unwrap();
1060 assert!(agent.did().await.is_none());
1061 assert!(agent.session().await.is_none());
1062 }
1063
1064 #[tokio::test]
1065 async fn agent_assert_did_fails_when_not_logged_in() {
1066 let agent = Agent::new("https://bsky.social").unwrap();
1067 let err = agent.assert_did().await.unwrap_err();
1068 assert!(matches!(err, AgentError::NotAuthenticated));
1069 }
1070
1071 #[test]
1072 fn now_iso_format() {
1073 let ts = Agent::now_iso();
1074 assert!(ts.ends_with('Z'));
1075 assert!(ts.contains('T'));
1076 }
1077
1078 #[test]
1079 fn resolve_timestamp_with_provided() {
1080 let ts = Agent::resolve_timestamp(Some("2024-01-15T12:00:00.000Z"));
1081 assert_eq!(ts, "2024-01-15T12:00:00.000Z");
1082 }
1083
1084 #[test]
1085 fn resolve_timestamp_without_provided() {
1086 let ts = Agent::resolve_timestamp(None);
1087 assert!(ts.ends_with('Z'));
1088 assert!(ts.contains('T'));
1089 }
1090
1091 #[test]
1092 fn service_url_accessible_without_async() {
1093 let agent = Agent::new("https://bsky.social").unwrap();
1094 assert_eq!(agent.service(), "https://bsky.social/");
1095 }
1096
1097 #[tokio::test]
1098 async fn auth_call_options_none_when_not_authenticated() {
1099 let agent = Agent::new("https://bsky.social").unwrap();
1100 assert!(agent.auth_call_options().await.is_none());
1101 }
1102
1103 use async_trait::async_trait;
1106 use proto_blue_common::fetch::{
1107 FetchError, FetchHandler, HttpRequest, HttpResponse,
1108 };
1109
1110 struct ScriptedFetcher {
1114 createsession_body: Vec<u8>,
1115 scripts: std::sync::Mutex<std::collections::HashMap<String, Vec<ScriptedResponse>>>,
1117 call_counts: std::sync::Mutex<std::collections::HashMap<String, usize>>,
1118 }
1119
1120 #[derive(Clone)]
1121 struct ScriptedResponse {
1122 status: u16,
1123 body: Vec<u8>,
1124 }
1125
1126 impl ScriptedFetcher {
1127 fn new(createsession_body: Vec<u8>) -> Self {
1128 Self {
1129 createsession_body,
1130 scripts: Default::default(),
1131 call_counts: Default::default(),
1132 }
1133 }
1134 fn script(&self, path: &str, responses: Vec<ScriptedResponse>) {
1135 self.scripts
1136 .lock()
1137 .unwrap()
1138 .insert(path.to_string(), responses);
1139 }
1140 fn call_count(&self, path: &str) -> usize {
1141 *self.call_counts.lock().unwrap().get(path).unwrap_or(&0)
1142 }
1143 }
1144
1145 #[async_trait]
1146 impl FetchHandler for ScriptedFetcher {
1147 async fn fetch(&self, req: HttpRequest) -> Result<HttpResponse, FetchError> {
1148 let path = req.url.clone();
1149 let key = path
1150 .split("/xrpc/")
1151 .nth(1)
1152 .unwrap_or(&path)
1153 .split('?')
1154 .next()
1155 .unwrap_or("")
1156 .to_string();
1157 *self
1158 .call_counts
1159 .lock()
1160 .unwrap()
1161 .entry(key.clone())
1162 .or_insert(0) += 1;
1163
1164 {
1168 let mut scripts = self.scripts.lock().unwrap();
1169 if let Some(list) = scripts.get_mut(&key) {
1170 let resp = if list.len() == 1 {
1171 list[0].clone()
1172 } else {
1173 list.remove(0)
1174 };
1175 let mut headers = proto_blue_common::fetch::HttpHeaders::new();
1176 headers.insert("content-type".into(), "application/json".into());
1177 return Ok(HttpResponse {
1178 status: resp.status,
1179 headers,
1180 body: resp.body,
1181 });
1182 }
1183 }
1184
1185 if key == "com.atproto.server.createSession" {
1187 let mut headers = proto_blue_common::fetch::HttpHeaders::new();
1188 headers.insert("content-type".into(), "application/json".into());
1189 return Ok(HttpResponse {
1190 status: 200,
1191 headers,
1192 body: self.createsession_body.clone(),
1193 });
1194 }
1195
1196 Err(FetchError::Other(format!("no script for {key}")))
1197 }
1198 }
1199
1200 fn login_body() -> Vec<u8> {
1201 br#"{"did":"did:plc:u","handle":"alice","accessJwt":"a1","refreshJwt":"r1"}"#
1202 .to_vec()
1203 }
1204
1205 fn agent_with_fetcher(fetcher: Arc<ScriptedFetcher>) -> Agent {
1206 let client = XrpcClient::with_fetch_handler(
1207 "https://example.com",
1208 fetcher,
1209 )
1210 .unwrap();
1211 Agent {
1212 client,
1213 session: Arc::new(RwLock::new(None)),
1214 listeners: Arc::new(Mutex::new(Vec::new())),
1215 refresh_lock: Arc::new(AsyncMutex::new(())),
1216 proxy: Arc::new(RwLock::new(None)),
1217 labelers: Arc::new(RwLock::new(Vec::new())),
1218 }
1219 }
1220
1221 #[tokio::test]
1222 async fn emits_create_on_successful_login() {
1223 let fetcher = Arc::new(ScriptedFetcher::new(login_body()));
1224 let agent = agent_with_fetcher(fetcher);
1225
1226 let events: Arc<Mutex<Vec<AtpSessionEvent>>> =
1227 Arc::new(Mutex::new(Vec::new()));
1228 let ev_clone = events.clone();
1229 agent.on_session(move |e, _| ev_clone.lock().unwrap().push(e));
1230
1231 agent.login("alice", "secret").await.unwrap();
1232 let got = events.lock().unwrap().clone();
1233 assert_eq!(got, vec![AtpSessionEvent::Create]);
1234 }
1235
1236 #[tokio::test]
1237 async fn emits_create_failed_on_login_rejection() {
1238 let fetcher = Arc::new(ScriptedFetcher::new(vec![]));
1239 fetcher.script(
1241 "com.atproto.server.createSession",
1242 vec![ScriptedResponse {
1243 status: 401,
1244 body: br#"{"error":"AuthenticationRequired","message":"bad pwd"}"#
1245 .to_vec(),
1246 }],
1247 );
1248 let agent = agent_with_fetcher(fetcher);
1249
1250 let events: Arc<Mutex<Vec<AtpSessionEvent>>> =
1251 Arc::new(Mutex::new(Vec::new()));
1252 let ev_clone = events.clone();
1253 agent.on_session(move |e, _| ev_clone.lock().unwrap().push(e));
1254
1255 let _ = agent.login("alice", "bad").await.unwrap_err();
1259 let got = events.lock().unwrap().clone();
1260 assert_eq!(got, vec![AtpSessionEvent::CreateFailed]);
1261 }
1262
1263 #[tokio::test]
1264 async fn auto_refreshes_on_expired_access_token() {
1265 let fetcher = Arc::new(ScriptedFetcher::new(login_body()));
1266
1267 fetcher.script(
1270 "com.atproto.server.describeServer",
1271 vec![
1272 ScriptedResponse {
1273 status: 401,
1274 body: br#"{"error":"ExpiredToken","message":"expired"}"#.to_vec(),
1275 },
1276 ScriptedResponse {
1277 status: 200,
1278 body: br#"{"did":"did:plc:svr"}"#.to_vec(),
1279 },
1280 ],
1281 );
1282 fetcher.script(
1283 "com.atproto.server.refreshSession",
1284 vec![ScriptedResponse {
1285 status: 200,
1286 body: br#"{"did":"did:plc:u","handle":"alice","accessJwt":"a2","refreshJwt":"r2"}"#
1287 .to_vec(),
1288 }],
1289 );
1290
1291 let agent = agent_with_fetcher(fetcher.clone());
1292 agent.login("alice", "secret").await.unwrap();
1293
1294 let events: Arc<Mutex<Vec<AtpSessionEvent>>> =
1295 Arc::new(Mutex::new(Vec::new()));
1296 let ev_clone = events.clone();
1297 agent.on_session(move |e, _| ev_clone.lock().unwrap().push(e));
1298
1299 let result = agent.describe_server().await.unwrap();
1300 assert_eq!(result["did"], "did:plc:svr");
1301
1302 assert_eq!(fetcher.call_count("com.atproto.server.describeServer"), 2);
1305 assert_eq!(fetcher.call_count("com.atproto.server.refreshSession"), 1);
1306
1307 let got = events.lock().unwrap().clone();
1309 assert_eq!(got, vec![AtpSessionEvent::Update]);
1310 }
1311
1312 #[tokio::test]
1313 async fn concurrent_expired_token_refreshes_once() {
1314 let fetcher = Arc::new(ScriptedFetcher::new(login_body()));
1315
1316 fetcher.script(
1319 "com.atproto.server.describeServer",
1320 vec![
1321 ScriptedResponse {
1322 status: 401,
1323 body: br#"{"error":"ExpiredToken","message":"expired"}"#.to_vec(),
1324 },
1325 ScriptedResponse {
1326 status: 200,
1327 body: br#"{"did":"did:plc:svr"}"#.to_vec(),
1328 },
1329 ],
1330 );
1331 fetcher.script(
1332 "com.atproto.server.refreshSession",
1333 vec![ScriptedResponse {
1334 status: 200,
1335 body: br#"{"did":"did:plc:u","handle":"alice","accessJwt":"a2","refreshJwt":"r2"}"#
1336 .to_vec(),
1337 }],
1338 );
1339
1340 let agent = Arc::new(agent_with_fetcher(fetcher.clone()));
1341 agent.login("alice", "secret").await.unwrap();
1342
1343 let mut handles = Vec::new();
1347 for _ in 0..5 {
1348 let a = agent.clone();
1349 handles.push(tokio::spawn(async move {
1350 a.describe_server().await.unwrap();
1351 }));
1352 }
1353 for h in handles {
1354 h.await.unwrap();
1355 }
1356
1357 assert_eq!(
1358 fetcher.call_count("com.atproto.server.refreshSession"),
1359 1,
1360 "concurrent callers must share one refreshSession call",
1361 );
1362 }
1363
1364 #[tokio::test]
1365 async fn configure_proxy_sets_header_on_next_call() {
1366 let fetcher = Arc::new(ScriptedFetcher::new(login_body()));
1367 fetcher.script(
1368 "com.atproto.server.describeServer",
1369 vec![ScriptedResponse {
1370 status: 200,
1371 body: br#"{"did":"did:plc:svr"}"#.to_vec(),
1372 }],
1373 );
1374 let agent = agent_with_fetcher(fetcher.clone());
1375 agent.configure_proxy(Some("did:web:api.bsky.chat#bsky_chat")).await;
1376
1377 agent.describe_server().await.unwrap();
1378
1379 let p = agent.proxy.read().await;
1382 assert_eq!(p.as_deref(), Some("did:web:api.bsky.chat#bsky_chat"));
1383 }
1384
1385 #[tokio::test]
1386 async fn configure_labelers_stores_list() {
1387 let fetcher = Arc::new(ScriptedFetcher::new(login_body()));
1388 let agent = agent_with_fetcher(fetcher);
1389 agent
1390 .configure_labelers(&[
1391 LabelerOpts {
1392 did: "did:plc:a".into(),
1393 redirect: false,
1394 },
1395 LabelerOpts {
1396 did: "did:plc:b".into(),
1397 redirect: true,
1398 },
1399 ])
1400 .await;
1401 let l = agent.labelers.read().await;
1402 assert_eq!(l.len(), 2);
1403 assert_eq!(l[0].header_value(), "did:plc:a");
1404 assert_eq!(l[1].header_value(), "did:plc:b;redirect");
1405 }
1406
1407 #[tokio::test]
1408 async fn logout_clears_session() {
1409 let fetcher = Arc::new(ScriptedFetcher::new(login_body()));
1410 fetcher.script(
1411 "com.atproto.server.deleteSession",
1412 vec![ScriptedResponse {
1413 status: 200,
1414 body: b"{}".to_vec(),
1415 }],
1416 );
1417 let agent = agent_with_fetcher(fetcher.clone());
1418 agent.login("alice", "secret").await.unwrap();
1419 assert!(agent.session().await.is_some());
1420 agent.logout().await.unwrap();
1421 assert!(agent.session().await.is_none());
1422 assert_eq!(
1423 fetcher.call_count("com.atproto.server.deleteSession"),
1424 1,
1425 );
1426 }
1427
1428 #[tokio::test]
1429 async fn logout_clears_session_even_on_server_error() {
1430 let fetcher = Arc::new(ScriptedFetcher::new(login_body()));
1431 fetcher.script(
1432 "com.atproto.server.deleteSession",
1433 vec![ScriptedResponse {
1434 status: 500,
1435 body: br#"{"error":"InternalServerError"}"#.to_vec(),
1436 }],
1437 );
1438 let agent = agent_with_fetcher(fetcher);
1439 agent.login("alice", "secret").await.unwrap();
1440 let _ = agent.logout().await;
1442 assert!(agent.session().await.is_none());
1443 }
1444
1445 #[tokio::test]
1446 async fn create_account_emits_create_on_success() {
1447 let fetcher = Arc::new(ScriptedFetcher::new(login_body()));
1448 fetcher.script(
1449 "com.atproto.server.createAccount",
1450 vec![ScriptedResponse {
1451 status: 200,
1452 body: br#"{"did":"did:plc:new","handle":"newuser","accessJwt":"a","refreshJwt":"r"}"#
1453 .to_vec(),
1454 }],
1455 );
1456 let agent = agent_with_fetcher(fetcher);
1457
1458 let events: Arc<Mutex<Vec<AtpSessionEvent>>> = Arc::new(Mutex::new(Vec::new()));
1459 let ev = events.clone();
1460 agent.on_session(move |e, _| ev.lock().unwrap().push(e));
1461
1462 let session = agent
1463 .create_account("newuser", "pw", Some("new@example.com"), None)
1464 .await
1465 .unwrap();
1466 assert_eq!(session.did, "did:plc:new");
1467 assert_eq!(events.lock().unwrap().clone(), vec![AtpSessionEvent::Create]);
1468 }
1469
1470 #[tokio::test]
1471 async fn upsert_profile_creates_when_absent() {
1472 let fetcher = Arc::new(ScriptedFetcher::new(login_body()));
1473 fetcher.script(
1475 "com.atproto.repo.getRecord",
1476 vec![ScriptedResponse {
1477 status: 400,
1478 body: br#"{"error":"RecordNotFound","message":"no such record"}"#.to_vec(),
1479 }],
1480 );
1481 fetcher.script(
1482 "com.atproto.repo.putRecord",
1483 vec![ScriptedResponse {
1484 status: 200,
1485 body: br#"{"uri":"at://did:plc:u/app.bsky.actor.profile/self","cid":"bafy"}"#
1486 .to_vec(),
1487 }],
1488 );
1489 let agent = agent_with_fetcher(fetcher);
1490 agent.login("alice", "secret").await.unwrap();
1491
1492 let result = agent
1493 .upsert_profile(|prev| {
1494 assert!(prev.is_null(), "no existing profile");
1495 serde_json::json!({"$type": "app.bsky.actor.profile", "displayName": "Alice"})
1496 })
1497 .await
1498 .unwrap();
1499 assert_eq!(result["uri"], "at://did:plc:u/app.bsky.actor.profile/self");
1500 }
1501
1502 #[tokio::test]
1503 async fn emits_expired_when_refresh_itself_401s() {
1504 let fetcher = Arc::new(ScriptedFetcher::new(login_body()));
1505 fetcher.script(
1506 "com.atproto.server.refreshSession",
1507 vec![ScriptedResponse {
1508 status: 401,
1509 body: br#"{"error":"AuthenticationRequired","message":"refresh expired"}"#.to_vec(),
1510 }],
1511 );
1512 let agent = agent_with_fetcher(fetcher);
1513 agent.login("alice", "secret").await.unwrap();
1514
1515 let events: Arc<Mutex<Vec<AtpSessionEvent>>> =
1516 Arc::new(Mutex::new(Vec::new()));
1517 let ev_clone = events.clone();
1518 agent.on_session(move |e, _| ev_clone.lock().unwrap().push(e));
1519
1520 let _ = agent.refresh_session().await.unwrap_err();
1521 let got = events.lock().unwrap().clone();
1522 assert_eq!(got, vec![AtpSessionEvent::Expired]);
1523 assert!(agent.session().await.is_none(), "session cleared on expired refresh");
1524 }
1525
1526}