use std::collections::HashMap;
use std::convert::TryFrom;
use std::sync::Arc;
use std::time::Duration;
#[cfg(feature = "stream")]
use async_trait::async_trait;
#[cfg(feature = "stream")]
use futures::Stream;
use http::header::{HeaderMap, HeaderName, HeaderValue};
use http::Error as HttpError;
use log::trace;
use reqwest::{Body, Client, Method, Request, RequestBuilder as HttpRequestBuilder, Response, Url};
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use static_assertions::assert_eq_size;
#[cfg(feature = "stream")]
use super::stream::{paginated, FetchNext, PaginatedResource};
use super::url as url_utils;
use super::{AuthType, EndpointFilters, Error};
pub const NO_PATH: Option<&'static str> = None;
#[derive(Debug, Clone)]
pub struct AuthenticatedClient {
client: Client,
auth: Arc<dyn AuthType>,
}
assert_eq_size!(AuthenticatedClient, Option<AuthenticatedClient>);
impl AuthenticatedClient {
pub async fn new<Auth: AuthType + 'static>(
client: Client,
auth_type: Auth,
) -> Result<AuthenticatedClient, Error> {
auth_type.refresh(&client).await?;
Ok(AuthenticatedClient::new_internal(
client,
Arc::new(auth_type),
))
}
#[inline]
pub(crate) fn new_internal(client: Client, auth: Arc<dyn AuthType>) -> AuthenticatedClient {
AuthenticatedClient { client, auth }
}
#[inline]
pub fn auth_type(&self) -> &dyn AuthType {
self.auth.as_ref()
}
#[inline]
async fn authenticate(&self, request: HttpRequestBuilder) -> Result<Request, Error> {
self.auth
.authenticate(&self.client, request)
.await?
.build()
.map_err(Error::from)
}
#[inline]
pub async fn get_endpoint(
&self,
service_type: &str,
filters: &EndpointFilters,
) -> Result<Url, Error> {
self.auth
.get_endpoint(&self.client, service_type, filters)
.await
}
#[inline]
pub fn inner(&self) -> &Client {
&self.client
}
#[inline]
pub async fn refresh(&mut self) -> Result<(), Error> {
self.auth.refresh(&self.client).await
}
#[inline]
pub fn set_auth_type<Auth: AuthType + 'static>(&mut self, auth_type: Auth) {
self.auth = Arc::new(auth_type);
}
#[inline]
pub fn set_inner(&mut self, client: Client) {
self.client = client;
}
#[inline]
pub fn request(&self, method: Method, url: Url) -> RequestBuilder {
RequestBuilder {
inner: self.client.request(method, url),
client: self.clone(),
}
}
#[cfg(test)]
pub(crate) async fn new_noauth(endpoint: &str) -> AuthenticatedClient {
use crate::NoAuth;
AuthenticatedClient::new(Client::new(), NoAuth::new(endpoint).unwrap())
.await
.unwrap()
}
}
impl From<AuthenticatedClient> for Client {
fn from(value: AuthenticatedClient) -> Client {
value.client
}
}
#[derive(Debug)]
#[must_use = "preparing a request is not enough to run it"]
pub struct RequestBuilder {
inner: HttpRequestBuilder,
client: AuthenticatedClient,
}
#[derive(Debug, Deserialize)]
struct Message {
message: Option<String>,
faultstring: Option<String>,
title: Option<String>,
error_message: Option<String>,
}
impl Message {
fn convert(self, recursive: bool) -> Option<String> {
if let Some(value) = self.message.or(self.faultstring).or(self.title) {
Some(value)
} else if recursive {
if let Some(json) = self.error_message {
serde_json::from_str::<Message>(&json)
.ok()
.and_then(|msg| msg.convert(false))
} else {
None
}
} else {
None
}
}
}
impl From<Message> for Option<String> {
fn from(value: Message) -> Option<String> {
value.convert(true)
}
}
#[derive(Debug, Deserialize)]
#[serde(untagged)]
enum ErrorResponse {
Map(HashMap<String, Message>),
Message(Message),
}
fn extract_message(text: String) -> String {
serde_json::from_str::<ErrorResponse>(&text)
.ok()
.and_then(|body| match body {
ErrorResponse::Map(map) => map.into_iter().next().and_then(|(_k, v)| v.into()),
ErrorResponse::Message(msg) => msg.into(),
})
.unwrap_or(text)
}
pub async fn check(response: Response) -> Result<Response, Error> {
let status = response.status();
if status.is_client_error() || status.is_server_error() {
let message = extract_message(response.text().await?);
trace!("HTTP request returned {}; error: {}", status, message);
Err(Error::new(status.into(), message).with_status(status))
} else {
trace!(
"HTTP request to {} returned {}",
response.url(),
response.status()
);
Ok(response)
}
}
impl RequestBuilder {
#[inline]
pub fn client(&self) -> &AuthenticatedClient {
&self.client
}
pub fn body<T: Into<Body>>(self, body: T) -> RequestBuilder {
RequestBuilder {
inner: self.inner.body(body),
..self
}
}
pub fn header<K, V>(self, key: K, value: V) -> RequestBuilder
where
HeaderName: TryFrom<K>,
<HeaderName as TryFrom<K>>::Error: Into<HttpError>,
HeaderValue: TryFrom<V>,
<HeaderValue as TryFrom<V>>::Error: Into<HttpError>,
{
RequestBuilder {
inner: self.inner.header(key, value),
..self
}
}
pub fn headers(self, headers: HeaderMap) -> RequestBuilder {
RequestBuilder {
inner: self.inner.headers(headers),
..self
}
}
pub fn json<T: Serialize + ?Sized>(self, json: &T) -> RequestBuilder {
RequestBuilder {
inner: self.inner.json(json),
..self
}
}
pub fn query<T: Serialize + ?Sized>(self, query: &T) -> RequestBuilder {
RequestBuilder {
inner: self.inner.query(query),
..self
}
}
pub fn timeout(self, timeout: Duration) -> RequestBuilder {
RequestBuilder {
inner: self.inner.timeout(timeout),
..self
}
}
pub async fn fetch<T>(self) -> Result<T, Error>
where
T: DeserializeOwned + Send,
{
self.send().await?.json::<T>().await.map_err(Error::from)
}
pub async fn send(self) -> Result<Response, Error> {
check(self.send_unchecked().await?).await
}
pub async fn send_unchecked(self) -> Result<Response, Error> {
let req = self.client.authenticate(self.inner).await?;
trace!("Sending HTTP {} request to {}", req.method(), req.url());
self.client.client.execute(req).await.map_err(Error::from)
}
pub(crate) async fn send_unchecked_to(self, url: &Url) -> Result<Response, Error> {
let mut req = self.client.authenticate(self.inner).await?;
url_utils::merge(req.url_mut(), url);
trace!("Sending HTTP {} request to {}", req.method(), req.url());
self.client.client.execute(req).await.map_err(Error::from)
}
#[cfg(test)]
pub(crate) fn build(self) -> Result<Request, Error> {
self.inner.build().map_err(From::from)
}
#[cfg(feature = "stream")]
pub async fn fetch_paginated<T>(
self,
limit: Option<usize>,
starting_with: Option<<T as PaginatedResource>::Id>,
) -> impl Stream<Item = Result<T, Error>>
where
T: PaginatedResource + Unpin,
<T as PaginatedResource>::Root: Into<Vec<T>> + Send,
{
paginated(self, limit, starting_with)
}
pub fn try_clone(&self) -> Option<RequestBuilder> {
self.inner.try_clone().map(|inner| RequestBuilder {
inner,
client: self.client.clone(),
})
}
}
#[cfg(feature = "stream")]
#[async_trait]
impl FetchNext for RequestBuilder {
async fn fetch_next<Q: Serialize + Send, T: DeserializeOwned + Send>(
&self,
query: Q,
) -> Result<T, Error> {
let prepared = self
.try_clone()
.expect("Builder with a streaming body cannot be used")
.query(&query);
prepared.fetch().await
}
}
#[cfg(test)]
mod test_extract_message {
use super::extract_message;
#[test]
fn test_plain() {
let msg = "<html><body>I failed</body></html>";
let result = extract_message(msg.to_string());
assert_eq!(result, msg);
}
#[test]
fn test_simple_message() {
let msg = r#"{"message": "I failed"}"#;
let result = extract_message(msg.to_string());
assert_eq!(result, "I failed");
}
#[test]
fn test_nested_message() {
let msg = r#"{"SomethingFailed": {"message": "I failed"}}"#;
let result = extract_message(msg.to_string());
assert_eq!(result, "I failed");
}
#[test]
fn test_ironic_message() {
let msg = r#"{"error_message": {"faultstring": "I failed"}}"#;
let result = extract_message(msg.to_string());
assert_eq!(result, "I failed");
}
#[test]
fn test_ironic_legacy() {
let msg = r#"{"error_message": "{\"faultstring\": \"I failed\"}"}"#;
let result = extract_message(msg.to_string());
assert_eq!(result, "I failed");
}
}