Skip to main content

hirn_engine/operators/
rerank.rs

1//! Rerank operator.
2//!
3//! Wraps the hirn-storage [`Reranker`](hirn_storage::Reranker) trait to re-score and
4//! re-order search results within a pipeline.
5
6use std::sync::Arc;
7
8use arrow_array::RecordBatch;
9use async_trait::async_trait;
10
11use hirn_core::error::{HirnError, HirnResult};
12use hirn_storage::Reranker;
13
14use super::{OpContext, Operator};
15
16/// Pipeline operator that re-ranks input batches using a [`Reranker`].
17///
18/// Uses `rerank_vector` on each input batch, sorted by the
19/// `_relevance_score` column. Batches without a `_relevance_score` column
20/// are passed through unchanged.
21pub struct RerankOp {
22    /// The reranker implementation.
23    pub reranker: Arc<dyn Reranker>,
24    /// The query text used for re-ranking.
25    pub query: String,
26}
27
28#[async_trait]
29impl Operator for RerankOp {
30    async fn execute(
31        &self,
32        input: Vec<RecordBatch>,
33        _ctx: &OpContext,
34    ) -> HirnResult<Vec<RecordBatch>> {
35        let mut output = Vec::with_capacity(input.len());
36        for batch in &input {
37            if batch.num_rows() == 0 {
38                output.push(batch.clone());
39                continue;
40            }
41            // Only rerank batches that have a relevance score column.
42            if batch.column_by_name("_relevance_score").is_none() {
43                output.push(batch.clone());
44                continue;
45            }
46            let reranked = self
47                .reranker
48                .rerank_vector(&self.query, batch)
49                .await
50                .map_err(HirnError::storage)?;
51            output.push(reranked);
52        }
53        Ok(output)
54    }
55}