1#[cfg(feature = "streaming")]
14pub mod streaming;
15
16use ipld_core::ipld::Ipld;
17#[cfg(feature = "streaming")]
18pub use streaming::{
19 StreamingResponse, XrpcProcedureSend, XrpcProcedureStream, XrpcResponseStream, XrpcStreamResp,
20};
21
22#[cfg(feature = "websocket")]
23pub mod subscription;
24
25#[cfg(feature = "streaming")]
26use crate::StreamError;
27use crate::error::DecodeError;
28use crate::http_client::HttpClient;
29#[cfg(feature = "streaming")]
30use crate::http_client::HttpClientExt;
31use crate::types::value::Data;
32use crate::{AuthorizationToken, error::AuthError};
33use crate::{CowStr, error::XrpcResult};
34use crate::{IntoStatic, types::value::RawData};
35use bytes::Bytes;
36use http::{
37 HeaderName, HeaderValue, Request, StatusCode,
38 header::{AUTHORIZATION, CONTENT_TYPE},
39};
40use serde::{Deserialize, Serialize};
41use smol_str::SmolStr;
42use std::fmt::{self, Debug};
43use std::{error::Error, marker::PhantomData};
44#[cfg(feature = "websocket")]
45pub use subscription::{
46 BasicSubscriptionClient, MessageEncoding, SubscriptionCall, SubscriptionClient,
47 SubscriptionEndpoint, SubscriptionExt, SubscriptionOptions, SubscriptionResp,
48 SubscriptionStream, TungsteniteSubscriptionClient, XrpcSubscription,
49};
50use url::Url;
51
52#[derive(Debug, thiserror::Error, miette::Diagnostic)]
54pub enum EncodeError {
55 #[error("Failed to serialize query: {0}")]
57 Query(
58 #[from]
59 #[source]
60 serde_html_form::ser::Error,
61 ),
62 #[error("Failed to serialize JSON: {0}")]
64 Json(
65 #[from]
66 #[source]
67 serde_json::Error,
68 ),
69 #[error("Encoding error: {0}")]
71 Other(String),
72}
73
74#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
76pub enum XrpcMethod {
77 Query,
79 Procedure(&'static str),
81}
82
83impl XrpcMethod {
84 pub const fn as_str(&self) -> &'static str {
86 match self {
87 Self::Query => "GET",
88 Self::Procedure(_) => "POST",
89 }
90 }
91
92 pub const fn body_encoding(&self) -> Option<&'static str> {
94 match self {
95 Self::Query => None,
96 Self::Procedure(enc) => Some(enc),
97 }
98 }
99}
100
101pub trait XrpcRequest: Serialize {
108 const NSID: &'static str;
110
111 const METHOD: XrpcMethod;
113
114 type Response: XrpcResp;
116
117 fn encode_body(&self) -> Result<Vec<u8>, EncodeError> {
121 Ok(serde_json::to_vec(self)?)
122 }
123
124 fn decode_body<'de>(body: &'de [u8]) -> Result<Box<Self>, DecodeError>
128 where
129 Self: Deserialize<'de>,
130 {
131 let body: Self = serde_json::from_slice(body)?;
132
133 Ok(Box::new(body))
134 }
135}
136
137pub trait XrpcResp {
141 const NSID: &'static str;
143
144 const ENCODING: &'static str;
146
147 type Output<'de>: Serialize + Deserialize<'de> + IntoStatic;
149
150 type Err<'de>: Error + Deserialize<'de> + Serialize + IntoStatic;
152
153 fn encode_output(output: &Self::Output<'_>) -> Result<Vec<u8>, EncodeError> {
155 Ok(serde_json::to_vec(output)?)
156 }
157
158 fn decode_output<'de>(body: &'de [u8]) -> core::result::Result<Self::Output<'de>, DecodeError>
162 where
163 Self::Output<'de>: Deserialize<'de>,
164 {
165 #[allow(deprecated)]
166 let body = serde_json::from_slice(body).map_err(|e| DecodeError::Json(e))?;
167
168 Ok(body)
169 }
170}
171
172pub trait XrpcEndpoint {
180 const PATH: &'static str;
182 const METHOD: XrpcMethod;
184 type Request<'de>: XrpcRequest + Deserialize<'de> + IntoStatic;
186 type Response: XrpcResp;
188}
189
190#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)]
192pub struct GenericError<'a>(#[serde(borrow)] Data<'a>);
193
194impl<'de> fmt::Display for GenericError<'de> {
195 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
196 self.0.fmt(f)
197 }
198}
199
200impl Error for GenericError<'_> {}
201
202impl IntoStatic for GenericError<'_> {
203 type Output = GenericError<'static>;
204 fn into_static(self) -> Self::Output {
205 GenericError(self.0.into_static())
206 }
207}
208
209#[derive(Debug, Default, Clone)]
211pub struct CallOptions<'a> {
212 pub auth: Option<AuthorizationToken<'a>>,
214 pub atproto_proxy: Option<CowStr<'a>>,
216 pub atproto_accept_labelers: Option<Vec<CowStr<'a>>>,
218 pub extra_headers: Vec<(HeaderName, HeaderValue)>,
220}
221
222impl IntoStatic for CallOptions<'_> {
223 type Output = CallOptions<'static>;
224
225 fn into_static(self) -> Self::Output {
226 CallOptions {
227 auth: self.auth.map(|auth| auth.into_static()),
228 atproto_proxy: self.atproto_proxy.map(|proxy| proxy.into_static()),
229 atproto_accept_labelers: self
230 .atproto_accept_labelers
231 .map(|labelers| labelers.into_static()),
232 extra_headers: self.extra_headers,
233 }
234 }
235}
236
237pub trait XrpcExt: HttpClient {
253 fn xrpc<'a>(&'a self, base: Url) -> XrpcCall<'a, Self>
255 where
256 Self: Sized,
257 {
258 XrpcCall {
259 client: self,
260 base,
261 opts: CallOptions::default(),
262 }
263 }
264}
265
266impl<T: HttpClient> XrpcExt for T {}
267
268pub type XrpcResponse<R> = Response<<R as XrpcRequest>::Response>;
270
271#[cfg_attr(not(target_arch = "wasm32"), trait_variant::make(Send))]
273pub trait XrpcClient: HttpClient {
274 fn base_uri(&self) -> impl Future<Output = Url>;
276
277 fn set_base_uri(&self, url: Url) -> impl Future<Output = ()> {
279 let _ = url;
280 async {}
281 }
282
283 fn opts(&self) -> impl Future<Output = CallOptions<'_>> {
285 async { CallOptions::default() }
286 }
287
288 fn set_opts(&self, opts: CallOptions) -> impl Future<Output = ()> {
290 let _ = opts;
291 async {}
292 }
293
294 #[cfg(not(target_arch = "wasm32"))]
296 fn send<R>(&self, request: R) -> impl Future<Output = XrpcResult<XrpcResponse<R>>>
297 where
298 R: XrpcRequest + Send + Sync,
299 <R as XrpcRequest>::Response: Send + Sync,
300 Self: Sync;
301
302 #[cfg(target_arch = "wasm32")]
304 fn send<R>(&self, request: R) -> impl Future<Output = XrpcResult<XrpcResponse<R>>>
305 where
306 R: XrpcRequest + Send + Sync,
307 <R as XrpcRequest>::Response: Send + Sync;
308
309 #[cfg(not(target_arch = "wasm32"))]
311 fn send_with_opts<R>(
312 &self,
313 request: R,
314 opts: CallOptions<'_>,
315 ) -> impl Future<Output = XrpcResult<XrpcResponse<R>>>
316 where
317 R: XrpcRequest + Send + Sync,
318 <R as XrpcRequest>::Response: Send + Sync,
319 Self: Sync;
320
321 #[cfg(target_arch = "wasm32")]
323 fn send_with_opts<R>(
324 &self,
325 request: R,
326 opts: CallOptions<'_>,
327 ) -> impl Future<Output = XrpcResult<XrpcResponse<R>>>
328 where
329 R: XrpcRequest + Send + Sync,
330 <R as XrpcRequest>::Response: Send + Sync;
331}
332
333#[cfg(feature = "streaming")]
335pub trait XrpcStreamingClient: XrpcClient + HttpClientExt {
336 #[cfg(not(target_arch = "wasm32"))]
338 fn download<R>(
339 &self,
340 request: R,
341 ) -> impl Future<Output = Result<StreamingResponse, StreamError>> + Send
342 where
343 R: XrpcRequest + Send + Sync,
344 <R as XrpcRequest>::Response: Send + Sync,
345 Self: Sync;
346
347 #[cfg(target_arch = "wasm32")]
349 fn download<R>(
350 &self,
351 request: R,
352 ) -> impl Future<Output = Result<StreamingResponse, StreamError>>
353 where
354 R: XrpcRequest + Send + Sync,
355 <R as XrpcRequest>::Response: Send + Sync;
356
357 #[cfg(not(target_arch = "wasm32"))]
359 fn stream<S>(
360 &self,
361 stream: XrpcProcedureSend<S::Frame<'static>>,
362 ) -> impl Future<
363 Output = Result<
364 XrpcResponseStream<
365 <<S as XrpcProcedureStream>::Response as XrpcStreamResp>::Frame<'static>,
366 >,
367 StreamError,
368 >,
369 >
370 where
371 S: XrpcProcedureStream + 'static,
372 <<S as XrpcProcedureStream>::Response as XrpcStreamResp>::Frame<'static>: XrpcStreamResp,
373 Self: Sync;
374
375 #[cfg(target_arch = "wasm32")]
377 fn stream<S>(
378 &self,
379 stream: XrpcProcedureSend<S::Frame<'static>>,
380 ) -> impl Future<
381 Output = Result<
382 XrpcResponseStream<
383 <<S as XrpcProcedureStream>::Response as XrpcStreamResp>::Frame<'static>,
384 >,
385 StreamError,
386 >,
387 >
388 where
389 S: XrpcProcedureStream + 'static,
390 <<S as XrpcProcedureStream>::Response as XrpcStreamResp>::Frame<'static>: XrpcStreamResp;
391}
392
393pub struct XrpcCall<'a, C: HttpClient> {
414 pub(crate) client: &'a C,
415 pub(crate) base: Url,
416 pub(crate) opts: CallOptions<'a>,
417}
418
419impl<'a, C: HttpClient> XrpcCall<'a, C> {
420 pub fn auth(mut self, token: AuthorizationToken<'a>) -> Self {
422 self.opts.auth = Some(token);
423 self
424 }
425 pub fn proxy(mut self, proxy: CowStr<'a>) -> Self {
427 self.opts.atproto_proxy = Some(proxy);
428 self
429 }
430 pub fn accept_labelers(mut self, labelers: Vec<CowStr<'a>>) -> Self {
432 self.opts.atproto_accept_labelers = Some(labelers);
433 self
434 }
435 pub fn header(mut self, name: HeaderName, value: HeaderValue) -> Self {
437 self.opts.extra_headers.push((name, value));
438 self
439 }
440 pub fn with_options(mut self, opts: CallOptions<'a>) -> Self {
442 self.opts = opts;
443 self
444 }
445
446 #[cfg_attr(feature = "tracing", tracing::instrument(level = "debug", skip(self, request), fields(nsid = R::NSID)))]
455 pub async fn send<R>(self, request: &R) -> XrpcResult<Response<<R as XrpcRequest>::Response>>
456 where
457 R: XrpcRequest,
458 <R as XrpcRequest>::Response: Send + Sync,
459 {
460 let http_request = build_http_request(&self.base, request, &self.opts)?;
461
462 let http_response = self
463 .client
464 .send_http(http_request)
465 .await
466 .map_err(|e| crate::error::ClientError::transport(e))?;
467
468 process_response(http_response)
469 }
470}
471
472#[inline]
476pub fn process_response<Resp>(http_response: http::Response<Vec<u8>>) -> XrpcResult<Response<Resp>>
477where
478 Resp: XrpcResp,
479{
480 let status = http_response.status();
481 #[allow(deprecated)]
484 if status.as_u16() == 401 {
485 if let Some(hv) = http_response.headers().get(http::header::WWW_AUTHENTICATE) {
486 return Err(crate::error::ClientError::auth(
487 crate::error::AuthError::Other(hv.clone()),
488 ));
489 }
490 }
491 let buffer = Bytes::from(http_response.into_body());
492
493 if !status.is_success() && !matches!(status.as_u16(), 400 | 401) {
494 return Err(crate::error::HttpError {
495 status,
496 body: Some(buffer),
497 }
498 .into());
499 }
500
501 Ok(Response::new(buffer, status))
502}
503
504pub enum Header {
506 ContentType,
508 Authorization,
510 AtprotoProxy,
514 AtprotoAcceptLabelers,
516}
517
518impl From<Header> for HeaderName {
519 fn from(value: Header) -> Self {
520 match value {
521 Header::ContentType => CONTENT_TYPE,
522 Header::Authorization => AUTHORIZATION,
523 Header::AtprotoProxy => HeaderName::from_static("atproto-proxy"),
524 Header::AtprotoAcceptLabelers => HeaderName::from_static("atproto-accept-labelers"),
525 }
526 }
527}
528
529pub fn build_http_request<'s, R>(
531 base: &Url,
532 req: &R,
533 opts: &CallOptions<'_>,
534) -> XrpcResult<Request<Vec<u8>>>
535where
536 R: XrpcRequest,
537{
538 use crate::error::ClientError;
539
540 let mut url = base.clone();
541 let mut path = url.path().trim_end_matches('/').to_owned();
542 path.push_str("/xrpc/");
543 path.push_str(<R as XrpcRequest>::NSID);
544 url.set_path(&path);
545 if let XrpcMethod::Query = <R as XrpcRequest>::METHOD {
548 let qs = serde_html_form::to_string(&req).map_err(|e| {
549 ClientError::invalid_request(format!("Failed to serialize query: {}", e))
550 })?;
551 if !qs.is_empty() {
552 url.set_query(Some(&qs));
553 } else {
554 url.set_query(None);
555 }
556 }
557
558 let method = match <R as XrpcRequest>::METHOD {
559 XrpcMethod::Query => http::Method::GET,
560 XrpcMethod::Procedure(_) => http::Method::POST,
561 };
562
563 let mut builder = Request::builder().method(method).uri(url.as_str());
564
565 let has_content_type = opts
566 .extra_headers
567 .iter()
568 .any(|(name, _)| name == CONTENT_TYPE);
569
570 if let XrpcMethod::Procedure(encoding) = <R as XrpcRequest>::METHOD {
571 if !has_content_type {
573 builder = builder.header(Header::ContentType, encoding);
574 }
575 }
576 let output_encoding = <R::Response as XrpcResp>::ENCODING;
577 builder = builder.header(http::header::ACCEPT, output_encoding);
578
579 if let Some(token) = &opts.auth {
580 let hv = match token {
581 AuthorizationToken::Bearer(t) => {
582 HeaderValue::from_str(&format!("Bearer {}", t.as_ref()))
583 }
584 AuthorizationToken::Dpop(t) => HeaderValue::from_str(&format!("DPoP {}", t.as_ref())),
585 }
586 .map_err(|e| ClientError::invalid_request(format!("Invalid authorization token: {}", e)))?;
587 builder = builder.header(Header::Authorization, hv);
588 }
589
590 if let Some(proxy) = &opts.atproto_proxy {
591 builder = builder.header(Header::AtprotoProxy, proxy.as_ref());
592 }
593 if let Some(labelers) = &opts.atproto_accept_labelers {
594 if !labelers.is_empty() {
595 let joined = labelers
596 .iter()
597 .map(|s| s.as_ref())
598 .collect::<Vec<_>>()
599 .join(", ");
600 builder = builder.header(Header::AtprotoAcceptLabelers, joined);
601 }
602 }
603 for (name, value) in &opts.extra_headers {
604 builder = builder.header(name, value);
605 }
606
607 let body = if let XrpcMethod::Procedure(_) = R::METHOD {
608 req.encode_body()
609 .map_err(|e| ClientError::invalid_request(format!("Failed to encode body: {}", e)))?
610 } else {
611 vec![]
612 };
613
614 builder
615 .body(body)
616 .map_err(|e| ClientError::invalid_request(format!("Failed to build request: {}", e)))
617}
618
619pub struct Response<Resp>
624where
625 Resp: XrpcResp, {
627 _marker: PhantomData<fn() -> Resp>,
628 buffer: Bytes,
629 status: StatusCode,
630}
631
632impl<R> Response<R>
633where
634 R: XrpcResp,
635{
636 pub fn new(buffer: Bytes, status: StatusCode) -> Self {
638 Self {
639 buffer,
640 status,
641 _marker: PhantomData,
642 }
643 }
644
645 pub fn status(&self) -> StatusCode {
647 self.status
648 }
649
650 pub fn buffer(&self) -> &Bytes {
652 &self.buffer
653 }
654
655 pub fn parse<'s>(&'s self) -> Result<RespOutput<'s, R>, XrpcError<RespErr<'s, R>>> {
657 if self.status.is_success() {
659 match R::decode_output(&self.buffer) {
660 Ok(output) => Ok(output),
661 Err(e) => Err(XrpcError::Decode(e)),
662 }
663 } else if self.status.as_u16() == 400 {
665 match serde_json::from_slice::<_>(&self.buffer) {
666 Ok(error) => Err(XrpcError::Xrpc(error)),
667 Err(_) => {
668 match serde_json::from_slice::<GenericXrpcError>(&self.buffer) {
670 Ok(mut generic) => {
671 generic.nsid = R::NSID;
672 generic.method = ""; generic.http_status = self.status;
674 match generic.error.as_str() {
676 "ExpiredToken" => Err(XrpcError::Auth(AuthError::TokenExpired)),
677 "InvalidToken" => Err(XrpcError::Auth(AuthError::InvalidToken)),
678 _ => Err(XrpcError::Generic(generic)),
679 }
680 }
681 Err(e) => Err(XrpcError::Decode(DecodeError::Json(e))),
682 }
683 }
684 }
685 } else {
687 match serde_json::from_slice::<GenericXrpcError>(&self.buffer) {
688 Ok(mut generic) => {
689 generic.nsid = R::NSID;
690 generic.method = ""; generic.http_status = self.status;
692 match generic.error.as_str() {
693 "ExpiredToken" => Err(XrpcError::Auth(AuthError::TokenExpired)),
694 "InvalidToken" => Err(XrpcError::Auth(AuthError::InvalidToken)),
695 _ => Err(XrpcError::Auth(AuthError::NotAuthenticated)),
696 }
697 }
698 Err(e) => Err(XrpcError::Decode(DecodeError::Json(e))),
699 }
700 }
701 }
702
703 pub fn parse_data<'s>(&'s self) -> Result<Data<'s>, XrpcError<RespErr<'s, R>>> {
707 if self.status.is_success() {
709 match serde_json::from_slice::<_>(&self.buffer) {
710 Ok(output) => Ok(output),
711 Err(_) => {
712 if let Ok(data) = serde_ipld_dagcbor::from_slice::<Ipld>(&self.buffer) {
713 if let Ok(data) = Data::from_cbor(&data) {
714 Ok(data.into_static())
715 } else {
716 Ok(Data::Bytes(self.buffer.clone()))
717 }
718 } else {
719 Ok(Data::Bytes(self.buffer.clone()))
720 }
721 }
722 }
723 } else if self.status.as_u16() == 400 {
725 match serde_json::from_slice::<_>(&self.buffer) {
726 Ok(error) => Err(XrpcError::Xrpc(error)),
727 Err(_) => {
728 match serde_json::from_slice::<GenericXrpcError>(&self.buffer) {
730 Ok(mut generic) => {
731 generic.nsid = R::NSID;
732 generic.method = ""; generic.http_status = self.status;
734 match generic.error.as_str() {
736 "ExpiredToken" => Err(XrpcError::Auth(AuthError::TokenExpired)),
737 "InvalidToken" => Err(XrpcError::Auth(AuthError::InvalidToken)),
738 _ => Err(XrpcError::Generic(generic)),
739 }
740 }
741 Err(e) => Err(XrpcError::Decode(DecodeError::Json(e))),
742 }
743 }
744 }
745 } else {
747 match serde_json::from_slice::<GenericXrpcError>(&self.buffer) {
748 Ok(mut generic) => {
749 generic.nsid = R::NSID;
750 generic.method = ""; generic.http_status = self.status;
752 match generic.error.as_str() {
753 "ExpiredToken" => Err(XrpcError::Auth(AuthError::TokenExpired)),
754 "InvalidToken" => Err(XrpcError::Auth(AuthError::InvalidToken)),
755 _ => Err(XrpcError::Auth(AuthError::NotAuthenticated)),
756 }
757 }
758 Err(e) => Err(XrpcError::Decode(DecodeError::Json(e))),
759 }
760 }
761 }
762
763 pub fn parse_raw<'s>(&'s self) -> Result<RawData<'s>, XrpcError<RespErr<'s, R>>> {
767 if self.status.is_success() {
769 match serde_json::from_slice::<_>(&self.buffer) {
770 Ok(output) => Ok(output),
771 Err(_) => {
772 if let Ok(data) = serde_ipld_dagcbor::from_slice::<Ipld>(&self.buffer) {
773 if let Ok(data) = RawData::from_cbor(&data) {
774 Ok(data.into_static())
775 } else {
776 Ok(RawData::Bytes(self.buffer.clone()))
777 }
778 } else {
779 Ok(RawData::Bytes(self.buffer.clone()))
780 }
781 }
782 }
783 } else if self.status.as_u16() == 400 {
785 match serde_json::from_slice::<_>(&self.buffer) {
786 Ok(error) => Err(XrpcError::Xrpc(error)),
787 Err(_) => {
788 match serde_json::from_slice::<GenericXrpcError>(&self.buffer) {
790 Ok(mut generic) => {
791 generic.nsid = R::NSID;
792 generic.method = ""; generic.http_status = self.status;
794 match generic.error.as_str() {
796 "ExpiredToken" => Err(XrpcError::Auth(AuthError::TokenExpired)),
797 "InvalidToken" => Err(XrpcError::Auth(AuthError::InvalidToken)),
798 _ => Err(XrpcError::Generic(generic)),
799 }
800 }
801 Err(e) => Err(XrpcError::Decode(DecodeError::Json(e))),
802 }
803 }
804 }
805 } else {
807 match serde_json::from_slice::<GenericXrpcError>(&self.buffer) {
808 Ok(mut generic) => {
809 generic.nsid = R::NSID;
810 generic.method = ""; generic.http_status = self.status;
812 match generic.error.as_str() {
813 "ExpiredToken" => Err(XrpcError::Auth(AuthError::TokenExpired)),
814 "InvalidToken" => Err(XrpcError::Auth(AuthError::InvalidToken)),
815 _ => Err(XrpcError::Auth(AuthError::NotAuthenticated)),
816 }
817 }
818 Err(e) => Err(XrpcError::Decode(DecodeError::Json(e))),
819 }
820 }
821 }
822
823 pub fn transmute<NEW: XrpcResp>(self) -> Response<NEW> {
835 Response {
836 buffer: self.buffer,
837 status: self.status,
838 _marker: PhantomData,
839 }
840 }
841}
842
843pub type RespOutput<'a, Resp> = <Resp as XrpcResp>::Output<'a>;
845pub type RespErr<'a, Resp> = <Resp as XrpcResp>::Err<'a>;
847
848impl<R> Response<R>
849where
850 R: XrpcResp,
851{
852 pub fn into_output(self) -> Result<RespOutput<'static, R>, XrpcError<RespErr<'static, R>>>
854 where
855 for<'a> RespOutput<'a, R>: IntoStatic<Output = RespOutput<'static, R>>,
856 for<'a> RespErr<'a, R>: IntoStatic<Output = RespErr<'static, R>>,
857 {
858 fn parse_error<'b, R: XrpcResp>(buffer: &'b [u8]) -> Result<R::Err<'b>, serde_json::Error> {
859 serde_json::from_slice(buffer)
860 }
861
862 if self.status.is_success() {
864 match R::decode_output(&self.buffer) {
865 Ok(output) => Ok(output.into_static()),
866 Err(e) => Err(XrpcError::Decode(e)),
867 }
868 } else if self.status.as_u16() == 400 {
870 let error = match parse_error::<R>(&self.buffer) {
871 Ok(error) => XrpcError::Xrpc(error),
872 Err(_) => {
873 match serde_json::from_slice::<GenericXrpcError>(&self.buffer) {
875 Ok(mut generic) => {
876 generic.nsid = R::NSID;
877 generic.method = ""; generic.http_status = self.status;
879 match generic.error.as_ref() {
881 "ExpiredToken" => XrpcError::Auth(AuthError::TokenExpired),
882 "InvalidToken" => XrpcError::Auth(AuthError::InvalidToken),
883 _ => XrpcError::Generic(generic),
884 }
885 }
886 Err(e) => XrpcError::Decode(DecodeError::Json(e)),
887 }
888 }
889 };
890 Err(error.into_static())
891 } else {
893 let error: XrpcError<<R as XrpcResp>::Err<'_>> =
894 match serde_json::from_slice::<GenericXrpcError>(&self.buffer) {
895 Ok(mut generic) => {
896 let status = self.status;
897 generic.nsid = R::NSID;
898 generic.method = ""; generic.http_status = status;
900 match generic.error.as_ref() {
901 "ExpiredToken" => XrpcError::Auth(AuthError::TokenExpired),
902 "InvalidToken" => XrpcError::Auth(AuthError::InvalidToken),
903 _ => XrpcError::Auth(AuthError::NotAuthenticated),
904 }
905 }
906 Err(e) => XrpcError::Decode(DecodeError::Json(e)),
907 };
908
909 Err(error.into_static())
910 }
911 }
912}
913
914#[derive(Debug, Clone, Deserialize, Serialize)]
918pub struct GenericXrpcError {
919 pub error: SmolStr,
921 pub message: Option<SmolStr>,
923 #[serde(skip)]
925 pub nsid: &'static str,
926 #[serde(skip)]
928 pub method: &'static str,
929 #[serde(skip)]
931 pub http_status: StatusCode,
932}
933
934impl std::fmt::Display for GenericXrpcError {
935 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
936 if let Some(msg) = &self.message {
937 write!(
938 f,
939 "{}: {} (nsid={}, method={}, status={})",
940 self.error, msg, self.nsid, self.method, self.http_status
941 )
942 } else {
943 write!(
944 f,
945 "{} (nsid={}, method={}, status={})",
946 self.error, self.nsid, self.method, self.http_status
947 )
948 }
949 }
950}
951
952impl IntoStatic for GenericXrpcError {
953 type Output = Self;
954
955 fn into_static(self) -> Self::Output {
956 self
957 }
958}
959
960impl std::error::Error for GenericXrpcError {}
961
962#[derive(Debug, thiserror::Error, miette::Diagnostic)]
967pub enum XrpcError<E: std::error::Error + IntoStatic> {
968 #[error("XRPC error: {0}")]
970 #[diagnostic(code(jacquard_common::xrpc::typed))]
971 Xrpc(E),
972
973 #[error("Authentication error: {0}")]
975 #[diagnostic(code(jacquard_common::xrpc::auth))]
976 Auth(#[from] AuthError),
977
978 #[error("XRPC error: {0}")]
980 #[diagnostic(code(jacquard_common::xrpc::generic))]
981 Generic(GenericXrpcError),
982
983 #[error("Failed to decode response: {0}")]
985 #[diagnostic(code(jacquard_common::xrpc::decode))]
986 Decode(#[from] DecodeError),
987}
988
989impl<E> IntoStatic for XrpcError<E>
990where
991 E: std::error::Error + IntoStatic,
992 E::Output: std::error::Error + IntoStatic,
993 <E as IntoStatic>::Output: std::error::Error + IntoStatic,
994{
995 type Output = XrpcError<E::Output>;
996 fn into_static(self) -> Self::Output {
997 match self {
998 XrpcError::Xrpc(e) => XrpcError::Xrpc(e.into_static()),
999 XrpcError::Auth(e) => XrpcError::Auth(e.into_static()),
1000 XrpcError::Generic(e) => XrpcError::Generic(e),
1001 XrpcError::Decode(e) => XrpcError::Decode(e),
1002 }
1003 }
1004}
1005
1006impl<E> Serialize for XrpcError<E>
1007where
1008 E: std::error::Error + IntoStatic + Serialize,
1009{
1010 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
1011 where
1012 S: serde::Serializer,
1013 {
1014 use serde::ser::SerializeStruct;
1015
1016 match self {
1017 XrpcError::Xrpc(e) => e.serialize(serializer),
1019 XrpcError::Generic(g) => g.serialize(serializer),
1021 XrpcError::Auth(auth) => {
1023 let mut state = serializer.serialize_struct("XrpcError", 2)?;
1024 let (error, message) = match auth {
1025 AuthError::TokenExpired => ("ExpiredToken", Some("Access token has expired")),
1026 AuthError::InvalidToken => {
1027 ("InvalidToken", Some("Access token is invalid or malformed"))
1028 }
1029 AuthError::RefreshFailed => {
1030 ("RefreshFailed", Some("Token refresh request failed"))
1031 }
1032 AuthError::NotAuthenticated => (
1033 "AuthenticationRequired",
1034 Some("Request requires authentication but none was provided"),
1035 ),
1036 AuthError::Other(hv) => {
1037 let msg = hv.to_str().unwrap_or("[non-utf8 header]");
1038 ("AuthenticationError", Some(msg))
1039 }
1040 };
1041 state.serialize_field("error", error)?;
1042 if let Some(msg) = message {
1043 state.serialize_field("message", msg)?;
1044 }
1045 state.end()
1046 }
1047 XrpcError::Decode(decode_err) => {
1048 let mut state = serializer.serialize_struct("XrpcError", 2)?;
1049 state.serialize_field("error", "ResponseDecodeError")?;
1050 let msg = format!("{:?}", decode_err);
1052 state.serialize_field("message", &msg)?;
1053 state.end()
1054 }
1055 }
1056 }
1057}
1058
1059#[cfg(feature = "streaming")]
1060impl<'a, C: HttpClient + HttpClientExt> XrpcCall<'a, C> {
1061 pub async fn download<R>(self, request: &R) -> Result<StreamingResponse, StreamError>
1065 where
1066 R: XrpcRequest,
1067 <R as XrpcRequest>::Response: Send + Sync,
1068 {
1069 let http_request =
1070 build_http_request(&self.base, request, &self.opts).map_err(StreamError::transport)?;
1071
1072 let http_response = self
1073 .client
1074 .send_http_streaming(http_request)
1075 .await
1076 .map_err(StreamError::transport)?;
1077 let (parts, body) = http_response.into_parts();
1078
1079 Ok(StreamingResponse::new(parts, body))
1080 }
1081
1082 pub async fn stream<S>(
1087 self,
1088 stream: XrpcProcedureSend<S::Frame<'static>>,
1089 ) -> Result<XrpcResponseStream<<S::Response as XrpcStreamResp>::Frame<'static>>, StreamError>
1090 where
1091 S: XrpcProcedureStream + 'static,
1092 <<S as XrpcProcedureStream>::Response as XrpcStreamResp>::Frame<'static>: XrpcStreamResp,
1093 {
1094 use futures::TryStreamExt;
1095
1096 let mut url = self.base;
1097 let mut path = url.path().trim_end_matches('/').to_owned();
1098 path.push_str("/xrpc/");
1099 path.push_str(<S::Request as XrpcRequest>::NSID);
1100 url.set_path(&path);
1101
1102 let mut builder = http::Request::post(url.to_string());
1103
1104 if let Some(token) = &self.opts.auth {
1105 let hv = match token {
1106 AuthorizationToken::Bearer(t) => {
1107 HeaderValue::from_str(&format!("Bearer {}", t.as_ref()))
1108 }
1109 AuthorizationToken::Dpop(t) => {
1110 HeaderValue::from_str(&format!("DPoP {}", t.as_ref()))
1111 }
1112 }
1113 .map_err(|e| StreamError::protocol(format!("Invalid authorization token: {}", e)))?;
1114 builder = builder.header(Header::Authorization, hv);
1115 }
1116
1117 if let Some(proxy) = &self.opts.atproto_proxy {
1118 builder = builder.header(Header::AtprotoProxy, proxy.as_ref());
1119 }
1120 if let Some(labelers) = &self.opts.atproto_accept_labelers {
1121 if !labelers.is_empty() {
1122 let joined = labelers
1123 .iter()
1124 .map(|s| s.as_ref())
1125 .collect::<Vec<_>>()
1126 .join(", ");
1127 builder = builder.header(Header::AtprotoAcceptLabelers, joined);
1128 }
1129 }
1130
1131 for (name, value) in &self.opts.extra_headers {
1132 builder = builder.header(name, value);
1133 }
1134
1135 let (parts, _) = builder
1136 .body(())
1137 .map_err(|e| StreamError::protocol(e.to_string()))?
1138 .into_parts();
1139
1140 let body_stream = Box::pin(stream.0.map_ok(|f| f.buffer));
1141
1142 let resp = self
1143 .client
1144 .send_http_bidirectional(parts, body_stream)
1145 .await
1146 .map_err(StreamError::transport)?;
1147
1148 let (parts, body) = resp.into_parts();
1149
1150 Ok(XrpcResponseStream::<
1151 <<S as XrpcProcedureStream>::Response as XrpcStreamResp>::Frame<'static>,
1152 >::from_typed_parts(parts, body))
1153 }
1154}
1155
1156#[cfg(test)]
1157mod tests {
1158 use super::*;
1159 use serde::{Deserialize, Serialize};
1160
1161 #[derive(Serialize, Deserialize)]
1162 #[allow(dead_code)]
1163 struct DummyReq;
1164
1165 #[derive(Deserialize, Serialize, Debug, thiserror::Error)]
1166 #[error("{0}")]
1167 struct DummyErr<'a>(#[serde(borrow)] CowStr<'a>);
1168
1169 impl IntoStatic for DummyErr<'_> {
1170 type Output = DummyErr<'static>;
1171 fn into_static(self) -> Self::Output {
1172 DummyErr(self.0.into_static())
1173 }
1174 }
1175
1176 struct DummyResp;
1177
1178 impl XrpcResp for DummyResp {
1179 const NSID: &'static str = "test.dummy";
1180 const ENCODING: &'static str = "application/json";
1181 type Output<'de> = ();
1182 type Err<'de> = DummyErr<'de>;
1183 }
1184
1185 impl XrpcRequest for DummyReq {
1186 const NSID: &'static str = "test.dummy";
1187 const METHOD: XrpcMethod = XrpcMethod::Procedure("application/json");
1188 type Response = DummyResp;
1189 }
1190
1191 #[test]
1192 fn generic_error_carries_context() {
1193 let body = serde_json::json!({"error":"InvalidRequest","message":"missing"});
1194 let buf = Bytes::from(serde_json::to_vec(&body).unwrap());
1195 let resp: Response<DummyResp> = Response::new(buf, StatusCode::BAD_REQUEST);
1196 match resp.parse().unwrap_err() {
1197 XrpcError::Generic(g) => {
1198 assert_eq!(g.error.as_str(), "InvalidRequest");
1199 assert_eq!(g.message.as_deref(), Some("missing"));
1200 assert_eq!(g.nsid, DummyResp::NSID);
1201 assert_eq!(g.method, ""); assert_eq!(g.http_status, StatusCode::BAD_REQUEST);
1203 }
1204 other => panic!("unexpected: {other:?}"),
1205 }
1206 }
1207
1208 #[test]
1209 fn auth_error_mapping() {
1210 for (code, expect) in [
1211 ("ExpiredToken", AuthError::TokenExpired),
1212 ("InvalidToken", AuthError::InvalidToken),
1213 ] {
1214 let body = serde_json::json!({"error": code});
1215 let buf = Bytes::from(serde_json::to_vec(&body).unwrap());
1216 let resp: Response<DummyResp> = Response::new(buf, StatusCode::UNAUTHORIZED);
1217 match resp.parse().unwrap_err() {
1218 XrpcError::Auth(e) => match (e, expect) {
1219 (AuthError::TokenExpired, AuthError::TokenExpired) => {}
1220 (AuthError::InvalidToken, AuthError::InvalidToken) => {}
1221 other => panic!("mismatch: {other:?}"),
1222 },
1223 other => panic!("unexpected: {other:?}"),
1224 }
1225 }
1226 }
1227
1228 #[test]
1229 fn no_double_slash_in_path() {
1230 #[derive(Serialize, Deserialize)]
1231 struct Req;
1232 #[derive(Deserialize, Serialize, Debug, thiserror::Error)]
1233 #[error("{0}")]
1234 struct Err<'a>(#[serde(borrow)] CowStr<'a>);
1235 impl IntoStatic for Err<'_> {
1236 type Output = Err<'static>;
1237 fn into_static(self) -> Self::Output {
1238 Err(self.0.into_static())
1239 }
1240 }
1241 struct Resp;
1242 impl XrpcResp for Resp {
1243 const NSID: &'static str = "com.example.test";
1244 const ENCODING: &'static str = "application/json";
1245 type Output<'de> = ();
1246 type Err<'de> = Err<'de>;
1247 }
1248 impl XrpcRequest for Req {
1249 const NSID: &'static str = "com.example.test";
1250 const METHOD: XrpcMethod = XrpcMethod::Query;
1251 type Response = Resp;
1252 }
1253
1254 let opts = CallOptions::default();
1255 for base in [
1256 Url::parse("https://pds").unwrap(),
1257 Url::parse("https://pds/").unwrap(),
1258 Url::parse("https://pds/base/").unwrap(),
1259 ] {
1260 let req = build_http_request(&base, &Req, &opts).unwrap();
1261 let uri = req.uri().to_string();
1262 assert!(uri.contains("/xrpc/com.example.test"));
1263 assert!(!uri.contains("//xrpc"));
1264 }
1265 }
1266}