1pub mod atp_agent;
4#[cfg(feature = "bluesky")]
5pub mod bluesky;
6mod inner;
7mod session_manager;
8pub mod utils;
9
10pub use self::session_manager::SessionManager;
11use crate::{client::Service, types::string::Did};
12use atrium_xrpc::types::AuthorizationToken;
13use std::{future::Future, sync::Arc};
14
15#[cfg_attr(not(target_arch = "wasm32"), trait_variant::make(Send))]
17pub trait AuthorizationProvider {
18 #[allow(unused_variables)]
19 fn authorization_token(
20 &self,
21 is_refresh: bool,
22 ) -> impl Future<Output = Option<AuthorizationToken>>;
23}
24
25pub trait Configure {
27 fn configure_endpoint(&self, endpoint: String);
29 fn configure_labelers_header(&self, labeler_dids: Option<Vec<(Did, bool)>>);
31 fn configure_proxy_header(&self, did: Did, service_type: impl AsRef<str>);
33}
34
35pub trait CloneWithProxy {
37 fn clone_with_proxy(&self, did: Did, service_type: impl AsRef<str>) -> Self;
38}
39
40#[cfg(feature = "bluesky")]
42pub type AtprotoServiceType = self::bluesky::AtprotoServiceType;
43
44#[cfg(not(feature = "bluesky"))]
45pub enum AtprotoServiceType {
46 AtprotoLabeler,
47}
48
49#[cfg(not(feature = "bluesky"))]
50impl AsRef<str> for AtprotoServiceType {
51 fn as_ref(&self) -> &str {
52 match self {
53 Self::AtprotoLabeler => "atproto_labeler",
54 }
55 }
56}
57
58pub struct Agent<M>
79where
80 M: SessionManager + Send + Sync,
81{
82 session_manager: Arc<inner::Wrapper<M>>,
83 pub api: Service<inner::Wrapper<M>>,
84}
85
86impl<M> Agent<M>
87where
88 M: SessionManager + Send + Sync,
89{
90 pub fn new(session_manager: M) -> Self {
92 let session_manager = Arc::new(inner::Wrapper::new(session_manager));
93 let api = Service::new(session_manager.clone());
94 Self { session_manager, api }
95 }
96 pub async fn did(&self) -> Option<Did> {
98 self.session_manager.did().await
99 }
100}
101
102impl<M> Agent<M>
103where
104 M: CloneWithProxy + SessionManager + Send + Sync,
105{
106 pub fn api_with_proxy(
110 &self,
111 did: Did,
112 service_type: impl AsRef<str>,
113 ) -> Service<inner::Wrapper<M>> {
114 Service::new(Arc::new(self.session_manager.clone_with_proxy(did, service_type)))
115 }
116}
117
118impl<M> Configure for Agent<M>
119where
120 M: Configure + SessionManager + Send + Sync,
121{
122 fn configure_endpoint(&self, endpoint: String) {
123 self.session_manager.configure_endpoint(endpoint);
124 }
125 fn configure_labelers_header(&self, labeler_dids: Option<Vec<(Did, bool)>>) {
126 self.session_manager.configure_labelers_header(labeler_dids);
127 }
128 fn configure_proxy_header(&self, did: Did, service_type: impl AsRef<str>) {
129 self.session_manager.configure_proxy_header(did, service_type);
130 }
131}
132
133#[cfg(test)]
134mod tests {
135 use super::inner::Wrapper;
136 use super::utils::{SessionClient, SessionWithEndpointStore};
137 use super::*;
138 use atrium_common::store::Store;
139 use atrium_xrpc::{Error, HttpClient, OutputDataOrBytes, XrpcClient, XrpcRequest};
140 use http::{header::CONTENT_TYPE, HeaderMap, HeaderName, HeaderValue, Request, Response};
141 use serde::{de::DeserializeOwned, Serialize};
142 use std::fmt::Debug;
143 use tokio::sync::Mutex;
144
145 #[derive(Default)]
146 struct RecordData {
147 host: Option<String>,
148 headers: HeaderMap<HeaderValue>,
149 }
150
151 struct MockClient {
152 data: Arc<Mutex<Option<RecordData>>>,
153 }
154
155 impl HttpClient for MockClient {
156 async fn send_http(
157 &self,
158 request: Request<Vec<u8>>,
159 ) -> Result<Response<Vec<u8>>, Box<dyn std::error::Error + Send + Sync + 'static>> {
160 self.data.lock().await.replace(RecordData {
161 host: request.uri().host().map(String::from),
162 headers: request.headers().clone(),
163 });
164 let output = crate::com::atproto::server::get_service_auth::OutputData {
165 token: String::from("fake_token"),
166 };
167 Response::builder()
168 .header(CONTENT_TYPE, "application/json")
169 .body(serde_json::to_vec(&output)?)
170 .map_err(|e| e.into())
171 }
172 }
173
174 impl XrpcClient for MockClient {
175 fn base_uri(&self) -> String {
176 unimplemented!()
177 }
178 }
179
180 #[derive(thiserror::Error, Debug)]
181 enum MockStoreError {}
182
183 struct MockStore;
184
185 impl Store<(), ()> for MockStore {
186 type Error = MockStoreError;
187
188 async fn get(&self, _key: &()) -> Result<Option<()>, Self::Error> {
189 unimplemented!()
190 }
191 async fn set(&self, _key: (), _value: ()) -> Result<(), Self::Error> {
192 unimplemented!()
193 }
194 async fn del(&self, _key: &()) -> Result<(), Self::Error> {
195 unimplemented!()
196 }
197 async fn clear(&self) -> Result<(), Self::Error> {
198 unimplemented!()
199 }
200 }
201
202 impl AuthorizationProvider for MockStore {
203 async fn authorization_token(&self, _: bool) -> Option<AuthorizationToken> {
204 None
205 }
206 }
207
208 struct MockSessionManager {
209 inner: SessionClient<MockStore, MockClient, ()>,
210 }
211
212 impl HttpClient for MockSessionManager {
213 async fn send_http(
214 &self,
215 request: Request<Vec<u8>>,
216 ) -> Result<Response<Vec<u8>>, Box<dyn std::error::Error + Send + Sync + 'static>> {
217 self.inner.send_http(request).await
218 }
219 }
220
221 impl XrpcClient for MockSessionManager {
222 fn base_uri(&self) -> String {
223 self.inner.base_uri()
224 }
225 async fn send_xrpc<P, I, O, E>(
226 &self,
227 request: &XrpcRequest<P, I>,
228 ) -> Result<OutputDataOrBytes<O>, Error<E>>
229 where
230 P: Serialize + Send + Sync,
231 I: Serialize + Send + Sync,
232 O: DeserializeOwned + Send + Sync,
233 E: DeserializeOwned + Send + Sync + Debug,
234 {
235 self.inner.send_xrpc(request).await
236 }
237 }
238
239 impl SessionManager for MockSessionManager {
240 async fn did(&self) -> Option<Did> {
241 Did::new(String::from("did:fake:handle.test")).ok()
242 }
243 }
244
245 impl Configure for MockSessionManager {
246 fn configure_endpoint(&self, endpoint: String) {
247 self.inner.configure_endpoint(endpoint);
248 }
249 fn configure_labelers_header(&self, labeler_dids: Option<Vec<(Did, bool)>>) {
250 self.inner.configure_labelers_header(labeler_dids);
251 }
252 fn configure_proxy_header(&self, did: Did, service_type: impl AsRef<str>) {
253 self.inner.configure_proxy_header(did, service_type);
254 }
255 }
256
257 impl CloneWithProxy for MockSessionManager {
258 fn clone_with_proxy(&self, did: Did, service_type: impl AsRef<str>) -> Self {
259 Self { inner: self.inner.clone_with_proxy(did, service_type) }
260 }
261 }
262
263 fn agent(data: Arc<Mutex<Option<RecordData>>>) -> Agent<MockSessionManager> {
264 let inner = SessionClient::new(
265 Arc::new(SessionWithEndpointStore::new(
266 MockStore {},
267 String::from("https://example.com"),
268 )),
269 MockClient { data },
270 );
271 Agent::new(MockSessionManager { inner })
272 }
273
274 async fn call_service(
275 service: &Service<Wrapper<MockSessionManager>>,
276 ) -> Result<(), Error<crate::com::atproto::server::get_service_auth::Error>> {
277 let output = service
278 .com
279 .atproto
280 .server
281 .get_service_auth(
282 crate::com::atproto::server::get_service_auth::ParametersData {
283 aud: Did::new(String::from("did:fake:handle.test"))
284 .expect("did should be valid"),
285 exp: None,
286 lxm: None,
287 }
288 .into(),
289 )
290 .await?;
291 assert_eq!(output.token, "fake_token");
292 Ok(())
293 }
294
295 #[tokio::test]
296 async fn test_new() -> Result<(), Box<dyn std::error::Error>> {
297 let agent = agent(Arc::new(Mutex::new(Default::default())));
298 assert_eq!(agent.did().await, Some(Did::new(String::from("did:fake:handle.test"))?));
299 Ok(())
300 }
301
302 #[tokio::test]
303 async fn test_configure_endpoint() -> Result<(), Box<dyn std::error::Error>> {
304 let data = Arc::new(Mutex::new(Default::default()));
305 let agent = agent(data.clone());
306 call_service(&agent.api).await?;
307 assert_eq!(
308 data.lock().await.as_ref().expect("data should be recorded").host.as_deref(),
309 Some("example.com")
310 );
311 agent.configure_endpoint(String::from("https://pds.example.com"));
312 call_service(&agent.api).await?;
313 assert_eq!(
314 data.lock().await.as_ref().expect("data should be recorded").host.as_deref(),
315 Some("pds.example.com")
316 );
317 Ok(())
318 }
319
320 #[tokio::test]
321 async fn test_configure_labelers_header() -> Result<(), Box<dyn std::error::Error>> {
322 let data = Arc::new(Mutex::new(Default::default()));
323 let agent = agent(data.clone());
324 {
326 call_service(&agent.api).await?;
327 assert_eq!(
328 data.lock().await.as_ref().expect("data should be recorded").headers,
329 HeaderMap::new()
330 );
331 }
332 {
334 agent.configure_labelers_header(Some(vec![(
335 Did::new(String::from("did:fake:labeler.test"))?,
336 false,
337 )]));
338 call_service(&agent.api).await?;
339 assert_eq!(
340 data.lock().await.as_ref().expect("data should be recorded").headers,
341 HeaderMap::from_iter([(
342 HeaderName::from_static("atproto-accept-labelers"),
343 HeaderValue::from_static("did:fake:labeler.test"),
344 )])
345 );
346 }
347 {
349 agent.configure_labelers_header(Some(vec![
350 (Did::new(String::from("did:fake:labeler.test_redact"))?, true),
351 (Did::new(String::from("did:fake:labeler.test"))?, false),
352 ]));
353 call_service(&agent.api).await?;
354 assert_eq!(
355 data.lock().await.as_ref().expect("data should be recorded").headers,
356 HeaderMap::from_iter([(
357 HeaderName::from_static("atproto-accept-labelers"),
358 HeaderValue::from_static(
359 "did:fake:labeler.test_redact;redact, did:fake:labeler.test"
360 ),
361 )])
362 );
363 }
364 Ok(())
365 }
366
367 #[tokio::test]
368 async fn test_configure_proxy_header() -> Result<(), Box<dyn std::error::Error>> {
369 let data = Arc::new(Mutex::new(Default::default()));
370 let agent = agent(data.clone());
371 {
373 call_service(&agent.api).await?;
374 assert_eq!(
375 data.lock().await.as_ref().expect("data should be recorded").headers,
376 HeaderMap::new()
377 );
378 }
379 {
381 agent.configure_proxy_header(
382 Did::new(String::from("did:fake:service.test"))?,
383 AtprotoServiceType::AtprotoLabeler,
384 );
385 call_service(&agent.api).await?;
386 assert_eq!(
387 data.lock().await.as_ref().expect("data should be recorded").headers,
388 HeaderMap::from_iter([(
389 HeaderName::from_static("atproto-proxy"),
390 HeaderValue::from_static("did:fake:service.test#atproto_labeler"),
391 )])
392 );
393 }
394 {
396 agent.configure_proxy_header(
397 Did::new(String::from("did:fake:service.test"))?,
398 "custom_service",
399 );
400 call_service(&agent.api).await?;
401 assert_eq!(
402 data.lock().await.as_ref().expect("data should be recorded").headers,
403 HeaderMap::from_iter([(
404 HeaderName::from_static("atproto-proxy"),
405 HeaderValue::from_static("did:fake:service.test#custom_service"),
406 )])
407 );
408 }
409 {
411 call_service(
412 &agent.api_with_proxy(
413 Did::new(String::from("did:fake:service.test"))?,
414 "temp_service",
415 ),
416 )
417 .await?;
418 assert_eq!(
419 data.lock().await.as_ref().expect("data should be recorded").headers,
420 HeaderMap::from_iter([(
421 HeaderName::from_static("atproto-proxy"),
422 HeaderValue::from_static("did:fake:service.test#temp_service"),
423 )])
424 );
425 call_service(&agent.api).await?;
426 assert_eq!(
427 data.lock().await.as_ref().expect("data should be recorded").headers,
428 HeaderMap::from_iter([(
429 HeaderName::from_static("atproto-proxy"),
430 HeaderValue::from_static("did:fake:service.test#custom_service"),
431 )])
432 );
433 }
434 Ok(())
435 }
436}