hirn_engine/operators/
mod.rs1mod narrative;
12mod policy;
13mod recall;
14mod rerank;
15mod temporal;
16
17pub use narrative::NarrativeAssemble;
18pub use policy::PolicyFilter;
19pub use recall::{HybridRecall, MultivectorRecall, VectorRecall};
20pub use rerank::RerankOp;
21pub use temporal::TemporalExpand;
22
23use std::sync::Arc;
24
25use arrow_array::RecordBatch;
26use async_trait::async_trait;
27
28use hirn_core::error::HirnResult;
29use hirn_storage::PhysicalStore;
30
31use crate::persistent_graph::PersistentGraph;
32
33#[async_trait]
37pub trait Operator: Send + Sync {
38 async fn execute(
43 &self,
44 input: Vec<RecordBatch>,
45 ctx: &OpContext,
46 ) -> HirnResult<Vec<RecordBatch>>;
47}
48
49pub struct OpContext {
53 pub store: Arc<dyn PhysicalStore>,
55 pub graph: Option<Arc<PersistentGraph>>,
57 pub principal: Option<String>,
59}
60
61impl OpContext {
62 pub fn new(store: Arc<dyn PhysicalStore>) -> Self {
63 Self {
64 store,
65 graph: None,
66 principal: None,
67 }
68 }
69
70 pub fn with_graph(mut self, graph: Arc<PersistentGraph>) -> Self {
71 self.graph = Some(graph);
72 self
73 }
74
75 pub fn with_principal(mut self, principal: impl Into<String>) -> Self {
76 self.principal = Some(principal.into());
77 self
78 }
79}
80
81pub struct Pipeline {
94 stages: Vec<Box<dyn Operator>>,
95}
96
97impl Pipeline {
98 pub fn new() -> Self {
99 Self { stages: Vec::new() }
100 }
101
102 #[must_use]
104 pub fn stage(mut self, op: impl Operator + 'static) -> Self {
105 self.stages.push(Box::new(op));
106 self
107 }
108
109 pub async fn execute(&self, ctx: &OpContext) -> HirnResult<Vec<RecordBatch>> {
111 let mut batches: Vec<RecordBatch> = Vec::new();
112 for stage in &self.stages {
113 batches = stage.execute(batches, ctx).await?;
114 }
115 Ok(batches)
116 }
117}
118
119impl Default for Pipeline {
120 fn default() -> Self {
121 Self::new()
122 }
123}
124
125#[cfg(test)]
128mod tests {
129 use super::*;
130 use arrow_array::StringArray;
131 use arrow_schema::{DataType, Field, Schema};
132
133 struct Identity;
135
136 #[async_trait]
137 impl Operator for Identity {
138 async fn execute(
139 &self,
140 input: Vec<RecordBatch>,
141 _ctx: &OpContext,
142 ) -> HirnResult<Vec<RecordBatch>> {
143 Ok(input)
144 }
145 }
146
147 struct NonEmpty;
149
150 #[async_trait]
151 impl Operator for NonEmpty {
152 async fn execute(
153 &self,
154 input: Vec<RecordBatch>,
155 _ctx: &OpContext,
156 ) -> HirnResult<Vec<RecordBatch>> {
157 Ok(input.into_iter().filter(|b| b.num_rows() > 0).collect())
158 }
159 }
160
161 fn test_ctx() -> OpContext {
162 let store = hirn_storage::HirnDb::open_memory();
163 OpContext::new(store.store_arc())
164 }
165
166 fn make_batch(values: &[&str]) -> RecordBatch {
167 let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Utf8, false)]));
168 RecordBatch::try_new(schema, vec![Arc::new(StringArray::from(values.to_vec()))]).unwrap()
169 }
170
171 #[tokio::test(flavor = "multi_thread")]
172 async fn pipeline_three_identity_passthrough() {
173 let ctx = test_ctx();
174 let input_batch = make_batch(&["a", "b", "c"]);
175
176 struct Source(Vec<RecordBatch>);
177 #[async_trait]
178 impl Operator for Source {
179 async fn execute(
180 &self,
181 _input: Vec<RecordBatch>,
182 _ctx: &OpContext,
183 ) -> HirnResult<Vec<RecordBatch>> {
184 Ok(self.0.clone())
185 }
186 }
187
188 let pipeline = Pipeline::new()
189 .stage(Source(vec![input_batch.clone()]))
190 .stage(Identity)
191 .stage(Identity);
192
193 let result = pipeline.execute(&ctx).await.unwrap();
194 assert_eq!(result.len(), 1);
195 assert_eq!(result[0].num_rows(), 3);
196 }
197
198 #[tokio::test(flavor = "multi_thread")]
199 async fn pipeline_filter_transform() {
200 let ctx = test_ctx();
201 let empty_schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Utf8, false)]));
202 let empty = RecordBatch::new_empty(empty_schema);
203 let non_empty = make_batch(&["x"]);
204
205 struct Source(Vec<RecordBatch>);
206 #[async_trait]
207 impl Operator for Source {
208 async fn execute(
209 &self,
210 _input: Vec<RecordBatch>,
211 _ctx: &OpContext,
212 ) -> HirnResult<Vec<RecordBatch>> {
213 Ok(self.0.clone())
214 }
215 }
216
217 let pipeline = Pipeline::new()
218 .stage(Source(vec![empty, non_empty]))
219 .stage(NonEmpty);
220
221 let result = pipeline.execute(&ctx).await.unwrap();
222 assert_eq!(result.len(), 1);
223 assert_eq!(result[0].num_rows(), 1);
224 }
225}