use std::marker::PhantomData;
use futures_core::Stream;
use serde::de::DeserializeOwned;
use crate::client::{ApiClient, Method};
use crate::error::{ApiError, DefinedErrorBody};
use crate::sse::SseEvent;
fn parse_error_response(status: u16, body: String) -> ApiError {
if let Ok(parsed) = serde_json::from_str::<DefinedErrorBody>(&body) {
if parsed.defined && !parsed.code.is_empty() {
return ApiError::Defined {
status,
code: parsed.code,
message: parsed.message,
};
}
}
ApiError::Api {
status,
message: body,
}
}
pub struct ApiRequest<T> {
pub method: Method,
pub path: String,
pub query: Option<String>,
pub body: Option<String>,
pub _marker: PhantomData<T>,
}
impl<T> ApiRequest<T> {
pub fn new(method: Method, path: String) -> Self {
Self {
method,
path,
query: None,
body: None,
_marker: PhantomData,
}
}
pub fn query_raw(mut self, qs: impl Into<String>) -> Self {
self.query = Some(qs.into());
self
}
pub fn body_json(mut self, body: &impl serde::Serialize) -> Self {
self.body = Some(serde_json::to_string(body).expect("request body must be serializable"));
self
}
pub fn try_body_json(mut self, body: &impl serde::Serialize) -> Result<Self, ApiError> {
self.body = Some(serde_json::to_string(body)?);
Ok(self)
}
}
impl<T: DeserializeOwned> ApiRequest<T> {
pub async fn fetch(self, client: &(impl ApiClient + ?Sized)) -> Result<T, ApiError> {
let resp = client
.request(self.method, &self.path, self.query.as_deref(), self.body)
.await?;
let status = resp.status();
if !status.is_success() {
let body = resp.text().await.unwrap_or_default();
return Err(parse_error_response(status.as_u16(), body));
}
let text = resp.text().await?;
if text.is_empty() {
return serde_json::from_str("null").map_err(ApiError::from);
}
serde_json::from_str(&text).map_err(ApiError::from)
}
}
impl ApiRequest<()> {
pub async fn fetch_empty(self, client: &(impl ApiClient + ?Sized)) -> Result<(), ApiError> {
let resp = client
.request(self.method, &self.path, self.query.as_deref(), self.body)
.await?;
let status = resp.status();
if !status.is_success() {
let body = resp.text().await.unwrap_or_default();
return Err(parse_error_response(status.as_u16(), body));
}
Ok(())
}
}
impl<T> ApiRequest<T> {
pub async fn fetch_stream(
self,
client: &(impl ApiClient + ?Sized),
) -> Result<impl Stream<Item = Result<SseEvent, ApiError>>, ApiError> {
let stream = client
.request_stream(self.method, &self.path, self.query.as_deref())
.await?;
Ok(stream)
}
}
fn percent_encode(input: &str) -> String {
let mut out = String::with_capacity(input.len());
for byte in input.bytes() {
match byte {
b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
out.push(byte as char);
}
_ => {
out.push('%');
out.push_str(&format!("{:02X}", byte));
}
}
}
out
}
pub fn build_query_string(pairs: &[(&str, &dyn ToString)]) -> String {
pairs
.iter()
.map(|(k, v)| format!("{}={}", percent_encode(k), percent_encode(&v.to_string())))
.collect::<Vec<_>>()
.join("&")
}
#[cfg(test)]
#[allow(clippy::manual_async_fn)]
mod tests {
use super::*;
use crate::sse::SseStream;
use tokio::io::AsyncWriteExt;
use tokio::time::{Duration, sleep};
async fn mock_response(status: u16, body: &str) -> reqwest::Response {
let mut server = mockito::Server::new_async().await;
let _mock = server
.mock("GET", "/mock")
.with_status(status as usize)
.with_header("content-type", "application/json")
.with_body(body)
.create_async()
.await;
reqwest::get(&format!("{}/mock", server.url()))
.await
.unwrap()
}
fn make_reqwest_error() -> reqwest::Error {
reqwest::Client::new()
.get("http://localhost:1/x")
.header("bad\0header", "v")
.build()
.unwrap_err()
}
async fn malformed_chunked_response() -> reqwest::Response {
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let server = tokio::spawn(async move {
let (mut socket, _) = listener.accept().await.unwrap();
socket
.write_all(
b"HTTP/1.1 200 OK\r\n\
Content-Type: application/json\r\n\
Content-Length: 40\r\n\
Connection: close\r\n\
\r\n\
{}",
)
.await
.unwrap();
sleep(Duration::from_millis(250)).await;
socket.shutdown().await.unwrap();
});
let resp = reqwest::get(format!("http://{addr}")).await.unwrap();
let _ = server.await;
resp
}
struct MockClient {
status: u16,
body: String,
}
struct FailingClient;
struct MalformedBodyClient;
impl ApiClient for MockClient {
fn request(
&self,
_: Method,
_: &str,
_: Option<&str>,
_: Option<String>,
) -> impl std::future::Future<Output = Result<reqwest::Response, ApiError>> + Send {
let status = self.status;
let body = self.body.clone();
async move { Ok(mock_response(status, &body).await) }
}
fn request_stream(
&self,
_: Method,
_: &str,
_: Option<&str>,
) -> impl std::future::Future<Output = Result<SseStream, ApiError>> + Send {
async move {
let chunks: Vec<Result<bytes::Bytes, reqwest::Error>> =
vec![Ok(bytes::Bytes::from(&b"data: hi\n\n"[..]))];
Ok(SseStream::new(Box::pin(futures_util::stream::iter(chunks))))
}
}
}
impl ApiClient for FailingClient {
fn request(
&self,
_: Method,
_: &str,
_: Option<&str>,
_: Option<String>,
) -> impl std::future::Future<Output = Result<reqwest::Response, ApiError>> + Send {
async { Err(ApiError::Http(make_reqwest_error())) }
}
fn request_stream(
&self,
_: Method,
_: &str,
_: Option<&str>,
) -> impl std::future::Future<Output = Result<SseStream, ApiError>> + Send {
async { Err(ApiError::Http(make_reqwest_error())) }
}
}
impl ApiClient for MalformedBodyClient {
fn request(
&self,
_: Method,
_: &str,
_: Option<&str>,
_: Option<String>,
) -> impl std::future::Future<Output = Result<reqwest::Response, ApiError>> + Send {
async { Ok(malformed_chunked_response().await) }
}
fn request_stream(
&self,
_: Method,
_: &str,
_: Option<&str>,
) -> impl std::future::Future<Output = Result<SseStream, ApiError>> + Send {
async { Err(ApiError::Http(make_reqwest_error())) }
}
}
#[test]
fn api_request_builder() {
let req = ApiRequest::<String>::new(Method::GET, "/test".into());
assert_eq!(req.method, Method::GET);
assert_eq!(req.path, "/test");
assert!(req.query.is_none());
assert!(req.body.is_none());
let body = serde_json::json!({"x": 1});
let req = ApiRequest::<String>::new(Method::POST, "/x".into())
.query_raw("q=1")
.body_json(&body);
assert_eq!(req.query.as_deref(), Some("q=1"));
assert_eq!(req.body.as_deref(), Some(r#"{"x":1}"#));
}
#[test]
fn body_serialization() {
let req = ApiRequest::<String>::new(Method::POST, "/t".into())
.try_body_json(&serde_json::json!({"x": 1}))
.unwrap();
assert!(req.body.is_some());
#[derive(Debug)]
struct Bad;
impl serde::Serialize for Bad {
fn serialize<S: serde::Serializer>(&self, _: S) -> Result<S::Ok, S::Error> {
Err(serde::ser::Error::custom("fail"))
}
}
assert!(
ApiRequest::<String>::new(Method::POST, "/t".into())
.try_body_json(&Bad)
.is_err()
);
}
#[test]
#[should_panic(expected = "request body must be serializable")]
fn body_json_panics_on_bad_input() {
#[derive(Debug)]
struct Bad;
impl serde::Serialize for Bad {
fn serialize<S: serde::Serializer>(&self, _: S) -> Result<S::Ok, S::Error> {
Err(serde::ser::Error::custom("fail"))
}
}
let _ = ApiRequest::<String>::new(Method::POST, "/t".into()).body_json(&Bad);
}
#[test]
fn query_string_and_percent_encode() {
assert_eq!(build_query_string(&[]), "");
assert_eq!(build_query_string(&[("limit", &10)]), "limit=10");
assert_eq!(
build_query_string(&[("a", &"hello"), ("b", &42)]),
"a=hello&b=42"
);
assert_eq!(
build_query_string(&[("q", &"hello world"), ("x", &"a&b=c")]),
"q=hello%20world&x=a%26b%3Dc"
);
assert_eq!(percent_encode("abc-_.~123"), "abc-_.~123");
assert_eq!(percent_encode("&="), "%26%3D");
}
#[tokio::test]
async fn fetch_success_and_edge_cases() {
let client = MockClient {
status: 200,
body: r#""hello""#.into(),
};
assert_eq!(
ApiRequest::<String>::new(Method::GET, "/t".into())
.fetch(&client)
.await
.unwrap(),
"hello"
);
let client = MockClient {
status: 200,
body: String::new(),
};
let result: Option<String> = ApiRequest::new(Method::GET, "/t".into())
.fetch(&client)
.await
.unwrap();
assert_eq!(result, None);
let client = MockClient {
status: 200,
body: "not-json".into(),
};
assert!(
ApiRequest::<i32>::new(Method::GET, "/t".into())
.fetch(&client)
.await
.unwrap_err()
.to_string()
.starts_with("serialization error:")
);
}
#[tokio::test]
async fn fetch_error_responses() {
let client = MockClient {
status: 403,
body: "forbidden".into(),
};
let err = ApiRequest::<String>::new(Method::GET, "/t".into())
.fetch(&client)
.await
.unwrap_err();
assert!(matches!(err, ApiError::Api { status: 403, .. }));
let client = MockClient {
status: 404,
body: r#"{"defined":true,"code":"TEAM_NOT_FOUND","message":"Team not found"}"#.into(),
};
let err = ApiRequest::<String>::new(Method::GET, "/t".into())
.fetch(&client)
.await
.unwrap_err();
assert!(err.is_code("TEAM_NOT_FOUND"));
assert_eq!(err.status(), Some(404));
let client = MockClient {
status: 400,
body: r#"{"defined":false,"code":"NOPE","message":"nope"}"#.into(),
};
let err = ApiRequest::<String>::new(Method::GET, "/t".into())
.fetch(&client)
.await
.unwrap_err();
assert!(matches!(err, ApiError::Api { status: 400, .. }));
assert_eq!(err.code(), None);
}
#[tokio::test]
async fn fetch_empty_success_and_errors() {
let client = MockClient {
status: 204,
body: String::new(),
};
assert!(
ApiRequest::<()>::new(Method::DELETE, "/t".into())
.fetch_empty(&client)
.await
.is_ok()
);
let client = MockClient {
status: 500,
body: "oops".into(),
};
assert!(matches!(
ApiRequest::<()>::new(Method::DELETE, "/t".into())
.fetch_empty(&client)
.await
.unwrap_err(),
ApiError::Api { status: 500, .. }
));
let client = MockClient {
status: 403,
body: r#"{"defined":true,"code":"FORBIDDEN","message":"no access"}"#.into(),
};
let err = ApiRequest::<()>::new(Method::DELETE, "/t".into())
.fetch_empty(&client)
.await
.unwrap_err();
assert!(err.is_code("FORBIDDEN"));
assert!(
ApiRequest::<()>::new(Method::DELETE, "/t".into())
.fetch_empty(&FailingClient)
.await
.unwrap_err()
.to_string()
.starts_with("HTTP error:")
);
}
#[tokio::test]
async fn fetch_stream_success_and_errors() {
use futures_util::StreamExt;
let client = MockClient {
status: 200,
body: String::new(),
};
let mut stream = ApiRequest::<()>::new(Method::GET, "/sse".into())
.fetch_stream(&client)
.await
.unwrap();
assert_eq!(stream.next().await.unwrap().unwrap().data, "hi");
assert!(
ApiRequest::<()>::new(Method::GET, "/sse".into())
.fetch_stream(&FailingClient)
.await
.is_err()
);
}
#[tokio::test]
async fn fetch_propagates_body_read_error() {
let err = ApiRequest::<String>::new(Method::GET, "/t".into())
.fetch(&MalformedBodyClient)
.await
.unwrap_err();
assert!(err.to_string().starts_with("HTTP error:"));
}
}