trs_mlflow/
config.rs

1//! This module contains client configuration.
2use anyhow::{Context as _, Result};
3use redact::Secret;
4use regex::RegexSet;
5use serde::{Deserialize, Deserializer, Serialize, Serializer};
6
7/// Contains all configuration options for [`crate::Client`].
8#[derive(Clone, Debug, Serialize, Deserialize)]
9#[serde(rename_all = "kebab-case")]
10pub struct ClientConfig {
11    /// Mlflow related configuration.
12    pub mlflow: MlflowConfig,
13    /// Configuration which is related to artifact download.
14    pub download: DownloadConfig,
15}
16
17/// Mlflow related configuration.
18#[derive(Clone, Debug, Serialize, Deserialize)]
19#[serde(rename_all = "kebab-case")]
20pub struct MlflowConfig {
21    /// Url to mlflow tracking server API.
22    ///
23    /// Example: 'http://localhost:5000/api'
24    pub urlbase: String,
25    /// Basic configuration which can be used to mlflow authentication.
26    pub auth: Option<BasicAuth>,
27}
28
29/// Configuration which is related to artifact download.
30#[derive(Clone, Debug, Serialize, Deserialize)]
31#[serde(rename_all = "kebab-case")]
32pub struct DownloadConfig {
33    /// Enables retrying for artifact download.
34    ///
35    /// Note that retry is enabled only for connection related errors.
36    pub retry: bool,
37    /// Determines how many download task can run at once.
38    pub tasks: usize,
39    /// List of regexes which specifies which files should not be downloaded
40    /// from mlflow.
41    pub blacklist: Blacklist,
42    /// This option will disable download of artifacts which are already stored
43    /// on disk.
44    ///
45    /// Note that there is no check that would ensure that remote and local files
46    /// are identical.
47    ///
48    /// The main purpose of this method is to speed up download during testing phase.
49    #[serde(default)]
50    pub cache_local_artifacts: bool,
51    /// Flag determines whether the downloaded file's size will be compared to the expected size.
52    /// If the two do not match, the download will be retried.
53    #[serde(default)]
54    pub file_size_check: bool,
55}
56
57/// List of regexes which specifies which files should not be downloaded
58/// from mlflow.
59#[derive(Clone, Debug, Default, Serialize, Deserialize)]
60#[serde(transparent, rename_all = "kebab-case")]
61pub struct Blacklist(
62    #[serde(
63        serialize_with = "serialize_regex_set",
64        deserialize_with = "deserialize_regex_set"
65    )]
66    pub RegexSet,
67);
68
69/// Basic configuration which can be used to mlflow authentication.
70///
71/// See more here: <https://mlflow.org/docs/latest/auth/index.html>
72#[derive(Clone, Debug, Serialize, Deserialize)]
73#[serde(rename_all = "kebab-case")]
74pub struct BasicAuth {
75    #[serde(serialize_with = "redact::expose_secret")]
76    pub user: Secret<String>,
77    #[serde(serialize_with = "redact::expose_secret")]
78    pub password: Secret<String>,
79}
80
81impl Blacklist {
82    /// Creates a new blacklist from provided patterns.
83    pub fn new(patterns: impl IntoIterator<Item = impl AsRef<str>>) -> Result<Self> {
84        RegexSet::new(patterns)
85            .context("Cannot make blacklist from provided patterns")
86            .map(Blacklist)
87    }
88
89    /// Returns true if file is blacklisted for download.
90    pub fn is_blacklisted(&self, file: impl AsRef<str>) -> bool {
91        self.0.is_match(file.as_ref())
92    }
93}
94
95impl DownloadConfig {
96    /// Returns true if file is blacklisted for download.
97    pub fn is_blacklisted(&self, file: impl AsRef<str>) -> bool {
98        self.blacklist.is_blacklisted(file)
99    }
100}
101
102fn deserialize_regex_set<'de, D>(deserializer: D) -> Result<RegexSet, D::Error>
103where
104    D: Deserializer<'de>,
105{
106    use serde::de::Error;
107
108    let patterns = Vec::<String>::deserialize(deserializer)?;
109
110    RegexSet::new(patterns).map_err(|error| {
111        Error::custom(format!(
112            "Cannot load regex set from provided patterns: {error}",
113        ))
114    })
115}
116
117fn serialize_regex_set<S>(set: &RegexSet, serializer: S) -> Result<S::Ok, S::Error>
118where
119    S: Serializer,
120{
121    use serde::ser::SerializeSeq;
122
123    let mut array = serializer.serialize_seq(Some(set.len()))?;
124
125    for pattern in set.patterns() {
126        array.serialize_element(&pattern)?;
127    }
128
129    array.end()
130}
131
132#[cfg(test)]
133mod test {
134    use super::*;
135    use rstest::rstest;
136
137    #[rstest]
138    fn test_config_serde() {
139        let json_config = r#"{
140            "mlflow": {
141                "urlbase": "http://localhost:5000/",
142                "auth": {
143                    "user": "john",
144                    "password": "doe"
145                }
146            },
147            "download": {
148                "retry": true,
149                "tasks": 37,
150                "blacklist": [
151                    "kokot.json",
152                    ".*bar.*"
153                ]
154            }
155        }"#;
156
157        let config = serde_json::from_str::<ClientConfig>(json_config)
158            .expect("BUG: Unable to deserialize json config");
159
160        let _json_config1 =
161            serde_json::to_string(&config).expect("BUG: Unable to serialize client config");
162    }
163}