Skip to main content

lexe_api/
rest.rs

1use std::{
2    borrow::Cow,
3    time::{Duration, Instant},
4};
5
6use bytes::Bytes;
7use http::{
8    Method,
9    header::{CONTENT_TYPE, HeaderValue},
10};
11use lexe_api_core::error::{
12    ApiError, CommonApiError, CommonErrorKind, ErrorCode, ErrorResponse,
13};
14use lexe_common::time::DisplayMs;
15use lexe_crypto::ed25519;
16use lexe_std::backoff;
17use lightning::util::ser::Writeable;
18use reqwest::IntoUrl;
19use serde::{Serialize, de::DeserializeOwned};
20use tracing::{Instrument, debug, warn};
21
22use crate::{trace, trace::TraceId};
23
24/// The CONTENT-TYPE header for signed BCS-serialized structs.
25pub static CONTENT_TYPE_ED25519_BCS: HeaderValue =
26    HeaderValue::from_static("application/ed25519-bcs");
27
28// Apparently it takes >15s to open a channel with an external peer.
29pub const API_REQUEST_TIMEOUT: Duration = Duration::from_secs(30);
30
31// Avoid `Method::` prefix. Associated constants can't be imported
32pub const GET: Method = Method::GET;
33pub const PUT: Method = Method::PUT;
34pub const POST: Method = Method::POST;
35pub const DELETE: Method = Method::DELETE;
36
37/// A generic RestClient which conforms to Lexe's API.
38#[derive(Clone)]
39pub struct RestClient {
40    client: reqwest::Client,
41    /// The process that this [`RestClient`] is being called from, e.g. "app"
42    from: Cow<'static, str>,
43    /// The process that this [`RestClient`] is calling, e.g. "node-run"
44    to: &'static str,
45}
46
47impl RestClient {
48    /// Builds a new [`RestClient`] with the given TLS config and safe defaults.
49    ///
50    /// The `from` and `to` fields should succinctly specify the client and
51    /// server components of the API trait that this [`RestClient`] is used for,
52    /// e.g. `from`="app", `to`="node-run" or `from`="node", `to`="backend".
53    /// The [`RestClient`] will log both fields so that requests from this
54    /// client can be differentiated from those made by other clients in the
55    /// same process, and propagate the `from` field to the server via the user
56    /// agent header so that servers can identify requesting clients.
57    pub fn new(
58        from: impl Into<Cow<'static, str>>,
59        to: &'static str,
60        tls_config: rustls::ClientConfig,
61    ) -> Self {
62        fn inner(
63            from: Cow<'static, str>,
64            to: &'static str,
65            tls_config: rustls::ClientConfig,
66        ) -> RestClient {
67            let client = RestClient::client_builder(&from)
68                .use_preconfigured_tls(tls_config)
69                .https_only(true)
70                .build()
71                .expect("Failed to build reqwest Client");
72            RestClient { client, from, to }
73        }
74        inner(from.into(), to, tls_config)
75    }
76
77    /// [`RestClient::new`] but without TLS.
78    /// This should only be used for non-security-critical endpoints.
79    pub fn new_insecure(
80        from: impl Into<Cow<'static, str>>,
81        to: &'static str,
82    ) -> Self {
83        fn inner(from: Cow<'static, str>, to: &'static str) -> RestClient {
84            let client = RestClient::client_builder(&from)
85                .https_only(false)
86                .build()
87                .expect("Failed to build reqwest Client");
88            RestClient { client, from, to }
89        }
90        inner(from.into(), to)
91    }
92
93    /// Get a [`reqwest::ClientBuilder`] with some defaults set.
94    /// NOTE that for safety, `https_only` is set to `true`, but you can
95    /// override it if needed.
96    pub fn client_builder(from: impl AsRef<str>) -> reqwest::ClientBuilder {
97        fn inner(from: &str) -> reqwest::ClientBuilder {
98            reqwest::Client::builder()
99                .user_agent(from)
100                .https_only(true)
101                .timeout(API_REQUEST_TIMEOUT)
102        }
103        inner(from.as_ref())
104    }
105
106    /// Construct a [`RestClient`] from a [`reqwest::Client`].
107    pub fn from_inner(
108        client: reqwest::Client,
109        from: impl Into<Cow<'static, str>>,
110        to: &'static str,
111    ) -> Self {
112        Self {
113            client,
114            from: from.into(),
115            to,
116        }
117    }
118
119    #[inline]
120    pub fn user_agent(&self) -> &Cow<'static, str> {
121        &self.from
122    }
123
124    // --- RequestBuilder helpers --- //
125
126    #[inline]
127    pub fn get<U, T>(&self, url: U, data: &T) -> reqwest::RequestBuilder
128    where
129        U: IntoUrl,
130        T: Serialize + ?Sized,
131    {
132        self.builder(GET, url).query(data)
133    }
134
135    #[inline]
136    pub fn post<U, T>(&self, url: U, data: &T) -> reqwest::RequestBuilder
137    where
138        U: IntoUrl,
139        T: Serialize + ?Sized,
140    {
141        self.builder(POST, url).json(data)
142    }
143
144    #[inline]
145    pub fn put<U, T>(&self, url: U, data: &T) -> reqwest::RequestBuilder
146    where
147        U: IntoUrl,
148        T: Serialize + ?Sized,
149    {
150        self.builder(PUT, url).json(data)
151    }
152
153    #[inline]
154    pub fn delete<U, T>(&self, url: U, data: &T) -> reqwest::RequestBuilder
155    where
156        U: IntoUrl,
157        T: Serialize + ?Sized,
158    {
159        self.builder(DELETE, url).json(data)
160    }
161
162    /// Serializes a LDK [`Writeable`] object into the request body.
163    #[inline]
164    pub fn serialize_ldk_writeable<U, W>(
165        &self,
166        method: Method,
167        url: U,
168        data: &W,
169    ) -> reqwest::RequestBuilder
170    where
171        U: IntoUrl,
172        W: Writeable,
173    {
174        let bytes = {
175            let mut buf = Vec::new();
176            data.write(&mut buf)
177                .expect("Serializing into in-memory buf shouldn't fail");
178            Bytes::from(buf)
179        };
180        self.builder(method, url).body(bytes)
181    }
182
183    /// A clean slate [`reqwest::RequestBuilder`] for non-standard requests.
184    /// Otherwise prefer to use the ready-made `get`, `post`, ..., etc helpers.
185    pub fn builder(
186        &self,
187        method: Method,
188        url: impl IntoUrl,
189    ) -> reqwest::RequestBuilder {
190        self.client.request(method, url)
191    }
192
193    // --- Request send/recv --- //
194
195    /// Sends the built HTTP request.
196    /// Tries to JSON deserialize the response body to `T`.
197    pub async fn send<T: DeserializeOwned, E: ApiError>(
198        &self,
199        request_builder: reqwest::RequestBuilder,
200    ) -> Result<T, E> {
201        let bytes = self.send_no_deserialize::<E>(request_builder).await?;
202        Self::json_deserialize(bytes)
203    }
204
205    /// Sends the HTTP request, but *doesn't* JSON-deserialize the response.
206    pub async fn send_no_deserialize<E: ApiError>(
207        &self,
208        request_builder: reqwest::RequestBuilder,
209    ) -> Result<Bytes, E> {
210        let request = request_builder.build().map_err(CommonApiError::from)?;
211        let (request_span, trace_id) =
212            trace::client::request_span(&request, &self.from, self.to);
213        let response = self
214            .send_inner(request, &trace_id)
215            .instrument(request_span)
216            .await;
217        let res = match response {
218            Ok(Ok(resp)) => resp.read_bytes().await.map(Ok),
219            Ok(Err(api_error)) => Ok(Err(api_error)),
220            Err(common_error) => Err(common_error),
221        };
222        Self::map_response_errors::<Bytes, E>(res)
223    }
224
225    /// Sends the HTTP request, but returns a [`StreamBody`] that yields
226    /// [`Bytes`] chunks as they arrive.
227    pub async fn send_and_stream_response<E: ApiError>(
228        &self,
229        request_builder: reqwest::RequestBuilder,
230    ) -> Result<StreamBody, E> {
231        let request = request_builder.build().map_err(CommonApiError::from)?;
232        let (request_span, trace_id) =
233            trace::client::request_span(&request, &self.from, self.to);
234        let response = self
235            .send_inner(request, &trace_id)
236            .instrument(request_span)
237            .await;
238        Self::map_response_errors::<SuccessResponse, E>(response)
239            .map(|resp| resp.into_stream_body())
240    }
241
242    /// Sends the built HTTP request, retrying up to `retries` times. Tries to
243    /// JSON deserialize the response body to `T`.
244    ///
245    /// If one of the request attempts yields an error code in `stop_codes`, we
246    /// will immediately stop retrying and return that error.
247    ///
248    /// See also: [`RestClient::send`]
249    pub async fn send_with_retries<T: DeserializeOwned, E: ApiError>(
250        &self,
251        request_builder: reqwest::RequestBuilder,
252        retries: usize,
253        stop_codes: &[ErrorCode],
254    ) -> Result<T, E> {
255        let request = request_builder.build().map_err(CommonApiError::from)?;
256        let (request_span, trace_id) =
257            trace::client::request_span(&request, &self.from, self.to);
258        let response = self
259            .send_with_retries_inner(request, retries, stop_codes, &trace_id)
260            .instrument(request_span)
261            .await;
262        let bytes = Self::map_response_errors::<Bytes, E>(response)?;
263        Self::json_deserialize(bytes)
264    }
265
266    // the `send_inner` and `send_with_retries_inner` intentionally use zero
267    // generics in their function signatures to minimize code bloat.
268
269    async fn send_with_retries_inner(
270        &self,
271        request: reqwest::Request,
272        retries: usize,
273        stop_codes: &[ErrorCode],
274        trace_id: &TraceId,
275    ) -> Result<Result<Bytes, ErrorResponse>, CommonApiError> {
276        let mut backoff_durations = backoff::get_backoff_iter();
277        let mut attempts_left = retries + 1;
278
279        let mut request = Some(request);
280
281        // Do the 'retries' first.
282        for _ in 0..retries {
283            tracing::Span::current().record("attempts_left", attempts_left);
284
285            // clone the request. the request body is cheaply cloneable. the
286            // headers and url are not :'(
287            let maybe_request_clone = request
288                .as_ref()
289                .expect(
290                    "This should never happen; we only take() the original \
291                     request on the last attempt",
292                )
293                .try_clone();
294
295            let request_clone = match maybe_request_clone {
296                Some(request_clone) => request_clone,
297                // We only get None if the request body is streamed and not set
298                // up front. In this case, we can't send more than once.
299                None => break,
300            };
301
302            // send the request and look for any error codes in the response
303            // that we should bail on and stop retrying.
304            match self.send_inner(request_clone, trace_id).await {
305                Ok(Ok(resp)) => match resp.read_bytes().await {
306                    Ok(bytes) => {
307                        return Ok(Ok(bytes));
308                    }
309                    Err(common_error) => {
310                        if stop_codes.contains(&common_error.to_code()) {
311                            return Err(common_error);
312                        }
313                    }
314                },
315                Ok(Err(api_error)) =>
316                    if stop_codes.contains(&api_error.code) {
317                        return Ok(Err(api_error));
318                    },
319                Err(common_error) => {
320                    if stop_codes.contains(&common_error.to_code()) {
321                        return Err(common_error);
322                    }
323                }
324            }
325
326            // sleep for a bit before next retry
327            tokio::time::sleep(backoff_durations.next().unwrap()).await;
328            attempts_left -= 1;
329        }
330
331        // We ran out of retries; return the result of the 'main' attempt.
332        assert_eq!(attempts_left, 1);
333        tracing::Span::current().record("attempts_left", attempts_left);
334
335        let resp = self.send_inner(request.take().unwrap(), trace_id).await?;
336        match resp {
337            Ok(resp_succ) => resp_succ.read_bytes().await.map(Ok),
338            Err(api_error) => Ok(Err(api_error)),
339        }
340    }
341
342    async fn send_inner(
343        &self,
344        mut request: reqwest::Request,
345        trace_id: &TraceId,
346    ) -> Result<Result<SuccessResponse, ErrorResponse>, CommonApiError> {
347        let start = tokio::time::Instant::now().into_std();
348        // This message should mirror `LxOnRequest`.
349        debug!(target: trace::TARGET, "New client request");
350
351        // Add the trace id header to the request.
352        match request.headers_mut().try_insert(
353            trace::TRACE_ID_HEADER_NAME.clone(),
354            trace_id.to_header_value(),
355        ) {
356            Ok(None) => (),
357            Ok(Some(_)) => warn!(target: trace::TARGET, "Trace id existed?"),
358            Err(e) => warn!(target: trace::TARGET, "Header map full?: {e:#}"),
359        }
360
361        // send the request, await the response headers
362        let resp = self.client.execute(request).await.inspect_err(|e| {
363            let req_time = DisplayMs(start.elapsed());
364            warn!(
365                target: trace::TARGET,
366                %req_time,
367                "Done (error)(sending) Error sending request: {e:#}"
368            );
369        })?;
370
371        // add the response http status to the current request span
372        let status = resp.status().as_u16();
373
374        if resp.status().is_success() {
375            Ok(Ok(SuccessResponse { resp, start }))
376        } else {
377            // http error => await response json and convert to ErrorResponse
378            let error =
379                resp.json::<ErrorResponse>().await.inspect_err(|e| {
380                    let req_time = DisplayMs(start.elapsed());
381                    warn!(
382                        target: trace::TARGET,
383                        %req_time,
384                        %status,
385                        "Done (error)(receiving) \
386                         Couldn't receive ErrorResponse: {e:#}",
387                    );
388                })?;
389
390            let req_time = DisplayMs(start.elapsed());
391            warn!(
392                target: trace::TARGET,
393                %req_time,
394                %status,
395                error_code = %error.code,
396                error_msg = %error.msg,
397                "Done (error)(response) Server returned error response",
398            );
399            Ok(Err(error))
400        }
401    }
402
403    /// Converts the [`Result<Result<T, ErrorResponse>, CommonApiError>`]
404    /// returned by [`Self::send_inner`] to [`Result<T, E>`].
405    fn map_response_errors<T, E: ApiError>(
406        response: Result<Result<T, ErrorResponse>, CommonApiError>,
407    ) -> Result<T, E> {
408        match response {
409            Ok(Ok(resp)) => Ok(resp),
410            Ok(Err(err_api)) => Err(E::from(err_api)),
411            Err(err_client) => Err(E::from(err_client)),
412        }
413    }
414
415    /// JSON-deserializes the REST response bytes.
416    fn json_deserialize<T: DeserializeOwned, E: ApiError>(
417        bytes: Bytes,
418    ) -> Result<T, E> {
419        serde_json::from_slice::<T>(&bytes)
420            .map_err(|err| {
421                let kind = CommonErrorKind::Decode;
422                let mut msg = format!("JSON deserialization failed: {err:#}");
423
424                // If we're in debug, append the response str to the error msg.
425                // TODO(max): Try to find a way to do this safely in prod.
426                if cfg!(any(debug_assertions, test, feature = "test-utils")) {
427                    let resp_msg = String::from_utf8_lossy(&bytes);
428                    msg.push_str(&format!(": '{resp_msg}'"));
429                }
430
431                CommonApiError::new(kind, msg)
432            })
433            .map_err(E::from)
434    }
435}
436
437// -- impl SuccessResponse -- //
438
439/// A successful [`reqwest::Response`], though we haven't read the body yet.
440struct SuccessResponse {
441    resp: reqwest::Response,
442    start: Instant,
443}
444
445impl SuccessResponse {
446    /// Convert into a streaming response body.
447    fn into_stream_body(self) -> StreamBody {
448        StreamBody {
449            resp: self.resp,
450            start: self.start,
451        }
452    }
453
454    /// Read the successful response body into a single raw [`Bytes`].
455    async fn read_bytes(self) -> Result<Bytes, CommonApiError> {
456        let status = self.resp.status().as_u16();
457        let bytes = self.resp.bytes().await.inspect_err(|e| {
458            let req_time = DisplayMs(self.start.elapsed());
459            warn!(
460                target: trace::TARGET,
461                %req_time,
462                %status,
463                "Done (error)(receiving) \
464                 Couldn't receive response body: {e:#}",
465            );
466        })?;
467
468        let req_time = DisplayMs(self.start.elapsed());
469        // NOTE: This client request log can be at INFO.
470        // It's cluttering our logs though, so we're suppressing.
471        debug!(target: trace::TARGET, %req_time, %status, "Done (success)");
472        Ok(bytes)
473    }
474}
475
476// -- impl StreamResponse -- //
477
478/// A streaming response body which yields chunks of the body as raw [`Bytes`]
479/// as they arrive.
480pub struct StreamBody {
481    resp: reqwest::Response,
482    start: Instant,
483}
484
485impl StreamBody {
486    /// Stream a chunk of the response body. Returns `Ok(None)` when the stream
487    /// is complete.
488    pub async fn next_chunk(
489        &mut self,
490    ) -> Result<Option<Bytes>, CommonApiError> {
491        match self.resp.chunk().await {
492            Ok(Some(chunk)) => Ok(Some(chunk)),
493            Ok(None) => {
494                // Done, log how long it took.
495                let status = self.resp.status().as_u16();
496                let req_time = DisplayMs(self.start.elapsed());
497                debug!(target: trace::TARGET, %req_time, %status, "Done (success)");
498                Ok(None)
499            }
500            Err(e) => {
501                // Error receiving next chunk.
502                let status = self.resp.status().as_u16();
503                let req_time = DisplayMs(self.start.elapsed());
504                warn!(
505                    target: trace::TARGET,
506                    %req_time,
507                    %status,
508                    "Done (error)(receiving) \
509                     Couldn't receive streaming response chunk: {e:#}",
510                );
511                Err(CommonApiError::from(e))
512            }
513        }
514    }
515}
516
517// -- impl RequestBuilderExt -- //
518
519/// Extension trait on [`reqwest::RequestBuilder`] for easily modifying requests
520/// as they're constructed.
521pub trait RequestBuilderExt: Sized {
522    /// Set the request body to a [`ed25519::Signed<T>`] serialized to BCS with
523    /// corresponding content type header.
524    fn signed_bcs<T>(
525        self,
526        signed_bcs: &ed25519::Signed<&T>,
527    ) -> Result<Self, bcs::Error>
528    where
529        T: ed25519::Signable + Serialize;
530}
531
532impl RequestBuilderExt for reqwest::RequestBuilder {
533    fn signed_bcs<T>(
534        self,
535        signed_bcs: &ed25519::Signed<&T>,
536    ) -> Result<Self, bcs::Error>
537    where
538        T: ed25519::Signable + Serialize,
539    {
540        let bytes = signed_bcs.serialize()?;
541        let request = self
542            .header(CONTENT_TYPE, CONTENT_TYPE_ED25519_BCS.clone())
543            .body(bytes);
544        Ok(request)
545    }
546}