use std::sync::Arc;
use arrow_array::RecordBatch;
use async_trait::async_trait;
use hirn_core::error::{HirnError, HirnResult};
use hirn_storage::Reranker;
use super::{OpContext, Operator};
pub struct RerankOp {
pub reranker: Arc<dyn Reranker>,
pub query: String,
}
#[async_trait]
impl Operator for RerankOp {
async fn execute(
&self,
input: Vec<RecordBatch>,
_ctx: &OpContext,
) -> HirnResult<Vec<RecordBatch>> {
let mut output = Vec::with_capacity(input.len());
for batch in &input {
if batch.num_rows() == 0 {
output.push(batch.clone());
continue;
}
if batch.column_by_name("_relevance_score").is_none() {
output.push(batch.clone());
continue;
}
let reranked = self
.reranker
.rerank_vector(&self.query, batch)
.await
.map_err(HirnError::storage)?;
output.push(reranked);
}
Ok(output)
}
}