1use std::any::Any;
12use std::fmt;
13use std::sync::Arc;
14
15use arrow_array::{Array, Int32Array, RecordBatch, StringArray};
16use arrow_schema::{DataType, Field, Schema, SchemaRef};
17use datafusion_common::Result;
18use datafusion_execution::{SendableRecordBatchStream, TaskContext};
19use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType};
20use datafusion_physical_plan::stream::RecordBatchStreamAdapter;
21use datafusion_physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
22
23use crate::extensions::HirnSessionExt;
24
25#[derive(Debug, Clone)]
27pub struct ProspectiveConfig {
28 pub num_questions: usize,
30 pub timeout_secs: u64,
32 pub prompt_template: String,
34 pub enabled: bool,
36 pub heuristic_templates: Vec<String>,
39}
40
41impl Default for ProspectiveConfig {
42 fn default() -> Self {
43 Self {
44 num_questions: 5,
45 timeout_secs: 5,
46 prompt_template: concat!(
47 "Given the following information, generate exactly {num_questions} ",
48 "future questions that this information could answer. ",
49 "Return only the questions, one per line.\n\n",
50 "Information: {content}"
51 )
52 .to_string(),
53 enabled: true,
54 heuristic_templates: vec![
55 "What is known about {content}?".into(),
56 "When did {content} happen?".into(),
57 "Who was involved in {content}?".into(),
58 "What was the outcome of {content}?".into(),
59 "Why is {content} important?".into(),
60 ],
61 }
62 }
63}
64
65#[derive(Debug)]
71pub struct ProspectiveIndexingExec {
72 input: Arc<dyn ExecutionPlan>,
73 config: ProspectiveConfig,
74 schema: SchemaRef,
75 properties: PlanProperties,
76}
77
78impl ProspectiveIndexingExec {
79 pub fn new(input: Arc<dyn ExecutionPlan>, config: ProspectiveConfig) -> Self {
80 let mut fields: Vec<Arc<Field>> = input.schema().fields().iter().cloned().collect();
81 fields.push(Arc::new(Field::new(
82 "prospective_count",
83 DataType::Int32,
84 false,
85 )));
86 let schema = Arc::new(Schema::new(fields));
87
88 let properties = PlanProperties::new(
89 datafusion_physical_expr::EquivalenceProperties::new(schema.clone()),
90 datafusion_physical_plan::Partitioning::UnknownPartitioning(1),
91 EmissionType::Final,
92 Boundedness::Bounded,
93 );
94
95 Self {
96 input,
97 config,
98 schema,
99 properties,
100 }
101 }
102
103 pub fn config(&self) -> &ProspectiveConfig {
104 &self.config
105 }
106}
107
108impl DisplayAs for ProspectiveIndexingExec {
109 fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
110 write!(
111 f,
112 "ProspectiveIndexingExec: questions={}, timeout={}s, enabled={}",
113 self.config.num_questions, self.config.timeout_secs, self.config.enabled
114 )
115 }
116}
117
118impl ExecutionPlan for ProspectiveIndexingExec {
119 fn name(&self) -> &str {
120 "ProspectiveIndexingExec"
121 }
122
123 fn as_any(&self) -> &dyn Any {
124 self
125 }
126
127 fn schema(&self) -> SchemaRef {
128 self.schema.clone()
129 }
130
131 fn properties(&self) -> &PlanProperties {
132 &self.properties
133 }
134
135 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
136 vec![&self.input]
137 }
138
139 fn with_new_children(
140 self: Arc<Self>,
141 children: Vec<Arc<dyn ExecutionPlan>>,
142 ) -> Result<Arc<dyn ExecutionPlan>> {
143 Ok(Arc::new(Self::new(
144 children[0].clone(),
145 self.config.clone(),
146 )))
147 }
148
149 fn execute(
150 &self,
151 partition: usize,
152 context: Arc<TaskContext>,
153 ) -> Result<SendableRecordBatchStream> {
154 let input_stream = self.input.execute(partition, context.clone())?;
155 let schema = self.schema.clone();
156 let config = self.config.clone();
157
158 let session_ctx = context
159 .session_config()
160 .options()
161 .extensions
162 .get::<HirnSessionExt>();
163 let embedder = session_ctx.as_ref().and_then(|ext| ext.embedder_arc());
164 let storage = session_ctx.and_then(|ext| ext.storage_arc());
165
166 let stream = futures::stream::once(async move {
167 use futures::StreamExt;
168
169 let mut batches = Vec::new();
170 let mut input_stream = input_stream;
171 while let Some(batch_result) = input_stream.next().await {
172 batches.push(batch_result?);
173 }
174
175 if batches.is_empty() {
176 let columns: Vec<Arc<dyn Array>> = schema
177 .fields()
178 .iter()
179 .map(|f| arrow_array::new_empty_array(f.data_type()))
180 .collect();
181 return RecordBatch::try_new(schema, columns).map_err(Into::into);
182 }
183
184 let merged =
185 arrow_select::concat::concat_batches(&batches[0].schema(), batches.iter())?;
186 let n = merged.num_rows();
187
188 let content_col = merged.column_by_name("content");
189 let contents = content_col.and_then(|c| c.as_any().downcast_ref::<StringArray>());
190
191 let id_col = merged.column_by_name("id");
192 let ids = id_col.and_then(|c| c.as_any().downcast_ref::<StringArray>());
193
194 let mut counts = Vec::with_capacity(n);
195
196 if !config.enabled {
197 counts.resize(n, 0i32);
199 } else {
200 struct RowQuestions {
202 row_idx: usize,
203 source_id: String,
204 questions: Vec<String>,
205 }
206
207 let mut all_row_questions: Vec<RowQuestions> = Vec::new();
208 let mut row_question_counts: Vec<i32> = vec![0; n];
209
210 for i in 0..n {
211 let content =
212 contents.and_then(|c| if c.is_null(i) { None } else { Some(c.value(i)) });
213 let source_id = ids
214 .and_then(|c| if c.is_null(i) { None } else { Some(c.value(i)) })
215 .unwrap_or("");
216
217 if let Some(text) = content {
218 let questions = generate_heuristic_questions(
219 text,
220 config.num_questions,
221 &config.heuristic_templates,
222 );
223 if !questions.is_empty() {
224 row_question_counts[i] = questions.len() as i32;
225 all_row_questions.push(RowQuestions {
226 row_idx: i,
227 source_id: source_id.to_string(),
228 questions,
229 });
230 }
231 }
232 }
233
234 let all_questions: Vec<&str> = all_row_questions
236 .iter()
237 .flat_map(|rq| rq.questions.iter().map(|q| q.as_str()))
238 .collect();
239
240 if !all_questions.is_empty() {
241 if let (Some(emb), Some(storage)) = (&embedder, &storage) {
242 let emb_result = tokio::time::timeout(
243 std::time::Duration::from_secs(config.timeout_secs),
244 emb.embed(&all_questions),
245 )
246 .await;
247
248 match emb_result {
249 Ok(Ok(embeddings)) if !embeddings.is_empty() => {
250 if embeddings.len() != all_questions.len() {
252 tracing::warn!(
253 expected = all_questions.len(),
254 actual = embeddings.len(),
255 "Embedding count mismatch, skipping prospective storage"
256 );
257 for rq in &all_row_questions {
258 row_question_counts[rq.row_idx] = 0;
259 }
260 } else {
261 let dims = embeddings[0].vector.len();
263 let total = embeddings.len();
264 let mut source_ids_vec = Vec::with_capacity(total);
265 let mut question_strs = Vec::with_capacity(total);
266
267 for rq in &all_row_questions {
268 for q in &rq.questions {
269 source_ids_vec.push(rq.source_id.as_str());
270 question_strs.push(q.as_str());
271 }
272 }
273
274 let flat_values: Vec<f32> = embeddings
276 .iter()
277 .flat_map(|e| e.vector.iter().copied())
278 .collect();
279 let values_arr = arrow_array::Float32Array::from(flat_values);
280 let emb_field =
281 Arc::new(Field::new("item", DataType::Float32, true));
282
283 if let Ok(embedding_col) =
284 arrow_array::FixedSizeListArray::try_new(
285 emb_field,
286 dims as i32,
287 Arc::new(values_arr),
288 None,
289 )
290 {
291 let batch_schema = Arc::new(Schema::new(vec![
292 Field::new("source_memory_id", DataType::Utf8, false),
293 Field::new("question", DataType::Utf8, false),
294 Field::new(
295 "embedding",
296 DataType::FixedSizeList(
297 Arc::new(Field::new(
298 "item",
299 DataType::Float32,
300 true,
301 )),
302 dims as i32,
303 ),
304 false,
305 ),
306 ]));
307
308 if let Ok(batch) = RecordBatch::try_new(
309 batch_schema,
310 vec![
311 Arc::new(StringArray::from(source_ids_vec)),
312 Arc::new(StringArray::from(question_strs)),
313 Arc::new(embedding_col),
314 ],
315 ) {
316 if let Err(e) = storage
317 .append("prospective_implications", batch)
318 .await
319 {
320 tracing::warn!(error = %e, "Failed to write prospective implications");
321 for rq in &all_row_questions {
323 row_question_counts[rq.row_idx] = 0;
324 }
325 }
326 } else {
327 tracing::warn!("Failed to build prospective batch");
328 for rq in &all_row_questions {
329 row_question_counts[rq.row_idx] = 0;
330 }
331 }
332 } else {
333 tracing::warn!("Failed to build embedding column");
334 for rq in &all_row_questions {
335 row_question_counts[rq.row_idx] = 0;
336 }
337 }
338 } }
340 Ok(Ok(_)) => {
341 tracing::debug!("Embedding returned empty results");
342 }
343 Ok(Err(e)) => {
344 tracing::warn!(
345 error = %e,
346 questions = all_questions.len(),
347 "Prospective batch embedding failed"
348 );
349 }
350 Err(_) => {
351 tracing::warn!(
352 timeout_secs = config.timeout_secs,
353 questions = all_questions.len(),
354 "Prospective embedding timed out"
355 );
356 }
357 }
358 }
359 }
360
361 counts = row_question_counts;
362 }
363
364 let count_col = Int32Array::from(counts);
365 let mut columns: Vec<Arc<dyn Array>> = merged.columns().to_vec();
366 columns.push(Arc::new(count_col));
367
368 RecordBatch::try_new(schema, columns).map_err(Into::into)
369 });
370
371 Ok(Box::pin(RecordBatchStreamAdapter::new(
372 self.schema.clone(),
373 stream,
374 )))
375 }
376}
377
378fn generate_heuristic_questions(content: &str, num: usize, templates: &[String]) -> Vec<String> {
380 let words: Vec<&str> = content.split_whitespace().collect();
381 if words.len() < 3 {
382 return vec![];
383 }
384
385 let truncated = hirn_core::text_util::truncate_at_word_boundary(content, 80);
387
388 templates
389 .iter()
390 .take(num)
391 .map(|t| t.replace("{content}", &truncated))
392 .collect()
393}
394
395#[cfg(test)]
396mod tests {
397 use super::*;
398 use arrow_schema::Field;
399
400 #[test]
401 fn default_config() {
402 let config = ProspectiveConfig::default();
403 assert_eq!(config.num_questions, 5);
404 assert_eq!(config.timeout_secs, 5);
405 assert!(config.enabled);
406 assert_eq!(config.heuristic_templates.len(), 5);
407 }
408
409 #[test]
410 fn heuristic_questions_short_content() {
411 let templates = ProspectiveConfig::default().heuristic_templates;
412 let q = generate_heuristic_questions("hi", 5, &templates);
413 assert!(q.is_empty());
414 }
415
416 #[test]
417 fn heuristic_questions_normal_content() {
418 let templates = ProspectiveConfig::default().heuristic_templates;
419 let q =
420 generate_heuristic_questions("Alice deployed version 2.3 on staging", 5, &templates);
421 assert_eq!(q.len(), 5);
422 assert!(q[0].contains("Alice deployed"));
423 assert!(!q[0].contains("..."));
425 }
426
427 #[test]
428 fn heuristic_questions_truncates_long_content() {
429 let templates = ProspectiveConfig::default().heuristic_templates;
430 let long = "A ".repeat(100); let q = generate_heuristic_questions(&long, 3, &templates);
432 assert_eq!(q.len(), 3);
433 assert!(q[0].contains("..."));
435 }
436
437 #[test]
438 fn heuristic_questions_custom_templates() {
439 let templates = vec![
440 "Tell me about {content}".into(),
441 "Summarize {content}".into(),
442 ];
443 let q =
444 generate_heuristic_questions("Alice deployed version 2.3 on staging", 5, &templates);
445 assert_eq!(q.len(), 2);
446 assert!(q[0].starts_with("Tell me about"));
447 assert!(q[1].starts_with("Summarize"));
448 }
449
450 #[test]
451 fn truncate_at_word_boundary_short() {
452 assert_eq!(
453 hirn_core::text_util::truncate_at_word_boundary("short", 80),
454 "short"
455 );
456 }
457
458 #[test]
459 fn truncate_at_word_boundary_long() {
460 let result =
461 hirn_core::text_util::truncate_at_word_boundary("hello world this is a long text", 15);
462 assert!(result.ends_with("..."));
463 assert!(result.len() <= 18); }
465
466 #[tokio::test]
467 async fn execute_empty_input() {
468 use futures::StreamExt;
469
470 let empty_schema = Arc::new(Schema::new(vec![
471 Field::new("id", DataType::Utf8, false),
472 Field::new("content", DataType::Utf8, false),
473 ]));
474 let empty = Arc::new(datafusion_physical_plan::empty::EmptyExec::new(
475 empty_schema,
476 ));
477 let exec = ProspectiveIndexingExec::new(empty, ProspectiveConfig::default());
478 let ctx = Arc::new(TaskContext::default());
479 let mut stream = exec.execute(0, ctx).unwrap();
480 let batch = stream.next().await.unwrap().unwrap();
481 assert_eq!(batch.num_rows(), 0);
482 assert!(batch.schema().field_with_name("prospective_count").is_ok());
483 }
484
485 #[tokio::test]
486 async fn execute_disabled_produces_zero_counts() {
487 use crate::test_utils::MemoryBatchExec;
488 use futures::StreamExt;
489
490 let schema = Arc::new(Schema::new(vec![
491 Field::new("id", DataType::Utf8, false),
492 Field::new("content", DataType::Utf8, false),
493 ]));
494 let batch = RecordBatch::try_new(
495 schema.clone(),
496 vec![
497 Arc::new(StringArray::from(vec!["id1"])),
498 Arc::new(StringArray::from(vec!["test memory content"])),
499 ],
500 )
501 .unwrap();
502
503 let config = ProspectiveConfig {
504 enabled: false,
505 ..Default::default()
506 };
507 let mem = MemoryBatchExec::new(schema, vec![batch]);
508 let exec = ProspectiveIndexingExec::new(Arc::new(mem), config);
509 let ctx = Arc::new(TaskContext::default());
510 let mut stream = exec.execute(0, ctx).unwrap();
511 let result = stream.next().await.unwrap().unwrap();
512
513 let count_col = result
514 .column_by_name("prospective_count")
515 .unwrap()
516 .as_any()
517 .downcast_ref::<Int32Array>()
518 .unwrap();
519 assert_eq!(count_col.value(0), 0);
520 }
521}