1#[cfg(feature = "streaming")]
14pub mod streaming;
15
16pub mod atproto;
18
19use alloc::string::{String, ToString};
20use alloc::vec::Vec;
21use ipld_core::ipld::Ipld;
22#[cfg(feature = "streaming")]
23pub use streaming::{
24 StreamingResponse, XrpcProcedureSend, XrpcProcedureStream, XrpcResponseStream, XrpcStreamResp,
25};
26
27#[cfg(feature = "websocket")]
28pub mod subscription;
29
30#[cfg(feature = "streaming")]
31use crate::StreamError;
32use crate::bos::BosStr;
33use crate::error::DecodeError;
34use crate::http_client::HttpClient;
35#[cfg(feature = "streaming")]
36use crate::http_client::HttpClientExt;
37use crate::types::value::Data;
38use crate::{AuthorizationToken, error::AuthError};
39use crate::{BorrowOrShare, DefaultStr};
40use crate::{CowStr, error::XrpcResult};
41use crate::{IntoStatic, types::value::RawData};
42use bytes::Bytes;
43use core::error::Error;
44use core::fmt::{self, Debug};
45use core::marker::PhantomData;
46use http::{
47 HeaderName, HeaderValue, Request, StatusCode,
48 header::{AUTHORIZATION, CONTENT_TYPE},
49};
50use serde::de::DeserializeOwned;
51use serde::{Deserialize, Serialize};
52use smol_str::SmolStr;
53
54use crate::deps::fluent_uri::Uri;
55#[cfg(feature = "websocket")]
56pub use subscription::{
57 BasicSubscriptionClient, MessageEncoding, SubscriptionCall, SubscriptionClient,
58 SubscriptionEndpoint, SubscriptionExt, SubscriptionOptions, SubscriptionResp,
59 SubscriptionStream, TungsteniteSubscriptionClient, XrpcSubscription,
60};
61
62pub fn normalize_base_uri(uri: Uri<String>) -> Uri<String> {
69 let s = uri.as_str();
70 if s.ends_with('/') && s.len() > 1 {
71 let trimmed = s.trim_end_matches('/');
72 Uri::parse(trimmed.to_string())
74 .expect("trimming trailing slash from valid URI yields valid URI")
75 } else {
76 uri
77 }
78}
79
80#[derive(Debug, thiserror::Error)]
82#[cfg_attr(feature = "std", derive(miette::Diagnostic))]
83#[non_exhaustive]
84pub enum EncodeError {
85 #[error("Failed to serialize query: {0}")]
87 Query(
88 #[from]
89 #[source]
90 serde_html_form::ser::Error,
91 ),
92 #[error("Failed to serialize JSON: {0}")]
94 Json(
95 #[from]
96 #[source]
97 serde_json::Error,
98 ),
99 #[error("Encoding error: {0}")]
101 Other(String),
102}
103
104#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
106pub enum XrpcMethod {
107 Query,
109 Procedure(&'static str),
111}
112
113impl XrpcMethod {
114 pub const fn as_str(&self) -> &'static str {
116 match self {
117 Self::Query => "GET",
118 Self::Procedure(_) => "POST",
119 }
120 }
121
122 pub const fn body_encoding(&self) -> Option<&'static str> {
124 match self {
125 Self::Query => None,
126 Self::Procedure(enc) => Some(enc),
127 }
128 }
129}
130
131pub trait XrpcRequest {
138 const NSID: &'static str;
140
141 const METHOD: XrpcMethod;
143
144 type Response: XrpcResp;
146
147 fn encode_body(&self, buffer: &mut Vec<u8>) -> Result<(), EncodeError>
151 where
152 Self: Serialize,
153 {
154 Ok(serde_json::to_writer(buffer, self)?)
155 }
156
157 fn decode_body<'de>(body: &'de [u8]) -> Result<Self, DecodeError>
161 where
162 Self: Deserialize<'de>,
163 {
164 let body: Self = serde_json::from_slice(body)?;
165
166 Ok(body)
167 }
168}
169
170pub trait XrpcResp {
181 const NSID: &'static str;
183
184 const ENCODING: &'static str;
186
187 type Output<S: BosStr>;
189
190 type Err: Error + Serialize + DeserializeOwned;
192
193 fn encode_output<S: BosStr>(output: &Self::Output<S>) -> Result<Vec<u8>, EncodeError>
197 where
198 Self::Output<S>: Serialize,
199 {
200 Ok(serde_json::to_vec(output)?)
201 }
202
203 fn decode_output<'de, S>(body: &'de [u8]) -> core::result::Result<Self::Output<S>, DecodeError>
207 where
208 S: BosStr + Deserialize<'de>,
209 Self::Output<S>: Deserialize<'de>,
210 {
211 let body = serde_json::from_slice(body).map_err(DecodeError::Json)?;
212 Ok(body)
213 }
214}
215
216pub trait XrpcEndpoint {
224 const PATH: &'static str;
226 const METHOD: XrpcMethod;
228 type Request<S: BosStr>: XrpcRequest;
230 type Response: XrpcResp;
232}
233
234#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)]
238pub struct GenericError(Data<SmolStr>);
239
240impl fmt::Display for GenericError {
241 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
242 self.0.fmt(f)
243 }
244}
245
246impl Error for GenericError {}
247
248#[derive(Debug, Clone)]
250pub struct CallOptions<S: BosStr = DefaultStr> {
251 pub auth: Option<AuthorizationToken<S>>,
253 pub atproto_proxy: Option<S>,
255 pub atproto_accept_labelers: Option<Vec<S>>,
257 pub extra_headers: Vec<(HeaderName, HeaderValue)>,
259}
260
261impl Default for CallOptions {
262 fn default() -> Self {
263 Self {
264 auth: None,
265 atproto_proxy: None,
266 atproto_accept_labelers: None,
267 extra_headers: Vec::new(),
268 }
269 }
270}
271
272impl<S: BosStr> CallOptions<S> {
273 pub fn borrow(&self) -> CallOptions<&str> {
275 CallOptions {
276 auth: self.auth.as_ref().map(|auth| auth.borrow()),
277 atproto_proxy: self
278 .atproto_proxy
279 .as_ref()
280 .map(|proxy| proxy.borrow_or_share()),
281 atproto_accept_labelers: self
282 .atproto_accept_labelers
283 .as_ref()
284 .map(|labelers| labelers.iter().map(|l| l.as_ref()).collect()),
285 extra_headers: self.extra_headers.clone(),
286 }
287 }
288}
289
290impl<S: BosStr + IntoStatic> IntoStatic for CallOptions<S>
291where
292 <S as IntoStatic>::Output: BosStr + 'static,
293{
294 type Output = CallOptions<<S as IntoStatic>::Output>;
295
296 fn into_static(self) -> Self::Output {
297 CallOptions {
298 auth: self.auth.map(|auth| auth.into_static()),
299 atproto_proxy: self.atproto_proxy.map(|proxy| proxy.into_static()),
300 atproto_accept_labelers: self
301 .atproto_accept_labelers
302 .map(|labelers| labelers.into_static()),
303 extra_headers: self.extra_headers,
304 }
305 }
306}
307
308pub trait XrpcExt: HttpClient {
325 fn xrpc<'a>(&'a self, base: Uri<&'a str>) -> XrpcCall<'a, Self>
327 where
328 Self: Sized,
329 {
330 XrpcCall {
331 client: self,
332 base,
333 opts: CallOptions::default(),
334 }
335 }
336}
337
338impl<T: HttpClient> XrpcExt for T {}
339
340pub type XrpcResponse<R> = Response<<R as XrpcRequest>::Response>;
342
343#[cfg_attr(not(target_arch = "wasm32"), trait_variant::make(Send))]
345pub trait XrpcClient: HttpClient {
346 fn base_uri(&self) -> impl Future<Output = Uri<String>>;
348
349 fn set_base_uri(&self, uri: Uri<String>) -> impl Future<Output = ()> {
353 let _ = uri;
354 async {}
355 }
356
357 fn opts(&self) -> impl Future<Output = CallOptions> {
359 async { CallOptions::default() }
360 }
361
362 fn set_opts(&self, opts: CallOptions) -> impl Future<Output = ()> {
364 let _ = opts;
365 async {}
366 }
367
368 #[cfg(not(target_arch = "wasm32"))]
370 fn send<R>(&self, request: R) -> impl Future<Output = XrpcResult<XrpcResponse<R>>>
371 where
372 R: XrpcRequest + Send + Sync + serde::Serialize,
373 <R as XrpcRequest>::Response: Send + Sync,
374 Self: Sync;
375
376 #[cfg(target_arch = "wasm32")]
378 fn send<R>(&self, request: R) -> impl Future<Output = XrpcResult<XrpcResponse<R>>>
379 where
380 R: XrpcRequest + Send + Sync + serde::Serialize,
381 <R as XrpcRequest>::Response: Send + Sync;
382
383 #[cfg(not(target_arch = "wasm32"))]
385 fn send_with_opts<R>(
386 &self,
387 request: R,
388 opts: CallOptions,
389 ) -> impl Future<Output = XrpcResult<XrpcResponse<R>>>
390 where
391 R: XrpcRequest + Send + Sync + serde::Serialize,
392 <R as XrpcRequest>::Response: Send + Sync,
393 Self: Sync;
394
395 #[cfg(target_arch = "wasm32")]
397 fn send_with_opts<R>(
398 &self,
399 request: R,
400 opts: CallOptions,
401 ) -> impl Future<Output = XrpcResult<XrpcResponse<R>>>
402 where
403 R: XrpcRequest + Send + Sync + serde::Serialize,
404 <R as XrpcRequest>::Response: Send + Sync;
405}
406
407#[cfg(feature = "streaming")]
409pub trait XrpcStreamingClient: XrpcClient + HttpClientExt {
410 #[cfg(not(target_arch = "wasm32"))]
412 fn download<R>(
413 &self,
414 request: R,
415 ) -> impl Future<Output = Result<StreamingResponse, StreamError>> + Send
416 where
417 R: XrpcRequest + Send + Sync + serde::Serialize,
418 <R as XrpcRequest>::Response: Send + Sync,
419 Self: Sync;
420
421 #[cfg(target_arch = "wasm32")]
423 fn download<R>(
424 &self,
425 request: R,
426 ) -> impl Future<Output = Result<StreamingResponse, StreamError>>
427 where
428 R: XrpcRequest + Send + Sync + serde::Serialize,
429 <R as XrpcRequest>::Response: Send + Sync;
430
431 #[cfg(not(target_arch = "wasm32"))]
433 fn stream<S, B>(
434 &self,
435 stream: XrpcProcedureSend<S::Frame<B>>,
436 ) -> impl Future<
437 Output = Result<
438 XrpcResponseStream<<<S as XrpcProcedureStream>::Response as XrpcStreamResp>::Frame<B>>,
439 StreamError,
440 >,
441 >
442 where
443 B: BosStr + 'static,
444 S: XrpcProcedureStream + 'static,
445 <<S as XrpcProcedureStream>::Response as XrpcStreamResp>::Frame<B>: XrpcStreamResp,
446 Self: Sync;
447
448 #[cfg(target_arch = "wasm32")]
450 fn stream<S, B>(
451 &self,
452 stream: XrpcProcedureSend<S::Frame<B>>,
453 ) -> impl Future<
454 Output = Result<
455 XrpcResponseStream<<<S as XrpcProcedureStream>::Response as XrpcStreamResp>::Frame<B>>,
456 StreamError,
457 >,
458 >
459 where
460 B: BosStr + 'static,
461 S: XrpcProcedureStream + 'static,
462 <<S as XrpcProcedureStream>::Response as XrpcStreamResp>::Frame<B>: XrpcStreamResp;
463}
464
465pub struct XrpcCall<'a, C: HttpClient> {
487 pub(crate) client: &'a C,
488 pub(crate) base: Uri<&'a str>,
489 pub(crate) opts: CallOptions,
490}
491
492impl<'a, C: HttpClient> XrpcCall<'a, C> {
493 pub fn auth(mut self, token: AuthorizationToken) -> Self {
495 self.opts.auth = Some(token);
496 self
497 }
498 pub fn proxy(mut self, proxy: DefaultStr) -> Self {
500 self.opts.atproto_proxy = Some(proxy);
501 self
502 }
503 pub fn accept_labelers(mut self, labelers: Vec<DefaultStr>) -> Self {
505 self.opts.atproto_accept_labelers = Some(labelers);
506 self
507 }
508 pub fn header(mut self, name: HeaderName, value: HeaderValue) -> Self {
510 self.opts.extra_headers.push((name, value));
511 self
512 }
513 pub fn with_options(mut self, opts: CallOptions) -> Self {
515 self.opts = opts;
516 self
517 }
518
519 #[cfg_attr(feature = "tracing", tracing::instrument(level = "debug", skip(self, request), fields(nsid = R::NSID)))]
528 pub async fn send<R>(self, request: &R) -> XrpcResult<Response<<R as XrpcRequest>::Response>>
529 where
530 R: XrpcRequest + Serialize,
531 <R as XrpcRequest>::Response: Send + Sync,
532 {
533 let http_request = build_http_request(&self.base, request, &self.opts)?;
534
535 let http_response = self
536 .client
537 .send_http(http_request)
538 .await
539 .map_err(|e| crate::error::ClientError::transport(e).for_nsid(R::NSID))?;
540
541 process_response(http_response)
542 }
543}
544
545#[inline]
549pub fn process_response<Resp>(http_response: http::Response<Vec<u8>>) -> XrpcResult<Response<Resp>>
550where
551 Resp: XrpcResp,
552{
553 let status = http_response.status();
554
555 #[allow(deprecated)]
558 if status.as_u16() == 401 {
559 if let Some(hv) = http_response.headers().get(http::header::WWW_AUTHENTICATE) {
560 return Err(
561 crate::error::ClientError::auth(crate::error::AuthError::Other(hv.clone()))
562 .for_nsid(Resp::NSID),
563 );
564 }
565 }
566 let buffer = Bytes::from(http_response.into_body());
567
568 if !status.is_success() && !matches!(status.as_u16(), 400 | 401) {
569 return Err(crate::error::ClientError::from(crate::error::HttpError {
570 status,
571 body: Some(buffer),
572 })
573 .for_nsid(Resp::NSID));
574 }
575
576 Ok(Response::new(buffer, status))
577}
578
579pub enum Header {
581 ContentType,
583 Authorization,
585 AtprotoProxy,
589 AtprotoAcceptLabelers,
591}
592
593impl From<Header> for HeaderName {
594 fn from(value: Header) -> Self {
595 match value {
596 Header::ContentType => CONTENT_TYPE,
597 Header::Authorization => AUTHORIZATION,
598 Header::AtprotoProxy => HeaderName::from_static("atproto-proxy"),
599 Header::AtprotoAcceptLabelers => HeaderName::from_static("atproto-accept-labelers"),
600 }
601 }
602}
603
604fn xrpc_endpoint_uri(base: &Uri<&str>, nsid: &str, query: Option<&str>) -> XrpcResult<Uri<String>> {
613 use crate::error::ClientError;
614
615 let base_path = base.path().as_str().trim_end_matches('/');
616
617 let capacity = base.scheme().as_str().len()
619 + 3 + base.authority().map(|a| a.as_str().len()).unwrap_or(0)
621 + base_path.len()
622 + 6 + nsid.len()
624 + query.map(|q| q.len() + 1).unwrap_or(0); let mut uri_str = String::with_capacity(capacity);
628 uri_str.push_str(base.scheme().as_str());
629 uri_str.push_str("://");
630
631 if let Some(authority) = base.authority() {
632 uri_str.push_str(authority.as_str());
633 }
634
635 uri_str.push_str(base_path);
636 uri_str.push_str("/xrpc/");
637 uri_str.push_str(nsid);
638
639 if let Some(q) = query {
640 uri_str.push('?');
641 uri_str.push_str(q);
642 }
643
644 Uri::parse(uri_str)
645 .map_err(|_| ClientError::invalid_request("Failed to construct XRPC endpoint URI"))
646}
647
648pub fn build_http_request<'s, R>(
650 base: &Uri<&str>,
651 req: &R,
652 opts: &CallOptions,
653) -> XrpcResult<Request<Vec<u8>>>
654where
655 R: XrpcRequest + Serialize,
656{
657 use crate::error::ClientError;
658
659 let query_string = if let XrpcMethod::Query = <R as XrpcRequest>::METHOD {
661 let qs = serde_html_form::to_string(&req).map_err(|e| {
662 ClientError::invalid_request(format!("Failed to serialize query: {}", e))
663 })?;
664 if !qs.is_empty() { Some(qs) } else { None }
665 } else {
666 None
667 };
668
669 let uri = xrpc_endpoint_uri(base, <R as XrpcRequest>::NSID, query_string.as_deref())?;
671
672 let method = match <R as XrpcRequest>::METHOD {
673 XrpcMethod::Query => http::Method::GET,
674 XrpcMethod::Procedure(_) => http::Method::POST,
675 };
676
677 let mut builder = Request::builder().method(method).uri(uri.as_str());
678
679 let has_content_type = opts
680 .extra_headers
681 .iter()
682 .any(|(name, _)| name == CONTENT_TYPE);
683
684 if let XrpcMethod::Procedure(encoding) = <R as XrpcRequest>::METHOD {
685 if !has_content_type {
687 builder = builder.header(Header::ContentType, encoding);
688 }
689 }
690 let output_encoding = <R::Response as XrpcResp>::ENCODING;
691 builder = builder.header(http::header::ACCEPT, output_encoding);
692
693 if let Some(token) = &opts.auth {
694 let hv = match token {
695 AuthorizationToken::Bearer(t) => {
696 HeaderValue::from_str(&format!("Bearer {}", t.as_str()))
697 }
698 AuthorizationToken::Dpop(t) => HeaderValue::from_str(&format!("DPoP {}", t.as_str())),
699 }
700 .map_err(|e| ClientError::invalid_request(format!("Invalid authorization token: {}", e)))?;
701 builder = builder.header(Header::Authorization, hv);
702 }
703
704 if let Some(proxy) = &opts.atproto_proxy {
705 builder = builder.header(Header::AtprotoProxy, proxy.as_str());
706 }
707 if let Some(labelers) = &opts.atproto_accept_labelers {
708 if !labelers.is_empty() {
709 let joined = labelers
710 .iter()
711 .map(|s| s.as_ref())
712 .collect::<Vec<_>>()
713 .join(", ");
714 builder = builder.header(Header::AtprotoAcceptLabelers, joined);
715 }
716 }
717 for (name, value) in &opts.extra_headers {
718 builder = builder.header(name, value);
719 }
720
721 let body = if let XrpcMethod::Procedure(_) = R::METHOD {
722 let mut buf = Vec::with_capacity(300);
723 req.encode_body(&mut buf)
724 .map_err(|e| ClientError::invalid_request(format!("Failed to encode body: {}", e)))?;
725 buf
726 } else {
727 vec![]
728 };
729
730 builder
731 .body(body)
732 .map_err(|e| ClientError::invalid_request(format!("Failed to build request: {}", e)))
733}
734
735pub struct Response<Resp>
740where
741 Resp: XrpcResp, {
743 _marker: PhantomData<fn() -> Resp>,
744 buffer: Bytes,
745 status: StatusCode,
746}
747
748impl<R> Response<R>
749where
750 R: XrpcResp,
751{
752 pub fn new(buffer: Bytes, status: StatusCode) -> Self {
754 Self {
755 buffer,
756 status,
757 _marker: PhantomData,
758 }
759 }
760
761 pub fn status(&self) -> StatusCode {
763 self.status
764 }
765
766 pub fn buffer(&self) -> &Bytes {
768 &self.buffer
769 }
770
771 pub fn parse<'s, S>(&'s self) -> Result<R::Output<S>, XrpcError<R::Err>>
776 where
777 S: BosStr + Deserialize<'s>,
778 R::Output<S>: Deserialize<'s>,
779 {
780 if self.status.is_success() {
781 R::decode_output::<S>(&self.buffer).map_err(XrpcError::Decode)
782 } else {
783 Err(self.parse_error())
784 }
785 }
786
787 pub fn parse_data(&self) -> Result<Data<CowStr<'_>>, XrpcError<R::Err>> {
792 if self.status.is_success() {
793 match serde_json::from_slice::<Data<CowStr<'_>>>(&self.buffer) {
794 Ok(output) => Ok(output),
795 Err(_) => {
796 if let Ok(ipld) = serde_ipld_dagcbor::from_slice::<Ipld>(&self.buffer) {
797 if let Ok(data) = RawData::from_cbor(&ipld) {
798 Ok(data
800 .into_static()
801 .try_into()
802 .unwrap_or(Data::Bytes(self.buffer.clone())))
803 } else {
804 Ok(Data::Bytes(self.buffer.clone()))
805 }
806 } else {
807 Ok(Data::Bytes(self.buffer.clone()))
808 }
809 }
810 }
811 } else {
812 Err(self.parse_error())
813 }
814 }
815
816 pub fn parse_raw(&self) -> Result<RawData<'_>, XrpcError<R::Err>> {
820 if self.status.is_success() {
821 match serde_json::from_slice::<RawData<'_>>(&self.buffer) {
822 Ok(output) => Ok(output),
823 Err(_) => {
824 if let Ok(ipld) = serde_ipld_dagcbor::from_slice::<Ipld>(&self.buffer) {
825 if let Ok(data) = RawData::from_cbor(&ipld) {
826 Ok(data.into_static())
827 } else {
828 Ok(RawData::Bytes(self.buffer.clone()))
829 }
830 } else {
831 Ok(RawData::Bytes(self.buffer.clone()))
832 }
833 }
834 }
835 } else {
836 Err(self.parse_error())
837 }
838 }
839
840 fn parse_error(&self) -> XrpcError<R::Err> {
842 if self.status.as_u16() == 400 {
844 match serde_json::from_slice::<R::Err>(&self.buffer) {
845 Ok(error) => {
846 use alloc::string::ToString;
847 if error.to_string().contains("InvalidToken") {
848 XrpcError::Auth(AuthError::InvalidToken)
849 } else if error.to_string().contains("ExpiredToken") {
850 XrpcError::Auth(AuthError::TokenExpired)
851 } else {
852 XrpcError::Xrpc(error)
853 }
854 }
855 Err(_) => self.parse_generic_error(),
856 }
857 } else {
859 self.parse_generic_error()
860 }
861 }
862
863 fn parse_generic_error(&self) -> XrpcError<R::Err> {
865 match serde_json::from_slice::<GenericXrpcError>(&self.buffer) {
866 Ok(mut generic) => {
867 generic.nsid = R::NSID;
868 generic.method = "";
869 generic.http_status = self.status;
870 match generic.error.as_str() {
871 "ExpiredToken" => XrpcError::Auth(AuthError::TokenExpired),
872 "InvalidToken" => XrpcError::Auth(AuthError::InvalidToken),
873 _ => XrpcError::Generic(generic),
874 }
875 }
876 Err(e) => XrpcError::Decode(DecodeError::Json(e)),
877 }
878 }
879
880 pub fn transmute<NEW: XrpcResp>(self) -> Response<NEW> {
892 Response {
893 buffer: self.buffer,
894 status: self.status,
895 _marker: PhantomData,
896 }
897 }
898}
899
900pub type RespOutput<S, Resp> = <Resp as XrpcResp>::Output<S>;
902pub type RespErr<Resp> = <Resp as XrpcResp>::Err;
904
905impl<R> Response<R>
906where
907 R: XrpcResp,
908{
909 pub fn into_output(self) -> Result<R::Output<SmolStr>, XrpcError<R::Err>>
911 where
912 R::Output<SmolStr>: DeserializeOwned,
913 {
914 if self.status.is_success() {
915 R::decode_output::<SmolStr>(&self.buffer).map_err(XrpcError::Decode)
916 } else {
917 Err(self.parse_error())
918 }
919 }
920}
921
922#[derive(Debug, Clone, Deserialize, Serialize)]
926pub struct GenericXrpcError {
927 pub error: SmolStr,
929 pub message: Option<SmolStr>,
931 #[serde(skip)]
933 pub nsid: &'static str,
934 #[serde(skip)]
936 pub method: &'static str,
937 #[serde(skip)]
939 pub http_status: StatusCode,
940}
941
942impl core::fmt::Display for GenericXrpcError {
943 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
944 if let Some(msg) = &self.message {
945 write!(
946 f,
947 "{}: {} (nsid={}, method={}, status={})",
948 self.error, msg, self.nsid, self.method, self.http_status
949 )
950 } else {
951 write!(
952 f,
953 "{} (nsid={}, method={}, status={})",
954 self.error, self.nsid, self.method, self.http_status
955 )
956 }
957 }
958}
959
960impl core::error::Error for GenericXrpcError {}
961
962#[derive(Debug, thiserror::Error)]
968#[cfg_attr(feature = "std", derive(miette::Diagnostic))]
969#[non_exhaustive]
970pub enum XrpcError<E: core::error::Error> {
971 #[error("XRPC error: {0}")]
973 #[cfg_attr(feature = "std", diagnostic(code(jacquard_common::xrpc::typed)))]
974 Xrpc(E),
975
976 #[error("Authentication error: {0}")]
978 #[cfg_attr(feature = "std", diagnostic(code(jacquard_common::xrpc::auth)))]
979 Auth(#[from] AuthError),
980
981 #[error("XRPC error: {0}")]
983 #[cfg_attr(feature = "std", diagnostic(code(jacquard_common::xrpc::generic)))]
984 Generic(GenericXrpcError),
985
986 #[error("Failed to decode response: {0}")]
988 #[cfg_attr(feature = "std", diagnostic(code(jacquard_common::xrpc::decode)))]
989 Decode(#[from] DecodeError),
990}
991
992impl<E> Serialize for XrpcError<E>
993where
994 E: core::error::Error + Serialize,
995{
996 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
997 where
998 S: serde::Serializer,
999 {
1000 use serde::ser::SerializeStruct;
1001
1002 match self {
1003 XrpcError::Xrpc(e) => e.serialize(serializer),
1005 XrpcError::Generic(g) => g.serialize(serializer),
1007 XrpcError::Auth(auth) => {
1009 let mut state = serializer.serialize_struct("XrpcError", 2)?;
1010 let (error, message) = match auth {
1011 AuthError::TokenExpired => ("ExpiredToken", Some("Access token has expired")),
1012 AuthError::InvalidToken => {
1013 ("InvalidToken", Some("Access token is invalid or malformed"))
1014 }
1015 AuthError::RefreshFailed => {
1016 ("RefreshFailed", Some("Token refresh request failed"))
1017 }
1018 AuthError::NotAuthenticated => (
1019 "AuthenticationRequired",
1020 Some("Request requires authentication but none was provided"),
1021 ),
1022 AuthError::DpopProofFailed => {
1023 ("DpopProofFailed", Some("DPoP proof construction failed"))
1024 }
1025 AuthError::DpopNonceFailed => {
1026 ("DpopNonceFailed", Some("DPoP nonce negotiation failed"))
1027 }
1028 AuthError::Other(hv) => {
1029 let msg = hv.to_str().unwrap_or("[non-utf8 header]");
1030 ("AuthenticationError", Some(msg))
1031 }
1032 };
1033 state.serialize_field("error", error)?;
1034 if let Some(msg) = message {
1035 state.serialize_field("message", msg)?;
1036 }
1037 state.end()
1038 }
1039 XrpcError::Decode(decode_err) => {
1040 let mut state = serializer.serialize_struct("XrpcError", 2)?;
1041 state.serialize_field("error", "ResponseDecodeError")?;
1042 let msg = format!("{:?}", decode_err);
1044 state.serialize_field("message", &msg)?;
1045 state.end()
1046 }
1047 }
1048 }
1049}
1050
1051#[cfg(feature = "streaming")]
1052impl<'a, C: HttpClient + HttpClientExt> XrpcCall<'a, C> {
1053 pub async fn download<R>(self, request: &R) -> Result<StreamingResponse, StreamError>
1057 where
1058 R: XrpcRequest + Serialize,
1059 <R as XrpcRequest>::Response: Send + Sync,
1060 {
1061 let http_request =
1062 build_http_request(&self.base, request, &self.opts).map_err(StreamError::transport)?;
1063
1064 let http_response = self
1065 .client
1066 .send_http_streaming(http_request)
1067 .await
1068 .map_err(StreamError::transport)?;
1069 let (parts, body) = http_response.into_parts();
1070
1071 Ok(StreamingResponse::new(parts, body))
1072 }
1073
1074 pub async fn stream<S, B>(
1079 self,
1080 stream: XrpcProcedureSend<S::Frame<B>>,
1081 ) -> Result<XrpcResponseStream<<S::Response as XrpcStreamResp>::Frame<B>>, StreamError>
1082 where
1083 S: XrpcProcedureStream + 'static,
1084 B: BosStr + 'static,
1085 <<S as XrpcProcedureStream>::Response as XrpcStreamResp>::Frame<B>: XrpcStreamResp,
1086 {
1087 use alloc::boxed::Box;
1088 use futures::TryStreamExt;
1089
1090 let uri = xrpc_endpoint_uri(&self.base, <S::Request as XrpcRequest>::NSID, None).map_err(
1091 |e| StreamError::protocol(format!("Failed to construct endpoint URI: {}", e)),
1092 )?;
1093
1094 let mut builder = http::Request::post(uri.as_str());
1095
1096 if let Some(token) = &self.opts.auth {
1097 let hv = match token {
1098 AuthorizationToken::Bearer(t) => {
1099 HeaderValue::from_str(&format!("Bearer {}", t.as_str()))
1100 }
1101 AuthorizationToken::Dpop(t) => {
1102 HeaderValue::from_str(&format!("DPoP {}", t.as_str()))
1103 }
1104 }
1105 .map_err(|e| StreamError::protocol(format!("Invalid authorization token: {}", e)))?;
1106 builder = builder.header(Header::Authorization, hv);
1107 }
1108
1109 if let Some(proxy) = &self.opts.atproto_proxy {
1110 builder = builder.header(Header::AtprotoProxy, proxy.as_str());
1111 }
1112 if let Some(labelers) = &self.opts.atproto_accept_labelers {
1113 if !labelers.is_empty() {
1114 let joined = labelers
1115 .iter()
1116 .map(|s| s.as_ref())
1117 .collect::<Vec<_>>()
1118 .join(", ");
1119 builder = builder.header(Header::AtprotoAcceptLabelers, joined);
1120 }
1121 }
1122
1123 for (name, value) in &self.opts.extra_headers {
1124 builder = builder.header(name, value);
1125 }
1126
1127 let (parts, _) = builder
1128 .body(())
1129 .map_err(|e| StreamError::protocol(e.to_string()))?
1130 .into_parts();
1131
1132 let body_stream = Box::pin(stream.0.map_ok(|f| f.buffer));
1133
1134 let resp = self
1135 .client
1136 .send_http_bidirectional(parts, body_stream)
1137 .await
1138 .map_err(StreamError::transport)?;
1139
1140 let (parts, body) = resp.into_parts();
1141
1142 Ok(XrpcResponseStream::<
1143 <<S as XrpcProcedureStream>::Response as XrpcStreamResp>::Frame<B>,
1144 >::from_typed_parts::<B>(parts, body))
1145 }
1146}
1147
1148#[cfg(test)]
1149mod tests {
1150 use super::*;
1151 use serde::{Deserialize, Serialize};
1152
1153 #[derive(Serialize, Deserialize)]
1154 #[allow(dead_code)]
1155 struct DummyReq;
1156
1157 #[derive(Deserialize, Serialize, Debug, thiserror::Error)]
1158 #[error("{0}")]
1159 struct DummyErr(SmolStr);
1160
1161 struct DummyResp;
1162
1163 impl XrpcResp for DummyResp {
1164 const NSID: &'static str = "test.dummy";
1165 const ENCODING: &'static str = "application/json";
1166 type Output<S: BosStr> = ();
1167 type Err = DummyErr;
1168 }
1169
1170 impl XrpcRequest for DummyReq {
1171 const NSID: &'static str = "test.dummy";
1172 const METHOD: XrpcMethod = XrpcMethod::Procedure("application/json");
1173 type Response = DummyResp;
1174 }
1175
1176 #[test]
1177 fn generic_error_carries_context() {
1178 let body = serde_json::json!({"error":"InvalidRequest","message":"missing"});
1179 let buf = Bytes::from(serde_json::to_vec(&body).unwrap());
1180 let resp: Response<DummyResp> = Response::new(buf, StatusCode::BAD_REQUEST);
1181 match resp.parse::<SmolStr>().unwrap_err() {
1182 XrpcError::Generic(g) => {
1183 assert_eq!(g.error.as_str(), "InvalidRequest");
1184 assert_eq!(g.message.as_deref(), Some("missing"));
1185 assert_eq!(g.nsid, DummyResp::NSID);
1186 assert_eq!(g.method, ""); assert_eq!(g.http_status, StatusCode::BAD_REQUEST);
1188 }
1189 other => panic!("unexpected: {other:?}"),
1190 }
1191 }
1192
1193 #[test]
1194 fn auth_error_mapping() {
1195 for (code, expect) in [
1196 ("ExpiredToken", AuthError::TokenExpired),
1197 ("InvalidToken", AuthError::InvalidToken),
1198 ] {
1199 let body = serde_json::json!({"error": code});
1200 let buf = Bytes::from(serde_json::to_vec(&body).unwrap());
1201 let resp: Response<DummyResp> = Response::new(buf, StatusCode::UNAUTHORIZED);
1202 match resp.parse::<SmolStr>().unwrap_err() {
1203 XrpcError::Auth(e) => match (e, expect) {
1204 (AuthError::TokenExpired, AuthError::TokenExpired) => {}
1205 (AuthError::InvalidToken, AuthError::InvalidToken) => {}
1206 other => panic!("mismatch: {other:?}"),
1207 },
1208 other => panic!("unexpected: {other:?}"),
1209 }
1210 }
1211 }
1212
1213 #[test]
1214 fn xrpc_uri_construction_basic() {
1215 use crate::alloc::string::ToString;
1216 #[derive(Serialize, Deserialize)]
1217 struct Req;
1218 #[derive(Deserialize, Serialize, Debug, thiserror::Error)]
1219 #[error("test error")]
1220 struct Err;
1221 struct Resp;
1222 impl XrpcResp for Resp {
1223 const NSID: &'static str = "com.example.test";
1224 const ENCODING: &'static str = "application/json";
1225 type Output<S: BosStr> = ();
1226 type Err = Err;
1227 }
1228 impl XrpcRequest for Req {
1229 const NSID: &'static str = "com.example.test";
1230 const METHOD: XrpcMethod = XrpcMethod::Query;
1231 type Response = Resp;
1232 }
1233
1234 let opts = CallOptions::default();
1235
1236 let base1 = Uri::parse("https://pds.example.com").expect("URI should be valid");
1238 let req1 = build_http_request(&base1, &Req, &opts).unwrap();
1239 let uri1 = req1.uri().to_string();
1240 assert!(
1241 uri1.contains("/xrpc/com.example.test"),
1242 "AC1.1: URI {} should contain '/xrpc/com.example.test'",
1243 uri1
1244 );
1245 assert_eq!(
1246 uri1, "https://pds.example.com/xrpc/com.example.test",
1247 "AC1.1: URI should be exact match"
1248 );
1249
1250 let base2 = Uri::parse("https://pds.example.com/base").expect("URI should be valid");
1252 let req2 = build_http_request(&base2, &Req, &opts).unwrap();
1253 let uri2 = req2.uri().to_string();
1254 assert!(
1255 uri2.contains("/base/xrpc/com.example.test"),
1256 "AC1.2: URI {} should contain '/base/xrpc/com.example.test'",
1257 uri2
1258 );
1259 assert_eq!(
1260 uri2, "https://pds.example.com/base/xrpc/com.example.test",
1261 "AC1.2: URI should preserve sub-path"
1262 );
1263
1264 let base_with_slash = Uri::parse("https://pds.example.com/").expect("URI should be valid");
1266 let req_slash = build_http_request(&base_with_slash, &Req, &opts).unwrap();
1267 let uri_slash = req_slash.uri().to_string();
1268 assert!(
1269 !uri_slash.contains("//xrpc"),
1270 "AC1.5: URI {} should not contain '//xrpc'",
1271 uri_slash
1272 );
1273 assert_eq!(
1274 uri_slash, "https://pds.example.com/xrpc/com.example.test",
1275 "AC1.5: URI should handle trailing slash"
1276 );
1277 }
1278
1279 #[test]
1280 fn xrpc_uri_query_parameters() {
1281 use crate::alloc::string::ToString;
1282 use serde::Serialize;
1283
1284 #[derive(Serialize)]
1285 struct QueryReq {
1286 #[serde(skip_serializing_if = "Option::is_none")]
1287 param1: Option<String>,
1288 #[serde(skip_serializing_if = "Option::is_none")]
1289 param2: Option<String>,
1290 }
1291
1292 #[derive(Serialize, Deserialize, Debug, thiserror::Error)]
1293 #[error("test error")]
1294 struct Err;
1295 struct Resp;
1296 impl XrpcResp for Resp {
1297 const NSID: &'static str = "com.example.test";
1298 const ENCODING: &'static str = "application/json";
1299 type Output<S: BosStr> = ();
1300 type Err = Err;
1301 }
1302 impl XrpcRequest for QueryReq {
1303 const NSID: &'static str = "com.example.test";
1304 const METHOD: XrpcMethod = XrpcMethod::Query;
1305 type Response = Resp;
1306 }
1307
1308 let opts = CallOptions::default();
1309 let base = Uri::parse("https://pds.example.com").expect("URI should be valid");
1310
1311 let req_with_params = QueryReq {
1313 param1: Some("value1".to_string()),
1314 param2: Some("value2".to_string()),
1315 };
1316 let http_req = build_http_request(&base, &req_with_params, &opts).unwrap();
1317 let uri_str = http_req.uri().to_string();
1318 assert!(
1319 uri_str.contains("?"),
1320 "AC1.3: URI should contain query string"
1321 );
1322 assert!(
1323 uri_str.contains("param1=value1"),
1324 "AC1.3: URI should contain param1"
1325 );
1326 assert!(
1327 uri_str.contains("param2=value2"),
1328 "AC1.3: URI should contain param2"
1329 );
1330
1331 let req_empty_params = QueryReq {
1333 param1: None,
1334 param2: None,
1335 };
1336 let http_req_empty = build_http_request(&base, &req_empty_params, &opts).unwrap();
1337 let uri_str_empty = http_req_empty.uri().to_string();
1338 assert!(
1339 !uri_str_empty.contains("?"),
1340 "AC1.4: URI {} should not contain '?' with empty params",
1341 uri_str_empty
1342 );
1343 assert_eq!(
1344 uri_str_empty, "https://pds.example.com/xrpc/com.example.test",
1345 "AC1.4: URI should have no query string"
1346 );
1347 }
1348
1349 #[test]
1350 fn xrpc_uri_special_characters_in_query() {
1351 use crate::alloc::string::ToString;
1352 use serde::Serialize;
1353
1354 #[derive(Serialize)]
1355 struct QueryReq {
1356 #[serde(skip_serializing_if = "Option::is_none")]
1357 search: Option<String>,
1358 #[serde(skip_serializing_if = "Option::is_none")]
1359 filter: Option<String>,
1360 #[serde(skip_serializing_if = "Option::is_none")]
1361 unicode_param: Option<String>,
1362 }
1363
1364 #[derive(Serialize, Deserialize, Debug, thiserror::Error)]
1365 #[error("test error")]
1366 struct Err;
1367 struct Resp;
1368 impl XrpcResp for Resp {
1369 const NSID: &'static str = "com.example.test";
1370 const ENCODING: &'static str = "application/json";
1371 type Output<S: BosStr> = ();
1372 type Err = Err;
1373 }
1374 impl XrpcRequest for QueryReq {
1375 const NSID: &'static str = "com.example.test";
1376 const METHOD: XrpcMethod = XrpcMethod::Query;
1377 type Response = Resp;
1378 }
1379
1380 let opts = CallOptions::default();
1381 let base = Uri::parse("https://pds.example.com").expect("URI should be valid");
1382
1383 let req_spaces = QueryReq {
1385 search: Some("hello world".to_string()),
1386 filter: None,
1387 unicode_param: None,
1388 };
1389 let http_req_spaces = build_http_request(&base, &req_spaces, &opts).unwrap();
1390 let uri_spaces = http_req_spaces.uri().to_string();
1391 assert!(
1392 uri_spaces.contains("search=hello"),
1393 "AC1.3: URI should contain search param"
1394 );
1395 assert!(
1397 uri_spaces.contains("hello+world") || uri_spaces.contains("hello%20world"),
1398 "AC1.3: URI {} should encode space in 'hello world'",
1399 uri_spaces
1400 );
1401
1402 let req_special = QueryReq {
1404 search: Some("a=b&c+d".to_string()),
1405 filter: None,
1406 unicode_param: None,
1407 };
1408 let http_req_special = build_http_request(&base, &req_special, &opts).unwrap();
1409 let uri_special = http_req_special.uri().to_string();
1410 assert!(
1411 uri_special.contains("?"),
1412 "AC1.3: URI should contain query string for special chars"
1413 );
1414 let parsed = Uri::parse(uri_special.clone());
1416 assert!(
1417 parsed.is_ok(),
1418 "AC1.3: URI {} should be parseable by fluent-uri",
1419 uri_special
1420 );
1421
1422 let req_unicode = QueryReq {
1424 search: None,
1425 filter: None,
1426 unicode_param: Some("你好世界".to_string()),
1427 };
1428 let http_req_unicode = build_http_request(&base, &req_unicode, &opts).unwrap();
1429 let uri_unicode = http_req_unicode.uri().to_string();
1430 assert!(
1431 uri_unicode.contains("?"),
1432 "AC1.3: URI should contain query string for unicode"
1433 );
1434 let parsed_unicode = Uri::parse(uri_unicode.clone());
1436 assert!(
1437 parsed_unicode.is_ok(),
1438 "AC1.3: URI {} should be parseable for unicode params",
1439 uri_unicode
1440 );
1441 }
1442
1443 #[test]
1444 fn no_double_slash_in_path() {
1445 use crate::alloc::string::ToString;
1446 #[derive(Serialize, Deserialize)]
1447 struct Req;
1448 #[derive(Deserialize, Serialize, Debug, thiserror::Error)]
1449 #[error("test error")]
1450 struct Err;
1451 struct Resp;
1452 impl XrpcResp for Resp {
1453 const NSID: &'static str = "com.example.test";
1454 const ENCODING: &'static str = "application/json";
1455 type Output<S: BosStr> = ();
1456 type Err = Err;
1457 }
1458 impl XrpcRequest for Req {
1459 const NSID: &'static str = "com.example.test";
1460 const METHOD: XrpcMethod = XrpcMethod::Query;
1461 type Response = Resp;
1462 }
1463
1464 let opts = CallOptions::default();
1465
1466 let base1 = Uri::parse("https://pds").expect("URI should be valid");
1468 let req1 = build_http_request(&base1, &Req, &opts).unwrap();
1469 let uri1 = req1.uri().to_string();
1470 assert!(
1471 !uri1.contains("//xrpc"),
1472 "URI {} should not contain '//xrpc'",
1473 uri1
1474 );
1475
1476 let base2 = Uri::parse("https://pds/base").expect("URI should be valid");
1477 let req2 = build_http_request(&base2, &Req, &opts).unwrap();
1478 let uri2 = req2.uri().to_string();
1479 assert!(
1480 !uri2.contains("//xrpc"),
1481 "URI {} should not contain '//xrpc'",
1482 uri2
1483 );
1484 }
1485}