trs-mlflow 0.7.0

This crate contains an asynchronous client which implements 2.0 REST API of MlFlow server.
Documentation
//! This module contains client configuration.
use anyhow::{Context as _, Result};
use redact::Secret;
use regex::RegexSet;
use serde::{Deserialize, Deserializer, Serialize, Serializer};

/// Contains all configuration options for [`crate::Client`].
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "kebab-case")]
pub struct ClientConfig {
    /// Mlflow related configuration.
    pub mlflow: MlflowConfig,
    /// Configuration which is related to artifact download.
    pub download: DownloadConfig,
}

/// Mlflow related configuration.
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "kebab-case")]
pub struct MlflowConfig {
    /// Url to mlflow tracking server API.
    ///
    /// Example: 'http://localhost:5000/api'
    pub urlbase: String,
    /// Basic configuration which can be used to mlflow authentication.
    pub auth: Option<BasicAuth>,
    /// HTTP status codes that should not be retried. Requests that fail with
    /// one of these status codes will return an error immediately instead of
    /// being retried with exponential backoff.
    ///
    /// Defaults to common non-transient 4xx codes: 400, 401, 403, 404, 409, 422.
    #[serde(default = "default_no_retry_status_codes")]
    pub no_retry_status_codes: Vec<u16>,
}

fn default_no_retry_status_codes() -> Vec<u16> {
    vec![400, 401, 403, 404, 409, 422]
}

/// Configuration which is related to artifact download.
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "kebab-case")]
pub struct DownloadConfig {
    /// Enables retrying for artifact download.
    ///
    /// Note that retry is enabled only for connection related errors.
    pub retry: bool,
    /// Determines how many download task can run at once.
    pub tasks: usize,
    /// List of regexes which specifies which files should not be downloaded
    /// from mlflow.
    pub blacklist: Blacklist,
    /// This option will disable download of artifacts which are already stored
    /// on disk.
    ///
    /// Note that there is no check that would ensure that remote and local files
    /// are identical.
    ///
    /// The main purpose of this method is to speed up download during testing phase.
    #[serde(default)]
    pub cache_local_artifacts: bool,
    /// Flag determines whether the downloaded file's size will be compared to the expected size.
    /// If the two do not match, the download will be retried.
    #[serde(default)]
    pub file_size_check: bool,
}

/// List of regexes which specifies which files should not be downloaded
/// from mlflow.
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
#[serde(transparent, rename_all = "kebab-case")]
pub struct Blacklist(
    #[serde(
        serialize_with = "serialize_regex_set",
        deserialize_with = "deserialize_regex_set"
    )]
    pub RegexSet,
);

/// Basic configuration which can be used to mlflow authentication.
///
/// See more here: <https://mlflow.org/docs/latest/auth/index.html>
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "kebab-case")]
pub struct BasicAuth {
    #[serde(serialize_with = "redact::expose_secret")]
    pub user: Secret<String>,
    #[serde(serialize_with = "redact::expose_secret")]
    pub password: Secret<String>,
}

impl Blacklist {
    /// Creates a new blacklist from provided patterns.
    pub fn new(patterns: impl IntoIterator<Item = impl AsRef<str>>) -> Result<Self> {
        RegexSet::new(patterns)
            .context("Cannot make blacklist from provided patterns")
            .map(Blacklist)
    }

    /// Returns true if file is blacklisted for download.
    pub fn is_blacklisted(&self, file: impl AsRef<str>) -> bool {
        self.0.is_match(file.as_ref())
    }
}

impl DownloadConfig {
    /// Returns true if file is blacklisted for download.
    pub fn is_blacklisted(&self, file: impl AsRef<str>) -> bool {
        self.blacklist.is_blacklisted(file)
    }
}

fn deserialize_regex_set<'de, D>(deserializer: D) -> Result<RegexSet, D::Error>
where
    D: Deserializer<'de>,
{
    use serde::de::Error;

    let patterns = Vec::<String>::deserialize(deserializer)?;

    RegexSet::new(patterns).map_err(|error| {
        Error::custom(format!(
            "Cannot load regex set from provided patterns: {error}",
        ))
    })
}

fn serialize_regex_set<S>(set: &RegexSet, serializer: S) -> Result<S::Ok, S::Error>
where
    S: Serializer,
{
    use serde::ser::SerializeSeq;

    let mut array = serializer.serialize_seq(Some(set.len()))?;

    for pattern in set.patterns() {
        array.serialize_element(&pattern)?;
    }

    array.end()
}

#[cfg(test)]
mod test {
    use super::*;
    use rstest::rstest;

    #[rstest]
    fn test_config_serde() {
        let json_config = r#"{
            "mlflow": {
                "urlbase": "http://localhost:5000/",
                "auth": {
                    "user": "john",
                    "password": "doe"
                }
            },
            "download": {
                "retry": true,
                "tasks": 37,
                "blacklist": [
                    "kokot.json",
                    ".*bar.*"
                ]
            }
        }"#;

        let config = serde_json::from_str::<ClientConfig>(json_config)
            .expect("BUG: Unable to deserialize json config");

        let _json_config1 =
            serde_json::to_string(&config).expect("BUG: Unable to serialize client config");
    }
}