ml_cellar/rack.rs
1use glob::Pattern;
2use serde::{Deserialize, Serialize};
3use std::env;
4use std::path::{Path, PathBuf};
5
6use chrono::NaiveDate;
7
8/// Configuration for a rack (model family/algorithm).
9///
10/// Contains all settings for managing a rack including artifact rules,
11/// project versioning schemes, and documentation generation.
12#[derive(Debug, Deserialize, Serialize, Default)]
13#[serde(default)]
14pub struct RackConfig {
15 pub rack: RackInfoConfig,
16 pub artifact: ArtifactConfig,
17 pub project: ProjectConfig,
18 pub document: DocumentConfig,
19}
20
21/// Basic information about the rack.
22#[derive(Debug, Deserialize, Serialize, Default)]
23#[serde(default)]
24pub struct RackInfoConfig {
25 /// The name of the rack (e.g., "vit-l", "llm-model").
26 pub name: String,
27}
28
29/// Configuration for artifact file management.
30/// Defines which files are required and which are optional in an ML-bin.
31#[derive(Debug, Deserialize, Serialize, Default)]
32#[serde(default)]
33pub struct ArtifactConfig {
34 /// List of required file patterns (supports globs and directory paths ending with /).
35 /// Files matching these patterns must exist in every ML-bin.
36 pub required_files: Vec<String>,
37 /// List of optional file patterns (supports globs).
38 /// Files matching these patterns are allowed but not required.
39 pub optional_files: Vec<String>,
40}
41
42impl ArtifactConfig {
43 /// Checks if a file matches any required file pattern.
44 ///
45 /// Supports both glob patterns and directory paths (ending with `/` or `\`).
46 /// For directory patterns, any file under that directory is considered a match.
47 ///
48 /// # Arguments
49 ///
50 /// - `filename` - The relative path of the file to check
51 ///
52 /// # Returns
53 ///
54 /// `true` if the file matches any required pattern, `false` otherwise.
55 pub fn is_required_file(&self, filename: &Path) -> bool {
56 self.required_files.iter().any(|pattern| {
57 let is_directory = pattern.ends_with('/') || pattern.ends_with('\\');
58 if is_directory {
59 let directory = pattern.trim_end_matches(&['/', '\\'][..]);
60 if directory.is_empty() {
61 return true;
62 }
63 let dir_path = Path::new(directory);
64 filename.starts_with(dir_path)
65 } else {
66 Pattern::new(pattern)
67 .map(|p| p.matches_path(filename))
68 .unwrap_or(false)
69 }
70 })
71 }
72
73 /// Checks if a file matches any optional file pattern.
74 ///
75 /// # Arguments
76 ///
77 /// - `filename` - The relative path of the file to check
78 ///
79 /// # Returns
80 ///
81 /// `true` if the file matches any optional pattern, `false` otherwise.
82 pub fn is_optional_file(&self, filename: &Path) -> bool {
83 self.optional_files.iter().any(|pattern| {
84 Pattern::new(pattern)
85 .map(|p| p.matches_path(filename))
86 .unwrap_or(false)
87 })
88 }
89}
90
91/// Configuration for project versioning schemes.
92///
93/// Allows defining different version formats (YYYYMMDD, X, X.Y, X.Y.Z) for different projects.
94/// Projects not explicitly configured accept any version format.
95#[derive(Debug, Deserialize, Serialize, Default)]
96#[serde(default)]
97pub struct ProjectConfig {
98 /// List of projects.
99 pub project: Option<Vec<String>>,
100 /// List of projects that use YYYYMMDD versioning format.
101 pub version_yyyymmdd: Option<Vec<String>>,
102 /// List of projects that use single integer versioning (e.g., "1", "2").
103 pub version_x: Option<Vec<String>>,
104 /// List of projects that use two-part versioning (e.g., "1.0", "2.3").
105 pub version_x_y: Option<Vec<String>>,
106 /// List of projects that use three-part semantic versioning (e.g., "1.0.0", "2.1.3").
107 pub version_x_y_z: Option<Vec<String>>,
108}
109
110impl ProjectConfig {
111 /// Validates whether a version string follows the configured format for a project.
112 ///
113 /// # Arguments
114 ///
115 /// - `project_name` - The name of the project
116 /// - `version` - The version string to validate
117 ///
118 /// # Returns
119 ///
120 /// `true` if the version is valid for the project, `false` otherwise.
121 /// Projects without explicit configuration accept any version format.
122 ///
123 pub fn is_valid_version(&self, project_name: &str, version: &str) -> bool {
124 let version_elements: Vec<String> = version.split('.').map(|s| s.to_string()).collect();
125
126 if self.is_yyyymmdd_project(project_name) {
127 is_yyyymmdd_format(version)
128 } else if self.is_x_project(project_name) {
129 version.parse::<i64>().is_ok()
130 } else if self.is_x_y_project(project_name) {
131 version.matches('.').count() == 1
132 && version_elements.len() == 2
133 && version_elements[0].parse::<i64>().is_ok()
134 && version_elements[1].parse::<i64>().is_ok()
135 } else if self.is_x_y_z_project(project_name) {
136 version.matches('.').count() == 2
137 && version_elements.len() == 3
138 && version_elements[0].parse::<i64>().is_ok()
139 && version_elements[1].parse::<i64>().is_ok()
140 && version_elements[2].parse::<i64>().is_ok()
141 } else {
142 true
143 }
144 }
145
146 /// Checks if a project uses YYYYMMDD versioning format.
147 fn is_yyyymmdd_project(&self, project_name: &str) -> bool {
148 self.version_yyyymmdd.is_some()
149 && self
150 .version_yyyymmdd
151 .as_ref()
152 .unwrap()
153 .iter()
154 .any(|s| s == project_name)
155 }
156
157 /// Checks if a project uses single integer (X) versioning format.
158 fn is_x_project(&self, project_name: &str) -> bool {
159 self.version_x.is_some()
160 && self
161 .version_x
162 .as_ref()
163 .unwrap()
164 .iter()
165 .any(|s| s == project_name)
166 }
167
168 /// Checks if a project uses two-part (X.Y) versioning format.
169 fn is_x_y_project(&self, project_name: &str) -> bool {
170 self.version_x_y.is_some()
171 && self
172 .version_x_y
173 .as_ref()
174 .unwrap()
175 .iter()
176 .any(|s| s == project_name)
177 }
178
179 /// Checks if a project uses three-part semantic (X.Y.Z) versioning format.
180 fn is_x_y_z_project(&self, project_name: &str) -> bool {
181 self.version_x_y_z.is_some()
182 && self
183 .version_x_y_z
184 .as_ref()
185 .unwrap()
186 .iter()
187 .any(|s| s == project_name)
188 }
189}
190
191/// Configuration for documentation generation.
192///
193/// Specifies template and result files used for automatically generating
194/// documentation for ML-bins.
195#[derive(Debug, Deserialize, Serialize, Default)]
196#[serde(default)]
197pub struct DocumentConfig {
198 /// The template file for documentation generation.
199 /// If you set template_file = "template.md" in config.toml for rack "my_rack", the structure is as follows:
200 ///
201 /// - model_registry_repository/
202 /// - {rack_name}/
203 /// - template.md
204 /// - config.toml
205 /// - 0.1/
206 /// - 0.2/
207 ///
208 pub template_file: Option<String>,
209 /// The result file for documentation generation in each version directory.
210 /// For now, only JSON format is supported.
211 /// If you set result_file = "result.json" in config.toml for rack "my_rack", the structure is as follows:
212 ///
213 /// - model_registry_repository/
214 /// - {rack_name}/
215 /// - template.md
216 /// - config.toml
217 /// - 0.1/
218 /// - result.json
219 /// - 0.2/
220 /// - result.json
221 ///
222 pub result_file: Option<String>,
223}
224
225/// Loads the rack configuration by searching for `config.toml` in the directory tree.
226///
227/// This function searches upward from the given path through parent directories
228/// until it finds a `config.toml` file. This allows ml-cellar commands to be
229/// run from any subdirectory within a rack.
230///
231/// # Arguments
232///
233/// - `path` - The starting path to search from (can be a file or directory)
234///
235/// # Returns
236///
237/// A tuple `(RackConfig, PathBuf)` where:
238/// - `RackConfig`: The parsed configuration from `config.toml`
239/// - `PathBuf`: The absolute path to the directory containing `config.toml`
240/// (this is the root of the rack)
241///
242/// # Panics
243///
244/// Panics if:
245/// - `config.toml` is not found in any parent directory up to the filesystem root
246/// - The configuration file cannot be read
247/// - The TOML content cannot be parsed
248///
249pub fn load_rack_config(path: &Path) -> (RackConfig, PathBuf) {
250 // Set directory to start searching
251 let relative_dir = if path.is_dir() {
252 path.to_path_buf()
253 } else {
254 path.parent().unwrap().to_path_buf()
255 };
256
257 // Convert to absolute path
258 let mut absolute_dir = if relative_dir.is_absolute() {
259 relative_dir
260 } else {
261 env::current_dir().unwrap().join(relative_dir)
262 };
263
264 // Search for config.toml in the directory and its parents
265 loop {
266 let candidate = absolute_dir.join("config.toml");
267 if candidate.is_file() {
268 // Found config.toml, read and parse it
269 log::info!("Loading config from {:?}", candidate);
270 let config_content = std::fs::read_to_string(&candidate).unwrap();
271 return (toml::from_str(&config_content).unwrap(), absolute_dir);
272 }
273
274 match absolute_dir.parent() {
275 Some(parent) => absolute_dir = parent.to_path_buf(),
276 None => {
277 log::error!(
278 "config.toml not found; reached filesystem root at {:?}\n\
279 Please ensure that config.toml exists in the directory tree starting from the directory of the provided path.",
280 absolute_dir
281 );
282 panic!(
283 "config.toml not found in the directory tree starting from {:?}",
284 path
285 );
286 }
287 }
288 }
289}
290
291/// Validates whether a string is in YYYYMMDD date format.
292fn is_yyyymmdd_format(s: &str) -> bool {
293 if s.len() != 8 || !s.as_bytes().iter().all(|b| b.is_ascii_digit()) {
294 return false;
295 }
296
297 let y: i32 = match s[0..4].parse() {
298 Ok(v) => v,
299 Err(_) => return false,
300 };
301 let m: u32 = match s[4..6].parse() {
302 Ok(v) => v,
303 Err(_) => return false,
304 };
305 let d: u32 = match s[6..8].parse() {
306 Ok(v) => v,
307 Err(_) => return false,
308 };
309
310 NaiveDate::from_ymd_opt(y, m, d).is_some()
311}
312
313#[cfg(test)]
314mod tests {
315 use super::*;
316 use tempfile::TempDir;
317 use toml::to_string_pretty;
318
319 #[test]
320 fn test_load_rack_config() {
321 let temp = TempDir::new().unwrap();
322 let root_directory = temp.path();
323
324 let toml_str = to_string_pretty(&RackConfig::default())
325 .expect("failed to serialize rack config to TOML");
326 std::fs::write(root_directory.join("config.toml"), toml_str)
327 .expect("failed to write config.toml");
328
329 let (config, config_dir) = load_rack_config(root_directory);
330 assert_eq!(config_dir, root_directory);
331 assert_eq!(config.rack.name, "");
332 }
333}