1use anyhow::{Context as _, Result};
3use redact::Secret;
4use regex::RegexSet;
5use serde::{Deserialize, Deserializer, Serialize, Serializer};
6
7#[derive(Clone, Debug, Serialize, Deserialize)]
9#[serde(rename_all = "kebab-case")]
10pub struct ClientConfig {
11 pub mlflow: MlflowConfig,
13 pub download: DownloadConfig,
15}
16
17#[derive(Clone, Debug, Serialize, Deserialize)]
19#[serde(rename_all = "kebab-case")]
20pub struct MlflowConfig {
21 pub urlbase: String,
25 pub auth: Option<BasicAuth>,
27}
28
29#[derive(Clone, Debug, Serialize, Deserialize)]
31#[serde(rename_all = "kebab-case")]
32pub struct DownloadConfig {
33 pub retry: bool,
37 pub tasks: usize,
39 pub blacklist: Blacklist,
42 #[serde(default)]
50 pub cache_local_artifacts: bool,
51 #[serde(default)]
54 pub file_size_check: bool,
55}
56
57#[derive(Clone, Debug, Default, Serialize, Deserialize)]
60#[serde(transparent, rename_all = "kebab-case")]
61pub struct Blacklist(
62 #[serde(
63 serialize_with = "serialize_regex_set",
64 deserialize_with = "deserialize_regex_set"
65 )]
66 pub RegexSet,
67);
68
69#[derive(Clone, Debug, Serialize, Deserialize)]
73#[serde(rename_all = "kebab-case")]
74pub struct BasicAuth {
75 #[serde(serialize_with = "redact::expose_secret")]
76 pub user: Secret<String>,
77 #[serde(serialize_with = "redact::expose_secret")]
78 pub password: Secret<String>,
79}
80
81impl Blacklist {
82 pub fn new(patterns: impl IntoIterator<Item = impl AsRef<str>>) -> Result<Self> {
84 RegexSet::new(patterns)
85 .context("Cannot make blacklist from provided patterns")
86 .map(Blacklist)
87 }
88
89 pub fn is_blacklisted(&self, file: impl AsRef<str>) -> bool {
91 self.0.is_match(file.as_ref())
92 }
93}
94
95impl DownloadConfig {
96 pub fn is_blacklisted(&self, file: impl AsRef<str>) -> bool {
98 self.blacklist.is_blacklisted(file)
99 }
100}
101
102fn deserialize_regex_set<'de, D>(deserializer: D) -> Result<RegexSet, D::Error>
103where
104 D: Deserializer<'de>,
105{
106 use serde::de::Error;
107
108 let patterns = Vec::<String>::deserialize(deserializer)?;
109
110 RegexSet::new(patterns).map_err(|error| {
111 Error::custom(format!(
112 "Cannot load regex set from provided patterns: {error}",
113 ))
114 })
115}
116
117fn serialize_regex_set<S>(set: &RegexSet, serializer: S) -> Result<S::Ok, S::Error>
118where
119 S: Serializer,
120{
121 use serde::ser::SerializeSeq;
122
123 let mut array = serializer.serialize_seq(Some(set.len()))?;
124
125 for pattern in set.patterns() {
126 array.serialize_element(&pattern)?;
127 }
128
129 array.end()
130}
131
132#[cfg(test)]
133mod test {
134 use super::*;
135 use rstest::rstest;
136
137 #[rstest]
138 fn test_config_serde() {
139 let json_config = r#"{
140 "mlflow": {
141 "urlbase": "http://localhost:5000/",
142 "auth": {
143 "user": "john",
144 "password": "doe"
145 }
146 },
147 "download": {
148 "retry": true,
149 "tasks": 37,
150 "blacklist": [
151 "kokot.json",
152 ".*bar.*"
153 ]
154 }
155 }"#;
156
157 let config = serde_json::from_str::<ClientConfig>(json_config)
158 .expect("BUG: Unable to deserialize json config");
159
160 let _json_config1 =
161 serde_json::to_string(&config).expect("BUG: Unable to serialize client config");
162 }
163}