Skip to main content

proto_blue_api/
agent.rs

1//! AT Protocol Agent — high-level client wrapping XRPC.
2//!
3//! Provides session management, convenience methods for common operations,
4//! and namespace accessors for the full Lexicon API surface.
5
6use std::sync::Arc;
7use tokio::sync::RwLock;
8
9use proto_blue_xrpc::{CallOptions, HeadersMap, QueryParams, QueryValue, XrpcBody, XrpcClient};
10
11use crate::rich_text::RichText;
12
13/// Session data for an authenticated agent.
14#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
15#[serde(rename_all = "camelCase")]
16pub struct Session {
17    pub did: String,
18    pub handle: String,
19    pub access_jwt: String,
20    pub refresh_jwt: String,
21    #[serde(skip_serializing_if = "Option::is_none")]
22    pub email: Option<String>,
23    #[serde(skip_serializing_if = "Option::is_none")]
24    pub email_confirmed: Option<bool>,
25}
26
27/// Errors from Agent operations.
28#[derive(Debug, thiserror::Error)]
29pub enum AgentError {
30    #[error("XRPC error: {0}")]
31    Xrpc(#[from] proto_blue_xrpc::Error),
32    #[error("Not authenticated")]
33    NotAuthenticated,
34    #[error("JSON error: {0}")]
35    Json(#[from] serde_json::Error),
36    #[error("{0}")]
37    Other(String),
38}
39
40/// High-level AT Protocol agent.
41///
42/// Auth state lives in a single `RwLock<Option<Session>>`. The XRPC client
43/// is never mutated after construction — auth headers are passed per-request.
44/// This avoids token leaks, giant-lock contention, and split-lock atomicity
45/// gaps that arise from storing auth in the client's default headers.
46pub struct Agent {
47    client: XrpcClient,
48    session: Arc<RwLock<Option<Session>>>,
49}
50
51impl Agent {
52    /// Create a new agent pointing at the given service URL.
53    pub fn new(service: impl AsRef<str>) -> Result<Self, AgentError> {
54        let client = XrpcClient::new(service)?;
55        Ok(Agent {
56            client,
57            session: Arc::new(RwLock::new(None)),
58        })
59    }
60
61    /// Get the service URL string.
62    pub fn service(&self) -> String {
63        self.client.service_url().to_string()
64    }
65
66    /// Get the current session's DID, if logged in.
67    pub async fn did(&self) -> Option<String> {
68        self.session.read().await.as_ref().map(|s| s.did.clone())
69    }
70
71    /// Get the current session, if any.
72    pub async fn session(&self) -> Option<Session> {
73        self.session.read().await.clone()
74    }
75
76    // --- Authentication ---
77
78    /// Build per-request `CallOptions` carrying the current access token.
79    /// Returns `None` if not authenticated.
80    async fn auth_call_options(&self) -> Option<CallOptions> {
81        let guard = self.session.read().await;
82        guard.as_ref().map(|s| {
83            let mut headers = HeadersMap::new();
84            headers.insert("Authorization".into(), format!("Bearer {}", s.access_jwt));
85            CallOptions {
86                encoding: None,
87                headers: Some(headers),
88            }
89        })
90    }
91
92    /// Log in with identifier (handle or DID) and password.
93    pub async fn login(&self, identifier: &str, password: &str) -> Result<Session, AgentError> {
94        let body = serde_json::json!({
95            "identifier": identifier,
96            "password": password,
97        });
98
99        let response = self
100            .client
101            .procedure(
102                "com.atproto.server.createSession",
103                None,
104                Some(XrpcBody::Json(body)),
105                None,
106            )
107            .await?;
108
109        let session: Session = serde_json::from_value(response.data)?;
110
111        // Atomically commit session in a single write lock
112        *self.session.write().await = Some(session.clone());
113        Ok(session)
114    }
115
116    /// Resume an existing session.
117    ///
118    /// Verifies the session with the server *before* updating internal state.
119    /// If verification fails, the agent remains unauthenticated.
120    pub async fn resume_session(&self, session: Session) -> Result<(), AgentError> {
121        // Verify the session is valid by calling getSession with the provided token,
122        // WITHOUT updating the agent's state first. Use a per-request auth header.
123        let mut headers = HeadersMap::new();
124        headers.insert(
125            "Authorization".into(),
126            format!("Bearer {}", session.access_jwt),
127        );
128        let opts = CallOptions {
129            encoding: None,
130            headers: Some(headers),
131        };
132        let response = self
133            .client
134            .query("com.atproto.server.getSession", None, Some(&opts))
135            .await?;
136        let verified_did = response
137            .data
138            .get("did")
139            .and_then(|v| v.as_str())
140            .map(|s| s.to_string());
141
142        // Verification succeeded — atomically commit state in a single write lock
143        let mut committed = session;
144        if let Some(did) = verified_did {
145            committed.did = did;
146        }
147        *self.session.write().await = Some(committed);
148
149        Ok(())
150    }
151
152    /// Refresh the current session tokens.
153    ///
154    /// Uses a per-request header for the refresh call so the refresh JWT is
155    /// never exposed as the global auth state. The new session is committed
156    /// atomically in a single write lock.
157    pub async fn refresh_session(&self) -> Result<Session, AgentError> {
158        let refresh_jwt = {
159            let sess = self.session.read().await;
160            let sess = sess.as_ref().ok_or(AgentError::NotAuthenticated)?;
161            sess.refresh_jwt.clone()
162        };
163
164        // Use per-request header for refresh — never mutate global auth state
165        let mut headers = HeadersMap::new();
166        headers.insert("Authorization".into(), format!("Bearer {}", refresh_jwt));
167        let opts = CallOptions {
168            encoding: None,
169            headers: Some(headers),
170        };
171
172        let response = self
173            .client
174            .procedure("com.atproto.server.refreshSession", None, None, Some(&opts))
175            .await?;
176
177        let session: Session = serde_json::from_value(response.data)?;
178
179        // Atomically commit new session in a single write lock
180        *self.session.write().await = Some(session.clone());
181        Ok(session)
182    }
183
184    // --- Convenience helpers ---
185
186    /// Ensure the agent is authenticated, returning the DID.
187    async fn assert_did(&self) -> Result<String, AgentError> {
188        self.did().await.ok_or(AgentError::NotAuthenticated)
189    }
190
191    /// Helper: make a query call.
192    async fn xrpc_query(
193        &self,
194        nsid: &str,
195        params: Option<&QueryParams>,
196    ) -> Result<serde_json::Value, AgentError> {
197        let opts = self.auth_call_options().await;
198        let response = self.client.query(nsid, params, opts.as_ref()).await?;
199        Ok(response.data)
200    }
201
202    /// Helper: make a procedure call with JSON body.
203    async fn xrpc_procedure(
204        &self,
205        nsid: &str,
206        body: serde_json::Value,
207    ) -> Result<serde_json::Value, AgentError> {
208        let opts = self.auth_call_options().await;
209        let response = self
210            .client
211            .procedure(nsid, None, Some(XrpcBody::Json(body)), opts.as_ref())
212            .await?;
213        Ok(response.data)
214    }
215
216    /// Helper: create a record.
217    async fn create_record(
218        &self,
219        collection: &str,
220        record: serde_json::Value,
221    ) -> Result<serde_json::Value, AgentError> {
222        let did = self.assert_did().await?;
223        let body = serde_json::json!({
224            "repo": did,
225            "collection": collection,
226            "record": record,
227        });
228        self.xrpc_procedure("com.atproto.repo.createRecord", body)
229            .await
230    }
231
232    /// Helper: delete a record by AT-URI.
233    async fn delete_record(&self, collection: &str, uri: &str) -> Result<(), AgentError> {
234        let did = self.assert_did().await?;
235        let rkey = uri
236            .rsplit('/')
237            .next()
238            .ok_or_else(|| AgentError::Other("Invalid AT-URI".into()))?;
239
240        let body = serde_json::json!({
241            "repo": did,
242            "collection": collection,
243            "rkey": rkey,
244        });
245        self.xrpc_procedure("com.atproto.repo.deleteRecord", body)
246            .await?;
247        Ok(())
248    }
249
250    /// Generate an ISO 8601 timestamp with millisecond precision.
251    fn now_iso() -> String {
252        chrono::Utc::now().to_rfc3339_opts(chrono::SecondsFormat::Millis, true)
253    }
254
255    /// Resolve a timestamp: use the provided value or generate one.
256    fn resolve_timestamp(created_at: Option<&str>) -> String {
257        created_at.map(String::from).unwrap_or_else(Self::now_iso)
258    }
259
260    // --- Post operations ---
261
262    /// Create a new post.
263    ///
264    /// If `created_at` is `None`, the current time is used.
265    pub async fn post(
266        &self,
267        text: &str,
268        facets: Option<Vec<crate::rich_text::Facet>>,
269        created_at: Option<&str>,
270    ) -> Result<serde_json::Value, AgentError> {
271        let mut record = serde_json::json!({
272            "$type": "app.bsky.feed.post",
273            "text": text,
274            "createdAt": Self::resolve_timestamp(created_at),
275        });
276
277        if let Some(facets) = facets {
278            record["facets"] = serde_json::to_value(&facets)?;
279        }
280
281        self.create_record("app.bsky.feed.post", record).await
282    }
283
284    /// Create a post from RichText (includes detected facets).
285    pub async fn post_rich(
286        &self,
287        rt: &RichText,
288        created_at: Option<&str>,
289    ) -> Result<serde_json::Value, AgentError> {
290        let facets = if rt.facets().is_empty() {
291            None
292        } else {
293            Some(rt.facets().to_vec())
294        };
295        self.post(rt.text(), facets, created_at).await
296    }
297
298    /// Delete a post by AT-URI.
299    pub async fn delete_post(&self, uri: &str) -> Result<(), AgentError> {
300        self.delete_record("app.bsky.feed.post", uri).await
301    }
302
303    // --- Like / Repost ---
304
305    /// Like a post.
306    ///
307    /// If `created_at` is `None`, the current time is used.
308    pub async fn like(
309        &self,
310        uri: &str,
311        cid: &str,
312        created_at: Option<&str>,
313    ) -> Result<serde_json::Value, AgentError> {
314        let record = serde_json::json!({
315            "$type": "app.bsky.feed.like",
316            "subject": { "uri": uri, "cid": cid },
317            "createdAt": Self::resolve_timestamp(created_at),
318        });
319        self.create_record("app.bsky.feed.like", record).await
320    }
321
322    /// Unlike a post by AT-URI of the like record.
323    pub async fn delete_like(&self, like_uri: &str) -> Result<(), AgentError> {
324        self.delete_record("app.bsky.feed.like", like_uri).await
325    }
326
327    /// Repost a post.
328    ///
329    /// If `created_at` is `None`, the current time is used.
330    pub async fn repost(
331        &self,
332        uri: &str,
333        cid: &str,
334        created_at: Option<&str>,
335    ) -> Result<serde_json::Value, AgentError> {
336        let record = serde_json::json!({
337            "$type": "app.bsky.feed.repost",
338            "subject": { "uri": uri, "cid": cid },
339            "createdAt": Self::resolve_timestamp(created_at),
340        });
341        self.create_record("app.bsky.feed.repost", record).await
342    }
343
344    /// Delete a repost by AT-URI.
345    pub async fn delete_repost(&self, repost_uri: &str) -> Result<(), AgentError> {
346        self.delete_record("app.bsky.feed.repost", repost_uri).await
347    }
348
349    // --- Follow ---
350
351    /// Follow a user by DID.
352    ///
353    /// If `created_at` is `None`, the current time is used.
354    pub async fn follow(
355        &self,
356        subject_did: &str,
357        created_at: Option<&str>,
358    ) -> Result<serde_json::Value, AgentError> {
359        let record = serde_json::json!({
360            "$type": "app.bsky.graph.follow",
361            "subject": subject_did,
362            "createdAt": Self::resolve_timestamp(created_at),
363        });
364        self.create_record("app.bsky.graph.follow", record).await
365    }
366
367    /// Unfollow by AT-URI of the follow record.
368    pub async fn delete_follow(&self, follow_uri: &str) -> Result<(), AgentError> {
369        self.delete_record("app.bsky.graph.follow", follow_uri)
370            .await
371    }
372
373    // --- Query helpers ---
374
375    /// Get a user's profile.
376    pub async fn get_profile(&self, actor: &str) -> Result<serde_json::Value, AgentError> {
377        let mut params = QueryParams::new();
378        params.insert("actor".into(), QueryValue::String(actor.into()));
379        self.xrpc_query("app.bsky.actor.getProfile", Some(&params))
380            .await
381    }
382
383    /// Get the home timeline.
384    pub async fn get_timeline(
385        &self,
386        limit: Option<i64>,
387        cursor: Option<&str>,
388    ) -> Result<serde_json::Value, AgentError> {
389        let mut params = QueryParams::new();
390        if let Some(limit) = limit {
391            params.insert("limit".into(), QueryValue::Integer(limit));
392        }
393        if let Some(cursor) = cursor {
394            params.insert("cursor".into(), QueryValue::String(cursor.into()));
395        }
396        self.xrpc_query("app.bsky.feed.getTimeline", Some(&params))
397            .await
398    }
399
400    /// Get a post thread.
401    pub async fn get_post_thread(
402        &self,
403        uri: &str,
404        depth: Option<i64>,
405    ) -> Result<serde_json::Value, AgentError> {
406        let mut params = QueryParams::new();
407        params.insert("uri".into(), QueryValue::String(uri.into()));
408        if let Some(depth) = depth {
409            params.insert("depth".into(), QueryValue::Integer(depth));
410        }
411        self.xrpc_query("app.bsky.feed.getPostThread", Some(&params))
412            .await
413    }
414
415    /// Search actors.
416    pub async fn search_actors(
417        &self,
418        query: &str,
419        limit: Option<i64>,
420    ) -> Result<serde_json::Value, AgentError> {
421        let mut params = QueryParams::new();
422        params.insert("q".into(), QueryValue::String(query.into()));
423        if let Some(limit) = limit {
424            params.insert("limit".into(), QueryValue::Integer(limit));
425        }
426        self.xrpc_query("app.bsky.actor.searchActors", Some(&params))
427            .await
428    }
429
430    /// Resolve a handle to a DID.
431    pub async fn resolve_handle(&self, handle: &str) -> Result<String, AgentError> {
432        let mut params = QueryParams::new();
433        params.insert("handle".into(), QueryValue::String(handle.into()));
434        let data = self
435            .xrpc_query("com.atproto.identity.resolveHandle", Some(&params))
436            .await?;
437        data.get("did")
438            .and_then(|v| v.as_str())
439            .map(|s| s.to_string())
440            .ok_or_else(|| AgentError::Other("Missing DID in response".into()))
441    }
442
443    /// Get notifications.
444    pub async fn list_notifications(
445        &self,
446        limit: Option<i64>,
447        cursor: Option<&str>,
448    ) -> Result<serde_json::Value, AgentError> {
449        let mut params = QueryParams::new();
450        if let Some(limit) = limit {
451            params.insert("limit".into(), QueryValue::Integer(limit));
452        }
453        if let Some(cursor) = cursor {
454            params.insert("cursor".into(), QueryValue::String(cursor.into()));
455        }
456        self.xrpc_query("app.bsky.notification.listNotifications", Some(&params))
457            .await
458    }
459
460    /// Upload a blob (image, video, etc.).
461    pub async fn upload_blob(
462        &self,
463        data: Vec<u8>,
464        content_type: &str,
465    ) -> Result<serde_json::Value, AgentError> {
466        let mut headers = HeadersMap::new();
467        headers.insert("Content-Type".into(), content_type.into());
468
469        // Add auth header from session
470        if let Some(sess) = self.session.read().await.as_ref() {
471            headers.insert(
472                "Authorization".into(),
473                format!("Bearer {}", sess.access_jwt),
474            );
475        }
476
477        let opts = CallOptions {
478            encoding: Some(content_type.to_string()),
479            headers: Some(headers),
480        };
481
482        let response = self
483            .client
484            .procedure(
485                "com.atproto.repo.uploadBlob",
486                None,
487                Some(XrpcBody::Bytes(data)),
488                Some(&opts),
489            )
490            .await?;
491
492        Ok(response.data)
493    }
494
495    /// Describe the server.
496    pub async fn describe_server(&self) -> Result<serde_json::Value, AgentError> {
497        self.xrpc_query("com.atproto.server.describeServer", None)
498            .await
499    }
500}
501
502#[cfg(test)]
503mod tests {
504    use super::*;
505
506    #[test]
507    fn agent_creation() {
508        let _agent = Agent::new("https://bsky.social").unwrap();
509    }
510
511    #[test]
512    fn session_serde_roundtrip() {
513        let session = Session {
514            did: "did:plc:abc123".to_string(),
515            handle: "alice.bsky.social".to_string(),
516            access_jwt: "eyJ...".to_string(),
517            refresh_jwt: "eyJ...".to_string(),
518            email: Some("alice@example.com".to_string()),
519            email_confirmed: Some(true),
520        };
521
522        let json = serde_json::to_string(&session).unwrap();
523        let parsed: Session = serde_json::from_str(&json).unwrap();
524        assert_eq!(parsed.did, "did:plc:abc123");
525        assert_eq!(parsed.handle, "alice.bsky.social");
526        assert_eq!(parsed.email, Some("alice@example.com".to_string()));
527    }
528
529    #[test]
530    fn agent_error_display() {
531        let err = AgentError::NotAuthenticated;
532        assert_eq!(err.to_string(), "Not authenticated");
533
534        let err = AgentError::Other("test error".into());
535        assert_eq!(err.to_string(), "test error");
536    }
537
538    #[tokio::test]
539    async fn agent_no_session_by_default() {
540        let agent = Agent::new("https://bsky.social").unwrap();
541        assert!(agent.did().await.is_none());
542        assert!(agent.session().await.is_none());
543    }
544
545    #[tokio::test]
546    async fn agent_assert_did_fails_when_not_logged_in() {
547        let agent = Agent::new("https://bsky.social").unwrap();
548        let err = agent.assert_did().await.unwrap_err();
549        assert!(matches!(err, AgentError::NotAuthenticated));
550    }
551
552    #[test]
553    fn now_iso_format() {
554        let ts = Agent::now_iso();
555        assert!(ts.ends_with('Z'));
556        assert!(ts.contains('T'));
557    }
558
559    #[test]
560    fn resolve_timestamp_with_provided() {
561        let ts = Agent::resolve_timestamp(Some("2024-01-15T12:00:00.000Z"));
562        assert_eq!(ts, "2024-01-15T12:00:00.000Z");
563    }
564
565    #[test]
566    fn resolve_timestamp_without_provided() {
567        let ts = Agent::resolve_timestamp(None);
568        assert!(ts.ends_with('Z'));
569        assert!(ts.contains('T'));
570    }
571
572    #[test]
573    fn service_url_accessible_without_async() {
574        let agent = Agent::new("https://bsky.social").unwrap();
575        assert_eq!(agent.service(), "https://bsky.social/");
576    }
577
578    #[tokio::test]
579    async fn auth_call_options_none_when_not_authenticated() {
580        let agent = Agent::new("https://bsky.social").unwrap();
581        assert!(agent.auth_call_options().await.is_none());
582    }
583}