use crate::config::Config;
use crate::error::{Error, Result};
use reqwest::header::{HeaderMap, HeaderValue};
use serde::{de::DeserializeOwned, Serialize};
use std::sync::Arc;
const BODY_SNIPPET_LIMIT: usize = 2_000;
const SUBSCRIPTION_KEY_HEADER: &str = "Ocp-Apim-Subscription-Key";
const API_KEY_HEADER: &str = "x-api-key";
#[derive(Debug, Clone)]
pub struct Client {
config: Arc<Config>,
subscription_key: Arc<str>,
http: reqwest::Client,
}
impl Client {
pub fn new(config: Config) -> Result<Self> {
validate_header_value("subscription_key", &config.subscription_key)?;
validate_header_value("api_key", &config.api_key)?;
if let Some(access) = &config.access_key {
validate_header_value("access_key", access)?;
}
let mut headers = HeaderMap::new();
headers.insert(
reqwest::header::CONTENT_TYPE,
HeaderValue::from_static("application/json"),
);
headers.insert(
reqwest::header::CACHE_CONTROL,
HeaderValue::from_static("no-cache"),
);
let http = reqwest::Client::builder()
.default_headers(headers)
.timeout(config.timeout)
.build()
.map_err(|e| Error::Configuration(format!("failed to build HTTP client: {e}")))?;
let subscription_key = Arc::from(config.subscription_key.as_str());
Ok(Self {
config: Arc::new(config),
subscription_key,
http,
})
}
pub fn with_subscription_key(&self, subscription_key: impl AsRef<str>) -> Result<Self> {
let key = subscription_key.as_ref();
validate_header_value("subscription_key", key)?;
Ok(Self {
config: self.config.clone(),
subscription_key: Arc::from(key),
http: self.http.clone(),
})
}
pub fn config(&self) -> &Config {
&self.config
}
pub fn subscription_key(&self) -> &str {
&self.subscription_key
}
pub fn http_client(&self) -> &reqwest::Client {
&self.http
}
fn url(&self, path: &str) -> String {
format!(
"{}/{}",
self.config.base_url.trim_end_matches('/'),
path.trim_start_matches('/')
)
}
fn auth_headers(&self) -> Vec<(&'static str, String)> {
vec![
(SUBSCRIPTION_KEY_HEADER, self.subscription_key.to_string()),
(API_KEY_HEADER, self.config.api_key.clone()),
]
}
pub async fn get_json<T>(
&self,
path: &str,
query: &[(&str, &str)],
extra_headers: &[(&'static str, String)],
) -> Result<T>
where
T: DeserializeOwned,
{
let mut req = self.http.get(self.url(path));
if !query.is_empty() {
req = req.query(query);
}
req = self.apply_all_headers(req, extra_headers);
self.send(req).await
}
pub async fn post_json<B, T>(
&self,
path: &str,
body: &B,
extra_headers: &[(&'static str, String)],
) -> Result<T>
where
B: Serialize + ?Sized,
T: DeserializeOwned,
{
let req = self.apply_all_headers(self.http.post(self.url(path)).json(body), extra_headers);
self.send(req).await
}
fn apply_all_headers(
&self,
mut req: reqwest::RequestBuilder,
extra_headers: &[(&'static str, String)],
) -> reqwest::RequestBuilder {
for (name, value) in self.auth_headers() {
req = req.header(name, value);
}
for (name, value) in extra_headers {
req = req.header(*name, value);
}
req
}
async fn send<T>(&self, req: reqwest::RequestBuilder) -> Result<T>
where
T: DeserializeOwned,
{
let response = req.send().await?;
let status = response.status();
let body = response.text().await?;
if !status.is_success() {
return Err(Error::Http { status, body });
}
serde_json::from_str::<T>(&body).map_err(|e| Error::Decode {
message: e.to_string(),
body: truncate(&body),
})
}
pub(crate) fn access_headers(&self, required: bool) -> Result<Vec<(&'static str, String)>> {
match &self.config.access_key {
Some(key) => Ok(vec![("access", key.clone())]),
None if required => Err(Error::Configuration(
"this endpoint requires the bills/airtime `access` key; set it via \
Config::with_access_key(...)"
.into(),
)),
None => Ok(Vec::new()),
}
}
}
fn validate_header_value(name: &str, value: &str) -> Result<()> {
HeaderValue::from_str(value)
.map(|_| ())
.map_err(|_| Error::Configuration(format!("{name} contains invalid header characters")))
}
fn truncate(s: &str) -> String {
if s.len() <= BODY_SNIPPET_LIMIT {
s.to_string()
} else {
format!("{}… [truncated, {} bytes total]", &s[..BODY_SNIPPET_LIMIT], s.len())
}
}