use std::{
borrow::Cow,
time::{Duration, Instant},
};
use bytes::Bytes;
use http::{
Method,
header::{CONTENT_TYPE, HeaderValue},
};
use lexe_api_core::error::{
ApiError, CommonApiError, CommonErrorKind, ErrorCode, ErrorResponse,
};
use lexe_common::time::DisplayMs;
use lexe_crypto::ed25519;
use lexe_std::backoff;
use lightning::util::ser::Writeable;
use reqwest::IntoUrl;
use serde::{Serialize, de::DeserializeOwned};
use tracing::{Instrument, debug, warn};
use crate::{trace, trace::TraceId};
pub static CONTENT_TYPE_ED25519_BCS: HeaderValue =
HeaderValue::from_static("application/ed25519-bcs");
pub const API_REQUEST_TIMEOUT: Duration = Duration::from_secs(30);
pub const GET: Method = Method::GET;
pub const PUT: Method = Method::PUT;
pub const POST: Method = Method::POST;
pub const DELETE: Method = Method::DELETE;
#[derive(Clone)]
pub struct RestClient {
client: reqwest::Client,
from: Cow<'static, str>,
to: &'static str,
}
impl RestClient {
pub fn new(
from: impl Into<Cow<'static, str>>,
to: &'static str,
tls_config: rustls::ClientConfig,
) -> Self {
fn inner(
from: Cow<'static, str>,
to: &'static str,
tls_config: rustls::ClientConfig,
) -> RestClient {
let client = RestClient::client_builder(&from)
.use_preconfigured_tls(tls_config)
.https_only(true)
.build()
.expect("Failed to build reqwest Client");
RestClient { client, from, to }
}
inner(from.into(), to, tls_config)
}
pub fn new_insecure(
from: impl Into<Cow<'static, str>>,
to: &'static str,
) -> Self {
fn inner(from: Cow<'static, str>, to: &'static str) -> RestClient {
let client = RestClient::client_builder(&from)
.https_only(false)
.build()
.expect("Failed to build reqwest Client");
RestClient { client, from, to }
}
inner(from.into(), to)
}
pub fn client_builder(from: impl AsRef<str>) -> reqwest::ClientBuilder {
fn inner(from: &str) -> reqwest::ClientBuilder {
reqwest::Client::builder()
.user_agent(from)
.https_only(true)
.timeout(API_REQUEST_TIMEOUT)
}
inner(from.as_ref())
}
pub fn from_inner(
client: reqwest::Client,
from: impl Into<Cow<'static, str>>,
to: &'static str,
) -> Self {
Self {
client,
from: from.into(),
to,
}
}
#[inline]
pub fn user_agent(&self) -> &Cow<'static, str> {
&self.from
}
#[inline]
pub fn get<U, T>(&self, url: U, data: &T) -> reqwest::RequestBuilder
where
U: IntoUrl,
T: Serialize + ?Sized,
{
self.builder(GET, url).query(data)
}
#[inline]
pub fn post<U, T>(&self, url: U, data: &T) -> reqwest::RequestBuilder
where
U: IntoUrl,
T: Serialize + ?Sized,
{
self.builder(POST, url).json(data)
}
#[inline]
pub fn put<U, T>(&self, url: U, data: &T) -> reqwest::RequestBuilder
where
U: IntoUrl,
T: Serialize + ?Sized,
{
self.builder(PUT, url).json(data)
}
#[inline]
pub fn delete<U, T>(&self, url: U, data: &T) -> reqwest::RequestBuilder
where
U: IntoUrl,
T: Serialize + ?Sized,
{
self.builder(DELETE, url).json(data)
}
#[inline]
pub fn serialize_ldk_writeable<U, W>(
&self,
method: Method,
url: U,
data: &W,
) -> reqwest::RequestBuilder
where
U: IntoUrl,
W: Writeable,
{
let bytes = {
let mut buf = Vec::new();
data.write(&mut buf)
.expect("Serializing into in-memory buf shouldn't fail");
Bytes::from(buf)
};
self.builder(method, url).body(bytes)
}
pub fn builder(
&self,
method: Method,
url: impl IntoUrl,
) -> reqwest::RequestBuilder {
self.client.request(method, url)
}
pub async fn send<T: DeserializeOwned, E: ApiError>(
&self,
request_builder: reqwest::RequestBuilder,
) -> Result<T, E> {
let bytes = self.send_no_deserialize::<E>(request_builder).await?;
Self::json_deserialize(bytes)
}
pub async fn send_no_deserialize<E: ApiError>(
&self,
request_builder: reqwest::RequestBuilder,
) -> Result<Bytes, E> {
let request = request_builder.build().map_err(CommonApiError::from)?;
let (request_span, trace_id) =
trace::client::request_span(&request, &self.from, self.to);
let response = self
.send_inner(request, &trace_id)
.instrument(request_span)
.await;
let res = match response {
Ok(Ok(resp)) => resp.read_bytes().await.map(Ok),
Ok(Err(api_error)) => Ok(Err(api_error)),
Err(common_error) => Err(common_error),
};
Self::map_response_errors::<Bytes, E>(res)
}
pub async fn send_and_stream_response<E: ApiError>(
&self,
request_builder: reqwest::RequestBuilder,
) -> Result<StreamBody, E> {
let request = request_builder.build().map_err(CommonApiError::from)?;
let (request_span, trace_id) =
trace::client::request_span(&request, &self.from, self.to);
let response = self
.send_inner(request, &trace_id)
.instrument(request_span)
.await;
Self::map_response_errors::<SuccessResponse, E>(response)
.map(|resp| resp.into_stream_body())
}
pub async fn send_with_retries<T: DeserializeOwned, E: ApiError>(
&self,
request_builder: reqwest::RequestBuilder,
retries: usize,
stop_codes: &[ErrorCode],
) -> Result<T, E> {
let request = request_builder.build().map_err(CommonApiError::from)?;
let (request_span, trace_id) =
trace::client::request_span(&request, &self.from, self.to);
let response = self
.send_with_retries_inner(request, retries, stop_codes, &trace_id)
.instrument(request_span)
.await;
let bytes = Self::map_response_errors::<Bytes, E>(response)?;
Self::json_deserialize(bytes)
}
async fn send_with_retries_inner(
&self,
request: reqwest::Request,
retries: usize,
stop_codes: &[ErrorCode],
trace_id: &TraceId,
) -> Result<Result<Bytes, ErrorResponse>, CommonApiError> {
let mut backoff_durations = backoff::get_backoff_iter();
let mut attempts_left = retries + 1;
let mut request = Some(request);
for _ in 0..retries {
tracing::Span::current().record("attempts_left", attempts_left);
let maybe_request_clone = request
.as_ref()
.expect(
"This should never happen; we only take() the original \
request on the last attempt",
)
.try_clone();
let request_clone = match maybe_request_clone {
Some(request_clone) => request_clone,
None => break,
};
match self.send_inner(request_clone, trace_id).await {
Ok(Ok(resp)) => match resp.read_bytes().await {
Ok(bytes) => {
return Ok(Ok(bytes));
}
Err(common_error) => {
if stop_codes.contains(&common_error.to_code()) {
return Err(common_error);
}
}
},
Ok(Err(api_error)) =>
if stop_codes.contains(&api_error.code) {
return Ok(Err(api_error));
},
Err(common_error) => {
if stop_codes.contains(&common_error.to_code()) {
return Err(common_error);
}
}
}
tokio::time::sleep(backoff_durations.next().unwrap()).await;
attempts_left -= 1;
}
assert_eq!(attempts_left, 1);
tracing::Span::current().record("attempts_left", attempts_left);
let resp = self.send_inner(request.take().unwrap(), trace_id).await?;
match resp {
Ok(resp_succ) => resp_succ.read_bytes().await.map(Ok),
Err(api_error) => Ok(Err(api_error)),
}
}
async fn send_inner(
&self,
mut request: reqwest::Request,
trace_id: &TraceId,
) -> Result<Result<SuccessResponse, ErrorResponse>, CommonApiError> {
let start = tokio::time::Instant::now().into_std();
debug!(target: trace::TARGET, "New client request");
match request.headers_mut().try_insert(
trace::TRACE_ID_HEADER_NAME.clone(),
trace_id.to_header_value(),
) {
Ok(None) => (),
Ok(Some(_)) => warn!(target: trace::TARGET, "Trace id existed?"),
Err(e) => warn!(target: trace::TARGET, "Header map full?: {e:#}"),
}
let resp = self.client.execute(request).await.inspect_err(|e| {
let req_time = DisplayMs(start.elapsed());
warn!(
target: trace::TARGET,
%req_time,
"Done (error)(sending) Error sending request: {e:#}"
);
})?;
let status = resp.status().as_u16();
if resp.status().is_success() {
Ok(Ok(SuccessResponse { resp, start }))
} else {
let error =
resp.json::<ErrorResponse>().await.inspect_err(|e| {
let req_time = DisplayMs(start.elapsed());
warn!(
target: trace::TARGET,
%req_time,
%status,
"Done (error)(receiving) \
Couldn't receive ErrorResponse: {e:#}",
);
})?;
let req_time = DisplayMs(start.elapsed());
warn!(
target: trace::TARGET,
%req_time,
%status,
error_code = %error.code,
error_msg = %error.msg,
"Done (error)(response) Server returned error response",
);
Ok(Err(error))
}
}
fn map_response_errors<T, E: ApiError>(
response: Result<Result<T, ErrorResponse>, CommonApiError>,
) -> Result<T, E> {
match response {
Ok(Ok(resp)) => Ok(resp),
Ok(Err(err_api)) => Err(E::from(err_api)),
Err(err_client) => Err(E::from(err_client)),
}
}
fn json_deserialize<T: DeserializeOwned, E: ApiError>(
bytes: Bytes,
) -> Result<T, E> {
serde_json::from_slice::<T>(&bytes)
.map_err(|err| {
let kind = CommonErrorKind::Decode;
let mut msg = format!("JSON deserialization failed: {err:#}");
if cfg!(any(debug_assertions, test, feature = "test-utils")) {
let resp_msg = String::from_utf8_lossy(&bytes);
msg.push_str(&format!(": '{resp_msg}'"));
}
CommonApiError::new(kind, msg)
})
.map_err(E::from)
}
}
struct SuccessResponse {
resp: reqwest::Response,
start: Instant,
}
impl SuccessResponse {
fn into_stream_body(self) -> StreamBody {
StreamBody {
resp: self.resp,
start: self.start,
}
}
async fn read_bytes(self) -> Result<Bytes, CommonApiError> {
let status = self.resp.status().as_u16();
let bytes = self.resp.bytes().await.inspect_err(|e| {
let req_time = DisplayMs(self.start.elapsed());
warn!(
target: trace::TARGET,
%req_time,
%status,
"Done (error)(receiving) \
Couldn't receive response body: {e:#}",
);
})?;
let req_time = DisplayMs(self.start.elapsed());
debug!(target: trace::TARGET, %req_time, %status, "Done (success)");
Ok(bytes)
}
}
pub struct StreamBody {
resp: reqwest::Response,
start: Instant,
}
impl StreamBody {
pub async fn next_chunk(
&mut self,
) -> Result<Option<Bytes>, CommonApiError> {
match self.resp.chunk().await {
Ok(Some(chunk)) => Ok(Some(chunk)),
Ok(None) => {
let status = self.resp.status().as_u16();
let req_time = DisplayMs(self.start.elapsed());
debug!(target: trace::TARGET, %req_time, %status, "Done (success)");
Ok(None)
}
Err(e) => {
let status = self.resp.status().as_u16();
let req_time = DisplayMs(self.start.elapsed());
warn!(
target: trace::TARGET,
%req_time,
%status,
"Done (error)(receiving) \
Couldn't receive streaming response chunk: {e:#}",
);
Err(CommonApiError::from(e))
}
}
}
}
pub trait RequestBuilderExt: Sized {
fn signed_bcs<T>(
self,
signed_bcs: &ed25519::Signed<&T>,
) -> Result<Self, bcs::Error>
where
T: ed25519::Signable + Serialize;
}
impl RequestBuilderExt for reqwest::RequestBuilder {
fn signed_bcs<T>(
self,
signed_bcs: &ed25519::Signed<&T>,
) -> Result<Self, bcs::Error>
where
T: ed25519::Signable + Serialize,
{
let bytes = signed_bcs.serialize()?;
let request = self
.header(CONTENT_TYPE, CONTENT_TYPE_ED25519_BCS.clone())
.body(bytes);
Ok(request)
}
}