use std::time::Duration;
use serde::Serialize;
use serde::de::DeserializeOwned;
use uuid::Uuid;
use crate::cache::TtlCache;
use crate::error::Error;
use crate::evaluation::{
EvaluationRequest, EvaluationResponse, EvaluationsRequest, EvaluationsResponse,
};
use crate::http::{HttpClient, HttpResponse, Method};
use crate::search::{
ActionSearchRequest, ActionSearchResponse, ResourceSearchRequest, ResourceSearchResponse,
SubjectSearchRequest, SubjectSearchResponse,
};
const AUTHZEN_WELL_KNOWN_PATH: &str = "/.well-known/authzen-configuration";
#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub struct PdpConfiguration {
pub policy_decision_point: String,
pub access_evaluation_endpoint: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub access_evaluations_endpoint: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub search_subject_endpoint: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub search_resource_endpoint: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub search_action_endpoint: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub capabilities: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub signed_metadata: Option<String>,
}
pub struct AuthZenClient<C: HttpClient> {
http: C,
cache: TtlCache<PdpConfiguration>,
}
impl<C: HttpClient> AuthZenClient<C> {
#[must_use]
pub fn new(http: C, cache_ttl: Duration) -> Self {
Self { http, cache: TtlCache::new(cache_ttl) }
}
pub async fn discover(&self, pdp_id: &str) -> Result<PdpConfiguration, Error> {
let url = Self::build_discovery_url(pdp_id)?;
let resp = self.unauthenticated_get(&url).await?;
let config: PdpConfiguration =
serde_json::from_slice(&resp.body).map_err(Error::InvalidResponse)?;
Self::validate_pdp_match(pdp_id, &config)?;
self.cache.insert(pdp_id.to_owned(), config.clone()).await;
Ok(config)
}
pub async fn get_pdp_config(&self, pdp_id: &str) -> Result<PdpConfiguration, Error> {
if let Some(config) = self.cache.get(pdp_id).await {
return Ok(config);
}
self.discover(pdp_id).await
}
pub async fn invalidate_pdp_config(&self, pdp_id: &str) -> bool {
let existed = self.cache.get(pdp_id).await.is_some();
self.cache.invalidate(pdp_id).await;
existed
}
pub async fn evaluate(
&self,
pdp_id: &str,
token: &str,
request: &EvaluationRequest,
) -> Result<EvaluationResponse, Error> {
let url = self.resolve_required_endpoint(pdp_id, |c| &c.access_evaluation_endpoint).await?;
self.post_json(&url, token, request).await
}
pub async fn evaluate_batch(
&self,
pdp_id: &str,
token: &str,
request: &EvaluationsRequest,
) -> Result<EvaluationsResponse, Error> {
let url = self
.resolve_optional_endpoint(pdp_id, |c| c.access_evaluations_endpoint.as_ref(), "/access/v1/evaluations")
.await?;
self.post_json(&url, token, request).await
}
pub async fn search_subjects(
&self,
pdp_id: &str,
token: &str,
request: &SubjectSearchRequest,
) -> Result<SubjectSearchResponse, Error> {
let url = self
.resolve_optional_endpoint(pdp_id, |c| c.search_subject_endpoint.as_ref(), "/access/v1/search/subject")
.await?;
self.post_json(&url, token, request).await
}
pub async fn search_resources(
&self,
pdp_id: &str,
token: &str,
request: &ResourceSearchRequest,
) -> Result<ResourceSearchResponse, Error> {
let url = self
.resolve_optional_endpoint(pdp_id, |c| c.search_resource_endpoint.as_ref(), "/access/v1/search/resource")
.await?;
self.post_json(&url, token, request).await
}
pub async fn search_actions(
&self,
pdp_id: &str,
token: &str,
request: &ActionSearchRequest,
) -> Result<ActionSearchResponse, Error> {
let url = self
.resolve_optional_endpoint(pdp_id, |c| c.search_action_endpoint.as_ref(), "/access/v1/search/action")
.await?;
self.post_json(&url, token, request).await
}
fn validate_pdp_url(pdp_id: &str) -> Result<url::Url, Error> {
let parsed =
url::Url::parse(pdp_id).map_err(|e| Error::InvalidPdpUrl(e.to_string()))?;
if parsed.scheme() != "https" {
return Err(Error::InvalidPdpUrl(format!(
"scheme must be https, got {}",
parsed.scheme()
)));
}
if parsed.query().is_some() || parsed.fragment().is_some() {
return Err(Error::InvalidPdpUrl(
"PDP URL must not contain query or fragment".to_owned(),
));
}
Ok(parsed)
}
fn build_discovery_url(pdp_id: &str) -> Result<String, Error> {
let parsed = Self::validate_pdp_url(pdp_id)?;
let path = parsed.path().trim_end_matches('/');
let mut discovery = parsed.clone();
discovery.set_path(&format!("{}{}", path, AUTHZEN_WELL_KNOWN_PATH));
Ok(discovery.to_string())
}
fn validate_pdp_match(expected: &str, config: &PdpConfiguration) -> Result<(), Error> {
let expected_normalized = expected.trim_end_matches('/');
let got_normalized = config.policy_decision_point.trim_end_matches('/');
if expected_normalized != got_normalized {
return Err(Error::PdpMismatch {
expected: expected.to_owned(),
got: config.policy_decision_point.clone(),
});
}
Ok(())
}
async fn resolve_required_endpoint(
&self,
pdp_id: &str,
extract: fn(&PdpConfiguration) -> &String,
) -> Result<String, Error> {
let config =
self.cache.get(pdp_id).await.ok_or_else(|| Error::NotCached(pdp_id.to_owned()))?;
Ok(extract(&config).clone())
}
async fn resolve_optional_endpoint(
&self,
pdp_id: &str,
extract: fn(&PdpConfiguration) -> Option<&String>,
default_path: &str,
) -> Result<String, Error> {
let config =
self.cache.get(pdp_id).await.ok_or_else(|| Error::NotCached(pdp_id.to_owned()))?;
if let Some(url) = extract(&config) {
return Ok(url.clone());
}
let mut base = Self::validate_pdp_url(&config.policy_decision_point)?;
let path = base.path().trim_end_matches('/');
base.set_path(&format!("{}{}", path, default_path));
Ok(base.to_string())
}
async fn authenticated_request(
&self,
method: Method,
url: &str,
token: &str,
body: Option<Vec<u8>>,
) -> Result<HttpResponse, Error> {
let auth = format!("Bearer {}", token);
let request_id = Uuid::new_v4().to_string();
let headers = [
("authorization", auth.as_str()),
("x-request-id", request_id.as_str()),
];
let resp = self.http.request(method, url, &headers, body).await?;
if resp.status >= 400 {
return Err(Error::HttpStatus {
status: resp.status,
body: String::from_utf8_lossy(&resp.body).into_owned(),
});
}
Ok(resp)
}
async fn unauthenticated_get(&self, url: &str) -> Result<HttpResponse, Error> {
let request_id = Uuid::new_v4().to_string();
let headers = [("x-request-id", request_id.as_str())];
let resp = self.http.request(Method::Get, url, &headers, None).await?;
if resp.status >= 400 {
return Err(Error::HttpStatus {
status: resp.status,
body: String::from_utf8_lossy(&resp.body).into_owned(),
});
}
Ok(resp)
}
async fn post_json<T: DeserializeOwned, B: Serialize>(
&self,
url: &str,
token: &str,
body: &B,
) -> Result<T, Error> {
let bytes = serde_json::to_vec(body).map_err(Error::Serialization)?;
let resp = self.authenticated_request(Method::Post, url, token, Some(bytes)).await?;
serde_json::from_slice(&resp.body).map_err(Error::InvalidResponse)
}
}