use std::sync::{Arc, Mutex};
use tokio::sync::{Mutex as AsyncMutex, RwLock};
use proto_blue_lex_data::Cid;
use proto_blue_syntax::{AtIdentifier, AtUri, Did, Handle};
use proto_blue_xrpc::{
CallOptions, HeadersMap, QueryParams, QueryValue, ResponseType, XrpcBody, XrpcClient,
};
use crate::rich_text::RichText;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AtpSessionEvent {
Create,
CreateFailed,
Update,
Expired,
NetworkError,
}
pub type SessionEventCallback = Arc<dyn Fn(AtpSessionEvent, Option<&Session>) + Send + Sync>;
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct Session {
pub did: Did,
pub handle: Handle,
pub access_jwt: String,
pub refresh_jwt: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub email: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub email_confirmed: Option<bool>,
}
#[derive(Debug, thiserror::Error)]
pub enum AgentError {
#[error("XRPC error: {0}")]
Xrpc(#[from] proto_blue_xrpc::Error),
#[error("Not authenticated")]
NotAuthenticated,
#[error("JSON error: {0}")]
Json(#[from] serde_json::Error),
#[error("{0}")]
Other(String),
}
pub struct Agent {
client: XrpcClient,
session: Arc<RwLock<Option<Session>>>,
listeners: Arc<Mutex<Vec<SessionEventCallback>>>,
refresh_lock: Arc<AsyncMutex<()>>,
proxy: Arc<RwLock<Option<String>>>,
labelers: Arc<RwLock<Vec<LabelerOpts>>>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct LabelerOpts {
pub did: Did,
pub redirect: bool,
}
impl LabelerOpts {
fn header_value(&self) -> String {
if self.redirect {
format!("{};redirect", self.did)
} else {
self.did.to_string()
}
}
}
impl Agent {
#[cfg(any(
all(feature = "fetch-reqwest", not(target_arch = "wasm32")),
target_arch = "wasm32",
))]
pub fn new(service: impl AsRef<str>) -> Result<Self, AgentError> {
let client = XrpcClient::new(service)?;
Ok(Self {
client,
session: Arc::new(RwLock::new(None)),
listeners: Arc::new(Mutex::new(Vec::new())),
refresh_lock: Arc::new(AsyncMutex::new(())),
proxy: Arc::new(RwLock::new(None)),
labelers: Arc::new(RwLock::new(Vec::new())),
})
}
pub fn on_session<F>(&self, callback: F)
where
F: Fn(AtpSessionEvent, Option<&Session>) + Send + Sync + 'static,
{
self.listeners.lock().unwrap().push(Arc::new(callback));
}
fn emit(&self, event: AtpSessionEvent, session: Option<&Session>) {
let listeners = self.listeners.lock().unwrap().clone();
for cb in listeners {
cb(event, session);
}
}
#[must_use]
pub fn service(&self) -> String {
self.client.service_url().to_string()
}
pub async fn did(&self) -> Option<Did> {
self.session.read().await.as_ref().map(|s| s.did.clone())
}
pub async fn session(&self) -> Option<Session> {
self.session.read().await.clone()
}
async fn auth_call_options(&self) -> Option<CallOptions> {
let guard = self.session.read().await;
let session = guard.as_ref()?;
let mut headers = HeadersMap::new();
headers.insert(
"Authorization".into(),
format!("Bearer {}", session.access_jwt),
);
self.inject_proxy_and_labelers(&mut headers).await;
Some(CallOptions {
encoding: None,
headers: Some(headers),
..Default::default()
})
}
pub async fn anon_call_options(&self) -> Option<CallOptions> {
let mut headers = HeadersMap::new();
self.inject_proxy_and_labelers(&mut headers).await;
if headers.is_empty() {
None
} else {
Some(CallOptions {
encoding: None,
headers: Some(headers),
..Default::default()
})
}
}
async fn inject_proxy_and_labelers(&self, headers: &mut HeadersMap) {
if let Some(proxy) = self.proxy.read().await.as_ref() {
headers.insert("atproto-proxy".into(), proxy.clone());
}
let labelers = self.labelers.read().await;
if !labelers.is_empty() {
let v = labelers
.iter()
.map(LabelerOpts::header_value)
.collect::<Vec<_>>()
.join(", ");
headers.insert("atproto-accept-labelers".into(), v);
}
}
pub async fn configure_proxy(&self, target: Option<&str>) {
*self.proxy.write().await = target.map(String::from);
}
pub async fn with_proxy(&self, target: &str) -> Self {
let cloned = self.shallow_clone();
cloned.configure_proxy(Some(target)).await;
cloned
}
pub async fn configure_labelers(&self, labelers: &[LabelerOpts]) {
*self.labelers.write().await = labelers.to_vec();
}
fn shallow_clone(&self) -> Self {
Self {
client: self.client.clone(),
session: self.session.clone(),
listeners: self.listeners.clone(),
refresh_lock: self.refresh_lock.clone(),
proxy: Arc::new(RwLock::new(None)),
labelers: self.labelers.clone(),
}
}
pub async fn login(
&self,
identifier: &AtIdentifier,
password: &str,
) -> Result<Session, AgentError> {
let body = serde_json::json!({
"identifier": identifier,
"password": password,
});
let response = match self
.client
.procedure(
"com.atproto.server.createSession",
None,
Some(XrpcBody::Json(body)),
None,
)
.await
{
Ok(r) => r,
Err(e) => {
self.emit(AtpSessionEvent::CreateFailed, None);
return Err(AgentError::Xrpc(e));
}
};
let session: Session = serde_json::from_value(response.data)?;
*self.session.write().await = Some(session.clone());
self.emit(AtpSessionEvent::Create, Some(&session));
Ok(session)
}
pub async fn resume_session(&self, session: Session) -> Result<(), AgentError> {
let mut headers = HeadersMap::new();
headers.insert(
"Authorization".into(),
format!("Bearer {}", session.access_jwt),
);
let opts = CallOptions {
encoding: None,
headers: Some(headers),
..Default::default()
};
let response = self
.client
.query("com.atproto.server.getSession", None, Some(&opts))
.await?;
let verified_did = response
.data
.get("did")
.and_then(|v| v.as_str())
.map(Did::new)
.transpose()
.map_err(|e| AgentError::Other(format!("server returned invalid DID: {e}")))?;
let mut committed = session;
if let Some(did) = verified_did {
committed.did = did;
}
*self.session.write().await = Some(committed.clone());
self.emit(AtpSessionEvent::Create, Some(&committed));
Ok(())
}
pub async fn refresh_session(&self) -> Result<Session, AgentError> {
let refresh_jwt = {
let sess = self.session.read().await;
let sess = sess.as_ref().ok_or(AgentError::NotAuthenticated)?;
sess.refresh_jwt.clone()
};
let mut headers = HeadersMap::new();
headers.insert("Authorization".into(), format!("Bearer {refresh_jwt}"));
let opts = CallOptions {
encoding: None,
headers: Some(headers),
..Default::default()
};
let response = match self
.client
.procedure("com.atproto.server.refreshSession", None, None, Some(&opts))
.await
{
Ok(r) => r,
Err(e) => {
if is_refresh_rejected(&e) {
*self.session.write().await = None;
self.emit(AtpSessionEvent::Expired, None);
} else {
self.emit(AtpSessionEvent::NetworkError, None);
}
return Err(AgentError::Xrpc(e));
}
};
let session: Session = serde_json::from_value(response.data)?;
*self.session.write().await = Some(session.clone());
self.emit(AtpSessionEvent::Update, Some(&session));
Ok(session)
}
async fn assert_did(&self) -> Result<Did, AgentError> {
self.did().await.ok_or(AgentError::NotAuthenticated)
}
async fn xrpc_query(
&self,
nsid: &str,
params: Option<&QueryParams>,
) -> Result<serde_json::Value, AgentError> {
let opts = self.auth_call_options().await;
let first = self.client.query(nsid, params, opts.as_ref()).await;
match first {
Ok(r) => Ok(r.data),
Err(e) if is_auth_expired(&e) => {
self.refresh_and_retry(|opts| {
let c = self.client.clone();
let nsid = nsid.to_string();
let params = params.cloned();
async move { c.query(&nsid, params.as_ref(), opts.as_ref()).await }
})
.await
}
Err(e) => Err(AgentError::Xrpc(e)),
}
}
async fn xrpc_procedure(
&self,
nsid: &str,
body: serde_json::Value,
) -> Result<serde_json::Value, AgentError> {
let opts = self.auth_call_options().await;
let first = self
.client
.procedure(
nsid,
None,
Some(XrpcBody::Json(body.clone())),
opts.as_ref(),
)
.await;
match first {
Ok(r) => Ok(r.data),
Err(e) if is_auth_expired(&e) => {
self.refresh_and_retry(|opts| {
let c = self.client.clone();
let nsid = nsid.to_string();
let body = body.clone();
async move {
c.procedure(&nsid, None, Some(XrpcBody::Json(body)), opts.as_ref())
.await
}
})
.await
}
Err(e) => Err(AgentError::Xrpc(e)),
}
}
async fn refresh_and_retry<F, Fut>(&self, replay: F) -> Result<serde_json::Value, AgentError>
where
F: FnOnce(Option<CallOptions>) -> Fut,
Fut: std::future::Future<
Output = Result<proto_blue_xrpc::XrpcResponse, proto_blue_xrpc::Error>,
>,
{
let pre_refresh_jwt = self
.session
.read()
.await
.as_ref()
.map(|s| s.access_jwt.clone());
let _guard = self.refresh_lock.lock().await;
let current_jwt = self
.session
.read()
.await
.as_ref()
.map(|s| s.access_jwt.clone());
if pre_refresh_jwt == current_jwt {
self.refresh_session().await?;
}
drop(_guard);
let opts = self.auth_call_options().await;
let response = replay(opts).await?;
Ok(response.data)
}
async fn create_record(
&self,
collection: &str,
record: serde_json::Value,
) -> Result<serde_json::Value, AgentError> {
let did = self.assert_did().await?;
let body = serde_json::json!({
"repo": did,
"collection": collection,
"record": record,
});
self.xrpc_procedure("com.atproto.repo.createRecord", body)
.await
}
async fn delete_record(&self, collection: &str, uri: &AtUri) -> Result<(), AgentError> {
let did = self.assert_did().await?;
let rkey = uri
.rkey()
.ok_or_else(|| AgentError::Other("AT-URI has no rkey segment".into()))?;
let body = serde_json::json!({
"repo": did,
"collection": collection,
"rkey": rkey,
});
self.xrpc_procedure("com.atproto.repo.deleteRecord", body)
.await?;
Ok(())
}
fn now_iso() -> String {
chrono::Utc::now().to_rfc3339_opts(chrono::SecondsFormat::Millis, true)
}
fn resolve_timestamp(created_at: Option<&str>) -> String {
created_at.map_or_else(Self::now_iso, String::from)
}
pub async fn post(
&self,
text: &str,
facets: Option<Vec<crate::rich_text::Facet>>,
created_at: Option<&str>,
) -> Result<serde_json::Value, AgentError> {
let mut record = serde_json::json!({
"$type": "app.bsky.feed.post",
"text": text,
"createdAt": Self::resolve_timestamp(created_at),
});
if let Some(facets) = facets {
record["facets"] = serde_json::to_value(&facets)?;
}
self.create_record("app.bsky.feed.post", record).await
}
pub async fn post_rich(
&self,
rt: &RichText,
created_at: Option<&str>,
) -> Result<serde_json::Value, AgentError> {
let facets = if rt.facets().is_empty() {
None
} else {
Some(rt.facets().to_vec())
};
self.post(rt.text(), facets, created_at).await
}
pub async fn delete_post(&self, uri: &AtUri) -> Result<(), AgentError> {
self.delete_record("app.bsky.feed.post", uri).await
}
pub async fn like(
&self,
uri: &AtUri,
cid: &Cid,
created_at: Option<&str>,
) -> Result<serde_json::Value, AgentError> {
let record = serde_json::json!({
"$type": "app.bsky.feed.like",
"subject": { "uri": uri, "cid": cid },
"createdAt": Self::resolve_timestamp(created_at),
});
self.create_record("app.bsky.feed.like", record).await
}
pub async fn delete_like(&self, like_uri: &AtUri) -> Result<(), AgentError> {
self.delete_record("app.bsky.feed.like", like_uri).await
}
pub async fn repost(
&self,
uri: &AtUri,
cid: &Cid,
created_at: Option<&str>,
) -> Result<serde_json::Value, AgentError> {
let record = serde_json::json!({
"$type": "app.bsky.feed.repost",
"subject": { "uri": uri, "cid": cid },
"createdAt": Self::resolve_timestamp(created_at),
});
self.create_record("app.bsky.feed.repost", record).await
}
pub async fn delete_repost(&self, repost_uri: &AtUri) -> Result<(), AgentError> {
self.delete_record("app.bsky.feed.repost", repost_uri).await
}
pub async fn follow(
&self,
subject_did: &Did,
created_at: Option<&str>,
) -> Result<serde_json::Value, AgentError> {
let record = serde_json::json!({
"$type": "app.bsky.graph.follow",
"subject": subject_did,
"createdAt": Self::resolve_timestamp(created_at),
});
self.create_record("app.bsky.graph.follow", record).await
}
pub async fn delete_follow(&self, follow_uri: &AtUri) -> Result<(), AgentError> {
self.delete_record("app.bsky.graph.follow", follow_uri)
.await
}
pub async fn get_profile(&self, actor: &AtIdentifier) -> Result<serde_json::Value, AgentError> {
let mut params = QueryParams::new();
params.insert("actor".into(), QueryValue::String(actor.to_string()));
self.xrpc_query("app.bsky.actor.getProfile", Some(¶ms))
.await
}
pub async fn get_timeline(
&self,
limit: Option<i64>,
cursor: Option<&str>,
) -> Result<serde_json::Value, AgentError> {
let mut params = QueryParams::new();
if let Some(limit) = limit {
params.insert("limit".into(), QueryValue::Integer(limit));
}
if let Some(cursor) = cursor {
params.insert("cursor".into(), QueryValue::String(cursor.into()));
}
self.xrpc_query("app.bsky.feed.getTimeline", Some(¶ms))
.await
}
pub async fn get_post_thread(
&self,
uri: &AtUri,
depth: Option<i64>,
) -> Result<serde_json::Value, AgentError> {
let mut params = QueryParams::new();
params.insert("uri".into(), QueryValue::String(uri.to_string()));
if let Some(depth) = depth {
params.insert("depth".into(), QueryValue::Integer(depth));
}
self.xrpc_query("app.bsky.feed.getPostThread", Some(¶ms))
.await
}
pub async fn search_actors(
&self,
query: &str,
limit: Option<i64>,
) -> Result<serde_json::Value, AgentError> {
let mut params = QueryParams::new();
params.insert("q".into(), QueryValue::String(query.into()));
if let Some(limit) = limit {
params.insert("limit".into(), QueryValue::Integer(limit));
}
self.xrpc_query("app.bsky.actor.searchActors", Some(¶ms))
.await
}
pub async fn resolve_handle(&self, handle: &Handle) -> Result<Did, AgentError> {
let mut params = QueryParams::new();
params.insert("handle".into(), QueryValue::String(handle.to_string()));
let data = self
.xrpc_query("com.atproto.identity.resolveHandle", Some(¶ms))
.await?;
let did_str = data
.get("did")
.and_then(|v| v.as_str())
.ok_or_else(|| AgentError::Other("Missing DID in response".into()))?;
Did::new(did_str)
.map_err(|e| AgentError::Other(format!("server returned invalid DID: {e}")))
}
pub async fn list_notifications(
&self,
limit: Option<i64>,
cursor: Option<&str>,
) -> Result<serde_json::Value, AgentError> {
let mut params = QueryParams::new();
if let Some(limit) = limit {
params.insert("limit".into(), QueryValue::Integer(limit));
}
if let Some(cursor) = cursor {
params.insert("cursor".into(), QueryValue::String(cursor.into()));
}
self.xrpc_query("app.bsky.notification.listNotifications", Some(¶ms))
.await
}
pub async fn upload_blob(
&self,
data: Vec<u8>,
content_type: &str,
) -> Result<serde_json::Value, AgentError> {
let mut headers = HeadersMap::new();
headers.insert("Content-Type".into(), content_type.into());
if let Some(sess) = self.session.read().await.as_ref() {
headers.insert(
"Authorization".into(),
format!("Bearer {}", sess.access_jwt),
);
}
let opts = CallOptions {
encoding: Some(content_type.to_string()),
headers: Some(headers),
..Default::default()
};
let response = self
.client
.procedure(
"com.atproto.repo.uploadBlob",
None,
Some(XrpcBody::Bytes(data)),
Some(&opts),
)
.await?;
Ok(response.data)
}
pub async fn describe_server(&self) -> Result<serde_json::Value, AgentError> {
self.xrpc_query("com.atproto.server.describeServer", None)
.await
}
pub async fn logout(&self) -> Result<(), AgentError> {
let refresh_jwt = {
let guard = self.session.read().await;
guard.as_ref().map(|s| s.refresh_jwt.clone())
};
let server_result = if let Some(refresh_jwt) = refresh_jwt {
let mut headers = HeadersMap::new();
headers.insert("Authorization".into(), format!("Bearer {refresh_jwt}"));
let opts = CallOptions {
encoding: None,
headers: Some(headers),
..Default::default()
};
self.client
.procedure("com.atproto.server.deleteSession", None, None, Some(&opts))
.await
.map(|_| ())
} else {
Ok(())
};
*self.session.write().await = None;
self.emit(AtpSessionEvent::Expired, None);
server_result.map_err(AgentError::Xrpc)
}
pub async fn create_account(
&self,
handle: &Handle,
password: &str,
email: Option<&str>,
extra: Option<serde_json::Value>,
) -> Result<Session, AgentError> {
let mut body = serde_json::json!({
"handle": handle,
"password": password,
});
if let Some(email) = email {
body["email"] = serde_json::Value::String(email.to_string());
}
if let Some(extra) = extra
&& let Some(extra_map) = extra.as_object()
&& let Some(body_map) = body.as_object_mut()
{
for (k, v) in extra_map {
body_map.insert(k.clone(), v.clone());
}
}
let response = match self
.client
.procedure(
"com.atproto.server.createAccount",
None,
Some(XrpcBody::Json(body)),
None,
)
.await
{
Ok(r) => r,
Err(e) => {
self.emit(AtpSessionEvent::CreateFailed, None);
return Err(AgentError::Xrpc(e));
}
};
let session: Session = serde_json::from_value(response.data)?;
*self.session.write().await = Some(session.clone());
self.emit(AtpSessionEvent::Create, Some(&session));
Ok(session)
}
pub async fn upsert_profile<F>(&self, mutate: F) -> Result<serde_json::Value, AgentError>
where
F: Fn(serde_json::Value) -> serde_json::Value,
{
let did = self.assert_did().await?;
const MAX_RETRIES: u32 = 5;
for _ in 0..MAX_RETRIES {
let existing_result = self
.xrpc_query(
"com.atproto.repo.getRecord",
Some(&{
let mut p = QueryParams::new();
p.insert("repo".into(), QueryValue::String(did.to_string()));
p.insert(
"collection".into(),
QueryValue::String("app.bsky.actor.profile".into()),
);
p.insert("rkey".into(), QueryValue::String("self".into()));
p
}),
)
.await;
let (existing_record, swap_cid) = match existing_result {
Ok(r) => {
let record = r.get("value").cloned().unwrap_or(serde_json::Value::Null);
let cid = r.get("cid").and_then(|v| v.as_str()).map(String::from);
(record, cid)
}
Err(AgentError::Xrpc(ref e)) if is_not_found(e) => (serde_json::Value::Null, None),
Err(e) => return Err(e),
};
let updated = mutate(existing_record);
let mut body = serde_json::json!({
"repo": did,
"collection": "app.bsky.actor.profile",
"rkey": "self",
"record": updated,
});
if let Some(cid) = swap_cid {
body["swapRecord"] = serde_json::Value::String(cid);
}
match self
.xrpc_procedure("com.atproto.repo.putRecord", body)
.await
{
Ok(r) => return Ok(r),
Err(AgentError::Xrpc(ref e)) if is_invalid_swap(e) => {
continue;
}
Err(e) => return Err(e),
}
}
Err(AgentError::Other(
"upsert_profile: exceeded maximum retries due to concurrent writes".into(),
))
}
}
fn is_not_found(err: &proto_blue_xrpc::Error) -> bool {
match err {
proto_blue_xrpc::Error::Xrpc(x) => x.is_error("RecordNotFound"),
_ => false,
}
}
fn is_invalid_swap(err: &proto_blue_xrpc::Error) -> bool {
match err {
proto_blue_xrpc::Error::Xrpc(x) => x.is_error("InvalidSwap"),
_ => false,
}
}
fn is_auth_expired(err: &proto_blue_xrpc::Error) -> bool {
match err {
proto_blue_xrpc::Error::Xrpc(x) => {
matches!(x.status, ResponseType::AuthenticationRequired) && x.is_error("ExpiredToken")
}
_ => false,
}
}
const fn is_refresh_rejected(err: &proto_blue_xrpc::Error) -> bool {
match err {
proto_blue_xrpc::Error::Xrpc(x) => {
matches!(x.status, ResponseType::AuthenticationRequired)
}
_ => false,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn agent_creation() {
let _agent = Agent::new("https://bsky.social").unwrap();
}
#[test]
fn session_serde_roundtrip() {
let session = Session {
did: Did::new("did:plc:abc123").unwrap(),
handle: Handle::new("alice.bsky.social").unwrap(),
access_jwt: "eyJ...".to_string(),
refresh_jwt: "eyJ...".to_string(),
email: Some("alice@example.com".to_string()),
email_confirmed: Some(true),
};
let json = serde_json::to_string(&session).unwrap();
let parsed: Session = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.did.as_str(), "did:plc:abc123");
assert_eq!(parsed.handle.as_str(), "alice.bsky.social");
assert_eq!(parsed.email, Some("alice@example.com".to_string()));
}
#[test]
fn agent_error_display() {
let err = AgentError::NotAuthenticated;
assert_eq!(err.to_string(), "Not authenticated");
let err = AgentError::Other("test error".into());
assert_eq!(err.to_string(), "test error");
}
#[tokio::test]
async fn agent_no_session_by_default() {
let agent = Agent::new("https://bsky.social").unwrap();
assert!(agent.did().await.is_none());
assert!(agent.session().await.is_none());
}
#[tokio::test]
async fn agent_assert_did_fails_when_not_logged_in() {
let agent = Agent::new("https://bsky.social").unwrap();
let err = agent.assert_did().await.unwrap_err();
assert!(matches!(err, AgentError::NotAuthenticated));
}
#[test]
fn now_iso_format() {
let ts = Agent::now_iso();
assert!(ts.ends_with('Z'));
assert!(ts.contains('T'));
}
#[test]
fn resolve_timestamp_with_provided() {
let ts = Agent::resolve_timestamp(Some("2024-01-15T12:00:00.000Z"));
assert_eq!(ts, "2024-01-15T12:00:00.000Z");
}
#[test]
fn resolve_timestamp_without_provided() {
let ts = Agent::resolve_timestamp(None);
assert!(ts.ends_with('Z'));
assert!(ts.contains('T'));
}
#[test]
fn service_url_accessible_without_async() {
let agent = Agent::new("https://bsky.social").unwrap();
assert_eq!(agent.service(), "https://bsky.social/");
}
#[tokio::test]
async fn auth_call_options_none_when_not_authenticated() {
let agent = Agent::new("https://bsky.social").unwrap();
assert!(agent.auth_call_options().await.is_none());
}
use async_trait::async_trait;
use proto_blue_common::fetch::{FetchError, FetchHandler, HttpRequest, HttpResponse};
struct ScriptedFetcher {
createsession_body: Vec<u8>,
scripts: std::sync::Mutex<std::collections::HashMap<String, Vec<ScriptedResponse>>>,
call_counts: std::sync::Mutex<std::collections::HashMap<String, usize>>,
}
#[derive(Clone)]
struct ScriptedResponse {
status: u16,
body: Vec<u8>,
}
impl ScriptedFetcher {
fn new(createsession_body: Vec<u8>) -> Self {
Self {
createsession_body,
scripts: Default::default(),
call_counts: Default::default(),
}
}
fn script(&self, path: &str, responses: Vec<ScriptedResponse>) {
self.scripts
.lock()
.unwrap()
.insert(path.to_string(), responses);
}
fn call_count(&self, path: &str) -> usize {
*self.call_counts.lock().unwrap().get(path).unwrap_or(&0)
}
}
#[async_trait]
impl FetchHandler for ScriptedFetcher {
async fn fetch(&self, req: HttpRequest) -> Result<HttpResponse, FetchError> {
let path = req.url.clone();
let key = path
.split("/xrpc/")
.nth(1)
.unwrap_or(&path)
.split('?')
.next()
.unwrap_or("")
.to_string();
*self
.call_counts
.lock()
.unwrap()
.entry(key.clone())
.or_insert(0) += 1;
{
let mut scripts = self.scripts.lock().unwrap();
if let Some(list) = scripts.get_mut(&key) {
let resp = if list.len() == 1 {
list[0].clone()
} else {
list.remove(0)
};
let mut headers = proto_blue_common::fetch::HttpHeaders::new();
headers.insert("content-type".into(), "application/json".into());
return Ok(HttpResponse {
status: resp.status,
headers,
body: resp.body,
});
}
}
if key == "com.atproto.server.createSession" {
let mut headers = proto_blue_common::fetch::HttpHeaders::new();
headers.insert("content-type".into(), "application/json".into());
return Ok(HttpResponse {
status: 200,
headers,
body: self.createsession_body.clone(),
});
}
Err(FetchError::Other(format!("no script for {key}")))
}
}
fn login_body() -> Vec<u8> {
br#"{"did":"did:plc:u","handle":"alice.test","accessJwt":"a1","refreshJwt":"r1"}"#.to_vec()
}
fn agent_with_fetcher(fetcher: Arc<ScriptedFetcher>) -> Agent {
let client = XrpcClient::with_fetch_handler("https://example.com", fetcher).unwrap();
Agent {
client,
session: Arc::new(RwLock::new(None)),
listeners: Arc::new(Mutex::new(Vec::new())),
refresh_lock: Arc::new(AsyncMutex::new(())),
proxy: Arc::new(RwLock::new(None)),
labelers: Arc::new(RwLock::new(Vec::new())),
}
}
#[tokio::test]
async fn emits_create_on_successful_login() {
let fetcher = Arc::new(ScriptedFetcher::new(login_body()));
let agent = agent_with_fetcher(fetcher);
let events: Arc<Mutex<Vec<AtpSessionEvent>>> = Arc::new(Mutex::new(Vec::new()));
let ev_clone = events.clone();
agent.on_session(move |e, _| ev_clone.lock().unwrap().push(e));
agent
.login(&AtIdentifier::new("alice.test").unwrap(), "secret")
.await
.unwrap();
let got = events.lock().unwrap().clone();
assert_eq!(got, vec![AtpSessionEvent::Create]);
}
#[tokio::test]
async fn emits_create_failed_on_login_rejection() {
let fetcher = Arc::new(ScriptedFetcher::new(vec![]));
fetcher.script(
"com.atproto.server.createSession",
vec![ScriptedResponse {
status: 401,
body: br#"{"error":"AuthenticationRequired","message":"bad pwd"}"#.to_vec(),
}],
);
let agent = agent_with_fetcher(fetcher);
let events: Arc<Mutex<Vec<AtpSessionEvent>>> = Arc::new(Mutex::new(Vec::new()));
let ev_clone = events.clone();
agent.on_session(move |e, _| ev_clone.lock().unwrap().push(e));
let _ = agent
.login(&AtIdentifier::new("alice.test").unwrap(), "bad")
.await
.unwrap_err();
let got = events.lock().unwrap().clone();
assert_eq!(got, vec![AtpSessionEvent::CreateFailed]);
}
#[tokio::test]
async fn auto_refreshes_on_expired_access_token() {
let fetcher = Arc::new(ScriptedFetcher::new(login_body()));
fetcher.script(
"com.atproto.server.describeServer",
vec![
ScriptedResponse {
status: 401,
body: br#"{"error":"ExpiredToken","message":"expired"}"#.to_vec(),
},
ScriptedResponse {
status: 200,
body: br#"{"did":"did:plc:svr"}"#.to_vec(),
},
],
);
fetcher.script(
"com.atproto.server.refreshSession",
vec![ScriptedResponse {
status: 200,
body: br#"{"did":"did:plc:u","handle":"alice.test","accessJwt":"a2","refreshJwt":"r2"}"#
.to_vec(),
}],
);
let agent = agent_with_fetcher(fetcher.clone());
agent
.login(&AtIdentifier::new("alice.test").unwrap(), "secret")
.await
.unwrap();
let events: Arc<Mutex<Vec<AtpSessionEvent>>> = Arc::new(Mutex::new(Vec::new()));
let ev_clone = events.clone();
agent.on_session(move |e, _| ev_clone.lock().unwrap().push(e));
let result = agent.describe_server().await.unwrap();
assert_eq!(result["did"], "did:plc:svr");
assert_eq!(fetcher.call_count("com.atproto.server.describeServer"), 2);
assert_eq!(fetcher.call_count("com.atproto.server.refreshSession"), 1);
let got = events.lock().unwrap().clone();
assert_eq!(got, vec![AtpSessionEvent::Update]);
}
#[tokio::test]
async fn concurrent_expired_token_refreshes_once() {
let fetcher = Arc::new(ScriptedFetcher::new(login_body()));
fetcher.script(
"com.atproto.server.describeServer",
vec![
ScriptedResponse {
status: 401,
body: br#"{"error":"ExpiredToken","message":"expired"}"#.to_vec(),
},
ScriptedResponse {
status: 200,
body: br#"{"did":"did:plc:svr"}"#.to_vec(),
},
],
);
fetcher.script(
"com.atproto.server.refreshSession",
vec![ScriptedResponse {
status: 200,
body: br#"{"did":"did:plc:u","handle":"alice.test","accessJwt":"a2","refreshJwt":"r2"}"#
.to_vec(),
}],
);
let agent = Arc::new(agent_with_fetcher(fetcher.clone()));
agent
.login(&AtIdentifier::new("alice.test").unwrap(), "secret")
.await
.unwrap();
let mut handles = Vec::new();
for _ in 0..5 {
let a = agent.clone();
handles.push(tokio::spawn(async move {
a.describe_server().await.unwrap();
}));
}
for h in handles {
h.await.unwrap();
}
assert_eq!(
fetcher.call_count("com.atproto.server.refreshSession"),
1,
"concurrent callers must share one refreshSession call",
);
}
#[tokio::test]
async fn configure_proxy_sets_header_on_next_call() {
let fetcher = Arc::new(ScriptedFetcher::new(login_body()));
fetcher.script(
"com.atproto.server.describeServer",
vec![ScriptedResponse {
status: 200,
body: br#"{"did":"did:plc:svr"}"#.to_vec(),
}],
);
let agent = agent_with_fetcher(fetcher.clone());
agent
.configure_proxy(Some("did:web:api.bsky.chat#bsky_chat"))
.await;
agent.describe_server().await.unwrap();
let p = agent.proxy.read().await;
assert_eq!(p.as_deref(), Some("did:web:api.bsky.chat#bsky_chat"));
}
#[tokio::test]
async fn configure_labelers_stores_list() {
let fetcher = Arc::new(ScriptedFetcher::new(login_body()));
let agent = agent_with_fetcher(fetcher);
agent
.configure_labelers(&[
LabelerOpts {
did: Did::new("did:plc:a").unwrap(),
redirect: false,
},
LabelerOpts {
did: Did::new("did:plc:b").unwrap(),
redirect: true,
},
])
.await;
let l = agent.labelers.read().await;
assert_eq!(l.len(), 2);
assert_eq!(l[0].header_value(), "did:plc:a");
assert_eq!(l[1].header_value(), "did:plc:b;redirect");
}
#[tokio::test]
async fn logout_clears_session() {
let fetcher = Arc::new(ScriptedFetcher::new(login_body()));
fetcher.script(
"com.atproto.server.deleteSession",
vec![ScriptedResponse {
status: 200,
body: b"{}".to_vec(),
}],
);
let agent = agent_with_fetcher(fetcher.clone());
agent
.login(&AtIdentifier::new("alice.test").unwrap(), "secret")
.await
.unwrap();
assert!(agent.session().await.is_some());
agent.logout().await.unwrap();
assert!(agent.session().await.is_none());
assert_eq!(fetcher.call_count("com.atproto.server.deleteSession"), 1,);
}
#[tokio::test]
async fn logout_clears_session_even_on_server_error() {
let fetcher = Arc::new(ScriptedFetcher::new(login_body()));
fetcher.script(
"com.atproto.server.deleteSession",
vec![ScriptedResponse {
status: 500,
body: br#"{"error":"InternalServerError"}"#.to_vec(),
}],
);
let agent = agent_with_fetcher(fetcher);
agent
.login(&AtIdentifier::new("alice.test").unwrap(), "secret")
.await
.unwrap();
let _ = agent.logout().await;
assert!(agent.session().await.is_none());
}
#[tokio::test]
async fn create_account_emits_create_on_success() {
let fetcher = Arc::new(ScriptedFetcher::new(login_body()));
fetcher.script(
"com.atproto.server.createAccount",
vec![ScriptedResponse {
status: 200,
body:
br#"{"did":"did:plc:new","handle":"newuser.test","accessJwt":"a","refreshJwt":"r"}"#
.to_vec(),
}],
);
let agent = agent_with_fetcher(fetcher);
let events: Arc<Mutex<Vec<AtpSessionEvent>>> = Arc::new(Mutex::new(Vec::new()));
let ev = events.clone();
agent.on_session(move |e, _| ev.lock().unwrap().push(e));
let session = agent
.create_account(
&Handle::new("newuser.test").unwrap(),
"pw",
Some("new@example.com"),
None,
)
.await
.unwrap();
assert_eq!(session.did.as_str(), "did:plc:new");
assert_eq!(
events.lock().unwrap().clone(),
vec![AtpSessionEvent::Create]
);
}
#[tokio::test]
async fn upsert_profile_creates_when_absent() {
let fetcher = Arc::new(ScriptedFetcher::new(login_body()));
fetcher.script(
"com.atproto.repo.getRecord",
vec![ScriptedResponse {
status: 400,
body: br#"{"error":"RecordNotFound","message":"no such record"}"#.to_vec(),
}],
);
fetcher.script(
"com.atproto.repo.putRecord",
vec![ScriptedResponse {
status: 200,
body: br#"{"uri":"at://did:plc:u/app.bsky.actor.profile/self","cid":"bafy"}"#
.to_vec(),
}],
);
let agent = agent_with_fetcher(fetcher);
agent
.login(&AtIdentifier::new("alice.test").unwrap(), "secret")
.await
.unwrap();
let result = agent
.upsert_profile(|prev| {
assert!(prev.is_null(), "no existing profile");
serde_json::json!({"$type": "app.bsky.actor.profile", "displayName": "Alice"})
})
.await
.unwrap();
assert_eq!(result["uri"], "at://did:plc:u/app.bsky.actor.profile/self");
}
#[tokio::test]
async fn emits_expired_when_refresh_itself_401s() {
let fetcher = Arc::new(ScriptedFetcher::new(login_body()));
fetcher.script(
"com.atproto.server.refreshSession",
vec![ScriptedResponse {
status: 401,
body: br#"{"error":"AuthenticationRequired","message":"refresh expired"}"#.to_vec(),
}],
);
let agent = agent_with_fetcher(fetcher);
agent
.login(&AtIdentifier::new("alice.test").unwrap(), "secret")
.await
.unwrap();
let events: Arc<Mutex<Vec<AtpSessionEvent>>> = Arc::new(Mutex::new(Vec::new()));
let ev_clone = events.clone();
agent.on_session(move |e, _| ev_clone.lock().unwrap().push(e));
let _ = agent.refresh_session().await.unwrap_err();
let got = events.lock().unwrap().clone();
assert_eq!(got, vec![AtpSessionEvent::Expired]);
assert!(
agent.session().await.is_none(),
"session cleared on expired refresh"
);
}
}