use http::header::{ACCEPT, CONTENT_TYPE};
use http::HeaderMap;
use reqwest::{Method, StatusCode};
use serde::de::DeserializeOwned;
use serde::Serialize;
use crate::client::{Auth, SdkClient};
use crate::expiry::parse_ch_expiry;
use crate::message::parse_error_messages_slice;
use crate::negotiated::NegotiatedResponse;
use crate::SdkError;
const BODY_SNIPPET_MAX: usize = 2048;
pub struct SdkRequest<'a> {
client: &'a SdkClient,
method: Method,
url: url::Url,
query_pairs: Vec<(String, String)>,
accept: Option<String>,
body: Option<(String, Vec<u8>)>,
}
impl<'a> SdkRequest<'a> {
pub(crate) fn new(client: &'a SdkClient, method: Method, path: impl AsRef<str>) -> crate::SdkResult<Self> {
let path = path.as_ref().trim_start_matches('/');
let url = client.inner.base_url.join(path)?;
Ok(Self {
client,
method,
url,
query_pairs: Vec::new(),
accept: None,
body: None,
})
}
pub fn query_pair(mut self, key: impl AsRef<str>, value: impl AsRef<str>) -> Self {
self.query_pairs.push((
key.as_ref().to_string(),
value.as_ref().to_string(),
));
self
}
pub fn query(mut self, q: &impl Serialize) -> crate::SdkResult<Self> {
let encoded = serde_urlencoded::to_string(q)?;
for (k, v) in url::form_urlencoded::parse(encoded.as_bytes()) {
self.query_pairs
.push((k.into_owned(), v.into_owned()));
}
Ok(self)
}
pub fn accept_mime(mut self, mime: impl Into<String>) -> Self {
self.accept = Some(mime.into());
self
}
pub fn vendor_json_body(
mut self,
content_type: impl Into<String>,
body: &impl Serialize,
) -> crate::SdkResult<Self> {
let bytes = serde_json::to_vec(body)?;
self.body = Some((content_type.into(), bytes));
Ok(self)
}
pub async fn send_json<T: DeserializeOwned>(self) -> crate::SdkResult<NegotiatedResponse<T>> {
let (status, headers, bytes) = self.send_raw().await?;
map_success_response(status, &headers, &bytes)
}
pub async fn send_empty(self) -> crate::SdkResult<NegotiatedResponse<()>> {
let (status, headers, bytes) = self.send_raw().await?;
if !status.is_success() {
return Err(map_error_status(status, &headers, &bytes));
}
if !bytes.is_empty() {
return Err(SdkError::UnexpectedResponse {
status,
body_snippet: snippet(bytes.as_ref()),
});
}
let content_type = content_type_header(&headers);
let deprecation = parse_ch_expiry(&headers);
Ok(NegotiatedResponse {
body: (),
content_type,
deprecation,
})
}
async fn send_raw(self) -> Result<(StatusCode, HeaderMap, Vec<u8>), SdkError> {
if let Some(lim) = &self.client.inner.limiter {
lim.acquire().await;
}
let mut url = self.url.clone();
{
let mut pairs = url.query_pairs_mut();
for (k, v) in &self.query_pairs {
pairs.append_pair(k, v);
}
}
let mut rb = self
.client
.inner
.http
.request(self.method, url);
match &self.client.inner.auth {
Auth::ApiKey { key } => {
rb = rb.basic_auth(key, Some(""));
}
Auth::Bearer { token } => {
rb = rb.bearer_auth(token);
}
}
if let Some(accept) = self.accept {
rb = rb.header(ACCEPT, accept);
}
if let Some((ct, body)) = self.body {
rb = rb.header(CONTENT_TYPE, ct).body(body);
}
let resp = rb.send().await?;
let status = resp.status();
let headers = resp.headers().clone();
let bytes = resp.bytes().await?.to_vec();
Ok((status, headers, bytes))
}
}
fn content_type_header(headers: &HeaderMap) -> Option<String> {
headers
.get(CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.map(String::from)
}
fn snippet(bytes: &[u8]) -> String {
let s = String::from_utf8_lossy(bytes);
let s = s.trim();
if s.len() <= BODY_SNIPPET_MAX {
s.to_string()
} else {
format!("{}…", &s[..BODY_SNIPPET_MAX])
}
}
fn parse_retry_after(headers: &HeaderMap) -> Option<std::time::Duration> {
let h = headers.get(http::header::RETRY_AFTER)?.to_str().ok()?;
if let Ok(secs) = h.parse::<u64>() {
return Some(std::time::Duration::from_secs(secs));
}
None
}
fn map_error_status(status: StatusCode, _headers: &HeaderMap, bytes: &[u8]) -> SdkError {
let body_snippet = snippet(bytes);
match status {
s if s == StatusCode::UNAUTHORIZED => SdkError::Unauthorized,
s if s == StatusCode::TOO_MANY_REQUESTS => SdkError::RateLimited {
retry_after: parse_retry_after(_headers),
},
s if s == StatusCode::NOT_ACCEPTABLE => SdkError::NotAcceptable { body_snippet },
s if s == StatusCode::GONE => SdkError::Gone { body_snippet },
_ => {
let messages = parse_error_messages_slice(bytes);
if messages.is_empty() {
SdkError::UnexpectedResponse {
status,
body_snippet,
}
} else {
SdkError::Api { status, messages }
}
}
}
}
fn map_success_response<T: DeserializeOwned>(
status: StatusCode,
headers: &HeaderMap,
bytes: &[u8],
) -> crate::SdkResult<NegotiatedResponse<T>> {
if !status.is_success() {
return Err(map_error_status(status, headers, bytes));
}
let body = serde_json::from_slice(bytes)?;
Ok(NegotiatedResponse {
body,
content_type: content_type_header(headers),
deprecation: parse_ch_expiry(headers),
})
}
#[cfg(test)]
mod tests {
#[test]
fn query_pairs_merge_into_url() {
let u = url::Url::parse("https://example.com/base/").unwrap();
let mut url = u.join("search/companies").unwrap();
{
let mut pairs = url.query_pairs_mut();
pairs.append_pair("q", "test");
pairs.append_pair("items_per_page", "10");
}
let s = url.to_string();
assert!(s.contains('q'));
assert!(s.contains("test"));
}
}