use bytes::Bytes;
use http::{header::CONTENT_TYPE, HeaderMap, HeaderName, HeaderValue, Method, Request, Response};
use serde::{Deserialize, Serialize};
use super::builder::Client;
#[derive(Debug, thiserror::Error)]
pub enum ClientError {
#[error("transport error: {0}")]
Transport(#[source] crate::BoxError),
#[error("serialization error: {0}")]
Serialization(#[from] serde_json::Error),
#[error("invalid URL: {0}")]
InvalidUrl(String),
}
#[derive(Clone, Debug)]
pub struct RestClient {
client: Client,
base_url: reqwest::Url,
default_headers: HeaderMap,
}
impl RestClient {
pub(crate) fn new(client: Client, base_url: reqwest::Url, default_headers: HeaderMap) -> Self {
Self {
client,
base_url,
default_headers,
}
}
pub fn request(&self, method: Method, path: &str) -> Result<RequestBuilder, ClientError> {
let url = self
.base_url
.join(path)
.map_err(|e| ClientError::InvalidUrl(e.to_string()))?;
Ok(RequestBuilder::new(
self.client.clone(),
method,
url,
self.default_headers.clone(),
))
}
pub fn get(&self, path: &str) -> Result<RequestBuilder, ClientError> {
self.request(Method::GET, path)
}
pub fn post(&self, path: &str) -> Result<RequestBuilder, ClientError> {
self.request(Method::POST, path)
}
pub fn put(&self, path: &str) -> Result<RequestBuilder, ClientError> {
self.request(Method::PUT, path)
}
pub fn delete(&self, path: &str) -> Result<RequestBuilder, ClientError> {
self.request(Method::DELETE, path)
}
pub fn patch(&self, path: &str) -> Result<RequestBuilder, ClientError> {
self.request(Method::PATCH, path)
}
pub async fn post_json<Req, Resp>(&self, path: &str, body: &Req) -> Result<Resp, ClientError>
where
Req: Serialize,
Resp: for<'de> Deserialize<'de>,
{
let resp = self
.post(path)?
.header(CONTENT_TYPE, HeaderValue::from_static("application/json"))?
.json(body)
.send()
.await?;
let body = resp.into_body();
Ok(serde_json::from_slice(&body)?)
}
pub async fn json<T>(&self, path: &str) -> Result<T, ClientError>
where
T: for<'de> Deserialize<'de>,
{
let resp = self.get(path)?.send().await?;
let body = resp.into_body();
Ok(serde_json::from_slice(&body)?)
}
pub async fn text(&self, path: &str) -> Result<String, ClientError> {
let resp = self.get(path)?.send().await?;
let body = resp.into_body();
String::from_utf8(body.to_vec()).map_err(|e| {
ClientError::Transport(Box::new(std::io::Error::new(
std::io::ErrorKind::InvalidData,
e,
)))
})
}
pub async fn bytes(&self, path: &str) -> Result<Bytes, ClientError> {
let resp = self.get(path)?.send().await?;
Ok(resp.into_body())
}
}
#[derive(Clone, Debug)]
pub struct RequestBuilder {
client: Client,
method: Method,
url: reqwest::Url,
headers: HeaderMap,
body: Option<Bytes>,
}
impl RequestBuilder {
fn new(client: Client, method: Method, url: reqwest::Url, default_headers: HeaderMap) -> Self {
Self {
client,
method,
url,
headers: default_headers,
body: None,
}
}
pub fn header<N, V>(mut self, name: N, value: V) -> Result<Self, ClientError>
where
N: TryInto<HeaderName>,
N::Error: std::fmt::Display,
V: TryInto<HeaderValue>,
V::Error: std::fmt::Display,
{
let name = name
.try_into()
.map_err(|e| ClientError::InvalidUrl(e.to_string()))?;
let value = value
.try_into()
.map_err(|e| ClientError::InvalidUrl(e.to_string()))?;
self.headers.append(name, value);
Ok(self)
}
pub fn json<T: Serialize>(mut self, value: &T) -> Self {
self.body = Some(Bytes::from(serde_json::to_vec(value).unwrap_or_default()));
self
}
pub fn body(mut self, body: impl Into<Bytes>) -> Self {
self.body = Some(body.into());
self
}
pub async fn send(self) -> Result<Response<Bytes>, ClientError> {
let mut builder = Request::builder().method(self.method).uri(self.url.as_str());
for (name, value) in &self.headers {
builder = builder.header(name, value);
}
let req = builder
.body(self.body.unwrap_or_default())
.map_err(|e| ClientError::Transport(Box::new(e)))?;
self.client.execute(req).await
}
}
#[cfg(test)]
mod tests {
use super::*;
use bytes::Bytes;
use http::{header::CONTENT_TYPE, HeaderMap, Method, Request, Response, StatusCode};
fn mock_client() -> Client {
let service = tower::service_fn(|req: Request<Bytes>| async move {
let (parts, body) = req.into_parts();
let path = parts.uri.path().to_string();
let method = parts.method.clone();
Ok::<_, crate::BoxError>(match (method, path.as_str()) {
(Method::GET, "/json") => Response::builder()
.header(CONTENT_TYPE, "application/json")
.body(Bytes::from_static(br#"{"hello":"world"}"#))
.unwrap(),
(Method::GET, "/text") => Response::builder()
.body(Bytes::from_static(b"plain text"))
.unwrap(),
(Method::GET, "/bytes") => Response::builder()
.body(Bytes::from_static(b"raw bytes"))
.unwrap(),
(Method::POST, "/echo") => Response::builder().body(body).unwrap(),
(Method::PUT, "/put") => Response::builder().status(204).body(Bytes::new()).unwrap(),
(Method::DELETE, "/delete") => Response::builder().status(204).body(Bytes::new()).unwrap(),
(Method::PATCH, "/patch") => Response::builder().status(204).body(Bytes::new()).unwrap(),
(Method::GET, "/headers") => {
let val = parts
.headers
.get("x-test")
.and_then(|v| v.to_str().ok())
.unwrap_or("missing");
Response::builder().body(Bytes::from(val.to_string())).unwrap()
}
_ => Response::builder()
.status(StatusCode::NOT_FOUND)
.body(Bytes::new())
.unwrap(),
})
});
Client::from_service(
crate::client::builder::BoxedClientService::new(service),
reqwest::Url::parse("http://example.com").unwrap(),
HeaderMap::new(),
)
}
#[tokio::test]
async fn test_rest_get_json() {
let client = mock_client();
let value: serde_json::Value = client.rest().json("/json").await.unwrap();
assert_eq!(value["hello"], "world");
}
#[tokio::test]
async fn test_rest_text() {
let client = mock_client();
let text = client.rest().text("/text").await.unwrap();
assert_eq!(text, "plain text");
}
#[tokio::test]
async fn test_rest_bytes() {
let client = mock_client();
let bytes = client.rest().bytes("/bytes").await.unwrap();
assert_eq!(bytes.as_ref(), b"raw bytes");
}
#[tokio::test]
async fn test_rest_post_json_and_methods() {
let client = mock_client();
#[derive(serde::Serialize)]
struct Echo {
message: String,
}
let resp: serde_json::Value = client
.rest()
.post_json("/echo", &Echo { message: "hi".into() })
.await
.unwrap();
assert_eq!(resp["message"], "hi");
let put_resp = client.rest().put("/put").unwrap().send().await.unwrap();
assert_eq!(put_resp.status(), StatusCode::NO_CONTENT);
let del_resp = client.rest().delete("/delete").unwrap().send().await.unwrap();
assert_eq!(del_resp.status(), StatusCode::NO_CONTENT);
let patch_resp = client.rest().patch("/patch").unwrap().send().await.unwrap();
assert_eq!(patch_resp.status(), StatusCode::NO_CONTENT);
}
#[tokio::test]
async fn test_rest_request_builder_header_and_body() {
let client = mock_client();
let resp = client
.rest()
.request(Method::GET, "/headers")
.unwrap()
.header("x-test", "present")
.unwrap()
.body("ignored")
.send()
.await
.unwrap();
assert_eq!(resp.into_body().as_ref(), b"present");
}
#[tokio::test]
async fn test_rest_request_builder_invalid_header() {
let client = mock_client();
let err = client
.rest()
.get("/")
.unwrap()
.header("x-test", "\0")
.unwrap_err();
assert!(matches!(err, ClientError::InvalidUrl(_)));
}
#[test]
fn test_client_error_display() {
let err = ClientError::InvalidUrl("bad".to_string());
assert!(format!("{err}").contains("bad"));
}
}