use crate::{
common::{ModelConfig, ModelType},
format::assets::{AssetKind, AssetSource, PlannedAsset},
generated::manifest::LuaLibs as ManifestLuaLibs,
transforms::{TransformSpec, convert_libs},
};
use anyhow::{Context, Result};
use crate::builder::config::EncoderfileConfig;
use prost::Message;
mod embedding;
mod sentence_embedding;
mod sequence_classification;
mod token_classification;
mod utils;
pub trait TransformValidatorExt: TransformSpec {
fn validate(
&self,
encoderfile_config: &EncoderfileConfig,
model_config: &ModelConfig,
) -> Result<()> {
if !encoderfile_config.validate_transform {
return Ok(());
}
if !self.has_postprocessor() {
utils::validation_err(
"Could not find `Postprocess` function in provided transform. Please make sure it exists.",
)?
}
self.dry_run(model_config)
}
fn dry_run(&self, model_config: &ModelConfig) -> Result<()>;
}
macro_rules! validate_transform {
($transform_type:ident, $transform_str:expr, $encoderfile_config:expr, $model_config:expr) => {
crate::transforms::$transform_type::new(
convert_libs($encoderfile_config.lua_libs()?.as_ref()),
Some($transform_str.clone()),
)
.with_context(|| utils::validation_err_ctx("Failed to create transform"))?
.validate($encoderfile_config, $model_config)
};
}
pub fn validate_transform<'a>(
encoderfile_config: &'a EncoderfileConfig,
model_config: &'a ModelConfig,
) -> Result<Option<PlannedAsset<'a>>> {
let transform_string = match &encoderfile_config.transform {
Some(t) => t.transform()?,
None => return Ok(None),
};
let transform_str = transform_string;
match encoderfile_config.model_type {
ModelType::Embedding => validate_transform!(
EmbeddingTransform,
transform_str,
encoderfile_config,
model_config
),
ModelType::SequenceClassification => validate_transform!(
SequenceClassificationTransform,
transform_str,
encoderfile_config,
model_config
),
ModelType::TokenClassification => validate_transform!(
TokenClassificationTransform,
transform_str,
encoderfile_config,
model_config
),
ModelType::SentenceEmbedding => validate_transform!(
SentenceEmbeddingTransform,
transform_str,
encoderfile_config,
model_config
),
}?;
let lua_libs: Option<ManifestLuaLibs> = encoderfile_config
.lua_libs
.clone()
.map(|libs| ManifestLuaLibs { libs });
let proto = crate::generated::manifest::Transform {
transform_type: crate::generated::manifest::TransformType::Lua.into(),
transform: transform_str,
lua_libs,
};
PlannedAsset::from_asset_source(
AssetSource::InMemory(std::borrow::Cow::Owned(proto.encode_to_vec())),
AssetKind::Transform,
)
.map(Some)
}
#[cfg(test)]
mod tests {
use crate::transforms::{DEFAULT_LIBS, EmbeddingTransform};
use crate::builder::config::{ModelPath, Transform};
use super::*;
fn test_encoderfile_config() -> EncoderfileConfig {
EncoderfileConfig {
name: "my-model".to_string(),
version: "0.0.1".to_string(),
path: ModelPath::Directory(std::path::PathBuf::from("models/embedding")),
model_type: ModelType::Embedding,
cache_dir: None,
output_path: None,
transform: None,
lua_libs: None,
validate_transform: true,
tokenizer: None,
base_binary_path: None,
target: None,
}
}
fn test_model_config() -> ModelConfig {
let config_json = include_str!("../../../../../models/embedding/config.json");
serde_json::from_str(config_json).unwrap()
}
#[test]
fn test_empty_transform() {
let result = EmbeddingTransform::new(DEFAULT_LIBS.to_vec(), None)
.expect("Failed to make embedding transform")
.validate(&test_encoderfile_config(), &test_model_config());
assert!(result.is_err())
}
#[test]
fn test_no_validation() {
let mut config = test_encoderfile_config();
config.validate_transform = false;
EmbeddingTransform::new(DEFAULT_LIBS.to_vec(), None)
.expect("Failed to make embedding transform")
.validate(&config, &test_model_config())
.expect("Should be ok")
}
#[test]
fn test_validate() {
let transform_str = "function Postprocess(arr) return arr end";
let encoderfile_config = EncoderfileConfig {
name: "my-model".to_string(),
version: "0.0.1".to_string(),
path: ModelPath::Directory(std::path::PathBuf::from("models/embedding")),
model_type: ModelType::Embedding,
cache_dir: None,
output_path: None,
transform: Some(Transform::Inline(transform_str.to_string())),
lua_libs: None,
validate_transform: true,
tokenizer: None,
base_binary_path: None,
target: None,
};
let model_config_str = include_str!(concat!(
"../../../../../models/",
"embedding",
"/config.json"
));
let model_config =
serde_json::from_str(model_config_str).expect("Failed to create model config");
validate_transform(&encoderfile_config, &model_config).expect("Failed to validate");
}
#[test]
fn test_validate_empty() {
let encoderfile_config = EncoderfileConfig {
name: "my-model".to_string(),
version: "0.0.1".to_string(),
path: ModelPath::Directory(std::path::PathBuf::from("models/embedding")),
model_type: ModelType::Embedding,
cache_dir: None,
output_path: None,
transform: None,
lua_libs: None,
validate_transform: true,
tokenizer: None,
base_binary_path: None,
target: None,
};
let model_config_str = include_str!(concat!(
"../../../../../models/",
"embedding",
"/config.json"
));
let model_config =
serde_json::from_str(model_config_str).expect("Failed to create model config");
validate_transform(&encoderfile_config, &model_config).expect("Failed to validate");
}
}