#[cfg(feature = "streaming")]
pub mod streaming;
use alloc::borrow::ToOwned;
use alloc::boxed::Box;
use alloc::string::{String, ToString};
use alloc::vec::Vec;
use ipld_core::ipld::Ipld;
#[cfg(feature = "streaming")]
pub use streaming::{
StreamingResponse, XrpcProcedureSend, XrpcProcedureStream, XrpcResponseStream, XrpcStreamResp,
};
#[cfg(feature = "websocket")]
pub mod subscription;
#[cfg(feature = "streaming")]
use crate::StreamError;
use crate::error::DecodeError;
use crate::http_client::HttpClient;
#[cfg(feature = "streaming")]
use crate::http_client::HttpClientExt;
use crate::types::value::Data;
use crate::{AuthorizationToken, error::AuthError};
use crate::{CowStr, error::XrpcResult};
use crate::{IntoStatic, types::value::RawData};
use bytes::Bytes;
use core::error::Error;
use core::fmt::{self, Debug};
use core::marker::PhantomData;
use http::{
HeaderName, HeaderValue, Request, StatusCode,
header::{AUTHORIZATION, CONTENT_TYPE},
};
use serde::{Deserialize, Serialize};
use smol_str::SmolStr;
use crate::deps::fluent_uri::Uri;
#[cfg(feature = "websocket")]
pub use subscription::{
BasicSubscriptionClient, MessageEncoding, SubscriptionCall, SubscriptionClient,
SubscriptionEndpoint, SubscriptionExt, SubscriptionOptions, SubscriptionResp,
SubscriptionStream, TungsteniteSubscriptionClient, XrpcSubscription,
};
pub fn normalize_base_uri(uri: Uri<String>) -> Uri<String> {
let s = uri.as_str();
if s.ends_with('/') && s.len() > 1 {
let trimmed = s.trim_end_matches('/');
Uri::parse(trimmed.to_string())
.expect("trimming trailing slash from valid URI yields valid URI")
} else {
uri
}
}
#[derive(Debug, thiserror::Error)]
#[cfg_attr(feature = "std", derive(miette::Diagnostic))]
#[non_exhaustive]
pub enum EncodeError {
#[error("Failed to serialize query: {0}")]
Query(
#[from]
#[source]
serde_html_form::ser::Error,
),
#[error("Failed to serialize JSON: {0}")]
Json(
#[from]
#[source]
serde_json::Error,
),
#[error("Encoding error: {0}")]
Other(String),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum XrpcMethod {
Query,
Procedure(&'static str),
}
impl XrpcMethod {
pub const fn as_str(&self) -> &'static str {
match self {
Self::Query => "GET",
Self::Procedure(_) => "POST",
}
}
pub const fn body_encoding(&self) -> Option<&'static str> {
match self {
Self::Query => None,
Self::Procedure(enc) => Some(enc),
}
}
}
pub trait XrpcRequest: Serialize {
const NSID: &'static str;
const METHOD: XrpcMethod;
type Response: XrpcResp;
fn encode_body(&self) -> Result<Vec<u8>, EncodeError> {
Ok(serde_json::to_vec(self)?)
}
fn decode_body<'de>(body: &'de [u8]) -> Result<Box<Self>, DecodeError>
where
Self: Deserialize<'de>,
{
let body: Self = serde_json::from_slice(body)?;
Ok(Box::new(body))
}
}
pub trait XrpcResp {
const NSID: &'static str;
const ENCODING: &'static str;
type Output<'de>: Serialize + Deserialize<'de> + IntoStatic;
type Err<'de>: Error + Deserialize<'de> + Serialize + IntoStatic;
fn encode_output(output: &Self::Output<'_>) -> Result<Vec<u8>, EncodeError> {
Ok(serde_json::to_vec(output)?)
}
fn decode_output<'de>(body: &'de [u8]) -> core::result::Result<Self::Output<'de>, DecodeError>
where
Self::Output<'de>: Deserialize<'de>,
{
let body = serde_json::from_slice(body).map_err(|e| DecodeError::Json(e))?;
Ok(body)
}
}
pub trait XrpcEndpoint {
const PATH: &'static str;
const METHOD: XrpcMethod;
type Request<'de>: XrpcRequest + Deserialize<'de> + IntoStatic;
type Response: XrpcResp;
}
#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)]
pub struct GenericError<'a>(#[serde(borrow)] Data<'a>);
impl<'de> fmt::Display for GenericError<'de> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.0.fmt(f)
}
}
impl Error for GenericError<'_> {}
impl IntoStatic for GenericError<'_> {
type Output = GenericError<'static>;
fn into_static(self) -> Self::Output {
GenericError(self.0.into_static())
}
}
#[derive(Debug, Default, Clone)]
pub struct CallOptions<'a> {
pub auth: Option<AuthorizationToken<'a>>,
pub atproto_proxy: Option<CowStr<'a>>,
pub atproto_accept_labelers: Option<Vec<CowStr<'a>>>,
pub extra_headers: Vec<(HeaderName, HeaderValue)>,
}
impl IntoStatic for CallOptions<'_> {
type Output = CallOptions<'static>;
fn into_static(self) -> Self::Output {
CallOptions {
auth: self.auth.map(|auth| auth.into_static()),
atproto_proxy: self.atproto_proxy.map(|proxy| proxy.into_static()),
atproto_accept_labelers: self
.atproto_accept_labelers
.map(|labelers| labelers.into_static()),
extra_headers: self.extra_headers,
}
}
}
pub trait XrpcExt: HttpClient {
fn xrpc<'a>(&'a self, base: Uri<String>) -> XrpcCall<'a, Self>
where
Self: Sized,
{
XrpcCall {
client: self,
base,
opts: CallOptions::default(),
}
}
}
impl<T: HttpClient> XrpcExt for T {}
pub type XrpcResponse<R> = Response<<R as XrpcRequest>::Response>;
#[cfg_attr(not(target_arch = "wasm32"), trait_variant::make(Send))]
pub trait XrpcClient: HttpClient {
fn base_uri(&self) -> impl Future<Output = Uri<String>>;
fn set_base_uri(&self, uri: Uri<String>) -> impl Future<Output = ()> {
let _ = uri;
async {}
}
fn opts(&self) -> impl Future<Output = CallOptions<'_>> {
async { CallOptions::default() }
}
fn set_opts(&self, opts: CallOptions) -> impl Future<Output = ()> {
let _ = opts;
async {}
}
#[cfg(not(target_arch = "wasm32"))]
fn send<R>(&self, request: R) -> impl Future<Output = XrpcResult<XrpcResponse<R>>>
where
R: XrpcRequest + Send + Sync,
<R as XrpcRequest>::Response: Send + Sync,
Self: Sync;
#[cfg(target_arch = "wasm32")]
fn send<R>(&self, request: R) -> impl Future<Output = XrpcResult<XrpcResponse<R>>>
where
R: XrpcRequest + Send + Sync,
<R as XrpcRequest>::Response: Send + Sync;
#[cfg(not(target_arch = "wasm32"))]
fn send_with_opts<R>(
&self,
request: R,
opts: CallOptions<'_>,
) -> impl Future<Output = XrpcResult<XrpcResponse<R>>>
where
R: XrpcRequest + Send + Sync,
<R as XrpcRequest>::Response: Send + Sync,
Self: Sync;
#[cfg(target_arch = "wasm32")]
fn send_with_opts<R>(
&self,
request: R,
opts: CallOptions<'_>,
) -> impl Future<Output = XrpcResult<XrpcResponse<R>>>
where
R: XrpcRequest + Send + Sync,
<R as XrpcRequest>::Response: Send + Sync;
}
#[cfg(feature = "streaming")]
pub trait XrpcStreamingClient: XrpcClient + HttpClientExt {
#[cfg(not(target_arch = "wasm32"))]
fn download<R>(
&self,
request: R,
) -> impl Future<Output = Result<StreamingResponse, StreamError>> + Send
where
R: XrpcRequest + Send + Sync,
<R as XrpcRequest>::Response: Send + Sync,
Self: Sync;
#[cfg(target_arch = "wasm32")]
fn download<R>(
&self,
request: R,
) -> impl Future<Output = Result<StreamingResponse, StreamError>>
where
R: XrpcRequest + Send + Sync,
<R as XrpcRequest>::Response: Send + Sync;
#[cfg(not(target_arch = "wasm32"))]
fn stream<S>(
&self,
stream: XrpcProcedureSend<S::Frame<'static>>,
) -> impl Future<
Output = Result<
XrpcResponseStream<
<<S as XrpcProcedureStream>::Response as XrpcStreamResp>::Frame<'static>,
>,
StreamError,
>,
>
where
S: XrpcProcedureStream + 'static,
<<S as XrpcProcedureStream>::Response as XrpcStreamResp>::Frame<'static>: XrpcStreamResp,
Self: Sync;
#[cfg(target_arch = "wasm32")]
fn stream<S>(
&self,
stream: XrpcProcedureSend<S::Frame<'static>>,
) -> impl Future<
Output = Result<
XrpcResponseStream<
<<S as XrpcProcedureStream>::Response as XrpcStreamResp>::Frame<'static>,
>,
StreamError,
>,
>
where
S: XrpcProcedureStream + 'static,
<<S as XrpcProcedureStream>::Response as XrpcStreamResp>::Frame<'static>: XrpcStreamResp;
}
pub struct XrpcCall<'a, C: HttpClient> {
pub(crate) client: &'a C,
pub(crate) base: Uri<String>,
pub(crate) opts: CallOptions<'a>,
}
impl<'a, C: HttpClient> XrpcCall<'a, C> {
pub fn auth(mut self, token: AuthorizationToken<'a>) -> Self {
self.opts.auth = Some(token);
self
}
pub fn proxy(mut self, proxy: CowStr<'a>) -> Self {
self.opts.atproto_proxy = Some(proxy);
self
}
pub fn accept_labelers(mut self, labelers: Vec<CowStr<'a>>) -> Self {
self.opts.atproto_accept_labelers = Some(labelers);
self
}
pub fn header(mut self, name: HeaderName, value: HeaderValue) -> Self {
self.opts.extra_headers.push((name, value));
self
}
pub fn with_options(mut self, opts: CallOptions<'a>) -> Self {
self.opts = opts;
self
}
#[cfg_attr(feature = "tracing", tracing::instrument(level = "debug", skip(self, request), fields(nsid = R::NSID)))]
pub async fn send<R>(self, request: &R) -> XrpcResult<Response<<R as XrpcRequest>::Response>>
where
R: XrpcRequest,
<R as XrpcRequest>::Response: Send + Sync,
{
let http_request = build_http_request(&self.base, request, &self.opts)?;
let http_response = self
.client
.send_http(http_request)
.await
.map_err(|e| crate::error::ClientError::transport(e).for_nsid(R::NSID))?;
process_response(http_response)
}
}
#[inline]
pub fn process_response<Resp>(http_response: http::Response<Vec<u8>>) -> XrpcResult<Response<Resp>>
where
Resp: XrpcResp,
{
let status = http_response.status();
#[allow(deprecated)]
if status.as_u16() == 401 {
if let Some(hv) = http_response.headers().get(http::header::WWW_AUTHENTICATE) {
return Err(
crate::error::ClientError::auth(crate::error::AuthError::Other(hv.clone()))
.for_nsid(Resp::NSID),
);
}
}
let buffer = Bytes::from(http_response.into_body());
if !status.is_success() && !matches!(status.as_u16(), 400 | 401) {
return Err(crate::error::ClientError::from(crate::error::HttpError {
status,
body: Some(buffer),
})
.for_nsid(Resp::NSID));
}
Ok(Response::new(buffer, status))
}
pub enum Header {
ContentType,
Authorization,
AtprotoProxy,
AtprotoAcceptLabelers,
}
impl From<Header> for HeaderName {
fn from(value: Header) -> Self {
match value {
Header::ContentType => CONTENT_TYPE,
Header::Authorization => AUTHORIZATION,
Header::AtprotoProxy => HeaderName::from_static("atproto-proxy"),
Header::AtprotoAcceptLabelers => HeaderName::from_static("atproto-accept-labelers"),
}
}
}
fn xrpc_endpoint_uri(
base: &Uri<String>,
nsid: &str,
query: Option<&str>,
) -> XrpcResult<Uri<String>> {
use crate::error::ClientError;
let base_path = base.path().as_str().trim_end_matches('/');
let capacity = base.scheme().as_str().len()
+ 3 + base.authority().map(|a| a.as_str().len()).unwrap_or(0)
+ base_path.len()
+ 6 + nsid.len()
+ query.map(|q| q.len() + 1).unwrap_or(0);
let mut uri_str = String::with_capacity(capacity);
uri_str.push_str(base.scheme().as_str());
uri_str.push_str("://");
if let Some(authority) = base.authority() {
uri_str.push_str(authority.as_str());
}
uri_str.push_str(base_path);
uri_str.push_str("/xrpc/");
uri_str.push_str(nsid);
if let Some(q) = query {
uri_str.push('?');
uri_str.push_str(q);
}
Uri::parse(uri_str)
.map(|u| u.to_owned())
.map_err(|_| ClientError::invalid_request("Failed to construct XRPC endpoint URI"))
}
pub fn build_http_request<'s, R>(
base: &Uri<String>,
req: &R,
opts: &CallOptions<'_>,
) -> XrpcResult<Request<Vec<u8>>>
where
R: XrpcRequest,
{
use crate::error::ClientError;
let query_string = if let XrpcMethod::Query = <R as XrpcRequest>::METHOD {
let qs = serde_html_form::to_string(&req).map_err(|e| {
ClientError::invalid_request(format!("Failed to serialize query: {}", e))
})?;
if !qs.is_empty() { Some(qs) } else { None }
} else {
None
};
let uri = xrpc_endpoint_uri(base, <R as XrpcRequest>::NSID, query_string.as_deref())?;
let method = match <R as XrpcRequest>::METHOD {
XrpcMethod::Query => http::Method::GET,
XrpcMethod::Procedure(_) => http::Method::POST,
};
let mut builder = Request::builder().method(method).uri(uri.as_str());
let has_content_type = opts
.extra_headers
.iter()
.any(|(name, _)| name == CONTENT_TYPE);
if let XrpcMethod::Procedure(encoding) = <R as XrpcRequest>::METHOD {
if !has_content_type {
builder = builder.header(Header::ContentType, encoding);
}
}
let output_encoding = <R::Response as XrpcResp>::ENCODING;
builder = builder.header(http::header::ACCEPT, output_encoding);
if let Some(token) = &opts.auth {
let hv = match token {
AuthorizationToken::Bearer(t) => {
HeaderValue::from_str(&format!("Bearer {}", t.as_ref()))
}
AuthorizationToken::Dpop(t) => HeaderValue::from_str(&format!("DPoP {}", t.as_ref())),
}
.map_err(|e| ClientError::invalid_request(format!("Invalid authorization token: {}", e)))?;
builder = builder.header(Header::Authorization, hv);
}
if let Some(proxy) = &opts.atproto_proxy {
builder = builder.header(Header::AtprotoProxy, proxy.as_ref());
}
if let Some(labelers) = &opts.atproto_accept_labelers {
if !labelers.is_empty() {
let joined = labelers
.iter()
.map(|s| s.as_ref())
.collect::<Vec<_>>()
.join(", ");
builder = builder.header(Header::AtprotoAcceptLabelers, joined);
}
}
for (name, value) in &opts.extra_headers {
builder = builder.header(name, value);
}
let body = if let XrpcMethod::Procedure(_) = R::METHOD {
req.encode_body()
.map_err(|e| ClientError::invalid_request(format!("Failed to encode body: {}", e)))?
} else {
vec![]
};
builder
.body(body)
.map_err(|e| ClientError::invalid_request(format!("Failed to build request: {}", e)))
}
pub struct Response<Resp>
where
Resp: XrpcResp, {
_marker: PhantomData<fn() -> Resp>,
buffer: Bytes,
status: StatusCode,
}
impl<R> Response<R>
where
R: XrpcResp,
{
pub fn new(buffer: Bytes, status: StatusCode) -> Self {
Self {
buffer,
status,
_marker: PhantomData,
}
}
pub fn status(&self) -> StatusCode {
self.status
}
pub fn buffer(&self) -> &Bytes {
&self.buffer
}
pub fn parse<'s>(&'s self) -> Result<RespOutput<'s, R>, XrpcError<RespErr<'s, R>>> {
if self.status.is_success() {
match R::decode_output(&self.buffer) {
Ok(output) => Ok(output),
Err(e) => Err(XrpcError::Decode(e)),
}
} else if self.status.as_u16() == 400 {
match serde_json::from_slice::<R::Err<'_>>(&self.buffer) {
Ok(error) => {
use alloc::string::ToString;
if error.to_string().contains("InvalidToken") {
Err(XrpcError::Auth(AuthError::InvalidToken))
} else if error.to_string().contains("ExpiredToken") {
Err(XrpcError::Auth(AuthError::TokenExpired))
} else {
Err(XrpcError::Xrpc(error))
}
}
Err(_) => {
match serde_json::from_slice::<GenericXrpcError>(&self.buffer) {
Ok(mut generic) => {
generic.nsid = R::NSID;
generic.method = ""; generic.http_status = self.status;
match generic.error.as_str() {
"ExpiredToken" => Err(XrpcError::Auth(AuthError::TokenExpired)),
"InvalidToken" => Err(XrpcError::Auth(AuthError::InvalidToken)),
_ => Err(XrpcError::Generic(generic)),
}
}
Err(e) => Err(XrpcError::Decode(DecodeError::Json(e))),
}
}
}
} else {
match serde_json::from_slice::<GenericXrpcError>(&self.buffer) {
Ok(mut generic) => {
generic.nsid = R::NSID;
generic.method = ""; generic.http_status = self.status;
match generic.error.as_str() {
"ExpiredToken" => Err(XrpcError::Auth(AuthError::TokenExpired)),
"InvalidToken" => Err(XrpcError::Auth(AuthError::InvalidToken)),
_ => Err(XrpcError::Auth(AuthError::NotAuthenticated)),
}
}
Err(e) => Err(XrpcError::Decode(DecodeError::Json(e))),
}
}
}
pub fn parse_data<'s>(&'s self) -> Result<Data<'s>, XrpcError<RespErr<'s, R>>> {
if self.status.is_success() {
match serde_json::from_slice::<_>(&self.buffer) {
Ok(output) => Ok(output),
Err(_) => {
if let Ok(data) = serde_ipld_dagcbor::from_slice::<Ipld>(&self.buffer) {
if let Ok(data) = Data::from_cbor(&data) {
Ok(data.into_static())
} else {
Ok(Data::Bytes(self.buffer.clone()))
}
} else {
Ok(Data::Bytes(self.buffer.clone()))
}
}
}
} else if self.status.as_u16() == 400 {
match serde_json::from_slice::<R::Err<'_>>(&self.buffer) {
Ok(error) => {
use alloc::string::ToString;
if error.to_string().contains("InvalidToken") {
Err(XrpcError::Auth(AuthError::InvalidToken))
} else if error.to_string().contains("ExpiredToken") {
Err(XrpcError::Auth(AuthError::TokenExpired))
} else {
Err(XrpcError::Xrpc(error))
}
}
Err(_) => {
match serde_json::from_slice::<GenericXrpcError>(&self.buffer) {
Ok(mut generic) => {
generic.nsid = R::NSID;
generic.method = ""; generic.http_status = self.status;
match generic.error.as_str() {
"ExpiredToken" => Err(XrpcError::Auth(AuthError::TokenExpired)),
"InvalidToken" => Err(XrpcError::Auth(AuthError::InvalidToken)),
_ => Err(XrpcError::Generic(generic)),
}
}
Err(e) => Err(XrpcError::Decode(DecodeError::Json(e))),
}
}
}
} else {
match serde_json::from_slice::<GenericXrpcError>(&self.buffer) {
Ok(mut generic) => {
generic.nsid = R::NSID;
generic.method = ""; generic.http_status = self.status;
match generic.error.as_str() {
"ExpiredToken" => Err(XrpcError::Auth(AuthError::TokenExpired)),
"InvalidToken" => Err(XrpcError::Auth(AuthError::InvalidToken)),
_ => Err(XrpcError::Auth(AuthError::NotAuthenticated)),
}
}
Err(e) => Err(XrpcError::Decode(DecodeError::Json(e))),
}
}
}
pub fn parse_raw<'s>(&'s self) -> Result<RawData<'s>, XrpcError<RespErr<'s, R>>> {
if self.status.is_success() {
match serde_json::from_slice::<_>(&self.buffer) {
Ok(output) => Ok(output),
Err(_) => {
if let Ok(data) = serde_ipld_dagcbor::from_slice::<Ipld>(&self.buffer) {
if let Ok(data) = RawData::from_cbor(&data) {
Ok(data.into_static())
} else {
Ok(RawData::Bytes(self.buffer.clone()))
}
} else {
Ok(RawData::Bytes(self.buffer.clone()))
}
}
}
} else if self.status.as_u16() == 400 {
match serde_json::from_slice::<R::Err<'_>>(&self.buffer) {
Ok(error) => {
use alloc::string::ToString;
if error.to_string().contains("InvalidToken") {
Err(XrpcError::Auth(AuthError::InvalidToken))
} else if error.to_string().contains("ExpiredToken") {
Err(XrpcError::Auth(AuthError::TokenExpired))
} else {
Err(XrpcError::Xrpc(error))
}
}
Err(_) => {
match serde_json::from_slice::<GenericXrpcError>(&self.buffer) {
Ok(mut generic) => {
generic.nsid = R::NSID;
generic.method = ""; generic.http_status = self.status;
match generic.error.as_str() {
"ExpiredToken" => Err(XrpcError::Auth(AuthError::TokenExpired)),
"InvalidToken" => Err(XrpcError::Auth(AuthError::InvalidToken)),
_ => Err(XrpcError::Generic(generic)),
}
}
Err(e) => Err(XrpcError::Decode(DecodeError::Json(e))),
}
}
}
} else {
match serde_json::from_slice::<GenericXrpcError>(&self.buffer) {
Ok(mut generic) => {
generic.nsid = R::NSID;
generic.method = ""; generic.http_status = self.status;
match generic.error.as_str() {
"ExpiredToken" => Err(XrpcError::Auth(AuthError::TokenExpired)),
"InvalidToken" => Err(XrpcError::Auth(AuthError::InvalidToken)),
_ => Err(XrpcError::Auth(AuthError::NotAuthenticated)),
}
}
Err(e) => Err(XrpcError::Decode(DecodeError::Json(e))),
}
}
}
pub fn transmute<NEW: XrpcResp>(self) -> Response<NEW> {
Response {
buffer: self.buffer,
status: self.status,
_marker: PhantomData,
}
}
}
pub type RespOutput<'a, Resp> = <Resp as XrpcResp>::Output<'a>;
pub type RespErr<'a, Resp> = <Resp as XrpcResp>::Err<'a>;
impl<R> Response<R>
where
R: XrpcResp,
{
pub fn into_output(self) -> Result<RespOutput<'static, R>, XrpcError<RespErr<'static, R>>>
where
for<'a> RespOutput<'a, R>: IntoStatic<Output = RespOutput<'static, R>>,
for<'a> RespErr<'a, R>: IntoStatic<Output = RespErr<'static, R>>,
{
fn parse_error<'b, R: XrpcResp>(buffer: &'b [u8]) -> Result<R::Err<'b>, serde_json::Error> {
serde_json::from_slice(buffer)
}
if self.status.is_success() {
match R::decode_output(&self.buffer) {
Ok(output) => Ok(output.into_static()),
Err(e) => Err(XrpcError::Decode(e)),
}
} else if self.status.as_u16() == 400 {
let error = match parse_error::<R>(&self.buffer) {
Ok(error) => {
use alloc::string::ToString;
if error.to_string().contains("InvalidToken") {
XrpcError::Auth(AuthError::InvalidToken)
} else if error.to_string().contains("ExpiredToken") {
XrpcError::Auth(AuthError::TokenExpired)
} else {
XrpcError::Xrpc(error)
}
}
Err(_) => {
match serde_json::from_slice::<GenericXrpcError>(&self.buffer) {
Ok(mut generic) => {
generic.nsid = R::NSID;
generic.method = ""; generic.http_status = self.status;
match generic.error.as_ref() {
"ExpiredToken" => XrpcError::Auth(AuthError::TokenExpired),
"InvalidToken" => XrpcError::Auth(AuthError::InvalidToken),
_ => XrpcError::Generic(generic),
}
}
Err(e) => XrpcError::Decode(DecodeError::Json(e)),
}
}
};
Err(error.into_static())
} else {
let error: XrpcError<<R as XrpcResp>::Err<'_>> =
match serde_json::from_slice::<GenericXrpcError>(&self.buffer) {
Ok(mut generic) => {
let status = self.status;
generic.nsid = R::NSID;
generic.method = ""; generic.http_status = status;
match generic.error.as_ref() {
"ExpiredToken" => XrpcError::Auth(AuthError::TokenExpired),
"InvalidToken" => XrpcError::Auth(AuthError::InvalidToken),
_ => XrpcError::Auth(AuthError::NotAuthenticated),
}
}
Err(e) => XrpcError::Decode(DecodeError::Json(e)),
};
Err(error.into_static())
}
}
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct GenericXrpcError {
pub error: SmolStr,
pub message: Option<SmolStr>,
#[serde(skip)]
pub nsid: &'static str,
#[serde(skip)]
pub method: &'static str,
#[serde(skip)]
pub http_status: StatusCode,
}
impl core::fmt::Display for GenericXrpcError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
if let Some(msg) = &self.message {
write!(
f,
"{}: {} (nsid={}, method={}, status={})",
self.error, msg, self.nsid, self.method, self.http_status
)
} else {
write!(
f,
"{} (nsid={}, method={}, status={})",
self.error, self.nsid, self.method, self.http_status
)
}
}
}
impl IntoStatic for GenericXrpcError {
type Output = Self;
fn into_static(self) -> Self::Output {
self
}
}
impl core::error::Error for GenericXrpcError {}
#[derive(Debug, thiserror::Error)]
#[cfg_attr(feature = "std", derive(miette::Diagnostic))]
#[non_exhaustive]
pub enum XrpcError<E: core::error::Error + IntoStatic> {
#[error("XRPC error: {0}")]
#[cfg_attr(feature = "std", diagnostic(code(jacquard_common::xrpc::typed)))]
Xrpc(E),
#[error("Authentication error: {0}")]
#[cfg_attr(feature = "std", diagnostic(code(jacquard_common::xrpc::auth)))]
Auth(#[from] AuthError),
#[error("XRPC error: {0}")]
#[cfg_attr(feature = "std", diagnostic(code(jacquard_common::xrpc::generic)))]
Generic(GenericXrpcError),
#[error("Failed to decode response: {0}")]
#[cfg_attr(feature = "std", diagnostic(code(jacquard_common::xrpc::decode)))]
Decode(#[from] DecodeError),
}
impl<E> IntoStatic for XrpcError<E>
where
E: core::error::Error + IntoStatic,
E::Output: core::error::Error + IntoStatic,
<E as IntoStatic>::Output: core::error::Error + IntoStatic,
{
type Output = XrpcError<E::Output>;
fn into_static(self) -> Self::Output {
match self {
XrpcError::Xrpc(e) => XrpcError::Xrpc(e.into_static()),
XrpcError::Auth(e) => XrpcError::Auth(e.into_static()),
XrpcError::Generic(e) => XrpcError::Generic(e),
XrpcError::Decode(e) => XrpcError::Decode(e),
}
}
}
impl<E> Serialize for XrpcError<E>
where
E: core::error::Error + IntoStatic + Serialize,
{
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
use serde::ser::SerializeStruct;
match self {
XrpcError::Xrpc(e) => e.serialize(serializer),
XrpcError::Generic(g) => g.serialize(serializer),
XrpcError::Auth(auth) => {
let mut state = serializer.serialize_struct("XrpcError", 2)?;
let (error, message) = match auth {
AuthError::TokenExpired => ("ExpiredToken", Some("Access token has expired")),
AuthError::InvalidToken => {
("InvalidToken", Some("Access token is invalid or malformed"))
}
AuthError::RefreshFailed => {
("RefreshFailed", Some("Token refresh request failed"))
}
AuthError::NotAuthenticated => (
"AuthenticationRequired",
Some("Request requires authentication but none was provided"),
),
AuthError::DpopProofFailed => {
("DpopProofFailed", Some("DPoP proof construction failed"))
}
AuthError::DpopNonceFailed => {
("DpopNonceFailed", Some("DPoP nonce negotiation failed"))
}
AuthError::Other(hv) => {
let msg = hv.to_str().unwrap_or("[non-utf8 header]");
("AuthenticationError", Some(msg))
}
};
state.serialize_field("error", error)?;
if let Some(msg) = message {
state.serialize_field("message", msg)?;
}
state.end()
}
XrpcError::Decode(decode_err) => {
let mut state = serializer.serialize_struct("XrpcError", 2)?;
state.serialize_field("error", "ResponseDecodeError")?;
let msg = format!("{:?}", decode_err);
state.serialize_field("message", &msg)?;
state.end()
}
}
}
}
#[cfg(feature = "streaming")]
impl<'a, C: HttpClient + HttpClientExt> XrpcCall<'a, C> {
pub async fn download<R>(self, request: &R) -> Result<StreamingResponse, StreamError>
where
R: XrpcRequest,
<R as XrpcRequest>::Response: Send + Sync,
{
let http_request =
build_http_request(&self.base, request, &self.opts).map_err(StreamError::transport)?;
let http_response = self
.client
.send_http_streaming(http_request)
.await
.map_err(StreamError::transport)?;
let (parts, body) = http_response.into_parts();
Ok(StreamingResponse::new(parts, body))
}
pub async fn stream<S>(
self,
stream: XrpcProcedureSend<S::Frame<'static>>,
) -> Result<XrpcResponseStream<<S::Response as XrpcStreamResp>::Frame<'static>>, StreamError>
where
S: XrpcProcedureStream + 'static,
<<S as XrpcProcedureStream>::Response as XrpcStreamResp>::Frame<'static>: XrpcStreamResp,
{
use futures::TryStreamExt;
let uri = xrpc_endpoint_uri(&self.base, <S::Request as XrpcRequest>::NSID, None).map_err(
|e| StreamError::protocol(format!("Failed to construct endpoint URI: {}", e)),
)?;
let mut builder = http::Request::post(uri.as_str());
if let Some(token) = &self.opts.auth {
let hv = match token {
AuthorizationToken::Bearer(t) => {
HeaderValue::from_str(&format!("Bearer {}", t.as_ref()))
}
AuthorizationToken::Dpop(t) => {
HeaderValue::from_str(&format!("DPoP {}", t.as_ref()))
}
}
.map_err(|e| StreamError::protocol(format!("Invalid authorization token: {}", e)))?;
builder = builder.header(Header::Authorization, hv);
}
if let Some(proxy) = &self.opts.atproto_proxy {
builder = builder.header(Header::AtprotoProxy, proxy.as_ref());
}
if let Some(labelers) = &self.opts.atproto_accept_labelers {
if !labelers.is_empty() {
let joined = labelers
.iter()
.map(|s| s.as_ref())
.collect::<Vec<_>>()
.join(", ");
builder = builder.header(Header::AtprotoAcceptLabelers, joined);
}
}
for (name, value) in &self.opts.extra_headers {
builder = builder.header(name, value);
}
let (parts, _) = builder
.body(())
.map_err(|e| StreamError::protocol(e.to_string()))?
.into_parts();
let body_stream = Box::pin(stream.0.map_ok(|f| f.buffer));
let resp = self
.client
.send_http_bidirectional(parts, body_stream)
.await
.map_err(StreamError::transport)?;
let (parts, body) = resp.into_parts();
Ok(XrpcResponseStream::<
<<S as XrpcProcedureStream>::Response as XrpcStreamResp>::Frame<'static>,
>::from_typed_parts(parts, body))
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize)]
#[allow(dead_code)]
struct DummyReq;
#[derive(Deserialize, Serialize, Debug, thiserror::Error)]
#[error("{0}")]
struct DummyErr<'a>(#[serde(borrow)] CowStr<'a>);
impl IntoStatic for DummyErr<'_> {
type Output = DummyErr<'static>;
fn into_static(self) -> Self::Output {
DummyErr(self.0.into_static())
}
}
struct DummyResp;
impl XrpcResp for DummyResp {
const NSID: &'static str = "test.dummy";
const ENCODING: &'static str = "application/json";
type Output<'de> = ();
type Err<'de> = DummyErr<'de>;
}
impl XrpcRequest for DummyReq {
const NSID: &'static str = "test.dummy";
const METHOD: XrpcMethod = XrpcMethod::Procedure("application/json");
type Response = DummyResp;
}
#[test]
fn generic_error_carries_context() {
let body = serde_json::json!({"error":"InvalidRequest","message":"missing"});
let buf = Bytes::from(serde_json::to_vec(&body).unwrap());
let resp: Response<DummyResp> = Response::new(buf, StatusCode::BAD_REQUEST);
match resp.parse().unwrap_err() {
XrpcError::Generic(g) => {
assert_eq!(g.error.as_str(), "InvalidRequest");
assert_eq!(g.message.as_deref(), Some("missing"));
assert_eq!(g.nsid, DummyResp::NSID);
assert_eq!(g.method, ""); assert_eq!(g.http_status, StatusCode::BAD_REQUEST);
}
other => panic!("unexpected: {other:?}"),
}
}
#[test]
fn auth_error_mapping() {
for (code, expect) in [
("ExpiredToken", AuthError::TokenExpired),
("InvalidToken", AuthError::InvalidToken),
] {
let body = serde_json::json!({"error": code});
let buf = Bytes::from(serde_json::to_vec(&body).unwrap());
let resp: Response<DummyResp> = Response::new(buf, StatusCode::UNAUTHORIZED);
match resp.parse().unwrap_err() {
XrpcError::Auth(e) => match (e, expect) {
(AuthError::TokenExpired, AuthError::TokenExpired) => {}
(AuthError::InvalidToken, AuthError::InvalidToken) => {}
other => panic!("mismatch: {other:?}"),
},
other => panic!("unexpected: {other:?}"),
}
}
}
#[test]
fn xrpc_uri_construction_basic() {
use crate::alloc::string::ToString;
#[derive(Serialize, Deserialize)]
struct Req;
#[derive(Deserialize, Serialize, Debug, thiserror::Error)]
#[error("{0}")]
struct Err<'a>(#[serde(borrow)] CowStr<'a>);
impl IntoStatic for Err<'_> {
type Output = Err<'static>;
fn into_static(self) -> Self::Output {
Err(self.0.into_static())
}
}
struct Resp;
impl XrpcResp for Resp {
const NSID: &'static str = "com.example.test";
const ENCODING: &'static str = "application/json";
type Output<'de> = ();
type Err<'de> = Err<'de>;
}
impl XrpcRequest for Req {
const NSID: &'static str = "com.example.test";
const METHOD: XrpcMethod = XrpcMethod::Query;
type Response = Resp;
}
let opts = CallOptions::default();
let base1 = Uri::parse("https://pds.example.com")
.expect("URI should be valid")
.to_owned();
let req1 = build_http_request(&base1, &Req, &opts).unwrap();
let uri1 = req1.uri().to_string();
assert!(
uri1.contains("/xrpc/com.example.test"),
"AC1.1: URI {} should contain '/xrpc/com.example.test'",
uri1
);
assert_eq!(
uri1, "https://pds.example.com/xrpc/com.example.test",
"AC1.1: URI should be exact match"
);
let base2 = Uri::parse("https://pds.example.com/base")
.expect("URI should be valid")
.to_owned();
let req2 = build_http_request(&base2, &Req, &opts).unwrap();
let uri2 = req2.uri().to_string();
assert!(
uri2.contains("/base/xrpc/com.example.test"),
"AC1.2: URI {} should contain '/base/xrpc/com.example.test'",
uri2
);
assert_eq!(
uri2, "https://pds.example.com/base/xrpc/com.example.test",
"AC1.2: URI should preserve sub-path"
);
let base_with_slash = Uri::parse("https://pds.example.com/")
.expect("URI should be valid")
.to_owned();
let req_slash = build_http_request(&base_with_slash, &Req, &opts).unwrap();
let uri_slash = req_slash.uri().to_string();
assert!(
!uri_slash.contains("//xrpc"),
"AC1.5: URI {} should not contain '//xrpc'",
uri_slash
);
assert_eq!(
uri_slash, "https://pds.example.com/xrpc/com.example.test",
"AC1.5: URI should handle trailing slash"
);
}
#[test]
fn xrpc_uri_query_parameters() {
use crate::alloc::string::ToString;
use serde::Serialize;
#[derive(Serialize)]
struct QueryReq {
#[serde(skip_serializing_if = "Option::is_none")]
param1: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
param2: Option<String>,
}
#[derive(Serialize, Deserialize, Debug, thiserror::Error)]
#[error("test error")]
struct Err;
impl IntoStatic for Err {
type Output = Err;
fn into_static(self) -> Self::Output {
self
}
}
struct Resp;
impl XrpcResp for Resp {
const NSID: &'static str = "com.example.test";
const ENCODING: &'static str = "application/json";
type Output<'de> = ();
type Err<'de> = Err;
}
impl XrpcRequest for QueryReq {
const NSID: &'static str = "com.example.test";
const METHOD: XrpcMethod = XrpcMethod::Query;
type Response = Resp;
}
let opts = CallOptions::default();
let base = Uri::parse("https://pds.example.com")
.expect("URI should be valid")
.to_owned();
let req_with_params = QueryReq {
param1: Some("value1".to_string()),
param2: Some("value2".to_string()),
};
let http_req = build_http_request(&base, &req_with_params, &opts).unwrap();
let uri_str = http_req.uri().to_string();
assert!(
uri_str.contains("?"),
"AC1.3: URI should contain query string"
);
assert!(
uri_str.contains("param1=value1"),
"AC1.3: URI should contain param1"
);
assert!(
uri_str.contains("param2=value2"),
"AC1.3: URI should contain param2"
);
let req_empty_params = QueryReq {
param1: None,
param2: None,
};
let http_req_empty = build_http_request(&base, &req_empty_params, &opts).unwrap();
let uri_str_empty = http_req_empty.uri().to_string();
assert!(
!uri_str_empty.contains("?"),
"AC1.4: URI {} should not contain '?' with empty params",
uri_str_empty
);
assert_eq!(
uri_str_empty, "https://pds.example.com/xrpc/com.example.test",
"AC1.4: URI should have no query string"
);
}
#[test]
fn xrpc_uri_special_characters_in_query() {
use crate::alloc::string::ToString;
use serde::Serialize;
#[derive(Serialize)]
struct QueryReq {
#[serde(skip_serializing_if = "Option::is_none")]
search: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
filter: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
unicode_param: Option<String>,
}
#[derive(Serialize, Deserialize, Debug, thiserror::Error)]
#[error("test error")]
struct Err;
impl IntoStatic for Err {
type Output = Err;
fn into_static(self) -> Self::Output {
self
}
}
struct Resp;
impl XrpcResp for Resp {
const NSID: &'static str = "com.example.test";
const ENCODING: &'static str = "application/json";
type Output<'de> = ();
type Err<'de> = Err;
}
impl XrpcRequest for QueryReq {
const NSID: &'static str = "com.example.test";
const METHOD: XrpcMethod = XrpcMethod::Query;
type Response = Resp;
}
let opts = CallOptions::default();
let base = Uri::parse("https://pds.example.com")
.expect("URI should be valid")
.to_owned();
let req_spaces = QueryReq {
search: Some("hello world".to_string()),
filter: None,
unicode_param: None,
};
let http_req_spaces = build_http_request(&base, &req_spaces, &opts).unwrap();
let uri_spaces = http_req_spaces.uri().to_string();
assert!(
uri_spaces.contains("search=hello"),
"AC1.3: URI should contain search param"
);
assert!(
uri_spaces.contains("hello+world") || uri_spaces.contains("hello%20world"),
"AC1.3: URI {} should encode space in 'hello world'",
uri_spaces
);
let req_special = QueryReq {
search: Some("a=b&c+d".to_string()),
filter: None,
unicode_param: None,
};
let http_req_special = build_http_request(&base, &req_special, &opts).unwrap();
let uri_special = http_req_special.uri().to_string();
assert!(
uri_special.contains("?"),
"AC1.3: URI should contain query string for special chars"
);
let parsed = Uri::parse(uri_special.clone());
assert!(
parsed.is_ok(),
"AC1.3: URI {} should be parseable by fluent-uri",
uri_special
);
let req_unicode = QueryReq {
search: None,
filter: None,
unicode_param: Some("你好世界".to_string()),
};
let http_req_unicode = build_http_request(&base, &req_unicode, &opts).unwrap();
let uri_unicode = http_req_unicode.uri().to_string();
assert!(
uri_unicode.contains("?"),
"AC1.3: URI should contain query string for unicode"
);
let parsed_unicode = Uri::parse(uri_unicode.clone());
assert!(
parsed_unicode.is_ok(),
"AC1.3: URI {} should be parseable for unicode params",
uri_unicode
);
}
#[test]
fn no_double_slash_in_path() {
use crate::alloc::string::ToString;
#[derive(Serialize, Deserialize)]
struct Req;
#[derive(Deserialize, Serialize, Debug, thiserror::Error)]
#[error("{0}")]
struct Err<'a>(#[serde(borrow)] CowStr<'a>);
impl IntoStatic for Err<'_> {
type Output = Err<'static>;
fn into_static(self) -> Self::Output {
Err(self.0.into_static())
}
}
struct Resp;
impl XrpcResp for Resp {
const NSID: &'static str = "com.example.test";
const ENCODING: &'static str = "application/json";
type Output<'de> = ();
type Err<'de> = Err<'de>;
}
impl XrpcRequest for Req {
const NSID: &'static str = "com.example.test";
const METHOD: XrpcMethod = XrpcMethod::Query;
type Response = Resp;
}
let opts = CallOptions::default();
let base1 = Uri::parse("https://pds")
.expect("URI should be valid")
.to_owned();
let req1 = build_http_request(&base1, &Req, &opts).unwrap();
let uri1 = req1.uri().to_string();
assert!(
!uri1.contains("//xrpc"),
"URI {} should not contain '//xrpc'",
uri1
);
let base2 = Uri::parse("https://pds/base")
.expect("URI should be valid")
.to_owned();
let req2 = build_http_request(&base2, &Req, &opts).unwrap();
let uri2 = req2.uri().to_string();
assert!(
!uri2.contains("//xrpc"),
"URI {} should not contain '//xrpc'",
uri2
);
}
}