Skip to main content

lettermint_rs/
client.rs

1use std::borrow::Cow;
2use std::future::Future;
3
4use bytes::Bytes;
5use http::{Request, Response, StatusCode};
6use std::error::Error;
7use thiserror::Error;
8use tracing::Instrument;
9
10/// A trait for providing the necessary information for a single REST API endpoint.
11pub trait Endpoint {
12    type Request: serde::Serialize + Send + Sync;
13    type Response: serde::de::DeserializeOwned + Send + Sync;
14
15    /// The path to the endpoint.
16    fn endpoint(&self) -> Cow<'static, str>;
17    /// The body for the endpoint.
18    fn body(&self) -> &Self::Request;
19    /// The HTTP method for the endpoint.
20    fn method(&self) -> http::Method {
21        http::Method::POST
22    }
23    /// Optional extra headers (e.g., Idempotency-Key).
24    fn extra_headers(&self) -> Vec<(Cow<'static, str>, Cow<'static, str>)> {
25        vec![]
26    }
27    /// Parse the raw response body into [`Self::Response`].
28    ///
29    /// The default implementation deserializes JSON. Override this for
30    /// endpoints that return non-JSON responses (e.g. plain text).
31    ///
32    /// # Errors
33    ///
34    /// Returns a [`serde_json::Error`] if the body cannot be parsed.
35    fn parse_response(&self, body: &[u8]) -> Result<Self::Response, serde_json::Error> {
36        serde_json::from_slice(body)
37    }
38}
39
40/// A trait which represents an asynchronous query which may be made to a Lettermint client.
41pub trait Query<C> {
42    type Result;
43    /// Perform the query against the client.
44    fn execute(self, client: &C) -> impl Future<Output = Self::Result> + Send;
45}
46
47/// An error thrown by the [`Query`] trait.
48///
49/// This enum is `#[non_exhaustive]` — new variants may be added in future
50/// releases without a semver-breaking change.
51#[derive(Debug, Error)]
52#[non_exhaustive]
53pub enum QueryError<E>
54where
55    E: Error + Send + Sync + 'static,
56{
57    #[error("client error: {}", source)]
58    Client { source: E },
59
60    #[error("failed to serialize request body: {}", source)]
61    SerializeBody { source: serde_json::Error },
62
63    #[error("could not parse JSON response: {}", source)]
64    DeserializeResponse { source: serde_json::Error },
65
66    #[error("failed to build request: {}", source)]
67    Body {
68        #[from]
69        source: http::Error,
70    },
71
72    /// Validation error (HTTP 422) with per-field details.
73    #[error("validation error: {message:?}")]
74    Validation {
75        error_type: Option<String>,
76        message: Option<String>,
77        /// Per-field validation errors (e.g., `{"from": ["domain not verified"]}`)
78        errors: Option<std::collections::HashMap<String, Vec<String>>>,
79        body: Bytes,
80    },
81
82    /// Authentication or authorization error (HTTP 401/403).
83    #[error("authentication error: {message:?}")]
84    Authentication {
85        message: Option<String>,
86        body: Bytes,
87    },
88
89    /// Rate limit exceeded (HTTP 429).
90    #[error("rate limit exceeded: {message:?}")]
91    RateLimit {
92        message: Option<String>,
93        body: Bytes,
94    },
95
96    /// Any other non-success API response.
97    #[error("api error: status={status}, error_type={error_type:?}, message={message:?}")]
98    Api {
99        status: StatusCode,
100        error_type: Option<String>,
101        message: Option<String>,
102        body: Bytes,
103    },
104}
105
106impl<E> QueryError<E>
107where
108    E: Error + Send + Sync + 'static,
109{
110    pub fn client(source: E) -> Self {
111        QueryError::Client { source }
112    }
113}
114
115impl<T, C> Query<C> for T
116where
117    T: Endpoint + Send + Sync,
118    C: Client + Send + Sync,
119{
120    type Result = Result<T::Response, QueryError<C::Error>>;
121
122    async fn execute(self, client: &C) -> Self::Result {
123        let method = self.method();
124        let endpoint = self.endpoint();
125
126        let span = tracing::debug_span!(
127            "lettermint.request",
128            method = %method,
129            endpoint = %endpoint,
130            status = tracing::field::Empty,
131        );
132
133        async {
134            // Always format as an absolute path so http::Uri parses it correctly.
135            // The Client implementation joins this with its base URL.
136            let uri = format!("/{}", endpoint.trim_start_matches('/'));
137            let mut req_builder = http::Request::builder()
138                .method(method.clone())
139                .uri(uri)
140                .header("Accept", "application/json");
141
142            for (name, value) in self.extra_headers() {
143                req_builder = req_builder.header(name.as_ref(), value.as_ref());
144            }
145
146            let body = match method {
147                http::Method::GET | http::Method::DELETE | http::Method::HEAD => Bytes::new(),
148                _ => {
149                    req_builder = req_builder.header("Content-Type", "application/json");
150                    serde_json::to_vec(self.body())
151                        .map_err(|e| {
152                            tracing::error!(error = %e, "failed to serialize request body");
153                            QueryError::SerializeBody { source: e }
154                        })?
155                        .into()
156                }
157            };
158
159            let http_req = req_builder.body(body)?;
160            let response = client.execute(http_req).await.map_err(|e| {
161                tracing::error!(error = %e, "client transport error");
162                QueryError::client(e)
163            })?;
164
165            let status = response.status();
166            tracing::Span::current().record("status", status.as_u16());
167
168            if !status.is_success() {
169                #[derive(serde::Deserialize)]
170                struct LettermintErrorBody {
171                    error_type: Option<String>,
172                    error: Option<String>,
173                    message: Option<String>,
174                    errors: Option<std::collections::HashMap<String, Vec<String>>>,
175                }
176
177                let body = response.body().clone();
178                let parsed = serde_json::from_slice::<LettermintErrorBody>(&body).ok();
179                let error_type = parsed
180                    .as_ref()
181                    .and_then(|p| p.error_type.clone().or_else(|| p.error.clone()));
182                let message = parsed.as_ref().and_then(|p| p.message.clone());
183
184                tracing::warn!(
185                    status = status.as_u16(),
186                    error_type = error_type.as_deref(),
187                    message = message.as_deref(),
188                    "API error response",
189                );
190
191                return Err(match status.as_u16() {
192                    422 => QueryError::Validation {
193                        error_type,
194                        message,
195                        errors: parsed.and_then(|p| p.errors),
196                        body,
197                    },
198                    401 | 403 => QueryError::Authentication { message, body },
199                    429 => QueryError::RateLimit { message, body },
200                    _ => QueryError::Api {
201                        status,
202                        error_type,
203                        message,
204                        body,
205                    },
206                });
207            }
208
209            tracing::debug!(status = status.as_u16(), "request completed");
210
211            self.parse_response(response.body()).map_err(|e| {
212                tracing::error!(error = %e, "failed to deserialize response body");
213                QueryError::DeserializeResponse { source: e }
214            })
215        }
216        .instrument(span)
217        .await
218    }
219}
220
221/// A trait representing a client which can communicate with a Lettermint instance.
222pub trait Client {
223    type Error: Error + Send + Sync + 'static;
224    fn execute(
225        &self,
226        req: Request<Bytes>,
227    ) -> impl Future<Output = Result<Response<Bytes>, Self::Error>> + Send;
228}
229
230#[cfg(test)]
231mod tests {
232    use super::*;
233    use std::borrow::Cow;
234    use std::sync::{Arc, Mutex};
235
236    #[derive(Debug, thiserror::Error)]
237    #[error("test client error")]
238    struct MockClientError;
239
240    #[derive(Clone)]
241    struct MockClient {
242        last_request: Arc<Mutex<Option<Request<Bytes>>>>,
243        response_status: StatusCode,
244        response_body: Bytes,
245    }
246
247    impl MockClient {
248        fn ok(body: &'static [u8]) -> Self {
249            Self {
250                last_request: Arc::new(Mutex::new(None)),
251                response_status: StatusCode::OK,
252                response_body: Bytes::from_static(body),
253            }
254        }
255
256        fn error(status: StatusCode, body: &'static [u8]) -> Self {
257            Self {
258                last_request: Arc::new(Mutex::new(None)),
259                response_status: status,
260                response_body: Bytes::from_static(body),
261            }
262        }
263
264        fn last_request(&self) -> Request<Bytes> {
265            self.last_request
266                .lock()
267                .expect("lock")
268                .clone()
269                .expect("request present")
270        }
271    }
272
273    impl Client for MockClient {
274        type Error = MockClientError;
275
276        async fn execute(&self, req: Request<Bytes>) -> Result<Response<Bytes>, Self::Error> {
277            *self.last_request.lock().expect("lock") = Some(req);
278            Ok(Response::builder()
279                .status(self.response_status)
280                .body(self.response_body.clone())
281                .expect("response"))
282        }
283    }
284
285    #[derive(serde::Serialize)]
286    struct TestBody {
287        value: &'static str,
288    }
289
290    #[derive(Debug, serde::Deserialize, PartialEq)]
291    struct TestResponse {
292        ok: bool,
293    }
294
295    struct PostEndpoint {
296        body: TestBody,
297        extra: Vec<(Cow<'static, str>, Cow<'static, str>)>,
298    }
299
300    impl PostEndpoint {
301        fn new() -> Self {
302            Self {
303                body: TestBody { value: "hello" },
304                extra: vec![],
305            }
306        }
307
308        fn with_extra_header(mut self, name: &'static str, value: impl Into<String>) -> Self {
309            self.extra
310                .push((Cow::Borrowed(name), Cow::Owned(value.into())));
311            self
312        }
313    }
314
315    impl Endpoint for PostEndpoint {
316        type Request = TestBody;
317        type Response = TestResponse;
318
319        fn endpoint(&self) -> Cow<'static, str> {
320            "send".into()
321        }
322
323        fn body(&self) -> &Self::Request {
324            &self.body
325        }
326
327        fn extra_headers(&self) -> Vec<(Cow<'static, str>, Cow<'static, str>)> {
328            self.extra.clone()
329        }
330    }
331
332    #[derive(serde::Serialize)]
333    struct NoBody;
334
335    struct GetEndpoint;
336    impl Endpoint for GetEndpoint {
337        type Request = NoBody;
338        type Response = TestResponse;
339
340        fn endpoint(&self) -> Cow<'static, str> {
341            "messages".into()
342        }
343
344        fn body(&self) -> &Self::Request {
345            static BODY: NoBody = NoBody;
346            &BODY
347        }
348
349        fn method(&self) -> http::Method {
350            http::Method::GET
351        }
352    }
353
354    #[tokio::test]
355    async fn post_request_has_json_body_and_content_type() {
356        let client = MockClient::ok(br#"{"ok":true}"#);
357        let resp = PostEndpoint::new().execute(&client).await.expect("execute");
358        assert!(resp.ok);
359
360        let req = client.last_request();
361        assert_eq!(req.method(), http::Method::POST);
362        assert_eq!(req.body(), &Bytes::from_static(br#"{"value":"hello"}"#));
363        assert_eq!(
364            req.headers().get("Content-Type").unwrap().to_str().unwrap(),
365            "application/json"
366        );
367        assert_eq!(
368            req.headers().get("Accept").unwrap().to_str().unwrap(),
369            "application/json"
370        );
371    }
372
373    #[tokio::test]
374    async fn get_request_has_no_body_or_content_type() {
375        let client = MockClient::ok(br#"{"ok":true}"#);
376        let resp = GetEndpoint.execute(&client).await.expect("execute");
377        assert!(resp.ok);
378
379        let req = client.last_request();
380        assert_eq!(req.method(), http::Method::GET);
381        assert!(req.body().is_empty());
382        assert!(req.headers().get("Content-Type").is_none());
383        assert!(req.headers().get("Accept").is_some());
384    }
385
386    #[tokio::test]
387    async fn extra_headers_are_applied() {
388        let client = MockClient::ok(br#"{"ok":true}"#);
389        PostEndpoint::new()
390            .with_extra_header("Idempotency-Key", "test-key")
391            .execute(&client)
392            .await
393            .expect("execute");
394
395        let req = client.last_request();
396        assert_eq!(
397            req.headers()
398                .get("Idempotency-Key")
399                .unwrap()
400                .to_str()
401                .unwrap(),
402            "test-key"
403        );
404    }
405
406    #[tokio::test]
407    async fn validation_error_422() {
408        let client = MockClient::error(
409            StatusCode::UNPROCESSABLE_ENTITY,
410            br#"{"error_type":"DailyLimitExceeded","message":"Limit reached"}"#,
411        );
412
413        let err = PostEndpoint::new()
414            .execute(&client)
415            .await
416            .expect_err("should fail");
417
418        match err {
419            QueryError::Validation {
420                error_type,
421                message,
422                ..
423            } => {
424                assert_eq!(error_type.as_deref(), Some("DailyLimitExceeded"));
425                assert_eq!(message.as_deref(), Some("Limit reached"));
426            }
427            _ => panic!("expected Validation error, got: {err:?}"),
428        }
429    }
430
431    #[tokio::test]
432    async fn authentication_error_401() {
433        let client = MockClient::error(
434            StatusCode::UNAUTHORIZED,
435            br#"{"message":"Invalid API token"}"#,
436        );
437
438        let err = PostEndpoint::new()
439            .execute(&client)
440            .await
441            .expect_err("should fail");
442
443        match err {
444            QueryError::Authentication { message, .. } => {
445                assert_eq!(message.as_deref(), Some("Invalid API token"));
446            }
447            _ => panic!("expected Authentication error, got: {err:?}"),
448        }
449    }
450
451    #[tokio::test]
452    async fn rate_limit_error_429() {
453        let client = MockClient::error(
454            StatusCode::TOO_MANY_REQUESTS,
455            br#"{"message":"Rate limit exceeded"}"#,
456        );
457
458        let err = PostEndpoint::new()
459            .execute(&client)
460            .await
461            .expect_err("should fail");
462
463        match err {
464            QueryError::RateLimit { message, .. } => {
465                assert_eq!(message.as_deref(), Some("Rate limit exceeded"));
466            }
467            _ => panic!("expected RateLimit error, got: {err:?}"),
468        }
469    }
470
471    #[tokio::test]
472    async fn api_error_with_non_json_body() {
473        let client = MockClient::error(StatusCode::BAD_GATEWAY, b"gateway timeout");
474
475        let err = PostEndpoint::new()
476            .execute(&client)
477            .await
478            .expect_err("should fail");
479
480        match err {
481            QueryError::Api {
482                status,
483                error_type,
484                message,
485                body,
486            } => {
487                assert_eq!(status, StatusCode::BAD_GATEWAY);
488                assert_eq!(error_type, None);
489                assert_eq!(message, None);
490                assert_eq!(body, Bytes::from_static(b"gateway timeout"));
491            }
492            _ => panic!("expected Api error, got: {err:?}"),
493        }
494    }
495
496    #[tokio::test]
497    async fn success_with_invalid_json_returns_deserialize_error() {
498        let client = MockClient::ok(b"not json");
499        let err = PostEndpoint::new()
500            .execute(&client)
501            .await
502            .expect_err("should fail");
503
504        assert!(matches!(err, QueryError::DeserializeResponse { .. }));
505    }
506
507    #[tokio::test]
508    async fn api_error_with_error_field_fallback() {
509        let client = MockClient::error(
510            StatusCode::BAD_REQUEST,
511            br#"{"error":"invalid_request","message":"Bad from address"}"#,
512        );
513
514        let err = PostEndpoint::new()
515            .execute(&client)
516            .await
517            .expect_err("should fail");
518
519        match err {
520            QueryError::Api {
521                status,
522                error_type,
523                message,
524                ..
525            } => {
526                assert_eq!(status, StatusCode::BAD_REQUEST);
527                assert_eq!(error_type.as_deref(), Some("invalid_request"));
528                assert_eq!(message.as_deref(), Some("Bad from address"));
529            }
530            _ => panic!("expected Api error, got: {err:?}"),
531        }
532    }
533
534    #[tokio::test]
535    async fn authentication_error_403() {
536        let client = MockClient::error(StatusCode::FORBIDDEN, br#"{"message":"Access denied"}"#);
537
538        let err = PostEndpoint::new()
539            .execute(&client)
540            .await
541            .expect_err("should fail");
542
543        assert!(matches!(err, QueryError::Authentication { .. }));
544    }
545}