use super::{
TransformValidatorExt,
utils::{
BATCH_SIZE, HIDDEN_DIM, SEQ_LEN, create_dummy_attention_mask, random_tensor,
validation_err, validation_err_ctx,
},
};
use crate::{
common::ModelConfig,
transforms::{Postprocessor, SentenceEmbeddingTransform},
};
use anyhow::{Context, Result};
impl TransformValidatorExt for SentenceEmbeddingTransform {
fn dry_run(&self, _model_config: &ModelConfig) -> Result<()> {
let dummy_hidden_states = random_tensor(&[BATCH_SIZE, SEQ_LEN, HIDDEN_DIM], (-1.0, 1.0))?;
let dummy_attention_mask = create_dummy_attention_mask(BATCH_SIZE, SEQ_LEN, 3)?;
let shape = dummy_hidden_states.shape().to_owned();
let res = self.postprocess((dummy_hidden_states, dummy_attention_mask))
.with_context(|| {
validation_err_ctx(
format!(
"Failed to run postprocessing on dummy hidden states (randomly generated in range -1.0..1.0) of shape {:?}",
shape.as_slice(),
)
)
})?;
if res.ndim() != 2 {
validation_err(format!(
"Transform must return tensor of rank 3. Got tensor of shape {:?}.",
res.shape()
))?
}
if res.shape()[0] != BATCH_SIZE {
validation_err(format!(
"Transform must preserve batch size [{}, *]. Got shape {:?}",
BATCH_SIZE,
res.shape()
))?
}
if res.shape()[1] < 1 {
validation_err(format!(
"Transform returned a tensor with last dimension 0. Shape: {:?}",
res.shape()
))?
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use crate::builder::config::{EncoderfileConfig, ModelPath};
use crate::common::ModelType;
use crate::transforms::DEFAULT_LIBS;
use super::*;
fn test_encoderfile_config() -> EncoderfileConfig {
EncoderfileConfig {
name: "my-model".to_string(),
version: "0.0.1".to_string(),
path: ModelPath::Directory(std::path::PathBuf::from("models/sentence_embedding")),
model_type: ModelType::SentenceEmbedding,
cache_dir: None,
output_path: None,
transform: None,
validate_transform: true,
lua_libs: None,
tokenizer: None,
base_binary_path: None,
target: None,
}
}
fn test_model_config() -> ModelConfig {
let config_json = include_str!("../../../../../models/sentence_embedding/config.json");
serde_json::from_str(config_json).unwrap()
}
#[test]
fn test_successful_mean_pool() {
let encoderfile_config = test_encoderfile_config();
let model_config = test_model_config();
SentenceEmbeddingTransform::new(
DEFAULT_LIBS.to_vec(),
Some("function Postprocess(arr, mask) return arr:mean_pool(mask) end".to_string()),
)
.expect("Failed to create transform")
.validate(&encoderfile_config, &model_config)
.expect("Failed to validate");
}
#[test]
fn test_bad_return_type() {
let encoderfile_config = test_encoderfile_config();
let model_config = test_model_config();
let result = SentenceEmbeddingTransform::new(
DEFAULT_LIBS.to_vec(),
Some("function Postprocess(arr, mask) return 1 end".to_string()),
)
.expect("Failed to create transform")
.validate(&encoderfile_config, &model_config);
assert!(result.is_err());
}
#[test]
fn test_bad_dimensionality() {
let encoderfile_config = test_encoderfile_config();
let model_config = test_model_config();
let result = SentenceEmbeddingTransform::new(
DEFAULT_LIBS.to_vec(),
Some("function Postprocess(arr, mask) return arr end".to_string()),
)
.expect("Failed to create transform")
.validate(&encoderfile_config, &model_config);
assert!(result.is_err());
}
}