use super::store::MemorySessionStore;
use crate::{
DpopClient,
store::{session::SessionStore, session_registry::SessionRegistry},
};
use atrium_api::{
agent::{
CloneWithProxy, Configure,
utils::{SessionClient, SessionWithEndpointStore},
},
did_doc::DidDocument,
types::string::{Did, Handle},
};
use atrium_common::resolver::Resolver;
use atrium_xrpc::{
Error, HttpClient, OutputDataOrBytes, XrpcClient, XrpcRequest,
http::{Request, Response},
};
use serde::{Serialize, de::DeserializeOwned};
use std::{fmt::Debug, sync::Arc};
pub struct Client<S, T, D, H>
where
S: SessionStore + Send + Sync + 'static,
T: HttpClient + Send + Sync + 'static,
{
inner: SessionClient<MemorySessionStore, DpopClient<T>, String>,
store: Arc<SessionWithEndpointStore<MemorySessionStore, String>>,
sub: Did,
session_registry: Arc<SessionRegistry<S, T, D, H>>,
}
impl<S, T, D, H> Client<S, T, D, H>
where
S: SessionStore + Send + Sync + 'static,
T: HttpClient + Send + Sync + 'static,
{
pub fn new(
store: Arc<SessionWithEndpointStore<MemorySessionStore, String>>,
xrpc: DpopClient<T>,
sub: Did,
session_registry: Arc<SessionRegistry<S, T, D, H>>,
) -> Self {
let inner = SessionClient::new(Arc::clone(&store), xrpc);
Self { inner, store, sub, session_registry }
}
}
impl<S, T, D, H> Client<S, T, D, H>
where
S: SessionStore + Send + Sync + 'static,
T: HttpClient + Send + Sync + 'static,
D: Resolver<Input = Did, Output = DidDocument, Error = atrium_identity::Error> + Send + Sync,
H: Resolver<Input = Handle, Output = Did, Error = atrium_identity::Error> + Send + Sync,
{
fn is_invalid_token_response<O, E>(result: &Result<OutputDataOrBytes<O>, Error<E>>) -> bool
where
O: DeserializeOwned + Send + Sync,
E: DeserializeOwned + Send + Sync + Debug,
{
match result {
Err(Error::Authentication(value)) => value
.to_str()
.is_ok_and(|s| s.starts_with("DPoP ") && s.contains("error=\"invalid_token\"")),
_ => false,
}
}
async fn refresh_token(&self) {
if let Ok(session) = self.session_registry.get(&self.sub, true).await {
let _ = self.store.set(session.token_set.access_token.clone()).await;
}
}
}
impl<S, T, D, H> HttpClient for Client<S, T, D, H>
where
S: SessionStore + Send + Sync + 'static,
T: HttpClient + Send + Sync + 'static,
D: Send + Sync,
H: Send + Sync,
{
async fn send_http(
&self,
request: Request<Vec<u8>>,
) -> Result<Response<Vec<u8>>, Box<dyn std::error::Error + Send + Sync + 'static>> {
self.inner.send_http(request).await
}
}
impl<S, T, D, H> XrpcClient for Client<S, T, D, H>
where
S: SessionStore + Send + Sync + 'static,
T: HttpClient + Send + Sync + 'static,
D: Resolver<Input = Did, Output = DidDocument, Error = atrium_identity::Error> + Send + Sync,
H: Resolver<Input = Handle, Output = Did, Error = atrium_identity::Error> + Send + Sync,
{
fn base_uri(&self) -> String {
self.inner.base_uri()
}
async fn send_xrpc<P, I, O, E>(
&self,
request: &XrpcRequest<P, I>,
) -> Result<OutputDataOrBytes<O>, Error<E>>
where
P: Serialize + Send + Sync,
I: Serialize + Send + Sync,
O: DeserializeOwned + Send + Sync,
E: DeserializeOwned + Send + Sync + Debug,
{
let result = self.inner.send_xrpc(request).await;
if Self::is_invalid_token_response(&result) {
self.refresh_token().await;
self.inner.send_xrpc(request).await
} else {
result
}
}
}
impl<S, T, D, H> Configure for Client<S, T, D, H>
where
S: SessionStore + Send + Sync + 'static,
T: HttpClient + Send + Sync + 'static,
{
fn configure_endpoint(&self, endpoint: String) {
self.inner.configure_endpoint(endpoint)
}
fn configure_labelers_header(&self, labeler_dids: Option<Vec<(Did, bool)>>) {
self.inner.configure_labelers_header(labeler_dids)
}
fn configure_proxy_header(&self, did: Did, service_type: impl AsRef<str>) {
self.inner.configure_proxy_header(did, service_type)
}
}
impl<S, T, D, H> CloneWithProxy for Client<S, T, D, H>
where
S: SessionStore + Send + Sync + 'static,
T: HttpClient + Send + Sync + 'static,
SessionClient<S, T, String>: CloneWithProxy,
{
fn clone_with_proxy(&self, did: Did, service_type: impl AsRef<str>) -> Self {
Self {
inner: self.inner.clone_with_proxy(did, service_type),
store: Arc::clone(&self.store),
sub: self.sub.clone(),
session_registry: Arc::clone(&self.session_registry),
}
}
}