encoderfile/
config.rs

1use anyhow::{Context, Result, bail};
2use encoderfile_core::common::{Config as EmbeddedConfig, ModelConfig, ModelType};
3use schemars::JsonSchema;
4use std::{
5    fs::File,
6    io::{BufReader, Read},
7    path::PathBuf,
8};
9
10use super::model::ModelTypeExt as _;
11use figment::{
12    Figment,
13    providers::{Format, Yaml},
14};
15use serde::{Deserialize, Serialize};
16use sha2::{Digest, Sha256};
17
18#[derive(Debug, Serialize, Deserialize, JsonSchema)]
19pub struct BuildConfig {
20    pub encoderfile: EncoderfileConfig,
21}
22
23impl BuildConfig {
24    pub fn load(path: &PathBuf) -> Result<Self> {
25        let config = Figment::new().merge(Yaml::file(path)).extract()?;
26
27        Ok(config)
28    }
29}
30
31#[derive(Debug, Serialize, Deserialize, JsonSchema)]
32pub struct EncoderfileConfig {
33    pub name: String,
34    #[serde(default = "default_version")]
35    pub version: String,
36    pub path: ModelPath,
37    pub model_type: ModelType,
38    pub output_path: Option<PathBuf>,
39    pub cache_dir: Option<PathBuf>,
40    pub transform: Option<Transform>,
41    pub tokenizer: Option<TokenizerBuildConfig>,
42    #[serde(default = "default_validate_transform")]
43    pub validate_transform: bool,
44    #[serde(default = "default_build")]
45    pub build: bool,
46}
47
48impl EncoderfileConfig {
49    pub fn embedded_config(&self) -> Result<EmbeddedConfig> {
50        let tokenizer = self.validate_tokenizer()?;
51        let config = EmbeddedConfig {
52            name: self.name.clone(),
53            version: self.version.clone(),
54            model_type: self.model_type.clone(),
55            transform: self.transform()?,
56            tokenizer,
57        };
58
59        Ok(config)
60    }
61
62    pub fn model_config(&self) -> Result<ModelConfig> {
63        let model_config_path = self.path.model_config_path()?;
64
65        let file = File::open(model_config_path)?;
66
67        let reader = BufReader::new(file);
68
69        serde_json::from_reader(reader).with_context(|| "Failed to deserialize model config")
70    }
71
72    pub fn output_path(&self) -> PathBuf {
73        match &self.output_path {
74            Some(p) => p.to_path_buf(),
75            None => {
76                println!("No output path detected. Saving to current directory...");
77                std::env::current_dir()
78                    .expect("Can't even find the current dir? Tragic. (no seriously please open an issue)")
79                    .join(format!("{}.encoderfile", self.name))
80            }
81        }
82    }
83
84    pub fn cache_dir(&self) -> PathBuf {
85        match &self.cache_dir {
86            Some(c) => c.to_path_buf(),
87            None => default_cache_dir(),
88        }
89    }
90
91    pub fn transform(&self) -> Result<Option<String>> {
92        let transform = match &self.transform {
93            None => None,
94            Some(s) => Some(s.transform()?),
95        };
96
97        Ok(transform)
98    }
99
100    pub fn to_tera_ctx(&self) -> Result<tera::Context> {
101        let mut ctx = tera::Context::new();
102        let embedded_config = self.embedded_config()?;
103
104        ctx.insert("version", embedded_config.version.as_str());
105        ctx.insert("config_str", &serde_json::to_string(&embedded_config)?);
106        ctx.insert("model_type", self.model_type.to_ident());
107        ctx.insert("model_weights_path", &self.path.model_weights_path()?);
108        ctx.insert("tokenizer_path", &self.path.tokenizer_path()?);
109        ctx.insert("model_config_path", &self.path.model_config_path()?);
110        ctx.insert("encoderfile_version_str", &encoderfile_core_version());
111
112        Ok(ctx)
113    }
114
115    pub fn get_generated_dir(&self) -> PathBuf {
116        let filename_hash = Sha256::digest(self.name.as_bytes());
117
118        self.cache_dir()
119            .join(format!("encoderfile-{:x}", filename_hash))
120    }
121}
122
123#[derive(Debug, Serialize, Deserialize, JsonSchema)]
124pub struct TokenizerBuildConfig {
125    pub pad_strategy: Option<TokenizerPadStrategy>,
126}
127
128#[derive(Debug, Serialize, Deserialize, JsonSchema)]
129#[serde(untagged, rename_all = "snake_case")]
130pub enum TokenizerPadStrategy {
131    BatchLongest,
132    Fixed { fixed: usize },
133}
134
135#[derive(Debug, Serialize, Deserialize, JsonSchema)]
136#[serde(untagged)]
137pub enum Transform {
138    Path { path: PathBuf },
139    Inline(String),
140}
141
142impl Transform {
143    pub fn transform(&self) -> Result<String> {
144        match self {
145            Self::Path { path } => {
146                if !path.exists() {
147                    bail!("No such file: {:?}", &path);
148                }
149
150                let mut code = String::new();
151
152                File::open(path)?.read_to_string(&mut code)?;
153
154                Ok(code)
155            }
156            Self::Inline(s) => Ok(s.clone()),
157        }
158        .map(|i| i.trim().to_string())
159    }
160}
161
162#[derive(Debug, Serialize, Deserialize, JsonSchema, Clone)]
163#[serde(untagged)]
164pub enum ModelPath {
165    Directory(PathBuf),
166    Paths {
167        model_config_path: PathBuf,
168        model_weights_path: PathBuf,
169        tokenizer_path: PathBuf,
170        tokenizer_config_path: Option<PathBuf>,
171    },
172}
173
174impl ModelPath {
175    fn resolve(
176        &self,
177        explicit: Option<PathBuf>,
178        default: impl FnOnce(&PathBuf) -> PathBuf,
179        err: &str,
180    ) -> Result<Option<PathBuf>> {
181        let path = match self {
182            Self::Paths { .. } => explicit,
183            Self::Directory(dir) => {
184                if !dir.is_dir() {
185                    bail!("No such directory: {:?}", dir);
186                }
187                Some(default(dir))
188            }
189        };
190
191        match path {
192            Some(p) => {
193                if !p.try_exists()? {
194                    bail!("Could not locate {} at path: {:?}", err, p);
195                }
196                Ok(Some(p.canonicalize()?))
197            }
198            None => Ok(None),
199        }
200    }
201}
202
203macro_rules! asset_path {
204    (@Optional $name:ident, $default:expr, $err:expr) => {
205        pub fn $name(&self) -> Result<Option<PathBuf>> {
206            let explicit = match self {
207                Self::Paths { $name, .. } => $name.clone(),
208                _ => None,
209            };
210
211            self.resolve(explicit, |dir| dir.join($default), $err)
212        }
213    };
214
215    ($name:ident, $default:expr, $err:expr) => {
216        pub fn $name(&self) -> Result<PathBuf> {
217            let explicit = match self {
218                Self::Paths { $name, .. } => Some($name.clone()),
219                _ => None,
220            };
221
222            self.resolve(explicit, |dir| dir.join($default), $err)?
223                .ok_or_else(|| anyhow::anyhow!("Missing required path: {}", $err))
224        }
225    };
226}
227
228impl ModelPath {
229    asset_path!(model_config_path, "config.json", "model config");
230    asset_path!(tokenizer_path, "tokenizer.json", "tokenizer");
231    asset_path!(model_weights_path, "model.onnx", "model weights");
232    asset_path!(@Optional tokenizer_config_path, "tokenizer_config.json", "tokenizer config");
233}
234
235fn default_cache_dir() -> PathBuf {
236    directories::ProjectDirs::from("com", "mozilla-ai", "encoderfile")
237        .expect("Cannot locate")
238        .cache_dir()
239        .to_path_buf()
240}
241
242fn default_version() -> String {
243    "0.1.0".to_string()
244}
245
246fn default_build() -> bool {
247    true
248}
249
250fn default_validate_transform() -> bool {
251    true
252}
253
254fn encoderfile_core_version() -> &'static str {
255    env!("ENCODERFILE_CORE_DEP_STR")
256}
257
258#[cfg(test)]
259mod tests {
260    use super::*;
261    use std::{fs, path::PathBuf};
262
263    // Create a stable, normal directory under the system temp dir
264    fn create_test_dir(name: &str) -> PathBuf {
265        let base = std::env::temp_dir().join(format!(
266            "encoderfile-test-{}-{}",
267            name,
268            uuid::Uuid::new_v4()
269        ));
270        fs::create_dir_all(&base).unwrap();
271        base
272    }
273
274    // Create temp output dir
275    fn create_temp_output_dir() -> PathBuf {
276        create_test_dir("model")
277    }
278
279    // Create a model dir populated with the required files
280    fn create_temp_model_dir() -> PathBuf {
281        let base = create_test_dir("model");
282        fs::write(base.join("config.json"), "{}").expect("Failed to create config.json");
283        fs::write(base.join("tokenizer.json"), "{}").expect("Failed to create tokenizer.json");
284        fs::write(base.join("model.onnx"), "onnx").expect("Failed to create model.onnx");
285        fs::write(base.join("tokenizer_config.json"), "{}")
286            .expect("Failed to create tokenizer_config.json");
287        base
288    }
289
290    // Clean up (best-effort, don't panic)
291    fn cleanup(path: &PathBuf) {
292        let _ = fs::remove_dir_all(path);
293    }
294
295    #[test]
296    fn test_get_encoderfile_core_version() {
297        encoderfile_core_version();
298    }
299
300    #[test]
301    fn test_modelpath_directory_valid() {
302        let base = create_temp_model_dir();
303        let mp = ModelPath::Directory(base.clone());
304
305        assert!(mp.model_config_path().unwrap().ends_with("config.json"));
306        assert!(mp.tokenizer_path().unwrap().ends_with("tokenizer.json"));
307        assert!(mp.model_weights_path().unwrap().ends_with("model.onnx"));
308        assert!(
309            mp.tokenizer_config_path()
310                .unwrap()
311                .unwrap()
312                .ends_with("tokenizer_config.json")
313        );
314
315        cleanup(&base);
316    }
317
318    #[test]
319    fn test_modelpath_directory_missing_file() {
320        let base = create_test_dir("missing");
321        let mp = ModelPath::Directory(base.clone());
322
323        let err = mp.model_config_path().unwrap_err();
324        assert!(err.to_string().contains("model config"));
325
326        cleanup(&base);
327    }
328
329    #[test]
330    fn test_modelpath_explicit_paths() {
331        let base = create_temp_model_dir();
332        let mp = ModelPath::Paths {
333            model_config_path: base.join("config.json"),
334            tokenizer_path: base.join("tokenizer.json"),
335            model_weights_path: base.join("model.onnx"),
336            tokenizer_config_path: Some(base.join("tokenizer_config.json")),
337        };
338
339        assert!(mp.model_config_path().is_ok());
340
341        cleanup(&base);
342    }
343
344    #[test]
345    fn test_transform_inline() {
346        let t = Transform::Inline("  hello world   ".into());
347        assert_eq!(t.transform().unwrap(), "hello world");
348    }
349
350    #[test]
351    fn test_transform_path() {
352        let dir = create_test_dir("transform");
353        let file = dir.join("script.txt");
354
355        fs::write(&file, "   goodbye world ").unwrap();
356
357        let t = Transform::Path { path: file };
358        assert_eq!(t.transform().unwrap(), "goodbye world");
359
360        cleanup(&dir);
361    }
362
363    #[test]
364    fn test_transform_missing_file() {
365        let bogus = PathBuf::from("totally-does-not-exist.txt");
366        let t = Transform::Path {
367            path: bogus.clone(),
368        };
369
370        let err = t.transform().unwrap_err();
371        assert!(err.to_string().contains("No such file"));
372    }
373
374    #[test]
375    fn test_encoderfile_generated_dir() {
376        let base = create_temp_output_dir();
377
378        let cfg = EncoderfileConfig {
379            name: "my-cool-model".into(),
380            version: "1.0".into(),
381            path: ModelPath::Directory("../models/embedding".into()),
382            model_type: ModelType::Embedding,
383            output_path: Some(base.clone()),
384            cache_dir: Some(base.clone()),
385            validate_transform: false,
386            transform: None,
387            tokenizer: None,
388            build: true,
389        };
390
391        let generated = cfg.get_generated_dir();
392        assert!(generated.to_string_lossy().contains("encoderfile-"));
393
394        cleanup(&base);
395    }
396
397    #[test]
398    fn test_encoderfile_to_tera_ctx() {
399        let base = create_temp_output_dir();
400        let cfg = EncoderfileConfig {
401            name: "sadness".into(),
402            version: "0.1.0".into(),
403            path: ModelPath::Directory("../models/embedding".into()),
404            model_type: ModelType::SequenceClassification,
405            output_path: Some(base.clone()),
406            cache_dir: Some(base.clone()),
407            validate_transform: false,
408            transform: Some(Transform::Inline("1+1".into())),
409            tokenizer: None,
410            build: true,
411        };
412
413        let _ctx = cfg.to_tera_ctx().expect("Tera ctx error");
414
415        cleanup(&base);
416    }
417
418    #[test]
419    fn test_config_loading() {
420        let dir = create_test_dir("config");
421        let path = dir.join("config.yml");
422
423        let yaml = r#"
424encoderfile:
425  name: testy
426  version: "0.9.0"
427  path: "./"
428  model_type: embedding
429"#;
430
431        fs::write(&path, yaml).unwrap();
432
433        let cfg = BuildConfig::load(&path).unwrap();
434        assert_eq!(cfg.encoderfile.name, "testy");
435        assert_eq!(cfg.encoderfile.version, "0.9.0");
436
437        cleanup(&dir);
438    }
439}