hirn_engine/operators/
rerank.rs1use std::sync::Arc;
7
8use arrow_array::RecordBatch;
9use async_trait::async_trait;
10
11use hirn_core::error::{HirnError, HirnResult};
12use hirn_storage::Reranker;
13
14use super::{OpContext, Operator};
15
16pub struct RerankOp {
22 pub reranker: Arc<dyn Reranker>,
24 pub query: String,
26}
27
28#[async_trait]
29impl Operator for RerankOp {
30 async fn execute(
31 &self,
32 input: Vec<RecordBatch>,
33 _ctx: &OpContext,
34 ) -> HirnResult<Vec<RecordBatch>> {
35 let mut output = Vec::with_capacity(input.len());
36 for batch in &input {
37 if batch.num_rows() == 0 {
38 output.push(batch.clone());
39 continue;
40 }
41 if batch.column_by_name("_relevance_score").is_none() {
43 output.push(batch.clone());
44 continue;
45 }
46 let reranked = self
47 .reranker
48 .rerank_vector(&self.query, batch)
49 .await
50 .map_err(HirnError::storage)?;
51 output.push(reranked);
52 }
53 Ok(output)
54 }
55}