use crate::error::{CliError, Result};
use aprender::format::v2::AprV2Reader;
use aprender::models::bert::{BertConfig, CrossEncoder};
use std::collections::HashMap;
use std::path::Path;
fn parse_id_list(s: &str, flag: &str) -> Result<Vec<u32>> {
s.split(',')
.map(str::trim)
.filter(|t| !t.is_empty())
.map(|t| {
t.parse::<u32>().map_err(|e| {
CliError::ValidationFailed(format!("--{flag}: invalid u32 token {t:?}: {e}"))
})
})
.collect()
}
fn load_vocab_txt(path: &Path) -> Result<HashMap<String, u32>> {
let text = std::fs::read_to_string(path).map_err(|e| {
CliError::ValidationFailed(format!("Failed to read vocab {}: {e}", path.display()))
})?;
let mut map = HashMap::new();
for (i, line) in text.lines().enumerate() {
map.insert(line.trim_end().to_string(), i as u32);
}
Ok(map)
}
fn load_tokenizer_json(path: &Path) -> Result<HashMap<String, u32>> {
let text = std::fs::read_to_string(path).map_err(|e| {
CliError::ValidationFailed(format!(
"Failed to read tokenizer.json {}: {e}",
path.display()
))
})?;
let root: serde_json::Value = serde_json::from_str(&text).map_err(|e| {
CliError::ValidationFailed(format!(
"tokenizer.json {} is not valid JSON: {e}",
path.display()
))
})?;
let mut map: HashMap<String, u32> = HashMap::new();
if let Some(added) = root.get("added_tokens").and_then(|v| v.as_array()) {
for entry in added {
let content = entry
.get("content")
.and_then(|v| v.as_str())
.ok_or_else(|| {
CliError::ValidationFailed(format!(
"tokenizer.json {}: added_tokens entry missing `content`",
path.display()
))
})?;
let id = entry.get("id").and_then(|v| v.as_u64()).ok_or_else(|| {
CliError::ValidationFailed(format!(
"tokenizer.json {}: added_tokens entry missing `id`",
path.display()
))
})?;
map.insert(content.to_string(), id as u32);
}
}
let vocab_obj = root
.get("model")
.and_then(|m| m.get("vocab"))
.and_then(|v| v.as_object())
.ok_or_else(|| {
CliError::ValidationFailed(format!(
"tokenizer.json {}: missing or non-object `model.vocab` \
(only WordPiece-style tokenizer.json is supported)",
path.display()
))
})?;
for (token, id_val) in vocab_obj {
let id = id_val.as_u64().ok_or_else(|| {
CliError::ValidationFailed(format!(
"tokenizer.json {}: model.vocab entry {token:?} has non-integer id",
path.display()
))
})?;
map.insert(token.to_string(), id as u32);
}
Ok(map)
}
fn load_vocab(path: &Path) -> Result<HashMap<String, u32>> {
let is_json = path
.extension()
.and_then(|e| e.to_str())
.is_some_and(|e| e.eq_ignore_ascii_case("json"));
if is_json {
load_tokenizer_json(path)
} else {
load_vocab_txt(path)
}
}
fn tokenize_query_passage(
query: &str,
passage: &str,
vocab_path: &Path,
) -> Result<(Vec<u32>, Vec<u32>)> {
use aprender::text::tokenize::WordPieceTokenizer;
let vocab = load_vocab(vocab_path)?;
let cls_id = *vocab.get("[CLS]").ok_or_else(|| {
CliError::ValidationFailed(format!(
"vocab {} missing required token [CLS]",
vocab_path.display()
))
})?;
let sep_id = *vocab.get("[SEP]").ok_or_else(|| {
CliError::ValidationFailed(format!(
"vocab {} missing required token [SEP]",
vocab_path.display()
))
})?;
if !vocab.contains_key("[UNK]") {
return Err(CliError::ValidationFailed(format!(
"vocab {} missing required token [UNK]",
vocab_path.display()
)));
}
let tokenizer = WordPieceTokenizer::from_vocab(vocab);
let q_ids = tokenizer
.encode(query)
.map_err(|e| CliError::ValidationFailed(format!("query tokenisation failed: {e:?}")))?;
let p_ids = tokenizer
.encode(passage)
.map_err(|e| CliError::ValidationFailed(format!("passage tokenisation failed: {e:?}")))?;
let mut input_ids = Vec::with_capacity(1 + q_ids.len() + 1 + p_ids.len() + 1);
input_ids.push(cls_id);
input_ids.extend(&q_ids);
input_ids.push(sep_id);
input_ids.extend(&p_ids);
input_ids.push(sep_id);
let mut token_type_ids = Vec::with_capacity(input_ids.len());
token_type_ids.extend(std::iter::repeat_n(0u32, 1 + q_ids.len() + 1));
token_type_ids.extend(std::iter::repeat_n(1u32, p_ids.len() + 1));
Ok((input_ids, token_type_ids))
}
#[allow(clippy::too_many_arguments)]
pub(crate) fn run(
model: &Path,
input_ids_str: Option<&str>,
token_type_ids_str: Option<&str>,
query: Option<&str>,
passage: Option<&str>,
passages: &[String],
sort: bool,
top_k: usize,
vocab: Option<&Path>,
hidden_dim: usize,
num_layers: usize,
num_heads: usize,
intermediate_dim: usize,
vocab_size: usize,
max_position_embeddings: usize,
type_vocab_size: usize,
num_labels: usize,
with_pooler: bool,
raw_logit: bool,
json: bool,
) -> Result<()> {
let model_bytes = std::fs::read(model).map_err(|e| {
CliError::ValidationFailed(format!("Failed to read {}: {e}", model.display()))
})?;
let reader = AprV2Reader::from_bytes(&model_bytes).map_err(|e| {
CliError::ValidationFailed(format!(
"Failed to parse APR v2 at {}: {e:?}",
model.display()
))
})?;
let config = BertConfig {
hidden_dim,
num_layers,
num_heads,
intermediate_dim,
vocab_size,
max_position_embeddings,
type_vocab_size,
layer_norm_eps: 1e-12,
pad_token_id: 0,
};
let mut cross_encoder = CrossEncoder::new(&config, num_labels, with_pooler);
cross_encoder
.load_from_reader(&reader, &config)
.map_err(|e| CliError::ValidationFailed(format!("BERT weight loading failed: {e}")))?;
if !passages.is_empty() {
return run_batch(
&cross_encoder,
model,
query,
passages,
vocab,
sort,
top_k,
raw_logit,
json,
);
}
let (input_ids, token_type_ids) =
match (input_ids_str, token_type_ids_str, query, passage, vocab) {
(Some(id_str), Some(tt_str), None, None, None) => {
let input_ids = parse_id_list(id_str, "input-ids")?;
let token_type_ids = parse_id_list(tt_str, "token-type-ids")?;
(input_ids, token_type_ids)
}
(None, None, Some(q), Some(p), Some(vp)) => tokenize_query_passage(q, p, vp)?,
_ => {
return Err(CliError::ValidationFailed(
"apr rerank requires EITHER \
(--input-ids AND --token-type-ids) \
OR (--query AND --passage AND --vocab) \
OR (--query AND --passages... AND --vocab)"
.to_string(),
));
}
};
if input_ids.is_empty() {
return Err(CliError::ValidationFailed(
"--input-ids must be non-empty".to_string(),
));
}
if input_ids.len() != token_type_ids.len() {
return Err(CliError::ValidationFailed(format!(
"--input-ids ({}) and --token-type-ids ({}) must have the same length",
input_ids.len(),
token_type_ids.len()
)));
}
let logit_tensor = cross_encoder.forward(&input_ids, &token_type_ids);
let logits: &[f32] = logit_tensor.data();
if json {
#[allow(clippy::disallowed_methods)]
{
let payload = if raw_logit {
serde_json::json!({
"model": model.display().to_string(),
"input_ids": input_ids,
"token_type_ids": token_type_ids,
"logits": logits,
})
} else {
let probs: Vec<f32> = logits.iter().map(|&l| 1.0 / (1.0 + (-l).exp())).collect();
serde_json::json!({
"model": model.display().to_string(),
"input_ids": input_ids,
"token_type_ids": token_type_ids,
"scores": probs,
})
};
println!(
"{}",
serde_json::to_string_pretty(&payload).unwrap_or_default()
);
}
return Ok(());
}
if raw_logit {
for (i, &l) in logits.iter().enumerate() {
println!("logit[{i}] = {l:.6}");
}
} else {
for (i, &l) in logits.iter().enumerate() {
let score = 1.0 / (1.0 + (-l).exp());
println!("score[{i}] = {score:.6}");
}
}
Ok(())
}
#[allow(clippy::too_many_arguments)]
fn run_batch(
cross_encoder: &CrossEncoder,
model: &Path,
query: Option<&str>,
passages: &[String],
vocab: Option<&Path>,
sort: bool,
top_k: usize,
raw_logit: bool,
json: bool,
) -> Result<()> {
let (Some(query), Some(vocab)) = (query, vocab) else {
return Err(CliError::ValidationFailed(
"--passages requires both --query and --vocab".to_string(),
));
};
let mut scored: Vec<(usize, f32, f32)> = Vec::with_capacity(passages.len());
for (i, p) in passages.iter().enumerate() {
let (input_ids, token_type_ids) = tokenize_query_passage(query, p, vocab)?;
let logit_tensor = cross_encoder.forward(&input_ids, &token_type_ids);
let logit = logit_tensor.data()[0];
let score = 1.0 / (1.0 + (-logit).exp());
scored.push((i, logit, score));
}
let do_sort = sort || top_k > 0;
if do_sort {
scored.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal));
}
let limit = if top_k > 0 {
top_k.min(scored.len())
} else {
scored.len()
};
let scored = &scored[..limit];
if json {
#[allow(clippy::disallowed_methods)]
{
let array: Vec<serde_json::Value> = scored
.iter()
.map(|&(i, logit, score)| {
serde_json::json!({
"index": i,
"passage": passages[i],
"logit": logit,
"score": score,
})
})
.collect();
let payload = serde_json::json!({
"model": model.display().to_string(),
"query": query,
"num_passages": passages.len(),
"returned": scored.len(),
"sorted": do_sort,
"results": array,
});
println!(
"{}",
serde_json::to_string_pretty(&payload).unwrap_or_default()
);
}
return Ok(());
}
if raw_logit {
for (i, logit, _score) in scored {
println!("logit[{i}] = {logit:.6}");
}
} else {
for (i, _logit, score) in scored {
println!("score[{i}] = {score:.6}");
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_id_list_accepts_commas_and_spaces() {
assert_eq!(
parse_id_list("1,2,3", "input-ids").unwrap(),
vec![1u32, 2, 3]
);
assert_eq!(
parse_id_list(" 101, 2024, 102 ", "input-ids").unwrap(),
vec![101u32, 2024, 102]
);
}
#[test]
fn parse_id_list_rejects_invalid_token() {
let err = parse_id_list("1,xx,3", "input-ids").expect_err("xx must reject");
match err {
CliError::ValidationFailed(msg) => {
assert!(msg.contains("input-ids"));
assert!(msg.contains("xx"));
}
_ => panic!("expected ValidationFailed"),
}
}
#[test]
fn parse_id_list_skips_empty_tokens_from_trailing_comma() {
assert_eq!(
parse_id_list("1,2,3,", "input-ids").unwrap(),
vec![1u32, 2, 3]
);
}
#[test]
fn load_vocab_txt_assigns_line_index_as_id() {
let dir = tempfile::tempdir().expect("tempdir");
let path = dir.path().join("vocab.txt");
std::fs::write(&path, "[PAD]\n[UNK]\n[CLS]\n[SEP]\nhello\n##world\n").unwrap();
let map = load_vocab_txt(&path).expect("load");
assert_eq!(map.get("[PAD]").copied(), Some(0));
assert_eq!(map.get("[UNK]").copied(), Some(1));
assert_eq!(map.get("[CLS]").copied(), Some(2));
assert_eq!(map.get("[SEP]").copied(), Some(3));
assert_eq!(map.get("hello").copied(), Some(4));
assert_eq!(map.get("##world").copied(), Some(5));
}
#[test]
fn tokenize_query_passage_builds_correct_segment_pair() {
let dir = tempfile::tempdir().expect("tempdir");
let path = dir.path().join("vocab.txt");
std::fs::write(
&path,
"[PAD]\n[UNK]\n[CLS]\n[SEP]\nhello\nworld\nfoo\nbar\n",
)
.unwrap();
let (input_ids, token_type_ids) =
tokenize_query_passage("hello world", "foo bar", &path).expect("tokenize");
assert_eq!(input_ids, vec![2u32, 4, 5, 3, 6, 7, 3]);
assert_eq!(token_type_ids, vec![0u32, 0, 0, 0, 1, 1, 1]);
}
#[test]
fn tokenize_query_passage_rejects_missing_cls() {
let dir = tempfile::tempdir().expect("tempdir");
let path = dir.path().join("vocab.txt");
std::fs::write(&path, "[PAD]\n[UNK]\n[SEP]\nhello\n").unwrap();
let err =
tokenize_query_passage("hello", "world", &path).expect_err("missing [CLS] must reject");
match err {
CliError::ValidationFailed(msg) => assert!(msg.contains("[CLS]"), "{msg}"),
_ => panic!("expected ValidationFailed"),
}
}
#[test]
fn tokenize_query_passage_rejects_missing_sep() {
let dir = tempfile::tempdir().expect("tempdir");
let path = dir.path().join("vocab.txt");
std::fs::write(&path, "[PAD]\n[UNK]\n[CLS]\nhello\n").unwrap();
let err =
tokenize_query_passage("hello", "world", &path).expect_err("missing [SEP] must reject");
match err {
CliError::ValidationFailed(msg) => assert!(msg.contains("[SEP]"), "{msg}"),
_ => panic!("expected ValidationFailed"),
}
}
}