1use std::sync::{Arc, Mutex};
7use tokio::sync::{Mutex as AsyncMutex, RwLock};
8
9use proto_blue_lex_data::Cid;
10use proto_blue_syntax::{AtIdentifier, AtUri, Did, Handle};
11use proto_blue_xrpc::{
12 CallOptions, HeadersMap, QueryParams, QueryValue, ResponseType, XrpcBody, XrpcClient,
13};
14
15use crate::rich_text::RichText;
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24pub enum AtpSessionEvent {
25 Create,
27 CreateFailed,
29 Update,
31 Expired,
33 NetworkError,
35}
36
37pub type SessionEventCallback = Arc<dyn Fn(AtpSessionEvent, Option<&Session>) + Send + Sync>;
44
45#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
53#[serde(rename_all = "camelCase")]
54pub struct Session {
55 pub did: Did,
56 pub handle: Handle,
57 pub access_jwt: String,
58 pub refresh_jwt: String,
59 #[serde(skip_serializing_if = "Option::is_none")]
60 pub email: Option<String>,
61 #[serde(skip_serializing_if = "Option::is_none")]
62 pub email_confirmed: Option<bool>,
63}
64
65#[derive(Debug, thiserror::Error)]
67pub enum AgentError {
68 #[error("XRPC error: {0}")]
69 Xrpc(#[from] proto_blue_xrpc::Error),
70 #[error("Not authenticated")]
71 NotAuthenticated,
72 #[error("JSON error: {0}")]
73 Json(#[from] serde_json::Error),
74 #[error("{0}")]
75 Other(String),
76}
77
78pub struct Agent {
96 client: XrpcClient,
97 session: Arc<RwLock<Option<Session>>>,
98 listeners: Arc<Mutex<Vec<SessionEventCallback>>>,
101 refresh_lock: Arc<AsyncMutex<()>>,
106 proxy: Arc<RwLock<Option<String>>>,
109 labelers: Arc<RwLock<Vec<LabelerOpts>>>,
111}
112
113#[derive(Debug, Clone, PartialEq, Eq)]
115pub struct LabelerOpts {
116 pub did: Did,
118 pub redirect: bool,
121}
122
123impl LabelerOpts {
124 fn header_value(&self) -> String {
126 if self.redirect {
127 format!("{};redirect", self.did)
128 } else {
129 self.did.to_string()
130 }
131 }
132}
133
134impl Agent {
135 #[cfg(any(
141 all(feature = "fetch-reqwest", not(target_arch = "wasm32")),
142 target_arch = "wasm32",
143 ))]
144 pub fn new(service: impl AsRef<str>) -> Result<Self, AgentError> {
145 let client = XrpcClient::new(service)?;
146 Ok(Self {
147 client,
148 session: Arc::new(RwLock::new(None)),
149 listeners: Arc::new(Mutex::new(Vec::new())),
150 refresh_lock: Arc::new(AsyncMutex::new(())),
151 proxy: Arc::new(RwLock::new(None)),
152 labelers: Arc::new(RwLock::new(Vec::new())),
153 })
154 }
155
156 pub fn on_session<F>(&self, callback: F)
163 where
164 F: Fn(AtpSessionEvent, Option<&Session>) + Send + Sync + 'static,
165 {
166 self.listeners.lock().unwrap().push(Arc::new(callback));
167 }
168
169 fn emit(&self, event: AtpSessionEvent, session: Option<&Session>) {
171 let listeners = self.listeners.lock().unwrap().clone();
176 for cb in listeners {
177 cb(event, session);
178 }
179 }
180
181 #[must_use]
183 pub fn service(&self) -> String {
184 self.client.service_url().to_string()
185 }
186
187 pub async fn did(&self) -> Option<Did> {
189 self.session.read().await.as_ref().map(|s| s.did.clone())
190 }
191
192 pub async fn session(&self) -> Option<Session> {
194 self.session.read().await.clone()
195 }
196
197 async fn auth_call_options(&self) -> Option<CallOptions> {
204 let guard = self.session.read().await;
205 let session = guard.as_ref()?;
206 let mut headers = HeadersMap::new();
207 headers.insert(
208 "Authorization".into(),
209 format!("Bearer {}", session.access_jwt),
210 );
211 self.inject_proxy_and_labelers(&mut headers).await;
212 Some(CallOptions {
213 encoding: None,
214 headers: Some(headers),
215 ..Default::default()
216 })
217 }
218
219 pub async fn anon_call_options(&self) -> Option<CallOptions> {
225 let mut headers = HeadersMap::new();
226 self.inject_proxy_and_labelers(&mut headers).await;
227 if headers.is_empty() {
228 None
229 } else {
230 Some(CallOptions {
231 encoding: None,
232 headers: Some(headers),
233 ..Default::default()
234 })
235 }
236 }
237
238 async fn inject_proxy_and_labelers(&self, headers: &mut HeadersMap) {
239 if let Some(proxy) = self.proxy.read().await.as_ref() {
240 headers.insert("atproto-proxy".into(), proxy.clone());
241 }
242 let labelers = self.labelers.read().await;
243 if !labelers.is_empty() {
244 let v = labelers
245 .iter()
246 .map(LabelerOpts::header_value)
247 .collect::<Vec<_>>()
248 .join(", ");
249 headers.insert("atproto-accept-labelers".into(), v);
250 }
251 }
252
253 pub async fn configure_proxy(&self, target: Option<&str>) {
259 *self.proxy.write().await = target.map(String::from);
260 }
261
262 pub async fn with_proxy(&self, target: &str) -> Self {
265 let cloned = self.shallow_clone();
266 cloned.configure_proxy(Some(target)).await;
267 cloned
268 }
269
270 pub async fn configure_labelers(&self, labelers: &[LabelerOpts]) {
273 *self.labelers.write().await = labelers.to_vec();
274 }
275
276 fn shallow_clone(&self) -> Self {
280 Self {
281 client: self.client.clone(),
282 session: self.session.clone(),
283 listeners: self.listeners.clone(),
284 refresh_lock: self.refresh_lock.clone(),
285 proxy: Arc::new(RwLock::new(None)),
286 labelers: self.labelers.clone(),
287 }
288 }
289
290 pub async fn login(
296 &self,
297 identifier: &AtIdentifier,
298 password: &str,
299 ) -> Result<Session, AgentError> {
300 let body = serde_json::json!({
301 "identifier": identifier,
302 "password": password,
303 });
304
305 let response = match self
306 .client
307 .procedure(
308 "com.atproto.server.createSession",
309 None,
310 Some(XrpcBody::Json(body)),
311 None,
312 )
313 .await
314 {
315 Ok(r) => r,
316 Err(e) => {
317 self.emit(AtpSessionEvent::CreateFailed, None);
318 return Err(AgentError::Xrpc(e));
319 }
320 };
321
322 let session: Session = serde_json::from_value(response.data)?;
323
324 *self.session.write().await = Some(session.clone());
326 self.emit(AtpSessionEvent::Create, Some(&session));
327 Ok(session)
328 }
329
330 pub async fn resume_session(&self, session: Session) -> Result<(), AgentError> {
335 let mut headers = HeadersMap::new();
338 headers.insert(
339 "Authorization".into(),
340 format!("Bearer {}", session.access_jwt),
341 );
342 let opts = CallOptions {
343 encoding: None,
344 headers: Some(headers),
345 ..Default::default()
346 };
347 let response = self
348 .client
349 .query("com.atproto.server.getSession", None, Some(&opts))
350 .await?;
351 let verified_did = response
352 .data
353 .get("did")
354 .and_then(|v| v.as_str())
355 .map(Did::new)
356 .transpose()
357 .map_err(|e| AgentError::Other(format!("server returned invalid DID: {e}")))?;
358
359 let mut committed = session;
361 if let Some(did) = verified_did {
362 committed.did = did;
363 }
364 *self.session.write().await = Some(committed.clone());
365 self.emit(AtpSessionEvent::Create, Some(&committed));
366
367 Ok(())
368 }
369
370 pub async fn refresh_session(&self) -> Result<Session, AgentError> {
378 let refresh_jwt = {
379 let sess = self.session.read().await;
380 let sess = sess.as_ref().ok_or(AgentError::NotAuthenticated)?;
381 sess.refresh_jwt.clone()
382 };
383
384 let mut headers = HeadersMap::new();
386 headers.insert("Authorization".into(), format!("Bearer {refresh_jwt}"));
387 let opts = CallOptions {
388 encoding: None,
389 headers: Some(headers),
390 ..Default::default()
391 };
392
393 let response = match self
394 .client
395 .procedure("com.atproto.server.refreshSession", None, None, Some(&opts))
396 .await
397 {
398 Ok(r) => r,
399 Err(e) => {
400 if is_refresh_rejected(&e) {
406 *self.session.write().await = None;
407 self.emit(AtpSessionEvent::Expired, None);
408 } else {
409 self.emit(AtpSessionEvent::NetworkError, None);
410 }
411 return Err(AgentError::Xrpc(e));
412 }
413 };
414
415 let session: Session = serde_json::from_value(response.data)?;
416
417 *self.session.write().await = Some(session.clone());
419 self.emit(AtpSessionEvent::Update, Some(&session));
420 Ok(session)
421 }
422
423 async fn assert_did(&self) -> Result<Did, AgentError> {
427 self.did().await.ok_or(AgentError::NotAuthenticated)
428 }
429
430 async fn xrpc_query(
437 &self,
438 nsid: &str,
439 params: Option<&QueryParams>,
440 ) -> Result<serde_json::Value, AgentError> {
441 let opts = self.auth_call_options().await;
442 let first = self.client.query(nsid, params, opts.as_ref()).await;
443 match first {
444 Ok(r) => Ok(r.data),
445 Err(e) if is_auth_expired(&e) => {
446 self.refresh_and_retry(|opts| {
447 let c = self.client.clone();
448 let nsid = nsid.to_string();
449 let params = params.cloned();
450 async move { c.query(&nsid, params.as_ref(), opts.as_ref()).await }
451 })
452 .await
453 }
454 Err(e) => Err(AgentError::Xrpc(e)),
455 }
456 }
457
458 async fn xrpc_procedure(
460 &self,
461 nsid: &str,
462 body: serde_json::Value,
463 ) -> Result<serde_json::Value, AgentError> {
464 let opts = self.auth_call_options().await;
465 let first = self
466 .client
467 .procedure(
468 nsid,
469 None,
470 Some(XrpcBody::Json(body.clone())),
471 opts.as_ref(),
472 )
473 .await;
474 match first {
475 Ok(r) => Ok(r.data),
476 Err(e) if is_auth_expired(&e) => {
477 self.refresh_and_retry(|opts| {
478 let c = self.client.clone();
479 let nsid = nsid.to_string();
480 let body = body.clone();
481 async move {
482 c.procedure(&nsid, None, Some(XrpcBody::Json(body)), opts.as_ref())
483 .await
484 }
485 })
486 .await
487 }
488 Err(e) => Err(AgentError::Xrpc(e)),
489 }
490 }
491
492 async fn refresh_and_retry<F, Fut>(&self, replay: F) -> Result<serde_json::Value, AgentError>
502 where
503 F: FnOnce(Option<CallOptions>) -> Fut,
504 Fut: std::future::Future<
505 Output = Result<proto_blue_xrpc::XrpcResponse, proto_blue_xrpc::Error>,
506 >,
507 {
508 let pre_refresh_jwt = self
512 .session
513 .read()
514 .await
515 .as_ref()
516 .map(|s| s.access_jwt.clone());
517 let _guard = self.refresh_lock.lock().await;
518 let current_jwt = self
519 .session
520 .read()
521 .await
522 .as_ref()
523 .map(|s| s.access_jwt.clone());
524 if pre_refresh_jwt == current_jwt {
525 self.refresh_session().await?;
527 }
528 drop(_guard);
529
530 let opts = self.auth_call_options().await;
531 let response = replay(opts).await?;
532 Ok(response.data)
533 }
534
535 async fn create_record(
537 &self,
538 collection: &str,
539 record: serde_json::Value,
540 ) -> Result<serde_json::Value, AgentError> {
541 let did = self.assert_did().await?;
542 let body = serde_json::json!({
543 "repo": did,
544 "collection": collection,
545 "record": record,
546 });
547 self.xrpc_procedure("com.atproto.repo.createRecord", body)
548 .await
549 }
550
551 async fn delete_record(&self, collection: &str, uri: &AtUri) -> Result<(), AgentError> {
553 let did = self.assert_did().await?;
554 let rkey = uri
555 .rkey()
556 .ok_or_else(|| AgentError::Other("AT-URI has no rkey segment".into()))?;
557
558 let body = serde_json::json!({
559 "repo": did,
560 "collection": collection,
561 "rkey": rkey,
562 });
563 self.xrpc_procedure("com.atproto.repo.deleteRecord", body)
564 .await?;
565 Ok(())
566 }
567
568 fn now_iso() -> String {
570 chrono::Utc::now().to_rfc3339_opts(chrono::SecondsFormat::Millis, true)
571 }
572
573 fn resolve_timestamp(created_at: Option<&str>) -> String {
575 created_at.map_or_else(Self::now_iso, String::from)
576 }
577
578 pub async fn post(
584 &self,
585 text: &str,
586 facets: Option<Vec<crate::rich_text::Facet>>,
587 created_at: Option<&str>,
588 ) -> Result<serde_json::Value, AgentError> {
589 let mut record = serde_json::json!({
590 "$type": "app.bsky.feed.post",
591 "text": text,
592 "createdAt": Self::resolve_timestamp(created_at),
593 });
594
595 if let Some(facets) = facets {
596 record["facets"] = serde_json::to_value(&facets)?;
597 }
598
599 self.create_record("app.bsky.feed.post", record).await
600 }
601
602 pub async fn post_rich(
604 &self,
605 rt: &RichText,
606 created_at: Option<&str>,
607 ) -> Result<serde_json::Value, AgentError> {
608 let facets = if rt.facets().is_empty() {
609 None
610 } else {
611 Some(rt.facets().to_vec())
612 };
613 self.post(rt.text(), facets, created_at).await
614 }
615
616 pub async fn delete_post(&self, uri: &AtUri) -> Result<(), AgentError> {
618 self.delete_record("app.bsky.feed.post", uri).await
619 }
620
621 pub async fn like(
627 &self,
628 uri: &AtUri,
629 cid: &Cid,
630 created_at: Option<&str>,
631 ) -> Result<serde_json::Value, AgentError> {
632 let record = serde_json::json!({
633 "$type": "app.bsky.feed.like",
634 "subject": { "uri": uri, "cid": cid },
635 "createdAt": Self::resolve_timestamp(created_at),
636 });
637 self.create_record("app.bsky.feed.like", record).await
638 }
639
640 pub async fn delete_like(&self, like_uri: &AtUri) -> Result<(), AgentError> {
642 self.delete_record("app.bsky.feed.like", like_uri).await
643 }
644
645 pub async fn repost(
649 &self,
650 uri: &AtUri,
651 cid: &Cid,
652 created_at: Option<&str>,
653 ) -> Result<serde_json::Value, AgentError> {
654 let record = serde_json::json!({
655 "$type": "app.bsky.feed.repost",
656 "subject": { "uri": uri, "cid": cid },
657 "createdAt": Self::resolve_timestamp(created_at),
658 });
659 self.create_record("app.bsky.feed.repost", record).await
660 }
661
662 pub async fn delete_repost(&self, repost_uri: &AtUri) -> Result<(), AgentError> {
664 self.delete_record("app.bsky.feed.repost", repost_uri).await
665 }
666
667 pub async fn follow(
673 &self,
674 subject_did: &Did,
675 created_at: Option<&str>,
676 ) -> Result<serde_json::Value, AgentError> {
677 let record = serde_json::json!({
678 "$type": "app.bsky.graph.follow",
679 "subject": subject_did,
680 "createdAt": Self::resolve_timestamp(created_at),
681 });
682 self.create_record("app.bsky.graph.follow", record).await
683 }
684
685 pub async fn delete_follow(&self, follow_uri: &AtUri) -> Result<(), AgentError> {
687 self.delete_record("app.bsky.graph.follow", follow_uri)
688 .await
689 }
690
691 pub async fn get_profile(&self, actor: &AtIdentifier) -> Result<serde_json::Value, AgentError> {
695 let mut params = QueryParams::new();
696 params.insert("actor".into(), QueryValue::String(actor.to_string()));
697 self.xrpc_query("app.bsky.actor.getProfile", Some(¶ms))
698 .await
699 }
700
701 pub async fn get_timeline(
703 &self,
704 limit: Option<i64>,
705 cursor: Option<&str>,
706 ) -> Result<serde_json::Value, AgentError> {
707 let mut params = QueryParams::new();
708 if let Some(limit) = limit {
709 params.insert("limit".into(), QueryValue::Integer(limit));
710 }
711 if let Some(cursor) = cursor {
712 params.insert("cursor".into(), QueryValue::String(cursor.into()));
713 }
714 self.xrpc_query("app.bsky.feed.getTimeline", Some(¶ms))
715 .await
716 }
717
718 pub async fn get_post_thread(
720 &self,
721 uri: &AtUri,
722 depth: Option<i64>,
723 ) -> Result<serde_json::Value, AgentError> {
724 let mut params = QueryParams::new();
725 params.insert("uri".into(), QueryValue::String(uri.to_string()));
726 if let Some(depth) = depth {
727 params.insert("depth".into(), QueryValue::Integer(depth));
728 }
729 self.xrpc_query("app.bsky.feed.getPostThread", Some(¶ms))
730 .await
731 }
732
733 pub async fn search_actors(
735 &self,
736 query: &str,
737 limit: Option<i64>,
738 ) -> Result<serde_json::Value, AgentError> {
739 let mut params = QueryParams::new();
740 params.insert("q".into(), QueryValue::String(query.into()));
741 if let Some(limit) = limit {
742 params.insert("limit".into(), QueryValue::Integer(limit));
743 }
744 self.xrpc_query("app.bsky.actor.searchActors", Some(¶ms))
745 .await
746 }
747
748 pub async fn resolve_handle(&self, handle: &Handle) -> Result<Did, AgentError> {
750 let mut params = QueryParams::new();
751 params.insert("handle".into(), QueryValue::String(handle.to_string()));
752 let data = self
753 .xrpc_query("com.atproto.identity.resolveHandle", Some(¶ms))
754 .await?;
755 let did_str = data
756 .get("did")
757 .and_then(|v| v.as_str())
758 .ok_or_else(|| AgentError::Other("Missing DID in response".into()))?;
759 Did::new(did_str)
760 .map_err(|e| AgentError::Other(format!("server returned invalid DID: {e}")))
761 }
762
763 pub async fn list_notifications(
765 &self,
766 limit: Option<i64>,
767 cursor: Option<&str>,
768 ) -> Result<serde_json::Value, AgentError> {
769 let mut params = QueryParams::new();
770 if let Some(limit) = limit {
771 params.insert("limit".into(), QueryValue::Integer(limit));
772 }
773 if let Some(cursor) = cursor {
774 params.insert("cursor".into(), QueryValue::String(cursor.into()));
775 }
776 self.xrpc_query("app.bsky.notification.listNotifications", Some(¶ms))
777 .await
778 }
779
780 pub async fn upload_blob(
782 &self,
783 data: Vec<u8>,
784 content_type: &str,
785 ) -> Result<serde_json::Value, AgentError> {
786 let mut headers = HeadersMap::new();
787 headers.insert("Content-Type".into(), content_type.into());
788
789 if let Some(sess) = self.session.read().await.as_ref() {
791 headers.insert(
792 "Authorization".into(),
793 format!("Bearer {}", sess.access_jwt),
794 );
795 }
796
797 let opts = CallOptions {
798 encoding: Some(content_type.to_string()),
799 headers: Some(headers),
800 ..Default::default()
801 };
802
803 let response = self
804 .client
805 .procedure(
806 "com.atproto.repo.uploadBlob",
807 None,
808 Some(XrpcBody::Bytes(data)),
809 Some(&opts),
810 )
811 .await?;
812
813 Ok(response.data)
814 }
815
816 pub async fn describe_server(&self) -> Result<serde_json::Value, AgentError> {
818 self.xrpc_query("com.atproto.server.describeServer", None)
819 .await
820 }
821
822 pub async fn logout(&self) -> Result<(), AgentError> {
832 let refresh_jwt = {
833 let guard = self.session.read().await;
834 guard.as_ref().map(|s| s.refresh_jwt.clone())
835 };
836
837 let server_result = if let Some(refresh_jwt) = refresh_jwt {
838 let mut headers = HeadersMap::new();
839 headers.insert("Authorization".into(), format!("Bearer {refresh_jwt}"));
840 let opts = CallOptions {
841 encoding: None,
842 headers: Some(headers),
843 ..Default::default()
844 };
845 self.client
846 .procedure("com.atproto.server.deleteSession", None, None, Some(&opts))
847 .await
848 .map(|_| ())
849 } else {
850 Ok(())
851 };
852
853 *self.session.write().await = None;
855 self.emit(AtpSessionEvent::Expired, None);
856
857 server_result.map_err(AgentError::Xrpc)
858 }
859
860 pub async fn create_account(
869 &self,
870 handle: &Handle,
871 password: &str,
872 email: Option<&str>,
873 extra: Option<serde_json::Value>,
874 ) -> Result<Session, AgentError> {
875 let mut body = serde_json::json!({
876 "handle": handle,
877 "password": password,
878 });
879 if let Some(email) = email {
880 body["email"] = serde_json::Value::String(email.to_string());
881 }
882 if let Some(extra) = extra
883 && let Some(extra_map) = extra.as_object()
884 && let Some(body_map) = body.as_object_mut()
885 {
886 for (k, v) in extra_map {
887 body_map.insert(k.clone(), v.clone());
888 }
889 }
890
891 let response = match self
892 .client
893 .procedure(
894 "com.atproto.server.createAccount",
895 None,
896 Some(XrpcBody::Json(body)),
897 None,
898 )
899 .await
900 {
901 Ok(r) => r,
902 Err(e) => {
903 self.emit(AtpSessionEvent::CreateFailed, None);
904 return Err(AgentError::Xrpc(e));
905 }
906 };
907
908 let session: Session = serde_json::from_value(response.data)?;
909 *self.session.write().await = Some(session.clone());
910 self.emit(AtpSessionEvent::Create, Some(&session));
911 Ok(session)
912 }
913
914 pub async fn upsert_profile<F>(&self, mutate: F) -> Result<serde_json::Value, AgentError>
925 where
926 F: Fn(serde_json::Value) -> serde_json::Value,
927 {
928 let did = self.assert_did().await?;
929 const MAX_RETRIES: u32 = 5;
930
931 for _ in 0..MAX_RETRIES {
932 let existing_result = self
934 .xrpc_query(
935 "com.atproto.repo.getRecord",
936 Some(&{
937 let mut p = QueryParams::new();
938 p.insert("repo".into(), QueryValue::String(did.to_string()));
939 p.insert(
940 "collection".into(),
941 QueryValue::String("app.bsky.actor.profile".into()),
942 );
943 p.insert("rkey".into(), QueryValue::String("self".into()));
944 p
945 }),
946 )
947 .await;
948
949 let (existing_record, swap_cid) = match existing_result {
950 Ok(r) => {
951 let record = r.get("value").cloned().unwrap_or(serde_json::Value::Null);
952 let cid = r.get("cid").and_then(|v| v.as_str()).map(String::from);
953 (record, cid)
954 }
955 Err(AgentError::Xrpc(ref e)) if is_not_found(e) => (serde_json::Value::Null, None),
956 Err(e) => return Err(e),
957 };
958
959 let updated = mutate(existing_record);
960 let mut body = serde_json::json!({
961 "repo": did,
962 "collection": "app.bsky.actor.profile",
963 "rkey": "self",
964 "record": updated,
965 });
966 if let Some(cid) = swap_cid {
967 body["swapRecord"] = serde_json::Value::String(cid);
968 }
969
970 match self
971 .xrpc_procedure("com.atproto.repo.putRecord", body)
972 .await
973 {
974 Ok(r) => return Ok(r),
975 Err(AgentError::Xrpc(ref e)) if is_invalid_swap(e) => {
976 continue;
979 }
980 Err(e) => return Err(e),
981 }
982 }
983
984 Err(AgentError::Other(
985 "upsert_profile: exceeded maximum retries due to concurrent writes".into(),
986 ))
987 }
988}
989
990fn is_not_found(err: &proto_blue_xrpc::Error) -> bool {
993 match err {
994 proto_blue_xrpc::Error::Xrpc(x) => x.is_error("RecordNotFound"),
995 _ => false,
996 }
997}
998
999fn is_invalid_swap(err: &proto_blue_xrpc::Error) -> bool {
1002 match err {
1003 proto_blue_xrpc::Error::Xrpc(x) => x.is_error("InvalidSwap"),
1004 _ => false,
1005 }
1006}
1007
1008fn is_auth_expired(err: &proto_blue_xrpc::Error) -> bool {
1015 match err {
1016 proto_blue_xrpc::Error::Xrpc(x) => {
1017 matches!(x.status, ResponseType::AuthenticationRequired) && x.is_error("ExpiredToken")
1018 }
1019 _ => false,
1020 }
1021}
1022
1023const fn is_refresh_rejected(err: &proto_blue_xrpc::Error) -> bool {
1028 match err {
1029 proto_blue_xrpc::Error::Xrpc(x) => {
1030 matches!(x.status, ResponseType::AuthenticationRequired)
1031 }
1032 _ => false,
1033 }
1034}
1035
1036#[cfg(test)]
1037mod tests {
1038 use super::*;
1039
1040 #[test]
1041 fn agent_creation() {
1042 let _agent = Agent::new("https://bsky.social").unwrap();
1043 }
1044
1045 #[test]
1046 fn session_serde_roundtrip() {
1047 let session = Session {
1048 did: Did::new("did:plc:abc123").unwrap(),
1049 handle: Handle::new("alice.bsky.social").unwrap(),
1050 access_jwt: "eyJ...".to_string(),
1051 refresh_jwt: "eyJ...".to_string(),
1052 email: Some("alice@example.com".to_string()),
1053 email_confirmed: Some(true),
1054 };
1055
1056 let json = serde_json::to_string(&session).unwrap();
1057 let parsed: Session = serde_json::from_str(&json).unwrap();
1058 assert_eq!(parsed.did.as_str(), "did:plc:abc123");
1059 assert_eq!(parsed.handle.as_str(), "alice.bsky.social");
1060 assert_eq!(parsed.email, Some("alice@example.com".to_string()));
1061 }
1062
1063 #[test]
1064 fn agent_error_display() {
1065 let err = AgentError::NotAuthenticated;
1066 assert_eq!(err.to_string(), "Not authenticated");
1067
1068 let err = AgentError::Other("test error".into());
1069 assert_eq!(err.to_string(), "test error");
1070 }
1071
1072 #[tokio::test]
1073 async fn agent_no_session_by_default() {
1074 let agent = Agent::new("https://bsky.social").unwrap();
1075 assert!(agent.did().await.is_none());
1076 assert!(agent.session().await.is_none());
1077 }
1078
1079 #[tokio::test]
1080 async fn agent_assert_did_fails_when_not_logged_in() {
1081 let agent = Agent::new("https://bsky.social").unwrap();
1082 let err = agent.assert_did().await.unwrap_err();
1083 assert!(matches!(err, AgentError::NotAuthenticated));
1084 }
1085
1086 #[test]
1087 fn now_iso_format() {
1088 let ts = Agent::now_iso();
1089 assert!(ts.ends_with('Z'));
1090 assert!(ts.contains('T'));
1091 }
1092
1093 #[test]
1094 fn resolve_timestamp_with_provided() {
1095 let ts = Agent::resolve_timestamp(Some("2024-01-15T12:00:00.000Z"));
1096 assert_eq!(ts, "2024-01-15T12:00:00.000Z");
1097 }
1098
1099 #[test]
1100 fn resolve_timestamp_without_provided() {
1101 let ts = Agent::resolve_timestamp(None);
1102 assert!(ts.ends_with('Z'));
1103 assert!(ts.contains('T'));
1104 }
1105
1106 #[test]
1107 fn service_url_accessible_without_async() {
1108 let agent = Agent::new("https://bsky.social").unwrap();
1109 assert_eq!(agent.service(), "https://bsky.social/");
1110 }
1111
1112 #[tokio::test]
1113 async fn auth_call_options_none_when_not_authenticated() {
1114 let agent = Agent::new("https://bsky.social").unwrap();
1115 assert!(agent.auth_call_options().await.is_none());
1116 }
1117
1118 use async_trait::async_trait;
1121 use proto_blue_common::fetch::{FetchError, FetchHandler, HttpRequest, HttpResponse};
1122
1123 struct ScriptedFetcher {
1127 createsession_body: Vec<u8>,
1128 scripts: std::sync::Mutex<std::collections::HashMap<String, Vec<ScriptedResponse>>>,
1130 call_counts: std::sync::Mutex<std::collections::HashMap<String, usize>>,
1131 }
1132
1133 #[derive(Clone)]
1134 struct ScriptedResponse {
1135 status: u16,
1136 body: Vec<u8>,
1137 }
1138
1139 impl ScriptedFetcher {
1140 fn new(createsession_body: Vec<u8>) -> Self {
1141 Self {
1142 createsession_body,
1143 scripts: Default::default(),
1144 call_counts: Default::default(),
1145 }
1146 }
1147 fn script(&self, path: &str, responses: Vec<ScriptedResponse>) {
1148 self.scripts
1149 .lock()
1150 .unwrap()
1151 .insert(path.to_string(), responses);
1152 }
1153 fn call_count(&self, path: &str) -> usize {
1154 *self.call_counts.lock().unwrap().get(path).unwrap_or(&0)
1155 }
1156 }
1157
1158 #[async_trait]
1159 impl FetchHandler for ScriptedFetcher {
1160 async fn fetch(&self, req: HttpRequest) -> Result<HttpResponse, FetchError> {
1161 let path = req.url.clone();
1162 let key = path
1163 .split("/xrpc/")
1164 .nth(1)
1165 .unwrap_or(&path)
1166 .split('?')
1167 .next()
1168 .unwrap_or("")
1169 .to_string();
1170 *self
1171 .call_counts
1172 .lock()
1173 .unwrap()
1174 .entry(key.clone())
1175 .or_insert(0) += 1;
1176
1177 {
1181 let mut scripts = self.scripts.lock().unwrap();
1182 if let Some(list) = scripts.get_mut(&key) {
1183 let resp = if list.len() == 1 {
1184 list[0].clone()
1185 } else {
1186 list.remove(0)
1187 };
1188 let mut headers = proto_blue_common::fetch::HttpHeaders::new();
1189 headers.insert("content-type".into(), "application/json".into());
1190 return Ok(HttpResponse {
1191 status: resp.status,
1192 headers,
1193 body: resp.body,
1194 });
1195 }
1196 }
1197
1198 if key == "com.atproto.server.createSession" {
1200 let mut headers = proto_blue_common::fetch::HttpHeaders::new();
1201 headers.insert("content-type".into(), "application/json".into());
1202 return Ok(HttpResponse {
1203 status: 200,
1204 headers,
1205 body: self.createsession_body.clone(),
1206 });
1207 }
1208
1209 Err(FetchError::Other(format!("no script for {key}")))
1210 }
1211 }
1212
1213 fn login_body() -> Vec<u8> {
1214 br#"{"did":"did:plc:u","handle":"alice.test","accessJwt":"a1","refreshJwt":"r1"}"#.to_vec()
1215 }
1216
1217 fn agent_with_fetcher(fetcher: Arc<ScriptedFetcher>) -> Agent {
1218 let client = XrpcClient::with_fetch_handler("https://example.com", fetcher).unwrap();
1219 Agent {
1220 client,
1221 session: Arc::new(RwLock::new(None)),
1222 listeners: Arc::new(Mutex::new(Vec::new())),
1223 refresh_lock: Arc::new(AsyncMutex::new(())),
1224 proxy: Arc::new(RwLock::new(None)),
1225 labelers: Arc::new(RwLock::new(Vec::new())),
1226 }
1227 }
1228
1229 #[tokio::test]
1230 async fn emits_create_on_successful_login() {
1231 let fetcher = Arc::new(ScriptedFetcher::new(login_body()));
1232 let agent = agent_with_fetcher(fetcher);
1233
1234 let events: Arc<Mutex<Vec<AtpSessionEvent>>> = Arc::new(Mutex::new(Vec::new()));
1235 let ev_clone = events.clone();
1236 agent.on_session(move |e, _| ev_clone.lock().unwrap().push(e));
1237
1238 agent
1239 .login(&AtIdentifier::new("alice.test").unwrap(), "secret")
1240 .await
1241 .unwrap();
1242 let got = events.lock().unwrap().clone();
1243 assert_eq!(got, vec![AtpSessionEvent::Create]);
1244 }
1245
1246 #[tokio::test]
1247 async fn emits_create_failed_on_login_rejection() {
1248 let fetcher = Arc::new(ScriptedFetcher::new(vec![]));
1249 fetcher.script(
1251 "com.atproto.server.createSession",
1252 vec![ScriptedResponse {
1253 status: 401,
1254 body: br#"{"error":"AuthenticationRequired","message":"bad pwd"}"#.to_vec(),
1255 }],
1256 );
1257 let agent = agent_with_fetcher(fetcher);
1258
1259 let events: Arc<Mutex<Vec<AtpSessionEvent>>> = Arc::new(Mutex::new(Vec::new()));
1260 let ev_clone = events.clone();
1261 agent.on_session(move |e, _| ev_clone.lock().unwrap().push(e));
1262
1263 let _ = agent
1267 .login(&AtIdentifier::new("alice.test").unwrap(), "bad")
1268 .await
1269 .unwrap_err();
1270 let got = events.lock().unwrap().clone();
1271 assert_eq!(got, vec![AtpSessionEvent::CreateFailed]);
1272 }
1273
1274 #[tokio::test]
1275 async fn auto_refreshes_on_expired_access_token() {
1276 let fetcher = Arc::new(ScriptedFetcher::new(login_body()));
1277
1278 fetcher.script(
1281 "com.atproto.server.describeServer",
1282 vec![
1283 ScriptedResponse {
1284 status: 401,
1285 body: br#"{"error":"ExpiredToken","message":"expired"}"#.to_vec(),
1286 },
1287 ScriptedResponse {
1288 status: 200,
1289 body: br#"{"did":"did:plc:svr"}"#.to_vec(),
1290 },
1291 ],
1292 );
1293 fetcher.script(
1294 "com.atproto.server.refreshSession",
1295 vec![ScriptedResponse {
1296 status: 200,
1297 body: br#"{"did":"did:plc:u","handle":"alice.test","accessJwt":"a2","refreshJwt":"r2"}"#
1298 .to_vec(),
1299 }],
1300 );
1301
1302 let agent = agent_with_fetcher(fetcher.clone());
1303 agent
1304 .login(&AtIdentifier::new("alice.test").unwrap(), "secret")
1305 .await
1306 .unwrap();
1307
1308 let events: Arc<Mutex<Vec<AtpSessionEvent>>> = Arc::new(Mutex::new(Vec::new()));
1309 let ev_clone = events.clone();
1310 agent.on_session(move |e, _| ev_clone.lock().unwrap().push(e));
1311
1312 let result = agent.describe_server().await.unwrap();
1313 assert_eq!(result["did"], "did:plc:svr");
1314
1315 assert_eq!(fetcher.call_count("com.atproto.server.describeServer"), 2);
1318 assert_eq!(fetcher.call_count("com.atproto.server.refreshSession"), 1);
1319
1320 let got = events.lock().unwrap().clone();
1322 assert_eq!(got, vec![AtpSessionEvent::Update]);
1323 }
1324
1325 #[tokio::test]
1326 async fn concurrent_expired_token_refreshes_once() {
1327 let fetcher = Arc::new(ScriptedFetcher::new(login_body()));
1328
1329 fetcher.script(
1332 "com.atproto.server.describeServer",
1333 vec![
1334 ScriptedResponse {
1335 status: 401,
1336 body: br#"{"error":"ExpiredToken","message":"expired"}"#.to_vec(),
1337 },
1338 ScriptedResponse {
1339 status: 200,
1340 body: br#"{"did":"did:plc:svr"}"#.to_vec(),
1341 },
1342 ],
1343 );
1344 fetcher.script(
1345 "com.atproto.server.refreshSession",
1346 vec![ScriptedResponse {
1347 status: 200,
1348 body: br#"{"did":"did:plc:u","handle":"alice.test","accessJwt":"a2","refreshJwt":"r2"}"#
1349 .to_vec(),
1350 }],
1351 );
1352
1353 let agent = Arc::new(agent_with_fetcher(fetcher.clone()));
1354 agent
1355 .login(&AtIdentifier::new("alice.test").unwrap(), "secret")
1356 .await
1357 .unwrap();
1358
1359 let mut handles = Vec::new();
1363 for _ in 0..5 {
1364 let a = agent.clone();
1365 handles.push(tokio::spawn(async move {
1366 a.describe_server().await.unwrap();
1367 }));
1368 }
1369 for h in handles {
1370 h.await.unwrap();
1371 }
1372
1373 assert_eq!(
1374 fetcher.call_count("com.atproto.server.refreshSession"),
1375 1,
1376 "concurrent callers must share one refreshSession call",
1377 );
1378 }
1379
1380 #[tokio::test]
1381 async fn configure_proxy_sets_header_on_next_call() {
1382 let fetcher = Arc::new(ScriptedFetcher::new(login_body()));
1383 fetcher.script(
1384 "com.atproto.server.describeServer",
1385 vec![ScriptedResponse {
1386 status: 200,
1387 body: br#"{"did":"did:plc:svr"}"#.to_vec(),
1388 }],
1389 );
1390 let agent = agent_with_fetcher(fetcher.clone());
1391 agent
1392 .configure_proxy(Some("did:web:api.bsky.chat#bsky_chat"))
1393 .await;
1394
1395 agent.describe_server().await.unwrap();
1396
1397 let p = agent.proxy.read().await;
1400 assert_eq!(p.as_deref(), Some("did:web:api.bsky.chat#bsky_chat"));
1401 }
1402
1403 #[tokio::test]
1404 async fn configure_labelers_stores_list() {
1405 let fetcher = Arc::new(ScriptedFetcher::new(login_body()));
1406 let agent = agent_with_fetcher(fetcher);
1407 agent
1408 .configure_labelers(&[
1409 LabelerOpts {
1410 did: Did::new("did:plc:a").unwrap(),
1411 redirect: false,
1412 },
1413 LabelerOpts {
1414 did: Did::new("did:plc:b").unwrap(),
1415 redirect: true,
1416 },
1417 ])
1418 .await;
1419 let l = agent.labelers.read().await;
1420 assert_eq!(l.len(), 2);
1421 assert_eq!(l[0].header_value(), "did:plc:a");
1422 assert_eq!(l[1].header_value(), "did:plc:b;redirect");
1423 }
1424
1425 #[tokio::test]
1426 async fn logout_clears_session() {
1427 let fetcher = Arc::new(ScriptedFetcher::new(login_body()));
1428 fetcher.script(
1429 "com.atproto.server.deleteSession",
1430 vec![ScriptedResponse {
1431 status: 200,
1432 body: b"{}".to_vec(),
1433 }],
1434 );
1435 let agent = agent_with_fetcher(fetcher.clone());
1436 agent
1437 .login(&AtIdentifier::new("alice.test").unwrap(), "secret")
1438 .await
1439 .unwrap();
1440 assert!(agent.session().await.is_some());
1441 agent.logout().await.unwrap();
1442 assert!(agent.session().await.is_none());
1443 assert_eq!(fetcher.call_count("com.atproto.server.deleteSession"), 1,);
1444 }
1445
1446 #[tokio::test]
1447 async fn logout_clears_session_even_on_server_error() {
1448 let fetcher = Arc::new(ScriptedFetcher::new(login_body()));
1449 fetcher.script(
1450 "com.atproto.server.deleteSession",
1451 vec![ScriptedResponse {
1452 status: 500,
1453 body: br#"{"error":"InternalServerError"}"#.to_vec(),
1454 }],
1455 );
1456 let agent = agent_with_fetcher(fetcher);
1457 agent
1458 .login(&AtIdentifier::new("alice.test").unwrap(), "secret")
1459 .await
1460 .unwrap();
1461 let _ = agent.logout().await;
1463 assert!(agent.session().await.is_none());
1464 }
1465
1466 #[tokio::test]
1467 async fn create_account_emits_create_on_success() {
1468 let fetcher = Arc::new(ScriptedFetcher::new(login_body()));
1469 fetcher.script(
1470 "com.atproto.server.createAccount",
1471 vec![ScriptedResponse {
1472 status: 200,
1473 body:
1474 br#"{"did":"did:plc:new","handle":"newuser.test","accessJwt":"a","refreshJwt":"r"}"#
1475 .to_vec(),
1476 }],
1477 );
1478 let agent = agent_with_fetcher(fetcher);
1479
1480 let events: Arc<Mutex<Vec<AtpSessionEvent>>> = Arc::new(Mutex::new(Vec::new()));
1481 let ev = events.clone();
1482 agent.on_session(move |e, _| ev.lock().unwrap().push(e));
1483
1484 let session = agent
1485 .create_account(
1486 &Handle::new("newuser.test").unwrap(),
1487 "pw",
1488 Some("new@example.com"),
1489 None,
1490 )
1491 .await
1492 .unwrap();
1493 assert_eq!(session.did.as_str(), "did:plc:new");
1494 assert_eq!(
1495 events.lock().unwrap().clone(),
1496 vec![AtpSessionEvent::Create]
1497 );
1498 }
1499
1500 #[tokio::test]
1501 async fn upsert_profile_creates_when_absent() {
1502 let fetcher = Arc::new(ScriptedFetcher::new(login_body()));
1503 fetcher.script(
1505 "com.atproto.repo.getRecord",
1506 vec![ScriptedResponse {
1507 status: 400,
1508 body: br#"{"error":"RecordNotFound","message":"no such record"}"#.to_vec(),
1509 }],
1510 );
1511 fetcher.script(
1512 "com.atproto.repo.putRecord",
1513 vec![ScriptedResponse {
1514 status: 200,
1515 body: br#"{"uri":"at://did:plc:u/app.bsky.actor.profile/self","cid":"bafy"}"#
1516 .to_vec(),
1517 }],
1518 );
1519 let agent = agent_with_fetcher(fetcher);
1520 agent
1521 .login(&AtIdentifier::new("alice.test").unwrap(), "secret")
1522 .await
1523 .unwrap();
1524
1525 let result = agent
1526 .upsert_profile(|prev| {
1527 assert!(prev.is_null(), "no existing profile");
1528 serde_json::json!({"$type": "app.bsky.actor.profile", "displayName": "Alice"})
1529 })
1530 .await
1531 .unwrap();
1532 assert_eq!(result["uri"], "at://did:plc:u/app.bsky.actor.profile/self");
1533 }
1534
1535 #[tokio::test]
1536 async fn emits_expired_when_refresh_itself_401s() {
1537 let fetcher = Arc::new(ScriptedFetcher::new(login_body()));
1538 fetcher.script(
1539 "com.atproto.server.refreshSession",
1540 vec![ScriptedResponse {
1541 status: 401,
1542 body: br#"{"error":"AuthenticationRequired","message":"refresh expired"}"#.to_vec(),
1543 }],
1544 );
1545 let agent = agent_with_fetcher(fetcher);
1546 agent
1547 .login(&AtIdentifier::new("alice.test").unwrap(), "secret")
1548 .await
1549 .unwrap();
1550
1551 let events: Arc<Mutex<Vec<AtpSessionEvent>>> = Arc::new(Mutex::new(Vec::new()));
1552 let ev_clone = events.clone();
1553 agent.on_session(move |e, _| ev_clone.lock().unwrap().push(e));
1554
1555 let _ = agent.refresh_session().await.unwrap_err();
1556 let got = events.lock().unwrap().clone();
1557 assert_eq!(got, vec![AtpSessionEvent::Expired]);
1558 assert!(
1559 agent.session().await.is_none(),
1560 "session cleared on expired refresh"
1561 );
1562 }
1563}