mod narrative;
mod policy;
mod recall;
mod rerank;
mod temporal;
pub use narrative::NarrativeAssemble;
pub use policy::PolicyFilter;
pub use recall::{HybridRecall, MultivectorRecall, VectorRecall};
pub use rerank::RerankOp;
pub use temporal::TemporalExpand;
use std::sync::Arc;
use arrow_array::RecordBatch;
use async_trait::async_trait;
use hirn_core::error::HirnResult;
use hirn_storage::PhysicalStore;
use crate::persistent_graph::PersistentGraph;
#[async_trait]
pub trait Operator: Send + Sync {
async fn execute(
&self,
input: Vec<RecordBatch>,
ctx: &OpContext,
) -> HirnResult<Vec<RecordBatch>>;
}
pub struct OpContext {
pub store: Arc<dyn PhysicalStore>,
pub graph: Option<Arc<PersistentGraph>>,
pub principal: Option<String>,
}
impl OpContext {
pub fn new(store: Arc<dyn PhysicalStore>) -> Self {
Self {
store,
graph: None,
principal: None,
}
}
pub fn with_graph(mut self, graph: Arc<PersistentGraph>) -> Self {
self.graph = Some(graph);
self
}
pub fn with_principal(mut self, principal: impl Into<String>) -> Self {
self.principal = Some(principal.into());
self
}
}
pub struct Pipeline {
stages: Vec<Box<dyn Operator>>,
}
impl Pipeline {
pub fn new() -> Self {
Self { stages: Vec::new() }
}
#[must_use]
pub fn stage(mut self, op: impl Operator + 'static) -> Self {
self.stages.push(Box::new(op));
self
}
pub async fn execute(&self, ctx: &OpContext) -> HirnResult<Vec<RecordBatch>> {
let mut batches: Vec<RecordBatch> = Vec::new();
for stage in &self.stages {
batches = stage.execute(batches, ctx).await?;
}
Ok(batches)
}
}
impl Default for Pipeline {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use arrow_array::StringArray;
use arrow_schema::{DataType, Field, Schema};
struct Identity;
#[async_trait]
impl Operator for Identity {
async fn execute(
&self,
input: Vec<RecordBatch>,
_ctx: &OpContext,
) -> HirnResult<Vec<RecordBatch>> {
Ok(input)
}
}
struct NonEmpty;
#[async_trait]
impl Operator for NonEmpty {
async fn execute(
&self,
input: Vec<RecordBatch>,
_ctx: &OpContext,
) -> HirnResult<Vec<RecordBatch>> {
Ok(input.into_iter().filter(|b| b.num_rows() > 0).collect())
}
}
fn test_ctx() -> OpContext {
let store = hirn_storage::HirnDb::open_memory();
OpContext::new(store.store_arc())
}
fn make_batch(values: &[&str]) -> RecordBatch {
let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Utf8, false)]));
RecordBatch::try_new(schema, vec![Arc::new(StringArray::from(values.to_vec()))]).unwrap()
}
#[tokio::test(flavor = "multi_thread")]
async fn pipeline_three_identity_passthrough() {
let ctx = test_ctx();
let input_batch = make_batch(&["a", "b", "c"]);
struct Source(Vec<RecordBatch>);
#[async_trait]
impl Operator for Source {
async fn execute(
&self,
_input: Vec<RecordBatch>,
_ctx: &OpContext,
) -> HirnResult<Vec<RecordBatch>> {
Ok(self.0.clone())
}
}
let pipeline = Pipeline::new()
.stage(Source(vec![input_batch.clone()]))
.stage(Identity)
.stage(Identity);
let result = pipeline.execute(&ctx).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].num_rows(), 3);
}
#[tokio::test(flavor = "multi_thread")]
async fn pipeline_filter_transform() {
let ctx = test_ctx();
let empty_schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Utf8, false)]));
let empty = RecordBatch::new_empty(empty_schema);
let non_empty = make_batch(&["x"]);
struct Source(Vec<RecordBatch>);
#[async_trait]
impl Operator for Source {
async fn execute(
&self,
_input: Vec<RecordBatch>,
_ctx: &OpContext,
) -> HirnResult<Vec<RecordBatch>> {
Ok(self.0.clone())
}
}
let pipeline = Pipeline::new()
.stage(Source(vec![empty, non_empty]))
.stage(NonEmpty);
let result = pipeline.execute(&ctx).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].num_rows(), 1);
}
}