1use super::{AuthorizationProvider, CloneWithProxy, Configure};
4use crate::{did_doc::DidDocument, types::string::Did};
5use atrium_common::store::Store;
6use atrium_xrpc::{types::AuthorizationToken, HttpClient, XrpcClient};
7use http::{Request, Response};
8use std::{
9 marker::PhantomData,
10 sync::{Arc, RwLock},
11};
12
13pub struct SessionClient<S, T, U> {
17 store: Arc<SessionWithEndpointStore<S, U>>,
18 proxy_header: RwLock<Option<String>>,
19 labelers_header: Arc<RwLock<Option<Vec<String>>>>,
20 inner: Arc<T>,
21}
22
23impl<S, T, U> SessionClient<S, T, U> {
24 pub fn new(store: Arc<SessionWithEndpointStore<S, U>>, http_client: T) -> Self {
25 Self {
26 store: Arc::clone(&store),
27 labelers_header: Arc::new(RwLock::new(None)),
28 proxy_header: RwLock::new(None),
29 inner: Arc::new(http_client),
30 }
31 }
32}
33
34impl<S, T, U> Configure for SessionClient<S, T, U> {
35 fn configure_endpoint(&self, endpoint: String) {
36 *self.store.endpoint.write().expect("failed to write endpoint") = endpoint;
37 }
38 fn configure_labelers_header(&self, labelers_dids: Option<Vec<(Did, bool)>>) {
39 *self.labelers_header.write().expect("failed to write labelers header") =
40 labelers_dids.map(|dids| {
41 dids.iter()
42 .map(|(did, redact)| {
43 if *redact {
44 format!("{};redact", did.as_ref())
45 } else {
46 did.as_ref().into()
47 }
48 })
49 .collect()
50 })
51 }
52 fn configure_proxy_header(&self, did: Did, service_type: impl AsRef<str>) {
53 self.proxy_header.write().expect("failed to write proxy header").replace(format!(
54 "{}#{}",
55 did.as_ref(),
56 service_type.as_ref()
57 ));
58 }
59}
60
61impl<S, T, U> CloneWithProxy for SessionClient<S, T, U> {
62 fn clone_with_proxy(&self, did: Did, service_type: impl AsRef<str>) -> Self {
63 let cloned = self.clone();
64 cloned.configure_proxy_header(did, service_type);
65 cloned
66 }
67}
68
69impl<S, T, U> Clone for SessionClient<S, T, U> {
70 fn clone(&self) -> Self {
71 Self {
72 store: self.store.clone(),
73 labelers_header: self.labelers_header.clone(),
74 proxy_header: RwLock::new(
75 self.proxy_header.read().expect("failed to read proxy header").clone(),
76 ),
77 inner: self.inner.clone(),
78 }
79 }
80}
81
82impl<S, T, U> HttpClient for SessionClient<S, T, U>
83where
84 S: Store<(), U> + Send + Sync,
85 T: HttpClient + Send + Sync,
86 U: Clone + Send + Sync,
87{
88 async fn send_http(
89 &self,
90 request: Request<Vec<u8>>,
91 ) -> core::result::Result<Response<Vec<u8>>, Box<dyn std::error::Error + Send + Sync + 'static>>
92 {
93 self.inner.send_http(request).await
94 }
95}
96
97impl<S, T, U> XrpcClient for SessionClient<S, T, U>
98where
99 S: Store<(), U> + AuthorizationProvider + Send + Sync,
100 T: HttpClient + Send + Sync,
101 U: Clone + Send + Sync,
102{
103 fn base_uri(&self) -> String {
104 self.store.get_endpoint()
105 }
106 async fn authorization_token(&self, is_refresh: bool) -> Option<AuthorizationToken> {
107 self.store.authorization_token(is_refresh).await
108 }
109 async fn atproto_proxy_header(&self) -> Option<String> {
110 self.proxy_header.read().expect("failed to read proxy header").clone()
111 }
112 async fn atproto_accept_labelers_header(&self) -> Option<Vec<String>> {
113 self.labelers_header.read().expect("failed to read labelers header").clone()
114 }
115}
116
117pub struct SessionWithEndpointStore<S, U> {
121 inner: S,
122 pub endpoint: RwLock<String>,
123 _phantom: PhantomData<U>,
124}
125
126impl<S, U> SessionWithEndpointStore<S, U> {
127 pub fn new(inner: S, initial_endpoint: String) -> Self {
128 Self { inner, endpoint: RwLock::new(initial_endpoint), _phantom: PhantomData }
129 }
130 pub fn get_endpoint(&self) -> String {
131 self.endpoint.read().expect("failed to read endpoint").clone()
132 }
133 pub fn update_endpoint(&self, did_doc: &DidDocument) {
134 if let Some(endpoint) = did_doc.get_pds_endpoint() {
135 *self.endpoint.write().expect("failed to write endpoint") = endpoint;
136 }
137 }
138}
139
140impl<S, U> SessionWithEndpointStore<S, U>
141where
142 S: Store<(), U>,
143 U: Clone,
144{
145 pub async fn get(&self) -> Result<Option<U>, S::Error> {
146 self.inner.get(&()).await
147 }
148 pub async fn set(&self, value: U) -> Result<(), S::Error> {
149 self.inner.set((), value).await
150 }
151 pub async fn clear(&self) -> Result<(), S::Error> {
152 self.inner.clear().await
153 }
154}
155
156impl<S, U> AuthorizationProvider for SessionWithEndpointStore<S, U>
157where
158 S: Store<(), U> + AuthorizationProvider + Send + Sync,
159 U: Clone + Send + Sync,
160{
161 async fn authorization_token(&self, is_refresh: bool) -> Option<AuthorizationToken> {
162 self.inner.authorization_token(is_refresh).await
163 }
164}