atrium_xrpc/
traits.rs

1use crate::error::{Error, XrpcError, XrpcErrorKind};
2use crate::types::{AuthorizationToken, Header, NSID_REFRESH_SESSION};
3use crate::{InputDataOrBytes, OutputDataOrBytes, XrpcRequest};
4use http::{header::WWW_AUTHENTICATE, Method, Request, Response};
5use serde::{de::DeserializeOwned, Serialize};
6use std::{fmt::Debug, future::Future};
7
8/// An abstract HTTP client.
9#[cfg_attr(not(target_arch = "wasm32"), trait_variant::make(Send))]
10pub trait HttpClient {
11    /// Send an HTTP request and return the response.
12    fn send_http(
13        &self,
14        request: Request<Vec<u8>>,
15    ) -> impl Future<
16        Output = core::result::Result<
17            Response<Vec<u8>>,
18            Box<dyn std::error::Error + Send + Sync + 'static>,
19        >,
20    >;
21}
22
23type XrpcResult<O, E> = core::result::Result<OutputDataOrBytes<O>, self::Error<E>>;
24
25/// An abstract XRPC client.
26///
27/// [`send_xrpc()`](XrpcClient::send_xrpc) method has a default implementation,
28/// which wraps the [`HttpClient::send_http()`]` method to handle input and output as an XRPC Request.
29#[cfg_attr(not(target_arch = "wasm32"), trait_variant::make(Send))]
30pub trait XrpcClient: HttpClient {
31    /// The base URI of the XRPC server.
32    fn base_uri(&self) -> String;
33    /// Get the authorization token to use `Authorization` header.
34    #[allow(unused_variables)]
35    fn authorization_token(
36        &self,
37        is_refresh: bool,
38    ) -> impl Future<Output = Option<AuthorizationToken>> {
39        async { None }
40    }
41    /// Get the `atproto-proxy` header.
42    fn atproto_proxy_header(&self) -> impl Future<Output = Option<String>> {
43        async { None }
44    }
45    /// Get the `atproto-accept-labelers` header.
46    fn atproto_accept_labelers_header(&self) -> impl Future<Output = Option<Vec<String>>> {
47        async { None }
48    }
49    /// Send an XRPC request and return the response.
50    #[cfg(not(target_arch = "wasm32"))]
51    fn send_xrpc<P, I, O, E>(
52        &self,
53        request: &XrpcRequest<P, I>,
54    ) -> impl Future<Output = XrpcResult<O, E>>
55    where
56        P: Serialize + Send + Sync,
57        I: Serialize + Send + Sync,
58        O: DeserializeOwned + Send + Sync,
59        E: DeserializeOwned + Send + Sync + Debug,
60        // This code is duplicated because of this trait bound.
61        // `Self` has to be `Sync` for `Future` to be `Send`.
62        Self: Sync,
63    {
64        send_xrpc(self, request)
65    }
66    #[cfg(target_arch = "wasm32")]
67    fn send_xrpc<P, I, O, E>(
68        &self,
69        request: &XrpcRequest<P, I>,
70    ) -> impl Future<Output = XrpcResult<O, E>>
71    where
72        P: Serialize + Send + Sync,
73        I: Serialize + Send + Sync,
74        O: DeserializeOwned + Send + Sync,
75        E: DeserializeOwned + Send + Sync + Debug,
76    {
77        send_xrpc(self, request)
78    }
79}
80
81#[inline(always)]
82async fn send_xrpc<P, I, O, E, C: XrpcClient + ?Sized>(
83    client: &C,
84    request: &XrpcRequest<P, I>,
85) -> XrpcResult<O, E>
86where
87    P: Serialize + Send + Sync,
88    I: Serialize + Send + Sync,
89    O: DeserializeOwned + Send + Sync,
90    E: DeserializeOwned + Send + Sync + Debug,
91{
92    let mut uri = format!("{}/xrpc/{}", client.base_uri(), request.nsid);
93    // Query parameters
94    if let Some(p) = &request.parameters {
95        serde_html_form::to_string(p).map(|qs| {
96            uri += "?";
97            uri += &qs;
98        })?;
99    };
100    let mut builder = Request::builder().method(&request.method).uri(&uri);
101    // Headers
102    if let Some(encoding) = &request.encoding {
103        builder = builder.header(Header::ContentType, encoding);
104    }
105    if let Some(token) = client
106        .authorization_token(request.method == Method::POST && request.nsid == NSID_REFRESH_SESSION)
107        .await
108    {
109        builder = builder.header(Header::Authorization, token);
110    }
111    if let Some(proxy) = client.atproto_proxy_header().await {
112        builder = builder.header(Header::AtprotoProxy, proxy);
113    }
114    if let Some(accept_labelers) = client.atproto_accept_labelers_header().await {
115        builder = builder.header(Header::AtprotoAcceptLabelers, accept_labelers.join(", "));
116    }
117    // Body
118    let body = if let Some(input) = &request.input {
119        match input {
120            InputDataOrBytes::Data(data) => serde_json::to_vec(&data)?,
121            InputDataOrBytes::Bytes(bytes) => bytes.clone(),
122        }
123    } else {
124        Vec::new()
125    };
126    // Send
127    let (parts, body) =
128        client.send_http(builder.body(body)?).await.map_err(Error::HttpClient)?.into_parts();
129    if parts.status.is_success() {
130        if parts
131            .headers
132            .get(http::header::CONTENT_TYPE)
133            .and_then(|value| value.to_str().ok())
134            .is_some_and(|content_type| content_type.starts_with("application/json"))
135        {
136            Ok(OutputDataOrBytes::Data(serde_json::from_slice(&body)?))
137        } else {
138            Ok(OutputDataOrBytes::Bytes(body))
139        }
140    } else if let Some(value) = parts.headers.get(WWW_AUTHENTICATE) {
141        Err(Error::Authentication(value.clone()))
142    } else {
143        Err(Error::XrpcResponse(XrpcError {
144            status: parts.status,
145            error: serde_json::from_slice::<XrpcErrorKind<E>>(&body).ok(),
146        }))
147    }
148}