use anyhow::{Context, Result, bail};
use encoderfile::builder::cli::GlobalArguments;
use std::{
fs,
path::Path,
process::{Child, Command},
thread::sleep,
time::{Duration, Instant},
};
use tempfile::tempdir;
use encoderfile::common;
tonic::include_proto!("encoderfile.metadata");
use encoderfile::generated::token_classification;
const BINARY_NAME: &str = "test.encoderfile";
fn config(model_name: &String, model_path: &Path, output_path: &Path) -> String {
format!(
r##"
encoderfile:
name: {:?}
path: {:?}
model_type: token_classification
output_path: {:?}
lua_libs:
- coroutine
- table
- io
- os
- package
- string
- utf8
- math
transform: |
--- Applies a softmax across token classification logits.
--- Each token classification is normalized independently.
---
--- Args:
--- arr (Tensor): A tensor of shape [batch_size, n_tokens, n_labels].
--- The softmax is applied along the third axis (n_labels).
---
--- Returns:
--- Tensor: The input tensor with softmax-normalized embeddings.
---@param arr Tensor
---@return Tensor
p = package.path
function MyCoroutine()
return Tensor({{1, 2, 3}})
end
function Postprocess(arr)
local mycor = coroutine.create(MyCoroutine)
local _, tensor = coroutine.resume(mycor)
io.stderr:write("This is a test of the IO library\n")
local fp_values = {{}}
for point in utf8.codes("hello") do
table.insert(fp_values, point)
end
local t = os.time()
return arr:softmax(3)
end
"##,
model_name, model_path, output_path
)
}
const MODEL_ASSETS_PATH: &str = "../models/token_classification";
#[tokio::test]
async fn test_build_encoderfile() -> Result<()> {
let dir = tempdir()?;
let path = dir
.path()
.canonicalize()
.expect("Failed to canonicalize temp path");
let tmp_model_path = path.join("models").join("token_classification");
let ef_config_path = path.join("encoderfile.yml");
let encoderfile_path = path.join(BINARY_NAME);
let http_port = "8080";
let grpc_port = "9090";
let sample_text =
"Hugging Face is a technology company based in New York and Paris.".to_string();
copy_dir_all(MODEL_ASSETS_PATH, tmp_model_path.as_path())
.expect("Failed to copy model assets to temp directory");
if !tmp_model_path.join("model.onnx").exists() {
bail!(
"Path {:?} does not exist",
tmp_model_path.join("model.onnx")
);
}
let _ = Command::new("cargo")
.args(["build", "-p", "encoderfile-runtime"])
.status()
.expect("Failed to build encoderfile-runtime");
let base_binary_path = fs::canonicalize("../target/debug/encoderfile-runtime")
.expect("Failed to canonicalize base binary path");
let config = config(
&String::from("test-model"),
tmp_model_path.as_path(),
Path::new(".").join(BINARY_NAME).as_path(), );
fs::write(ef_config_path.as_path(), config.as_bytes())
.expect("Failed to write encoderfile config");
let dummy_dir = path.join("dummy");
fs::create_dir(&dummy_dir).expect("Failed to create dummy directory");
std::env::set_current_dir(&dummy_dir)
.expect("Failed to change current directory to dummy directory");
let build_args = encoderfile::builder::cli::test_build_args_working_dir(
ef_config_path.as_path(),
base_binary_path,
&path,
);
let global_args = GlobalArguments::default();
build_args
.run(&global_args)
.context("Failed to build encoderfile")?;
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
let mut perms = fs::metadata(encoderfile_path.as_path())
.expect("Failed to get path for built encoderfile")
.permissions();
perms.set_mode(0o755);
fs::set_permissions(encoderfile_path.as_path(), perms).expect("Failed to set permissions");
}
let mut child = spawn_encoderfile(
encoderfile_path
.to_str()
.expect("Failed to create encoderfile binary path"),
http_port,
grpc_port,
)?;
wait_for_http(
format!("http://localhost:{http_port}/health").as_str(),
Duration::from_secs(10),
)
.await?;
send_http_inference(&sample_text, http_port.to_string()).await?;
send_grpc_inference(&sample_text, grpc_port.to_string()).await?;
child.kill()?;
child.wait().ok();
Ok(())
}
async fn wait_for_http(url: &str, timeout: Duration) -> Result<()> {
let client = reqwest::Client::new();
let start = Instant::now();
loop {
if start.elapsed() > timeout {
anyhow::bail!("server did not become ready in time");
}
if let Ok(resp) = client.get(url).send().await
&& resp.status().is_success()
{
return Ok(());
}
sleep(Duration::from_millis(200));
}
}
async fn send_http_inference(sample_text: &str, http_port: String) -> Result<()> {
let client = reqwest::Client::new();
let req = common::TokenClassificationRequest {
inputs: vec![sample_text.to_owned()],
metadata: None,
};
let res = client
.post(format!("http://localhost:{http_port}/predict"))
.json(&req)
.send()
.await?;
assert!(
res.status().is_success(),
"HTTP inference request failed with status: {}",
res.status()
);
Ok(())
}
async fn send_grpc_inference(sample_text: &str, grpc_port: String) -> Result<()> {
let mut client = token_classification::token_classification_inference_client::TokenClassificationInferenceClient::connect(format!("http://[::]:{grpc_port}/predict")).await?;
let req = token_classification::TokenClassificationRequest {
inputs: vec![sample_text.to_owned()],
metadata: std::collections::HashMap::new(),
};
client.predict(req).await?;
Ok(())
}
fn spawn_encoderfile(path: &str, http_port: &str, grpc_port: &str) -> Result<Child> {
Command::new(path)
.arg("serve")
.arg("--grpc-port")
.arg(grpc_port)
.arg("--http-port")
.arg(http_port)
.spawn()
.context("failed to spawn encoderfile process")
}
fn copy_dir_all(src: impl AsRef<Path>, dst: impl AsRef<Path>) -> anyhow::Result<()> {
let src = src.as_ref();
let dst = dst.as_ref();
fs::create_dir_all(dst).context(format!("Failed to create directory {:?}", &dst))?;
for entry in fs::read_dir(src)? {
let entry = entry?;
let ty = entry.file_type()?;
let dest_path = dst.join(entry.file_name());
if ty.is_dir() {
copy_dir_all(entry.path(), dest_path.as_path()).context(format!(
"Failed to copy {:?} to {:?}",
entry.path(),
dest_path.as_path()
))?;
} else {
fs::copy(entry.path(), dest_path.as_path()).context(format!(
"Failed to copy {:?} to {:?}",
entry.path(),
dest_path.as_path()
))?;
}
}
Ok(())
}