use std::fmt::Debug;
use async_trait::async_trait;
use reqwest::{Client, RequestBuilder, Url};
use static_assertions::{assert_impl_all, assert_obj_safe};
use super::{EndpointFilters, Error, ErrorKind};
#[async_trait]
pub trait AuthType: Debug + Sync + Send {
async fn authenticate(
&self,
client: &Client,
request: RequestBuilder,
) -> Result<RequestBuilder, Error>;
async fn get_endpoint(
&self,
client: &Client,
service_type: &str,
filters: &EndpointFilters,
) -> Result<Url, Error>;
async fn refresh(&self, client: &Client) -> Result<(), Error>;
}
assert_obj_safe!(AuthType);
#[derive(Clone, Debug)]
pub struct NoAuth {
endpoint: Option<Url>,
}
assert_impl_all!(NoAuth: Send, Sync);
impl NoAuth {
#[inline]
pub fn new<U>(endpoint: U) -> Result<NoAuth, Error>
where
U: AsRef<str>,
{
let endpoint = Url::parse(endpoint.as_ref())
.map_err(|e| Error::new(ErrorKind::InvalidInput, e.to_string()))?;
Ok(NoAuth {
endpoint: Some(endpoint),
})
}
#[inline]
pub fn new_without_endpoint() -> NoAuth {
NoAuth { endpoint: None }
}
}
#[async_trait]
impl AuthType for NoAuth {
async fn authenticate(
&self,
_client: &Client,
request: RequestBuilder,
) -> Result<RequestBuilder, Error> {
Ok(request)
}
async fn get_endpoint(
&self,
_client: &Client,
service_type: &str,
_filters: &EndpointFilters,
) -> Result<Url, Error> {
self.endpoint.clone().ok_or_else(|| {
Error::new(
ErrorKind::EndpointNotFound,
format!(
"None authentication without an endpoint, use an override for {}",
service_type
),
)
})
}
async fn refresh(&self, _client: &Client) -> Result<(), Error> {
Ok(())
}
}
#[cfg(test)]
pub mod test {
use reqwest::Client;
use super::{AuthType, NoAuth};
#[test]
fn test_noauth_new() {
let a = NoAuth::new("http://127.0.0.1:8080/v1").unwrap();
let e = a.endpoint.unwrap();
assert_eq!(e.scheme(), "http");
assert_eq!(e.host_str().unwrap(), "127.0.0.1");
assert_eq!(e.port().unwrap(), 8080u16);
assert_eq!(e.path(), "/v1");
}
#[test]
fn test_noauth_new_fail() {
let _ = NoAuth::new("foo bar").err().unwrap();
}
#[tokio::test]
async fn test_noauth_get_endpoint() {
let a = NoAuth::new("http://127.0.0.1:8080/v1").unwrap();
let e = a
.get_endpoint(&Client::new(), "foobar", &Default::default())
.await
.unwrap();
assert_eq!(e.scheme(), "http");
assert_eq!(e.host_str().unwrap(), "127.0.0.1");
assert_eq!(e.port().unwrap(), 8080u16);
assert_eq!(e.path(), "/v1");
}
}