use crate::error::{CliError, Result};
use aprender::autograd::Tensor;
use aprender::format::v2::AprV2Reader;
use aprender::models::bert::{BertConfig, BertEmbeddings, BertEncoder};
use std::collections::HashMap;
use std::path::Path;
fn load_vocab(path: &Path) -> Result<HashMap<String, u32>> {
let is_json = path
.extension()
.and_then(|e| e.to_str())
.is_some_and(|e| e.eq_ignore_ascii_case("json"));
if is_json {
load_tokenizer_json(path)
} else {
load_vocab_txt(path)
}
}
fn load_vocab_txt(path: &Path) -> Result<HashMap<String, u32>> {
let text = std::fs::read_to_string(path).map_err(|e| {
CliError::ValidationFailed(format!("Failed to read vocab {}: {e}", path.display()))
})?;
let mut map = HashMap::new();
for (i, line) in text.lines().enumerate() {
map.insert(line.trim_end().to_string(), i as u32);
}
Ok(map)
}
fn load_tokenizer_json(path: &Path) -> Result<HashMap<String, u32>> {
let text = std::fs::read_to_string(path).map_err(|e| {
CliError::ValidationFailed(format!(
"Failed to read tokenizer.json {}: {e}",
path.display()
))
})?;
let root: serde_json::Value = serde_json::from_str(&text).map_err(|e| {
CliError::ValidationFailed(format!(
"tokenizer.json {} is not valid JSON: {e}",
path.display()
))
})?;
let mut map: HashMap<String, u32> = HashMap::new();
if let Some(added) = root.get("added_tokens").and_then(|v| v.as_array()) {
for entry in added {
if let (Some(c), Some(i)) = (
entry.get("content").and_then(|v| v.as_str()),
entry.get("id").and_then(|v| v.as_u64()),
) {
map.insert(c.to_string(), i as u32);
}
}
}
let vocab_obj = root
.get("model")
.and_then(|m| m.get("vocab"))
.and_then(|v| v.as_object())
.ok_or_else(|| {
CliError::ValidationFailed(format!(
"tokenizer.json {}: missing or non-object `model.vocab`",
path.display()
))
})?;
for (token, id_val) in vocab_obj {
if let Some(id) = id_val.as_u64() {
map.insert(token.to_string(), id as u32);
}
}
Ok(map)
}
fn tokenize_single(text: &str, vocab_path: &Path) -> Result<(Vec<u32>, Vec<u32>)> {
use aprender::text::tokenize::WordPieceTokenizer;
let vocab = load_vocab(vocab_path)?;
let cls_id = *vocab.get("[CLS]").ok_or_else(|| {
CliError::ValidationFailed(format!(
"vocab {} missing required token [CLS]",
vocab_path.display()
))
})?;
let sep_id = *vocab.get("[SEP]").ok_or_else(|| {
CliError::ValidationFailed(format!(
"vocab {} missing required token [SEP]",
vocab_path.display()
))
})?;
if !vocab.contains_key("[UNK]") {
return Err(CliError::ValidationFailed(format!(
"vocab {} missing required token [UNK]",
vocab_path.display()
)));
}
let tokenizer = WordPieceTokenizer::from_vocab(vocab);
let body_ids = tokenizer
.encode(text)
.map_err(|e| CliError::ValidationFailed(format!("tokenisation failed: {e:?}")))?;
let mut input_ids = Vec::with_capacity(body_ids.len() + 2);
input_ids.push(cls_id);
input_ids.extend(&body_ids);
input_ids.push(sep_id);
let token_type_ids = vec![0u32; input_ids.len()];
Ok((input_ids, token_type_ids))
}
fn pool(hidden: &Tensor, seq_len: usize, hidden_dim: usize, mode: &str) -> Result<Vec<f32>> {
let data = hidden.data();
let need = seq_len * hidden_dim;
if data.len() < need {
return Err(CliError::ValidationFailed(format!(
"encoder output {} smaller than seq_len*hidden_dim {}",
data.len(),
need
)));
}
match mode {
"cls" => Ok(data[..hidden_dim].to_vec()),
"mean" => {
let mut acc = vec![0.0f32; hidden_dim];
for t in 0..seq_len {
let row = &data[t * hidden_dim..(t + 1) * hidden_dim];
for i in 0..hidden_dim {
acc[i] += row[i];
}
}
let denom = seq_len as f32;
for v in &mut acc {
*v /= denom;
}
Ok(acc)
}
other => Err(CliError::ValidationFailed(format!(
"--pool must be `cls` or `mean`, got {other:?}"
))),
}
}
fn l2_normalize(v: &mut [f32]) {
let norm = v.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
for x in v.iter_mut() {
*x /= norm;
}
}
}
fn load_text_file(path: &Path) -> Result<Vec<String>> {
let content = std::fs::read_to_string(path).map_err(|e| {
CliError::ValidationFailed(format!(
"Failed to read --text-file {}: {e}",
path.display()
))
})?;
let mut out = Vec::new();
for raw in content.lines() {
let line = raw.trim_end();
if line.is_empty() || line.starts_with('#') {
continue;
}
out.push(line.to_string());
}
Ok(out)
}
#[allow(clippy::too_many_arguments)]
pub(crate) fn run(
model: &Path,
texts: &[String],
text_file: Option<&Path>,
vocab: &Path,
pool_mode: &str,
normalize: bool,
hidden_dim: usize,
num_layers: usize,
num_heads: usize,
intermediate_dim: usize,
vocab_size: usize,
max_position_embeddings: usize,
type_vocab_size: usize,
json: bool,
) -> Result<()> {
let mut all_texts: Vec<String> = texts.to_vec();
if let Some(path) = text_file {
all_texts.extend(load_text_file(path)?);
}
let texts = all_texts.as_slice();
if texts.is_empty() {
return Err(CliError::ValidationFailed(
"apr embed requires at least one --text or --text-file row".to_string(),
));
}
let model_bytes = std::fs::read(model).map_err(|e| {
CliError::ValidationFailed(format!("Failed to read {}: {e}", model.display()))
})?;
let reader = AprV2Reader::from_bytes(&model_bytes).map_err(|e| {
CliError::ValidationFailed(format!(
"Failed to parse APR v2 at {}: {e:?}",
model.display()
))
})?;
let config = BertConfig {
hidden_dim,
num_layers,
num_heads,
intermediate_dim,
vocab_size,
max_position_embeddings,
type_vocab_size,
layer_norm_eps: 1e-12,
pad_token_id: 0,
};
let mut embeddings = BertEmbeddings::new(&config);
embeddings
.load_from_reader(&reader, &config)
.map_err(|e| CliError::ValidationFailed(format!("BERT embeddings load failed: {e}")))?;
let mut encoder = BertEncoder::new(&config);
encoder
.load_from_reader(&reader, &config)
.map_err(|e| CliError::ValidationFailed(format!("BERT encoder load failed: {e}")))?;
let mut results: Vec<(String, Vec<f32>)> = Vec::with_capacity(texts.len());
for text in texts {
let (input_ids, token_type_ids) = tokenize_single(text, vocab)?;
let seq_len = input_ids.len();
let emb_tensor = embeddings.forward(&input_ids, &token_type_ids);
let hidden = encoder.forward(&emb_tensor, None);
let mut pooled = pool(&hidden, seq_len, hidden_dim, pool_mode)?;
if normalize {
l2_normalize(&mut pooled);
}
results.push((text.clone(), pooled));
}
if json {
#[allow(clippy::disallowed_methods)]
{
let array: Vec<serde_json::Value> = results
.iter()
.map(|(t, v)| {
serde_json::json!({
"text": t,
"embedding": v,
"dim": v.len(),
})
})
.collect();
let payload = serde_json::json!({
"model": model.display().to_string(),
"pool": pool_mode,
"normalize": normalize,
"results": array,
});
println!(
"{}",
serde_json::to_string_pretty(&payload).unwrap_or_default()
);
}
return Ok(());
}
for (text, v) in &results {
let preview: Vec<String> = v.iter().take(4).map(|x| format!("{x:+.4}")).collect();
println!(
"text={text:?} dim={} preview=[{}, …]",
v.len(),
preview.join(", ")
);
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use aprender::autograd::Tensor;
#[test]
fn pool_cls_returns_first_token() {
let hidden_dim = 4;
let seq_len = 3;
let data: Vec<f32> = (0..(seq_len * hidden_dim)).map(|i| i as f32).collect();
let t = Tensor::from_vec(data, &[1, seq_len, hidden_dim]);
let out = pool(&t, seq_len, hidden_dim, "cls").unwrap();
assert_eq!(out, vec![0.0f32, 1.0, 2.0, 3.0]);
}
#[test]
fn pool_mean_averages_all_tokens() {
let hidden_dim = 2;
let seq_len = 3;
let t = Tensor::from_vec(
vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0],
&[1, seq_len, hidden_dim],
);
let out = pool(&t, seq_len, hidden_dim, "mean").unwrap();
assert_eq!(out, vec![3.0f32, 4.0]);
}
#[test]
fn pool_rejects_unknown_mode() {
let t = Tensor::from_vec(vec![0.0f32; 4], &[1, 2, 2]);
let err = pool(&t, 2, 2, "max").expect_err("max not yet supported");
match err {
CliError::ValidationFailed(msg) => assert!(msg.contains("`cls` or `mean`"), "{msg}"),
_ => panic!("expected ValidationFailed"),
}
}
#[test]
fn l2_normalize_produces_unit_norm() {
let mut v = vec![3.0f32, 4.0]; l2_normalize(&mut v);
let n: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((n - 1.0).abs() < 1e-6);
assert!((v[0] - 0.6).abs() < 1e-6);
assert!((v[1] - 0.8).abs() < 1e-6);
}
#[test]
fn l2_normalize_handles_zero_vector() {
let mut v = vec![0.0f32; 4];
l2_normalize(&mut v);
assert_eq!(v, vec![0.0f32; 4]);
}
#[test]
fn load_text_file_reads_one_per_line() {
let dir = tempfile::tempdir().expect("tempdir");
let path = dir.path().join("docs.txt");
std::fs::write(
&path,
"first document\n\
second document\n\
# this is a comment\n\
\n\
third document \n",
)
.unwrap();
let texts = load_text_file(&path).expect("load");
assert_eq!(
texts,
vec![
"first document".to_string(),
"second document".to_string(),
"third document".to_string(),
]
);
}
#[test]
fn load_text_file_empty_returns_empty() {
let dir = tempfile::tempdir().expect("tempdir");
let path = dir.path().join("empty.txt");
std::fs::write(&path, "").unwrap();
let texts = load_text_file(&path).expect("load");
assert!(texts.is_empty());
}
#[test]
fn load_text_file_only_comments_returns_empty() {
let dir = tempfile::tempdir().expect("tempdir");
let path = dir.path().join("comments.txt");
std::fs::write(
&path,
"# header comment\n\
\n\
# another\n\
\n",
)
.unwrap();
let texts = load_text_file(&path).expect("load");
assert!(texts.is_empty());
}
#[test]
fn load_text_file_missing_path_errors_with_path() {
let dir = tempfile::tempdir().expect("tempdir");
let path = dir.path().join("does-not-exist.txt");
let err = load_text_file(&path).expect_err("missing path must error");
match err {
CliError::ValidationFailed(msg) => {
assert!(msg.contains("does-not-exist.txt"), "{msg}")
}
_ => panic!("expected ValidationFailed"),
}
}
}