use std::str::FromStr;
use serde::{de::DeserializeOwned, Deserialize, Deserializer};
use tracing::{error, warn};
use ureq::{
http::{
header::{ACCEPT, AUTHORIZATION, CONTENT_TYPE},
HeaderName, HeaderValue, Response, StatusCode
},
tls::TlsConfig,
Agent, Body, RequestBuilder, ResponseExt
};
use crate::errors::{Error, Result};
pub(crate) trait ResponseToOption: Sized {
fn to_option<T>(&mut self) -> Result<Option<T>>
where
T: DeserializeOwned;
fn check_error(self) -> Result<Self>;
}
impl ResponseToOption for Response<Body> {
fn to_option<T>(&mut self) -> Result<Option<T>>
where
T: DeserializeOwned
{
match self.status() {
StatusCode::OK => {
let body = self.body_mut().read_to_string()?;
let obj: T = serde_json::from_str(&body)?;
Ok(Some(obj))
}
StatusCode::NOT_FOUND => {
warn!("Record doesn't exist: {}", self.get_uri());
Ok(None)
}
_ => {
let body = self.body_mut().read_to_string()?;
Err(Error::ApiError(format!("Api Error: {} -> {body}", self.status())))
}
}
}
fn check_error(mut self) -> Result<Self> {
let code = self.status();
if code.is_success() {
return Ok(self)
}
let err = self.body_mut()
.read_to_string()?;
error!("REST op failed: {code} {err:?}");
Err(Error::HttpError(format!("REST op failed: {code} {err:?}")))
}
}
pub(crate) trait WithHeaders<T> {
fn with_headers(self, headers: Vec<(&str, String)>) -> Result<RequestBuilder<T>>;
fn with_auth(self, auth: String) -> RequestBuilder<T>;
fn with_json_headers(self) -> RequestBuilder<T>;
}
impl<Any> WithHeaders<Any> for RequestBuilder<Any> {
fn with_headers(mut self, headers: Vec<(&str, String)>) -> Result<Self> {
let reqh = self.headers_mut()
.ok_or(Error::HttpError("Failed to get headers from ureq".to_string()))?;
for (k, v) in headers {
reqh.insert(HeaderName::from_str(k)?, HeaderValue::from_str(&v)?);
}
Ok(self)
}
fn with_auth(self, auth: String) -> Self {
self.header(AUTHORIZATION, auth)
}
fn with_json_headers(self) -> Self {
self.header(ACCEPT, "application/json")
.header(CONTENT_TYPE, "application/json")
}
}
pub(crate) fn client() -> Agent {
Agent::config_builder()
.http_status_as_error(false)
.tls_config(
TlsConfig::builder()
.provider(ureq::tls::TlsProvider::NativeTls)
.build()
)
.build()
.new_agent()
}
pub(crate) fn de_str<'de, T, D>(destr: D) -> std::result::Result<T, D::Error>
where
T: FromStr,
T::Err: std::fmt::Display,
D: Deserializer<'de>,
{
let s = String::deserialize(destr)?;
T::from_str(&s)
.map_err(serde::de::Error::custom)
}