mod inner;
pub mod store;
use self::store::SessionStore;
use crate::client::Service;
use atrium_xrpc::error::Error;
use atrium_xrpc::XrpcClient;
use std::sync::Arc;
pub type Session = crate::com::atproto::server::create_session::Output;
pub struct AtpAgent<S, T>
where
S: SessionStore + Send + Sync,
T: XrpcClient + Send + Sync,
{
store: Arc<inner::Store<S>>,
pub api: Service<inner::Client<S, T>>,
}
impl<S, T> AtpAgent<S, T>
where
S: SessionStore + Send + Sync,
T: XrpcClient + Send + Sync,
{
pub fn new(xrpc: T, store: S) -> Self {
let store = Arc::new(inner::Store::new(store, xrpc.base_uri()));
let api = Service::new(Arc::new(inner::Client::new(Arc::clone(&store), xrpc)));
Self { store, api }
}
pub async fn login(
&self,
identifier: impl AsRef<str>,
password: impl AsRef<str>,
) -> Result<Session, Error<crate::com::atproto::server::create_session::Error>> {
let result = self
.api
.com
.atproto
.server
.create_session(crate::com::atproto::server::create_session::Input {
identifier: identifier.as_ref().into(),
password: password.as_ref().into(),
})
.await?;
self.store.set_session(result.clone()).await;
if let Some(did_doc) = &result.did_doc {
self.store.update_endpoint(did_doc);
}
Ok(result)
}
pub async fn resume_session(
&self,
session: Session,
) -> Result<(), Error<crate::com::atproto::server::get_session::Error>> {
self.store.set_session(session.clone()).await;
let result = self.api.com.atproto.server.get_session().await;
match result {
Ok(output) => {
assert_eq!(output.did, session.did);
if let Some(mut session) = self.store.get_session().await {
session.did_doc = output.did_doc.clone();
session.email = output.email;
session.email_confirmed = output.email_confirmed;
session.handle = output.handle;
self.store.set_session(session).await;
}
if let Some(did_doc) = &output.did_doc {
self.store.update_endpoint(did_doc);
}
Ok(())
}
Err(err) => {
self.store.clear_session().await;
Err(err)
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::agent::store::MemorySessionStore;
use crate::did_doc::{DidDocument, Service, VerificationMethod};
use async_trait::async_trait;
use atrium_xrpc::HttpClient;
use http::{Request, Response};
use std::collections::HashMap;
use tokio::sync::RwLock;
#[cfg(target_arch = "wasm32")]
use wasm_bindgen_test::wasm_bindgen_test;
#[derive(Default)]
struct DummyResponses {
create_session: Option<crate::com::atproto::server::create_session::Output>,
get_session: Option<crate::com::atproto::server::get_session::Output>,
}
#[derive(Default)]
struct DummyClient {
responses: DummyResponses,
counts: Arc<RwLock<HashMap<String, usize>>>,
}
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
impl HttpClient for DummyClient {
async fn send_http(
&self,
request: Request<Vec<u8>>,
) -> Result<Response<Vec<u8>>, Box<dyn std::error::Error + Send + Sync + 'static>> {
#[cfg(not(target_arch = "wasm32"))]
tokio::time::sleep(std::time::Duration::from_micros(10)).await;
let builder =
Response::builder().header(http::header::CONTENT_TYPE, "application/json");
let token = request
.headers()
.get(http::header::AUTHORIZATION)
.and_then(|value| value.to_str().ok())
.and_then(|value| value.split(' ').last());
if token == Some("expired") {
return Ok(builder.status(http::StatusCode::BAD_REQUEST).body(
serde_json::to_vec(&atrium_xrpc::error::ErrorResponseBody {
error: Some(String::from("ExpiredToken")),
message: Some(String::from("Token has expired")),
})?,
)?);
}
let mut body = Vec::new();
if let Some(nsid) = request.uri().path().strip_prefix("/xrpc/") {
*self.counts.write().await.entry(nsid.into()).or_default() += 1;
match nsid {
"com.atproto.server.createSession" => {
if let Some(output) = &self.responses.create_session {
body.extend(serde_json::to_vec(output)?);
}
}
"com.atproto.server.getSession" => {
if token == Some("access") {
if let Some(output) = &self.responses.get_session {
body.extend(serde_json::to_vec(output)?);
}
}
}
"com.atproto.server.refreshSession" => {
if token == Some("refresh") {
body.extend(serde_json::to_vec(
&crate::com::atproto::server::refresh_session::Output {
access_jwt: String::from("access"),
did: "did:web:example.com".parse().expect("valid"),
did_doc: None,
handle: "example.com".parse().expect("valid"),
refresh_jwt: String::from("refresh"),
},
)?);
}
}
_ => {}
}
}
if body.is_empty() {
Ok(builder
.status(http::StatusCode::UNAUTHORIZED)
.body(serde_json::to_vec(
&atrium_xrpc::error::ErrorResponseBody {
error: Some(String::from("AuthenticationRequired")),
message: Some(String::from("Invalid identifier or password")),
},
)?)?)
} else {
Ok(builder.status(http::StatusCode::OK).body(body)?)
}
}
}
impl XrpcClient for DummyClient {
fn base_uri(&self) -> String {
"http://localhost:8080".into()
}
}
fn session() -> Session {
Session {
access_jwt: String::from("access"),
did: "did:web:example.com".parse().expect("valid"),
did_doc: None,
email: None,
email_confirmed: None,
handle: "example.com".parse().expect("valid"),
refresh_jwt: String::from("refresh"),
}
}
#[tokio::test]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
async fn test_new() {
let agent = AtpAgent::new(DummyClient::default(), MemorySessionStore::default());
assert_eq!(agent.store.get_session().await, None);
}
#[tokio::test]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
async fn test_login() {
let session = session();
{
let client = DummyClient {
responses: DummyResponses {
create_session: Some(crate::com::atproto::server::create_session::Output {
..session.clone()
}),
..Default::default()
},
..Default::default()
};
let agent = AtpAgent::new(client, MemorySessionStore::default());
agent
.login("test", "pass")
.await
.expect("login should be succeeded");
assert_eq!(agent.store.get_session().await, Some(session));
}
{
let client = DummyClient {
responses: DummyResponses {
..Default::default()
},
..Default::default()
};
let agent = AtpAgent::new(client, MemorySessionStore::default());
agent
.login("test", "bad")
.await
.expect_err("login should be failed");
assert_eq!(agent.store.get_session().await, None);
}
}
#[tokio::test]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
async fn test_xrpc_get_session() {
let session = session();
let client = DummyClient {
responses: DummyResponses {
get_session: Some(crate::com::atproto::server::get_session::Output {
did: session.did.clone(),
did_doc: session.did_doc.clone(),
email: session.email.clone(),
email_confirmed: session.email_confirmed,
handle: session.handle.clone(),
}),
..Default::default()
},
..Default::default()
};
let agent = AtpAgent::new(client, MemorySessionStore::default());
agent.store.set_session(session).await;
let output = agent
.api
.com
.atproto
.server
.get_session()
.await
.expect("get session should be succeeded");
assert_eq!(output.did.as_str(), "did:web:example.com");
}
#[tokio::test]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
async fn test_xrpc_get_session_with_refresh() {
let mut session = session();
session.access_jwt = String::from("expired");
let client = DummyClient {
responses: DummyResponses {
get_session: Some(crate::com::atproto::server::get_session::Output {
did: session.did.clone(),
did_doc: session.did_doc.clone(),
email: session.email.clone(),
email_confirmed: session.email_confirmed,
handle: session.handle.clone(),
}),
..Default::default()
},
..Default::default()
};
let agent = AtpAgent::new(client, MemorySessionStore::default());
agent.store.set_session(session).await;
let output = agent
.api
.com
.atproto
.server
.get_session()
.await
.expect("get session should be succeeded");
assert_eq!(output.did.as_str(), "did:web:example.com");
assert_eq!(
agent
.store
.get_session()
.await
.map(|session| session.access_jwt),
Some("access".into())
);
}
#[cfg(not(target_arch = "wasm32"))]
#[tokio::test]
async fn test_xrpc_get_session_with_duplicated_refresh() {
let mut session = session();
session.access_jwt = String::from("expired");
let client = DummyClient {
responses: DummyResponses {
get_session: Some(crate::com::atproto::server::get_session::Output {
did: session.did.clone(),
did_doc: session.did_doc.clone(),
email: session.email.clone(),
email_confirmed: session.email_confirmed,
handle: session.handle.clone(),
}),
..Default::default()
},
..Default::default()
};
let counts = Arc::clone(&client.counts);
let agent = Arc::new(AtpAgent::new(client, MemorySessionStore::default()));
agent.store.set_session(session).await;
let handles = (0..3).map(|_| {
let agent = Arc::clone(&agent);
tokio::spawn(async move { agent.api.com.atproto.server.get_session().await })
});
let results = futures::future::join_all(handles).await;
for result in &results {
let output = result
.as_ref()
.expect("task should be successfully executed")
.as_ref()
.expect("get session should be succeeded");
assert_eq!(output.did.as_str(), "did:web:example.com");
}
assert_eq!(
agent
.store
.get_session()
.await
.map(|session| session.access_jwt),
Some("access".into())
);
assert_eq!(
counts.read().await.clone(),
HashMap::from_iter([
("com.atproto.server.refreshSession".into(), 1),
("com.atproto.server.getSession".into(), 3)
])
);
}
#[tokio::test]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
async fn test_resume_session() {
let session = session();
{
let client = DummyClient {
responses: DummyResponses {
get_session: Some(crate::com::atproto::server::get_session::Output {
did: session.did.clone(),
did_doc: session.did_doc.clone(),
email: session.email.clone(),
email_confirmed: session.email_confirmed,
handle: session.handle.clone(),
}),
..Default::default()
},
..Default::default()
};
let agent = AtpAgent::new(client, MemorySessionStore::default());
assert_eq!(agent.store.get_session().await, None);
agent
.resume_session(Session {
email: Some(String::from("test@example.com")),
..session.clone()
})
.await
.expect("resume_session should be succeeded");
assert_eq!(agent.store.get_session().await, Some(session.clone()));
}
{
let client = DummyClient {
responses: DummyResponses {
..Default::default()
},
..Default::default()
};
let agent = AtpAgent::new(client, MemorySessionStore::default());
assert_eq!(agent.store.get_session().await, None);
agent
.resume_session(session)
.await
.expect_err("resume_session should be failed");
assert_eq!(agent.store.get_session().await, None);
}
}
#[tokio::test]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
async fn test_resume_session_with_refresh() {
let session = session();
let client = DummyClient {
responses: DummyResponses {
get_session: Some(crate::com::atproto::server::get_session::Output {
did: session.did.clone(),
did_doc: session.did_doc.clone(),
email: session.email.clone(),
email_confirmed: session.email_confirmed,
handle: session.handle.clone(),
}),
..Default::default()
},
..Default::default()
};
let agent = AtpAgent::new(client, MemorySessionStore::default());
agent
.resume_session(Session {
access_jwt: "expired".into(),
..session.clone()
})
.await
.expect("resume_session should be succeeded");
assert_eq!(agent.store.get_session().await, Some(session));
}
#[tokio::test]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
async fn test_login_with_diddoc() {
let session = session();
let did_doc = DidDocument {
id: "did:plc:ewvi7nxzyoun6zhxrhs64oiz".into(),
also_known_as: Some(vec!["at://atproto.com".into()]),
verification_method: Some(vec![VerificationMethod {
id: "did:plc:ewvi7nxzyoun6zhxrhs64oiz#atproto".into(),
r#type: "Multikey".into(),
controller: "did:plc:ewvi7nxzyoun6zhxrhs64oiz".into(),
public_key_multibase: Some(
"zQ3shXjHeiBuRCKmM36cuYnm7YEMzhGnCmCyW92sRJ9pribSF".into(),
),
}]),
service: Some(vec![Service {
id: "#atproto_pds".into(),
r#type: "AtprotoPersonalDataServer".into(),
service_endpoint: "https://bsky.social".into(),
}]),
};
{
let client = DummyClient {
responses: DummyResponses {
create_session: Some(crate::com::atproto::server::create_session::Output {
did_doc: Some(did_doc.clone()),
..session.clone()
}),
..Default::default()
},
..Default::default()
};
let agent = AtpAgent::new(client, MemorySessionStore::default());
agent
.login("test", "pass")
.await
.expect("login should be succeeded");
assert_eq!(agent.store.get_endpoint(), "https://bsky.social");
assert_eq!(
agent.api.com.atproto.server.xrpc.base_uri(),
"https://bsky.social"
);
}
{
let client = DummyClient {
responses: DummyResponses {
create_session: Some(crate::com::atproto::server::create_session::Output {
did_doc: Some(DidDocument {
service: Some(vec![
Service {
id: "#pds".into(), r#type: "AtprotoPersonalDataServer".into(),
service_endpoint: "https://bsky.social".into(),
},
Service {
id: "#atproto_pds".into(),
r#type: "AtprotoPersonalDataServer".into(),
service_endpoint: "htps://bsky.social".into(), },
]),
..did_doc.clone()
}),
..session.clone()
}),
..Default::default()
},
..Default::default()
};
let agent = AtpAgent::new(client, MemorySessionStore::default());
agent
.login("test", "pass")
.await
.expect("login should be succeeded");
assert_eq!(agent.store.get_endpoint(), "http://localhost:8080");
assert_eq!(
agent.api.com.atproto.server.xrpc.base_uri(),
"http://localhost:8080"
);
}
}
}