1use anyhow::{Context, Result, bail};
2use encoderfile_core::common::{Config as EmbeddedConfig, ModelConfig, ModelType};
3use schemars::JsonSchema;
4use std::{
5 fs::File,
6 io::{BufReader, Read},
7 path::PathBuf,
8};
9
10use super::model::ModelTypeExt as _;
11use figment::{
12 Figment,
13 providers::{Format, Yaml},
14};
15use serde::{Deserialize, Serialize};
16use sha2::{Digest, Sha256};
17
18#[derive(Debug, Serialize, Deserialize, JsonSchema)]
19pub struct BuildConfig {
20 pub encoderfile: EncoderfileConfig,
21}
22
23impl BuildConfig {
24 pub fn load(path: &PathBuf) -> Result<Self> {
25 let config = Figment::new().merge(Yaml::file(path)).extract()?;
26
27 Ok(config)
28 }
29}
30
31#[derive(Debug, Serialize, Deserialize, JsonSchema)]
32pub struct EncoderfileConfig {
33 pub name: String,
34 #[serde(default = "default_version")]
35 pub version: String,
36 pub path: ModelPath,
37 pub model_type: ModelType,
38 pub output_path: Option<PathBuf>,
39 pub cache_dir: Option<PathBuf>,
40 pub transform: Option<Transform>,
41 pub tokenizer: Option<TokenizerBuildConfig>,
42 #[serde(default = "default_validate_transform")]
43 pub validate_transform: bool,
44 #[serde(default = "default_build")]
45 pub build: bool,
46}
47
48impl EncoderfileConfig {
49 pub fn embedded_config(&self) -> Result<EmbeddedConfig> {
50 let tokenizer = self.validate_tokenizer()?;
51 let config = EmbeddedConfig {
52 name: self.name.clone(),
53 version: self.version.clone(),
54 model_type: self.model_type.clone(),
55 transform: self.transform()?,
56 tokenizer,
57 };
58
59 Ok(config)
60 }
61
62 pub fn model_config(&self) -> Result<ModelConfig> {
63 let model_config_path = self.path.model_config_path()?;
64
65 let file = File::open(model_config_path)?;
66
67 let reader = BufReader::new(file);
68
69 serde_json::from_reader(reader).with_context(|| "Failed to deserialize model config")
70 }
71
72 pub fn output_path(&self) -> PathBuf {
73 match &self.output_path {
74 Some(p) => p.to_path_buf(),
75 None => {
76 println!("No output path detected. Saving to current directory...");
77 std::env::current_dir()
78 .expect("Can't even find the current dir? Tragic. (no seriously please open an issue)")
79 .join(format!("{}.encoderfile", self.name))
80 }
81 }
82 }
83
84 pub fn cache_dir(&self) -> PathBuf {
85 match &self.cache_dir {
86 Some(c) => c.to_path_buf(),
87 None => default_cache_dir(),
88 }
89 }
90
91 pub fn transform(&self) -> Result<Option<String>> {
92 let transform = match &self.transform {
93 None => None,
94 Some(s) => Some(s.transform()?),
95 };
96
97 Ok(transform)
98 }
99
100 pub fn to_tera_ctx(&self) -> Result<tera::Context> {
101 let mut ctx = tera::Context::new();
102 let embedded_config = self.embedded_config()?;
103
104 ctx.insert("version", embedded_config.version.as_str());
105 ctx.insert("config_str", &serde_json::to_string(&embedded_config)?);
106 ctx.insert("model_type", self.model_type.to_ident());
107 ctx.insert("model_weights_path", &self.path.model_weights_path()?);
108 ctx.insert("tokenizer_path", &self.path.tokenizer_path()?);
109 ctx.insert("model_config_path", &self.path.model_config_path()?);
110 ctx.insert("encoderfile_version_str", &encoderfile_core_version());
111
112 Ok(ctx)
113 }
114
115 pub fn get_generated_dir(&self) -> PathBuf {
116 let filename_hash = Sha256::digest(self.name.as_bytes());
117
118 self.cache_dir()
119 .join(format!("encoderfile-{:x}", filename_hash))
120 }
121}
122
123#[derive(Debug, Serialize, Deserialize, JsonSchema)]
124pub struct TokenizerBuildConfig {
125 pub pad_strategy: Option<TokenizerPadStrategy>,
126}
127
128#[derive(Debug, Serialize, Deserialize, JsonSchema)]
129#[serde(untagged, rename_all = "snake_case")]
130pub enum TokenizerPadStrategy {
131 BatchLongest,
132 Fixed { fixed: usize },
133}
134
135#[derive(Debug, Serialize, Deserialize, JsonSchema)]
136#[serde(untagged)]
137pub enum Transform {
138 Path { path: PathBuf },
139 Inline(String),
140}
141
142impl Transform {
143 pub fn transform(&self) -> Result<String> {
144 match self {
145 Self::Path { path } => {
146 if !path.exists() {
147 bail!("No such file: {:?}", &path);
148 }
149
150 let mut code = String::new();
151
152 File::open(path)?.read_to_string(&mut code)?;
153
154 Ok(code)
155 }
156 Self::Inline(s) => Ok(s.clone()),
157 }
158 .map(|i| i.trim().to_string())
159 }
160}
161
162#[derive(Debug, Serialize, Deserialize, JsonSchema, Clone)]
163#[serde(untagged)]
164pub enum ModelPath {
165 Directory(PathBuf),
166 Paths {
167 model_config_path: PathBuf,
168 model_weights_path: PathBuf,
169 tokenizer_path: PathBuf,
170 tokenizer_config_path: Option<PathBuf>,
171 },
172}
173
174impl ModelPath {
175 fn resolve(
176 &self,
177 explicit: Option<PathBuf>,
178 default: impl FnOnce(&PathBuf) -> PathBuf,
179 err: &str,
180 ) -> Result<Option<PathBuf>> {
181 let path = match self {
182 Self::Paths { .. } => explicit,
183 Self::Directory(dir) => {
184 if !dir.is_dir() {
185 bail!("No such directory: {:?}", dir);
186 }
187 Some(default(dir))
188 }
189 };
190
191 match path {
192 Some(p) => {
193 if !p.try_exists()? {
194 bail!("Could not locate {} at path: {:?}", err, p);
195 }
196 Ok(Some(p.canonicalize()?))
197 }
198 None => Ok(None),
199 }
200 }
201}
202
203macro_rules! asset_path {
204 (@Optional $name:ident, $default:expr, $err:expr) => {
205 pub fn $name(&self) -> Result<Option<PathBuf>> {
206 let explicit = match self {
207 Self::Paths { $name, .. } => $name.clone(),
208 _ => None,
209 };
210
211 self.resolve(explicit, |dir| dir.join($default), $err)
212 }
213 };
214
215 ($name:ident, $default:expr, $err:expr) => {
216 pub fn $name(&self) -> Result<PathBuf> {
217 let explicit = match self {
218 Self::Paths { $name, .. } => Some($name.clone()),
219 _ => None,
220 };
221
222 self.resolve(explicit, |dir| dir.join($default), $err)?
223 .ok_or_else(|| anyhow::anyhow!("Missing required path: {}", $err))
224 }
225 };
226}
227
228impl ModelPath {
229 asset_path!(model_config_path, "config.json", "model config");
230 asset_path!(tokenizer_path, "tokenizer.json", "tokenizer");
231 asset_path!(model_weights_path, "model.onnx", "model weights");
232 asset_path!(@Optional tokenizer_config_path, "tokenizer_config.json", "tokenizer config");
233}
234
235fn default_cache_dir() -> PathBuf {
236 directories::ProjectDirs::from("com", "mozilla-ai", "encoderfile")
237 .expect("Cannot locate")
238 .cache_dir()
239 .to_path_buf()
240}
241
242fn default_version() -> String {
243 "0.1.0".to_string()
244}
245
246fn default_build() -> bool {
247 true
248}
249
250fn default_validate_transform() -> bool {
251 true
252}
253
254fn encoderfile_core_version() -> &'static str {
255 env!("ENCODERFILE_CORE_DEP_STR")
256}
257
258#[cfg(test)]
259mod tests {
260 use super::*;
261 use std::{fs, path::PathBuf};
262
263 fn create_test_dir(name: &str) -> PathBuf {
265 let base = std::env::temp_dir().join(format!(
266 "encoderfile-test-{}-{}",
267 name,
268 uuid::Uuid::new_v4()
269 ));
270 fs::create_dir_all(&base).unwrap();
271 base
272 }
273
274 fn create_temp_output_dir() -> PathBuf {
276 create_test_dir("model")
277 }
278
279 fn create_temp_model_dir() -> PathBuf {
281 let base = create_test_dir("model");
282 fs::write(base.join("config.json"), "{}").expect("Failed to create config.json");
283 fs::write(base.join("tokenizer.json"), "{}").expect("Failed to create tokenizer.json");
284 fs::write(base.join("model.onnx"), "onnx").expect("Failed to create model.onnx");
285 fs::write(base.join("tokenizer_config.json"), "{}")
286 .expect("Failed to create tokenizer_config.json");
287 base
288 }
289
290 fn cleanup(path: &PathBuf) {
292 let _ = fs::remove_dir_all(path);
293 }
294
295 #[test]
296 fn test_get_encoderfile_core_version() {
297 encoderfile_core_version();
298 }
299
300 #[test]
301 fn test_modelpath_directory_valid() {
302 let base = create_temp_model_dir();
303 let mp = ModelPath::Directory(base.clone());
304
305 assert!(mp.model_config_path().unwrap().ends_with("config.json"));
306 assert!(mp.tokenizer_path().unwrap().ends_with("tokenizer.json"));
307 assert!(mp.model_weights_path().unwrap().ends_with("model.onnx"));
308 assert!(
309 mp.tokenizer_config_path()
310 .unwrap()
311 .unwrap()
312 .ends_with("tokenizer_config.json")
313 );
314
315 cleanup(&base);
316 }
317
318 #[test]
319 fn test_modelpath_directory_missing_file() {
320 let base = create_test_dir("missing");
321 let mp = ModelPath::Directory(base.clone());
322
323 let err = mp.model_config_path().unwrap_err();
324 assert!(err.to_string().contains("model config"));
325
326 cleanup(&base);
327 }
328
329 #[test]
330 fn test_modelpath_explicit_paths() {
331 let base = create_temp_model_dir();
332 let mp = ModelPath::Paths {
333 model_config_path: base.join("config.json"),
334 tokenizer_path: base.join("tokenizer.json"),
335 model_weights_path: base.join("model.onnx"),
336 tokenizer_config_path: Some(base.join("tokenizer_config.json")),
337 };
338
339 assert!(mp.model_config_path().is_ok());
340
341 cleanup(&base);
342 }
343
344 #[test]
345 fn test_transform_inline() {
346 let t = Transform::Inline(" hello world ".into());
347 assert_eq!(t.transform().unwrap(), "hello world");
348 }
349
350 #[test]
351 fn test_transform_path() {
352 let dir = create_test_dir("transform");
353 let file = dir.join("script.txt");
354
355 fs::write(&file, " goodbye world ").unwrap();
356
357 let t = Transform::Path { path: file };
358 assert_eq!(t.transform().unwrap(), "goodbye world");
359
360 cleanup(&dir);
361 }
362
363 #[test]
364 fn test_transform_missing_file() {
365 let bogus = PathBuf::from("totally-does-not-exist.txt");
366 let t = Transform::Path {
367 path: bogus.clone(),
368 };
369
370 let err = t.transform().unwrap_err();
371 assert!(err.to_string().contains("No such file"));
372 }
373
374 #[test]
375 fn test_encoderfile_generated_dir() {
376 let base = create_temp_output_dir();
377
378 let cfg = EncoderfileConfig {
379 name: "my-cool-model".into(),
380 version: "1.0".into(),
381 path: ModelPath::Directory("../models/embedding".into()),
382 model_type: ModelType::Embedding,
383 output_path: Some(base.clone()),
384 cache_dir: Some(base.clone()),
385 validate_transform: false,
386 transform: None,
387 tokenizer: None,
388 build: true,
389 };
390
391 let generated = cfg.get_generated_dir();
392 assert!(generated.to_string_lossy().contains("encoderfile-"));
393
394 cleanup(&base);
395 }
396
397 #[test]
398 fn test_encoderfile_to_tera_ctx() {
399 let base = create_temp_output_dir();
400 let cfg = EncoderfileConfig {
401 name: "sadness".into(),
402 version: "0.1.0".into(),
403 path: ModelPath::Directory("../models/embedding".into()),
404 model_type: ModelType::SequenceClassification,
405 output_path: Some(base.clone()),
406 cache_dir: Some(base.clone()),
407 validate_transform: false,
408 transform: Some(Transform::Inline("1+1".into())),
409 tokenizer: None,
410 build: true,
411 };
412
413 let _ctx = cfg.to_tera_ctx().expect("Tera ctx error");
414
415 cleanup(&base);
416 }
417
418 #[test]
419 fn test_config_loading() {
420 let dir = create_test_dir("config");
421 let path = dir.join("config.yml");
422
423 let yaml = r#"
424encoderfile:
425 name: testy
426 version: "0.9.0"
427 path: "./"
428 model_type: embedding
429"#;
430
431 fs::write(&path, yaml).unwrap();
432
433 let cfg = BuildConfig::load(&path).unwrap();
434 assert_eq!(cfg.encoderfile.name, "testy");
435 assert_eq!(cfg.encoderfile.version, "0.9.0");
436
437 cleanup(&dir);
438 }
439}