Skip to main content

proto_blue_api/
agent.rs

1//! AT Protocol Agent — high-level client wrapping XRPC.
2//!
3//! Provides session management, convenience methods for common operations,
4//! and namespace accessors for the full Lexicon API surface.
5
6use std::sync::{Arc, Mutex};
7use tokio::sync::{Mutex as AsyncMutex, RwLock};
8
9use proto_blue_lex_data::Cid;
10use proto_blue_syntax::{AtIdentifier, AtUri, Did, Handle};
11use proto_blue_xrpc::{
12    CallOptions, HeadersMap, QueryParams, QueryValue, ResponseType, XrpcBody, XrpcClient,
13};
14
15use crate::rich_text::RichText;
16
17/// Session lifecycle events emitted by [`Agent`].
18///
19/// Mirrors TS `AtpSessionEvent`. Register a listener via
20/// [`Agent::on_session`] to observe login / refresh / expiry. Typical
21/// use is to persist the session on `Create` / `Update` and to clear
22/// local state on `Expired`.
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24pub enum AtpSessionEvent {
25    /// A new session was established (successful login / resume).
26    Create,
27    /// A login attempt failed.
28    CreateFailed,
29    /// The session tokens were refreshed.
30    Update,
31    /// The server rejected the refresh — the user must log in again.
32    Expired,
33    /// A network-level failure during a session-affecting call.
34    NetworkError,
35}
36
37/// Callback registered via [`Agent::on_session`].
38///
39/// Invoked synchronously on the task that produced the event; handlers
40/// should not block for long. The `Option<&Session>` is `Some` for
41/// `Create` / `Update` and `None` for `CreateFailed` / `Expired` /
42/// `NetworkError`.
43pub type SessionEventCallback = Arc<dyn Fn(AtpSessionEvent, Option<&Session>) + Send + Sync>;
44
45/// Session data for an authenticated agent.
46///
47/// `did` and `handle` are typed as the validated `proto_blue_syntax`
48/// newtypes so callers can't accidentally pass random strings; the
49/// JWTs and email stay as `String` (no validated newtype exists for
50/// either, and over-validating an opaque token would just mean
51/// re-parsing on every request).
52#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
53#[serde(rename_all = "camelCase")]
54pub struct Session {
55    pub did: Did,
56    pub handle: Handle,
57    pub access_jwt: String,
58    pub refresh_jwt: String,
59    #[serde(skip_serializing_if = "Option::is_none")]
60    pub email: Option<String>,
61    #[serde(skip_serializing_if = "Option::is_none")]
62    pub email_confirmed: Option<bool>,
63}
64
65/// Errors from Agent operations.
66#[derive(Debug, thiserror::Error)]
67pub enum AgentError {
68    #[error("XRPC error: {0}")]
69    Xrpc(#[from] proto_blue_xrpc::Error),
70    #[error("Not authenticated")]
71    NotAuthenticated,
72    #[error("JSON error: {0}")]
73    Json(#[from] serde_json::Error),
74    #[error("{0}")]
75    Other(String),
76}
77
78/// High-level AT Protocol agent.
79///
80/// Auth state lives in a single `RwLock<Option<Session>>`. The XRPC client
81/// is never mutated after construction — auth headers are passed per-request.
82/// This avoids token leaks, giant-lock contention, and split-lock atomicity
83/// gaps that arise from storing auth in the client's default headers.
84///
85/// ## Transparent refresh
86///
87/// Every XRPC call goes through `xrpc_query_with_refresh` /
88/// `xrpc_procedure_with_refresh`, which detect 401 /
89/// `ExpiredToken` responses, call [`Agent::refresh_session`], and
90/// retry once. Concurrent refresh attempts are deduplicated via an
91/// async `Mutex` so N in-flight calls that all see an expired token
92/// issue exactly one `/refreshSession` request. If the refresh itself
93/// fails, the agent fires [`AtpSessionEvent::Expired`] and the
94/// original error propagates.
95pub struct Agent {
96    client: XrpcClient,
97    session: Arc<RwLock<Option<Session>>>,
98    /// Session-event listeners. Called synchronously on the task that
99    /// produced the event.
100    listeners: Arc<Mutex<Vec<SessionEventCallback>>>,
101    /// Serialises concurrent refreshes. The first caller to see a 401
102    /// acquires this lock, performs the refresh, and writes the new
103    /// session back; subsequent callers block until that finishes and
104    /// then observe the updated session when their retry fires.
105    refresh_lock: Arc<AsyncMutex<()>>,
106    /// Optional `atproto-proxy` target (e.g. `did:web:api.bsky.chat#bsky_chat`
107    /// when calling the chat service).
108    proxy: Arc<RwLock<Option<String>>>,
109    /// Optional set of labeler DIDs to send as `atproto-accept-labelers`.
110    labelers: Arc<RwLock<Vec<LabelerOpts>>>,
111}
112
113/// A single labeler entry, mirroring TS `Agent`'s `AtprotoLabelerDef`.
114#[derive(Debug, Clone, PartialEq, Eq)]
115pub struct LabelerOpts {
116    /// The labeler's DID (e.g. `did:plc:<labeler>`).
117    pub did: Did,
118    /// When `true`, this labeler is redirected to (sent as
119    /// `atproto-accept-labelers: did;redirect`). Matches TS behaviour.
120    pub redirect: bool,
121}
122
123impl LabelerOpts {
124    /// Format a single labeler for the `atproto-accept-labelers` header.
125    fn header_value(&self) -> String {
126        if self.redirect {
127            format!("{};redirect", self.did)
128        } else {
129            self.did.to_string()
130        }
131    }
132}
133
134impl Agent {
135    /// Create a new agent pointing at the given service URL.
136    ///
137    /// Available whenever [`XrpcClient::new`] is — native requires
138    /// `fetch-reqwest`; on wasm the default is always the browser
139    /// `fetch` backend.
140    #[cfg(any(
141        all(feature = "fetch-reqwest", not(target_arch = "wasm32")),
142        target_arch = "wasm32",
143    ))]
144    pub fn new(service: impl AsRef<str>) -> Result<Self, AgentError> {
145        let client = XrpcClient::new(service)?;
146        Ok(Self {
147            client,
148            session: Arc::new(RwLock::new(None)),
149            listeners: Arc::new(Mutex::new(Vec::new())),
150            refresh_lock: Arc::new(AsyncMutex::new(())),
151            proxy: Arc::new(RwLock::new(None)),
152            labelers: Arc::new(RwLock::new(Vec::new())),
153        })
154    }
155
156    /// Register a session-event listener.
157    ///
158    /// Returns `()`, not a handle — listener unregistration isn't
159    /// currently supported (the typical pattern is to register a
160    /// single persistence callback that lives for the Agent's
161    /// lifetime). Multiple listeners are fired in registration order.
162    pub fn on_session<F>(&self, callback: F)
163    where
164        F: Fn(AtpSessionEvent, Option<&Session>) + Send + Sync + 'static,
165    {
166        self.listeners.lock().unwrap().push(Arc::new(callback));
167    }
168
169    /// Fire an event to every registered listener.
170    fn emit(&self, event: AtpSessionEvent, session: Option<&Session>) {
171        // `listeners` is a sync Mutex; we clone the Arc<callback> list
172        // out from under the lock so the callbacks themselves run
173        // without holding it (they could be slow / could call back
174        // into `on_session`).
175        let listeners = self.listeners.lock().unwrap().clone();
176        for cb in listeners {
177            cb(event, session);
178        }
179    }
180
181    /// Get the service URL string.
182    #[must_use]
183    pub fn service(&self) -> String {
184        self.client.service_url().to_string()
185    }
186
187    /// Get the current session's DID, if logged in.
188    pub async fn did(&self) -> Option<Did> {
189        self.session.read().await.as_ref().map(|s| s.did.clone())
190    }
191
192    /// Get the current session, if any.
193    pub async fn session(&self) -> Option<Session> {
194        self.session.read().await.clone()
195    }
196
197    // --- Authentication ---
198
199    /// Build per-request `CallOptions` carrying the current access
200    /// token, proxy target, and labeler list. Returns `None` if not
201    /// authenticated (the proxy + labelers are still folded into the
202    /// call options of non-auth helpers via [`Self::anon_call_options`]).
203    async fn auth_call_options(&self) -> Option<CallOptions> {
204        // Narrow the read-lock guard: extract only the bearer string and
205        // drop the guard before we await `inject_proxy_and_labelers`,
206        // which separately re-locks state we don't want to deadlock on.
207        let access_jwt = {
208            let guard = self.session.read().await;
209            guard.as_ref()?.access_jwt.clone()
210        };
211        let mut headers = HeadersMap::new();
212        headers.insert("Authorization".into(), format!("Bearer {access_jwt}"));
213        self.inject_proxy_and_labelers(&mut headers).await;
214        Some(CallOptions {
215            encoding: None,
216            headers: Some(headers),
217            ..Default::default()
218        })
219    }
220
221    /// Build anonymous [`CallOptions`] carrying just the proxy and
222    /// labeler config, for methods that don't need auth.
223    ///
224    /// Exposed in case callers drive `XrpcClient::query` / `::procedure`
225    /// directly and want the agent's proxy / labeler headers folded in.
226    pub async fn anon_call_options(&self) -> Option<CallOptions> {
227        let mut headers = HeadersMap::new();
228        self.inject_proxy_and_labelers(&mut headers).await;
229        if headers.is_empty() {
230            None
231        } else {
232            Some(CallOptions {
233                encoding: None,
234                headers: Some(headers),
235                ..Default::default()
236            })
237        }
238    }
239
240    async fn inject_proxy_and_labelers(&self, headers: &mut HeadersMap) {
241        if let Some(proxy) = self.proxy.read().await.as_ref() {
242            headers.insert("atproto-proxy".into(), proxy.clone());
243        }
244        let labelers = self.labelers.read().await;
245        if !labelers.is_empty() {
246            let v = labelers
247                .iter()
248                .map(LabelerOpts::header_value)
249                .collect::<Vec<_>>()
250                .join(", ");
251            drop(labelers);
252            headers.insert("atproto-accept-labelers".into(), v);
253        }
254    }
255
256    /// Configure the service-proxy target (`atproto-proxy` header) for
257    /// every subsequent call. Pass `None` to clear.
258    ///
259    /// The canonical use case is chat, which runs on a different
260    /// service: `agent.configure_proxy(Some("did:web:api.bsky.chat#bsky_chat"))`.
261    pub async fn configure_proxy(&self, target: Option<&str>) {
262        *self.proxy.write().await = target.map(String::from);
263    }
264
265    /// Return a new [`Agent`] configured with the given proxy target.
266    /// Shares session state with this agent (cheap clone of internals).
267    pub async fn with_proxy(&self, target: &str) -> Self {
268        let cloned = self.shallow_clone();
269        cloned.configure_proxy(Some(target)).await;
270        cloned
271    }
272
273    /// Configure the set of labelers sent as `atproto-accept-labelers`.
274    /// Passing an empty slice clears the header.
275    pub async fn configure_labelers(&self, labelers: &[LabelerOpts]) {
276        *self.labelers.write().await = labelers.to_vec();
277    }
278
279    /// Shallow-clone the agent: shares session / listener / refresh
280    /// state but receives independent proxy + labeler config. Used by
281    /// [`Self::with_proxy`].
282    fn shallow_clone(&self) -> Self {
283        Self {
284            client: self.client.clone(),
285            session: self.session.clone(),
286            listeners: self.listeners.clone(),
287            refresh_lock: self.refresh_lock.clone(),
288            proxy: Arc::new(RwLock::new(None)),
289            labelers: self.labelers.clone(),
290        }
291    }
292
293    /// Log in with identifier (handle or DID) and password.
294    ///
295    /// Emits [`AtpSessionEvent::Create`] on success, or
296    /// [`AtpSessionEvent::CreateFailed`] if the server rejected the
297    /// credentials.
298    pub async fn login(
299        &self,
300        identifier: &AtIdentifier,
301        password: &str,
302    ) -> Result<Session, AgentError> {
303        let body = serde_json::json!({
304            "identifier": identifier,
305            "password": password,
306        });
307
308        let response = match self
309            .client
310            .procedure(
311                "com.atproto.server.createSession",
312                None,
313                Some(XrpcBody::Json(body)),
314                None,
315            )
316            .await
317        {
318            Ok(r) => r,
319            Err(e) => {
320                self.emit(AtpSessionEvent::CreateFailed, None);
321                return Err(AgentError::Xrpc(e));
322            }
323        };
324
325        let session: Session = serde_json::from_value(response.data)?;
326
327        // Atomically commit session in a single write lock
328        *self.session.write().await = Some(session.clone());
329        self.emit(AtpSessionEvent::Create, Some(&session));
330        Ok(session)
331    }
332
333    /// Resume an existing session.
334    ///
335    /// Verifies the session with the server *before* updating internal state.
336    /// If verification fails, the agent remains unauthenticated.
337    pub async fn resume_session(&self, session: Session) -> Result<(), AgentError> {
338        // Verify the session is valid by calling getSession with the provided token,
339        // WITHOUT updating the agent's state first. Use a per-request auth header.
340        let mut headers = HeadersMap::new();
341        headers.insert(
342            "Authorization".into(),
343            format!("Bearer {}", session.access_jwt),
344        );
345        let opts = CallOptions {
346            encoding: None,
347            headers: Some(headers),
348            ..Default::default()
349        };
350        let response = self
351            .client
352            .query("com.atproto.server.getSession", None, Some(&opts))
353            .await?;
354        let verified_did = response
355            .data
356            .get("did")
357            .and_then(|v| v.as_str())
358            .map(Did::new)
359            .transpose()
360            .map_err(|e| AgentError::Other(format!("server returned invalid DID: {e}")))?;
361
362        // Verification succeeded — atomically commit state in a single write lock
363        let mut committed = session;
364        if let Some(did) = verified_did {
365            committed.did = did;
366        }
367        *self.session.write().await = Some(committed.clone());
368        self.emit(AtpSessionEvent::Create, Some(&committed));
369
370        Ok(())
371    }
372
373    /// Refresh the current session tokens.
374    ///
375    /// Emits [`AtpSessionEvent::Update`] on success or
376    /// [`AtpSessionEvent::Expired`] if the refresh token was
377    /// rejected. Uses a per-request header for the refresh call so the
378    /// refresh JWT is never exposed as the global auth state. The new
379    /// session is committed atomically in a single write lock.
380    // The read guard's scope is bounded by an inner block (sess is dropped
381    // before the first .await downstream). Clippy flags the `?` early-return
382    // as "drop could be tighter" but reshaping the flow with explicit
383    // drop(sess) would add LoC without changing observable behavior. The
384    // attribute is fn-level because the inner-expression lint isn't silenced
385    // when the attribute is placed on the `let` binding.
386    #[allow(clippy::significant_drop_tightening)]
387    pub async fn refresh_session(&self) -> Result<Session, AgentError> {
388        let refresh_jwt = {
389            let sess = self.session.read().await;
390            let sess = sess.as_ref().ok_or(AgentError::NotAuthenticated)?;
391            sess.refresh_jwt.clone()
392        };
393
394        // Use per-request header for refresh — never mutate global auth state
395        let mut headers = HeadersMap::new();
396        headers.insert("Authorization".into(), format!("Bearer {refresh_jwt}"));
397        let opts = CallOptions {
398            encoding: None,
399            headers: Some(headers),
400            ..Default::default()
401        };
402
403        let response = match self
404            .client
405            .procedure("com.atproto.server.refreshSession", None, None, Some(&opts))
406            .await
407        {
408            Ok(r) => r,
409            Err(e) => {
410                // Any 401 during refresh means the refresh token
411                // itself is rejected — drop the session and signal
412                // Expired. Other errors (network failure, 5xx, etc.)
413                // surface as NetworkError and leave the session in
414                // place so a later attempt can retry.
415                if is_refresh_rejected(&e) {
416                    *self.session.write().await = None;
417                    self.emit(AtpSessionEvent::Expired, None);
418                } else {
419                    self.emit(AtpSessionEvent::NetworkError, None);
420                }
421                return Err(AgentError::Xrpc(e));
422            }
423        };
424
425        let session: Session = serde_json::from_value(response.data)?;
426
427        // Atomically commit new session in a single write lock
428        *self.session.write().await = Some(session.clone());
429        self.emit(AtpSessionEvent::Update, Some(&session));
430        Ok(session)
431    }
432
433    // --- Convenience helpers ---
434
435    /// Ensure the agent is authenticated, returning the DID.
436    async fn assert_did(&self) -> Result<Did, AgentError> {
437        self.did().await.ok_or(AgentError::NotAuthenticated)
438    }
439
440    /// Helper: make a query call with transparent 401-refresh retry.
441    ///
442    /// When the first attempt returns `ExpiredToken`, try to refresh
443    /// the session and replay the call once with the fresh access
444    /// token. Concurrent refreshes are deduplicated via
445    /// [`Agent::refresh_lock`].
446    async fn xrpc_query(
447        &self,
448        nsid: &str,
449        params: Option<&QueryParams>,
450    ) -> Result<serde_json::Value, AgentError> {
451        let opts = self.auth_call_options().await;
452        let first = self.client.query(nsid, params, opts.as_ref()).await;
453        match first {
454            Ok(r) => Ok(r.data),
455            Err(e) if is_auth_expired(&e) => {
456                self.refresh_and_retry(|opts| {
457                    let c = self.client.clone();
458                    let nsid = nsid.to_string();
459                    let params = params.cloned();
460                    async move { c.query(&nsid, params.as_ref(), opts.as_ref()).await }
461                })
462                .await
463            }
464            Err(e) => Err(AgentError::Xrpc(e)),
465        }
466    }
467
468    /// Helper: make a procedure call with transparent 401-refresh retry.
469    async fn xrpc_procedure(
470        &self,
471        nsid: &str,
472        body: serde_json::Value,
473    ) -> Result<serde_json::Value, AgentError> {
474        let opts = self.auth_call_options().await;
475        let first = self
476            .client
477            .procedure(
478                nsid,
479                None,
480                Some(XrpcBody::Json(body.clone())),
481                opts.as_ref(),
482            )
483            .await;
484        match first {
485            Ok(r) => Ok(r.data),
486            Err(e) if is_auth_expired(&e) => {
487                self.refresh_and_retry(|opts| {
488                    let c = self.client.clone();
489                    let nsid = nsid.to_string();
490                    let body = body.clone();
491                    async move {
492                        c.procedure(&nsid, None, Some(XrpcBody::Json(body)), opts.as_ref())
493                            .await
494                    }
495                })
496                .await
497            }
498            Err(e) => Err(AgentError::Xrpc(e)),
499        }
500    }
501
502    /// Shared refresh-and-retry driver.
503    ///
504    /// Acquires the `refresh_lock`, refreshes the session if the
505    /// access token in `self.session` is still the one that produced
506    /// the 401, rebuilds `CallOptions` from the new token, and runs
507    /// `replay(new_opts)`. Concurrent callers that arrive after the
508    /// lock is held observe the refreshed session when they get to
509    /// build their own opts — only one `/refreshSession` HTTP call
510    /// fires per refresh cycle.
511    async fn refresh_and_retry<F, Fut>(&self, replay: F) -> Result<serde_json::Value, AgentError>
512    where
513        F: FnOnce(Option<CallOptions>) -> Fut,
514        Fut: std::future::Future<
515                Output = Result<proto_blue_xrpc::XrpcResponse, proto_blue_xrpc::Error>,
516            >,
517    {
518        // Snapshot the access token the caller's first attempt used.
519        // After we acquire the refresh lock, compare — if a peer
520        // already refreshed, skip the redundant refresh.
521        let pre_refresh_jwt = self
522            .session
523            .read()
524            .await
525            .as_ref()
526            .map(|s| s.access_jwt.clone());
527        let guard = self.refresh_lock.lock().await;
528        let current_jwt = self
529            .session
530            .read()
531            .await
532            .as_ref()
533            .map(|s| s.access_jwt.clone());
534        if pre_refresh_jwt == current_jwt {
535            // No peer did the refresh — we must.
536            self.refresh_session().await?;
537        }
538        drop(guard);
539
540        let opts = self.auth_call_options().await;
541        let response = replay(opts).await?;
542        Ok(response.data)
543    }
544
545    /// Helper: create a record.
546    async fn create_record(
547        &self,
548        collection: &str,
549        record: serde_json::Value,
550    ) -> Result<serde_json::Value, AgentError> {
551        let did = self.assert_did().await?;
552        let body = serde_json::json!({
553            "repo": did,
554            "collection": collection,
555            "record": record,
556        });
557        self.xrpc_procedure("com.atproto.repo.createRecord", body)
558            .await
559    }
560
561    /// Helper: delete a record by AT-URI.
562    async fn delete_record(&self, collection: &str, uri: &AtUri) -> Result<(), AgentError> {
563        let did = self.assert_did().await?;
564        let rkey = uri
565            .rkey()
566            .ok_or_else(|| AgentError::Other("AT-URI has no rkey segment".into()))?;
567
568        let body = serde_json::json!({
569            "repo": did,
570            "collection": collection,
571            "rkey": rkey,
572        });
573        self.xrpc_procedure("com.atproto.repo.deleteRecord", body)
574            .await?;
575        Ok(())
576    }
577
578    /// Generate an ISO 8601 timestamp with millisecond precision.
579    fn now_iso() -> String {
580        chrono::Utc::now().to_rfc3339_opts(chrono::SecondsFormat::Millis, true)
581    }
582
583    /// Resolve a timestamp: use the provided value or generate one.
584    fn resolve_timestamp(created_at: Option<&str>) -> String {
585        created_at.map_or_else(Self::now_iso, String::from)
586    }
587
588    // --- Post operations ---
589
590    /// Create a new post.
591    ///
592    /// If `created_at` is `None`, the current time is used.
593    pub async fn post(
594        &self,
595        text: &str,
596        facets: Option<Vec<crate::rich_text::Facet>>,
597        created_at: Option<&str>,
598    ) -> Result<serde_json::Value, AgentError> {
599        let mut record = serde_json::json!({
600            "$type": "app.bsky.feed.post",
601            "text": text,
602            "createdAt": Self::resolve_timestamp(created_at),
603        });
604
605        if let Some(facets) = facets {
606            record["facets"] = serde_json::to_value(&facets)?;
607        }
608
609        self.create_record("app.bsky.feed.post", record).await
610    }
611
612    /// Create a post from `RichText` (includes detected facets).
613    pub async fn post_rich(
614        &self,
615        rt: &RichText,
616        created_at: Option<&str>,
617    ) -> Result<serde_json::Value, AgentError> {
618        let facets = if rt.facets().is_empty() {
619            None
620        } else {
621            Some(rt.facets().to_vec())
622        };
623        self.post(rt.text(), facets, created_at).await
624    }
625
626    /// Delete a post by AT-URI.
627    pub async fn delete_post(&self, uri: &AtUri) -> Result<(), AgentError> {
628        self.delete_record("app.bsky.feed.post", uri).await
629    }
630
631    // --- Like / Repost ---
632
633    /// Like a post.
634    ///
635    /// If `created_at` is `None`, the current time is used.
636    pub async fn like(
637        &self,
638        uri: &AtUri,
639        cid: &Cid,
640        created_at: Option<&str>,
641    ) -> Result<serde_json::Value, AgentError> {
642        let record = serde_json::json!({
643            "$type": "app.bsky.feed.like",
644            "subject": { "uri": uri, "cid": cid },
645            "createdAt": Self::resolve_timestamp(created_at),
646        });
647        self.create_record("app.bsky.feed.like", record).await
648    }
649
650    /// Unlike a post by AT-URI of the like record.
651    pub async fn delete_like(&self, like_uri: &AtUri) -> Result<(), AgentError> {
652        self.delete_record("app.bsky.feed.like", like_uri).await
653    }
654
655    /// Repost a post.
656    ///
657    /// If `created_at` is `None`, the current time is used.
658    pub async fn repost(
659        &self,
660        uri: &AtUri,
661        cid: &Cid,
662        created_at: Option<&str>,
663    ) -> Result<serde_json::Value, AgentError> {
664        let record = serde_json::json!({
665            "$type": "app.bsky.feed.repost",
666            "subject": { "uri": uri, "cid": cid },
667            "createdAt": Self::resolve_timestamp(created_at),
668        });
669        self.create_record("app.bsky.feed.repost", record).await
670    }
671
672    /// Delete a repost by AT-URI.
673    pub async fn delete_repost(&self, repost_uri: &AtUri) -> Result<(), AgentError> {
674        self.delete_record("app.bsky.feed.repost", repost_uri).await
675    }
676
677    // --- Follow ---
678
679    /// Follow a user by DID.
680    ///
681    /// If `created_at` is `None`, the current time is used.
682    pub async fn follow(
683        &self,
684        subject_did: &Did,
685        created_at: Option<&str>,
686    ) -> Result<serde_json::Value, AgentError> {
687        let record = serde_json::json!({
688            "$type": "app.bsky.graph.follow",
689            "subject": subject_did,
690            "createdAt": Self::resolve_timestamp(created_at),
691        });
692        self.create_record("app.bsky.graph.follow", record).await
693    }
694
695    /// Unfollow by AT-URI of the follow record.
696    pub async fn delete_follow(&self, follow_uri: &AtUri) -> Result<(), AgentError> {
697        self.delete_record("app.bsky.graph.follow", follow_uri)
698            .await
699    }
700
701    // --- Query helpers ---
702
703    /// Get a user's profile.
704    pub async fn get_profile(&self, actor: &AtIdentifier) -> Result<serde_json::Value, AgentError> {
705        let mut params = QueryParams::new();
706        params.insert("actor".into(), QueryValue::String(actor.to_string()));
707        self.xrpc_query("app.bsky.actor.getProfile", Some(&params))
708            .await
709    }
710
711    /// Get the home timeline.
712    pub async fn get_timeline(
713        &self,
714        limit: Option<i64>,
715        cursor: Option<&str>,
716    ) -> Result<serde_json::Value, AgentError> {
717        let mut params = QueryParams::new();
718        if let Some(limit) = limit {
719            params.insert("limit".into(), QueryValue::Integer(limit));
720        }
721        if let Some(cursor) = cursor {
722            params.insert("cursor".into(), QueryValue::String(cursor.into()));
723        }
724        self.xrpc_query("app.bsky.feed.getTimeline", Some(&params))
725            .await
726    }
727
728    /// Get a post thread.
729    pub async fn get_post_thread(
730        &self,
731        uri: &AtUri,
732        depth: Option<i64>,
733    ) -> Result<serde_json::Value, AgentError> {
734        let mut params = QueryParams::new();
735        params.insert("uri".into(), QueryValue::String(uri.to_string()));
736        if let Some(depth) = depth {
737            params.insert("depth".into(), QueryValue::Integer(depth));
738        }
739        self.xrpc_query("app.bsky.feed.getPostThread", Some(&params))
740            .await
741    }
742
743    /// Search actors.
744    pub async fn search_actors(
745        &self,
746        query: &str,
747        limit: Option<i64>,
748    ) -> Result<serde_json::Value, AgentError> {
749        let mut params = QueryParams::new();
750        params.insert("q".into(), QueryValue::String(query.into()));
751        if let Some(limit) = limit {
752            params.insert("limit".into(), QueryValue::Integer(limit));
753        }
754        self.xrpc_query("app.bsky.actor.searchActors", Some(&params))
755            .await
756    }
757
758    /// Resolve a handle to a DID.
759    pub async fn resolve_handle(&self, handle: &Handle) -> Result<Did, AgentError> {
760        let mut params = QueryParams::new();
761        params.insert("handle".into(), QueryValue::String(handle.to_string()));
762        let data = self
763            .xrpc_query("com.atproto.identity.resolveHandle", Some(&params))
764            .await?;
765        let did_str = data
766            .get("did")
767            .and_then(|v| v.as_str())
768            .ok_or_else(|| AgentError::Other("Missing DID in response".into()))?;
769        Did::new(did_str)
770            .map_err(|e| AgentError::Other(format!("server returned invalid DID: {e}")))
771    }
772
773    /// Get notifications.
774    pub async fn list_notifications(
775        &self,
776        limit: Option<i64>,
777        cursor: Option<&str>,
778    ) -> Result<serde_json::Value, AgentError> {
779        let mut params = QueryParams::new();
780        if let Some(limit) = limit {
781            params.insert("limit".into(), QueryValue::Integer(limit));
782        }
783        if let Some(cursor) = cursor {
784            params.insert("cursor".into(), QueryValue::String(cursor.into()));
785        }
786        self.xrpc_query("app.bsky.notification.listNotifications", Some(&params))
787            .await
788    }
789
790    /// Upload a blob (image, video, etc.).
791    pub async fn upload_blob(
792        &self,
793        data: Vec<u8>,
794        content_type: &str,
795    ) -> Result<serde_json::Value, AgentError> {
796        let mut headers = HeadersMap::new();
797        headers.insert("Content-Type".into(), content_type.into());
798
799        // Add auth header from session
800        if let Some(sess) = self.session.read().await.as_ref() {
801            headers.insert(
802                "Authorization".into(),
803                format!("Bearer {}", sess.access_jwt),
804            );
805        }
806
807        let opts = CallOptions {
808            encoding: Some(content_type.to_string()),
809            headers: Some(headers),
810            ..Default::default()
811        };
812
813        let response = self
814            .client
815            .procedure(
816                "com.atproto.repo.uploadBlob",
817                None,
818                Some(XrpcBody::Bytes(data)),
819                Some(&opts),
820            )
821            .await?;
822
823        Ok(response.data)
824    }
825
826    /// Describe the server.
827    pub async fn describe_server(&self) -> Result<serde_json::Value, AgentError> {
828        self.xrpc_query("com.atproto.server.describeServer", None)
829            .await
830    }
831
832    // --- Account lifecycle ---
833
834    /// Log out of the current session.
835    ///
836    /// Sends a best-effort `deleteSession` call using the current
837    /// **refresh** token (TS matches this — `deleteSession` requires
838    /// the refresh JWT, not the access JWT). Clears local session
839    /// state whether or not the server call succeeds, so the agent
840    /// always ends up unauthenticated.
841    pub async fn logout(&self) -> Result<(), AgentError> {
842        let refresh_jwt = {
843            let guard = self.session.read().await;
844            guard.as_ref().map(|s| s.refresh_jwt.clone())
845        };
846
847        let server_result = if let Some(refresh_jwt) = refresh_jwt {
848            let mut headers = HeadersMap::new();
849            headers.insert("Authorization".into(), format!("Bearer {refresh_jwt}"));
850            let opts = CallOptions {
851                encoding: None,
852                headers: Some(headers),
853                ..Default::default()
854            };
855            self.client
856                .procedure("com.atproto.server.deleteSession", None, None, Some(&opts))
857                .await
858                .map(|_| ())
859        } else {
860            Ok(())
861        };
862
863        // Always clear local state.
864        *self.session.write().await = None;
865        self.emit(AtpSessionEvent::Expired, None);
866
867        server_result.map_err(AgentError::Xrpc)
868    }
869
870    /// Create a new account on the current service.
871    ///
872    /// `extra` is merged into the request body — useful for passing
873    /// `inviteCode`, `verificationCode`, or custom provider-specific
874    /// fields without this method's signature needing to know every
875    /// option the server supports.
876    ///
877    /// On success, the new session is stored and `Create` is emitted.
878    pub async fn create_account(
879        &self,
880        handle: &Handle,
881        password: &str,
882        email: Option<&str>,
883        extra: Option<serde_json::Value>,
884    ) -> Result<Session, AgentError> {
885        let mut body = serde_json::json!({
886            "handle": handle,
887            "password": password,
888        });
889        if let Some(email) = email {
890            body["email"] = serde_json::Value::String(email.to_string());
891        }
892        if let Some(extra) = extra
893            && let Some(extra_map) = extra.as_object()
894            && let Some(body_map) = body.as_object_mut()
895        {
896            for (k, v) in extra_map {
897                body_map.insert(k.clone(), v.clone());
898            }
899        }
900
901        let response = match self
902            .client
903            .procedure(
904                "com.atproto.server.createAccount",
905                None,
906                Some(XrpcBody::Json(body)),
907                None,
908            )
909            .await
910        {
911            Ok(r) => r,
912            Err(e) => {
913                self.emit(AtpSessionEvent::CreateFailed, None);
914                return Err(AgentError::Xrpc(e));
915            }
916        };
917
918        let session: Session = serde_json::from_value(response.data)?;
919        *self.session.write().await = Some(session.clone());
920        self.emit(AtpSessionEvent::Create, Some(&session));
921        Ok(session)
922    }
923
924    /// Create-or-update the signed-in user's `app.bsky.actor.profile`
925    /// record.
926    ///
927    /// The `mutate` closure receives the existing profile record (or
928    /// `serde_json::Value::Null` if none exists) and returns the
929    /// desired next state. This pattern mirrors TS
930    /// `AtpAgent.upsertProfile(updateFn)`.
931    ///
932    /// The write uses `putRecord` with `swapRecord` for CAS safety;
933    /// if the swap fails we retry up to 5 times with a fresh read.
934    pub async fn upsert_profile<F>(&self, mutate: F) -> Result<serde_json::Value, AgentError>
935    where
936        F: Fn(serde_json::Value) -> serde_json::Value,
937    {
938        const MAX_RETRIES: u32 = 5;
939
940        let did = self.assert_did().await?;
941
942        for _ in 0..MAX_RETRIES {
943            // Read the existing profile (may 404 — that's fine).
944            let existing_result = self
945                .xrpc_query(
946                    "com.atproto.repo.getRecord",
947                    Some(&{
948                        let mut p = QueryParams::new();
949                        p.insert("repo".into(), QueryValue::String(did.to_string()));
950                        p.insert(
951                            "collection".into(),
952                            QueryValue::String("app.bsky.actor.profile".into()),
953                        );
954                        p.insert("rkey".into(), QueryValue::String("self".into()));
955                        p
956                    }),
957                )
958                .await;
959
960            let (existing_record, swap_cid) = match existing_result {
961                Ok(r) => {
962                    let record = r.get("value").cloned().unwrap_or(serde_json::Value::Null);
963                    let cid = r.get("cid").and_then(|v| v.as_str()).map(String::from);
964                    (record, cid)
965                }
966                Err(AgentError::Xrpc(ref e)) if is_not_found(e) => (serde_json::Value::Null, None),
967                Err(e) => return Err(e),
968            };
969
970            let updated = mutate(existing_record);
971            let mut body = serde_json::json!({
972                "repo": did,
973                "collection": "app.bsky.actor.profile",
974                "rkey": "self",
975                "record": updated,
976            });
977            if let Some(cid) = swap_cid {
978                body["swapRecord"] = serde_json::Value::String(cid);
979            }
980
981            match self
982                .xrpc_procedure("com.atproto.repo.putRecord", body)
983                .await
984            {
985                Ok(r) => return Ok(r),
986                Err(AgentError::Xrpc(ref e)) if is_invalid_swap(e) => {
987                    // Race lost — someone else updated between our read
988                    // and write. Fall through to the next loop iteration
989                    // with a fresh read.
990                }
991                Err(e) => return Err(e),
992            }
993        }
994
995        Err(AgentError::Other(
996            "upsert_profile: exceeded maximum retries due to concurrent writes".into(),
997        ))
998    }
999}
1000
1001/// `true` if an XRPC error is a 4xx that specifically indicates the
1002/// record does not exist. `getRecord` uses `RecordNotFound`.
1003fn is_not_found(err: &proto_blue_xrpc::Error) -> bool {
1004    match err {
1005        proto_blue_xrpc::Error::Xrpc(x) => x.is_error("RecordNotFound"),
1006        _ => false,
1007    }
1008}
1009
1010/// `true` if the server rejected a `putRecord` because the `swapRecord`
1011/// CID didn't match — caller should re-read and retry.
1012fn is_invalid_swap(err: &proto_blue_xrpc::Error) -> bool {
1013    match err {
1014        proto_blue_xrpc::Error::Xrpc(x) => x.is_error("InvalidSwap"),
1015        _ => false,
1016    }
1017}
1018
1019/// `true` if an XRPC error signals that the access token is expired
1020/// and the caller should try to refresh. Looks for
1021/// `AuthenticationRequired` (401) with the specific `ExpiredToken`
1022/// error name — other 401 variants aren't necessarily caused by
1023/// expiry (e.g. wrong credentials, app-password rejection) and
1024/// shouldn't trigger the refresh-and-retry path.
1025fn is_auth_expired(err: &proto_blue_xrpc::Error) -> bool {
1026    match err {
1027        proto_blue_xrpc::Error::Xrpc(x) => {
1028            matches!(x.status, ResponseType::AuthenticationRequired) && x.is_error("ExpiredToken")
1029        }
1030        _ => false,
1031    }
1032}
1033
1034/// `true` if an error from `/refreshSession` signals that the refresh
1035/// token is rejected (rather than a transient network problem). Any
1036/// 401 from the refresh endpoint is authoritative — the token is
1037/// dead — regardless of the specific error-name code.
1038const fn is_refresh_rejected(err: &proto_blue_xrpc::Error) -> bool {
1039    match err {
1040        proto_blue_xrpc::Error::Xrpc(x) => {
1041            matches!(x.status, ResponseType::AuthenticationRequired)
1042        }
1043        _ => false,
1044    }
1045}
1046
1047#[cfg(test)]
1048mod tests {
1049    use super::*;
1050
1051    #[test]
1052    fn agent_creation() {
1053        let _agent = Agent::new("https://bsky.social").unwrap();
1054    }
1055
1056    #[test]
1057    fn session_serde_roundtrip() {
1058        let session = Session {
1059            did: Did::new("did:plc:abc123").unwrap(),
1060            handle: Handle::new("alice.bsky.social").unwrap(),
1061            access_jwt: "eyJ...".to_string(),
1062            refresh_jwt: "eyJ...".to_string(),
1063            email: Some("alice@example.com".to_string()),
1064            email_confirmed: Some(true),
1065        };
1066
1067        let json = serde_json::to_string(&session).unwrap();
1068        let parsed: Session = serde_json::from_str(&json).unwrap();
1069        assert_eq!(parsed.did.as_str(), "did:plc:abc123");
1070        assert_eq!(parsed.handle.as_str(), "alice.bsky.social");
1071        assert_eq!(parsed.email, Some("alice@example.com".to_string()));
1072    }
1073
1074    #[test]
1075    fn agent_error_display() {
1076        let err = AgentError::NotAuthenticated;
1077        assert_eq!(err.to_string(), "Not authenticated");
1078
1079        let err = AgentError::Other("test error".into());
1080        assert_eq!(err.to_string(), "test error");
1081    }
1082
1083    #[tokio::test]
1084    async fn agent_no_session_by_default() {
1085        let agent = Agent::new("https://bsky.social").unwrap();
1086        assert!(agent.did().await.is_none());
1087        assert!(agent.session().await.is_none());
1088    }
1089
1090    #[tokio::test]
1091    async fn agent_assert_did_fails_when_not_logged_in() {
1092        let agent = Agent::new("https://bsky.social").unwrap();
1093        let err = agent.assert_did().await.unwrap_err();
1094        assert!(matches!(err, AgentError::NotAuthenticated));
1095    }
1096
1097    #[test]
1098    fn now_iso_format() {
1099        let ts = Agent::now_iso();
1100        assert!(ts.ends_with('Z'));
1101        assert!(ts.contains('T'));
1102    }
1103
1104    #[test]
1105    fn resolve_timestamp_with_provided() {
1106        let ts = Agent::resolve_timestamp(Some("2024-01-15T12:00:00.000Z"));
1107        assert_eq!(ts, "2024-01-15T12:00:00.000Z");
1108    }
1109
1110    #[test]
1111    fn resolve_timestamp_without_provided() {
1112        let ts = Agent::resolve_timestamp(None);
1113        assert!(ts.ends_with('Z'));
1114        assert!(ts.contains('T'));
1115    }
1116
1117    #[test]
1118    fn service_url_accessible_without_async() {
1119        let agent = Agent::new("https://bsky.social").unwrap();
1120        assert_eq!(agent.service(), "https://bsky.social/");
1121    }
1122
1123    #[tokio::test]
1124    async fn auth_call_options_none_when_not_authenticated() {
1125        let agent = Agent::new("https://bsky.social").unwrap();
1126        assert!(agent.auth_call_options().await.is_none());
1127    }
1128
1129    // ── Session events + auto-refresh ────────────────────────────────
1130
1131    use async_trait::async_trait;
1132    use proto_blue_common::fetch::{FetchError, FetchHandler, HttpRequest, HttpResponse};
1133
1134    /// Fetcher that scripts a sequence of responses for each NSID path.
1135    /// The first call to each NSID returns `responses[i][0]`, second
1136    /// `responses[i][1]`, etc. Also counts calls per NSID for assertions.
1137    struct ScriptedFetcher {
1138        createsession_body: Vec<u8>,
1139        /// (path_suffix, sequence_of_bodies)
1140        scripts: std::sync::Mutex<std::collections::HashMap<String, Vec<ScriptedResponse>>>,
1141        call_counts: std::sync::Mutex<std::collections::HashMap<String, usize>>,
1142    }
1143
1144    #[derive(Clone)]
1145    struct ScriptedResponse {
1146        status: u16,
1147        body: Vec<u8>,
1148    }
1149
1150    impl ScriptedFetcher {
1151        fn new(createsession_body: Vec<u8>) -> Self {
1152            Self {
1153                createsession_body,
1154                scripts: Default::default(),
1155                call_counts: Default::default(),
1156            }
1157        }
1158        fn script(&self, path: &str, responses: Vec<ScriptedResponse>) {
1159            self.scripts
1160                .lock()
1161                .unwrap()
1162                .insert(path.to_string(), responses);
1163        }
1164        fn call_count(&self, path: &str) -> usize {
1165            *self.call_counts.lock().unwrap().get(path).unwrap_or(&0)
1166        }
1167    }
1168
1169    #[async_trait]
1170    impl FetchHandler for ScriptedFetcher {
1171        async fn fetch(&self, req: HttpRequest) -> Result<HttpResponse, FetchError> {
1172            let path = req.url.clone();
1173            let key = path
1174                .split("/xrpc/")
1175                .nth(1)
1176                .unwrap_or(&path)
1177                .split('?')
1178                .next()
1179                .unwrap_or("")
1180                .to_string();
1181            *self
1182                .call_counts
1183                .lock()
1184                .unwrap()
1185                .entry(key.clone())
1186                .or_insert(0) += 1;
1187
1188            // Scripted responses always take precedence; the
1189            // createSession short-circuit only fires when the caller
1190            // hasn't explicitly scripted it.
1191            {
1192                let mut scripts = self.scripts.lock().unwrap();
1193                if let Some(list) = scripts.get_mut(&key) {
1194                    let resp = if list.len() == 1 {
1195                        list[0].clone()
1196                    } else {
1197                        list.remove(0)
1198                    };
1199                    let mut headers = proto_blue_common::fetch::HttpHeaders::new();
1200                    headers.insert("content-type".into(), "application/json".into());
1201                    return Ok(HttpResponse {
1202                        status: resp.status,
1203                        headers,
1204                        body: resp.body,
1205                    });
1206                }
1207            }
1208
1209            // Default: createSession always succeeds.
1210            if key == "com.atproto.server.createSession" {
1211                let mut headers = proto_blue_common::fetch::HttpHeaders::new();
1212                headers.insert("content-type".into(), "application/json".into());
1213                return Ok(HttpResponse {
1214                    status: 200,
1215                    headers,
1216                    body: self.createsession_body.clone(),
1217                });
1218            }
1219
1220            Err(FetchError::Other(format!("no script for {key}")))
1221        }
1222    }
1223
1224    fn login_body() -> Vec<u8> {
1225        br#"{"did":"did:plc:u","handle":"alice.test","accessJwt":"a1","refreshJwt":"r1"}"#.to_vec()
1226    }
1227
1228    fn agent_with_fetcher(fetcher: Arc<ScriptedFetcher>) -> Agent {
1229        let client = XrpcClient::with_fetch_handler("https://example.com", fetcher).unwrap();
1230        Agent {
1231            client,
1232            session: Arc::new(RwLock::new(None)),
1233            listeners: Arc::new(Mutex::new(Vec::new())),
1234            refresh_lock: Arc::new(AsyncMutex::new(())),
1235            proxy: Arc::new(RwLock::new(None)),
1236            labelers: Arc::new(RwLock::new(Vec::new())),
1237        }
1238    }
1239
1240    #[tokio::test]
1241    async fn emits_create_on_successful_login() {
1242        let fetcher = Arc::new(ScriptedFetcher::new(login_body()));
1243        let agent = agent_with_fetcher(fetcher);
1244
1245        let events: Arc<Mutex<Vec<AtpSessionEvent>>> = Arc::new(Mutex::new(Vec::new()));
1246        let ev_clone = events.clone();
1247        agent.on_session(move |e, _| ev_clone.lock().unwrap().push(e));
1248
1249        agent
1250            .login(&AtIdentifier::new("alice.test").unwrap(), "secret")
1251            .await
1252            .unwrap();
1253        let got = events.lock().unwrap().clone();
1254        assert_eq!(got, vec![AtpSessionEvent::Create]);
1255    }
1256
1257    #[tokio::test]
1258    async fn emits_create_failed_on_login_rejection() {
1259        let fetcher = Arc::new(ScriptedFetcher::new(vec![]));
1260        // Override createSession to fail:
1261        fetcher.script(
1262            "com.atproto.server.createSession",
1263            vec![ScriptedResponse {
1264                status: 401,
1265                body: br#"{"error":"AuthenticationRequired","message":"bad pwd"}"#.to_vec(),
1266            }],
1267        );
1268        let agent = agent_with_fetcher(fetcher);
1269
1270        let events: Arc<Mutex<Vec<AtpSessionEvent>>> = Arc::new(Mutex::new(Vec::new()));
1271        let ev_clone = events.clone();
1272        agent.on_session(move |e, _| ev_clone.lock().unwrap().push(e));
1273
1274        // Override `createsession_body` handler: scripts take precedence.
1275        // ScriptedFetcher's createSession short-circuit only fires when
1276        // NOT scripted; since we scripted it, the 401 flows through.
1277        let _ = agent
1278            .login(&AtIdentifier::new("alice.test").unwrap(), "bad")
1279            .await
1280            .unwrap_err();
1281        let got = events.lock().unwrap().clone();
1282        assert_eq!(got, vec![AtpSessionEvent::CreateFailed]);
1283    }
1284
1285    #[tokio::test]
1286    async fn auto_refreshes_on_expired_access_token() {
1287        let fetcher = Arc::new(ScriptedFetcher::new(login_body()));
1288
1289        // First call to describeServer returns 401 ExpiredToken,
1290        // second call (post-refresh) returns 200.
1291        fetcher.script(
1292            "com.atproto.server.describeServer",
1293            vec![
1294                ScriptedResponse {
1295                    status: 401,
1296                    body: br#"{"error":"ExpiredToken","message":"expired"}"#.to_vec(),
1297                },
1298                ScriptedResponse {
1299                    status: 200,
1300                    body: br#"{"did":"did:plc:svr"}"#.to_vec(),
1301                },
1302            ],
1303        );
1304        fetcher.script(
1305            "com.atproto.server.refreshSession",
1306            vec![ScriptedResponse {
1307                status: 200,
1308                body: br#"{"did":"did:plc:u","handle":"alice.test","accessJwt":"a2","refreshJwt":"r2"}"#
1309                    .to_vec(),
1310            }],
1311        );
1312
1313        let agent = agent_with_fetcher(fetcher.clone());
1314        agent
1315            .login(&AtIdentifier::new("alice.test").unwrap(), "secret")
1316            .await
1317            .unwrap();
1318
1319        let events: Arc<Mutex<Vec<AtpSessionEvent>>> = Arc::new(Mutex::new(Vec::new()));
1320        let ev_clone = events.clone();
1321        agent.on_session(move |e, _| ev_clone.lock().unwrap().push(e));
1322
1323        let result = agent.describe_server().await.unwrap();
1324        assert_eq!(result["did"], "did:plc:svr");
1325
1326        // describeServer was called twice (first 401, second success
1327        // after refresh); refreshSession was called exactly once.
1328        assert_eq!(fetcher.call_count("com.atproto.server.describeServer"), 2);
1329        assert_eq!(fetcher.call_count("com.atproto.server.refreshSession"), 1);
1330
1331        // One Update event fired during the refresh.
1332        let got = events.lock().unwrap().clone();
1333        assert_eq!(got, vec![AtpSessionEvent::Update]);
1334    }
1335
1336    #[tokio::test]
1337    async fn concurrent_expired_token_refreshes_once() {
1338        let fetcher = Arc::new(ScriptedFetcher::new(login_body()));
1339
1340        // All 401s for the first three attempts; subsequent calls get
1341        // the scripted OK response (the last entry is reused).
1342        fetcher.script(
1343            "com.atproto.server.describeServer",
1344            vec![
1345                ScriptedResponse {
1346                    status: 401,
1347                    body: br#"{"error":"ExpiredToken","message":"expired"}"#.to_vec(),
1348                },
1349                ScriptedResponse {
1350                    status: 200,
1351                    body: br#"{"did":"did:plc:svr"}"#.to_vec(),
1352                },
1353            ],
1354        );
1355        fetcher.script(
1356            "com.atproto.server.refreshSession",
1357            vec![ScriptedResponse {
1358                status: 200,
1359                body: br#"{"did":"did:plc:u","handle":"alice.test","accessJwt":"a2","refreshJwt":"r2"}"#
1360                    .to_vec(),
1361            }],
1362        );
1363
1364        let agent = Arc::new(agent_with_fetcher(fetcher.clone()));
1365        agent
1366            .login(&AtIdentifier::new("alice.test").unwrap(), "secret")
1367            .await
1368            .unwrap();
1369
1370        // 5 concurrent calls all hit 401 on first attempt. Refresh
1371        // must fire exactly once — the dedup lock + access-token
1372        // staleness check guarantee this.
1373        let mut handles = Vec::new();
1374        for _ in 0..5 {
1375            let a = agent.clone();
1376            handles.push(tokio::spawn(async move {
1377                a.describe_server().await.unwrap();
1378            }));
1379        }
1380        for h in handles {
1381            h.await.unwrap();
1382        }
1383
1384        assert_eq!(
1385            fetcher.call_count("com.atproto.server.refreshSession"),
1386            1,
1387            "concurrent callers must share one refreshSession call",
1388        );
1389    }
1390
1391    #[tokio::test]
1392    async fn configure_proxy_sets_header_on_next_call() {
1393        let fetcher = Arc::new(ScriptedFetcher::new(login_body()));
1394        fetcher.script(
1395            "com.atproto.server.describeServer",
1396            vec![ScriptedResponse {
1397                status: 200,
1398                body: br#"{"did":"did:plc:svr"}"#.to_vec(),
1399            }],
1400        );
1401        let agent = agent_with_fetcher(fetcher.clone());
1402        agent
1403            .configure_proxy(Some("did:web:api.bsky.chat#bsky_chat"))
1404            .await;
1405
1406        agent.describe_server().await.unwrap();
1407
1408        // We can't easily inspect fetched headers from ScriptedFetcher
1409        // as it's structured; instead, assert the proxy is readable.
1410        let p = agent.proxy.read().await;
1411        assert_eq!(p.as_deref(), Some("did:web:api.bsky.chat#bsky_chat"));
1412    }
1413
1414    #[tokio::test]
1415    async fn configure_labelers_stores_list() {
1416        let fetcher = Arc::new(ScriptedFetcher::new(login_body()));
1417        let agent = agent_with_fetcher(fetcher);
1418        agent
1419            .configure_labelers(&[
1420                LabelerOpts {
1421                    did: Did::new("did:plc:a").unwrap(),
1422                    redirect: false,
1423                },
1424                LabelerOpts {
1425                    did: Did::new("did:plc:b").unwrap(),
1426                    redirect: true,
1427                },
1428            ])
1429            .await;
1430        let l = agent.labelers.read().await;
1431        assert_eq!(l.len(), 2);
1432        assert_eq!(l[0].header_value(), "did:plc:a");
1433        assert_eq!(l[1].header_value(), "did:plc:b;redirect");
1434    }
1435
1436    #[tokio::test]
1437    async fn logout_clears_session() {
1438        let fetcher = Arc::new(ScriptedFetcher::new(login_body()));
1439        fetcher.script(
1440            "com.atproto.server.deleteSession",
1441            vec![ScriptedResponse {
1442                status: 200,
1443                body: b"{}".to_vec(),
1444            }],
1445        );
1446        let agent = agent_with_fetcher(fetcher.clone());
1447        agent
1448            .login(&AtIdentifier::new("alice.test").unwrap(), "secret")
1449            .await
1450            .unwrap();
1451        assert!(agent.session().await.is_some());
1452        agent.logout().await.unwrap();
1453        assert!(agent.session().await.is_none());
1454        assert_eq!(fetcher.call_count("com.atproto.server.deleteSession"), 1,);
1455    }
1456
1457    #[tokio::test]
1458    async fn logout_clears_session_even_on_server_error() {
1459        let fetcher = Arc::new(ScriptedFetcher::new(login_body()));
1460        fetcher.script(
1461            "com.atproto.server.deleteSession",
1462            vec![ScriptedResponse {
1463                status: 500,
1464                body: br#"{"error":"InternalServerError"}"#.to_vec(),
1465            }],
1466        );
1467        let agent = agent_with_fetcher(fetcher);
1468        agent
1469            .login(&AtIdentifier::new("alice.test").unwrap(), "secret")
1470            .await
1471            .unwrap();
1472        // Server call fails, but local state must still be cleared.
1473        let _ = agent.logout().await;
1474        assert!(agent.session().await.is_none());
1475    }
1476
1477    #[tokio::test]
1478    async fn create_account_emits_create_on_success() {
1479        let fetcher = Arc::new(ScriptedFetcher::new(login_body()));
1480        fetcher.script(
1481            "com.atproto.server.createAccount",
1482            vec![ScriptedResponse {
1483                status: 200,
1484                body:
1485                    br#"{"did":"did:plc:new","handle":"newuser.test","accessJwt":"a","refreshJwt":"r"}"#
1486                        .to_vec(),
1487            }],
1488        );
1489        let agent = agent_with_fetcher(fetcher);
1490
1491        let events: Arc<Mutex<Vec<AtpSessionEvent>>> = Arc::new(Mutex::new(Vec::new()));
1492        let ev = events.clone();
1493        agent.on_session(move |e, _| ev.lock().unwrap().push(e));
1494
1495        let session = agent
1496            .create_account(
1497                &Handle::new("newuser.test").unwrap(),
1498                "pw",
1499                Some("new@example.com"),
1500                None,
1501            )
1502            .await
1503            .unwrap();
1504        assert_eq!(session.did.as_str(), "did:plc:new");
1505        assert_eq!(
1506            events.lock().unwrap().clone(),
1507            vec![AtpSessionEvent::Create]
1508        );
1509    }
1510
1511    #[tokio::test]
1512    async fn upsert_profile_creates_when_absent() {
1513        let fetcher = Arc::new(ScriptedFetcher::new(login_body()));
1514        // getRecord returns 404 RecordNotFound
1515        fetcher.script(
1516            "com.atproto.repo.getRecord",
1517            vec![ScriptedResponse {
1518                status: 400,
1519                body: br#"{"error":"RecordNotFound","message":"no such record"}"#.to_vec(),
1520            }],
1521        );
1522        fetcher.script(
1523            "com.atproto.repo.putRecord",
1524            vec![ScriptedResponse {
1525                status: 200,
1526                body: br#"{"uri":"at://did:plc:u/app.bsky.actor.profile/self","cid":"bafy"}"#
1527                    .to_vec(),
1528            }],
1529        );
1530        let agent = agent_with_fetcher(fetcher);
1531        agent
1532            .login(&AtIdentifier::new("alice.test").unwrap(), "secret")
1533            .await
1534            .unwrap();
1535
1536        let result = agent
1537            .upsert_profile(|prev| {
1538                assert!(prev.is_null(), "no existing profile");
1539                serde_json::json!({"$type": "app.bsky.actor.profile", "displayName": "Alice"})
1540            })
1541            .await
1542            .unwrap();
1543        assert_eq!(result["uri"], "at://did:plc:u/app.bsky.actor.profile/self");
1544    }
1545
1546    #[tokio::test]
1547    async fn emits_expired_when_refresh_itself_401s() {
1548        let fetcher = Arc::new(ScriptedFetcher::new(login_body()));
1549        fetcher.script(
1550            "com.atproto.server.refreshSession",
1551            vec![ScriptedResponse {
1552                status: 401,
1553                body: br#"{"error":"AuthenticationRequired","message":"refresh expired"}"#.to_vec(),
1554            }],
1555        );
1556        let agent = agent_with_fetcher(fetcher);
1557        agent
1558            .login(&AtIdentifier::new("alice.test").unwrap(), "secret")
1559            .await
1560            .unwrap();
1561
1562        let events: Arc<Mutex<Vec<AtpSessionEvent>>> = Arc::new(Mutex::new(Vec::new()));
1563        let ev_clone = events.clone();
1564        agent.on_session(move |e, _| ev_clone.lock().unwrap().push(e));
1565
1566        let _ = agent.refresh_session().await.unwrap_err();
1567        let got = events.lock().unwrap().clone();
1568        assert_eq!(got, vec![AtpSessionEvent::Expired]);
1569        assert!(
1570            agent.session().await.is_none(),
1571            "session cleared on expired refresh"
1572        );
1573    }
1574}