1use std::sync::Arc;
7use tokio::sync::RwLock;
8
9use proto_blue_xrpc::{CallOptions, HeadersMap, QueryParams, QueryValue, XrpcBody, XrpcClient};
10
11use crate::rich_text::RichText;
12
13#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
15#[serde(rename_all = "camelCase")]
16pub struct Session {
17 pub did: String,
18 pub handle: String,
19 pub access_jwt: String,
20 pub refresh_jwt: String,
21 #[serde(skip_serializing_if = "Option::is_none")]
22 pub email: Option<String>,
23 #[serde(skip_serializing_if = "Option::is_none")]
24 pub email_confirmed: Option<bool>,
25}
26
27#[derive(Debug, thiserror::Error)]
29pub enum AgentError {
30 #[error("XRPC error: {0}")]
31 Xrpc(#[from] proto_blue_xrpc::Error),
32 #[error("Not authenticated")]
33 NotAuthenticated,
34 #[error("JSON error: {0}")]
35 Json(#[from] serde_json::Error),
36 #[error("{0}")]
37 Other(String),
38}
39
40pub struct Agent {
47 client: XrpcClient,
48 session: Arc<RwLock<Option<Session>>>,
49}
50
51impl Agent {
52 pub fn new(service: impl AsRef<str>) -> Result<Self, AgentError> {
54 let client = XrpcClient::new(service)?;
55 Ok(Agent {
56 client,
57 session: Arc::new(RwLock::new(None)),
58 })
59 }
60
61 pub fn service(&self) -> String {
63 self.client.service_url().to_string()
64 }
65
66 pub async fn did(&self) -> Option<String> {
68 self.session.read().await.as_ref().map(|s| s.did.clone())
69 }
70
71 pub async fn session(&self) -> Option<Session> {
73 self.session.read().await.clone()
74 }
75
76 async fn auth_call_options(&self) -> Option<CallOptions> {
81 let guard = self.session.read().await;
82 guard.as_ref().map(|s| {
83 let mut headers = HeadersMap::new();
84 headers.insert("Authorization".into(), format!("Bearer {}", s.access_jwt));
85 CallOptions {
86 encoding: None,
87 headers: Some(headers),
88 }
89 })
90 }
91
92 pub async fn login(&self, identifier: &str, password: &str) -> Result<Session, AgentError> {
94 let body = serde_json::json!({
95 "identifier": identifier,
96 "password": password,
97 });
98
99 let response = self
100 .client
101 .procedure(
102 "com.atproto.server.createSession",
103 None,
104 Some(XrpcBody::Json(body)),
105 None,
106 )
107 .await?;
108
109 let session: Session = serde_json::from_value(response.data)?;
110
111 *self.session.write().await = Some(session.clone());
113 Ok(session)
114 }
115
116 pub async fn resume_session(&self, session: Session) -> Result<(), AgentError> {
121 let mut headers = HeadersMap::new();
124 headers.insert(
125 "Authorization".into(),
126 format!("Bearer {}", session.access_jwt),
127 );
128 let opts = CallOptions {
129 encoding: None,
130 headers: Some(headers),
131 };
132 let response = self
133 .client
134 .query("com.atproto.server.getSession", None, Some(&opts))
135 .await?;
136 let verified_did = response
137 .data
138 .get("did")
139 .and_then(|v| v.as_str())
140 .map(|s| s.to_string());
141
142 let mut committed = session;
144 if let Some(did) = verified_did {
145 committed.did = did;
146 }
147 *self.session.write().await = Some(committed);
148
149 Ok(())
150 }
151
152 pub async fn refresh_session(&self) -> Result<Session, AgentError> {
158 let refresh_jwt = {
159 let sess = self.session.read().await;
160 let sess = sess.as_ref().ok_or(AgentError::NotAuthenticated)?;
161 sess.refresh_jwt.clone()
162 };
163
164 let mut headers = HeadersMap::new();
166 headers.insert("Authorization".into(), format!("Bearer {}", refresh_jwt));
167 let opts = CallOptions {
168 encoding: None,
169 headers: Some(headers),
170 };
171
172 let response = self
173 .client
174 .procedure("com.atproto.server.refreshSession", None, None, Some(&opts))
175 .await?;
176
177 let session: Session = serde_json::from_value(response.data)?;
178
179 *self.session.write().await = Some(session.clone());
181 Ok(session)
182 }
183
184 async fn assert_did(&self) -> Result<String, AgentError> {
188 self.did().await.ok_or(AgentError::NotAuthenticated)
189 }
190
191 async fn xrpc_query(
193 &self,
194 nsid: &str,
195 params: Option<&QueryParams>,
196 ) -> Result<serde_json::Value, AgentError> {
197 let opts = self.auth_call_options().await;
198 let response = self.client.query(nsid, params, opts.as_ref()).await?;
199 Ok(response.data)
200 }
201
202 async fn xrpc_procedure(
204 &self,
205 nsid: &str,
206 body: serde_json::Value,
207 ) -> Result<serde_json::Value, AgentError> {
208 let opts = self.auth_call_options().await;
209 let response = self
210 .client
211 .procedure(nsid, None, Some(XrpcBody::Json(body)), opts.as_ref())
212 .await?;
213 Ok(response.data)
214 }
215
216 async fn create_record(
218 &self,
219 collection: &str,
220 record: serde_json::Value,
221 ) -> Result<serde_json::Value, AgentError> {
222 let did = self.assert_did().await?;
223 let body = serde_json::json!({
224 "repo": did,
225 "collection": collection,
226 "record": record,
227 });
228 self.xrpc_procedure("com.atproto.repo.createRecord", body)
229 .await
230 }
231
232 async fn delete_record(&self, collection: &str, uri: &str) -> Result<(), AgentError> {
234 let did = self.assert_did().await?;
235 let rkey = uri
236 .rsplit('/')
237 .next()
238 .ok_or_else(|| AgentError::Other("Invalid AT-URI".into()))?;
239
240 let body = serde_json::json!({
241 "repo": did,
242 "collection": collection,
243 "rkey": rkey,
244 });
245 self.xrpc_procedure("com.atproto.repo.deleteRecord", body)
246 .await?;
247 Ok(())
248 }
249
250 fn now_iso() -> String {
252 chrono::Utc::now().to_rfc3339_opts(chrono::SecondsFormat::Millis, true)
253 }
254
255 fn resolve_timestamp(created_at: Option<&str>) -> String {
257 created_at.map(String::from).unwrap_or_else(Self::now_iso)
258 }
259
260 pub async fn post(
266 &self,
267 text: &str,
268 facets: Option<Vec<crate::rich_text::Facet>>,
269 created_at: Option<&str>,
270 ) -> Result<serde_json::Value, AgentError> {
271 let mut record = serde_json::json!({
272 "$type": "app.bsky.feed.post",
273 "text": text,
274 "createdAt": Self::resolve_timestamp(created_at),
275 });
276
277 if let Some(facets) = facets {
278 record["facets"] = serde_json::to_value(&facets)?;
279 }
280
281 self.create_record("app.bsky.feed.post", record).await
282 }
283
284 pub async fn post_rich(
286 &self,
287 rt: &RichText,
288 created_at: Option<&str>,
289 ) -> Result<serde_json::Value, AgentError> {
290 let facets = if rt.facets().is_empty() {
291 None
292 } else {
293 Some(rt.facets().to_vec())
294 };
295 self.post(rt.text(), facets, created_at).await
296 }
297
298 pub async fn delete_post(&self, uri: &str) -> Result<(), AgentError> {
300 self.delete_record("app.bsky.feed.post", uri).await
301 }
302
303 pub async fn like(
309 &self,
310 uri: &str,
311 cid: &str,
312 created_at: Option<&str>,
313 ) -> Result<serde_json::Value, AgentError> {
314 let record = serde_json::json!({
315 "$type": "app.bsky.feed.like",
316 "subject": { "uri": uri, "cid": cid },
317 "createdAt": Self::resolve_timestamp(created_at),
318 });
319 self.create_record("app.bsky.feed.like", record).await
320 }
321
322 pub async fn delete_like(&self, like_uri: &str) -> Result<(), AgentError> {
324 self.delete_record("app.bsky.feed.like", like_uri).await
325 }
326
327 pub async fn repost(
331 &self,
332 uri: &str,
333 cid: &str,
334 created_at: Option<&str>,
335 ) -> Result<serde_json::Value, AgentError> {
336 let record = serde_json::json!({
337 "$type": "app.bsky.feed.repost",
338 "subject": { "uri": uri, "cid": cid },
339 "createdAt": Self::resolve_timestamp(created_at),
340 });
341 self.create_record("app.bsky.feed.repost", record).await
342 }
343
344 pub async fn delete_repost(&self, repost_uri: &str) -> Result<(), AgentError> {
346 self.delete_record("app.bsky.feed.repost", repost_uri).await
347 }
348
349 pub async fn follow(
355 &self,
356 subject_did: &str,
357 created_at: Option<&str>,
358 ) -> Result<serde_json::Value, AgentError> {
359 let record = serde_json::json!({
360 "$type": "app.bsky.graph.follow",
361 "subject": subject_did,
362 "createdAt": Self::resolve_timestamp(created_at),
363 });
364 self.create_record("app.bsky.graph.follow", record).await
365 }
366
367 pub async fn delete_follow(&self, follow_uri: &str) -> Result<(), AgentError> {
369 self.delete_record("app.bsky.graph.follow", follow_uri)
370 .await
371 }
372
373 pub async fn get_profile(&self, actor: &str) -> Result<serde_json::Value, AgentError> {
377 let mut params = QueryParams::new();
378 params.insert("actor".into(), QueryValue::String(actor.into()));
379 self.xrpc_query("app.bsky.actor.getProfile", Some(¶ms))
380 .await
381 }
382
383 pub async fn get_timeline(
385 &self,
386 limit: Option<i64>,
387 cursor: Option<&str>,
388 ) -> Result<serde_json::Value, AgentError> {
389 let mut params = QueryParams::new();
390 if let Some(limit) = limit {
391 params.insert("limit".into(), QueryValue::Integer(limit));
392 }
393 if let Some(cursor) = cursor {
394 params.insert("cursor".into(), QueryValue::String(cursor.into()));
395 }
396 self.xrpc_query("app.bsky.feed.getTimeline", Some(¶ms))
397 .await
398 }
399
400 pub async fn get_post_thread(
402 &self,
403 uri: &str,
404 depth: Option<i64>,
405 ) -> Result<serde_json::Value, AgentError> {
406 let mut params = QueryParams::new();
407 params.insert("uri".into(), QueryValue::String(uri.into()));
408 if let Some(depth) = depth {
409 params.insert("depth".into(), QueryValue::Integer(depth));
410 }
411 self.xrpc_query("app.bsky.feed.getPostThread", Some(¶ms))
412 .await
413 }
414
415 pub async fn search_actors(
417 &self,
418 query: &str,
419 limit: Option<i64>,
420 ) -> Result<serde_json::Value, AgentError> {
421 let mut params = QueryParams::new();
422 params.insert("q".into(), QueryValue::String(query.into()));
423 if let Some(limit) = limit {
424 params.insert("limit".into(), QueryValue::Integer(limit));
425 }
426 self.xrpc_query("app.bsky.actor.searchActors", Some(¶ms))
427 .await
428 }
429
430 pub async fn resolve_handle(&self, handle: &str) -> Result<String, AgentError> {
432 let mut params = QueryParams::new();
433 params.insert("handle".into(), QueryValue::String(handle.into()));
434 let data = self
435 .xrpc_query("com.atproto.identity.resolveHandle", Some(¶ms))
436 .await?;
437 data.get("did")
438 .and_then(|v| v.as_str())
439 .map(|s| s.to_string())
440 .ok_or_else(|| AgentError::Other("Missing DID in response".into()))
441 }
442
443 pub async fn list_notifications(
445 &self,
446 limit: Option<i64>,
447 cursor: Option<&str>,
448 ) -> Result<serde_json::Value, AgentError> {
449 let mut params = QueryParams::new();
450 if let Some(limit) = limit {
451 params.insert("limit".into(), QueryValue::Integer(limit));
452 }
453 if let Some(cursor) = cursor {
454 params.insert("cursor".into(), QueryValue::String(cursor.into()));
455 }
456 self.xrpc_query("app.bsky.notification.listNotifications", Some(¶ms))
457 .await
458 }
459
460 pub async fn upload_blob(
462 &self,
463 data: Vec<u8>,
464 content_type: &str,
465 ) -> Result<serde_json::Value, AgentError> {
466 let mut headers = HeadersMap::new();
467 headers.insert("Content-Type".into(), content_type.into());
468
469 if let Some(sess) = self.session.read().await.as_ref() {
471 headers.insert(
472 "Authorization".into(),
473 format!("Bearer {}", sess.access_jwt),
474 );
475 }
476
477 let opts = CallOptions {
478 encoding: Some(content_type.to_string()),
479 headers: Some(headers),
480 };
481
482 let response = self
483 .client
484 .procedure(
485 "com.atproto.repo.uploadBlob",
486 None,
487 Some(XrpcBody::Bytes(data)),
488 Some(&opts),
489 )
490 .await?;
491
492 Ok(response.data)
493 }
494
495 pub async fn describe_server(&self) -> Result<serde_json::Value, AgentError> {
497 self.xrpc_query("com.atproto.server.describeServer", None)
498 .await
499 }
500}
501
502#[cfg(test)]
503mod tests {
504 use super::*;
505
506 #[test]
507 fn agent_creation() {
508 let _agent = Agent::new("https://bsky.social").unwrap();
509 }
510
511 #[test]
512 fn session_serde_roundtrip() {
513 let session = Session {
514 did: "did:plc:abc123".to_string(),
515 handle: "alice.bsky.social".to_string(),
516 access_jwt: "eyJ...".to_string(),
517 refresh_jwt: "eyJ...".to_string(),
518 email: Some("alice@example.com".to_string()),
519 email_confirmed: Some(true),
520 };
521
522 let json = serde_json::to_string(&session).unwrap();
523 let parsed: Session = serde_json::from_str(&json).unwrap();
524 assert_eq!(parsed.did, "did:plc:abc123");
525 assert_eq!(parsed.handle, "alice.bsky.social");
526 assert_eq!(parsed.email, Some("alice@example.com".to_string()));
527 }
528
529 #[test]
530 fn agent_error_display() {
531 let err = AgentError::NotAuthenticated;
532 assert_eq!(err.to_string(), "Not authenticated");
533
534 let err = AgentError::Other("test error".into());
535 assert_eq!(err.to_string(), "test error");
536 }
537
538 #[tokio::test]
539 async fn agent_no_session_by_default() {
540 let agent = Agent::new("https://bsky.social").unwrap();
541 assert!(agent.did().await.is_none());
542 assert!(agent.session().await.is_none());
543 }
544
545 #[tokio::test]
546 async fn agent_assert_did_fails_when_not_logged_in() {
547 let agent = Agent::new("https://bsky.social").unwrap();
548 let err = agent.assert_did().await.unwrap_err();
549 assert!(matches!(err, AgentError::NotAuthenticated));
550 }
551
552 #[test]
553 fn now_iso_format() {
554 let ts = Agent::now_iso();
555 assert!(ts.ends_with('Z'));
556 assert!(ts.contains('T'));
557 }
558
559 #[test]
560 fn resolve_timestamp_with_provided() {
561 let ts = Agent::resolve_timestamp(Some("2024-01-15T12:00:00.000Z"));
562 assert_eq!(ts, "2024-01-15T12:00:00.000Z");
563 }
564
565 #[test]
566 fn resolve_timestamp_without_provided() {
567 let ts = Agent::resolve_timestamp(None);
568 assert!(ts.ends_with('Z'));
569 assert!(ts.contains('T'));
570 }
571
572 #[test]
573 fn service_url_accessible_without_async() {
574 let agent = Agent::new("https://bsky.social").unwrap();
575 assert_eq!(agent.service(), "https://bsky.social/");
576 }
577
578 #[tokio::test]
579 async fn auth_call_options_none_when_not_authenticated() {
580 let agent = Agent::new("https://bsky.social").unwrap();
581 assert!(agent.auth_call_options().await.is_none());
582 }
583}