1use bytes::Bytes;
14use http::{
15 HeaderName, HeaderValue, Request, StatusCode,
16 header::{AUTHORIZATION, CONTENT_TYPE},
17};
18use serde::{Deserialize, Serialize};
19use smol_str::SmolStr;
20use std::fmt::{self, Debug};
21use std::{error::Error, marker::PhantomData};
22use url::Url;
23
24use crate::IntoStatic;
25use crate::error::TransportError;
26use crate::http_client::HttpClient;
27use crate::types::value::Data;
28use crate::{AuthorizationToken, error::AuthError};
29use crate::{CowStr, error::XrpcResult};
30
31#[derive(Debug, thiserror::Error, miette::Diagnostic)]
33pub enum EncodeError {
34 #[error("Failed to serialize query: {0}")]
36 Query(
37 #[from]
38 #[source]
39 serde_html_form::ser::Error,
40 ),
41 #[error("Failed to serialize JSON: {0}")]
43 Json(
44 #[from]
45 #[source]
46 serde_json::Error,
47 ),
48 #[error("Encoding error: {0}")]
50 Other(String),
51}
52
53#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
55pub enum XrpcMethod {
56 Query,
58 Procedure(&'static str),
60}
61
62impl XrpcMethod {
63 pub const fn as_str(&self) -> &'static str {
65 match self {
66 Self::Query => "GET",
67 Self::Procedure(_) => "POST",
68 }
69 }
70
71 pub const fn body_encoding(&self) -> Option<&'static str> {
73 match self {
74 Self::Query => None,
75 Self::Procedure(enc) => Some(enc),
76 }
77 }
78}
79
80pub trait XrpcRequest: Serialize {
87 const NSID: &'static str;
89
90 const METHOD: XrpcMethod;
92
93 const OUTPUT_ENCODING: &'static str;
95
96 type Output<'de>: Deserialize<'de> + IntoStatic;
98
99 type Err<'de>: Error + Deserialize<'de> + IntoStatic;
101
102 fn encode_body(&self) -> Result<Vec<u8>, EncodeError> {
106 Ok(serde_json::to_vec(self)?)
107 }
108}
109
110#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)]
112#[serde(bound(deserialize = "'de: 'a"))]
113pub struct GenericError<'a>(Data<'a>);
114
115impl fmt::Display for GenericError<'_> {
116 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
117 self.0.fmt(f)
118 }
119}
120
121impl Error for GenericError<'_> {}
122
123impl IntoStatic for GenericError<'_> {
124 type Output = GenericError<'static>;
125 fn into_static(self) -> Self::Output {
126 GenericError(self.0.into_static())
127 }
128}
129
130#[derive(Debug, Default, Clone)]
132pub struct CallOptions<'a> {
133 pub auth: Option<AuthorizationToken<'a>>,
135 pub atproto_proxy: Option<CowStr<'a>>,
137 pub atproto_accept_labelers: Option<Vec<CowStr<'a>>>,
139 pub extra_headers: Vec<(HeaderName, HeaderValue)>,
141}
142
143impl IntoStatic for CallOptions<'_> {
144 type Output = CallOptions<'static>;
145
146 fn into_static(self) -> Self::Output {
147 CallOptions {
148 auth: self.auth.map(|auth| auth.into_static()),
149 atproto_proxy: self.atproto_proxy.map(|proxy| proxy.into_static()),
150 atproto_accept_labelers: self
151 .atproto_accept_labelers
152 .map(|labelers| labelers.into_static()),
153 extra_headers: self.extra_headers,
154 }
155 }
156}
157
158pub trait XrpcExt: HttpClient {
186 fn xrpc<'a>(&'a self, base: Url) -> XrpcCall<'a, Self>
188 where
189 Self: Sized,
190 {
191 XrpcCall {
192 client: self,
193 base,
194 opts: CallOptions::default(),
195 }
196 }
197}
198
199impl<T: HttpClient> XrpcExt for T {}
200
201pub struct XrpcCall<'a, C: HttpClient> {
233 pub(crate) client: &'a C,
234 pub(crate) base: Url,
235 pub(crate) opts: CallOptions<'a>,
236}
237
238impl<'a, C: HttpClient> XrpcCall<'a, C> {
239 pub fn auth(mut self, token: AuthorizationToken<'a>) -> Self {
241 self.opts.auth = Some(token);
242 self
243 }
244 pub fn proxy(mut self, proxy: CowStr<'a>) -> Self {
246 self.opts.atproto_proxy = Some(proxy);
247 self
248 }
249 pub fn accept_labelers(mut self, labelers: Vec<CowStr<'a>>) -> Self {
251 self.opts.atproto_accept_labelers = Some(labelers);
252 self
253 }
254 pub fn header(mut self, name: HeaderName, value: HeaderValue) -> Self {
256 self.opts.extra_headers.push((name, value));
257 self
258 }
259 pub fn with_options(mut self, opts: CallOptions<'a>) -> Self {
261 self.opts = opts;
262 self
263 }
264
265 pub async fn send<R: XrpcRequest + Send>(self, request: &R) -> XrpcResult<Response<R>> {
274 let http_request = build_http_request(&self.base, request, &self.opts)
275 .map_err(crate::error::TransportError::from)?;
276
277 let http_response = self
278 .client
279 .send_http(http_request)
280 .await
281 .map_err(|e| crate::error::TransportError::Other(Box::new(e)))?;
282
283 process_response(http_response)
284 }
285}
286
287#[inline]
291pub fn process_response<R: XrpcRequest + Send>(
292 http_response: http::Response<Vec<u8>>,
293) -> XrpcResult<Response<R>> {
294 let status = http_response.status();
295 if status.as_u16() == 401 {
298 if let Some(hv) = http_response.headers().get(http::header::WWW_AUTHENTICATE) {
299 return Err(crate::error::ClientError::Auth(
300 crate::error::AuthError::Other(hv.clone()),
301 ));
302 }
303 }
304 let buffer = Bytes::from(http_response.into_body());
305
306 if !status.is_success() && !matches!(status.as_u16(), 400 | 401) {
307 return Err(crate::error::HttpError {
308 status,
309 body: Some(buffer),
310 }
311 .into());
312 }
313
314 Ok(Response::new(buffer, status))
315}
316
317pub enum Header {
319 ContentType,
321 Authorization,
323 AtprotoProxy,
327 AtprotoAcceptLabelers,
329}
330
331impl From<Header> for HeaderName {
332 fn from(value: Header) -> Self {
333 match value {
334 Header::ContentType => CONTENT_TYPE,
335 Header::Authorization => AUTHORIZATION,
336 Header::AtprotoProxy => HeaderName::from_static("atproto-proxy"),
337 Header::AtprotoAcceptLabelers => HeaderName::from_static("atproto-accept-labelers"),
338 }
339 }
340}
341
342pub fn build_http_request<R: XrpcRequest>(
344 base: &Url,
345 req: &R,
346 opts: &CallOptions<'_>,
347) -> core::result::Result<Request<Vec<u8>>, crate::error::TransportError> {
348 let mut url = base.clone();
349 let mut path = url.path().trim_end_matches('/').to_owned();
350 path.push_str("/xrpc/");
351 path.push_str(R::NSID);
352 url.set_path(&path);
353
354 if let XrpcMethod::Query = R::METHOD {
355 let qs = serde_html_form::to_string(&req)
356 .map_err(|e| crate::error::TransportError::InvalidRequest(e.to_string()))?;
357 if !qs.is_empty() {
358 url.set_query(Some(&qs));
359 } else {
360 url.set_query(None);
361 }
362 }
363
364 let method = match R::METHOD {
365 XrpcMethod::Query => http::Method::GET,
366 XrpcMethod::Procedure(_) => http::Method::POST,
367 };
368
369 let mut builder = Request::builder().method(method).uri(url.as_str());
370
371 if let XrpcMethod::Procedure(encoding) = R::METHOD {
372 builder = builder.header(Header::ContentType, encoding);
373 }
374 builder = builder.header(http::header::ACCEPT, R::OUTPUT_ENCODING);
375
376 if let Some(token) = &opts.auth {
377 let hv = match token {
378 AuthorizationToken::Bearer(t) => {
379 HeaderValue::from_str(&format!("Bearer {}", t.as_ref()))
380 }
381 AuthorizationToken::Dpop(t) => HeaderValue::from_str(&format!("DPoP {}", t.as_ref())),
382 }
383 .map_err(|e| {
384 TransportError::InvalidRequest(format!("Invalid authorization token: {}", e))
385 })?;
386 builder = builder.header(Header::Authorization, hv);
387 }
388
389 if let Some(proxy) = &opts.atproto_proxy {
390 builder = builder.header(Header::AtprotoProxy, proxy.as_ref());
391 }
392 if let Some(labelers) = &opts.atproto_accept_labelers {
393 if !labelers.is_empty() {
394 let joined = labelers
395 .iter()
396 .map(|s| s.as_ref())
397 .collect::<Vec<_>>()
398 .join(", ");
399 builder = builder.header(Header::AtprotoAcceptLabelers, joined);
400 }
401 }
402 for (name, value) in &opts.extra_headers {
403 builder = builder.header(name, value);
404 }
405
406 let body = if let XrpcMethod::Procedure(_) = R::METHOD {
407 req.encode_body()
408 .map_err(|e| TransportError::InvalidRequest(e.to_string()))?
409 } else {
410 vec![]
411 };
412
413 builder
414 .body(body)
415 .map_err(|e| TransportError::InvalidRequest(e.to_string()))
416}
417
418pub struct Response<R: XrpcRequest> {
423 buffer: Bytes,
424 status: StatusCode,
425 _marker: PhantomData<R>,
426}
427
428impl<R: XrpcRequest> Response<R> {
429 pub fn new(buffer: Bytes, status: StatusCode) -> Self {
431 Self {
432 buffer,
433 status,
434 _marker: PhantomData,
435 }
436 }
437
438 pub fn status(&self) -> StatusCode {
440 self.status
441 }
442
443 pub fn parse(&self) -> Result<R::Output<'_>, XrpcError<R::Err<'_>>> {
445 fn parse_output<'b, R: XrpcRequest>(
447 buffer: &'b [u8],
448 ) -> Result<R::Output<'b>, serde_json::Error> {
449 serde_json::from_slice(buffer)
450 }
451
452 fn parse_error<'b, R: XrpcRequest>(
453 buffer: &'b [u8],
454 ) -> Result<R::Err<'b>, serde_json::Error> {
455 serde_json::from_slice(buffer)
456 }
457
458 if self.status.is_success() {
460 match parse_output::<R>(&self.buffer) {
461 Ok(output) => Ok(output),
462 Err(e) => Err(XrpcError::Decode(e)),
463 }
464 } else if self.status.as_u16() == 400 {
466 match parse_error::<R>(&self.buffer) {
467 Ok(error) => Err(XrpcError::Xrpc(error)),
468 Err(_) => {
469 match serde_json::from_slice::<GenericXrpcError>(&self.buffer) {
471 Ok(mut generic) => {
472 generic.nsid = R::NSID;
473 generic.method = R::METHOD.as_str();
474 generic.http_status = self.status;
475 match generic.error.as_str() {
477 "ExpiredToken" => Err(XrpcError::Auth(AuthError::TokenExpired)),
478 "InvalidToken" => Err(XrpcError::Auth(AuthError::InvalidToken)),
479 _ => Err(XrpcError::Generic(generic)),
480 }
481 }
482 Err(e) => Err(XrpcError::Decode(e)),
483 }
484 }
485 }
486 } else {
488 match serde_json::from_slice::<GenericXrpcError>(&self.buffer) {
489 Ok(mut generic) => {
490 generic.nsid = R::NSID;
491 generic.method = R::METHOD.as_str();
492 generic.http_status = self.status;
493 match generic.error.as_str() {
494 "ExpiredToken" => Err(XrpcError::Auth(AuthError::TokenExpired)),
495 "InvalidToken" => Err(XrpcError::Auth(AuthError::InvalidToken)),
496 _ => Err(XrpcError::Auth(AuthError::NotAuthenticated)),
497 }
498 }
499 Err(e) => Err(XrpcError::Decode(e)),
500 }
501 }
502 }
503
504 pub fn into_output(self) -> Result<R::Output<'static>, XrpcError<R::Err<'static>>>
506 where
507 for<'a> R::Output<'a>: IntoStatic<Output = R::Output<'static>>,
508 for<'a> R::Err<'a>: IntoStatic<Output = R::Err<'static>>,
509 {
510 fn parse_output<'b, R: XrpcRequest>(
512 buffer: &'b [u8],
513 ) -> Result<R::Output<'b>, serde_json::Error> {
514 serde_json::from_slice(buffer)
515 }
516
517 fn parse_error<'b, R: XrpcRequest>(
518 buffer: &'b [u8],
519 ) -> Result<R::Err<'b>, serde_json::Error> {
520 serde_json::from_slice(buffer)
521 }
522
523 if self.status.is_success() {
525 match parse_output::<R>(&self.buffer) {
526 Ok(output) => Ok(output.into_static()),
527 Err(e) => Err(XrpcError::Decode(e)),
528 }
529 } else if self.status.as_u16() == 400 {
531 match parse_error::<R>(&self.buffer) {
532 Ok(error) => Err(XrpcError::Xrpc(error.into_static())),
533 Err(_) => {
534 match serde_json::from_slice::<GenericXrpcError>(&self.buffer) {
536 Ok(mut generic) => {
537 generic.nsid = R::NSID;
538 generic.method = R::METHOD.as_str();
539 generic.http_status = self.status;
540 match generic.error.as_ref() {
542 "ExpiredToken" => Err(XrpcError::Auth(AuthError::TokenExpired)),
543 "InvalidToken" => Err(XrpcError::Auth(AuthError::InvalidToken)),
544 _ => Err(XrpcError::Generic(generic)),
545 }
546 }
547 Err(e) => Err(XrpcError::Decode(e)),
548 }
549 }
550 }
551 } else {
553 match serde_json::from_slice::<GenericXrpcError>(&self.buffer) {
554 Ok(mut generic) => {
555 let status = self.status;
556 generic.nsid = R::NSID;
557 generic.method = R::METHOD.as_str();
558 generic.http_status = status;
559 match generic.error.as_ref() {
560 "ExpiredToken" => Err(XrpcError::Auth(AuthError::TokenExpired)),
561 "InvalidToken" => Err(XrpcError::Auth(AuthError::InvalidToken)),
562 _ => Err(XrpcError::Auth(AuthError::NotAuthenticated)),
563 }
564 }
565 Err(e) => Err(XrpcError::Decode(e)),
566 }
567 }
568 }
569
570 pub fn buffer(&self) -> &Bytes {
572 &self.buffer
573 }
574}
575
576#[derive(Debug, Clone, Deserialize)]
580pub struct GenericXrpcError {
581 pub error: SmolStr,
583 pub message: Option<SmolStr>,
585 #[serde(skip)]
587 pub nsid: &'static str,
588 #[serde(skip)]
590 pub method: &'static str,
591 #[serde(skip)]
593 pub http_status: StatusCode,
594}
595
596impl std::fmt::Display for GenericXrpcError {
597 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
598 if let Some(msg) = &self.message {
599 write!(
600 f,
601 "{}: {} (nsid={}, method={}, status={})",
602 self.error, msg, self.nsid, self.method, self.http_status
603 )
604 } else {
605 write!(
606 f,
607 "{} (nsid={}, method={}, status={})",
608 self.error, self.nsid, self.method, self.http_status
609 )
610 }
611 }
612}
613
614impl std::error::Error for GenericXrpcError {}
615
616#[derive(Debug, thiserror::Error, miette::Diagnostic)]
621pub enum XrpcError<E: std::error::Error + IntoStatic> {
622 #[error("XRPC error: {0}")]
624 #[diagnostic(code(jacquard_common::xrpc::typed))]
625 Xrpc(E),
626
627 #[error("Authentication error: {0}")]
629 #[diagnostic(code(jacquard_common::xrpc::auth))]
630 Auth(#[from] AuthError),
631
632 #[error("XRPC error: {0}")]
634 #[diagnostic(code(jacquard_common::xrpc::generic))]
635 Generic(GenericXrpcError),
636
637 #[error("Failed to decode response: {0}")]
639 #[diagnostic(code(jacquard_common::xrpc::decode))]
640 Decode(#[from] serde_json::Error),
641}
642
643#[cfg(test)]
644mod tests {
645 use super::*;
646 use serde::{Deserialize, Serialize};
647
648 #[derive(Serialize)]
649 struct DummyReq;
650
651 #[derive(Deserialize, Debug, thiserror::Error)]
652 #[error("{0}")]
653 struct DummyErr<'a>(#[serde(borrow)] CowStr<'a>);
654
655 impl IntoStatic for DummyErr<'_> {
656 type Output = DummyErr<'static>;
657 fn into_static(self) -> Self::Output {
658 DummyErr(self.0.into_static())
659 }
660 }
661
662 impl XrpcRequest for DummyReq {
663 const NSID: &'static str = "test.dummy";
664 const METHOD: XrpcMethod = XrpcMethod::Procedure("application/json");
665 const OUTPUT_ENCODING: &'static str = "application/json";
666 type Output<'de> = ();
667 type Err<'de> = DummyErr<'de>;
668 }
669
670 #[test]
671 fn generic_error_carries_context() {
672 let body = serde_json::json!({"error":"InvalidRequest","message":"missing"});
673 let buf = Bytes::from(serde_json::to_vec(&body).unwrap());
674 let resp: Response<DummyReq> = Response::new(buf, StatusCode::BAD_REQUEST);
675 match resp.parse().unwrap_err() {
676 XrpcError::Generic(g) => {
677 assert_eq!(g.error.as_str(), "InvalidRequest");
678 assert_eq!(g.message.as_deref(), Some("missing"));
679 assert_eq!(g.nsid, DummyReq::NSID);
680 assert_eq!(g.method, DummyReq::METHOD.as_str());
681 assert_eq!(g.http_status, StatusCode::BAD_REQUEST);
682 }
683 other => panic!("unexpected: {other:?}"),
684 }
685 }
686
687 #[test]
688 fn auth_error_mapping() {
689 for (code, expect) in [
690 ("ExpiredToken", AuthError::TokenExpired),
691 ("InvalidToken", AuthError::InvalidToken),
692 ] {
693 let body = serde_json::json!({"error": code});
694 let buf = Bytes::from(serde_json::to_vec(&body).unwrap());
695 let resp: Response<DummyReq> = Response::new(buf, StatusCode::UNAUTHORIZED);
696 match resp.parse().unwrap_err() {
697 XrpcError::Auth(e) => match (e, expect) {
698 (AuthError::TokenExpired, AuthError::TokenExpired) => {}
699 (AuthError::InvalidToken, AuthError::InvalidToken) => {}
700 other => panic!("mismatch: {other:?}"),
701 },
702 other => panic!("unexpected: {other:?}"),
703 }
704 }
705 }
706
707 #[test]
708 fn no_double_slash_in_path() {
709 #[derive(Serialize)]
710 struct Req;
711 #[derive(Deserialize, Debug, thiserror::Error)]
712 #[error("{0}")]
713 struct Err<'a>(#[serde(borrow)] CowStr<'a>);
714 impl IntoStatic for Err<'_> {
715 type Output = Err<'static>;
716 fn into_static(self) -> Self::Output {
717 Err(self.0.into_static())
718 }
719 }
720 impl XrpcRequest for Req {
721 const NSID: &'static str = "com.example.test";
722 const METHOD: XrpcMethod = XrpcMethod::Query;
723 const OUTPUT_ENCODING: &'static str = "application/json";
724 type Output<'de> = ();
725 type Err<'de> = Err<'de>;
726 }
727
728 let opts = CallOptions::default();
729 for base in [
730 Url::parse("https://pds").unwrap(),
731 Url::parse("https://pds/").unwrap(),
732 Url::parse("https://pds/base/").unwrap(),
733 ] {
734 let req = build_http_request(&base, &Req, &opts).unwrap();
735 let uri = req.uri().to_string();
736 assert!(uri.contains("/xrpc/com.example.test"));
737 assert!(!uri.contains("//xrpc"));
738 }
739 }
740}
741
742pub trait XrpcClient: HttpClient {
744 fn base_uri(&self) -> Url;
746
747 fn opts(&self) -> impl Future<Output = CallOptions<'_>> {
749 async { CallOptions::default() }
750 }
751 fn send<R: XrpcRequest + Send>(
753 self,
754 request: &R,
755 ) -> impl Future<Output = XrpcResult<Response<R>>>;
756}