1use std::borrow::Cow;
2use std::future::Future;
3
4use bytes::Bytes;
5use http::{Request, Response, StatusCode};
6use std::error::Error;
7use thiserror::Error;
8use tracing::Instrument;
9
10pub trait Endpoint {
12 type Request: serde::Serialize + Send + Sync;
13 type Response: serde::de::DeserializeOwned + Send + Sync;
14
15 fn endpoint(&self) -> Cow<'static, str>;
17 fn body(&self) -> &Self::Request;
19 fn method(&self) -> http::Method {
21 http::Method::POST
22 }
23 fn extra_headers(&self) -> Vec<(Cow<'static, str>, Cow<'static, str>)> {
25 vec![]
26 }
27 fn parse_response(&self, body: &[u8]) -> Result<Self::Response, serde_json::Error> {
36 serde_json::from_slice(body)
37 }
38}
39
40pub trait Query<C> {
42 type Result;
43 fn execute(self, client: &C) -> impl Future<Output = Self::Result> + Send;
45}
46
47#[derive(Debug, Error)]
52#[non_exhaustive]
53pub enum QueryError<E>
54where
55 E: Error + Send + Sync + 'static,
56{
57 #[error("client error: {}", source)]
58 Client { source: E },
59
60 #[error("failed to serialize request body: {}", source)]
61 SerializeBody { source: serde_json::Error },
62
63 #[error("could not parse JSON response: {}", source)]
64 DeserializeResponse { source: serde_json::Error },
65
66 #[error("failed to build request: {}", source)]
67 Body {
68 #[from]
69 source: http::Error,
70 },
71
72 #[error("validation error: {message:?}")]
74 Validation {
75 error_type: Option<String>,
76 message: Option<String>,
77 errors: Option<std::collections::HashMap<String, Vec<String>>>,
79 body: Bytes,
80 },
81
82 #[error("authentication error: {message:?}")]
84 Authentication {
85 message: Option<String>,
86 body: Bytes,
87 },
88
89 #[error("rate limit exceeded: {message:?}")]
91 RateLimit {
92 message: Option<String>,
93 body: Bytes,
94 },
95
96 #[error("api error: status={status}, error_type={error_type:?}, message={message:?}")]
98 Api {
99 status: StatusCode,
100 error_type: Option<String>,
101 message: Option<String>,
102 body: Bytes,
103 },
104}
105
106impl<E> QueryError<E>
107where
108 E: Error + Send + Sync + 'static,
109{
110 pub fn client(source: E) -> Self {
111 QueryError::Client { source }
112 }
113}
114
115impl<T, C> Query<C> for T
116where
117 T: Endpoint + Send + Sync,
118 C: Client + Send + Sync,
119{
120 type Result = Result<T::Response, QueryError<C::Error>>;
121
122 async fn execute(self, client: &C) -> Self::Result {
123 let method = self.method();
124 let endpoint = self.endpoint();
125
126 let span = tracing::debug_span!(
127 "lettermint.request",
128 method = %method,
129 endpoint = %endpoint,
130 status = tracing::field::Empty,
131 );
132
133 async {
134 let uri = format!("/{}", endpoint.trim_start_matches('/'));
137 let mut req_builder = http::Request::builder()
138 .method(method.clone())
139 .uri(uri)
140 .header("Accept", "application/json");
141
142 for (name, value) in self.extra_headers() {
143 req_builder = req_builder.header(name.as_ref(), value.as_ref());
144 }
145
146 let body = match method {
147 http::Method::GET | http::Method::DELETE | http::Method::HEAD => Bytes::new(),
148 _ => {
149 req_builder = req_builder.header("Content-Type", "application/json");
150 serde_json::to_vec(self.body())
151 .map_err(|e| {
152 tracing::error!(error = %e, "failed to serialize request body");
153 QueryError::SerializeBody { source: e }
154 })?
155 .into()
156 }
157 };
158
159 let http_req = req_builder.body(body)?;
160 let response = client.execute(http_req).await.map_err(|e| {
161 tracing::error!(error = %e, "client transport error");
162 QueryError::client(e)
163 })?;
164
165 let status = response.status();
166 tracing::Span::current().record("status", status.as_u16());
167
168 if !status.is_success() {
169 #[derive(serde::Deserialize)]
170 struct LettermintErrorBody {
171 error_type: Option<String>,
172 error: Option<String>,
173 message: Option<String>,
174 errors: Option<std::collections::HashMap<String, Vec<String>>>,
175 }
176
177 let body = response.body().clone();
178 let parsed = serde_json::from_slice::<LettermintErrorBody>(&body).ok();
179 let error_type = parsed
180 .as_ref()
181 .and_then(|p| p.error_type.clone().or_else(|| p.error.clone()));
182 let message = parsed.as_ref().and_then(|p| p.message.clone());
183
184 tracing::warn!(
185 status = status.as_u16(),
186 error_type = error_type.as_deref(),
187 message = message.as_deref(),
188 "API error response",
189 );
190
191 return Err(match status.as_u16() {
192 422 => QueryError::Validation {
193 error_type,
194 message,
195 errors: parsed.and_then(|p| p.errors),
196 body,
197 },
198 401 | 403 => QueryError::Authentication { message, body },
199 429 => QueryError::RateLimit { message, body },
200 _ => QueryError::Api {
201 status,
202 error_type,
203 message,
204 body,
205 },
206 });
207 }
208
209 tracing::debug!(status = status.as_u16(), "request completed");
210
211 self.parse_response(response.body()).map_err(|e| {
212 tracing::error!(error = %e, "failed to deserialize response body");
213 QueryError::DeserializeResponse { source: e }
214 })
215 }
216 .instrument(span)
217 .await
218 }
219}
220
221pub trait Client {
223 type Error: Error + Send + Sync + 'static;
224 fn execute(
225 &self,
226 req: Request<Bytes>,
227 ) -> impl Future<Output = Result<Response<Bytes>, Self::Error>> + Send;
228}
229
230#[cfg(test)]
231mod tests {
232 use super::*;
233 use std::borrow::Cow;
234 use std::sync::{Arc, Mutex};
235
236 #[derive(Debug, thiserror::Error)]
237 #[error("test client error")]
238 struct MockClientError;
239
240 #[derive(Clone)]
241 struct MockClient {
242 last_request: Arc<Mutex<Option<Request<Bytes>>>>,
243 response_status: StatusCode,
244 response_body: Bytes,
245 }
246
247 impl MockClient {
248 fn ok(body: &'static [u8]) -> Self {
249 Self {
250 last_request: Arc::new(Mutex::new(None)),
251 response_status: StatusCode::OK,
252 response_body: Bytes::from_static(body),
253 }
254 }
255
256 fn error(status: StatusCode, body: &'static [u8]) -> Self {
257 Self {
258 last_request: Arc::new(Mutex::new(None)),
259 response_status: status,
260 response_body: Bytes::from_static(body),
261 }
262 }
263
264 fn last_request(&self) -> Request<Bytes> {
265 self.last_request
266 .lock()
267 .expect("lock")
268 .clone()
269 .expect("request present")
270 }
271 }
272
273 impl Client for MockClient {
274 type Error = MockClientError;
275
276 async fn execute(&self, req: Request<Bytes>) -> Result<Response<Bytes>, Self::Error> {
277 *self.last_request.lock().expect("lock") = Some(req);
278 Ok(Response::builder()
279 .status(self.response_status)
280 .body(self.response_body.clone())
281 .expect("response"))
282 }
283 }
284
285 #[derive(serde::Serialize)]
286 struct TestBody {
287 value: &'static str,
288 }
289
290 #[derive(Debug, serde::Deserialize, PartialEq)]
291 struct TestResponse {
292 ok: bool,
293 }
294
295 struct PostEndpoint {
296 body: TestBody,
297 extra: Vec<(Cow<'static, str>, Cow<'static, str>)>,
298 }
299
300 impl PostEndpoint {
301 fn new() -> Self {
302 Self {
303 body: TestBody { value: "hello" },
304 extra: vec![],
305 }
306 }
307
308 fn with_extra_header(mut self, name: &'static str, value: impl Into<String>) -> Self {
309 self.extra
310 .push((Cow::Borrowed(name), Cow::Owned(value.into())));
311 self
312 }
313 }
314
315 impl Endpoint for PostEndpoint {
316 type Request = TestBody;
317 type Response = TestResponse;
318
319 fn endpoint(&self) -> Cow<'static, str> {
320 "send".into()
321 }
322
323 fn body(&self) -> &Self::Request {
324 &self.body
325 }
326
327 fn extra_headers(&self) -> Vec<(Cow<'static, str>, Cow<'static, str>)> {
328 self.extra.clone()
329 }
330 }
331
332 #[derive(serde::Serialize)]
333 struct NoBody;
334
335 struct GetEndpoint;
336 impl Endpoint for GetEndpoint {
337 type Request = NoBody;
338 type Response = TestResponse;
339
340 fn endpoint(&self) -> Cow<'static, str> {
341 "messages".into()
342 }
343
344 fn body(&self) -> &Self::Request {
345 static BODY: NoBody = NoBody;
346 &BODY
347 }
348
349 fn method(&self) -> http::Method {
350 http::Method::GET
351 }
352 }
353
354 #[tokio::test]
355 async fn post_request_has_json_body_and_content_type() {
356 let client = MockClient::ok(br#"{"ok":true}"#);
357 let resp = PostEndpoint::new().execute(&client).await.expect("execute");
358 assert!(resp.ok);
359
360 let req = client.last_request();
361 assert_eq!(req.method(), http::Method::POST);
362 assert_eq!(req.body(), &Bytes::from_static(br#"{"value":"hello"}"#));
363 assert_eq!(
364 req.headers().get("Content-Type").unwrap().to_str().unwrap(),
365 "application/json"
366 );
367 assert_eq!(
368 req.headers().get("Accept").unwrap().to_str().unwrap(),
369 "application/json"
370 );
371 }
372
373 #[tokio::test]
374 async fn get_request_has_no_body_or_content_type() {
375 let client = MockClient::ok(br#"{"ok":true}"#);
376 let resp = GetEndpoint.execute(&client).await.expect("execute");
377 assert!(resp.ok);
378
379 let req = client.last_request();
380 assert_eq!(req.method(), http::Method::GET);
381 assert!(req.body().is_empty());
382 assert!(req.headers().get("Content-Type").is_none());
383 assert!(req.headers().get("Accept").is_some());
384 }
385
386 #[tokio::test]
387 async fn extra_headers_are_applied() {
388 let client = MockClient::ok(br#"{"ok":true}"#);
389 PostEndpoint::new()
390 .with_extra_header("Idempotency-Key", "test-key")
391 .execute(&client)
392 .await
393 .expect("execute");
394
395 let req = client.last_request();
396 assert_eq!(
397 req.headers()
398 .get("Idempotency-Key")
399 .unwrap()
400 .to_str()
401 .unwrap(),
402 "test-key"
403 );
404 }
405
406 #[tokio::test]
407 async fn validation_error_422() {
408 let client = MockClient::error(
409 StatusCode::UNPROCESSABLE_ENTITY,
410 br#"{"error_type":"DailyLimitExceeded","message":"Limit reached"}"#,
411 );
412
413 let err = PostEndpoint::new()
414 .execute(&client)
415 .await
416 .expect_err("should fail");
417
418 match err {
419 QueryError::Validation {
420 error_type,
421 message,
422 ..
423 } => {
424 assert_eq!(error_type.as_deref(), Some("DailyLimitExceeded"));
425 assert_eq!(message.as_deref(), Some("Limit reached"));
426 }
427 _ => panic!("expected Validation error, got: {err:?}"),
428 }
429 }
430
431 #[tokio::test]
432 async fn authentication_error_401() {
433 let client = MockClient::error(
434 StatusCode::UNAUTHORIZED,
435 br#"{"message":"Invalid API token"}"#,
436 );
437
438 let err = PostEndpoint::new()
439 .execute(&client)
440 .await
441 .expect_err("should fail");
442
443 match err {
444 QueryError::Authentication { message, .. } => {
445 assert_eq!(message.as_deref(), Some("Invalid API token"));
446 }
447 _ => panic!("expected Authentication error, got: {err:?}"),
448 }
449 }
450
451 #[tokio::test]
452 async fn rate_limit_error_429() {
453 let client = MockClient::error(
454 StatusCode::TOO_MANY_REQUESTS,
455 br#"{"message":"Rate limit exceeded"}"#,
456 );
457
458 let err = PostEndpoint::new()
459 .execute(&client)
460 .await
461 .expect_err("should fail");
462
463 match err {
464 QueryError::RateLimit { message, .. } => {
465 assert_eq!(message.as_deref(), Some("Rate limit exceeded"));
466 }
467 _ => panic!("expected RateLimit error, got: {err:?}"),
468 }
469 }
470
471 #[tokio::test]
472 async fn api_error_with_non_json_body() {
473 let client = MockClient::error(StatusCode::BAD_GATEWAY, b"gateway timeout");
474
475 let err = PostEndpoint::new()
476 .execute(&client)
477 .await
478 .expect_err("should fail");
479
480 match err {
481 QueryError::Api {
482 status,
483 error_type,
484 message,
485 body,
486 } => {
487 assert_eq!(status, StatusCode::BAD_GATEWAY);
488 assert_eq!(error_type, None);
489 assert_eq!(message, None);
490 assert_eq!(body, Bytes::from_static(b"gateway timeout"));
491 }
492 _ => panic!("expected Api error, got: {err:?}"),
493 }
494 }
495
496 #[tokio::test]
497 async fn success_with_invalid_json_returns_deserialize_error() {
498 let client = MockClient::ok(b"not json");
499 let err = PostEndpoint::new()
500 .execute(&client)
501 .await
502 .expect_err("should fail");
503
504 assert!(matches!(err, QueryError::DeserializeResponse { .. }));
505 }
506
507 #[tokio::test]
508 async fn api_error_with_error_field_fallback() {
509 let client = MockClient::error(
510 StatusCode::BAD_REQUEST,
511 br#"{"error":"invalid_request","message":"Bad from address"}"#,
512 );
513
514 let err = PostEndpoint::new()
515 .execute(&client)
516 .await
517 .expect_err("should fail");
518
519 match err {
520 QueryError::Api {
521 status,
522 error_type,
523 message,
524 ..
525 } => {
526 assert_eq!(status, StatusCode::BAD_REQUEST);
527 assert_eq!(error_type.as_deref(), Some("invalid_request"));
528 assert_eq!(message.as_deref(), Some("Bad from address"));
529 }
530 _ => panic!("expected Api error, got: {err:?}"),
531 }
532 }
533
534 #[tokio::test]
535 async fn authentication_error_403() {
536 let client = MockClient::error(StatusCode::FORBIDDEN, br#"{"message":"Access denied"}"#);
537
538 let err = PostEndpoint::new()
539 .execute(&client)
540 .await
541 .expect_err("should fail");
542
543 assert!(matches!(err, QueryError::Authentication { .. }));
544 }
545}