Skip to main content

hirn_exec/operators/
prospective_indexing.rs

1//! `ProspectiveIndexingExec` — generates future queries at write time (Kumiho).
2//!
3//! For each incoming memory (slow-path only), asks an LLM to generate
4//! future questions this memory could answer, embeds them, and writes them
5//! to the `prospective_implications` dataset.
6//!
7//! Pass-through operator: input batch is emitted unchanged plus a
8//! `prospective_count (Int32)` column indicating how many implications
9//! were generated per row.
10
11use std::any::Any;
12use std::fmt;
13use std::sync::Arc;
14
15use arrow_array::{Array, Int32Array, RecordBatch, StringArray};
16use arrow_schema::{DataType, Field, Schema, SchemaRef};
17use datafusion_common::Result;
18use datafusion_execution::{SendableRecordBatchStream, TaskContext};
19use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType};
20use datafusion_physical_plan::stream::RecordBatchStreamAdapter;
21use datafusion_physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
22
23use crate::extensions::HirnSessionExt;
24
25/// Configuration for prospective indexing.
26#[derive(Debug, Clone)]
27pub struct ProspectiveConfig {
28    /// Number of future questions to generate per memory (default: 5).
29    pub num_questions: usize,
30    /// LLM timeout in seconds (default: 5). Skip on timeout.
31    pub timeout_secs: u64,
32    /// LLM prompt template. `{content}` is replaced with memory content.
33    pub prompt_template: String,
34    /// Whether prospective indexing is enabled (default: true).
35    pub enabled: bool,
36    /// Heuristic question templates (fallback when no LLM).
37    /// `{content}` is replaced with truncated memory content.
38    pub heuristic_templates: Vec<String>,
39}
40
41impl Default for ProspectiveConfig {
42    fn default() -> Self {
43        Self {
44            num_questions: 5,
45            timeout_secs: 5,
46            prompt_template: concat!(
47                "Given the following information, generate exactly {num_questions} ",
48                "future questions that this information could answer. ",
49                "Return only the questions, one per line.\n\n",
50                "Information: {content}"
51            )
52            .to_string(),
53            enabled: true,
54            heuristic_templates: vec![
55                "What is known about {content}?".into(),
56                "When did {content} happen?".into(),
57                "Who was involved in {content}?".into(),
58                "What was the outcome of {content}?".into(),
59                "Why is {content} important?".into(),
60            ],
61        }
62    }
63}
64
65/// DataFusion operator for prospective indexing of incoming memories.
66///
67/// Passes through input batches, appending `prospective_count` column.
68/// Uses LLM from `HirnSessionExt` to generate future queries, embeds
69/// them, and writes to `prospective_implications` via storage.
70#[derive(Debug)]
71pub struct ProspectiveIndexingExec {
72    input: Arc<dyn ExecutionPlan>,
73    config: ProspectiveConfig,
74    schema: SchemaRef,
75    properties: PlanProperties,
76}
77
78impl ProspectiveIndexingExec {
79    pub fn new(input: Arc<dyn ExecutionPlan>, config: ProspectiveConfig) -> Self {
80        let mut fields: Vec<Arc<Field>> = input.schema().fields().iter().cloned().collect();
81        fields.push(Arc::new(Field::new(
82            "prospective_count",
83            DataType::Int32,
84            false,
85        )));
86        let schema = Arc::new(Schema::new(fields));
87
88        let properties = PlanProperties::new(
89            datafusion_physical_expr::EquivalenceProperties::new(schema.clone()),
90            datafusion_physical_plan::Partitioning::UnknownPartitioning(1),
91            EmissionType::Final,
92            Boundedness::Bounded,
93        );
94
95        Self {
96            input,
97            config,
98            schema,
99            properties,
100        }
101    }
102
103    pub fn config(&self) -> &ProspectiveConfig {
104        &self.config
105    }
106}
107
108impl DisplayAs for ProspectiveIndexingExec {
109    fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
110        write!(
111            f,
112            "ProspectiveIndexingExec: questions={}, timeout={}s, enabled={}",
113            self.config.num_questions, self.config.timeout_secs, self.config.enabled
114        )
115    }
116}
117
118impl ExecutionPlan for ProspectiveIndexingExec {
119    fn name(&self) -> &str {
120        "ProspectiveIndexingExec"
121    }
122
123    fn as_any(&self) -> &dyn Any {
124        self
125    }
126
127    fn schema(&self) -> SchemaRef {
128        self.schema.clone()
129    }
130
131    fn properties(&self) -> &PlanProperties {
132        &self.properties
133    }
134
135    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
136        vec![&self.input]
137    }
138
139    fn with_new_children(
140        self: Arc<Self>,
141        children: Vec<Arc<dyn ExecutionPlan>>,
142    ) -> Result<Arc<dyn ExecutionPlan>> {
143        Ok(Arc::new(Self::new(
144            children[0].clone(),
145            self.config.clone(),
146        )))
147    }
148
149    fn execute(
150        &self,
151        partition: usize,
152        context: Arc<TaskContext>,
153    ) -> Result<SendableRecordBatchStream> {
154        let input_stream = self.input.execute(partition, context.clone())?;
155        let schema = self.schema.clone();
156        let config = self.config.clone();
157
158        let session_ctx = context
159            .session_config()
160            .options()
161            .extensions
162            .get::<HirnSessionExt>();
163        let embedder = session_ctx.as_ref().and_then(|ext| ext.embedder_arc());
164        let storage = session_ctx.and_then(|ext| ext.storage_arc());
165
166        let stream = futures::stream::once(async move {
167            use futures::StreamExt;
168
169            let mut batches = Vec::new();
170            let mut input_stream = input_stream;
171            while let Some(batch_result) = input_stream.next().await {
172                batches.push(batch_result?);
173            }
174
175            if batches.is_empty() {
176                let columns: Vec<Arc<dyn Array>> = schema
177                    .fields()
178                    .iter()
179                    .map(|f| arrow_array::new_empty_array(f.data_type()))
180                    .collect();
181                return RecordBatch::try_new(schema, columns).map_err(Into::into);
182            }
183
184            let merged =
185                arrow_select::concat::concat_batches(&batches[0].schema(), batches.iter())?;
186            let n = merged.num_rows();
187
188            let content_col = merged.column_by_name("content");
189            let contents = content_col.and_then(|c| c.as_any().downcast_ref::<StringArray>());
190
191            let id_col = merged.column_by_name("id");
192            let ids = id_col.and_then(|c| c.as_any().downcast_ref::<StringArray>());
193
194            let mut counts = Vec::with_capacity(n);
195
196            if !config.enabled {
197                // Prospective indexing disabled — output 0 for all rows.
198                counts.resize(n, 0i32);
199            } else {
200                // ── Batch processing: collect all questions, embed once, write once ──
201                struct RowQuestions {
202                    row_idx: usize,
203                    source_id: String,
204                    questions: Vec<String>,
205                }
206
207                let mut all_row_questions: Vec<RowQuestions> = Vec::new();
208                let mut row_question_counts: Vec<i32> = vec![0; n];
209
210                for i in 0..n {
211                    let content =
212                        contents.and_then(|c| if c.is_null(i) { None } else { Some(c.value(i)) });
213                    let source_id = ids
214                        .and_then(|c| if c.is_null(i) { None } else { Some(c.value(i)) })
215                        .unwrap_or("");
216
217                    if let Some(text) = content {
218                        let questions = generate_heuristic_questions(
219                            text,
220                            config.num_questions,
221                            &config.heuristic_templates,
222                        );
223                        if !questions.is_empty() {
224                            row_question_counts[i] = questions.len() as i32;
225                            all_row_questions.push(RowQuestions {
226                                row_idx: i,
227                                source_id: source_id.to_string(),
228                                questions,
229                            });
230                        }
231                    }
232                }
233
234                // Flatten all questions for batch embedding.
235                let all_questions: Vec<&str> = all_row_questions
236                    .iter()
237                    .flat_map(|rq| rq.questions.iter().map(|q| q.as_str()))
238                    .collect();
239
240                if !all_questions.is_empty() {
241                    if let (Some(emb), Some(storage)) = (&embedder, &storage) {
242                        let emb_result = tokio::time::timeout(
243                            std::time::Duration::from_secs(config.timeout_secs),
244                            emb.embed(&all_questions),
245                        )
246                        .await;
247
248                        match emb_result {
249                            Ok(Ok(embeddings)) if !embeddings.is_empty() => {
250                                // Validate embedding count matches question count.
251                                if embeddings.len() != all_questions.len() {
252                                    tracing::warn!(
253                                        expected = all_questions.len(),
254                                        actual = embeddings.len(),
255                                        "Embedding count mismatch, skipping prospective storage"
256                                    );
257                                    for rq in &all_row_questions {
258                                        row_question_counts[rq.row_idx] = 0;
259                                    }
260                                } else {
261                                    // Map embeddings back to rows and write in single batch.
262                                    let dims = embeddings[0].vector.len();
263                                    let total = embeddings.len();
264                                    let mut source_ids_vec = Vec::with_capacity(total);
265                                    let mut question_strs = Vec::with_capacity(total);
266
267                                    for rq in &all_row_questions {
268                                        for q in &rq.questions {
269                                            source_ids_vec.push(rq.source_id.as_str());
270                                            question_strs.push(q.as_str());
271                                        }
272                                    }
273
274                                    // Build FixedSizeList embedding column.
275                                    let flat_values: Vec<f32> = embeddings
276                                        .iter()
277                                        .flat_map(|e| e.vector.iter().copied())
278                                        .collect();
279                                    let values_arr = arrow_array::Float32Array::from(flat_values);
280                                    let emb_field =
281                                        Arc::new(Field::new("item", DataType::Float32, true));
282
283                                    if let Ok(embedding_col) =
284                                        arrow_array::FixedSizeListArray::try_new(
285                                            emb_field,
286                                            dims as i32,
287                                            Arc::new(values_arr),
288                                            None,
289                                        )
290                                    {
291                                        let batch_schema = Arc::new(Schema::new(vec![
292                                            Field::new("source_memory_id", DataType::Utf8, false),
293                                            Field::new("question", DataType::Utf8, false),
294                                            Field::new(
295                                                "embedding",
296                                                DataType::FixedSizeList(
297                                                    Arc::new(Field::new(
298                                                        "item",
299                                                        DataType::Float32,
300                                                        true,
301                                                    )),
302                                                    dims as i32,
303                                                ),
304                                                false,
305                                            ),
306                                        ]));
307
308                                        if let Ok(batch) = RecordBatch::try_new(
309                                            batch_schema,
310                                            vec![
311                                                Arc::new(StringArray::from(source_ids_vec)),
312                                                Arc::new(StringArray::from(question_strs)),
313                                                Arc::new(embedding_col),
314                                            ],
315                                        ) {
316                                            if let Err(e) = storage
317                                                .append("prospective_implications", batch)
318                                                .await
319                                            {
320                                                tracing::warn!(error = %e, "Failed to write prospective implications");
321                                                // Zero out counts on write failure.
322                                                for rq in &all_row_questions {
323                                                    row_question_counts[rq.row_idx] = 0;
324                                                }
325                                            }
326                                        } else {
327                                            tracing::warn!("Failed to build prospective batch");
328                                            for rq in &all_row_questions {
329                                                row_question_counts[rq.row_idx] = 0;
330                                            }
331                                        }
332                                    } else {
333                                        tracing::warn!("Failed to build embedding column");
334                                        for rq in &all_row_questions {
335                                            row_question_counts[rq.row_idx] = 0;
336                                        }
337                                    }
338                                } // end else (embedding count match)
339                            }
340                            Ok(Ok(_)) => {
341                                tracing::debug!("Embedding returned empty results");
342                            }
343                            Ok(Err(e)) => {
344                                tracing::warn!(
345                                    error = %e,
346                                    questions = all_questions.len(),
347                                    "Prospective batch embedding failed"
348                                );
349                            }
350                            Err(_) => {
351                                tracing::warn!(
352                                    timeout_secs = config.timeout_secs,
353                                    questions = all_questions.len(),
354                                    "Prospective embedding timed out"
355                                );
356                            }
357                        }
358                    }
359                }
360
361                counts = row_question_counts;
362            }
363
364            let count_col = Int32Array::from(counts);
365            let mut columns: Vec<Arc<dyn Array>> = merged.columns().to_vec();
366            columns.push(Arc::new(count_col));
367
368            RecordBatch::try_new(schema, columns).map_err(Into::into)
369        });
370
371        Ok(Box::pin(RecordBatchStreamAdapter::new(
372            self.schema.clone(),
373            stream,
374        )))
375    }
376}
377
378/// Generate heuristic questions from content using configurable templates.
379fn generate_heuristic_questions(content: &str, num: usize, templates: &[String]) -> Vec<String> {
380    let words: Vec<&str> = content.split_whitespace().collect();
381    if words.len() < 3 {
382        return vec![];
383    }
384
385    // Truncate to a reasonable prefix for question templates.
386    let truncated = hirn_core::text_util::truncate_at_word_boundary(content, 80);
387
388    templates
389        .iter()
390        .take(num)
391        .map(|t| t.replace("{content}", &truncated))
392        .collect()
393}
394
395#[cfg(test)]
396mod tests {
397    use super::*;
398    use arrow_schema::Field;
399
400    #[test]
401    fn default_config() {
402        let config = ProspectiveConfig::default();
403        assert_eq!(config.num_questions, 5);
404        assert_eq!(config.timeout_secs, 5);
405        assert!(config.enabled);
406        assert_eq!(config.heuristic_templates.len(), 5);
407    }
408
409    #[test]
410    fn heuristic_questions_short_content() {
411        let templates = ProspectiveConfig::default().heuristic_templates;
412        let q = generate_heuristic_questions("hi", 5, &templates);
413        assert!(q.is_empty());
414    }
415
416    #[test]
417    fn heuristic_questions_normal_content() {
418        let templates = ProspectiveConfig::default().heuristic_templates;
419        let q =
420            generate_heuristic_questions("Alice deployed version 2.3 on staging", 5, &templates);
421        assert_eq!(q.len(), 5);
422        assert!(q[0].contains("Alice deployed"));
423        // Full content fits within 80 chars, no truncation.
424        assert!(!q[0].contains("..."));
425    }
426
427    #[test]
428    fn heuristic_questions_truncates_long_content() {
429        let templates = ProspectiveConfig::default().heuristic_templates;
430        let long = "A ".repeat(100); // 200 chars
431        let q = generate_heuristic_questions(&long, 3, &templates);
432        assert_eq!(q.len(), 3);
433        // Should be truncated with "..."
434        assert!(q[0].contains("..."));
435    }
436
437    #[test]
438    fn heuristic_questions_custom_templates() {
439        let templates = vec![
440            "Tell me about {content}".into(),
441            "Summarize {content}".into(),
442        ];
443        let q =
444            generate_heuristic_questions("Alice deployed version 2.3 on staging", 5, &templates);
445        assert_eq!(q.len(), 2);
446        assert!(q[0].starts_with("Tell me about"));
447        assert!(q[1].starts_with("Summarize"));
448    }
449
450    #[test]
451    fn truncate_at_word_boundary_short() {
452        assert_eq!(
453            hirn_core::text_util::truncate_at_word_boundary("short", 80),
454            "short"
455        );
456    }
457
458    #[test]
459    fn truncate_at_word_boundary_long() {
460        let result =
461            hirn_core::text_util::truncate_at_word_boundary("hello world this is a long text", 15);
462        assert!(result.ends_with("..."));
463        assert!(result.len() <= 18); // 15 + "..."
464    }
465
466    #[tokio::test]
467    async fn execute_empty_input() {
468        use futures::StreamExt;
469
470        let empty_schema = Arc::new(Schema::new(vec![
471            Field::new("id", DataType::Utf8, false),
472            Field::new("content", DataType::Utf8, false),
473        ]));
474        let empty = Arc::new(datafusion_physical_plan::empty::EmptyExec::new(
475            empty_schema,
476        ));
477        let exec = ProspectiveIndexingExec::new(empty, ProspectiveConfig::default());
478        let ctx = Arc::new(TaskContext::default());
479        let mut stream = exec.execute(0, ctx).unwrap();
480        let batch = stream.next().await.unwrap().unwrap();
481        assert_eq!(batch.num_rows(), 0);
482        assert!(batch.schema().field_with_name("prospective_count").is_ok());
483    }
484
485    #[tokio::test]
486    async fn execute_disabled_produces_zero_counts() {
487        use crate::test_utils::MemoryBatchExec;
488        use futures::StreamExt;
489
490        let schema = Arc::new(Schema::new(vec![
491            Field::new("id", DataType::Utf8, false),
492            Field::new("content", DataType::Utf8, false),
493        ]));
494        let batch = RecordBatch::try_new(
495            schema.clone(),
496            vec![
497                Arc::new(StringArray::from(vec!["id1"])),
498                Arc::new(StringArray::from(vec!["test memory content"])),
499            ],
500        )
501        .unwrap();
502
503        let config = ProspectiveConfig {
504            enabled: false,
505            ..Default::default()
506        };
507        let mem = MemoryBatchExec::new(schema, vec![batch]);
508        let exec = ProspectiveIndexingExec::new(Arc::new(mem), config);
509        let ctx = Arc::new(TaskContext::default());
510        let mut stream = exec.execute(0, ctx).unwrap();
511        let result = stream.next().await.unwrap().unwrap();
512
513        let count_col = result
514            .column_by_name("prospective_count")
515            .unwrap()
516            .as_any()
517            .downcast_ref::<Int32Array>()
518            .unwrap();
519        assert_eq!(count_col.value(0), 0);
520    }
521}