use log::{info, warn};
use std::sync::{Arc, OnceLock};
use crate::certificates::{Cert, CertificatesError};
use crate::datastream::Datastream;
use crate::error::DshError;
use crate::utils;
use crate::*;
#[cfg(feature = "kafka")]
use crate::protocol_adapters::kafka_protocol::config::KafkaConfig;
#[derive(Debug, Clone)]
pub struct Dsh {
config_host: String,
task_id: String,
tenant_name: String,
datastream: Arc<Datastream>,
certificates: Option<Cert>,
#[cfg(feature = "kafka")]
kafka_config: KafkaConfig,
}
impl Dsh {
pub(crate) fn new(
config_host: String,
task_id: String,
tenant_name: String,
datastream: Datastream,
certificates: Option<Cert>,
) -> Self {
let datastream = Arc::new(datastream);
Self {
config_host,
task_id,
tenant_name,
datastream: datastream.clone(),
certificates,
#[cfg(feature = "kafka")]
kafka_config: KafkaConfig::new(Some(datastream)),
}
}
pub fn get() -> &'static Self {
static PROPERTIES: OnceLock<Dsh> = OnceLock::new();
PROPERTIES.get_or_init(|| tokio::task::block_in_place(Self::init))
}
fn init() -> Self {
let tenant_name = utils::tenant_name().unwrap_or_else(|_| "local_tenant".to_string());
let task_id =
utils::get_env_var(VAR_TASK_ID).unwrap_or_else(|_| "local_task_id".to_string());
let config_host =
utils::get_env_var(VAR_DSH_KAFKA_CONFIG_ENDPOINT).map(utils::ensure_https_prefix);
let certificates = if let Ok(cert) = Cert::from_pki_config_dir::<std::path::PathBuf>(None) {
Some(cert)
} else if let Ok(config_host) = &config_host {
Cert::from_bootstrap(config_host, &tenant_name, &task_id)
.inspect_err(|e| warn!("Could not bootstrap to DSH, due to: {}", e))
.inspect(|_| info!("Successfully connected to DSH"))
.ok()
} else {
None
};
let config_host = config_host.unwrap_or_else(|_| DEFAULT_CONFIG_HOST.to_string());
let fetched_datastreams = certificates.as_ref().and_then(|cert| {
cert.reqwest_blocking_client_config()
.build()
.ok()
.and_then(|client| {
Datastream::fetch_blocking(&client, &config_host, &tenant_name, &task_id).ok()
})
});
let datastream = if let Some(datastream) = fetched_datastreams {
datastream
} else {
warn!("Could not fetch datastreams.json; using local or default datastreams");
Datastream::load_local_datastreams().unwrap_or_default()
};
Self::new(config_host, task_id, tenant_name, datastream, certificates)
}
pub fn reqwest_client_config(&self) -> reqwest::ClientBuilder {
if let Ok(certificates) = &self.certificates() {
certificates.reqwest_client_config()
} else {
reqwest::Client::builder()
}
}
pub fn reqwest_blocking_client_config(&self) -> reqwest::blocking::ClientBuilder {
if let Ok(certificates) = &self.certificates() {
certificates.reqwest_blocking_client_config()
} else {
reqwest::blocking::Client::builder()
}
}
pub fn certificates(&self) -> Result<&Cert, DshError> {
match &self.certificates {
Some(cert) => Ok(cert),
None => Err(CertificatesError::NoCertificates.into()),
}
}
pub fn client_id(&self) -> &str {
&self.task_id
}
pub fn tenant_name(&self) -> &str {
&self.tenant_name
}
pub fn task_id(&self) -> &str {
&self.task_id
}
pub fn datastream(&self) -> &Datastream {
self.datastream.as_ref()
}
pub async fn fetch_datastream(&self) -> Result<Datastream, DshError> {
static ASYNC_CLIENT: OnceLock<reqwest::Client> = OnceLock::new();
let client = ASYNC_CLIENT.get_or_init(|| {
self.reqwest_client_config()
.build()
.expect("Could not build reqwest client for fetching datastream")
});
Ok(Datastream::fetch(client, &self.config_host, &self.tenant_name, &self.task_id).await?)
}
pub fn fetch_datastream_blocking(&self) -> Result<Datastream, DshError> {
static BLOCKING_CLIENT: OnceLock<reqwest::blocking::Client> = OnceLock::new();
let client = BLOCKING_CLIENT.get_or_init(|| {
self.reqwest_blocking_client_config()
.build()
.expect("Could not build reqwest client for fetching datastream")
});
Ok(Datastream::fetch_blocking(
client,
&self.config_host,
&self.tenant_name,
&self.task_id,
)?)
}
pub fn schema_registry_host(&self) -> &str {
self.datastream().schema_store()
}
#[cfg(feature = "kafka")]
pub fn kafka_config(&self) -> &KafkaConfig {
&self.kafka_config
}
}
#[cfg(test)]
mod tests {
use super::*;
use serial_test::serial;
use std::io::Read;
impl Default for Dsh {
fn default() -> Self {
let datastream = Arc::new(Datastream::load_local_datastreams().unwrap_or_default());
Self {
task_id: "local_task_id".to_string(),
tenant_name: "local_tenant".to_string(),
config_host: "http://localhost/".to_string(),
datastream,
certificates: None,
#[cfg(feature = "kafka")]
kafka_config: KafkaConfig::default(),
}
}
}
fn datastreams_json() -> String {
std::fs::File::open("test_resources/valid_datastreams.json")
.map(|mut file| {
let mut contents = String::new();
file.read_to_string(&mut contents).unwrap();
contents
})
.unwrap()
}
fn datastream() -> Datastream {
serde_json::from_str(&datastreams_json()).unwrap()
}
#[test]
#[serial(env_dependency)]
fn test_get_or_init() {
let properties = Dsh::get();
assert_eq!(properties.client_id(), "local_task_id");
assert_eq!(properties.task_id, "local_task_id");
assert_eq!(properties.tenant_name, "local_tenant");
assert_eq!(
properties.config_host,
"https://pikachu.dsh.marathon.mesos:4443"
);
assert!(properties.certificates.is_none());
}
#[test]
#[serial(env_dependency)]
fn test_reqwest_client_config() {
let properties = Dsh::default();
let _ = properties.reqwest_client_config();
assert!(true);
}
#[test]
#[serial(env_dependency)]
fn test_client_id() {
let properties = Dsh::default();
assert_eq!(properties.client_id(), "local_task_id");
}
#[test]
#[serial(env_dependency)]
fn test_tenant_name() {
let properties = Dsh::default();
assert_eq!(properties.tenant_name(), "local_tenant");
}
#[test]
#[serial(env_dependency)]
fn test_task_id() {
let properties = Dsh::default();
assert_eq!(properties.task_id(), "local_task_id");
}
#[test]
#[serial(env_dependency)]
fn test_schema_registry_host() {
let properties = Dsh::default();
assert_eq!(
properties.schema_registry_host(),
"http://localhost:8081/apis/ccompat/v7"
);
}
#[tokio::test]
async fn test_fetch_datastream() {
let mut server = mockito::Server::new_async().await;
let tenant = "test-tenant";
let task_id = "test-task-id";
let host = server.url();
let prop = Dsh::new(
host,
task_id.to_string(),
tenant.to_string(),
Datastream::default(),
None,
);
server
.mock("GET", "/kafka/config/test-tenant/test-task-id")
.with_status(200)
.with_header("content-type", "application/json")
.with_body(datastreams_json())
.create();
let fetched_datastream = prop.fetch_datastream().await.unwrap();
assert_eq!(fetched_datastream, datastream());
}
#[test]
fn test_fetch_blocking_datastream() {
let mut dsh = mockito::Server::new();
let tenant = "test-tenant";
let task_id = "test-task-id";
let host = dsh.url();
let prop = Dsh::new(
host,
task_id.to_string(),
tenant.to_string(),
Datastream::default(),
None,
);
dsh.mock("GET", "/kafka/config/test-tenant/test-task-id")
.with_status(200)
.with_header("content-type", "application/json")
.with_body(datastreams_json())
.create();
let fetched_datastream = prop.fetch_datastream_blocking().unwrap();
assert_eq!(fetched_datastream, datastream());
}
}