atrium_api/agent/
utils.rs

1//! Utilities for managing sessions and endpoints.
2
3use super::{AuthorizationProvider, CloneWithProxy, Configure};
4use crate::{did_doc::DidDocument, types::string::Did};
5use atrium_common::store::Store;
6use atrium_xrpc::{types::AuthorizationToken, HttpClient, XrpcClient};
7use http::{Request, Response};
8use std::{
9    marker::PhantomData,
10    sync::{Arc, RwLock},
11};
12
13/// A client that maintains session data and manages endpoints and XRPC headers.  
14///
15/// It is recommended to use this struct internally in higher-level clients such as [`XrpcClient`], which can automatically update tokens.
16pub struct SessionClient<S, T, U> {
17    store: Arc<SessionWithEndpointStore<S, U>>,
18    proxy_header: RwLock<Option<String>>,
19    labelers_header: Arc<RwLock<Option<Vec<String>>>>,
20    inner: Arc<T>,
21}
22
23impl<S, T, U> SessionClient<S, T, U> {
24    pub fn new(store: Arc<SessionWithEndpointStore<S, U>>, http_client: T) -> Self {
25        Self {
26            store: Arc::clone(&store),
27            labelers_header: Arc::new(RwLock::new(None)),
28            proxy_header: RwLock::new(None),
29            inner: Arc::new(http_client),
30        }
31    }
32}
33
34impl<S, T, U> Configure for SessionClient<S, T, U> {
35    fn configure_endpoint(&self, endpoint: String) {
36        *self.store.endpoint.write().expect("failed to write endpoint") = endpoint;
37    }
38    fn configure_labelers_header(&self, labelers_dids: Option<Vec<(Did, bool)>>) {
39        *self.labelers_header.write().expect("failed to write labelers header") =
40            labelers_dids.map(|dids| {
41                dids.iter()
42                    .map(|(did, redact)| {
43                        if *redact {
44                            format!("{};redact", did.as_ref())
45                        } else {
46                            did.as_ref().into()
47                        }
48                    })
49                    .collect()
50            })
51    }
52    fn configure_proxy_header(&self, did: Did, service_type: impl AsRef<str>) {
53        self.proxy_header.write().expect("failed to write proxy header").replace(format!(
54            "{}#{}",
55            did.as_ref(),
56            service_type.as_ref()
57        ));
58    }
59}
60
61impl<S, T, U> CloneWithProxy for SessionClient<S, T, U> {
62    fn clone_with_proxy(&self, did: Did, service_type: impl AsRef<str>) -> Self {
63        let cloned = self.clone();
64        cloned.configure_proxy_header(did, service_type);
65        cloned
66    }
67}
68
69impl<S, T, U> Clone for SessionClient<S, T, U> {
70    fn clone(&self) -> Self {
71        Self {
72            store: self.store.clone(),
73            labelers_header: self.labelers_header.clone(),
74            proxy_header: RwLock::new(
75                self.proxy_header.read().expect("failed to read proxy header").clone(),
76            ),
77            inner: self.inner.clone(),
78        }
79    }
80}
81
82impl<S, T, U> HttpClient for SessionClient<S, T, U>
83where
84    S: Store<(), U> + Send + Sync,
85    T: HttpClient + Send + Sync,
86    U: Clone + Send + Sync,
87{
88    async fn send_http(
89        &self,
90        request: Request<Vec<u8>>,
91    ) -> core::result::Result<Response<Vec<u8>>, Box<dyn std::error::Error + Send + Sync + 'static>>
92    {
93        self.inner.send_http(request).await
94    }
95}
96
97impl<S, T, U> XrpcClient for SessionClient<S, T, U>
98where
99    S: Store<(), U> + AuthorizationProvider + Send + Sync,
100    T: HttpClient + Send + Sync,
101    U: Clone + Send + Sync,
102{
103    fn base_uri(&self) -> String {
104        self.store.get_endpoint()
105    }
106    async fn authorization_token(&self, is_refresh: bool) -> Option<AuthorizationToken> {
107        self.store.authorization_token(is_refresh).await
108    }
109    async fn atproto_proxy_header(&self) -> Option<String> {
110        self.proxy_header.read().expect("failed to read proxy header").clone()
111    }
112    async fn atproto_accept_labelers_header(&self) -> Option<Vec<String>> {
113        self.labelers_header.read().expect("failed to read labelers header").clone()
114    }
115}
116
117/// A store that wraps an underlying store providing authorization token and adds endpoint management functionality.
118///
119/// This struct is intended to be used when creating a [`SessionClient`].
120pub struct SessionWithEndpointStore<S, U> {
121    inner: S,
122    pub endpoint: RwLock<String>,
123    _phantom: PhantomData<U>,
124}
125
126impl<S, U> SessionWithEndpointStore<S, U> {
127    pub fn new(inner: S, initial_endpoint: String) -> Self {
128        Self { inner, endpoint: RwLock::new(initial_endpoint), _phantom: PhantomData }
129    }
130    pub fn get_endpoint(&self) -> String {
131        self.endpoint.read().expect("failed to read endpoint").clone()
132    }
133    pub fn update_endpoint(&self, did_doc: &DidDocument) {
134        if let Some(endpoint) = did_doc.get_pds_endpoint() {
135            *self.endpoint.write().expect("failed to write endpoint") = endpoint;
136        }
137    }
138}
139
140impl<S, U> SessionWithEndpointStore<S, U>
141where
142    S: Store<(), U>,
143    U: Clone,
144{
145    pub async fn get(&self) -> Result<Option<U>, S::Error> {
146        self.inner.get(&()).await
147    }
148    pub async fn set(&self, value: U) -> Result<(), S::Error> {
149        self.inner.set((), value).await
150    }
151    pub async fn clear(&self) -> Result<(), S::Error> {
152        self.inner.clear().await
153    }
154}
155
156impl<S, U> AuthorizationProvider for SessionWithEndpointStore<S, U>
157where
158    S: Store<(), U> + AuthorizationProvider + Send + Sync,
159    U: Clone + Send + Sync,
160{
161    async fn authorization_token(&self, is_refresh: bool) -> Option<AuthorizationToken> {
162        self.inner.authorization_token(is_refresh).await
163    }
164}