use std::sync::RwLock;
use anyhow::{Error as E, Result};
use candle_core::{Device, Tensor};
use hf_hub::{api::sync::Api, Repo};
use ndarray::Array2;
use ort::{
execution_providers::{CUDAExecutionProvider, CoreMLExecutionProvider, ExecutionProvider},
session::{builder::GraphOptimizationLevel, Session},
};
use tokenizers::{PaddingParams, Tokenizer, TruncationParams};
use crate::Dtype;
use crate::{embeddings::local::bert::TokenizerConfig, reranker::qwen3};
use serde::Serialize;
#[derive(Debug, Serialize)]
pub struct RerankerResult {
pub query: String,
pub documents: Vec<DocumentRank>,
}
#[derive(Debug, Serialize, Clone)]
pub struct DocumentRank {
pub document: String,
pub relevance_score: f32,
pub rank: usize,
}
pub struct Reranker {
model: RwLock<Session>,
model_type: Option<String>,
tokenizer: Tokenizer,
}
impl Reranker {
pub fn new(
model_id: &str,
revision: Option<&str>,
dtype: Dtype,
path_in_repo: Option<&str>,
) -> Result<Self, E> {
let (config_filename, tokenizer_filename, weights_filename, tokenizer_config_filename) = {
let api = Api::new().unwrap();
let api = match revision {
Some(rev) => api.repo(Repo::with_revision(
model_id.to_string(),
hf_hub::RepoType::Model,
rev.to_string(),
)),
None => api.repo(hf_hub::Repo::new(
model_id.to_string(),
hf_hub::RepoType::Model,
)),
};
let config = api.get("config.json")?;
let tokenizer = api.get("tokenizer.json")?;
let tokenizer_config = api.get("tokenizer_config.json")?;
let mut path_in_repo = path_in_repo.unwrap_or_default().to_string();
if !path_in_repo.is_empty() {
path_in_repo.push('/');
}
let weights = match dtype {
Dtype::Q4F16 => api.get(format!("{}model_q4f16.onnx", path_in_repo).as_str())?,
Dtype::F16 => api.get(format!("{}model_fp16.onnx", path_in_repo).as_str())?,
Dtype::INT8 => api.get(format!("{}model_int8.onnx", path_in_repo).as_str())?,
Dtype::Q4 => api.get(format!("{}model_q4.onnx", path_in_repo).as_str())?,
Dtype::UINT8 => api.get(format!("{}model_uint8.onnx", path_in_repo).as_str())?,
Dtype::BNB4 => api.get(format!("{}model_bnb4.onnx", path_in_repo).as_str())?,
Dtype::F32 => api.get(format!("{}model.onnx", path_in_repo).as_str())?,
Dtype::BF16 => api.get(format!("{}model_bf16.onnx", path_in_repo).as_str())?,
Dtype::QUANTIZED => {
api.get(format!("{}model_quantized.onnx", path_in_repo).as_str())?
}
};
(config, tokenizer, weights, tokenizer_config)
};
let tokenizer_config = std::fs::read_to_string(tokenizer_config_filename)?;
let tokenizer_config: TokenizerConfig = serde_json::from_str(&tokenizer_config)?;
let max_length = match (
tokenizer_config.max_length,
tokenizer_config.model_max_length,
) {
(Some(max_len), Some(model_max_len)) => std::cmp::min(max_len, model_max_len),
(Some(max_len), None) => max_len,
(None, Some(model_max_len)) => model_max_len,
(None, None) => 128,
};
let config_file = std::fs::read_to_string(config_filename)?;
let config: serde_json::Value = serde_json::from_str(&config_file)?;
let model_type = config
.get("model_type")
.and_then(|v| v.as_str())
.map(|s| s.to_string());
let mut tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let pp = PaddingParams {
strategy: tokenizers::PaddingStrategy::BatchLongest,
..Default::default()
};
let trunc = TruncationParams {
max_length,
..Default::default()
};
tokenizer
.with_padding(Some(pp))
.with_truncation(Some(trunc))
.unwrap();
let cuda = CUDAExecutionProvider::default();
if !cuda.is_available()? {
eprintln!("CUDAExecutionProvider is not available");
} else {
println!("Session is using CUDAExecutionProvider");
}
let threads = std::thread::available_parallelism()
.map(|p| p.get())
.unwrap_or(1);
let optimal_threads = std::cmp::max(1, threads / 2);
let model = Session::builder()?
.with_execution_providers([
CUDAExecutionProvider::default().build(),
CoreMLExecutionProvider::default().build(),
])?
.with_optimization_level(GraphOptimizationLevel::Level3)?
.with_intra_threads(optimal_threads)? .with_inter_threads(1)? .commit_from_file(weights_filename)?;
Ok(Reranker {
model: RwLock::new(model),
model_type,
tokenizer,
})
}
pub fn compute_scores(
&self,
queries: Vec<&str>,
documents: Vec<&str>,
batch_size: usize,
) -> Result<Vec<Vec<f32>>, E> {
let is_qwen3 = self.model_type.as_ref().is_some_and(|t| t == "qwen3");
let prefix = "<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n<|im_start|>user\n";
let suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n";
let pairs = queries
.iter()
.flat_map(|query| documents.iter().map(move |doc| (*query, *doc)))
.collect::<Vec<_>>();
let pairs = if is_qwen3 {
pairs.iter().map(|(query, doc)| (format!("{}{}", prefix, query), format!("{}{}", doc, suffix))).collect::<Vec<_>>()
} else {
pairs.iter().map(|(query, doc)| (query.to_string(), doc.to_string())).collect::<Vec<_>>()
};
let mut scores = Vec::with_capacity(pairs.len());
let mut model_guard = self.model.write().unwrap();
let false_token_id = self
.tokenizer
.token_to_id("no")
.ok_or(E::msg("no token found"))?;
let true_token_id = self
.tokenizer
.token_to_id("yes")
.ok_or(E::msg("yes token found"))?;
for pair in pairs.chunks(batch_size) {
let input_ids = self.tokenize_batch_ndarray(pair)?;
let attention_mask = self.get_attention_mask_ndarray(pair)?;
let input_ids_tensor = ort::value::TensorRef::from_array_view(&input_ids)?;
let attention_mask_tensor = ort::value::TensorRef::from_array_view(&attention_mask)?;
let model_inputs = model_guard
.inputs
.iter()
.map(|i| i.name.clone())
.collect::<Vec<_>>();
let seq_len = input_ids.shape()[1];
let position_ids_vec: Vec<i64> = (0..seq_len as i64)
.collect::<Vec<i64>>()
.repeat(input_ids.shape()[0]);
let position_ids =
Array2::<i64>::from_shape_vec((input_ids.shape()[0], seq_len), position_ids_vec)?;
let position_ids_tensor = ort::value::TensorRef::from_array_view(&position_ids)?;
let inputs = if model_inputs.contains(&"position_ids".to_string()) {
ort::inputs!["input_ids" => input_ids_tensor, "attention_mask" => attention_mask_tensor, "position_ids" => position_ids_tensor]
} else {
ort::inputs!["input_ids" => input_ids_tensor, "attention_mask" => attention_mask_tensor]
};
if is_qwen3 {
let logits = qwen3::compute_scores(
&mut model_guard,
input_ids,
attention_mask,
position_ids,
false_token_id,
true_token_id,
)?;
scores.extend(logits);
} else {
let outputs = model_guard.run(inputs)?;
let logits = outputs["logits".to_string()]
.try_extract_array::<f32>()?
.to_owned()
.into_dimensionality::<ndarray::Ix2>()?;
scores.extend(
logits
.outer_iter()
.flat_map(|row| row.to_vec())
.collect::<Vec<_>>(),
);
}
}
let scores_tensor = Tensor::from_vec(
scores.clone(),
(queries.len(), documents.len()),
&Device::Cpu,
)?;
if is_qwen3 {
Ok(scores_tensor.to_vec2::<f32>()?)
} else {
let sigmoid_scores = candle_nn::ops::sigmoid(&scores_tensor)?;
Ok(sigmoid_scores.to_vec2::<f32>()?)
}
}
pub fn rerank(
&self,
queries: Vec<&str>,
documents: Vec<&str>,
batch_size: usize,
) -> Result<Vec<RerankerResult>, E> {
let scores = self.compute_scores(queries.clone(), documents.clone(), batch_size)?;
let mut reranker_results = Vec::new();
for (i, query) in queries.iter().enumerate() {
let scores = scores[i].clone();
let mut indices: Vec<usize> = (0..scores.len()).collect();
indices.sort_by(|&j, &k| {
scores[k]
.partial_cmp(&scores[j])
.unwrap_or(std::cmp::Ordering::Equal)
});
let document_ranks = scores
.iter()
.enumerate()
.map(|(p, score)| DocumentRank {
document: documents[p].to_string(),
relevance_score: *score,
rank: indices.iter().position(|&i| i == p).unwrap() + 1,
})
.collect::<Vec<_>>();
reranker_results.push(RerankerResult {
query: query.to_string(),
documents: document_ranks,
});
}
Ok(reranker_results)
}
pub fn tokenize_batch_ndarray(&self, pairs: &[(String, String)]) -> anyhow::Result<Array2<i64>> {
let token_ids = self
.tokenizer
.encode_batch(pairs.to_vec(), true)
.map_err(E::msg)?
.iter()
.map(|tokens| {
tokens
.get_ids()
.iter()
.map(|&id| id as i64)
.collect::<Vec<i64>>()
})
.collect::<Vec<Vec<i64>>>();
let token_ids_array = Array2::from_shape_vec(
(token_ids.len(), token_ids[0].len()),
token_ids.into_iter().flatten().collect::<Vec<i64>>(),
)
.unwrap();
Ok(token_ids_array)
}
pub fn get_attention_mask_ndarray(
&self,
pairs: &[(String, String)],
) -> anyhow::Result<Array2<i64>> {
let attention_mask = self
.tokenizer
.encode_batch(pairs.to_vec(), true)
.map_err(E::msg)?
.iter()
.map(|tokens| {
tokens
.get_attention_mask()
.iter()
.map(|&id| id as i64)
.collect::<Vec<i64>>()
})
.collect::<Vec<Vec<i64>>>();
let attention_mask_array = Array2::from_shape_vec(
(attention_mask.len(), attention_mask[0].len()),
attention_mask.into_iter().flatten().collect::<Vec<i64>>(),
)
.unwrap();
Ok(attention_mask_array)
}
}