1use std::any::Any;
7use std::collections::{BTreeSet, HashMap, HashSet};
8use std::fmt;
9use std::sync::Arc;
10
11use arrow_array::{
12 Array, ArrayRef, Float32Array, Int64Array, RecordBatch, StringArray, UInt32Array, UInt64Array,
13};
14use arrow_schema::{DataType, Field, Schema, SchemaRef};
15use datafusion_common::Result;
16use datafusion_execution::{SendableRecordBatchStream, TaskContext};
17use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType};
18use datafusion_physical_plan::stream::RecordBatchStreamAdapter;
19use datafusion_physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
20use hirn_core::embed::Embedder;
21
22use crate::extensions::HirnSessionExt;
23use crate::operators::lance_hybrid_search::{
24 HybridSearchParams, LanceHybridSearchExec, RecallRow, resolved_search_params, search_rows,
25};
26
27#[derive(Debug, Clone)]
29pub struct IterativeConfig {
30 pub max_rounds: u32,
32 pub coverage_threshold: f32,
34 pub expansion_prior_rows: usize,
36 pub expansion_terms: usize,
38}
39
40impl Default for IterativeConfig {
41 fn default() -> Self {
42 Self {
43 max_rounds: 3,
44 coverage_threshold: 0.7,
45 expansion_prior_rows: 8,
46 expansion_terms: 4,
47 }
48 }
49}
50
51#[derive(Debug)]
60pub struct IterativeRetrievalExec {
61 input: Arc<dyn ExecutionPlan>,
62 config: IterativeConfig,
63 schema: SchemaRef,
64 properties: PlanProperties,
65 base_search_params: Option<HybridSearchParams>,
66}
67
68impl IterativeRetrievalExec {
69 pub fn new(input: Arc<dyn ExecutionPlan>, config: IterativeConfig) -> Self {
70 let mut fields: Vec<Arc<Field>> = input.schema().fields().iter().cloned().collect();
72 fields.push(Arc::new(Field::new(
73 "retrieval_round",
74 DataType::UInt32,
75 false,
76 )));
77 let schema = Arc::new(Schema::new(fields));
78
79 let properties = PlanProperties::new(
80 datafusion_physical_expr::EquivalenceProperties::new(schema.clone()),
81 datafusion_physical_plan::Partitioning::UnknownPartitioning(1),
82 EmissionType::Final,
83 Boundedness::Bounded,
84 );
85
86 Self {
87 base_search_params: find_base_search_params(input.as_ref()),
88 input,
89 config,
90 schema,
91 properties,
92 }
93 }
94
95 pub fn config(&self) -> &IterativeConfig {
96 &self.config
97 }
98}
99
100impl DisplayAs for IterativeRetrievalExec {
101 fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
102 write!(
103 f,
104 "IterativeRetrievalExec: max_rounds={}, coverage_threshold={}, \
105 expansion_prior_rows={}, expansion_terms={}",
106 self.config.max_rounds,
107 self.config.coverage_threshold,
108 self.config.expansion_prior_rows,
109 self.config.expansion_terms,
110 )
111 }
112}
113
114impl ExecutionPlan for IterativeRetrievalExec {
115 fn name(&self) -> &str {
116 "IterativeRetrievalExec"
117 }
118
119 fn as_any(&self) -> &dyn Any {
120 self
121 }
122
123 fn schema(&self) -> SchemaRef {
124 self.schema.clone()
125 }
126
127 fn properties(&self) -> &PlanProperties {
128 &self.properties
129 }
130
131 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
132 vec![&self.input]
133 }
134
135 fn with_new_children(
136 self: Arc<Self>,
137 children: Vec<Arc<dyn ExecutionPlan>>,
138 ) -> Result<Arc<dyn ExecutionPlan>> {
139 let [child]: [Arc<dyn ExecutionPlan>; 1] = children.try_into().map_err(|v: Vec<_>| {
140 datafusion_common::DataFusionError::Plan(format!(
141 "IterativeRetrievalExec requires exactly 1 child, got {}",
142 v.len()
143 ))
144 })?;
145 Ok(Arc::new(Self::new(child, self.config.clone())))
146 }
147
148 fn execute(
149 &self,
150 partition: usize,
151 context: Arc<TaskContext>,
152 ) -> Result<SendableRecordBatchStream> {
153 let input_stream = self.input.execute(partition, context.clone())?;
154 let schema = self.schema.clone();
155 let max_rounds = self.config.max_rounds;
156 let coverage_threshold = self.config.coverage_threshold;
157 let expansion_prior_rows = self.config.expansion_prior_rows;
158 let expansion_terms = self.config.expansion_terms;
159 let base_search_params = self.base_search_params.clone();
160
161 let session_ext = context
162 .session_config()
163 .options()
164 .extensions
165 .get::<HirnSessionExt>()
166 .cloned();
167 let storage = session_ext.as_ref().and_then(HirnSessionExt::storage_arc);
168 let embedder = session_ext.as_ref().and_then(HirnSessionExt::embedder_arc);
169
170 let stream = futures::stream::once(async move {
171 use futures::StreamExt;
172
173 let mut seen_ids: HashSet<String> = HashSet::new();
174 let mut all_rows: Vec<IterativeRecallRow> = Vec::new();
175
176 {
178 let mut input_stream = input_stream;
179 let mut round_batches = Vec::new();
180 while let Some(batch_result) = input_stream.next().await {
181 round_batches.push(batch_result?);
182 }
183 if round_batches.is_empty() {
184 let columns: Vec<Arc<dyn Array>> = schema
185 .fields()
186 .iter()
187 .map(|f| arrow_array::new_empty_array(f.data_type()))
188 .collect();
189 return RecordBatch::try_new(schema, columns).map_err(Into::into);
190 }
191 all_rows.extend(deduplicate_round_batches(&round_batches, &mut seen_ids, 1)?);
192 }
193
194 if all_rows.is_empty() {
195 let columns: Vec<Arc<dyn Array>> = schema
196 .fields()
197 .iter()
198 .map(|f| arrow_array::new_empty_array(f.data_type()))
199 .collect();
200 return RecordBatch::try_new(schema, columns).map_err(Into::into);
201 }
202
203 let Some(storage) = storage else {
204 return build_output_batch(schema, &all_rows);
205 };
206 let Some(embedder) = embedder else {
207 if max_rounds > 1 {
210 tracing::warn!(
211 max_rounds,
212 "IterativeRetrievalExec: embedder absent, falling back to single-round \
213 result; configure an embedder to enable full iterative retrieval"
214 );
215 }
216 return build_output_batch(schema, &all_rows);
217 };
218 let Some(base_search_params) = base_search_params else {
219 return build_output_batch(schema, &all_rows);
220 };
221
222 let params = resolved_search_params(&base_search_params, session_ext.as_ref());
223 let target_count = params.limit.max(5);
224 let mut previous_round = all_rows.clone();
225 let mut current_round = 1u32;
229
230 while current_round < max_rounds
231 && (all_rows.len() as f32 / target_count as f32) < coverage_threshold
232 && !previous_round.is_empty()
233 {
234 current_round += 1;
235 let Some(expanded_query) = build_expanded_query(
236 params.fts_query.as_str(),
237 &previous_round,
238 expansion_prior_rows,
239 expansion_terms,
240 ) else {
241 break;
242 };
243
244 let query_embedding =
245 embedder
246 .embed(&[expanded_query.as_str()])
247 .await
248 .map_err(|error| {
249 datafusion_common::DataFusionError::Execution(error.to_string())
250 })?;
251 let Some(query_embedding) = query_embedding.first() else {
252 break;
253 };
254
255 let mut round_params = params.clone();
256 round_params
257 .query_vector
258 .clone_from(&query_embedding.vector);
259 round_params.fts_query = expanded_query;
260
261 let round_rows =
262 search_rows(storage.as_ref(), &round_params)
263 .await
264 .map_err(|error| {
265 datafusion_common::DataFusionError::Execution(error.to_string())
266 })?;
267 let deduped_rows =
268 deduplicate_search_rows(round_rows, &mut seen_ids, current_round, &schema);
269 if deduped_rows.is_empty() {
270 break;
271 }
272
273 previous_round.clone_from(&deduped_rows);
274 all_rows.extend(deduped_rows);
275 }
276
277 build_output_batch(schema, &all_rows)
278 });
279
280 Ok(Box::pin(RecordBatchStreamAdapter::new(
281 self.schema.clone(),
282 stream,
283 )))
284 }
285}
286
287#[derive(Debug, Clone)]
288struct IterativeRecallRow {
289 base: RecallRow,
290 activation_score: Option<f32>,
291 activation_depth: Option<u32>,
292 causal_score: Option<f32>,
293 causal_depth: Option<u32>,
294 retrieval_round: u32,
295}
296
297fn find_base_search_params(plan: &dyn ExecutionPlan) -> Option<HybridSearchParams> {
298 if let Some(search) = plan.as_any().downcast_ref::<LanceHybridSearchExec>() {
299 return Some(search.params().clone());
300 }
301
302 for child in plan.children() {
303 if let Some(params) = find_base_search_params(child.as_ref()) {
304 return Some(params);
305 }
306 }
307 None
308}
309
310fn deduplicate_round_batches(
311 batches: &[RecordBatch],
312 seen_ids: &mut HashSet<String>,
313 retrieval_round: u32,
314) -> datafusion_common::Result<Vec<IterativeRecallRow>> {
315 let mut result = Vec::new();
316 for batch in batches {
317 for row in recall_rows_from_batch(batch, retrieval_round)? {
318 if seen_ids.insert(row.base.id.clone()) {
319 result.push(row);
320 }
321 }
322 }
323 Ok(result)
324}
325
326fn deduplicate_search_rows(
327 rows: Vec<RecallRow>,
328 seen_ids: &mut HashSet<String>,
329 retrieval_round: u32,
330 schema: &Schema,
331) -> Vec<IterativeRecallRow> {
332 let include_activation = schema.field_with_name("activation_score").is_ok();
333 let include_causal = schema.field_with_name("causal_score").is_ok();
334
335 rows.into_iter()
336 .filter(|row| seen_ids.insert(row.id.clone()))
337 .map(|base| IterativeRecallRow {
338 base,
339 activation_score: include_activation.then_some(0.0),
340 activation_depth: include_activation.then_some(0),
341 causal_score: include_causal.then_some(0.0),
342 causal_depth: include_causal.then_some(0),
343 retrieval_round,
344 })
345 .collect()
346}
347
348fn recall_rows_from_batch(
349 batch: &RecordBatch,
350 retrieval_round: u32,
351) -> datafusion_common::Result<Vec<IterativeRecallRow>> {
352 let ids = required_string_column(batch, "id")?;
353 let contents = required_string_column(batch, "content")?;
354 let full_contents = batch
355 .column_by_name("full_content")
356 .and_then(|column| column.as_any().downcast_ref::<StringArray>());
357 let layers = required_string_column(batch, "layer")?;
358 let namespaces = required_string_column(batch, "namespace")?;
359 let scores = required_f32_column(batch, "score")?;
360 let temporal_ms = required_i64_column(batch, "temporal_ms")?;
361 let created_at_ms = required_i64_column(batch, "created_at_ms")?;
362 let importances = required_f32_column(batch, "importance")?;
363 let access_counts = required_u32_column(batch, "access_count")?;
364 let surprises = optional_f32_column(batch, "surprise");
365 let evidence_counts = optional_u32_column(batch, "evidence_count");
366 let invocation_counts = optional_u64_column(batch, "invocation_count");
367 let activation_scores = optional_f32_column(batch, "activation_score");
368 let activation_depths = optional_u32_column(batch, "depth");
369 let causal_scores = optional_f32_column(batch, "causal_score");
370 let causal_depths = optional_u32_column(batch, "causal_depth");
371
372 let mut rows = Vec::with_capacity(batch.num_rows());
373 for row in 0..batch.num_rows() {
374 rows.push(IterativeRecallRow {
375 base: RecallRow {
376 id: ids.value(row).to_string(),
377 content: contents.value(row).to_string(),
378 full_content: full_contents
379 .map(|fc| fc.value(row).to_string())
380 .unwrap_or_else(|| contents.value(row).to_string()),
381 layer: match layers.value(row) {
382 "episodic" => "episodic",
383 "semantic" => "semantic",
384 "procedural" => "procedural",
385 other => {
386 return Err(datafusion_common::DataFusionError::Execution(format!(
387 "unsupported recall layer `{other}` in iterative retrieval"
388 )));
389 }
390 },
391 namespace: namespaces.value(row).to_string(),
392 score: scores.value(row),
393 temporal_ms: temporal_ms.value(row),
394 created_at_ms: created_at_ms.value(row),
395 importance: importances.value(row),
396 access_count: access_counts.value(row),
397 surprise: optional_f32_value(surprises, row),
398 evidence_count: optional_u32_value(evidence_counts, row),
399 invocation_count: optional_u64_value(invocation_counts, row),
400 },
401 activation_score: optional_f32_value(activation_scores, row),
402 activation_depth: optional_u32_value(activation_depths, row),
403 causal_score: optional_f32_value(causal_scores, row),
404 causal_depth: optional_u32_value(causal_depths, row),
405 retrieval_round,
406 });
407 }
408
409 Ok(rows)
410}
411
412fn build_expanded_query(
420 original_query: &str,
421 prior_rows: &[IterativeRecallRow],
422 prior_rows_limit: usize,
423 expansion_terms: usize,
424) -> Option<String> {
425 let original_terms = lexical_terms(original_query);
426 let candidates: Vec<&IterativeRecallRow> = prior_rows.iter().take(prior_rows_limit).collect();
427
428 let row_terms: Vec<BTreeSet<String>> = candidates
430 .iter()
431 .map(|row| lexical_terms(&row.base.content))
432 .collect();
433
434 let mut doc_freq: HashMap<String, usize> = HashMap::new();
436 for terms in &row_terms {
437 for term in terms {
438 if !original_terms.contains(term) {
439 *doc_freq.entry(term.clone()).or_insert(0) += 1;
440 }
441 }
442 }
443
444 let mut term_scores: HashMap<String, f32> = HashMap::new();
448 for (row, terms) in candidates.iter().zip(&row_terms) {
449 for term in terms {
450 if original_terms.contains(term) {
451 continue;
452 }
453 let df = *doc_freq.get(term).unwrap_or(&1) as f32;
454 let idf_weight = 1.0 / df.sqrt();
455 *term_scores.entry(term.clone()).or_insert(0.0) +=
456 row.base.score.max(0.05) * idf_weight;
457 }
458 }
459
460 let mut ranked: Vec<(String, f32)> = term_scores.into_iter().collect();
461 ranked.sort_by(|a, b| b.1.total_cmp(&a.1).then_with(|| a.0.cmp(&b.0)));
463
464 let expansion: Vec<String> = ranked
465 .into_iter()
466 .take(expansion_terms)
467 .map(|(term, _)| term)
468 .collect();
469
470 if expansion.is_empty() {
471 return None;
472 }
473
474 Some(format!("{} {}", original_query, expansion.join(" ")))
475}
476
477fn lexical_terms(text: &str) -> BTreeSet<String> {
478 const STOP_WORDS: &[&str] = &[
479 "a", "an", "and", "are", "as", "at", "be", "by", "for", "from", "how", "i", "in", "is",
480 "it", "of", "on", "or", "that", "the", "to", "was", "what", "when", "where", "which",
481 "who", "why", "with",
482 ];
483
484 text.split_whitespace()
485 .map(|token| {
486 token
487 .trim_matches(|c: char| !c.is_alphanumeric())
488 .to_ascii_lowercase()
489 })
490 .filter(|token| token.len() > 2 && !STOP_WORDS.contains(&token.as_str()))
491 .collect()
492}
493
494fn required_string_column<'a>(
495 batch: &'a RecordBatch,
496 name: &str,
497) -> datafusion_common::Result<&'a StringArray> {
498 batch
499 .column_by_name(name)
500 .and_then(|column| column.as_any().downcast_ref::<StringArray>())
501 .ok_or_else(|| {
502 datafusion_common::DataFusionError::Execution(format!(
503 "iterative retrieval batch missing `{name}` string column"
504 ))
505 })
506}
507
508fn required_f32_column<'a>(
509 batch: &'a RecordBatch,
510 name: &str,
511) -> datafusion_common::Result<&'a Float32Array> {
512 batch
513 .column_by_name(name)
514 .and_then(|column| column.as_any().downcast_ref::<Float32Array>())
515 .ok_or_else(|| {
516 datafusion_common::DataFusionError::Execution(format!(
517 "iterative retrieval batch missing `{name}` f32 column"
518 ))
519 })
520}
521
522fn required_i64_column<'a>(
523 batch: &'a RecordBatch,
524 name: &str,
525) -> datafusion_common::Result<&'a Int64Array> {
526 batch
527 .column_by_name(name)
528 .and_then(|column| column.as_any().downcast_ref::<Int64Array>())
529 .ok_or_else(|| {
530 datafusion_common::DataFusionError::Execution(format!(
531 "iterative retrieval batch missing `{name}` i64 column"
532 ))
533 })
534}
535
536fn required_u32_column<'a>(
537 batch: &'a RecordBatch,
538 name: &str,
539) -> datafusion_common::Result<&'a UInt32Array> {
540 batch
541 .column_by_name(name)
542 .and_then(|column| column.as_any().downcast_ref::<UInt32Array>())
543 .ok_or_else(|| {
544 datafusion_common::DataFusionError::Execution(format!(
545 "iterative retrieval batch missing `{name}` u32 column"
546 ))
547 })
548}
549
550fn optional_f32_column<'a>(batch: &'a RecordBatch, name: &str) -> Option<&'a Float32Array> {
551 batch
552 .column_by_name(name)
553 .and_then(|column| column.as_any().downcast_ref::<Float32Array>())
554}
555
556fn optional_u32_column<'a>(batch: &'a RecordBatch, name: &str) -> Option<&'a UInt32Array> {
557 batch
558 .column_by_name(name)
559 .and_then(|column| column.as_any().downcast_ref::<UInt32Array>())
560}
561
562fn optional_u64_column<'a>(batch: &'a RecordBatch, name: &str) -> Option<&'a UInt64Array> {
563 batch
564 .column_by_name(name)
565 .and_then(|column| column.as_any().downcast_ref::<UInt64Array>())
566}
567
568fn optional_f32_value(array: Option<&Float32Array>, row: usize) -> Option<f32> {
569 array.and_then(|array| (!array.is_null(row)).then(|| array.value(row)))
570}
571
572fn optional_u32_value(array: Option<&UInt32Array>, row: usize) -> Option<u32> {
573 array.and_then(|array| (!array.is_null(row)).then(|| array.value(row)))
574}
575
576fn optional_u64_value(array: Option<&UInt64Array>, row: usize) -> Option<u64> {
577 array.and_then(|array| (!array.is_null(row)).then(|| array.value(row)))
578}
579
580fn build_output_batch(
581 schema: SchemaRef,
582 rows: &[IterativeRecallRow],
583) -> datafusion_common::Result<RecordBatch> {
584 if rows.is_empty() {
585 return Ok(RecordBatch::new_empty(schema));
586 }
587
588 let include_activation = schema.field_with_name("activation_score").is_ok();
589 let include_causal = schema.field_with_name("causal_score").is_ok();
590
591 let ids = rows
592 .iter()
593 .map(|row| row.base.id.as_str())
594 .collect::<Vec<_>>();
595 let contents = rows
596 .iter()
597 .map(|row| row.base.content.as_str())
598 .collect::<Vec<_>>();
599 let full_contents = rows
600 .iter()
601 .map(|row| row.base.full_content.as_str())
602 .collect::<Vec<_>>();
603 let layers = rows.iter().map(|row| row.base.layer).collect::<Vec<_>>();
604 let namespaces = rows
605 .iter()
606 .map(|row| row.base.namespace.as_str())
607 .collect::<Vec<_>>();
608 let scores = rows.iter().map(|row| row.base.score).collect::<Vec<_>>();
609 let temporal = rows
610 .iter()
611 .map(|row| row.base.temporal_ms)
612 .collect::<Vec<_>>();
613 let created_at = rows
614 .iter()
615 .map(|row| row.base.created_at_ms)
616 .collect::<Vec<_>>();
617 let importances = rows
618 .iter()
619 .map(|row| row.base.importance)
620 .collect::<Vec<_>>();
621 let access_counts = rows
622 .iter()
623 .map(|row| row.base.access_count)
624 .collect::<Vec<_>>();
625 let surprises = rows.iter().map(|row| row.base.surprise).collect::<Vec<_>>();
626 let evidence_counts = rows
627 .iter()
628 .map(|row| row.base.evidence_count)
629 .collect::<Vec<_>>();
630 let invocation_counts = rows
631 .iter()
632 .map(|row| row.base.invocation_count)
633 .collect::<Vec<_>>();
634 let retrieval_rounds = rows
635 .iter()
636 .map(|row| row.retrieval_round)
637 .collect::<Vec<_>>();
638
639 let mut columns: Vec<ArrayRef> = vec![
640 Arc::new(StringArray::from(ids)) as ArrayRef,
641 Arc::new(StringArray::from(contents)) as ArrayRef,
642 Arc::new(StringArray::from(full_contents)) as ArrayRef,
643 Arc::new(StringArray::from(layers)) as ArrayRef,
644 Arc::new(StringArray::from(namespaces)) as ArrayRef,
645 Arc::new(Float32Array::from(scores)) as ArrayRef,
646 Arc::new(Int64Array::from(temporal)) as ArrayRef,
647 Arc::new(Int64Array::from(created_at)) as ArrayRef,
648 Arc::new(Float32Array::from(importances)) as ArrayRef,
649 Arc::new(UInt32Array::from(access_counts)) as ArrayRef,
650 Arc::new(Float32Array::from(surprises)) as ArrayRef,
651 Arc::new(UInt32Array::from(evidence_counts)) as ArrayRef,
652 Arc::new(UInt64Array::from(invocation_counts)) as ArrayRef,
653 ];
654
655 if include_activation {
656 columns.push(Arc::new(Float32Array::from(
657 rows.iter()
658 .map(|row| row.activation_score.unwrap_or(0.0))
659 .collect::<Vec<_>>(),
660 )) as ArrayRef);
661 columns.push(Arc::new(UInt32Array::from(
662 rows.iter()
663 .map(|row| row.activation_depth.unwrap_or(0))
664 .collect::<Vec<_>>(),
665 )) as ArrayRef);
666 }
667
668 if include_causal {
669 columns.push(Arc::new(Float32Array::from(
670 rows.iter()
671 .map(|row| row.causal_score.unwrap_or(0.0))
672 .collect::<Vec<_>>(),
673 )) as ArrayRef);
674 columns.push(Arc::new(UInt32Array::from(
675 rows.iter()
676 .map(|row| row.causal_depth.unwrap_or(0))
677 .collect::<Vec<_>>(),
678 )) as ArrayRef);
679 }
680
681 columns.push(Arc::new(UInt32Array::from(retrieval_rounds)) as ArrayRef);
682
683 Ok(RecordBatch::try_new(schema, columns)?)
684}
685
686#[cfg(test)]
687mod tests {
688 use super::*;
689 use std::sync::Arc;
690
691 use async_trait::async_trait;
692 use hirn_core::HirnResult;
693 use hirn_core::config::HirnConfig;
694 use hirn_core::embed::{Embedding, MultivectorEmbedding};
695 use hirn_core::episodic::EpisodicRecord;
696 use hirn_core::types::AgentId;
697 use hirn_storage::PhysicalStore;
698 use hirn_storage::datasets::episodic;
699 use hirn_storage::memory_store::MemoryStore;
700
701 use crate::extensions::HirnSessionExt;
702 use crate::operators::lance_hybrid_search::LanceHybridSearchExec;
703
704 #[test]
705 fn default_config() {
706 let config = IterativeConfig::default();
707 assert_eq!(config.max_rounds, 3);
708 assert!((config.coverage_threshold - 0.7).abs() < f32::EPSILON);
709 assert_eq!(config.expansion_prior_rows, 8);
710 assert_eq!(config.expansion_terms, 4);
711 }
712
713 #[test]
714 fn display_format() {
715 let exec = IterativeRetrievalExec::new(
716 Arc::new(datafusion_physical_plan::empty::EmptyExec::new(Arc::new(
717 Schema::empty(),
718 ))),
719 IterativeConfig::default(),
720 );
721 assert_eq!(exec.name(), "IterativeRetrievalExec");
722 }
723
724 #[tokio::test]
725 async fn execute_empty_input() {
726 use futures::StreamExt;
727
728 let empty_schema = Arc::new(Schema::new(vec![
729 Field::new("id", DataType::Utf8, false),
730 Field::new("content", DataType::Utf8, false),
731 ]));
732 let empty = Arc::new(datafusion_physical_plan::empty::EmptyExec::new(
733 empty_schema,
734 ));
735 let exec = IterativeRetrievalExec::new(empty, IterativeConfig::default());
736 let ctx = Arc::new(TaskContext::default());
737 let mut stream = exec.execute(0, ctx).unwrap();
738 let batch = stream.next().await.unwrap().unwrap();
739 assert_eq!(batch.num_rows(), 0);
740 }
741
742 #[derive(Debug)]
743 struct KeywordEmbedder;
744
745 #[async_trait]
746 impl Embedder for KeywordEmbedder {
747 async fn embed(&self, texts: &[&str]) -> HirnResult<Vec<Embedding>> {
748 Ok(texts
749 .iter()
750 .map(|text| Embedding {
751 vector: if text.to_ascii_lowercase().contains("entanglement") {
752 vec![0.0, 1.0]
753 } else {
754 vec![1.0, 0.0]
755 },
756 model_id: "keyword-test".to_string(),
757 })
758 .collect())
759 }
760
761 fn dimensions(&self) -> usize {
762 2
763 }
764
765 fn model_id(&self) -> &str {
766 "keyword-test"
767 }
768
769 fn max_input_tokens(&self) -> usize {
770 1024
771 }
772
773 async fn embed_multivec(&self, _texts: &[&str]) -> HirnResult<Vec<MultivectorEmbedding>> {
774 Ok(Vec::new())
775 }
776 }
777
778 fn test_recall_schema() -> SchemaRef {
779 Arc::new(Schema::new(vec![
780 Field::new("id", DataType::Utf8, false),
781 Field::new("content", DataType::Utf8, false),
782 Field::new("full_content", DataType::Utf8, false),
783 Field::new("layer", DataType::Utf8, false),
784 Field::new("namespace", DataType::Utf8, false),
785 Field::new("score", DataType::Float32, true),
786 Field::new("temporal_ms", DataType::Int64, false),
787 Field::new("created_at_ms", DataType::Int64, false),
788 Field::new("importance", DataType::Float32, true),
789 Field::new("access_count", DataType::UInt32, true),
790 Field::new("surprise", DataType::Float32, true),
791 Field::new("evidence_count", DataType::UInt32, true),
792 Field::new("invocation_count", DataType::UInt64, true),
793 ]))
794 }
795
796 #[tokio::test]
797 async fn iterative_retrieval_exec_runs_real_second_round() {
798 use futures::StreamExt;
799
800 let storage: Arc<dyn PhysicalStore> = Arc::new(MemoryStore::new());
801 let records = vec![
802 EpisodicRecord::builder()
803 .content("quantum qubits entanglement")
804 .agent_id(AgentId::new("iterative_test").unwrap())
805 .embedding(vec![1.0, 0.0])
806 .build()
807 .unwrap(),
808 EpisodicRecord::builder()
809 .content("entanglement teleportation bell-states")
810 .agent_id(AgentId::new("iterative_test").unwrap())
811 .embedding(vec![0.0, 1.0])
812 .build()
813 .unwrap(),
814 ];
815 storage
816 .append(
817 episodic::DATASET_NAME,
818 episodic::to_batch(&records, 2).unwrap(),
819 )
820 .await
821 .unwrap();
822
823 let ctx = datafusion::prelude::SessionContext::new();
824 HirnSessionExt::new(
825 Arc::new(0_u8),
826 Arc::new(HirnConfig::default()),
827 Some(Arc::new(KeywordEmbedder)),
828 )
829 .with_storage(Arc::clone(&storage))
830 .register(&ctx)
831 .unwrap();
832
833 let search = Arc::new(LanceHybridSearchExec::new(
834 test_recall_schema(),
835 HybridSearchParams {
836 datasets: vec![episodic::DATASET_NAME.to_string()],
837 vector_column: "embedding".to_string(),
838 query_vector: vec![1.0, 0.0],
839 hybrid_mode: false,
840 fts_columns: vec!["content".to_string()],
841 fts_query: "quantum".to_string(),
842 limit: 1,
843 metric: hirn_storage::store::DistanceMetric::Cosine,
844 filter: None,
845 numeric_filters: Vec::new(),
846 temporal_start_ms: None,
847 temporal_end_ms: None,
848 temporal_expansion: false,
849 temporal_boost: 1.25,
850 },
851 ));
852
853 let exec = IterativeRetrievalExec::new(
854 search,
855 IterativeConfig {
856 max_rounds: 2,
857 coverage_threshold: 0.9,
858 ..IterativeConfig::default()
859 },
860 );
861 let mut stream = exec.execute(0, ctx.task_ctx()).unwrap();
862 let batch = stream.next().await.unwrap().unwrap();
863
864 let ids = batch
865 .column_by_name("id")
866 .and_then(|column| column.as_any().downcast_ref::<StringArray>())
867 .unwrap();
868 let rounds = batch
869 .column_by_name("retrieval_round")
870 .and_then(|column| column.as_any().downcast_ref::<UInt32Array>())
871 .unwrap();
872
873 assert_eq!(batch.num_rows(), 2);
874 assert_eq!(rounds.value(0), 1);
875 assert_eq!(rounds.value(1), 2);
876 assert_ne!(ids.value(0), ids.value(1));
877 }
878}