ml-cellar 0.2.0

CLI of ML model registry for minimum MLOps
Documentation
use glob::Pattern;
use serde::{Deserialize, Serialize};
use std::env;
use std::path::{Path, PathBuf};

use chrono::NaiveDate;

/// Configuration for a rack (model family/algorithm).
///
/// Contains all settings for managing a rack including artifact rules,
/// project versioning schemes, and documentation generation.
#[derive(Debug, Deserialize, Serialize, Default)]
#[serde(default)]
pub struct RackConfig {
    pub rack: RackInfoConfig,
    pub artifact: ArtifactConfig,
    pub project: ProjectConfig,
    pub document: DocumentConfig,
}

/// Basic information about the rack.
#[derive(Debug, Deserialize, Serialize, Default)]
#[serde(default)]
pub struct RackInfoConfig {
    /// The name of the rack (e.g., "vit-l", "llm-model").
    pub name: String,
}

/// Configuration for artifact file management.
/// Defines which files are required and which are optional in an ML-bin.
#[derive(Debug, Deserialize, Serialize, Default)]
#[serde(default)]
pub struct ArtifactConfig {
    /// List of required file patterns (supports globs and directory paths ending with /).
    /// Files matching these patterns must exist in every ML-bin.
    pub required_files: Vec<String>,
    /// List of optional file patterns (supports globs).
    /// Files matching these patterns are allowed but not required.
    pub optional_files: Vec<String>,
}

impl ArtifactConfig {
    /// Checks if a file matches any required file pattern.
    ///
    /// Supports both glob patterns and directory paths (ending with `/` or `\`).
    /// For directory patterns, any file under that directory is considered a match.
    ///
    /// # Arguments
    ///
    /// - `filename` - The relative path of the file to check
    ///
    /// # Returns
    ///
    /// `true` if the file matches any required pattern, `false` otherwise.
    pub fn is_required_file(&self, filename: &Path) -> bool {
        self.required_files.iter().any(|pattern| {
            let is_directory = pattern.ends_with('/') || pattern.ends_with('\\');
            if is_directory {
                let directory = pattern.trim_end_matches(&['/', '\\'][..]);
                if directory.is_empty() {
                    return true;
                }
                let dir_path = Path::new(directory);
                filename.starts_with(dir_path)
            } else {
                Pattern::new(pattern)
                    .map(|p| p.matches_path(filename))
                    .unwrap_or(false)
            }
        })
    }

    /// Checks if a file matches any optional file pattern.
    ///
    /// # Arguments
    ///
    /// - `filename` - The relative path of the file to check
    ///
    /// # Returns
    ///
    /// `true` if the file matches any optional pattern, `false` otherwise.
    pub fn is_optional_file(&self, filename: &Path) -> bool {
        self.optional_files.iter().any(|pattern| {
            Pattern::new(pattern)
                .map(|p| p.matches_path(filename))
                .unwrap_or(false)
        })
    }
}

/// Configuration for project versioning schemes.
///
/// Allows defining different version formats (YYYYMMDD, X, X.Y, X.Y.Z) for different projects.
/// Projects not explicitly configured accept any version format.
#[derive(Debug, Deserialize, Serialize, Default)]
#[serde(default)]
pub struct ProjectConfig {
    /// List of projects.
    pub project: Option<Vec<String>>,
    /// List of projects that use YYYYMMDD versioning format.
    pub version_yyyymmdd: Option<Vec<String>>,
    /// List of projects that use single integer versioning (e.g., "1", "2").
    pub version_x: Option<Vec<String>>,
    /// List of projects that use two-part versioning (e.g., "1.0", "2.3").
    pub version_x_y: Option<Vec<String>>,
    /// List of projects that use three-part semantic versioning (e.g., "1.0.0", "2.1.3").
    pub version_x_y_z: Option<Vec<String>>,
}

impl ProjectConfig {
    /// Validates whether a version string follows the configured format for a project.
    ///
    /// # Arguments
    ///
    /// - `project_name` - The name of the project
    /// - `version` - The version string to validate
    ///
    /// # Returns
    ///
    /// `true` if the version is valid for the project, `false` otherwise.
    /// Projects without explicit configuration accept any version format.
    ///
    pub fn is_valid_version(&self, project_name: &str, version: &str) -> bool {
        let version_elements: Vec<String> = version.split('.').map(|s| s.to_string()).collect();

        if self.is_yyyymmdd_project(project_name) {
            is_yyyymmdd_format(version)
        } else if self.is_x_project(project_name) {
            version.parse::<i64>().is_ok()
        } else if self.is_x_y_project(project_name) {
            version.matches('.').count() == 1
                && version_elements.len() == 2
                && version_elements[0].parse::<i64>().is_ok()
                && version_elements[1].parse::<i64>().is_ok()
        } else if self.is_x_y_z_project(project_name) {
            version.matches('.').count() == 2
                && version_elements.len() == 3
                && version_elements[0].parse::<i64>().is_ok()
                && version_elements[1].parse::<i64>().is_ok()
                && version_elements[2].parse::<i64>().is_ok()
        } else {
            true
        }
    }

    /// Checks if a project uses YYYYMMDD versioning format.
    fn is_yyyymmdd_project(&self, project_name: &str) -> bool {
        self.version_yyyymmdd.is_some()
            && self
                .version_yyyymmdd
                .as_ref()
                .unwrap()
                .iter()
                .any(|s| s == project_name)
    }

