use std::str::FromStr;
use http::{Uri, uri::Scheme};
use ocm_types::{
discovery::Discovery,
error::{Error, ValidationError},
};
use crate::common::HttpClient;
pub const DISCOVERY_ENDPOINT: &str = "/.well-known/ocm";
pub const LEGACY_DISCOVERY_ENDPOINT: &str = "/ocm-provider";
#[derive(Debug)]
pub enum DiscoveryError {
InvalidOcmServerAddress(String),
RequestError(String),
DeserializationFailed(serde_json::Error),
}
impl DiscoveryError {
pub fn status_code(&self) -> http::StatusCode {
match self {
DiscoveryError::InvalidOcmServerAddress(_) => http::StatusCode::NOT_ACCEPTABLE,
DiscoveryError::RequestError(_) => http::StatusCode::BAD_GATEWAY,
DiscoveryError::DeserializationFailed(_) => http::StatusCode::BAD_GATEWAY,
}
}
}
impl From<DiscoveryError> for Error {
fn from(value: DiscoveryError) -> Self {
match value {
DiscoveryError::InvalidOcmServerAddress(target_uri) => Error {
message: "INVALID_OCM_SERVER_ADDRRESS".to_string(),
validation_errors: vec![ValidationError {
name: Some("Missing Host".to_string()),
message: Some(format!("'{target_uri}' does not contain a host")),
}],
},
DiscoveryError::RequestError(e) => Error {
message: "REQUEST_ERROR".to_string(),
validation_errors: vec![ValidationError {
name: Some("OCM Server rejected request".to_string()),
message: Some(e.to_string()),
}],
},
DiscoveryError::DeserializationFailed(e) => Error {
message: "INVALID_DISCOVERY_RESPONSE".to_string(),
validation_errors: vec![ValidationError {
name: Some("Failed to deserialize Discovery Response".to_string()),
message: Some(e.to_string()),
}],
},
}
}
}
impl From<serde_json::Error> for DiscoveryError {
fn from(value: serde_json::Error) -> Self {
Self::DeserializationFailed(value)
}
}
pub async fn discover(
http_client: &impl HttpClient,
target: &Uri,
) -> Result<Discovery, DiscoveryError> {
let discovery = fetch_discovery(Scheme::from_str("https").unwrap(), http_client, target).await;
if discovery.is_err() && http_client.allow_http() {
fetch_discovery(Scheme::from_str("http").unwrap(), http_client, target).await
} else {
discovery
}
}
async fn fetch_discovery(
scheme: Scheme,
http_client: &impl HttpClient,
target: &Uri,
) -> Result<Discovery, DiscoveryError> {
let uri = derive_discovery_endpoint(scheme.clone(), target, DISCOVERY_ENDPOINT)?;
let response: Result<String, DiscoveryError> = http_client
.get(&uri)
.await
.map_err(DiscoveryError::RequestError);
let discovery: Result<Discovery, DiscoveryError> =
response.and_then(|res: String| serde_json::from_str(&res).map_err(|e| e.into()));
if discovery.is_err() {
let uri = derive_discovery_endpoint(scheme, target, LEGACY_DISCOVERY_ENDPOINT)?;
let response: Result<String, DiscoveryError> = http_client
.get(&uri)
.await
.map_err(DiscoveryError::RequestError);
response.and_then(|res: String| serde_json::from_str(&res).map_err(|e| e.into()))
} else {
discovery
}
}
fn derive_discovery_endpoint(
scheme: Scheme,
target: &Uri,
endpoint_path: &str,
) -> Result<Uri, DiscoveryError> {
let fqdn = target
.host()
.ok_or(DiscoveryError::InvalidOcmServerAddress(target.to_string()))?;
let port = target
.port()
.map(|p| format!(":{}", p.as_str()))
.unwrap_or("".to_string());
let uri = Uri::builder()
.scheme(scheme)
.authority(format!("{fqdn}{port}").as_str())
.path_and_query(endpoint_path)
.build()
.unwrap();
Ok(uri)
}