use anyhow::{Context, Result, bail};
use encoderfile::builder::cli::{GlobalArguments, inspect_encoderfile};
use std::{
fs,
path::Path,
process::{Command, Output},
};
use tempfile::tempdir;
const BINARY_NAME: &str = "test.encoderfile";
fn config(model_name: &String, model_path: &Path, output_path: &Path) -> String {
format!(
r##"
encoderfile:
name: {:?}
path: {:?}
model_type: token_classification
output_path: {:?}
transform: |
--- Applies a softmax across token classification logits.
--- Each token classification is normalized independently.
---
--- Args:
--- arr (Tensor): A tensor of shape [batch_size, n_tokens, n_labels].
--- The softmax is applied along the third axis (n_labels).
---
--- Returns:
--- Tensor: The input tensor with softmax-normalized embeddings.
---@param arr Tensor
---@return Tensor
function Postprocess(arr)
return arr:softmax(3)
end
"##,
model_name, model_path, output_path
)
}
const MODEL_ASSETS_PATH: &str = "../models/token_classification";
#[test]
fn test_inspect_encoderfile() -> Result<()> {
let dir = tempdir()?;
let path = dir
.path()
.canonicalize()
.expect("Failed to canonicalize temp path");
let tmp_model_path = path.join("models").join("token_classification");
let ef_config_path = path.join("encoderfile.yml");
let encoderfile_path = path.join(BINARY_NAME);
let model_name = String::from("some-custom-name");
copy_dir_all(MODEL_ASSETS_PATH, tmp_model_path.as_path())
.expect("Failed to copy model assets to temp directory");
if !tmp_model_path.join("model.onnx").exists() {
bail!(
"Path {:?} does not exist",
tmp_model_path.join("model.onnx")
);
}
let _ = Command::new("cargo")
.args(["build"])
.status()
.expect("Failed to build encoderfile-runtime");
let base_binary_path = fs::canonicalize("../target/debug/encoderfile-runtime")
.expect("Failed to canonicalize base binary path");
let ef_binary_path = fs::canonicalize("../target/debug/encoderfile")
.expect("Failed to canonicalize base binary path");
let config = config(
&model_name,
tmp_model_path.as_path(),
encoderfile_path.as_path(),
);
fs::write(ef_config_path.as_path(), config.as_bytes())
.expect("Failed to write encoderfile config");
let build_args =
encoderfile::builder::cli::test_build_args(ef_config_path.as_path(), base_binary_path);
let global_args = GlobalArguments::default();
build_args
.run(&global_args)
.context("Failed to build encoderfile")?;
let ef_path_str = String::from(
encoderfile_path
.to_str()
.expect("Encoderfile path name failed to convert to string"),
);
let _inspect_output = inspect_encoderfile(&ef_path_str)?;
let output = run_inspect_encoderfile(
ef_binary_path
.to_str()
.expect("Failed to create encoderfile binary path"),
&ef_path_str,
)?;
let stdout = String::from_utf8(output.stdout)?;
let stderr = String::from_utf8(output.stderr)?;
println!("STDOUT: {}", stdout);
println!("STDERR: {}", stderr);
let inspect_output_json = serde_json::from_str::<serde_json::Value>(&stdout)
.context("Failed to parse inspect output as JSON")?;
inspect_output_json
.get("encoderfile_config")
.and_then(|efc| efc.get("name"))
.and_then(|name| name.as_str())
.filter(|name_str| *name_str == model_name.as_str())
.ok_or_else(|| anyhow::anyhow!("Model name in inspect output does not match expected"))?;
Ok(())
}
fn copy_dir_all(src: impl AsRef<Path>, dst: impl AsRef<Path>) -> anyhow::Result<()> {
let src = src.as_ref();
let dst = dst.as_ref();
fs::create_dir_all(dst).context(format!("Failed to create directory {:?}", &dst))?;
for entry in fs::read_dir(src)? {
let entry = entry?;
let ty = entry.file_type()?;
let dest_path = dst.join(entry.file_name());
if ty.is_dir() {
copy_dir_all(entry.path(), dest_path.as_path()).context(format!(
"Failed to copy {:?} to {:?}",
entry.path(),
dest_path.as_path()
))?;
} else {
fs::copy(entry.path(), dest_path.as_path()).context(format!(
"Failed to copy {:?} to {:?}",
entry.path(),
dest_path.as_path()
))?;
}
}
Ok(())
}
fn run_inspect_encoderfile(path: &str, ef_path: &str) -> Result<Output> {
let mut cmd = Command::new(path);
cmd.arg("inspect").arg(ef_path);
println!("{:?}", cmd);
cmd.output().context("Failed inspect command")
}