    /// Checks if a project uses single integer (X) versioning format.
    fn is_x_project(&self, project_name: &str) -> bool {
        self.version_x.is_some()
            && self
                .version_x
                .as_ref()
                .unwrap()
                .iter()
                .any(|s| s == project_name)
    }

    /// Checks if a project uses two-part (X.Y) versioning format.
    fn is_x_y_project(&self, project_name: &str) -> bool {
        self.version_x_y.is_some()
            && self
                .version_x_y
                .as_ref()
                .unwrap()
                .iter()
                .any(|s| s == project_name)
    }

    /// Checks if a project uses three-part semantic (X.Y.Z) versioning format.
    fn is_x_y_z_project(&self, project_name: &str) -> bool {
        self.version_x_y_z.is_some()
            && self
                .version_x_y_z
                .as_ref()
                .unwrap()
                .iter()
                .any(|s| s == project_name)
    }
}

/// Configuration for documentation generation.
///
/// Specifies template and result files used for automatically generating
/// documentation for ML-bins.
#[derive(Debug, Deserialize, Serialize, Default)]
#[serde(default)]
pub struct DocumentConfig {
    /// The template file for documentation generation.
    /// If you set template_file = "template.md" in config.toml for rack "my_rack", the structure is as follows:
    ///
    /// - model_registry_repository/
    ///   - {rack_name}/
    ///     - template.md
    ///     - config.toml
    ///     - 0.1/
    ///     - 0.2/
    ///
    pub template_file: Option<String>,
    /// The result file for documentation generation in each version directory.
    /// For now, only JSON format is supported.
    /// If you set result_file = "result.json" in config.toml for rack "my_rack", the structure is as follows:
    ///
    /// - model_registry_repository/
    ///   - {rack_name}/
    ///     - template.md
    ///     - config.toml
    ///     - 0.1/
    ///       - result.json
    ///     - 0.2/
    ///       - result.json
    ///
    pub result_file: Option<String>,
}

/// Loads the rack configuration by searching for `config.toml` in the directory tree.
///
/// This function searches upward from the given path through parent directories
/// until it finds a `config.toml` file. This allows ml-cellar commands to be
/// run from any subdirectory within a rack.
///
/// # Arguments
///
/// - `path` - The starting path to search from (can be a file or directory)
///
/// # Returns
///
/// A tuple `(RackConfig, PathBuf)` where:
/// - `RackConfig`: The parsed configuration from `config.toml`
/// - `PathBuf`: The absolute path to the directory containing `config.toml`
///   (this is the root of the rack)
///
/// # Panics
///
/// Panics if:
/// - `config.toml` is not found in any parent directory up to the filesystem root
/// - The configuration file cannot be read
/// - The TOML content cannot be parsed
///
pub fn load_rack_config(path: &Path) -> (RackConfig, PathBuf) {
    // Set directory to start searching
    let relative_dir = if path.is_dir() {
        path.to_path_buf()
    } else {
        path.parent().unwrap().to_path_buf()
    };

    // Convert to absolute path
    let mut absolute_dir = if relative_dir.is_absolute() {
        relative_dir
    } else {
        env::current_dir().unwrap().join(relative_dir)
    };

    // Search for config.toml in the directory and its parents
    loop {
        let candidate = absolute_dir.join("config.toml");
        if candidate.is_file() {
            // Found config.toml, read and parse it
            log::info!("Loading config from {:?}", candidate);
            let config_content = std::fs::read_to_string(&candidate).unwrap();
            return (toml::from_str(&config_content).unwrap(), absolute_dir);
        }

        match absolute_dir.parent() {
            Some(parent) => absolute_dir = parent.to_path_buf(),
            None => {
                log::error!(
                    "config.toml not found; reached filesystem root at {:?}\n\
                     Please ensure that config.toml exists in the directory tree starting from the directory of the provided path.",
                    absolute_dir
                );
                panic!(
                    "config.toml not found in the directory tree starting from {:?}",
                    path
                );
            }
        }
    }
}

/// Validates whether a string is in YYYYMMDD date format.
fn is_yyyymmdd_format(s: &str) -> bool {
    if s.len() != 8 || !s.as_bytes().iter().all(|b| b.is_ascii_digit()) {
        return false;
    }

    let y: i32 = match s[0..4].parse() {
        Ok(v) => v,
        Err(_) => return false,
    };
    let m: u32 = match s[4..6].parse() {
        Ok(v) => v,
        Err(_) => return false,
    };
    let d: u32 = match s[6..8].parse() {
        Ok(v) => v,
        Err(_) => return false,
    };

    NaiveDate::from_ymd_opt(y, m, d).is_some()
}

#[cfg(test)]
mod tests {
    use super::*;
    use tempfile::TempDir;
    use toml::to_string_pretty;

    #[test]
    fn test_load_rack_config() {
        let temp = TempDir::new().unwrap();
        let root_directory = temp.path();

        let toml_str = to_string_pretty(&RackConfig::default())
            .expect("failed to serialize rack config to TOML");
        std::fs::write(root_directory.join("config.toml"), toml_str)
            .expect("failed to write config.toml");

        let (config, config_dir) = load_rack_config(root_directory);
        assert_eq!(config_dir, root_directory);
        assert_eq!(config.rack.name, "");
    }
}