use crate::batch::LlamaBatch;
use crate::error::Result;
use crate::Llama;
impl Llama {
pub fn rerank(&mut self, query: &str, documents: &[&str]) -> Result<Vec<f32>> {
let mut scores = Vec::with_capacity(documents.len());
for (i, doc) in documents.iter().enumerate() {
let q = self.model().tokenize(query, true, false)?;
let d = self.model().tokenize(doc, false, false)?;
let mut batch = LlamaBatch::new(q.len() + d.len(), 1);
for (j, &t) in q.iter().chain(d.iter()).enumerate() {
let logits = j + 1 == q.len() + d.len();
let _ = batch
.add(t, j as i32, &[i as i32], logits)
.map_err(crate::error::LlamaError::from)?;
}
self.context_mut().encode(&batch)?;
let emb = self.context().embeddings_seq(i as i32)?;
scores.push(emb.first().copied().unwrap_or(0.0));
}
Ok(scores)
}
}