Skip to main content

hirn_exec/operators/
iterative_retrieval.rs

1//! `IterativeRetrievalExec` — multi-hop retrieval with query reformulation.
2//!
3//! Loop: retrieve → extract entities → compare coverage → if gaps, reformulate → retrieve again.
4//! Maximum configurable rounds (default: 3). Results deduplicated by memory ID.
5
6use std::any::Any;
7use std::collections::{BTreeSet, HashMap, HashSet};
8use std::fmt;
9use std::sync::Arc;
10
11use arrow_array::{
12    Array, ArrayRef, Float32Array, Int64Array, RecordBatch, StringArray, UInt32Array, UInt64Array,
13};
14use arrow_schema::{DataType, Field, Schema, SchemaRef};
15use datafusion_common::Result;
16use datafusion_execution::{SendableRecordBatchStream, TaskContext};
17use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType};
18use datafusion_physical_plan::stream::RecordBatchStreamAdapter;
19use datafusion_physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
20use hirn_core::embed::Embedder;
21
22use crate::extensions::HirnSessionExt;
23use crate::operators::lance_hybrid_search::{
24    HybridSearchParams, LanceHybridSearchExec, RecallRow, resolved_search_params, search_rows,
25};
26
27/// Configuration for iterative retrieval.
28#[derive(Debug, Clone)]
29pub struct IterativeConfig {
30    /// Maximum retrieval rounds (default: 3, validated 1–5 at plan-compile time).
31    pub max_rounds: u32,
32    /// Coverage threshold — stop when `retrieved / target >= threshold` (default: 0.7).
33    pub coverage_threshold: f32,
34    /// Maximum rows from prior rounds considered for PRF query expansion (default: 8).
35    pub expansion_prior_rows: usize,
36    /// Maximum gap-filling terms appended to the reformulated query (default: 4).
37    pub expansion_terms: usize,
38}
39
40impl Default for IterativeConfig {
41    fn default() -> Self {
42        Self {
43            max_rounds: 3,
44            coverage_threshold: 0.7,
45            expansion_prior_rows: 8,
46            expansion_terms: 4,
47        }
48    }
49}
50
51/// DataFusion operator for iterative multi-hop retrieval.
52///
53/// Each round retrieves from the child plan using Pseudo-Relevance Feedback (PRF):
54/// salient entities from prior-round results are extracted and appended to the
55/// query, then a new hybrid search round is issued. Results are deduplicated by
56/// memory ID. Rounds continue until the coverage threshold is met or `max_rounds`
57/// is exhausted. Requires `base_search_params` with an embedder for rounds > 1;
58/// falls back to single-round passthrough when the embedder is unavailable.
59#[derive(Debug)]
60pub struct IterativeRetrievalExec {
61    input: Arc<dyn ExecutionPlan>,
62    config: IterativeConfig,
63    schema: SchemaRef,
64    properties: PlanProperties,
65    base_search_params: Option<HybridSearchParams>,
66}
67
68impl IterativeRetrievalExec {
69    pub fn new(input: Arc<dyn ExecutionPlan>, config: IterativeConfig) -> Self {
70        // Output schema: input schema + retrieval_round column.
71        let mut fields: Vec<Arc<Field>> = input.schema().fields().iter().cloned().collect();
72        fields.push(Arc::new(Field::new(
73            "retrieval_round",
74            DataType::UInt32,
75            false,
76        )));
77        let schema = Arc::new(Schema::new(fields));
78
79        let properties = PlanProperties::new(
80            datafusion_physical_expr::EquivalenceProperties::new(schema.clone()),
81            datafusion_physical_plan::Partitioning::UnknownPartitioning(1),
82            EmissionType::Final,
83            Boundedness::Bounded,
84        );
85
86        Self {
87            base_search_params: find_base_search_params(input.as_ref()),
88            input,
89            config,
90            schema,
91            properties,
92        }
93    }
94
95    pub fn config(&self) -> &IterativeConfig {
96        &self.config
97    }
98}
99
100impl DisplayAs for IterativeRetrievalExec {
101    fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
102        write!(
103            f,
104            "IterativeRetrievalExec: max_rounds={}, coverage_threshold={}, \
105             expansion_prior_rows={}, expansion_terms={}",
106            self.config.max_rounds,
107            self.config.coverage_threshold,
108            self.config.expansion_prior_rows,
109            self.config.expansion_terms,
110        )
111    }
112}
113
114impl ExecutionPlan for IterativeRetrievalExec {
115    fn name(&self) -> &str {
116        "IterativeRetrievalExec"
117    }
118
119    fn as_any(&self) -> &dyn Any {
120        self
121    }
122
123    fn schema(&self) -> SchemaRef {
124        self.schema.clone()
125    }
126
127    fn properties(&self) -> &PlanProperties {
128        &self.properties
129    }
130
131    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
132        vec![&self.input]
133    }
134
135    fn with_new_children(
136        self: Arc<Self>,
137        children: Vec<Arc<dyn ExecutionPlan>>,
138    ) -> Result<Arc<dyn ExecutionPlan>> {
139        let [child]: [Arc<dyn ExecutionPlan>; 1] = children.try_into().map_err(|v: Vec<_>| {
140            datafusion_common::DataFusionError::Plan(format!(
141                "IterativeRetrievalExec requires exactly 1 child, got {}",
142                v.len()
143            ))
144        })?;
145        Ok(Arc::new(Self::new(child, self.config.clone())))
146    }
147
148    fn execute(
149        &self,
150        partition: usize,
151        context: Arc<TaskContext>,
152    ) -> Result<SendableRecordBatchStream> {
153        let input_stream = self.input.execute(partition, context.clone())?;
154        let schema = self.schema.clone();
155        let max_rounds = self.config.max_rounds;
156        let coverage_threshold = self.config.coverage_threshold;
157        let expansion_prior_rows = self.config.expansion_prior_rows;
158        let expansion_terms = self.config.expansion_terms;
159        let base_search_params = self.base_search_params.clone();
160
161        let session_ext = context
162            .session_config()
163            .options()
164            .extensions
165            .get::<HirnSessionExt>()
166            .cloned();
167        let storage = session_ext.as_ref().and_then(HirnSessionExt::storage_arc);
168        let embedder = session_ext.as_ref().and_then(HirnSessionExt::embedder_arc);
169
170        let stream = futures::stream::once(async move {
171            use futures::StreamExt;
172
173            let mut seen_ids: HashSet<String> = HashSet::new();
174            let mut all_rows: Vec<IterativeRecallRow> = Vec::new();
175
176            // ── Round 1: execute the child plan ──
177            {
178                let mut input_stream = input_stream;
179                let mut round_batches = Vec::new();
180                while let Some(batch_result) = input_stream.next().await {
181                    round_batches.push(batch_result?);
182                }
183                if round_batches.is_empty() {
184                    let columns: Vec<Arc<dyn Array>> = schema
185                        .fields()
186                        .iter()
187                        .map(|f| arrow_array::new_empty_array(f.data_type()))
188                        .collect();
189                    return RecordBatch::try_new(schema, columns).map_err(Into::into);
190                }
191                all_rows.extend(deduplicate_round_batches(&round_batches, &mut seen_ids, 1)?);
192            }
193
194            if all_rows.is_empty() {
195                let columns: Vec<Arc<dyn Array>> = schema
196                    .fields()
197                    .iter()
198                    .map(|f| arrow_array::new_empty_array(f.data_type()))
199                    .collect();
200                return RecordBatch::try_new(schema, columns).map_err(Into::into);
201            }
202
203            let Some(storage) = storage else {
204                return build_output_batch(schema, &all_rows);
205            };
206            let Some(embedder) = embedder else {
207                // No embedder configured: multi-round expansion requires query re-embedding,
208                // so fall back to the single-round result already in `all_rows`.
209                if max_rounds > 1 {
210                    tracing::warn!(
211                        max_rounds,
212                        "IterativeRetrievalExec: embedder absent, falling back to single-round \
213                         result; configure an embedder to enable full iterative retrieval"
214                    );
215                }
216                return build_output_batch(schema, &all_rows);
217            };
218            let Some(base_search_params) = base_search_params else {
219                return build_output_batch(schema, &all_rows);
220            };
221
222            let params = resolved_search_params(&base_search_params, session_ext.as_ref());
223            let target_count = params.limit.max(5);
224            let mut previous_round = all_rows.clone();
225            // Explicit round counter avoids reading `all_rows.last()` which is
226            // unreliable when a round produces no new results and `all_rows` is
227            // an aggregate of all previous rounds.
228            let mut current_round = 1u32;
229
230            while current_round < max_rounds
231                && (all_rows.len() as f32 / target_count as f32) < coverage_threshold
232                && !previous_round.is_empty()
233            {
234                current_round += 1;
235                let Some(expanded_query) = build_expanded_query(
236                    params.fts_query.as_str(),
237                    &previous_round,
238                    expansion_prior_rows,
239                    expansion_terms,
240                ) else {
241                    break;
242                };
243
244                let query_embedding =
245                    embedder
246                        .embed(&[expanded_query.as_str()])
247                        .await
248                        .map_err(|error| {
249                            datafusion_common::DataFusionError::Execution(error.to_string())
250                        })?;
251                let Some(query_embedding) = query_embedding.first() else {
252                    break;
253                };
254
255                let mut round_params = params.clone();
256                round_params
257                    .query_vector
258                    .clone_from(&query_embedding.vector);
259                round_params.fts_query = expanded_query;
260
261                let round_rows =
262                    search_rows(storage.as_ref(), &round_params)
263                        .await
264                        .map_err(|error| {
265                            datafusion_common::DataFusionError::Execution(error.to_string())
266                        })?;
267                let deduped_rows =
268                    deduplicate_search_rows(round_rows, &mut seen_ids, current_round, &schema);
269                if deduped_rows.is_empty() {
270                    break;
271                }
272
273                previous_round.clone_from(&deduped_rows);
274                all_rows.extend(deduped_rows);
275            }
276
277            build_output_batch(schema, &all_rows)
278        });
279
280        Ok(Box::pin(RecordBatchStreamAdapter::new(
281            self.schema.clone(),
282            stream,
283        )))
284    }
285}
286
287#[derive(Debug, Clone)]
288struct IterativeRecallRow {
289    base: RecallRow,
290    activation_score: Option<f32>,
291    activation_depth: Option<u32>,
292    causal_score: Option<f32>,
293    causal_depth: Option<u32>,
294    retrieval_round: u32,
295}
296
297fn find_base_search_params(plan: &dyn ExecutionPlan) -> Option<HybridSearchParams> {
298    if let Some(search) = plan.as_any().downcast_ref::<LanceHybridSearchExec>() {
299        return Some(search.params().clone());
300    }
301
302    for child in plan.children() {
303        if let Some(params) = find_base_search_params(child.as_ref()) {
304            return Some(params);
305        }
306    }
307    None
308}
309
310fn deduplicate_round_batches(
311    batches: &[RecordBatch],
312    seen_ids: &mut HashSet<String>,
313    retrieval_round: u32,
314) -> datafusion_common::Result<Vec<IterativeRecallRow>> {
315    let mut result = Vec::new();
316    for batch in batches {
317        for row in recall_rows_from_batch(batch, retrieval_round)? {
318            if seen_ids.insert(row.base.id.clone()) {
319                result.push(row);
320            }
321        }
322    }
323    Ok(result)
324}
325
326fn deduplicate_search_rows(
327    rows: Vec<RecallRow>,
328    seen_ids: &mut HashSet<String>,
329    retrieval_round: u32,
330    schema: &Schema,
331) -> Vec<IterativeRecallRow> {
332    let include_activation = schema.field_with_name("activation_score").is_ok();
333    let include_causal = schema.field_with_name("causal_score").is_ok();
334
335    rows.into_iter()
336        .filter(|row| seen_ids.insert(row.id.clone()))
337        .map(|base| IterativeRecallRow {
338            base,
339            activation_score: include_activation.then_some(0.0),
340            activation_depth: include_activation.then_some(0),
341            causal_score: include_causal.then_some(0.0),
342            causal_depth: include_causal.then_some(0),
343            retrieval_round,
344        })
345        .collect()
346}
347
348fn recall_rows_from_batch(
349    batch: &RecordBatch,
350    retrieval_round: u32,
351) -> datafusion_common::Result<Vec<IterativeRecallRow>> {
352    let ids = required_string_column(batch, "id")?;
353    let contents = required_string_column(batch, "content")?;
354    let full_contents = batch
355        .column_by_name("full_content")
356        .and_then(|column| column.as_any().downcast_ref::<StringArray>());
357    let layers = required_string_column(batch, "layer")?;
358    let namespaces = required_string_column(batch, "namespace")?;
359    let scores = required_f32_column(batch, "score")?;
360    let temporal_ms = required_i64_column(batch, "temporal_ms")?;
361    let created_at_ms = required_i64_column(batch, "created_at_ms")?;
362    let importances = required_f32_column(batch, "importance")?;
363    let access_counts = required_u32_column(batch, "access_count")?;
364    let surprises = optional_f32_column(batch, "surprise");
365    let evidence_counts = optional_u32_column(batch, "evidence_count");
366    let invocation_counts = optional_u64_column(batch, "invocation_count");
367    let activation_scores = optional_f32_column(batch, "activation_score");
368    let activation_depths = optional_u32_column(batch, "depth");
369    let causal_scores = optional_f32_column(batch, "causal_score");
370    let causal_depths = optional_u32_column(batch, "causal_depth");
371
372    let mut rows = Vec::with_capacity(batch.num_rows());
373    for row in 0..batch.num_rows() {
374        rows.push(IterativeRecallRow {
375            base: RecallRow {
376                id: ids.value(row).to_string(),
377                content: contents.value(row).to_string(),
378                full_content: full_contents
379                    .map(|fc| fc.value(row).to_string())
380                    .unwrap_or_else(|| contents.value(row).to_string()),
381                layer: match layers.value(row) {
382                    "episodic" => "episodic",
383                    "semantic" => "semantic",
384                    "procedural" => "procedural",
385                    other => {
386                        return Err(datafusion_common::DataFusionError::Execution(format!(
387                            "unsupported recall layer `{other}` in iterative retrieval"
388                        )));
389                    }
390                },
391                namespace: namespaces.value(row).to_string(),
392                score: scores.value(row),
393                temporal_ms: temporal_ms.value(row),
394                created_at_ms: created_at_ms.value(row),
395                importance: importances.value(row),
396                access_count: access_counts.value(row),
397                surprise: optional_f32_value(surprises, row),
398                evidence_count: optional_u32_value(evidence_counts, row),
399                invocation_count: optional_u64_value(invocation_counts, row),
400            },
401            activation_score: optional_f32_value(activation_scores, row),
402            activation_depth: optional_u32_value(activation_depths, row),
403            causal_score: optional_f32_value(causal_scores, row),
404            causal_depth: optional_u32_value(causal_depths, row),
405            retrieval_round,
406        });
407    }
408
409    Ok(rows)
410}
411
412/// Build a reformulated query by appending gap-filling terms drawn from the
413/// highest-scoring rows of the previous retrieval round.
414///
415/// Uses pseudo-relevance-feedback (PRF) with inverse-sqrt document-frequency
416/// weighting: terms that appear in only a few high-scoring rows receive higher
417/// weight than terms shared across many rows (which are typically generic and
418/// less discriminative).
419fn build_expanded_query(
420    original_query: &str,
421    prior_rows: &[IterativeRecallRow],
422    prior_rows_limit: usize,
423    expansion_terms: usize,
424) -> Option<String> {
425    let original_terms = lexical_terms(original_query);
426    let candidates: Vec<&IterativeRecallRow> = prior_rows.iter().take(prior_rows_limit).collect();
427
428    // Tokenise each candidate row once to avoid redundant work.
429    let row_terms: Vec<BTreeSet<String>> = candidates
430        .iter()
431        .map(|row| lexical_terms(&row.base.content))
432        .collect();
433
434    // Document frequency: how many candidate rows contain each non-query term.
435    let mut doc_freq: HashMap<String, usize> = HashMap::new();
436    for terms in &row_terms {
437        for term in terms {
438            if !original_terms.contains(term) {
439                *doc_freq.entry(term.clone()).or_insert(0) += 1;
440            }
441        }
442    }
443
444    // PRF score: sum of (row_score × 1/√doc_freq) across rows containing the term.
445    // The inverse-sqrt IDF downweights ubiquitous terms, preferring discriminative
446    // terms that appear in only a few high-scoring rows.
447    let mut term_scores: HashMap<String, f32> = HashMap::new();
448    for (row, terms) in candidates.iter().zip(&row_terms) {
449        for term in terms {
450            if original_terms.contains(term) {
451                continue;
452            }
453            let df = *doc_freq.get(term).unwrap_or(&1) as f32;
454            let idf_weight = 1.0 / df.sqrt();
455            *term_scores.entry(term.clone()).or_insert(0.0) +=
456                row.base.score.max(0.05) * idf_weight;
457        }
458    }
459
460    let mut ranked: Vec<(String, f32)> = term_scores.into_iter().collect();
461    // Sort by score descending; break ties alphabetically for determinism.
462    ranked.sort_by(|a, b| b.1.total_cmp(&a.1).then_with(|| a.0.cmp(&b.0)));
463
464    let expansion: Vec<String> = ranked
465        .into_iter()
466        .take(expansion_terms)
467        .map(|(term, _)| term)
468        .collect();
469
470    if expansion.is_empty() {
471        return None;
472    }
473
474    Some(format!("{} {}", original_query, expansion.join(" ")))
475}
476
477fn lexical_terms(text: &str) -> BTreeSet<String> {
478    const STOP_WORDS: &[&str] = &[
479        "a", "an", "and", "are", "as", "at", "be", "by", "for", "from", "how", "i", "in", "is",
480        "it", "of", "on", "or", "that", "the", "to", "was", "what", "when", "where", "which",
481        "who", "why", "with",
482    ];
483
484    text.split_whitespace()
485        .map(|token| {
486            token
487                .trim_matches(|c: char| !c.is_alphanumeric())
488                .to_ascii_lowercase()
489        })
490        .filter(|token| token.len() > 2 && !STOP_WORDS.contains(&token.as_str()))
491        .collect()
492}
493
494fn required_string_column<'a>(
495    batch: &'a RecordBatch,
496    name: &str,
497) -> datafusion_common::Result<&'a StringArray> {
498    batch
499        .column_by_name(name)
500        .and_then(|column| column.as_any().downcast_ref::<StringArray>())
501        .ok_or_else(|| {
502            datafusion_common::DataFusionError::Execution(format!(
503                "iterative retrieval batch missing `{name}` string column"
504            ))
505        })
506}
507
508fn required_f32_column<'a>(
509    batch: &'a RecordBatch,
510    name: &str,
511) -> datafusion_common::Result<&'a Float32Array> {
512    batch
513        .column_by_name(name)
514        .and_then(|column| column.as_any().downcast_ref::<Float32Array>())
515        .ok_or_else(|| {
516            datafusion_common::DataFusionError::Execution(format!(
517                "iterative retrieval batch missing `{name}` f32 column"
518            ))
519        })
520}
521
522fn required_i64_column<'a>(
523    batch: &'a RecordBatch,
524    name: &str,
525) -> datafusion_common::Result<&'a Int64Array> {
526    batch
527        .column_by_name(name)
528        .and_then(|column| column.as_any().downcast_ref::<Int64Array>())
529        .ok_or_else(|| {
530            datafusion_common::DataFusionError::Execution(format!(
531                "iterative retrieval batch missing `{name}` i64 column"
532            ))
533        })
534}
535
536fn required_u32_column<'a>(
537    batch: &'a RecordBatch,
538    name: &str,
539) -> datafusion_common::Result<&'a UInt32Array> {
540    batch
541        .column_by_name(name)
542        .and_then(|column| column.as_any().downcast_ref::<UInt32Array>())
543        .ok_or_else(|| {
544            datafusion_common::DataFusionError::Execution(format!(
545                "iterative retrieval batch missing `{name}` u32 column"
546            ))
547        })
548}
549
550fn optional_f32_column<'a>(batch: &'a RecordBatch, name: &str) -> Option<&'a Float32Array> {
551    batch
552        .column_by_name(name)
553        .and_then(|column| column.as_any().downcast_ref::<Float32Array>())
554}
555
556fn optional_u32_column<'a>(batch: &'a RecordBatch, name: &str) -> Option<&'a UInt32Array> {
557    batch
558        .column_by_name(name)
559        .and_then(|column| column.as_any().downcast_ref::<UInt32Array>())
560}
561
562fn optional_u64_column<'a>(batch: &'a RecordBatch, name: &str) -> Option<&'a UInt64Array> {
563    batch
564        .column_by_name(name)
565        .and_then(|column| column.as_any().downcast_ref::<UInt64Array>())
566}
567
568fn optional_f32_value(array: Option<&Float32Array>, row: usize) -> Option<f32> {
569    array.and_then(|array| (!array.is_null(row)).then(|| array.value(row)))
570}
571
572fn optional_u32_value(array: Option<&UInt32Array>, row: usize) -> Option<u32> {
573    array.and_then(|array| (!array.is_null(row)).then(|| array.value(row)))
574}
575
576fn optional_u64_value(array: Option<&UInt64Array>, row: usize) -> Option<u64> {
577    array.and_then(|array| (!array.is_null(row)).then(|| array.value(row)))
578}
579
580fn build_output_batch(
581    schema: SchemaRef,
582    rows: &[IterativeRecallRow],
583) -> datafusion_common::Result<RecordBatch> {
584    if rows.is_empty() {
585        return Ok(RecordBatch::new_empty(schema));
586    }
587
588    let include_activation = schema.field_with_name("activation_score").is_ok();
589    let include_causal = schema.field_with_name("causal_score").is_ok();
590
591    let ids = rows
592        .iter()
593        .map(|row| row.base.id.as_str())
594        .collect::<Vec<_>>();
595    let contents = rows
596        .iter()
597        .map(|row| row.base.content.as_str())
598        .collect::<Vec<_>>();
599    let full_contents = rows
600        .iter()
601        .map(|row| row.base.full_content.as_str())
602        .collect::<Vec<_>>();
603    let layers = rows.iter().map(|row| row.base.layer).collect::<Vec<_>>();
604    let namespaces = rows
605        .iter()
606        .map(|row| row.base.namespace.as_str())
607        .collect::<Vec<_>>();
608    let scores = rows.iter().map(|row| row.base.score).collect::<Vec<_>>();
609    let temporal = rows
610        .iter()
611        .map(|row| row.base.temporal_ms)
612        .collect::<Vec<_>>();
613    let created_at = rows
614        .iter()
615        .map(|row| row.base.created_at_ms)
616        .collect::<Vec<_>>();
617    let importances = rows
618        .iter()
619        .map(|row| row.base.importance)
620        .collect::<Vec<_>>();
621    let access_counts = rows
622        .iter()
623        .map(|row| row.base.access_count)
624        .collect::<Vec<_>>();
625    let surprises = rows.iter().map(|row| row.base.surprise).collect::<Vec<_>>();
626    let evidence_counts = rows
627        .iter()
628        .map(|row| row.base.evidence_count)
629        .collect::<Vec<_>>();
630    let invocation_counts = rows
631        .iter()
632        .map(|row| row.base.invocation_count)
633        .collect::<Vec<_>>();
634    let retrieval_rounds = rows
635        .iter()
636        .map(|row| row.retrieval_round)
637        .collect::<Vec<_>>();
638
639    let mut columns: Vec<ArrayRef> = vec![
640        Arc::new(StringArray::from(ids)) as ArrayRef,
641        Arc::new(StringArray::from(contents)) as ArrayRef,
642        Arc::new(StringArray::from(full_contents)) as ArrayRef,
643        Arc::new(StringArray::from(layers)) as ArrayRef,
644        Arc::new(StringArray::from(namespaces)) as ArrayRef,
645        Arc::new(Float32Array::from(scores)) as ArrayRef,
646        Arc::new(Int64Array::from(temporal)) as ArrayRef,
647        Arc::new(Int64Array::from(created_at)) as ArrayRef,
648        Arc::new(Float32Array::from(importances)) as ArrayRef,
649        Arc::new(UInt32Array::from(access_counts)) as ArrayRef,
650        Arc::new(Float32Array::from(surprises)) as ArrayRef,
651        Arc::new(UInt32Array::from(evidence_counts)) as ArrayRef,
652        Arc::new(UInt64Array::from(invocation_counts)) as ArrayRef,
653    ];
654
655    if include_activation {
656        columns.push(Arc::new(Float32Array::from(
657            rows.iter()
658                .map(|row| row.activation_score.unwrap_or(0.0))
659                .collect::<Vec<_>>(),
660        )) as ArrayRef);
661        columns.push(Arc::new(UInt32Array::from(
662            rows.iter()
663                .map(|row| row.activation_depth.unwrap_or(0))
664                .collect::<Vec<_>>(),
665        )) as ArrayRef);
666    }
667
668    if include_causal {
669        columns.push(Arc::new(Float32Array::from(
670            rows.iter()
671                .map(|row| row.causal_score.unwrap_or(0.0))
672                .collect::<Vec<_>>(),
673        )) as ArrayRef);
674        columns.push(Arc::new(UInt32Array::from(
675            rows.iter()
676                .map(|row| row.causal_depth.unwrap_or(0))
677                .collect::<Vec<_>>(),
678        )) as ArrayRef);
679    }
680
681    columns.push(Arc::new(UInt32Array::from(retrieval_rounds)) as ArrayRef);
682
683    Ok(RecordBatch::try_new(schema, columns)?)
684}
685
686#[cfg(test)]
687mod tests {
688    use super::*;
689    use std::sync::Arc;
690
691    use async_trait::async_trait;
692    use hirn_core::HirnResult;
693    use hirn_core::config::HirnConfig;
694    use hirn_core::embed::{Embedding, MultivectorEmbedding};
695    use hirn_core::episodic::EpisodicRecord;
696    use hirn_core::types::AgentId;
697    use hirn_storage::PhysicalStore;
698    use hirn_storage::datasets::episodic;
699    use hirn_storage::memory_store::MemoryStore;
700
701    use crate::extensions::HirnSessionExt;
702    use crate::operators::lance_hybrid_search::LanceHybridSearchExec;
703
704    #[test]
705    fn default_config() {
706        let config = IterativeConfig::default();
707        assert_eq!(config.max_rounds, 3);
708        assert!((config.coverage_threshold - 0.7).abs() < f32::EPSILON);
709        assert_eq!(config.expansion_prior_rows, 8);
710        assert_eq!(config.expansion_terms, 4);
711    }
712
713    #[test]
714    fn display_format() {
715        let exec = IterativeRetrievalExec::new(
716            Arc::new(datafusion_physical_plan::empty::EmptyExec::new(Arc::new(
717                Schema::empty(),
718            ))),
719            IterativeConfig::default(),
720        );
721        assert_eq!(exec.name(), "IterativeRetrievalExec");
722    }
723
724    #[tokio::test]
725    async fn execute_empty_input() {
726        use futures::StreamExt;
727
728        let empty_schema = Arc::new(Schema::new(vec![
729            Field::new("id", DataType::Utf8, false),
730            Field::new("content", DataType::Utf8, false),
731        ]));
732        let empty = Arc::new(datafusion_physical_plan::empty::EmptyExec::new(
733            empty_schema,
734        ));
735        let exec = IterativeRetrievalExec::new(empty, IterativeConfig::default());
736        let ctx = Arc::new(TaskContext::default());
737        let mut stream = exec.execute(0, ctx).unwrap();
738        let batch = stream.next().await.unwrap().unwrap();
739        assert_eq!(batch.num_rows(), 0);
740    }
741
742    #[derive(Debug)]
743    struct KeywordEmbedder;
744
745    #[async_trait]
746    impl Embedder for KeywordEmbedder {
747        async fn embed(&self, texts: &[&str]) -> HirnResult<Vec<Embedding>> {
748            Ok(texts
749                .iter()
750                .map(|text| Embedding {
751                    vector: if text.to_ascii_lowercase().contains("entanglement") {
752                        vec![0.0, 1.0]
753                    } else {
754                        vec![1.0, 0.0]
755                    },
756                    model_id: "keyword-test".to_string(),
757                })
758                .collect())
759        }
760
761        fn dimensions(&self) -> usize {
762            2
763        }
764
765        fn model_id(&self) -> &str {
766            "keyword-test"
767        }
768
769        fn max_input_tokens(&self) -> usize {
770            1024
771        }
772
773        async fn embed_multivec(&self, _texts: &[&str]) -> HirnResult<Vec<MultivectorEmbedding>> {
774            Ok(Vec::new())
775        }
776    }
777
778    fn test_recall_schema() -> SchemaRef {
779        Arc::new(Schema::new(vec![
780            Field::new("id", DataType::Utf8, false),
781            Field::new("content", DataType::Utf8, false),
782            Field::new("full_content", DataType::Utf8, false),
783            Field::new("layer", DataType::Utf8, false),
784            Field::new("namespace", DataType::Utf8, false),
785            Field::new("score", DataType::Float32, true),
786            Field::new("temporal_ms", DataType::Int64, false),
787            Field::new("created_at_ms", DataType::Int64, false),
788            Field::new("importance", DataType::Float32, true),
789            Field::new("access_count", DataType::UInt32, true),
790            Field::new("surprise", DataType::Float32, true),
791            Field::new("evidence_count", DataType::UInt32, true),
792            Field::new("invocation_count", DataType::UInt64, true),
793        ]))
794    }
795
796    #[tokio::test]
797    async fn iterative_retrieval_exec_runs_real_second_round() {
798        use futures::StreamExt;
799
800        let storage: Arc<dyn PhysicalStore> = Arc::new(MemoryStore::new());
801        let records = vec![
802            EpisodicRecord::builder()
803                .content("quantum qubits entanglement")
804                .agent_id(AgentId::new("iterative_test").unwrap())
805                .embedding(vec![1.0, 0.0])
806                .build()
807                .unwrap(),
808            EpisodicRecord::builder()
809                .content("entanglement teleportation bell-states")
810                .agent_id(AgentId::new("iterative_test").unwrap())
811                .embedding(vec![0.0, 1.0])
812                .build()
813                .unwrap(),
814        ];
815        storage
816            .append(
817                episodic::DATASET_NAME,
818                episodic::to_batch(&records, 2).unwrap(),
819            )
820            .await
821            .unwrap();
822
823        let ctx = datafusion::prelude::SessionContext::new();
824        HirnSessionExt::new(
825            Arc::new(0_u8),
826            Arc::new(HirnConfig::default()),
827            Some(Arc::new(KeywordEmbedder)),
828        )
829        .with_storage(Arc::clone(&storage))
830        .register(&ctx)
831        .unwrap();
832
833        let search = Arc::new(LanceHybridSearchExec::new(
834            test_recall_schema(),
835            HybridSearchParams {
836                datasets: vec![episodic::DATASET_NAME.to_string()],
837                vector_column: "embedding".to_string(),
838                query_vector: vec![1.0, 0.0],
839                hybrid_mode: false,
840                fts_columns: vec!["content".to_string()],
841                fts_query: "quantum".to_string(),
842                limit: 1,
843                metric: hirn_storage::store::DistanceMetric::Cosine,
844                filter: None,
845                numeric_filters: Vec::new(),
846                temporal_start_ms: None,
847                temporal_end_ms: None,
848                temporal_expansion: false,
849                temporal_boost: 1.25,
850            },
851        ));
852
853        let exec = IterativeRetrievalExec::new(
854            search,
855            IterativeConfig {
856                max_rounds: 2,
857                coverage_threshold: 0.9,
858                ..IterativeConfig::default()
859            },
860        );
861        let mut stream = exec.execute(0, ctx.task_ctx()).unwrap();
862        let batch = stream.next().await.unwrap().unwrap();
863
864        let ids = batch
865            .column_by_name("id")
866            .and_then(|column| column.as_any().downcast_ref::<StringArray>())
867            .unwrap();
868        let rounds = batch
869            .column_by_name("retrieval_round")
870            .and_then(|column| column.as_any().downcast_ref::<UInt32Array>())
871            .unwrap();
872
873        assert_eq!(batch.num_rows(), 2);
874        assert_eq!(rounds.value(0), 1);
875        assert_eq!(rounds.value(1), 2);
876        assert_ne!(ids.value(0), ids.value(1));
877    }
878}