ml_cellar/
rack.rs

1use glob::Pattern;
2use serde::{Deserialize, Serialize};
3use std::env;
4use std::path::{Path, PathBuf};
5
6use chrono::NaiveDate;
7
8/// Configuration for a rack (model family/algorithm).
9///
10/// Contains all settings for managing a rack including artifact rules,
11/// project versioning schemes, and documentation generation.
12#[derive(Debug, Deserialize, Serialize, Default)]
13#[serde(default)]
14pub struct RackConfig {
15    pub rack: RackInfoConfig,
16    pub artifact: ArtifactConfig,
17    pub project: ProjectConfig,
18    pub document: DocumentConfig,
19}
20
21/// Basic information about the rack.
22#[derive(Debug, Deserialize, Serialize, Default)]
23#[serde(default)]
24pub struct RackInfoConfig {
25    /// The name of the rack (e.g., "vit-l", "llm-model").
26    pub name: String,
27}
28
29/// Configuration for artifact file management.
30/// Defines which files are required and which are optional in an ML-bin.
31#[derive(Debug, Deserialize, Serialize, Default)]
32#[serde(default)]
33pub struct ArtifactConfig {
34    /// List of required file patterns (supports globs and directory paths ending with /).
35    /// Files matching these patterns must exist in every ML-bin.
36    pub required_files: Vec<String>,
37    /// List of optional file patterns (supports globs).
38    /// Files matching these patterns are allowed but not required.
39    pub optional_files: Vec<String>,
40}
41
42impl ArtifactConfig {
43    /// Checks if a file matches any required file pattern.
44    ///
45    /// Supports both glob patterns and directory paths (ending with `/` or `\`).
46    /// For directory patterns, any file under that directory is considered a match.
47    ///
48    /// # Arguments
49    ///
50    /// - `filename` - The relative path of the file to check
51    ///
52    /// # Returns
53    ///
54    /// `true` if the file matches any required pattern, `false` otherwise.
55    pub fn is_required_file(&self, filename: &Path) -> bool {
56        self.required_files.iter().any(|pattern| {
57            let is_directory = pattern.ends_with('/') || pattern.ends_with('\\');
58            if is_directory {
59                let directory = pattern.trim_end_matches(&['/', '\\'][..]);
60                if directory.is_empty() {
61                    return true;
62                }
63                let dir_path = Path::new(directory);
64                filename.starts_with(dir_path)
65            } else {
66                Pattern::new(pattern)
67                    .map(|p| p.matches_path(filename))
68                    .unwrap_or(false)
69            }
70        })
71    }
72
73    /// Checks if a file matches any optional file pattern.
74    ///
75    /// # Arguments
76    ///
77    /// - `filename` - The relative path of the file to check
78    ///
79    /// # Returns
80    ///
81    /// `true` if the file matches any optional pattern, `false` otherwise.
82    pub fn is_optional_file(&self, filename: &Path) -> bool {
83        self.optional_files.iter().any(|pattern| {
84            Pattern::new(pattern)
85                .map(|p| p.matches_path(filename))
86                .unwrap_or(false)
87        })
88    }
89}
90
91/// Configuration for project versioning schemes.
92///
93/// Allows defining different version formats (YYYYMMDD, X, X.Y, X.Y.Z) for different projects.
94/// Projects not explicitly configured accept any version format.
95#[derive(Debug, Deserialize, Serialize, Default)]
96#[serde(default)]
97pub struct ProjectConfig {
98    /// List of projects.
99    pub project: Option<Vec<String>>,
100    /// List of projects that use YYYYMMDD versioning format.
101    pub version_yyyymmdd: Option<Vec<String>>,
102    /// List of projects that use single integer versioning (e.g., "1", "2").
103    pub version_x: Option<Vec<String>>,
104    /// List of projects that use two-part versioning (e.g., "1.0", "2.3").
105    pub version_x_y: Option<Vec<String>>,
106    /// List of projects that use three-part semantic versioning (e.g., "1.0.0", "2.1.3").
107    pub version_x_y_z: Option<Vec<String>>,
108}
109
110impl ProjectConfig {
111    /// Validates whether a version string follows the configured format for a project.
112    ///
113    /// # Arguments
114    ///
115    /// - `project_name` - The name of the project
116    /// - `version` - The version string to validate
117    ///
118    /// # Returns
119    ///
120    /// `true` if the version is valid for the project, `false` otherwise.
121    /// Projects without explicit configuration accept any version format.
122    ///
123    pub fn is_valid_version(&self, project_name: &str, version: &str) -> bool {
124        let version_elements: Vec<String> = version.split('.').map(|s| s.to_string()).collect();
125
126        if self.is_yyyymmdd_project(project_name) {
127            is_yyyymmdd_format(version)
128        } else if self.is_x_project(project_name) {
129            version.parse::<i64>().is_ok()
130        } else if self.is_x_y_project(project_name) {
131            version.matches('.').count() == 1
132                && version_elements.len() == 2
133                && version_elements[0].parse::<i64>().is_ok()
134                && version_elements[1].parse::<i64>().is_ok()
135        } else if self.is_x_y_z_project(project_name) {
136            version.matches('.').count() == 2
137                && version_elements.len() == 3
138                && version_elements[0].parse::<i64>().is_ok()
139                && version_elements[1].parse::<i64>().is_ok()
140                && version_elements[2].parse::<i64>().is_ok()
141        } else {
142            true
143        }
144    }
145
146    /// Checks if a project uses YYYYMMDD versioning format.
147    fn is_yyyymmdd_project(&self, project_name: &str) -> bool {
148        self.version_yyyymmdd.is_some()
149            && self
150                .version_yyyymmdd
151                .as_ref()
152                .unwrap()
153                .iter()
154                .any(|s| s == project_name)
155    }
156
157    /// Checks if a project uses single integer (X) versioning format.
158    fn is_x_project(&self, project_name: &str) -> bool {
159        self.version_x.is_some()
160            && self
161                .version_x
162                .as_ref()
163                .unwrap()
164                .iter()
165                .any(|s| s == project_name)
166    }
167
168    /// Checks if a project uses two-part (X.Y) versioning format.
169    fn is_x_y_project(&self, project_name: &str) -> bool {
170        self.version_x_y.is_some()
171            && self
172                .version_x_y
173                .as_ref()
174                .unwrap()
175                .iter()
176                .any(|s| s == project_name)
177    }
178
179    /// Checks if a project uses three-part semantic (X.Y.Z) versioning format.
180    fn is_x_y_z_project(&self, project_name: &str) -> bool {
181        self.version_x_y_z.is_some()
182            && self
183                .version_x_y_z
184                .as_ref()
185                .unwrap()
186                .iter()
187                .any(|s| s == project_name)
188    }
189}
190
191/// Configuration for documentation generation.
192///
193/// Specifies template and result files used for automatically generating
194/// documentation for ML-bins.
195#[derive(Debug, Deserialize, Serialize, Default)]
196#[serde(default)]
197pub struct DocumentConfig {
198    /// The template file for documentation generation.
199    /// If you set template_file = "template.md" in config.toml for rack "my_rack", the structure is as follows:
200    ///
201    /// - model_registry_repository/
202    ///   - {rack_name}/
203    ///     - template.md
204    ///     - config.toml
205    ///     - 0.1/
206    ///     - 0.2/
207    ///
208    pub template_file: Option<String>,
209    /// The result file for documentation generation in each version directory.
210    /// For now, only JSON format is supported.
211    /// If you set result_file = "result.json" in config.toml for rack "my_rack", the structure is as follows:
212    ///
213    /// - model_registry_repository/
214    ///   - {rack_name}/
215    ///     - template.md
216    ///     - config.toml
217    ///     - 0.1/
218    ///       - result.json
219    ///     - 0.2/
220    ///       - result.json
221    ///
222    pub result_file: Option<String>,
223}
224
225/// Loads the rack configuration by searching for `config.toml` in the directory tree.
226///
227/// This function searches upward from the given path through parent directories
228/// until it finds a `config.toml` file. This allows ml-cellar commands to be
229/// run from any subdirectory within a rack.
230///
231/// # Arguments
232///
233/// - `path` - The starting path to search from (can be a file or directory)
234///
235/// # Returns
236///
237/// A tuple `(RackConfig, PathBuf)` where:
238/// - `RackConfig`: The parsed configuration from `config.toml`
239/// - `PathBuf`: The absolute path to the directory containing `config.toml`
240///   (this is the root of the rack)
241///
242/// # Panics
243///
244/// Panics if:
245/// - `config.toml` is not found in any parent directory up to the filesystem root
246/// - The configuration file cannot be read
247/// - The TOML content cannot be parsed
248///
249pub fn load_rack_config(path: &Path) -> (RackConfig, PathBuf) {
250    // Set directory to start searching
251    let relative_dir = if path.is_dir() {
252        path.to_path_buf()
253    } else {
254        path.parent().unwrap().to_path_buf()
255    };
256
257    // Convert to absolute path
258    let mut absolute_dir = if relative_dir.is_absolute() {
259        relative_dir
260    } else {
261        env::current_dir().unwrap().join(relative_dir)
262    };
263
264    // Search for config.toml in the directory and its parents
265    loop {
266        let candidate = absolute_dir.join("config.toml");
267        if candidate.is_file() {
268            // Found config.toml, read and parse it
269            log::info!("Loading config from {:?}", candidate);
270            let config_content = std::fs::read_to_string(&candidate).unwrap();
271            return (toml::from_str(&config_content).unwrap(), absolute_dir);
272        }
273
274        match absolute_dir.parent() {
275            Some(parent) => absolute_dir = parent.to_path_buf(),
276            None => {
277                log::error!(
278                    "config.toml not found; reached filesystem root at {:?}\n\
279                     Please ensure that config.toml exists in the directory tree starting from the directory of the provided path.",
280                    absolute_dir
281                );
282                panic!(
283                    "config.toml not found in the directory tree starting from {:?}",
284                    path
285                );
286            }
287        }
288    }
289}
290
291/// Validates whether a string is in YYYYMMDD date format.
292fn is_yyyymmdd_format(s: &str) -> bool {
293    if s.len() != 8 || !s.as_bytes().iter().all(|b| b.is_ascii_digit()) {
294        return false;
295    }
296
297    let y: i32 = match s[0..4].parse() {
298        Ok(v) => v,
299        Err(_) => return false,
300    };
301    let m: u32 = match s[4..6].parse() {
302        Ok(v) => v,
303        Err(_) => return false,
304    };
305    let d: u32 = match s[6..8].parse() {
306        Ok(v) => v,
307        Err(_) => return false,
308    };
309
310    NaiveDate::from_ymd_opt(y, m, d).is_some()
311}
312
313#[cfg(test)]
314mod tests {
315    use super::*;
316    use tempfile::TempDir;
317    use toml::to_string_pretty;
318
319    #[test]
320    fn test_load_rack_config() {
321        let temp = TempDir::new().unwrap();
322        let root_directory = temp.path();
323
324        let toml_str = to_string_pretty(&RackConfig::default())
325            .expect("failed to serialize rack config to TOML");
326        std::fs::write(root_directory.join("config.toml"), toml_str)
327            .expect("failed to write config.toml");
328
329        let (config, config_dir) = load_rack_config(root_directory);
330        assert_eq!(config_dir, root_directory);
331        assert_eq!(config.rack.name, "");
332    }
333}