Skip to main content

proto_blue_api/
agent.rs

1//! AT Protocol Agent — high-level client wrapping XRPC.
2//!
3//! Provides session management, convenience methods for common operations,
4//! and namespace accessors for the full Lexicon API surface.
5
6use std::sync::{Arc, Mutex};
7use tokio::sync::{Mutex as AsyncMutex, RwLock};
8
9use proto_blue_lex_data::Cid;
10use proto_blue_syntax::{AtIdentifier, AtUri, Did, Handle};
11use proto_blue_xrpc::{
12    CallOptions, HeadersMap, QueryParams, QueryValue, ResponseType, XrpcBody, XrpcClient,
13};
14
15use crate::rich_text::RichText;
16
17/// Session lifecycle events emitted by [`Agent`].
18///
19/// Mirrors TS `AtpSessionEvent`. Register a listener via
20/// [`Agent::on_session`] to observe login / refresh / expiry. Typical
21/// use is to persist the session on `Create` / `Update` and to clear
22/// local state on `Expired`.
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24pub enum AtpSessionEvent {
25    /// A new session was established (successful login / resume).
26    Create,
27    /// A login attempt failed.
28    CreateFailed,
29    /// The session tokens were refreshed.
30    Update,
31    /// The server rejected the refresh — the user must log in again.
32    Expired,
33    /// A network-level failure during a session-affecting call.
34    NetworkError,
35}
36
37/// Callback registered via [`Agent::on_session`].
38///
39/// Invoked synchronously on the task that produced the event; handlers
40/// should not block for long. The `Option<&Session>` is `Some` for
41/// `Create` / `Update` and `None` for `CreateFailed` / `Expired` /
42/// `NetworkError`.
43pub type SessionEventCallback = Arc<dyn Fn(AtpSessionEvent, Option<&Session>) + Send + Sync>;
44
45/// Session data for an authenticated agent.
46///
47/// `did` and `handle` are typed as the validated `proto_blue_syntax`
48/// newtypes so callers can't accidentally pass random strings; the
49/// JWTs and email stay as `String` (no validated newtype exists for
50/// either, and over-validating an opaque token would just mean
51/// re-parsing on every request).
52#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
53#[serde(rename_all = "camelCase")]
54pub struct Session {
55    pub did: Did,
56    pub handle: Handle,
57    pub access_jwt: String,
58    pub refresh_jwt: String,
59    #[serde(skip_serializing_if = "Option::is_none")]
60    pub email: Option<String>,
61    #[serde(skip_serializing_if = "Option::is_none")]
62    pub email_confirmed: Option<bool>,
63}
64
65/// Errors from Agent operations.
66#[derive(Debug, thiserror::Error)]
67pub enum AgentError {
68    #[error("XRPC error: {0}")]
69    Xrpc(#[from] proto_blue_xrpc::Error),
70    #[error("Not authenticated")]
71    NotAuthenticated,
72    #[error("JSON error: {0}")]
73    Json(#[from] serde_json::Error),
74    #[error("{0}")]
75    Other(String),
76}
77
78/// High-level AT Protocol agent.
79///
80/// Auth state lives in a single `RwLock<Option<Session>>`. The XRPC client
81/// is never mutated after construction — auth headers are passed per-request.
82/// This avoids token leaks, giant-lock contention, and split-lock atomicity
83/// gaps that arise from storing auth in the client's default headers.
84///
85/// ## Transparent refresh
86///
87/// Every XRPC call goes through `xrpc_query_with_refresh` /
88/// `xrpc_procedure_with_refresh`, which detect 401 /
89/// `ExpiredToken` responses, call [`Agent::refresh_session`], and
90/// retry once. Concurrent refresh attempts are deduplicated via an
91/// async `Mutex` so N in-flight calls that all see an expired token
92/// issue exactly one `/refreshSession` request. If the refresh itself
93/// fails, the agent fires [`AtpSessionEvent::Expired`] and the
94/// original error propagates.
95pub struct Agent {
96    client: XrpcClient,
97    session: Arc<RwLock<Option<Session>>>,
98    /// Session-event listeners. Called synchronously on the task that
99    /// produced the event.
100    listeners: Arc<Mutex<Vec<SessionEventCallback>>>,
101    /// Serialises concurrent refreshes. The first caller to see a 401
102    /// acquires this lock, performs the refresh, and writes the new
103    /// session back; subsequent callers block until that finishes and
104    /// then observe the updated session when their retry fires.
105    refresh_lock: Arc<AsyncMutex<()>>,
106    /// Optional `atproto-proxy` target (e.g. `did:web:api.bsky.chat#bsky_chat`
107    /// when calling the chat service).
108    proxy: Arc<RwLock<Option<String>>>,
109    /// Optional set of labeler DIDs to send as `atproto-accept-labelers`.
110    labelers: Arc<RwLock<Vec<LabelerOpts>>>,
111}
112
113/// A single labeler entry, mirroring TS `Agent`'s `AtprotoLabelerDef`.
114#[derive(Debug, Clone, PartialEq, Eq)]
115pub struct LabelerOpts {
116    /// The labeler's DID (e.g. `did:plc:<labeler>`).
117    pub did: Did,
118    /// When `true`, this labeler is redirected to (sent as
119    /// `atproto-accept-labelers: did;redirect`). Matches TS behaviour.
120    pub redirect: bool,
121}
122
123impl LabelerOpts {
124    /// Format a single labeler for the `atproto-accept-labelers` header.
125    fn header_value(&self) -> String {
126        if self.redirect {
127            format!("{};redirect", self.did)
128        } else {
129            self.did.to_string()
130        }
131    }
132}
133
134impl Agent {
135    /// Create a new agent pointing at the given service URL.
136    ///
137    /// Available whenever [`XrpcClient::new`] is — native requires
138    /// `fetch-reqwest`; on wasm the default is always the browser
139    /// `fetch` backend.
140    #[cfg(any(
141        all(feature = "fetch-reqwest", not(target_arch = "wasm32")),
142        target_arch = "wasm32",
143    ))]
144    pub fn new(service: impl AsRef<str>) -> Result<Self, AgentError> {
145        let client = XrpcClient::new(service)?;
146        Ok(Self {
147            client,
148            session: Arc::new(RwLock::new(None)),
149            listeners: Arc::new(Mutex::new(Vec::new())),
150            refresh_lock: Arc::new(AsyncMutex::new(())),
151            proxy: Arc::new(RwLock::new(None)),
152            labelers: Arc::new(RwLock::new(Vec::new())),
153        })
154    }
155
156    /// Register a session-event listener.
157    ///
158    /// Returns `()`, not a handle — listener unregistration isn't
159    /// currently supported (the typical pattern is to register a
160    /// single persistence callback that lives for the Agent's
161    /// lifetime). Multiple listeners are fired in registration order.
162    pub fn on_session<F>(&self, callback: F)
163    where
164        F: Fn(AtpSessionEvent, Option<&Session>) + Send + Sync + 'static,
165    {
166        self.listeners.lock().unwrap().push(Arc::new(callback));
167    }
168
169    /// Fire an event to every registered listener.
170    fn emit(&self, event: AtpSessionEvent, session: Option<&Session>) {
171        // `listeners` is a sync Mutex; we clone the Arc<callback> list
172        // out from under the lock so the callbacks themselves run
173        // without holding it (they could be slow / could call back
174        // into `on_session`).
175        let listeners = self.listeners.lock().unwrap().clone();
176        for cb in listeners {
177            cb(event, session);
178        }
179    }
180
181    /// Get the service URL string.
182    #[must_use]
183    pub fn service(&self) -> String {
184        self.client.service_url().to_string()
185    }
186
187    /// Get the current session's DID, if logged in.
188    pub async fn did(&self) -> Option<Did> {
189        self.session.read().await.as_ref().map(|s| s.did.clone())
190    }
191
192    /// Get the current session, if any.
193    pub async fn session(&self) -> Option<Session> {
194        self.session.read().await.clone()
195    }
196
197    // --- Authentication ---
198
199    /// Build per-request `CallOptions` carrying the current access
200    /// token, proxy target, and labeler list. Returns `None` if not
201    /// authenticated (the proxy + labelers are still folded into the
202    /// call options of non-auth helpers via [`Self::anon_call_options`]).
203    async fn auth_call_options(&self) -> Option<CallOptions> {
204        let guard = self.session.read().await;
205        let session = guard.as_ref()?;
206        let mut headers = HeadersMap::new();
207        headers.insert(
208            "Authorization".into(),
209            format!("Bearer {}", session.access_jwt),
210        );
211        self.inject_proxy_and_labelers(&mut headers).await;
212        Some(CallOptions {
213            encoding: None,
214            headers: Some(headers),
215            ..Default::default()
216        })
217    }
218
219    /// Build anonymous [`CallOptions`] carrying just the proxy and
220    /// labeler config, for methods that don't need auth.
221    ///
222    /// Exposed in case callers drive `XrpcClient::query` / `::procedure`
223    /// directly and want the agent's proxy / labeler headers folded in.
224    pub async fn anon_call_options(&self) -> Option<CallOptions> {
225        let mut headers = HeadersMap::new();
226        self.inject_proxy_and_labelers(&mut headers).await;
227        if headers.is_empty() {
228            None
229        } else {
230            Some(CallOptions {
231                encoding: None,
232                headers: Some(headers),
233                ..Default::default()
234            })
235        }
236    }
237
238    async fn inject_proxy_and_labelers(&self, headers: &mut HeadersMap) {
239        if let Some(proxy) = self.proxy.read().await.as_ref() {
240            headers.insert("atproto-proxy".into(), proxy.clone());
241        }
242        let labelers = self.labelers.read().await;
243        if !labelers.is_empty() {
244            let v = labelers
245                .iter()
246                .map(LabelerOpts::header_value)
247                .collect::<Vec<_>>()
248                .join(", ");
249            headers.insert("atproto-accept-labelers".into(), v);
250        }
251    }
252
253    /// Configure the service-proxy target (`atproto-proxy` header) for
254    /// every subsequent call. Pass `None` to clear.
255    ///
256    /// The canonical use case is chat, which runs on a different
257    /// service: `agent.configure_proxy(Some("did:web:api.bsky.chat#bsky_chat"))`.
258    pub async fn configure_proxy(&self, target: Option<&str>) {
259        *self.proxy.write().await = target.map(String::from);
260    }
261
262    /// Return a new [`Agent`] configured with the given proxy target.
263    /// Shares session state with this agent (cheap clone of internals).
264    pub async fn with_proxy(&self, target: &str) -> Self {
265        let cloned = self.shallow_clone();
266        cloned.configure_proxy(Some(target)).await;
267        cloned
268    }
269
270    /// Configure the set of labelers sent as `atproto-accept-labelers`.
271    /// Passing an empty slice clears the header.
272    pub async fn configure_labelers(&self, labelers: &[LabelerOpts]) {
273        *self.labelers.write().await = labelers.to_vec();
274    }
275
276    /// Shallow-clone the agent: shares session / listener / refresh
277    /// state but receives independent proxy + labeler config. Used by
278    /// [`Self::with_proxy`].
279    fn shallow_clone(&self) -> Self {
280        Self {
281            client: self.client.clone(),
282            session: self.session.clone(),
283            listeners: self.listeners.clone(),
284            refresh_lock: self.refresh_lock.clone(),
285            proxy: Arc::new(RwLock::new(None)),
286            labelers: self.labelers.clone(),
287        }
288    }
289
290    /// Log in with identifier (handle or DID) and password.
291    ///
292    /// Emits [`AtpSessionEvent::Create`] on success, or
293    /// [`AtpSessionEvent::CreateFailed`] if the server rejected the
294    /// credentials.
295    pub async fn login(
296        &self,
297        identifier: &AtIdentifier,
298        password: &str,
299    ) -> Result<Session, AgentError> {
300        let body = serde_json::json!({
301            "identifier": identifier,
302            "password": password,
303        });
304
305        let response = match self
306            .client
307            .procedure(
308                "com.atproto.server.createSession",
309                None,
310                Some(XrpcBody::Json(body)),
311                None,
312            )
313            .await
314        {
315            Ok(r) => r,
316            Err(e) => {
317                self.emit(AtpSessionEvent::CreateFailed, None);
318                return Err(AgentError::Xrpc(e));
319            }
320        };
321
322        let session: Session = serde_json::from_value(response.data)?;
323
324        // Atomically commit session in a single write lock
325        *self.session.write().await = Some(session.clone());
326        self.emit(AtpSessionEvent::Create, Some(&session));
327        Ok(session)
328    }
329
330    /// Resume an existing session.
331    ///
332    /// Verifies the session with the server *before* updating internal state.
333    /// If verification fails, the agent remains unauthenticated.
334    pub async fn resume_session(&self, session: Session) -> Result<(), AgentError> {
335        // Verify the session is valid by calling getSession with the provided token,
336        // WITHOUT updating the agent's state first. Use a per-request auth header.
337        let mut headers = HeadersMap::new();
338        headers.insert(
339            "Authorization".into(),
340            format!("Bearer {}", session.access_jwt),
341        );
342        let opts = CallOptions {
343            encoding: None,
344            headers: Some(headers),
345            ..Default::default()
346        };
347        let response = self
348            .client
349            .query("com.atproto.server.getSession", None, Some(&opts))
350            .await?;
351        let verified_did = response
352            .data
353            .get("did")
354            .and_then(|v| v.as_str())
355            .map(Did::new)
356            .transpose()
357            .map_err(|e| AgentError::Other(format!("server returned invalid DID: {e}")))?;
358
359        // Verification succeeded — atomically commit state in a single write lock
360        let mut committed = session;
361        if let Some(did) = verified_did {
362            committed.did = did;
363        }
364        *self.session.write().await = Some(committed.clone());
365        self.emit(AtpSessionEvent::Create, Some(&committed));
366
367        Ok(())
368    }
369
370    /// Refresh the current session tokens.
371    ///
372    /// Emits [`AtpSessionEvent::Update`] on success or
373    /// [`AtpSessionEvent::Expired`] if the refresh token was
374    /// rejected. Uses a per-request header for the refresh call so the
375    /// refresh JWT is never exposed as the global auth state. The new
376    /// session is committed atomically in a single write lock.
377    pub async fn refresh_session(&self) -> Result<Session, AgentError> {
378        let refresh_jwt = {
379            let sess = self.session.read().await;
380            let sess = sess.as_ref().ok_or(AgentError::NotAuthenticated)?;
381            sess.refresh_jwt.clone()
382        };
383
384        // Use per-request header for refresh — never mutate global auth state
385        let mut headers = HeadersMap::new();
386        headers.insert("Authorization".into(), format!("Bearer {refresh_jwt}"));
387        let opts = CallOptions {
388            encoding: None,
389            headers: Some(headers),
390            ..Default::default()
391        };
392
393        let response = match self
394            .client
395            .procedure("com.atproto.server.refreshSession", None, None, Some(&opts))
396            .await
397        {
398            Ok(r) => r,
399            Err(e) => {
400                // Any 401 during refresh means the refresh token
401                // itself is rejected — drop the session and signal
402                // Expired. Other errors (network failure, 5xx, etc.)
403                // surface as NetworkError and leave the session in
404                // place so a later attempt can retry.
405                if is_refresh_rejected(&e) {
406                    *self.session.write().await = None;
407                    self.emit(AtpSessionEvent::Expired, None);
408                } else {
409                    self.emit(AtpSessionEvent::NetworkError, None);
410                }
411                return Err(AgentError::Xrpc(e));
412            }
413        };
414
415        let session: Session = serde_json::from_value(response.data)?;
416
417        // Atomically commit new session in a single write lock
418        *self.session.write().await = Some(session.clone());
419        self.emit(AtpSessionEvent::Update, Some(&session));
420        Ok(session)
421    }
422
423    // --- Convenience helpers ---
424
425    /// Ensure the agent is authenticated, returning the DID.
426    async fn assert_did(&self) -> Result<Did, AgentError> {
427        self.did().await.ok_or(AgentError::NotAuthenticated)
428    }
429
430    /// Helper: make a query call with transparent 401-refresh retry.
431    ///
432    /// When the first attempt returns `ExpiredToken`, try to refresh
433    /// the session and replay the call once with the fresh access
434    /// token. Concurrent refreshes are deduplicated via
435    /// [`Agent::refresh_lock`].
436    async fn xrpc_query(
437        &self,
438        nsid: &str,
439        params: Option<&QueryParams>,
440    ) -> Result<serde_json::Value, AgentError> {
441        let opts = self.auth_call_options().await;
442        let first = self.client.query(nsid, params, opts.as_ref()).await;
443        match first {
444            Ok(r) => Ok(r.data),
445            Err(e) if is_auth_expired(&e) => {
446                self.refresh_and_retry(|opts| {
447                    let c = self.client.clone();
448                    let nsid = nsid.to_string();
449                    let params = params.cloned();
450                    async move { c.query(&nsid, params.as_ref(), opts.as_ref()).await }
451                })
452                .await
453            }
454            Err(e) => Err(AgentError::Xrpc(e)),
455        }
456    }
457
458    /// Helper: make a procedure call with transparent 401-refresh retry.
459    async fn xrpc_procedure(
460        &self,
461        nsid: &str,
462        body: serde_json::Value,
463    ) -> Result<serde_json::Value, AgentError> {
464        let opts = self.auth_call_options().await;
465        let first = self
466            .client
467            .procedure(
468                nsid,
469                None,
470                Some(XrpcBody::Json(body.clone())),
471                opts.as_ref(),
472            )
473            .await;
474        match first {
475            Ok(r) => Ok(r.data),
476            Err(e) if is_auth_expired(&e) => {
477                self.refresh_and_retry(|opts| {
478                    let c = self.client.clone();
479                    let nsid = nsid.to_string();
480                    let body = body.clone();
481                    async move {
482                        c.procedure(&nsid, None, Some(XrpcBody::Json(body)), opts.as_ref())
483                            .await
484                    }
485                })
486                .await
487            }
488            Err(e) => Err(AgentError::Xrpc(e)),
489        }
490    }
491
492    /// Shared refresh-and-retry driver.
493    ///
494    /// Acquires the `refresh_lock`, refreshes the session if the
495    /// access token in `self.session` is still the one that produced
496    /// the 401, rebuilds `CallOptions` from the new token, and runs
497    /// `replay(new_opts)`. Concurrent callers that arrive after the
498    /// lock is held observe the refreshed session when they get to
499    /// build their own opts — only one `/refreshSession` HTTP call
500    /// fires per refresh cycle.
501    async fn refresh_and_retry<F, Fut>(&self, replay: F) -> Result<serde_json::Value, AgentError>
502    where
503        F: FnOnce(Option<CallOptions>) -> Fut,
504        Fut: std::future::Future<
505                Output = Result<proto_blue_xrpc::XrpcResponse, proto_blue_xrpc::Error>,
506            >,
507    {
508        // Snapshot the access token the caller's first attempt used.
509        // After we acquire the refresh lock, compare — if a peer
510        // already refreshed, skip the redundant refresh.
511        let pre_refresh_jwt = self
512            .session
513            .read()
514            .await
515            .as_ref()
516            .map(|s| s.access_jwt.clone());
517        let _guard = self.refresh_lock.lock().await;
518        let current_jwt = self
519            .session
520            .read()
521            .await
522            .as_ref()
523            .map(|s| s.access_jwt.clone());
524        if pre_refresh_jwt == current_jwt {
525            // No peer did the refresh — we must.
526            self.refresh_session().await?;
527        }
528        drop(_guard);
529
530        let opts = self.auth_call_options().await;
531        let response = replay(opts).await?;
532        Ok(response.data)
533    }
534
535    /// Helper: create a record.
536    async fn create_record(
537        &self,
538        collection: &str,
539        record: serde_json::Value,
540    ) -> Result<serde_json::Value, AgentError> {
541        let did = self.assert_did().await?;
542        let body = serde_json::json!({
543            "repo": did,
544            "collection": collection,
545            "record": record,
546        });
547        self.xrpc_procedure("com.atproto.repo.createRecord", body)
548            .await
549    }
550
551    /// Helper: delete a record by AT-URI.
552    async fn delete_record(&self, collection: &str, uri: &AtUri) -> Result<(), AgentError> {
553        let did = self.assert_did().await?;
554        let rkey = uri
555            .rkey()
556            .ok_or_else(|| AgentError::Other("AT-URI has no rkey segment".into()))?;
557
558        let body = serde_json::json!({
559            "repo": did,
560            "collection": collection,
561            "rkey": rkey,
562        });
563        self.xrpc_procedure("com.atproto.repo.deleteRecord", body)
564            .await?;
565        Ok(())
566    }
567
568    /// Generate an ISO 8601 timestamp with millisecond precision.
569    fn now_iso() -> String {
570        chrono::Utc::now().to_rfc3339_opts(chrono::SecondsFormat::Millis, true)
571    }
572
573    /// Resolve a timestamp: use the provided value or generate one.
574    fn resolve_timestamp(created_at: Option<&str>) -> String {
575        created_at.map_or_else(Self::now_iso, String::from)
576    }
577
578    // --- Post operations ---
579
580    /// Create a new post.
581    ///
582    /// If `created_at` is `None`, the current time is used.
583    pub async fn post(
584        &self,
585        text: &str,
586        facets: Option<Vec<crate::rich_text::Facet>>,
587        created_at: Option<&str>,
588    ) -> Result<serde_json::Value, AgentError> {
589        let mut record = serde_json::json!({
590            "$type": "app.bsky.feed.post",
591            "text": text,
592            "createdAt": Self::resolve_timestamp(created_at),
593        });
594
595        if let Some(facets) = facets {
596            record["facets"] = serde_json::to_value(&facets)?;
597        }
598
599        self.create_record("app.bsky.feed.post", record).await
600    }
601
602    /// Create a post from `RichText` (includes detected facets).
603    pub async fn post_rich(
604        &self,
605        rt: &RichText,
606        created_at: Option<&str>,
607    ) -> Result<serde_json::Value, AgentError> {
608        let facets = if rt.facets().is_empty() {
609            None
610        } else {
611            Some(rt.facets().to_vec())
612        };
613        self.post(rt.text(), facets, created_at).await
614    }
615
616    /// Delete a post by AT-URI.
617    pub async fn delete_post(&self, uri: &AtUri) -> Result<(), AgentError> {
618        self.delete_record("app.bsky.feed.post", uri).await
619    }
620
621    // --- Like / Repost ---
622
623    /// Like a post.
624    ///
625    /// If `created_at` is `None`, the current time is used.
626    pub async fn like(
627        &self,
628        uri: &AtUri,
629        cid: &Cid,
630        created_at: Option<&str>,
631    ) -> Result<serde_json::Value, AgentError> {
632        let record = serde_json::json!({
633            "$type": "app.bsky.feed.like",
634            "subject": { "uri": uri, "cid": cid },
635            "createdAt": Self::resolve_timestamp(created_at),
636        });
637        self.create_record("app.bsky.feed.like", record).await
638    }
639
640    /// Unlike a post by AT-URI of the like record.
641    pub async fn delete_like(&self, like_uri: &AtUri) -> Result<(), AgentError> {
642        self.delete_record("app.bsky.feed.like", like_uri).await
643    }
644
645    /// Repost a post.
646    ///
647    /// If `created_at` is `None`, the current time is used.
648    pub async fn repost(
649        &self,
650        uri: &AtUri,
651        cid: &Cid,
652        created_at: Option<&str>,
653    ) -> Result<serde_json::Value, AgentError> {
654        let record = serde_json::json!({
655            "$type": "app.bsky.feed.repost",
656            "subject": { "uri": uri, "cid": cid },
657            "createdAt": Self::resolve_timestamp(created_at),
658        });
659        self.create_record("app.bsky.feed.repost", record).await
660    }
661
662    /// Delete a repost by AT-URI.
663    pub async fn delete_repost(&self, repost_uri: &AtUri) -> Result<(), AgentError> {
664        self.delete_record("app.bsky.feed.repost", repost_uri).await
665    }
666
667    // --- Follow ---
668
669    /// Follow a user by DID.
670    ///
671    /// If `created_at` is `None`, the current time is used.
672    pub async fn follow(
673        &self,
674        subject_did: &Did,
675        created_at: Option<&str>,
676    ) -> Result<serde_json::Value, AgentError> {
677        let record = serde_json::json!({
678            "$type": "app.bsky.graph.follow",
679            "subject": subject_did,
680            "createdAt": Self::resolve_timestamp(created_at),
681        });
682        self.create_record("app.bsky.graph.follow", record).await
683    }
684
685    /// Unfollow by AT-URI of the follow record.
686    pub async fn delete_follow(&self, follow_uri: &AtUri) -> Result<(), AgentError> {
687        self.delete_record("app.bsky.graph.follow", follow_uri)
688            .await
689    }
690
691    // --- Query helpers ---
692
693    /// Get a user's profile.
694    pub async fn get_profile(&self, actor: &AtIdentifier) -> Result<serde_json::Value, AgentError> {
695        let mut params = QueryParams::new();
696        params.insert("actor".into(), QueryValue::String(actor.to_string()));
697        self.xrpc_query("app.bsky.actor.getProfile", Some(&params))
698            .await
699    }
700
701    /// Get the home timeline.
702    pub async fn get_timeline(
703        &self,
704        limit: Option<i64>,
705        cursor: Option<&str>,
706    ) -> Result<serde_json::Value, AgentError> {
707        let mut params = QueryParams::new();
708        if let Some(limit) = limit {
709            params.insert("limit".into(), QueryValue::Integer(limit));
710        }
711        if let Some(cursor) = cursor {
712            params.insert("cursor".into(), QueryValue::String(cursor.into()));
713        }
714        self.xrpc_query("app.bsky.feed.getTimeline", Some(&params))
715            .await
716    }
717
718    /// Get a post thread.
719    pub async fn get_post_thread(
720        &self,
721        uri: &AtUri,
722        depth: Option<i64>,
723    ) -> Result<serde_json::Value, AgentError> {
724        let mut params = QueryParams::new();
725        params.insert("uri".into(), QueryValue::String(uri.to_string()));
726        if let Some(depth) = depth {
727            params.insert("depth".into(), QueryValue::Integer(depth));
728        }
729        self.xrpc_query("app.bsky.feed.getPostThread", Some(&params))
730            .await
731    }
732
733    /// Search actors.
734    pub async fn search_actors(
735        &self,
736        query: &str,
737        limit: Option<i64>,
738    ) -> Result<serde_json::Value, AgentError> {
739        let mut params = QueryParams::new();
740        params.insert("q".into(), QueryValue::String(query.into()));
741        if let Some(limit) = limit {
742            params.insert("limit".into(), QueryValue::Integer(limit));
743        }
744        self.xrpc_query("app.bsky.actor.searchActors", Some(&params))
745            .await
746    }
747
748    /// Resolve a handle to a DID.
749    pub async fn resolve_handle(&self, handle: &Handle) -> Result<Did, AgentError> {
750        let mut params = QueryParams::new();
751        params.insert("handle".into(), QueryValue::String(handle.to_string()));
752        let data = self
753            .xrpc_query("com.atproto.identity.resolveHandle", Some(&params))
754            .await?;
755        let did_str = data
756            .get("did")
757            .and_then(|v| v.as_str())
758            .ok_or_else(|| AgentError::Other("Missing DID in response".into()))?;
759        Did::new(did_str)
760            .map_err(|e| AgentError::Other(format!("server returned invalid DID: {e}")))
761    }
762
763    /// Get notifications.
764    pub async fn list_notifications(
765        &self,
766        limit: Option<i64>,
767        cursor: Option<&str>,
768    ) -> Result<serde_json::Value, AgentError> {
769        let mut params = QueryParams::new();
770        if let Some(limit) = limit {
771            params.insert("limit".into(), QueryValue::Integer(limit));
772        }
773        if let Some(cursor) = cursor {
774            params.insert("cursor".into(), QueryValue::String(cursor.into()));
775        }
776        self.xrpc_query("app.bsky.notification.listNotifications", Some(&params))
777            .await
778    }
779
780    /// Upload a blob (image, video, etc.).
781    pub async fn upload_blob(
782        &self,
783        data: Vec<u8>,
784        content_type: &str,
785    ) -> Result<serde_json::Value, AgentError> {
786        let mut headers = HeadersMap::new();
787        headers.insert("Content-Type".into(), content_type.into());
788
789        // Add auth header from session
790        if let Some(sess) = self.session.read().await.as_ref() {
791            headers.insert(
792                "Authorization".into(),
793                format!("Bearer {}", sess.access_jwt),
794            );
795        }
796
797        let opts = CallOptions {
798            encoding: Some(content_type.to_string()),
799            headers: Some(headers),
800            ..Default::default()
801        };
802
803        let response = self
804            .client
805            .procedure(
806                "com.atproto.repo.uploadBlob",
807                None,
808                Some(XrpcBody::Bytes(data)),
809                Some(&opts),
810            )
811            .await?;
812
813        Ok(response.data)
814    }
815
816    /// Describe the server.
817    pub async fn describe_server(&self) -> Result<serde_json::Value, AgentError> {
818        self.xrpc_query("com.atproto.server.describeServer", None)
819            .await
820    }
821
822    // --- Account lifecycle ---
823
824    /// Log out of the current session.
825    ///
826    /// Sends a best-effort `deleteSession` call using the current
827    /// **refresh** token (TS matches this — `deleteSession` requires
828    /// the refresh JWT, not the access JWT). Clears local session
829    /// state whether or not the server call succeeds, so the agent
830    /// always ends up unauthenticated.
831    pub async fn logout(&self) -> Result<(), AgentError> {
832        let refresh_jwt = {
833            let guard = self.session.read().await;
834            guard.as_ref().map(|s| s.refresh_jwt.clone())
835        };
836
837        let server_result = if let Some(refresh_jwt) = refresh_jwt {
838            let mut headers = HeadersMap::new();
839            headers.insert("Authorization".into(), format!("Bearer {refresh_jwt}"));
840            let opts = CallOptions {
841                encoding: None,
842                headers: Some(headers),
843                ..Default::default()
844            };
845            self.client
846                .procedure("com.atproto.server.deleteSession", None, None, Some(&opts))
847                .await
848                .map(|_| ())
849        } else {
850            Ok(())
851        };
852
853        // Always clear local state.
854        *self.session.write().await = None;
855        self.emit(AtpSessionEvent::Expired, None);
856
857        server_result.map_err(AgentError::Xrpc)
858    }
859
860    /// Create a new account on the current service.
861    ///
862    /// `extra` is merged into the request body — useful for passing
863    /// `inviteCode`, `verificationCode`, or custom provider-specific
864    /// fields without this method's signature needing to know every
865    /// option the server supports.
866    ///
867    /// On success, the new session is stored and `Create` is emitted.
868    pub async fn create_account(
869        &self,
870        handle: &Handle,
871        password: &str,
872        email: Option<&str>,
873        extra: Option<serde_json::Value>,
874    ) -> Result<Session, AgentError> {
875        let mut body = serde_json::json!({
876            "handle": handle,
877            "password": password,
878        });
879        if let Some(email) = email {
880            body["email"] = serde_json::Value::String(email.to_string());
881        }
882        if let Some(extra) = extra
883            && let Some(extra_map) = extra.as_object()
884            && let Some(body_map) = body.as_object_mut()
885        {
886            for (k, v) in extra_map {
887                body_map.insert(k.clone(), v.clone());
888            }
889        }
890
891        let response = match self
892            .client
893            .procedure(
894                "com.atproto.server.createAccount",
895                None,
896                Some(XrpcBody::Json(body)),
897                None,
898            )
899            .await
900        {
901            Ok(r) => r,
902            Err(e) => {
903                self.emit(AtpSessionEvent::CreateFailed, None);
904                return Err(AgentError::Xrpc(e));
905            }
906        };
907
908        let session: Session = serde_json::from_value(response.data)?;
909        *self.session.write().await = Some(session.clone());
910        self.emit(AtpSessionEvent::Create, Some(&session));
911        Ok(session)
912    }
913
914    /// Create-or-update the signed-in user's `app.bsky.actor.profile`
915    /// record.
916    ///
917    /// The `mutate` closure receives the existing profile record (or
918    /// `serde_json::Value::Null` if none exists) and returns the
919    /// desired next state. This pattern mirrors TS
920    /// `AtpAgent.upsertProfile(updateFn)`.
921    ///
922    /// The write uses `putRecord` with `swapRecord` for CAS safety;
923    /// if the swap fails we retry up to 5 times with a fresh read.
924    pub async fn upsert_profile<F>(&self, mutate: F) -> Result<serde_json::Value, AgentError>
925    where
926        F: Fn(serde_json::Value) -> serde_json::Value,
927    {
928        let did = self.assert_did().await?;
929        const MAX_RETRIES: u32 = 5;
930
931        for _ in 0..MAX_RETRIES {
932            // Read the existing profile (may 404 — that's fine).
933            let existing_result = self
934                .xrpc_query(
935                    "com.atproto.repo.getRecord",
936                    Some(&{
937                        let mut p = QueryParams::new();
938                        p.insert("repo".into(), QueryValue::String(did.to_string()));
939                        p.insert(
940                            "collection".into(),
941                            QueryValue::String("app.bsky.actor.profile".into()),
942                        );
943                        p.insert("rkey".into(), QueryValue::String("self".into()));
944                        p
945                    }),
946                )
947                .await;
948
949            let (existing_record, swap_cid) = match existing_result {
950                Ok(r) => {
951                    let record = r.get("value").cloned().unwrap_or(serde_json::Value::Null);
952                    let cid = r.get("cid").and_then(|v| v.as_str()).map(String::from);
953                    (record, cid)
954                }
955                Err(AgentError::Xrpc(ref e)) if is_not_found(e) => (serde_json::Value::Null, None),
956                Err(e) => return Err(e),
957            };
958
959            let updated = mutate(existing_record);
960            let mut body = serde_json::json!({
961                "repo": did,
962                "collection": "app.bsky.actor.profile",
963                "rkey": "self",
964                "record": updated,
965            });
966            if let Some(cid) = swap_cid {
967                body["swapRecord"] = serde_json::Value::String(cid);
968            }
969
970            match self
971                .xrpc_procedure("com.atproto.repo.putRecord", body)
972                .await
973            {
974                Ok(r) => return Ok(r),
975                Err(AgentError::Xrpc(ref e)) if is_invalid_swap(e) => {
976                    // Race lost — someone else updated between our read
977                    // and write. Loop and try again with a fresh read.
978                    continue;
979                }
980                Err(e) => return Err(e),
981            }
982        }
983
984        Err(AgentError::Other(
985            "upsert_profile: exceeded maximum retries due to concurrent writes".into(),
986        ))
987    }
988}
989
990/// `true` if an XRPC error is a 4xx that specifically indicates the
991/// record does not exist. `getRecord` uses `RecordNotFound`.
992fn is_not_found(err: &proto_blue_xrpc::Error) -> bool {
993    match err {
994        proto_blue_xrpc::Error::Xrpc(x) => x.is_error("RecordNotFound"),
995        _ => false,
996    }
997}
998
999/// `true` if the server rejected a `putRecord` because the `swapRecord`
1000/// CID didn't match — caller should re-read and retry.
1001fn is_invalid_swap(err: &proto_blue_xrpc::Error) -> bool {
1002    match err {
1003        proto_blue_xrpc::Error::Xrpc(x) => x.is_error("InvalidSwap"),
1004        _ => false,
1005    }
1006}
1007
1008/// `true` if an XRPC error signals that the access token is expired
1009/// and the caller should try to refresh. Looks for
1010/// `AuthenticationRequired` (401) with the specific `ExpiredToken`
1011/// error name — other 401 variants aren't necessarily caused by
1012/// expiry (e.g. wrong credentials, app-password rejection) and
1013/// shouldn't trigger the refresh-and-retry path.
1014fn is_auth_expired(err: &proto_blue_xrpc::Error) -> bool {
1015    match err {
1016        proto_blue_xrpc::Error::Xrpc(x) => {
1017            matches!(x.status, ResponseType::AuthenticationRequired) && x.is_error("ExpiredToken")
1018        }
1019        _ => false,
1020    }
1021}
1022
1023/// `true` if an error from `/refreshSession` signals that the refresh
1024/// token is rejected (rather than a transient network problem). Any
1025/// 401 from the refresh endpoint is authoritative — the token is
1026/// dead — regardless of the specific error-name code.
1027const fn is_refresh_rejected(err: &proto_blue_xrpc::Error) -> bool {
1028    match err {
1029        proto_blue_xrpc::Error::Xrpc(x) => {
1030            matches!(x.status, ResponseType::AuthenticationRequired)
1031        }
1032        _ => false,
1033    }
1034}
1035
1036#[cfg(test)]
1037mod tests {
1038    use super::*;
1039
1040    #[test]
1041    fn agent_creation() {
1042        let _agent = Agent::new("https://bsky.social").unwrap();
1043    }
1044
1045    #[test]
1046    fn session_serde_roundtrip() {
1047        let session = Session {
1048            did: Did::new("did:plc:abc123").unwrap(),
1049            handle: Handle::new("alice.bsky.social").unwrap(),
1050            access_jwt: "eyJ...".to_string(),
1051            refresh_jwt: "eyJ...".to_string(),
1052            email: Some("alice@example.com".to_string()),
1053            email_confirmed: Some(true),
1054        };
1055
1056        let json = serde_json::to_string(&session).unwrap();
1057        let parsed: Session = serde_json::from_str(&json).unwrap();
1058        assert_eq!(parsed.did.as_str(), "did:plc:abc123");
1059        assert_eq!(parsed.handle.as_str(), "alice.bsky.social");
1060        assert_eq!(parsed.email, Some("alice@example.com".to_string()));
1061    }
1062
1063    #[test]
1064    fn agent_error_display() {
1065        let err = AgentError::NotAuthenticated;
1066        assert_eq!(err.to_string(), "Not authenticated");
1067
1068        let err = AgentError::Other("test error".into());
1069        assert_eq!(err.to_string(), "test error");
1070    }
1071
1072    #[tokio::test]
1073    async fn agent_no_session_by_default() {
1074        let agent = Agent::new("https://bsky.social").unwrap();
1075        assert!(agent.did().await.is_none());
1076        assert!(agent.session().await.is_none());
1077    }
1078
1079    #[tokio::test]
1080    async fn agent_assert_did_fails_when_not_logged_in() {
1081        let agent = Agent::new("https://bsky.social").unwrap();
1082        let err = agent.assert_did().await.unwrap_err();
1083        assert!(matches!(err, AgentError::NotAuthenticated));
1084    }
1085
1086    #[test]
1087    fn now_iso_format() {
1088        let ts = Agent::now_iso();
1089        assert!(ts.ends_with('Z'));
1090        assert!(ts.contains('T'));
1091    }
1092
1093    #[test]
1094    fn resolve_timestamp_with_provided() {
1095        let ts = Agent::resolve_timestamp(Some("2024-01-15T12:00:00.000Z"));
1096        assert_eq!(ts, "2024-01-15T12:00:00.000Z");
1097    }
1098
1099    #[test]
1100    fn resolve_timestamp_without_provided() {
1101        let ts = Agent::resolve_timestamp(None);
1102        assert!(ts.ends_with('Z'));
1103        assert!(ts.contains('T'));
1104    }
1105
1106    #[test]
1107    fn service_url_accessible_without_async() {
1108        let agent = Agent::new("https://bsky.social").unwrap();
1109        assert_eq!(agent.service(), "https://bsky.social/");
1110    }
1111
1112    #[tokio::test]
1113    async fn auth_call_options_none_when_not_authenticated() {
1114        let agent = Agent::new("https://bsky.social").unwrap();
1115        assert!(agent.auth_call_options().await.is_none());
1116    }
1117
1118    // ── Session events + auto-refresh ────────────────────────────────
1119
1120    use async_trait::async_trait;
1121    use proto_blue_common::fetch::{FetchError, FetchHandler, HttpRequest, HttpResponse};
1122
1123    /// Fetcher that scripts a sequence of responses for each NSID path.
1124    /// The first call to each NSID returns `responses[i][0]`, second
1125    /// `responses[i][1]`, etc. Also counts calls per NSID for assertions.
1126    struct ScriptedFetcher {
1127        createsession_body: Vec<u8>,
1128        /// (path_suffix, sequence_of_bodies)
1129        scripts: std::sync::Mutex<std::collections::HashMap<String, Vec<ScriptedResponse>>>,
1130        call_counts: std::sync::Mutex<std::collections::HashMap<String, usize>>,
1131    }
1132
1133    #[derive(Clone)]
1134    struct ScriptedResponse {
1135        status: u16,
1136        body: Vec<u8>,
1137    }
1138
1139    impl ScriptedFetcher {
1140        fn new(createsession_body: Vec<u8>) -> Self {
1141            Self {
1142                createsession_body,
1143                scripts: Default::default(),
1144                call_counts: Default::default(),
1145            }
1146        }
1147        fn script(&self, path: &str, responses: Vec<ScriptedResponse>) {
1148            self.scripts
1149                .lock()
1150                .unwrap()
1151                .insert(path.to_string(), responses);
1152        }
1153        fn call_count(&self, path: &str) -> usize {
1154            *self.call_counts.lock().unwrap().get(path).unwrap_or(&0)
1155        }
1156    }
1157
1158    #[async_trait]
1159    impl FetchHandler for ScriptedFetcher {
1160        async fn fetch(&self, req: HttpRequest) -> Result<HttpResponse, FetchError> {
1161            let path = req.url.clone();
1162            let key = path
1163                .split("/xrpc/")
1164                .nth(1)
1165                .unwrap_or(&path)
1166                .split('?')
1167                .next()
1168                .unwrap_or("")
1169                .to_string();
1170            *self
1171                .call_counts
1172                .lock()
1173                .unwrap()
1174                .entry(key.clone())
1175                .or_insert(0) += 1;
1176
1177            // Scripted responses always take precedence; the
1178            // createSession short-circuit only fires when the caller
1179            // hasn't explicitly scripted it.
1180            {
1181                let mut scripts = self.scripts.lock().unwrap();
1182                if let Some(list) = scripts.get_mut(&key) {
1183                    let resp = if list.len() == 1 {
1184                        list[0].clone()
1185                    } else {
1186                        list.remove(0)
1187                    };
1188                    let mut headers = proto_blue_common::fetch::HttpHeaders::new();
1189                    headers.insert("content-type".into(), "application/json".into());
1190                    return Ok(HttpResponse {
1191                        status: resp.status,
1192                        headers,
1193                        body: resp.body,
1194                    });
1195                }
1196            }
1197
1198            // Default: createSession always succeeds.
1199            if key == "com.atproto.server.createSession" {
1200                let mut headers = proto_blue_common::fetch::HttpHeaders::new();
1201                headers.insert("content-type".into(), "application/json".into());
1202                return Ok(HttpResponse {
1203                    status: 200,
1204                    headers,
1205                    body: self.createsession_body.clone(),
1206                });
1207            }
1208
1209            Err(FetchError::Other(format!("no script for {key}")))
1210        }
1211    }
1212
1213    fn login_body() -> Vec<u8> {
1214        br#"{"did":"did:plc:u","handle":"alice.test","accessJwt":"a1","refreshJwt":"r1"}"#.to_vec()
1215    }
1216
1217    fn agent_with_fetcher(fetcher: Arc<ScriptedFetcher>) -> Agent {
1218        let client = XrpcClient::with_fetch_handler("https://example.com", fetcher).unwrap();
1219        Agent {
1220            client,
1221            session: Arc::new(RwLock::new(None)),
1222            listeners: Arc::new(Mutex::new(Vec::new())),
1223            refresh_lock: Arc::new(AsyncMutex::new(())),
1224            proxy: Arc::new(RwLock::new(None)),
1225            labelers: Arc::new(RwLock::new(Vec::new())),
1226        }
1227    }
1228
1229    #[tokio::test]
1230    async fn emits_create_on_successful_login() {
1231        let fetcher = Arc::new(ScriptedFetcher::new(login_body()));
1232        let agent = agent_with_fetcher(fetcher);
1233
1234        let events: Arc<Mutex<Vec<AtpSessionEvent>>> = Arc::new(Mutex::new(Vec::new()));
1235        let ev_clone = events.clone();
1236        agent.on_session(move |e, _| ev_clone.lock().unwrap().push(e));
1237
1238        agent
1239            .login(&AtIdentifier::new("alice.test").unwrap(), "secret")
1240            .await
1241            .unwrap();
1242        let got = events.lock().unwrap().clone();
1243        assert_eq!(got, vec![AtpSessionEvent::Create]);
1244    }
1245
1246    #[tokio::test]
1247    async fn emits_create_failed_on_login_rejection() {
1248        let fetcher = Arc::new(ScriptedFetcher::new(vec![]));
1249        // Override createSession to fail:
1250        fetcher.script(
1251            "com.atproto.server.createSession",
1252            vec![ScriptedResponse {
1253                status: 401,
1254                body: br#"{"error":"AuthenticationRequired","message":"bad pwd"}"#.to_vec(),
1255            }],
1256        );
1257        let agent = agent_with_fetcher(fetcher);
1258
1259        let events: Arc<Mutex<Vec<AtpSessionEvent>>> = Arc::new(Mutex::new(Vec::new()));
1260        let ev_clone = events.clone();
1261        agent.on_session(move |e, _| ev_clone.lock().unwrap().push(e));
1262
1263        // Override `createsession_body` handler: scripts take precedence.
1264        // ScriptedFetcher's createSession short-circuit only fires when
1265        // NOT scripted; since we scripted it, the 401 flows through.
1266        let _ = agent
1267            .login(&AtIdentifier::new("alice.test").unwrap(), "bad")
1268            .await
1269            .unwrap_err();
1270        let got = events.lock().unwrap().clone();
1271        assert_eq!(got, vec![AtpSessionEvent::CreateFailed]);
1272    }
1273
1274    #[tokio::test]
1275    async fn auto_refreshes_on_expired_access_token() {
1276        let fetcher = Arc::new(ScriptedFetcher::new(login_body()));
1277
1278        // First call to describeServer returns 401 ExpiredToken,
1279        // second call (post-refresh) returns 200.
1280        fetcher.script(
1281            "com.atproto.server.describeServer",
1282            vec![
1283                ScriptedResponse {
1284                    status: 401,
1285                    body: br#"{"error":"ExpiredToken","message":"expired"}"#.to_vec(),
1286                },
1287                ScriptedResponse {
1288                    status: 200,
1289                    body: br#"{"did":"did:plc:svr"}"#.to_vec(),
1290                },
1291            ],
1292        );
1293        fetcher.script(
1294            "com.atproto.server.refreshSession",
1295            vec![ScriptedResponse {
1296                status: 200,
1297                body: br#"{"did":"did:plc:u","handle":"alice.test","accessJwt":"a2","refreshJwt":"r2"}"#
1298                    .to_vec(),
1299            }],
1300        );
1301
1302        let agent = agent_with_fetcher(fetcher.clone());
1303        agent
1304            .login(&AtIdentifier::new("alice.test").unwrap(), "secret")
1305            .await
1306            .unwrap();
1307
1308        let events: Arc<Mutex<Vec<AtpSessionEvent>>> = Arc::new(Mutex::new(Vec::new()));
1309        let ev_clone = events.clone();
1310        agent.on_session(move |e, _| ev_clone.lock().unwrap().push(e));
1311
1312        let result = agent.describe_server().await.unwrap();
1313        assert_eq!(result["did"], "did:plc:svr");
1314
1315        // describeServer was called twice (first 401, second success
1316        // after refresh); refreshSession was called exactly once.
1317        assert_eq!(fetcher.call_count("com.atproto.server.describeServer"), 2);
1318        assert_eq!(fetcher.call_count("com.atproto.server.refreshSession"), 1);
1319
1320        // One Update event fired during the refresh.
1321        let got = events.lock().unwrap().clone();
1322        assert_eq!(got, vec![AtpSessionEvent::Update]);
1323    }
1324
1325    #[tokio::test]
1326    async fn concurrent_expired_token_refreshes_once() {
1327        let fetcher = Arc::new(ScriptedFetcher::new(login_body()));
1328
1329        // All 401s for the first three attempts; subsequent calls get
1330        // the scripted OK response (the last entry is reused).
1331        fetcher.script(
1332            "com.atproto.server.describeServer",
1333            vec![
1334                ScriptedResponse {
1335                    status: 401,
1336                    body: br#"{"error":"ExpiredToken","message":"expired"}"#.to_vec(),
1337                },
1338                ScriptedResponse {
1339                    status: 200,
1340                    body: br#"{"did":"did:plc:svr"}"#.to_vec(),
1341                },
1342            ],
1343        );
1344        fetcher.script(
1345            "com.atproto.server.refreshSession",
1346            vec![ScriptedResponse {
1347                status: 200,
1348                body: br#"{"did":"did:plc:u","handle":"alice.test","accessJwt":"a2","refreshJwt":"r2"}"#
1349                    .to_vec(),
1350            }],
1351        );
1352
1353        let agent = Arc::new(agent_with_fetcher(fetcher.clone()));
1354        agent
1355            .login(&AtIdentifier::new("alice.test").unwrap(), "secret")
1356            .await
1357            .unwrap();
1358
1359        // 5 concurrent calls all hit 401 on first attempt. Refresh
1360        // must fire exactly once — the dedup lock + access-token
1361        // staleness check guarantee this.
1362        let mut handles = Vec::new();
1363        for _ in 0..5 {
1364            let a = agent.clone();
1365            handles.push(tokio::spawn(async move {
1366                a.describe_server().await.unwrap();
1367            }));
1368        }
1369        for h in handles {
1370            h.await.unwrap();
1371        }
1372
1373        assert_eq!(
1374            fetcher.call_count("com.atproto.server.refreshSession"),
1375            1,
1376            "concurrent callers must share one refreshSession call",
1377        );
1378    }
1379
1380    #[tokio::test]
1381    async fn configure_proxy_sets_header_on_next_call() {
1382        let fetcher = Arc::new(ScriptedFetcher::new(login_body()));
1383        fetcher.script(
1384            "com.atproto.server.describeServer",
1385            vec![ScriptedResponse {
1386                status: 200,
1387                body: br#"{"did":"did:plc:svr"}"#.to_vec(),
1388            }],
1389        );
1390        let agent = agent_with_fetcher(fetcher.clone());
1391        agent
1392            .configure_proxy(Some("did:web:api.bsky.chat#bsky_chat"))
1393            .await;
1394
1395        agent.describe_server().await.unwrap();
1396
1397        // We can't easily inspect fetched headers from ScriptedFetcher
1398        // as it's structured; instead, assert the proxy is readable.
1399        let p = agent.proxy.read().await;
1400        assert_eq!(p.as_deref(), Some("did:web:api.bsky.chat#bsky_chat"));
1401    }
1402
1403    #[tokio::test]
1404    async fn configure_labelers_stores_list() {
1405        let fetcher = Arc::new(ScriptedFetcher::new(login_body()));
1406        let agent = agent_with_fetcher(fetcher);
1407        agent
1408            .configure_labelers(&[
1409                LabelerOpts {
1410                    did: Did::new("did:plc:a").unwrap(),
1411                    redirect: false,
1412                },
1413                LabelerOpts {
1414                    did: Did::new("did:plc:b").unwrap(),
1415                    redirect: true,
1416                },
1417            ])
1418            .await;
1419        let l = agent.labelers.read().await;
1420        assert_eq!(l.len(), 2);
1421        assert_eq!(l[0].header_value(), "did:plc:a");
1422        assert_eq!(l[1].header_value(), "did:plc:b;redirect");
1423    }
1424
1425    #[tokio::test]
1426    async fn logout_clears_session() {
1427        let fetcher = Arc::new(ScriptedFetcher::new(login_body()));
1428        fetcher.script(
1429            "com.atproto.server.deleteSession",
1430            vec![ScriptedResponse {
1431                status: 200,
1432                body: b"{}".to_vec(),
1433            }],
1434        );
1435        let agent = agent_with_fetcher(fetcher.clone());
1436        agent
1437            .login(&AtIdentifier::new("alice.test").unwrap(), "secret")
1438            .await
1439            .unwrap();
1440        assert!(agent.session().await.is_some());
1441        agent.logout().await.unwrap();
1442        assert!(agent.session().await.is_none());
1443        assert_eq!(fetcher.call_count("com.atproto.server.deleteSession"), 1,);
1444    }
1445
1446    #[tokio::test]
1447    async fn logout_clears_session_even_on_server_error() {
1448        let fetcher = Arc::new(ScriptedFetcher::new(login_body()));
1449        fetcher.script(
1450            "com.atproto.server.deleteSession",
1451            vec![ScriptedResponse {
1452                status: 500,
1453                body: br#"{"error":"InternalServerError"}"#.to_vec(),
1454            }],
1455        );
1456        let agent = agent_with_fetcher(fetcher);
1457        agent
1458            .login(&AtIdentifier::new("alice.test").unwrap(), "secret")
1459            .await
1460            .unwrap();
1461        // Server call fails, but local state must still be cleared.
1462        let _ = agent.logout().await;
1463        assert!(agent.session().await.is_none());
1464    }
1465
1466    #[tokio::test]
1467    async fn create_account_emits_create_on_success() {
1468        let fetcher = Arc::new(ScriptedFetcher::new(login_body()));
1469        fetcher.script(
1470            "com.atproto.server.createAccount",
1471            vec![ScriptedResponse {
1472                status: 200,
1473                body:
1474                    br#"{"did":"did:plc:new","handle":"newuser.test","accessJwt":"a","refreshJwt":"r"}"#
1475                        .to_vec(),
1476            }],
1477        );
1478        let agent = agent_with_fetcher(fetcher);
1479
1480        let events: Arc<Mutex<Vec<AtpSessionEvent>>> = Arc::new(Mutex::new(Vec::new()));
1481        let ev = events.clone();
1482        agent.on_session(move |e, _| ev.lock().unwrap().push(e));
1483
1484        let session = agent
1485            .create_account(
1486                &Handle::new("newuser.test").unwrap(),
1487                "pw",
1488                Some("new@example.com"),
1489                None,
1490            )
1491            .await
1492            .unwrap();
1493        assert_eq!(session.did.as_str(), "did:plc:new");
1494        assert_eq!(
1495            events.lock().unwrap().clone(),
1496            vec![AtpSessionEvent::Create]
1497        );
1498    }
1499
1500    #[tokio::test]
1501    async fn upsert_profile_creates_when_absent() {
1502        let fetcher = Arc::new(ScriptedFetcher::new(login_body()));
1503        // getRecord returns 404 RecordNotFound
1504        fetcher.script(
1505            "com.atproto.repo.getRecord",
1506            vec![ScriptedResponse {
1507                status: 400,
1508                body: br#"{"error":"RecordNotFound","message":"no such record"}"#.to_vec(),
1509            }],
1510        );
1511        fetcher.script(
1512            "com.atproto.repo.putRecord",
1513            vec![ScriptedResponse {
1514                status: 200,
1515                body: br#"{"uri":"at://did:plc:u/app.bsky.actor.profile/self","cid":"bafy"}"#
1516                    .to_vec(),
1517            }],
1518        );
1519        let agent = agent_with_fetcher(fetcher);
1520        agent
1521            .login(&AtIdentifier::new("alice.test").unwrap(), "secret")
1522            .await
1523            .unwrap();
1524
1525        let result = agent
1526            .upsert_profile(|prev| {
1527                assert!(prev.is_null(), "no existing profile");
1528                serde_json::json!({"$type": "app.bsky.actor.profile", "displayName": "Alice"})
1529            })
1530            .await
1531            .unwrap();
1532        assert_eq!(result["uri"], "at://did:plc:u/app.bsky.actor.profile/self");
1533    }
1534
1535    #[tokio::test]
1536    async fn emits_expired_when_refresh_itself_401s() {
1537        let fetcher = Arc::new(ScriptedFetcher::new(login_body()));
1538        fetcher.script(
1539            "com.atproto.server.refreshSession",
1540            vec![ScriptedResponse {
1541                status: 401,
1542                body: br#"{"error":"AuthenticationRequired","message":"refresh expired"}"#.to_vec(),
1543            }],
1544        );
1545        let agent = agent_with_fetcher(fetcher);
1546        agent
1547            .login(&AtIdentifier::new("alice.test").unwrap(), "secret")
1548            .await
1549            .unwrap();
1550
1551        let events: Arc<Mutex<Vec<AtpSessionEvent>>> = Arc::new(Mutex::new(Vec::new()));
1552        let ev_clone = events.clone();
1553        agent.on_session(move |e, _| ev_clone.lock().unwrap().push(e));
1554
1555        let _ = agent.refresh_session().await.unwrap_err();
1556        let got = events.lock().unwrap().clone();
1557        assert_eq!(got, vec![AtpSessionEvent::Expired]);
1558        assert!(
1559            agent.session().await.is_none(),
1560            "session cleared on expired refresh"
1561        );
1562    }
1563}