use anyhow::{Context as _, Result};
use redact::Secret;
use regex::RegexSet;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "kebab-case")]
pub struct ClientConfig {
pub mlflow: MlflowConfig,
pub download: DownloadConfig,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "kebab-case")]
pub struct MlflowConfig {
pub urlbase: String,
pub auth: Option<BasicAuth>,
#[serde(default = "default_no_retry_status_codes")]
pub no_retry_status_codes: Vec<u16>,
}
fn default_no_retry_status_codes() -> Vec<u16> {
vec![400, 401, 403, 404, 409, 422]
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "kebab-case")]
pub struct DownloadConfig {
pub retry: bool,
pub tasks: usize,
pub blacklist: Blacklist,
#[serde(default)]
pub cache_local_artifacts: bool,
#[serde(default)]
pub file_size_check: bool,
}
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
#[serde(transparent, rename_all = "kebab-case")]
pub struct Blacklist(
#[serde(
serialize_with = "serialize_regex_set",
deserialize_with = "deserialize_regex_set"
)]
pub RegexSet,
);
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "kebab-case")]
pub struct BasicAuth {
#[serde(serialize_with = "redact::expose_secret")]
pub user: Secret<String>,
#[serde(serialize_with = "redact::expose_secret")]
pub password: Secret<String>,
}
impl Blacklist {
pub fn new(patterns: impl IntoIterator<Item = impl AsRef<str>>) -> Result<Self> {
RegexSet::new(patterns)
.context("Cannot make blacklist from provided patterns")
.map(Blacklist)
}
pub fn is_blacklisted(&self, file: impl AsRef<str>) -> bool {
self.0.is_match(file.as_ref())
}
}
impl DownloadConfig {
pub fn is_blacklisted(&self, file: impl AsRef<str>) -> bool {
self.blacklist.is_blacklisted(file)
}
}
fn deserialize_regex_set<'de, D>(deserializer: D) -> Result<RegexSet, D::Error>
where
D: Deserializer<'de>,
{
use serde::de::Error;
let patterns = Vec::<String>::deserialize(deserializer)?;
RegexSet::new(patterns).map_err(|error| {
Error::custom(format!(
"Cannot load regex set from provided patterns: {error}",
))
})
}
fn serialize_regex_set<S>(set: &RegexSet, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
use serde::ser::SerializeSeq;
let mut array = serializer.serialize_seq(Some(set.len()))?;
for pattern in set.patterns() {
array.serialize_element(&pattern)?;
}
array.end()
}
#[cfg(test)]
mod test {
use super::*;
use rstest::rstest;
#[rstest]
fn test_config_serde() {
let json_config = r#"{
"mlflow": {
"urlbase": "http://localhost:5000/",
"auth": {
"user": "john",
"password": "doe"
}
},
"download": {
"retry": true,
"tasks": 37,
"blacklist": [
"kokot.json",
".*bar.*"
]
}
}"#;
let config = serde_json::from_str::<ClientConfig>(json_config)
.expect("BUG: Unable to deserialize json config");
let _json_config1 =
serde_json::to_string(&config).expect("BUG: Unable to serialize client config");
}
}