use crate::{
header::{TRAILER, X_STREAM_ERROR_KEY},
read::{JsonLineDecoder, StreamReader},
ApiError, ApiRequest,
};
use async_trait::async_trait;
use bytes::Bytes;
use common_multipart_rfc7578::client::multipart;
use futures::{future, FutureExt, Stream, StreamExt, TryStreamExt};
use http::{
header::{HeaderName, HeaderValue},
StatusCode,
};
use serde::{Deserialize, Serialize};
use std::fmt::Display;
use tokio_util::codec::{Decoder, FramedRead};
#[async_trait(?Send)]
pub trait Backend {
type HttpRequest;
type HttpResponse;
type Error: Display + From<ApiError> + From<crate::Error> + 'static;
fn build_base_request<Req>(
&self,
req: &Req,
form: Option<multipart::Form<'static>>,
) -> Result<Self::HttpRequest, Self::Error>
where
Req: ApiRequest;
fn get_header(res: &Self::HttpResponse, key: HeaderName) -> Option<&HeaderValue>;
async fn request_raw<Req>(
&self,
req: Req,
form: Option<multipart::Form<'static>>,
) -> Result<(StatusCode, Bytes), Self::Error>
where
Req: ApiRequest + Serialize;
fn response_to_byte_stream(
res: Self::HttpResponse,
) -> Box<dyn Stream<Item = Result<Bytes, Self::Error>> + Unpin>;
fn request_stream<Res, F, OutStream>(
&self,
req: Self::HttpRequest,
process: F,
) -> Box<dyn Stream<Item = Result<Res, Self::Error>> + Unpin>
where
OutStream: Stream<Item = Result<Res, Self::Error>> + Unpin,
F: 'static + Fn(Self::HttpResponse) -> OutStream;
#[inline]
fn process_error_from_body(body: Bytes) -> Self::Error {
match serde_json::from_slice::<ApiError>(&body) {
Ok(e) => e.into(),
Err(_) => {
let err = match String::from_utf8(body.to_vec()) {
Ok(s) => crate::Error::UnrecognizedApiError(s),
Err(e) => crate::Error::from(e),
};
err.into()
}
}
}
fn process_json_response<Res>(status: StatusCode, body: Bytes) -> Result<Res, Self::Error>
where
for<'de> Res: 'static + Deserialize<'de>,
{
match status {
StatusCode::OK => serde_json::from_slice(&body)
.map_err(crate::Error::from)
.map_err(Self::Error::from),
_ => Err(Self::process_error_from_body(body)),
}
}
fn process_stream_response<D, Res>(
res: Self::HttpResponse,
decoder: D,
) -> FramedRead<StreamReader<Box<dyn Stream<Item = Result<Bytes, Self::Error>> + Unpin>>, D>
where
D: Decoder<Item = Res, Error = crate::Error>,
{
FramedRead::new(
StreamReader::new(Self::response_to_byte_stream(res)),
decoder,
)
}
async fn request<Req, Res>(
&self,
req: Req,
form: Option<multipart::Form<'static>>,
) -> Result<Res, Self::Error>
where
Req: ApiRequest + Serialize,
for<'de> Res: 'static + Deserialize<'de>,
{
let (status, chunk) = self.request_raw(req, form).await?;
Self::process_json_response(status, chunk)
}
async fn request_empty<Req>(
&self,
req: Req,
form: Option<multipart::Form<'static>>,
) -> Result<(), Self::Error>
where
Req: ApiRequest + Serialize,
{
let (status, chunk) = self.request_raw(req, form).await?;
match status {
StatusCode::OK => Ok(()),
_ => Err(Self::process_error_from_body(chunk)),
}
}
async fn request_string<Req>(
&self,
req: Req,
form: Option<multipart::Form<'static>>,
) -> Result<String, Self::Error>
where
Req: ApiRequest + Serialize,
{
let (status, chunk) = self.request_raw(req, form).await?;
match status {
StatusCode::OK => String::from_utf8(chunk.to_vec())
.map_err(crate::Error::from)
.map_err(Self::Error::from),
_ => Err(Self::process_error_from_body(chunk)),
}
}
fn request_stream_bytes(
&self,
req: Self::HttpRequest,
) -> Box<dyn Stream<Item = Result<Bytes, Self::Error>> + Unpin> {
self.request_stream(req, |res| Self::response_to_byte_stream(res))
}
fn request_stream_json<Res>(
&self,
req: Self::HttpRequest,
) -> Box<dyn Stream<Item = Result<Res, Self::Error>> + Unpin>
where
for<'de> Res: 'static + Deserialize<'de>,
{
self.request_stream(req, |res| {
let parse_stream_error = if let Some(trailer) = Self::get_header(&res, TRAILER) {
if trailer == X_STREAM_ERROR_KEY {
true
} else {
let err = crate::Error::UnrecognizedTrailerHeader(
String::from_utf8_lossy(trailer.as_ref()).into(),
);
return future::err(err).into_stream().err_into().left_stream();
}
} else {
false
};
Self::process_stream_response(res, JsonLineDecoder::new(parse_stream_error))
.err_into()
.right_stream()
})
}
}