use std::borrow::Cow;
use std::sync::Arc;
use async_trait::async_trait;
use hive_console_sdk::circuit_breaker::CircuitBreakerBuilder;
use hive_console_sdk::persisted_documents::{PersistedDocumentsError, PersistedDocumentsManager};
use hive_router_config::persisted_documents::PersistedDocumentsHiveStorageConfig;
use thiserror::Error;
use crate::consts::ROUTER_VERSION;
use super::{
PersistedDocumentResolveInput, PersistedDocumentResolver, PersistedDocumentResolverError,
ResolvedDocument,
};
pub struct HiveCDNResolver {
manager: PersistedDocumentsManager,
}
static CLIENT_INSTRUCTIONS: &str = "Provide both client name and version headers, or send persisted document id in 'appName~appVersion~documentId' format";
#[derive(Debug, Error)]
pub enum HiveResolverError {
#[error("persisted_documents.storage.hive.endpoint is not configured")]
MissingEndpoint,
#[error("persisted_documents.storage.hive.key is not configured")]
MissingKey,
#[error("Document id format is invalid. Either 'appName~appVersion~documentId' or 'documentId' is accepted, received: {0}")]
InvalidDocumentIdFormat(String),
#[error("Client identity is missing. {CLIENT_INSTRUCTIONS}")]
ClientIdentityMissing,
#[error("Client identity is partial. {CLIENT_INSTRUCTIONS}")]
ClientIdentityPartial,
#[error("Initialization failed: {0}")]
ManagerInit(String),
#[error("SDK error: {0}")]
SDKError(String),
}
struct AppDocumentId<'a>(Cow<'a, str>);
enum DocumentIdSyntax<'a> {
App(&'a str),
Plain(&'a str),
}
impl<'a> TryFrom<&'a str> for DocumentIdSyntax<'a> {
type Error = HiveResolverError;
fn try_from(value: &'a str) -> Result<Self, Self::Error> {
let bytes = value.as_bytes();
let Some(first) = memchr::memchr(b'~', bytes) else {
return Ok(Self::Plain(value));
};
if first == 0 {
return Err(HiveResolverError::InvalidDocumentIdFormat(
value.to_string(),
));
}
let Some(second_relative) = memchr::memchr(b'~', &bytes[first + 1..]) else {
return Err(HiveResolverError::InvalidDocumentIdFormat(
value.to_string(),
));
};
if second_relative == 0 {
return Err(HiveResolverError::InvalidDocumentIdFormat(
value.to_string(),
));
}
let second = first + 1 + second_relative;
if second + 1 >= bytes.len() {
return Err(HiveResolverError::InvalidDocumentIdFormat(
value.to_string(),
));
}
if memchr::memchr(b'~', &bytes[second + 1..]).is_some() {
return Err(HiveResolverError::InvalidDocumentIdFormat(
value.to_string(),
));
}
Ok(Self::App(value))
}
}
impl<'a> TryFrom<PersistedDocumentResolveInput<'a>> for AppDocumentId<'a> {
type Error = HiveResolverError;
fn try_from(input: PersistedDocumentResolveInput<'a>) -> Result<Self, Self::Error> {
let persisted_document_id = input.persisted_document_id.as_ref();
match DocumentIdSyntax::try_from(persisted_document_id)? {
DocumentIdSyntax::App(app_document_id) => {
Ok(Self(Cow::Borrowed(app_document_id)))
}
DocumentIdSyntax::Plain(document_id) => {
match (input.client_identity.name, input.client_identity.version) {
(Some(name), Some(version)) => {
Ok(Self(Cow::Owned(format!("{name}~{version}~{document_id}"))))
}
(Some(_), None) | (None, Some(_)) => {
Err(HiveResolverError::ClientIdentityPartial)
}
(None, None) => Err(HiveResolverError::ClientIdentityMissing),
}
}
}
}
}
impl AsRef<str> for AppDocumentId<'_> {
fn as_ref(&self) -> &str {
self.0.as_ref()
}
}
impl HiveCDNResolver {
pub fn from_storage_config(
config: &PersistedDocumentsHiveStorageConfig,
) -> Result<Self, HiveResolverError> {
let endpoints: Vec<String> = config
.endpoint
.clone()
.ok_or(HiveResolverError::MissingEndpoint)?
.into();
let key = config.key.clone().ok_or(HiveResolverError::MissingKey)?;
let circuit_breaker = CircuitBreakerBuilder::default()
.error_threshold(config.circuit_breaker.error_threshold)
.volume_threshold(config.circuit_breaker.volume_threshold)
.reset_timeout(config.circuit_breaker.reset_timeout);
let mut builder = PersistedDocumentsManager::builder()
.key(key)
.accept_invalid_certs(config.accept_invalid_certs)
.connect_timeout(config.connect_timeout)
.request_timeout(config.request_timeout)
.max_retries(config.retry_policy.max_retries)
.cache_size(config.cache_size)
.circuit_breaker(circuit_breaker)
.user_agent(format!("hive-router/{ROUTER_VERSION}"));
if let Some(negative_cache) = config.negative_cache.enabled_config() {
builder = builder.negative_cache_ttl(negative_cache.ttl);
}
for endpoint in endpoints {
builder = builder.add_endpoint(endpoint);
}
let manager = builder
.build()
.map_err(|err| HiveResolverError::ManagerInit(err.to_string()))?;
Ok(Self { manager })
}
}
#[async_trait]
impl PersistedDocumentResolver for HiveCDNResolver {
async fn resolve(
&self,
input: PersistedDocumentResolveInput<'_>,
) -> Result<ResolvedDocument, PersistedDocumentResolverError> {
let app_document_id = AppDocumentId::try_from(input)?;
let text = self
.manager
.resolve_document(app_document_id.as_ref())
.await
.map_err(|err| match err {
PersistedDocumentsError::DocumentNotFound => {
PersistedDocumentResolverError::NotFound(app_document_id.as_ref().to_string())
}
other => HiveResolverError::SDKError(other.to_string()).into(),
})?;
Ok(ResolvedDocument {
text: Arc::<str>::from(text),
})
}
}
#[cfg(test)]
mod tests {
use super::{AppDocumentId, PersistedDocumentResolveInput};
use crate::pipeline::persisted_documents::types::{ClientIdentity, PersistedDocumentId};
struct Case {
raw_id: &'static str,
client_name: Option<&'static str>,
client_version: Option<&'static str>,
expected: Result<&'static str, &'static str>,
}
#[test]
fn app_document_id_conversion_matrix() {
let cases = [
Case {
raw_id: "documentId",
client_name: Some("app"),
client_version: Some("1.0.0"),
expected: Ok("app~1.0.0~documentId"),
},
Case {
raw_id: "app~1.0.0~documentId",
client_name: None,
client_version: None,
expected: Ok("app~1.0.0~documentId"),
},
Case {
raw_id: "app~1.0.0~documentId",
client_name: Some("app"),
client_version: Some("1.2.3"),
expected: Ok("app~1.0.0~documentId"),
},
Case {
raw_id: "documentId",
client_name: None,
client_version: None,
expected: Err("missing"),
},
Case {
raw_id: "documentId",
client_name: Some("app"),
client_version: None,
expected: Err("partial"),
},
Case {
raw_id: "documentId",
client_name: None,
client_version: Some("1.0.0"),
expected: Err("partial"),
},
Case {
raw_id: "app~documentId",
client_name: None,
client_version: None,
expected: Err("invalid"),
},
Case {
raw_id: "app~~documentId",
client_name: None,
client_version: None,
expected: Err("invalid"),
},
Case {
raw_id: "~1.0.0~documentId",
client_name: None,
client_version: None,
expected: Err("invalid"),
},
Case {
raw_id: "app~1.0.0~",
client_name: None,
client_version: None,
expected: Err("invalid"),
},
Case {
raw_id: "a~b~c~d",
client_name: None,
client_version: None,
expected: Err("invalid"),
},
];
for (idx, case) in cases.into_iter().enumerate() {
let persisted_document_id =
PersistedDocumentId::try_from(case.raw_id).expect("fixture id should parse");
let input = PersistedDocumentResolveInput {
persisted_document_id: &persisted_document_id,
client_identity: ClientIdentity {
name: case.client_name,
version: case.client_version,
},
};
match (AppDocumentId::try_from(input), case.expected) {
(Ok(actual), Ok(expected)) => {
assert_eq!(actual.as_ref(), expected, "case_index={idx}")
}
(Err(err), Err(expected)) => {
assert!(
err.to_string().contains(expected),
"case_index={}, err={}",
idx,
err
);
}
(Ok(actual), Err(expected)) => panic!(
"case_index={} expected err containing '{}' but got Ok({})",
idx,
expected,
actual.as_ref()
),
(Err(err), Ok(expected)) => {
panic!(
"case_index={} expected Ok({}) but got Err({})",
idx, expected, err
)
}
}
}
}
}