use std::path::Path;
use std::thread;
use crate::Result;
use crate::calibrate::{Calibrator, SigmoidCalibrator};
use crate::engine::Scorer;
use crate::tokenize::Tokenizer;
use crate::types::{ModelConfig, RerankResult};
use ort::execution_providers::CPU;
use ort::session::Session;
use ort::value::Tensor;
pub struct ParallelScorer {
config: ModelConfig,
workers: Vec<WorkerHandle>,
tokenizer: std::sync::Mutex<Tokenizer>,
num_workers: usize,
}
struct WorkerHandle {
tx: std::sync::mpsc::Sender<WorkerRequest>,
}
struct WorkerRequest {
input_ids: Vec<i64>,
attention_mask: Vec<i64>,
token_type_ids: Vec<i64>,
batch_size: usize,
seq_len: usize,
reply: std::sync::mpsc::Sender<WorkerResult>,
}
type WorkerResult = std::result::Result<Vec<f32>, String>;
impl ParallelScorer {
pub fn new(config: ModelConfig, model_dir: &Path, num_workers: Option<usize>) -> Result<Self> {
let total_threads = thread::available_parallelism()
.map(|n| n.get())
.unwrap_or(8);
let n_workers = num_workers.unwrap_or_else(|| (total_threads / 2).clamp(2, 8));
let threads_per_worker = (total_threads / n_workers).max(1);
let model_path = super::ort_backend::OrtScorer::resolve_model_path(model_dir, &config);
let tokenizer_path = model_dir.join("tokenizer.json");
let tokenizer = Tokenizer::from_file(&tokenizer_path)?;
let mut workers = Vec::with_capacity(n_workers);
for worker_id in 0..n_workers {
let (tx, rx) = std::sync::mpsc::channel::<WorkerRequest>();
let model_path = model_path.clone();
let tpw = threads_per_worker;
thread::Builder::new()
.name(format!("flash-rerank-worker-{worker_id}"))
.spawn(move || {
worker_loop(rx, &model_path, tpw);
})
.map_err(|e| {
crate::Error::Inference(format!("Failed to spawn worker {worker_id}: {e}"))
})?;
workers.push(WorkerHandle { tx });
}
tracing::info!(
num_workers = n_workers,
threads_per_worker,
model = %model_path.display(),
"ParallelScorer initialized"
);
Ok(Self {
config,
workers,
tokenizer: std::sync::Mutex::new(tokenizer),
num_workers: n_workers,
})
}
}
fn worker_loop(rx: std::sync::mpsc::Receiver<WorkerRequest>, model_path: &Path, threads: usize) {
let build_result = (|| -> std::result::Result<Session, ort::Error> {
let mut builder = Session::builder()?;
builder = builder.with_intra_threads(threads)?;
builder = builder.with_execution_providers([CPU::default().build()])?;
builder.commit_from_file(model_path)
})();
let mut session = match build_result {
Ok(s) => s,
Err(e) => {
tracing::error!("Worker failed to create session: {e}");
while let Ok(req) = rx.recv() {
let _ = req.reply.send(Err(format!("Session init failed: {e}")));
}
return;
}
};
while let Ok(req) = rx.recv() {
let result = run_inference(
&mut session,
&req.input_ids,
&req.attention_mask,
&req.token_type_ids,
req.batch_size,
req.seq_len,
);
let _ = req.reply.send(result);
}
}
fn run_inference(
session: &mut Session,
input_ids: &[i64],
attention_mask: &[i64],
token_type_ids: &[i64],
batch_size: usize,
seq_len: usize,
) -> WorkerResult {
let shape = [batch_size as i64, seq_len as i64];
let ids =
Tensor::from_array((shape, input_ids.to_vec())).map_err(|e| format!("tensor: {e}"))?;
let mask =
Tensor::from_array((shape, attention_mask.to_vec())).map_err(|e| format!("tensor: {e}"))?;
let tids =
Tensor::from_array((shape, token_type_ids.to_vec())).map_err(|e| format!("tensor: {e}"))?;
let outputs = session
.run(ort::inputs![
"input_ids" => ids,
"attention_mask" => mask,
"token_type_ids" => tids,
])
.map_err(|e| format!("inference: {e}"))?;
let (_shape, logits) = outputs[0]
.try_extract_tensor::<f32>()
.map_err(|e| format!("extract: {e}"))?;
Ok(logits.to_vec())
}
impl Scorer for ParallelScorer {
fn score(&self, query: &str, documents: &[String]) -> Result<Vec<RerankResult>> {
if query.is_empty() {
return Err(crate::Error::Inference("Empty query".to_string()));
}
if documents.is_empty() {
return Ok(vec![]);
}
let encodings = {
let mut tokenizer = self
.tokenizer
.lock()
.map_err(|e| crate::Error::Tokenizer(format!("lock: {e}")))?;
tokenizer.tokenize_pairs(query, documents, self.config.max_length)?
};
let total = encodings.len();
let seq_len = encodings[0].get_ids().len();
let chunk_size = total.div_ceil(self.num_workers);
let mut receivers = Vec::new();
for (worker_idx, chunk) in encodings.chunks(chunk_size).enumerate() {
let batch_size = chunk.len();
let input_ids: Vec<i64> = chunk
.iter()
.flat_map(|e| e.get_ids().iter().map(|&id| id as i64))
.collect();
let attention_mask: Vec<i64> = chunk
.iter()
.flat_map(|e| e.get_attention_mask().iter().map(|&m| m as i64))
.collect();
let token_type_ids: Vec<i64> = chunk
.iter()
.flat_map(|e| e.get_type_ids().iter().map(|&t| t as i64))
.collect();
let (reply_tx, reply_rx) = std::sync::mpsc::channel();
let worker = &self.workers[worker_idx % self.num_workers];
worker
.tx
.send(WorkerRequest {
input_ids,
attention_mask,
token_type_ids,
batch_size,
seq_len,
reply: reply_tx,
})
.map_err(|e| crate::Error::Inference(format!("Worker send failed: {e}")))?;
receivers.push(reply_rx);
}
let calibrator = SigmoidCalibrator;
let mut all_results = Vec::with_capacity(total);
let mut global_idx = 0;
for rx in receivers {
let logits = rx
.recv()
.map_err(|e| crate::Error::Inference(format!("Worker recv failed: {e}")))?
.map_err(crate::Error::Inference)?;
for logit in logits {
all_results.push(RerankResult {
index: global_idx,
score: calibrator.calibrate(logit),
document: None,
});
global_idx += 1;
}
}
all_results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
Ok(all_results)
}
}