Skip to main content

hirn_engine/operators/
mod.rs

1//! Composable cognitive operators for query pipelines.
2//!
3//! An [`Operator`] takes zero or more input `RecordBatch`es and produces output
4//! `RecordBatch`es. Operators compose into a [`Pipeline`] — a linear chain of
5//! stages where the output of stage N becomes the input of stage N+1.
6//!
7//! The first stage typically receives an empty input (source operators like
8//! [`VectorRecall`] produce data from the store). Subsequent stages filter,
9//! expand, or rerank the data flowing through the pipeline.
10
11mod narrative;
12mod policy;
13mod recall;
14mod rerank;
15mod temporal;
16
17pub use narrative::NarrativeAssemble;
18pub use policy::PolicyFilter;
19pub use recall::{HybridRecall, MultivectorRecall, VectorRecall};
20pub use rerank::RerankOp;
21pub use temporal::TemporalExpand;
22
23use std::sync::Arc;
24
25use arrow_array::RecordBatch;
26use async_trait::async_trait;
27
28use hirn_core::error::HirnResult;
29use hirn_storage::PhysicalStore;
30
31use crate::persistent_graph::PersistentGraph;
32
33// ── Operator Trait ──────────────────────────────────────────────────────
34
35/// A composable query-plan stage that transforms `RecordBatch` streams.
36#[async_trait]
37pub trait Operator: Send + Sync {
38    /// Execute this operator.
39    ///
40    /// * `input` — batches from the previous stage (empty for source operators).
41    /// * `ctx`   — shared execution context (store, graph, principal).
42    async fn execute(
43        &self,
44        input: Vec<RecordBatch>,
45        ctx: &OpContext,
46    ) -> HirnResult<Vec<RecordBatch>>;
47}
48
49// ── Execution Context ───────────────────────────────────────────────────
50
51/// Shared context available to every operator in a pipeline.
52pub struct OpContext {
53    /// Physical store for data access.
54    pub store: Arc<dyn PhysicalStore>,
55    /// Optional persistent graph for graph-based operators.
56    pub graph: Option<Arc<PersistentGraph>>,
57    /// The current principal (for policy filtering). `None` = permissive.
58    pub principal: Option<String>,
59}
60
61impl OpContext {
62    pub fn new(store: Arc<dyn PhysicalStore>) -> Self {
63        Self {
64            store,
65            graph: None,
66            principal: None,
67        }
68    }
69
70    pub fn with_graph(mut self, graph: Arc<PersistentGraph>) -> Self {
71        self.graph = Some(graph);
72        self
73    }
74
75    pub fn with_principal(mut self, principal: impl Into<String>) -> Self {
76        self.principal = Some(principal.into());
77        self
78    }
79}
80
81// ── Pipeline ────────────────────────────────────────────────────────────
82
83/// A linear chain of [`Operator`] stages.
84///
85/// ```text
86/// Pipeline::new()
87///     .stage(VectorRecall { ... })
88///     .stage(PolicyFilter)
89///     .stage(Rerank { ... })
90///     .execute(&ctx)
91///     .await
92/// ```
93pub struct Pipeline {
94    stages: Vec<Box<dyn Operator>>,
95}
96
97impl Pipeline {
98    pub fn new() -> Self {
99        Self { stages: Vec::new() }
100    }
101
102    /// Append an operator stage. Stages execute in insertion order.
103    #[must_use]
104    pub fn stage(mut self, op: impl Operator + 'static) -> Self {
105        self.stages.push(Box::new(op));
106        self
107    }
108
109    /// Execute the pipeline, threading batches through each stage.
110    pub async fn execute(&self, ctx: &OpContext) -> HirnResult<Vec<RecordBatch>> {
111        let mut batches: Vec<RecordBatch> = Vec::new();
112        for stage in &self.stages {
113            batches = stage.execute(batches, ctx).await?;
114        }
115        Ok(batches)
116    }
117}
118
119impl Default for Pipeline {
120    fn default() -> Self {
121        Self::new()
122    }
123}
124
125// ── Tests ───────────────────────────────────────────────────────────────
126
127#[cfg(test)]
128mod tests {
129    use super::*;
130    use arrow_array::StringArray;
131    use arrow_schema::{DataType, Field, Schema};
132
133    /// Identity operator — passes input through unchanged.
134    struct Identity;
135
136    #[async_trait]
137    impl Operator for Identity {
138        async fn execute(
139            &self,
140            input: Vec<RecordBatch>,
141            _ctx: &OpContext,
142        ) -> HirnResult<Vec<RecordBatch>> {
143            Ok(input)
144        }
145    }
146
147    /// Filter operator — keeps only batches with > 0 rows.
148    struct NonEmpty;
149
150    #[async_trait]
151    impl Operator for NonEmpty {
152        async fn execute(
153            &self,
154            input: Vec<RecordBatch>,
155            _ctx: &OpContext,
156        ) -> HirnResult<Vec<RecordBatch>> {
157            Ok(input.into_iter().filter(|b| b.num_rows() > 0).collect())
158        }
159    }
160
161    fn test_ctx() -> OpContext {
162        let store = hirn_storage::HirnDb::open_memory();
163        OpContext::new(store.store_arc())
164    }
165
166    fn make_batch(values: &[&str]) -> RecordBatch {
167        let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Utf8, false)]));
168        RecordBatch::try_new(schema, vec![Arc::new(StringArray::from(values.to_vec()))]).unwrap()
169    }
170
171    #[tokio::test(flavor = "multi_thread")]
172    async fn pipeline_three_identity_passthrough() {
173        let ctx = test_ctx();
174        let input_batch = make_batch(&["a", "b", "c"]);
175
176        struct Source(Vec<RecordBatch>);
177        #[async_trait]
178        impl Operator for Source {
179            async fn execute(
180                &self,
181                _input: Vec<RecordBatch>,
182                _ctx: &OpContext,
183            ) -> HirnResult<Vec<RecordBatch>> {
184                Ok(self.0.clone())
185            }
186        }
187
188        let pipeline = Pipeline::new()
189            .stage(Source(vec![input_batch.clone()]))
190            .stage(Identity)
191            .stage(Identity);
192
193        let result = pipeline.execute(&ctx).await.unwrap();
194        assert_eq!(result.len(), 1);
195        assert_eq!(result[0].num_rows(), 3);
196    }
197
198    #[tokio::test(flavor = "multi_thread")]
199    async fn pipeline_filter_transform() {
200        let ctx = test_ctx();
201        let empty_schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Utf8, false)]));
202        let empty = RecordBatch::new_empty(empty_schema);
203        let non_empty = make_batch(&["x"]);
204
205        struct Source(Vec<RecordBatch>);
206        #[async_trait]
207        impl Operator for Source {
208            async fn execute(
209                &self,
210                _input: Vec<RecordBatch>,
211                _ctx: &OpContext,
212            ) -> HirnResult<Vec<RecordBatch>> {
213                Ok(self.0.clone())
214            }
215        }
216
217        let pipeline = Pipeline::new()
218            .stage(Source(vec![empty, non_empty]))
219            .stage(NonEmpty);
220
221        let result = pipeline.execute(&ctx).await.unwrap();
222        assert_eq!(result.len(), 1);
223        assert_eq!(result[0].num_rows(), 1);
224    }
225}