Skip to main content

hirn_exec/operators/
interference_detector.rs

1//! `InterferenceDetectorExec` — per-write interference detection.
2//!
3//! Four checks per write:
4//! 1. **Hash deduplication** — FNV-1a 64-bit deterministic hash on `content`.
5//!    Exact duplicate within the current batch → `interference_flags = "duplicate"`.
6//!    **Check 1b** — Near-duplicate detection via vector similarity against persisted
7//!    memories: requires `HirnSessionExt` `PhysicalStore` and an `embedding` column
8//!    (FixedSizeList<Float32>) on the incoming batch.
9//!    Near-duplicate → `interference_flags = "near_duplicate"`.
10//!    Silently skipped when storage or embedding column is absent (graceful degradation).
11//! 2. **Supersession** — same namespace + overlapping `entities_json` + newer
12//!    `timestamp_ms` within the current batch → `interference_flags = "supersession"`.
13//! 3. **NLI contradiction** — pairwise NLI classification of new content against all
14//!    earlier rows in the current batch. Backed by [`HeuristicNliClassifier`] by default;
15//!    upgrade to DeBERTa-MNLI ONNX by injecting an [`NliClassifier`] via
16//!    `HirnSessionExt::with_nli_classifier()`. Caps comparison pairs via
17//!    `InterferenceConfig::nli_max_pairs`. Contradiction → `interference_flags = "conflict"`.
18//!
19//! # Upgrade path for ONNX NLI
20//! Inject `Arc<dyn NliClassifier>` backed by a loaded DeBERTa-MNLI ONNX session into
21//! `HirnSessionExt` at database open time. The planner picks it up automatically.
22
23use std::any::Any;
24use std::fmt;
25use std::sync::Arc;
26
27use arrow_array::{Array, FixedSizeListArray, Float32Array, RecordBatch, StringArray};
28use arrow_schema::{DataType, Field, Schema, SchemaRef};
29use datafusion_common::Result;
30use datafusion_execution::{SendableRecordBatchStream, TaskContext};
31use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType};
32use datafusion_physical_plan::stream::RecordBatchStreamAdapter;
33use datafusion_physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
34use hirn_storage::PhysicalStore;
35use hirn_storage::store::{DistanceMetric, VectorSearchOptions};
36
37use crate::extensions::HirnSessionExt;
38use crate::operators::nli_contradiction::{HeuristicNliClassifier, NliClassifier, NliLabel};
39
40/// Configuration for interference detection.
41#[derive(Debug, Clone)]
42pub struct InterferenceConfig {
43    /// Similarity threshold above which a record is a near-duplicate (default: 0.95).
44    pub duplicate_threshold: f32,
45    /// Cumulative interference score threshold to trigger consolidation (default: 0.3).
46    pub consolidation_trigger: f32,
47    /// Datasets to search for Check 1b vector-similarity near-dup detection.
48    pub search_datasets: Vec<String>,
49    /// Distance metric for vector similarity search (must match index build metric).
50    pub distance_metric: DistanceMetric,
51    /// Number of nearest neighbours per query in near-dup search (default: 3).
52    pub near_dup_search_limit: usize,
53    /// Contradiction probability threshold for Check 3 NLI (default: 0.7).
54    pub nli_contradiction_threshold: f32,
55    /// Maximum number of (earlier_row, new_row) content pairs classified per new row
56    /// in Check 3. Caps O(n²) cost; 0 disables Check 3 entirely (default: 32).
57    pub nli_max_pairs: usize,
58}
59
60impl Default for InterferenceConfig {
61    fn default() -> Self {
62        Self {
63            duplicate_threshold: 0.95,
64            consolidation_trigger: 0.3,
65            search_datasets: vec![
66                "episodic".to_string(),
67                "semantic".to_string(),
68                "procedural".to_string(),
69            ],
70            distance_metric: DistanceMetric::L2,
71            near_dup_search_limit: 3,
72            nli_contradiction_threshold: 0.7,
73            nli_max_pairs: 32,
74        }
75    }
76}
77
78#[allow(clippy::struct_excessive_bools)] // 4 independent check flags — a bitfield would be less clear
79#[derive(Debug, Clone, Default)]
80pub struct InterferenceFlags {
81    /// Exact content hash duplicate within the current batch (Check 1).
82    pub is_duplicate: bool,
83    /// Vector-similarity near-duplicate against persisted memories (Check 1b).
84    pub is_near_duplicate: bool,
85    /// Temporal supersession within the current batch (Check 2).
86    pub is_supersession: bool,
87    /// NLI-confirmed contradiction detected at write time (Check 3).
88    pub has_conflict: bool,
89    /// Max interference score across all checks.
90    pub score: f32,
91}
92
93impl InterferenceFlags {
94    pub fn flag_string(&self) -> String {
95        let mut flags = Vec::new();
96        if self.is_duplicate {
97            flags.push("duplicate");
98        }
99        if self.is_near_duplicate {
100            flags.push("near_duplicate");
101        }
102        if self.is_supersession {
103            flags.push("supersession");
104        }
105        if self.has_conflict {
106            flags.push("conflict");
107        }
108        if flags.is_empty() {
109            "none".to_string()
110        } else {
111            flags.join(",")
112        }
113    }
114}
115
116/// DataFusion operator for write-path interference detection.
117///
118/// Passes through input batches, appending `interference_flags` and
119/// `interference_score` columns.
120///
121/// **Check 1 (implemented):** FNV-1a hash deduplication within the batch.
122///
123/// **Check 1b (implemented):** Vector-similarity near-duplicate detection against
124/// persisted memories via `HirnSessionExt` `PhysicalStore`. Requires an `embedding`
125/// (FixedSizeList<Float32>) column on the incoming batch. Silently skipped when
126/// either is absent.
127///
128/// **Check 2 (implemented):** Batch-local supersession by namespace + entity overlap.
129///
130/// **Check 3 (implemented):** Pairwise NLI contradiction detection against earlier rows in
131/// the current write batch using an injectable [`NliClassifier`]. Defaults to the
132/// deterministic [`HeuristicNliClassifier`]; upgrade to DeBERTa-MNLI by injecting via
133/// `HirnSessionExt::with_nli_classifier()` at database open time.
134#[derive(Debug)]
135pub struct InterferenceDetectorExec {
136    input: Arc<dyn ExecutionPlan>,
137    config: InterferenceConfig,
138    /// NLI classifier for Check 3. Defaults to heuristic; injectable for ONNX upgrade.
139    nli_classifier: Arc<dyn NliClassifier>,
140    schema: SchemaRef,
141    properties: PlanProperties,
142}
143
144impl InterferenceDetectorExec {
145    /// Create with the default [`HeuristicNliClassifier`] for Check 3.
146    pub fn new(input: Arc<dyn ExecutionPlan>, config: InterferenceConfig) -> Self {
147        Self::with_nli_classifier(input, config, Arc::new(HeuristicNliClassifier))
148    }
149
150    /// Create with a custom NLI classifier (e.g. DeBERTa-MNLI via ONNX).
151    pub fn with_nli_classifier(
152        input: Arc<dyn ExecutionPlan>,
153        config: InterferenceConfig,
154        nli_classifier: Arc<dyn NliClassifier>,
155    ) -> Self {
156        let mut fields: Vec<Arc<Field>> = input.schema().fields().iter().cloned().collect();
157        fields.push(Arc::new(Field::new(
158            "interference_flags",
159            DataType::Utf8,
160            false,
161        )));
162        fields.push(Arc::new(Field::new(
163            "interference_score",
164            DataType::Float32,
165            false,
166        )));
167        let schema = Arc::new(Schema::new(fields));
168
169        let properties = PlanProperties::new(
170            datafusion_physical_expr::EquivalenceProperties::new(schema.clone()),
171            datafusion_physical_plan::Partitioning::UnknownPartitioning(1),
172            EmissionType::Final,
173            Boundedness::Bounded,
174        );
175
176        Self {
177            input,
178            config,
179            nli_classifier,
180            schema,
181            properties,
182        }
183    }
184
185    pub fn config(&self) -> &InterferenceConfig {
186        &self.config
187    }
188
189    pub fn nli_classifier(&self) -> &Arc<dyn NliClassifier> {
190        &self.nli_classifier
191    }
192}
193
194impl DisplayAs for InterferenceDetectorExec {
195    fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
196        write!(
197            f,
198            "InterferenceDetectorExec: dup_threshold={}, consolidation_trigger={}, near_dup_limit={}",
199            self.config.duplicate_threshold,
200            self.config.consolidation_trigger,
201            self.config.near_dup_search_limit,
202        )
203    }
204}
205
206impl ExecutionPlan for InterferenceDetectorExec {
207    fn name(&self) -> &str {
208        "InterferenceDetectorExec"
209    }
210
211    fn as_any(&self) -> &dyn Any {
212        self
213    }
214
215    fn schema(&self) -> SchemaRef {
216        self.schema.clone()
217    }
218
219    fn properties(&self) -> &PlanProperties {
220        &self.properties
221    }
222
223    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
224        vec![&self.input]
225    }
226
227    fn with_new_children(
228        self: Arc<Self>,
229        children: Vec<Arc<dyn ExecutionPlan>>,
230    ) -> Result<Arc<dyn ExecutionPlan>> {
231        let [child]: [Arc<dyn ExecutionPlan>; 1] = children.try_into().map_err(|v: Vec<_>| {
232            datafusion_common::DataFusionError::Plan(format!(
233                "InterferenceDetectorExec requires exactly 1 child, got {}",
234                v.len()
235            ))
236        })?;
237        Ok(Arc::new(Self::with_nli_classifier(
238            child,
239            self.config.clone(),
240            Arc::clone(&self.nli_classifier),
241        )))
242    }
243
244    fn execute(
245        &self,
246        partition: usize,
247        context: Arc<TaskContext>,
248    ) -> Result<SendableRecordBatchStream> {
249        let input_stream = self.input.execute(partition, context.clone())?;
250        let schema = self.schema.clone();
251        let dup_threshold = self.config.duplicate_threshold;
252        let config = self.config.clone();
253
254        // Extract per-session overrides before the `async move` closure.
255        let session_ext = context
256            .session_config()
257            .options()
258            .extensions
259            .get::<HirnSessionExt>();
260
261        // Check 1b: storage for vector-similarity near-dup detection.
262        let storage = session_ext.as_ref().and_then(|ext| ext.storage_arc());
263
264        // Check 3: NLI classifier — session-injected ONNX model or heuristic fallback.
265        let nli_classifier: Arc<dyn NliClassifier> = session_ext
266            .and_then(|ext| ext.nli_classifier())
267            .unwrap_or_else(|| Arc::clone(&self.nli_classifier));
268
269        let stream = futures::stream::once(async move {
270            use futures::StreamExt;
271            use std::collections::HashMap;
272
273            /// Deterministic FNV-1a 64-bit hash (N-L05).
274            ///
275            /// `std::hash::DefaultHasher` is intentionally NOT used here because
276            /// its output is randomised per-process (HashDoS protection), making
277            /// duplicate detection non-repeatable across restarts.
278            #[inline]
279            fn fnv1a_64(bytes: &[u8]) -> u64 {
280                const OFFSET: u64 = 14_695_981_039_346_656_037;
281                const PRIME: u64 = 1_099_511_628_211;
282                let mut h = OFFSET;
283                for &b in bytes {
284                    h ^= b as u64;
285                    h = h.wrapping_mul(PRIME);
286                }
287                h
288            }
289
290            let mut batches = Vec::new();
291            let mut input_stream = input_stream;
292            while let Some(batch_result) = input_stream.next().await {
293                batches.push(batch_result?);
294            }
295
296            if batches.is_empty() {
297                let columns: Vec<Arc<dyn Array>> = schema
298                    .fields()
299                    .iter()
300                    .map(|f| arrow_array::new_empty_array(f.data_type()))
301                    .collect();
302                return RecordBatch::try_new(schema, columns).map_err(Into::into);
303            }
304
305            let merged =
306                arrow_select::concat::concat_batches(&batches[0].schema(), batches.iter())?;
307
308            let n = merged.num_rows();
309
310            // ── Check 1 + Check 2 pass ──
311            // Collect per-row InterferenceFlags so we can post-process with Check 1b.
312            let content_col = merged.column_by_name("content");
313            let contents = content_col.and_then(|c| c.as_any().downcast_ref::<StringArray>());
314
315            let mut content_hashes: HashMap<u64, usize> = HashMap::new();
316            let mut all_flags: Vec<InterferenceFlags> = Vec::with_capacity(n);
317
318            for i in 0..n {
319                let mut flags = InterferenceFlags::default();
320
321                // Check 1: Exact content duplicate (hash-based).
322                if let Some(contents) = contents {
323                    if !contents.is_null(i) {
324                        let content = contents.value(i);
325                        let h = fnv1a_64(content.as_bytes());
326                        if content_hashes.contains_key(&h) {
327                            flags.is_duplicate = true;
328                            flags.score = dup_threshold;
329                        }
330                        content_hashes.insert(h, i);
331                    }
332                }
333
334                // Check 2: Supersession — same namespace + overlapping entities + newer
335                // timestamp means this record supersedes an earlier one in the batch.
336                if !flags.is_duplicate {
337                    let entities_col = merged.column_by_name("entities_json");
338                    let ts_col = merged.column_by_name("timestamp_ms");
339                    let ns_col = merged.column_by_name("namespace");
340
341                    let entities =
342                        entities_col.and_then(|c| c.as_any().downcast_ref::<StringArray>());
343                    let timestamps =
344                        ts_col.and_then(|c| c.as_any().downcast_ref::<arrow_array::Int64Array>());
345                    let namespaces = ns_col.and_then(|c| c.as_any().downcast_ref::<StringArray>());
346
347                    if let (Some(ents), Some(tss), Some(nss)) = (entities, timestamps, namespaces) {
348                        if !ents.is_null(i) && !tss.is_null(i) && !nss.is_null(i) {
349                            let ns_i = nss.value(i);
350                            let ts_i = tss.value(i);
351                            // B-M02: explicit warning on malformed JSON so operational
352                            // visibility is preserved. Conservative: treat as empty set
353                            // (no supersession flagged) rather than silently swallowing.
354                            let ents_i: std::collections::HashSet<String> =
355                                match serde_json::from_str(ents.value(i)) {
356                                    Ok(v) => v,
357                                    Err(e) => {
358                                        tracing::warn!(
359                                            row = i,
360                                            error = %e,
361                                            "interference_detector: malformed entities_json \
362                                             at row {i} — treating as empty set (no supersession)"
363                                        );
364                                        std::collections::HashSet::new()
365                                    }
366                                };
367
368                            for j in 0..i {
369                                if nss.is_null(j)
370                                    || tss.is_null(j)
371                                    || ents.is_null(j)
372                                    || nss.value(j) != ns_i
373                                {
374                                    continue;
375                                }
376                                let ts_j = tss.value(j);
377                                if ts_i <= ts_j {
378                                    // Not newer — no supersession.
379                                    continue;
380                                }
381                                let ents_j: std::collections::HashSet<String> =
382                                    match serde_json::from_str(ents.value(j)) {
383                                        Ok(v) => v,
384                                        Err(e) => {
385                                            tracing::warn!(
386                                                row = j,
387                                                error = %e,
388                                                "interference_detector: malformed entities_json \
389                                                 at row {j} — treating as empty set (no supersession)"
390                                            );
391                                            std::collections::HashSet::new()
392                                        }
393                                    };
394                                let overlap = ents_i.intersection(&ents_j).count();
395                                if overlap > 0 {
396                                    flags.is_supersession = true;
397                                    let union_sz = ents_i.union(&ents_j).count().max(1) as f32;
398                                    let jaccard = overlap as f32 / union_sz;
399                                    flags.score = flags.score.max(jaccard * 0.8);
400                                    break;
401                                }
402                            }
403                        }
404                    }
405                }
406
407                // Check 3: NLI contradiction detection.
408                //
409                // Compare the new row's content against all earlier rows in this batch.
410                // Uses the injected NliClassifier (heuristic by default; ONNX-upgradeable).
411                // Capped by `config.nli_max_pairs` to bound O(n²) cost.
412                if !flags.is_duplicate
413                    && !flags.is_supersession
414                    && config.nli_max_pairs > 0
415                    && i > 0
416                // no earlier rows to compare against for the first row
417                {
418                    if let Some(contents) = contents {
419                        if !contents.is_null(i) {
420                            let text_i = contents.value(i);
421                            let mut pairs_checked = 0usize;
422                            let mut j = i.saturating_sub(1);
423                            loop {
424                                if pairs_checked >= config.nli_max_pairs {
425                                    break;
426                                }
427                                if !contents.is_null(j) {
428                                    let text_j = contents.value(j);
429                                    let (label, score) = nli_classifier.classify(text_j, text_i);
430                                    if label == NliLabel::Contradiction
431                                        && score >= config.nli_contradiction_threshold
432                                    {
433                                        flags.has_conflict = true;
434                                        // Weight contradiction slightly below exact dup.
435                                        flags.score = flags.score.max(score * 0.9);
436                                        tracing::debug!(
437                                            row = i,
438                                            against_row = j,
439                                            score,
440                                            "InterferenceDetectorExec: NLI contradiction detected"
441                                        );
442                                        break;
443                                    }
444                                }
445                                pairs_checked += 1;
446                                if j == 0 {
447                                    break;
448                                }
449                                j -= 1;
450                            }
451                        }
452                    }
453                }
454
455                all_flags.push(flags);
456            }
457
458            // ── Check 1b: Near-duplicate detection via vector similarity ──
459            //
460            // For rows not already flagged by Check 1 or Check 2, search persisted
461            // memories using the `embedding` (FixedSizeList<Float32>) column.
462            // Queries are batched across datasets and executed in parallel for
463            // minimum latency. Silently skipped when storage or column is absent.
464            if let Some(ref storage) = storage {
465                let fsl = merged
466                    .column_by_name("embedding")
467                    .and_then(|c| c.as_any().downcast_ref::<FixedSizeListArray>());
468
469                if let Some(fsl) = fsl {
470                    // Gather unflagged rows that have a non-null embedding.
471                    let row_embeddings: Vec<(usize, Vec<f32>)> = (0..n)
472                        .filter(|&i| !all_flags[i].is_duplicate && !all_flags[i].is_supersession)
473                        .filter_map(|i| {
474                            if fsl.is_null(i) {
475                                return None;
476                            }
477                            let values = fsl.value(i);
478                            let f32_arr = values.as_any().downcast_ref::<Float32Array>()?;
479                            Some((i, f32_arr.values().to_vec()))
480                        })
481                        .collect();
482
483                    if !row_embeddings.is_empty() {
484                        let emb_slices: Vec<&[f32]> =
485                            row_embeddings.iter().map(|(_, e)| e.as_slice()).collect();
486
487                        let max_sims = find_max_similarities(&emb_slices, storage, &config).await;
488
489                        for (q_idx, &(row_idx, _)) in row_embeddings.iter().enumerate() {
490                            let sim = max_sims.get(q_idx).copied().unwrap_or(0.0);
491                            if sim >= dup_threshold {
492                                all_flags[row_idx].is_near_duplicate = true;
493                                all_flags[row_idx].score = all_flags[row_idx].score.max(sim);
494                                tracing::debug!(
495                                    row = row_idx,
496                                    similarity = sim,
497                                    "InterferenceDetectorExec: near-duplicate detected"
498                                );
499                            }
500                        }
501                    }
502                }
503            }
504
505            // Convert accumulated flags to columnar output.
506            let flags_col: StringArray = all_flags
507                .iter()
508                .map(|f| f.flag_string())
509                .collect::<Vec<_>>()
510                .into();
511            let score_col: Float32Array =
512                all_flags.iter().map(|f| f.score).collect::<Vec<_>>().into();
513
514            let mut columns: Vec<Arc<dyn Array>> = merged.columns().to_vec();
515            columns.push(Arc::new(flags_col));
516            columns.push(Arc::new(score_col));
517
518            RecordBatch::try_new(schema, columns).map_err(Into::into)
519        });
520
521        Ok(Box::pin(RecordBatchStreamAdapter::new(
522            self.schema.clone(),
523            stream,
524        )))
525    }
526}
527
528/// Search `storage` for the most similar persisted memory for each query embedding.
529///
530/// Queries are batched per dataset and executed in parallel (one `vector_search_many`
531/// call per dataset, all datasets launched concurrently via `join_all`).
532///
533/// Returns one `f32` max-similarity per query in the same order as `embeddings`.
534/// Returns 0.0 for any query whose searches fail or return no results.
535async fn find_max_similarities(
536    embeddings: &[&[f32]],
537    storage: &Arc<dyn PhysicalStore>,
538    config: &InterferenceConfig,
539) -> Vec<f32> {
540    if embeddings.is_empty() {
541        return Vec::new();
542    }
543
544    let metric = config.distance_metric;
545    let limit = config.near_dup_search_limit;
546    let n_queries = embeddings.len();
547
548    let queries: Vec<VectorSearchOptions> = embeddings
549        .iter()
550        .map(|emb| VectorSearchOptions {
551            query: emb.to_vec(),
552            column: "embedding".into(),
553            limit,
554            metric,
555            ..Default::default()
556        })
557        .collect();
558
559    // All datasets searched in parallel — max(D) instead of N×D serial calls.
560    let search_futures = config.search_datasets.iter().map(|dataset| {
561        let storage = Arc::clone(storage);
562        let dataset = dataset.clone();
563        let queries = queries.clone();
564        async move {
565            let exists = storage.exists(&dataset).await.unwrap_or(false);
566            let n_q = queries.len();
567            if !exists {
568                return vec![0.0_f32; n_q];
569            }
570            match storage.vector_search_many(&dataset, queries).await {
571                Ok(per_query_results) => per_query_results
572                    .iter()
573                    .map(|batches| {
574                        // Find the top similarity across all returned result batches.
575                        batches
576                            .iter()
577                            .map(|b| {
578                                b.column_by_name("_distance")
579                                    .and_then(|c| c.as_any().downcast_ref::<Float32Array>())
580                                    .map(|dists| {
581                                        (0..dists.len())
582                                            .filter(|&j| !dists.is_null(j))
583                                            .map(|j| dist_to_sim(metric, dists.value(j)))
584                                            .fold(0.0_f32, f32::max)
585                                    })
586                                    .unwrap_or(0.0)
587                            })
588                            .fold(0.0_f32, f32::max)
589                    })
590                    .collect(),
591                Err(e) => {
592                    tracing::warn!(
593                        dataset,
594                        error = %e,
595                        "InterferenceDetectorExec: near-dup search failed, skipping dataset"
596                    );
597                    vec![0.0_f32; n_q]
598                }
599            }
600        }
601    });
602
603    let per_dataset_sims: Vec<Vec<f32>> = futures::future::join_all(search_futures).await;
604
605    // For each query, find the maximum similarity across all datasets.
606    (0..n_queries)
607        .map(|q_idx| {
608            per_dataset_sims
609                .iter()
610                .map(|sims| sims.get(q_idx).copied().unwrap_or(0.0))
611                .fold(0.0_f32, f32::max)
612        })
613        .collect()
614}
615
616/// Convert a Lance `_distance` value to a [0, 1] similarity score.
617///
618/// The formula depends on the distance metric (must match the index build metric).
619fn dist_to_sim(metric: DistanceMetric, dist: f32) -> f32 {
620    match metric {
621        // Cosine distance = 1 - cosine_similarity, so similarity = 1 - dist.
622        DistanceMetric::Cosine => (1.0 - dist).clamp(0.0, 1.0),
623        // Dot-product distance = 1 - dot_product for unit-normalized vectors.
624        DistanceMetric::DotProduct => (1.0 - dist).clamp(0.0, 1.0),
625        // L2 distance: map to (0, 1] via 1/(1+d²) — matches RPE scoring.
626        DistanceMetric::L2 => 1.0 / (1.0 + dist),
627    }
628}
629
630#[cfg(test)]
631mod tests {
632    use super::*;
633
634    #[test]
635    fn default_config() {
636        let config = InterferenceConfig::default();
637        assert!((config.duplicate_threshold - 0.95).abs() < f32::EPSILON);
638        assert!((config.consolidation_trigger - 0.3).abs() < f32::EPSILON);
639        assert_eq!(config.search_datasets.len(), 3);
640        assert_eq!(config.near_dup_search_limit, 3);
641    }
642
643    #[test]
644    fn flag_string_none() {
645        let flags = InterferenceFlags::default();
646        assert_eq!(flags.flag_string(), "none");
647    }
648
649    #[test]
650    fn flag_string_near_duplicate() {
651        let flags = InterferenceFlags {
652            is_near_duplicate: true,
653            score: 0.97,
654            ..Default::default()
655        };
656        assert_eq!(flags.flag_string(), "near_duplicate");
657    }
658
659    #[test]
660    fn flag_string_multiple() {
661        let flags = InterferenceFlags {
662            is_duplicate: true,
663            has_conflict: true,
664            ..Default::default()
665        };
666        assert_eq!(flags.flag_string(), "duplicate,conflict");
667    }
668
669    #[test]
670    fn dist_to_sim_l2() {
671        // Distance 0 → similarity 1.0
672        assert!((dist_to_sim(DistanceMetric::L2, 0.0) - 1.0).abs() < f32::EPSILON);
673        // Distance 1 → similarity 0.5
674        assert!((dist_to_sim(DistanceMetric::L2, 1.0) - 0.5).abs() < f32::EPSILON);
675    }
676
677    #[test]
678    fn dist_to_sim_cosine() {
679        // Cosine distance 0 (identical) → similarity 1.0
680        assert!((dist_to_sim(DistanceMetric::Cosine, 0.0) - 1.0).abs() < f32::EPSILON);
681        // Cosine distance 0.1 → similarity 0.9
682        assert!((dist_to_sim(DistanceMetric::Cosine, 0.1) - 0.9).abs() < f32::EPSILON);
683    }
684
685    #[tokio::test]
686    async fn execute_empty_input() {
687        use futures::StreamExt;
688
689        let empty_schema = Arc::new(Schema::new(vec![
690            Field::new("id", DataType::Utf8, false),
691            Field::new("content", DataType::Utf8, false),
692        ]));
693        let empty = Arc::new(datafusion_physical_plan::empty::EmptyExec::new(
694            empty_schema,
695        ));
696        let exec = InterferenceDetectorExec::new(empty, InterferenceConfig::default());
697        let ctx = Arc::new(TaskContext::default());
698        let mut stream = exec.execute(0, ctx).unwrap();
699        let batch = stream.next().await.unwrap().unwrap();
700        assert_eq!(batch.num_rows(), 0);
701    }
702
703    #[tokio::test]
704    async fn detects_exact_content_duplicate() {
705        use futures::StreamExt;
706
707        let schema = Arc::new(Schema::new(vec![
708            Field::new("id", DataType::Utf8, false),
709            Field::new("content", DataType::Utf8, false),
710        ]));
711
712        let batch = RecordBatch::try_new(
713            schema.clone(),
714            vec![
715                Arc::new(StringArray::from(vec!["a", "b", "c"])),
716                Arc::new(StringArray::from(vec![
717                    "hello world",
718                    "unique text",
719                    "hello world",
720                ])),
721            ],
722        )
723        .unwrap();
724
725        let input = Arc::new(crate::test_utils::MemoryBatchExec::new(
726            schema.clone(),
727            vec![batch],
728        ));
729        let exec = InterferenceDetectorExec::new(input, InterferenceConfig::default());
730        let ctx = Arc::new(TaskContext::default());
731        let mut stream = exec.execute(0, ctx).unwrap();
732        let result = stream.next().await.unwrap().unwrap();
733
734        assert_eq!(result.num_rows(), 3);
735        let flags = result
736            .column_by_name("interference_flags")
737            .unwrap()
738            .as_any()
739            .downcast_ref::<StringArray>()
740            .unwrap();
741        // Row 0: first occurrence → no flag.
742        assert_eq!(flags.value(0), "none");
743        // Row 1: unique → no flag.
744        assert_eq!(flags.value(1), "none");
745        // Row 2: duplicate of row 0 → flagged.
746        assert_eq!(flags.value(2), "duplicate");
747    }
748
749    #[tokio::test]
750    async fn no_duplicates_all_unique() {
751        use futures::StreamExt;
752
753        let schema = Arc::new(Schema::new(vec![
754            Field::new("id", DataType::Utf8, false),
755            Field::new("content", DataType::Utf8, false),
756        ]));
757
758        let batch = RecordBatch::try_new(
759            schema.clone(),
760            vec![
761                Arc::new(StringArray::from(vec!["a", "b"])),
762                Arc::new(StringArray::from(vec!["first content", "second content"])),
763            ],
764        )
765        .unwrap();
766
767        let input = Arc::new(crate::test_utils::MemoryBatchExec::new(
768            schema.clone(),
769            vec![batch],
770        ));
771        let exec = InterferenceDetectorExec::new(input, InterferenceConfig::default());
772        let ctx = Arc::new(TaskContext::default());
773        let mut stream = exec.execute(0, ctx).unwrap();
774        let result = stream.next().await.unwrap().unwrap();
775
776        assert_eq!(result.num_rows(), 2);
777        let scores = result
778            .column_by_name("interference_score")
779            .unwrap()
780            .as_any()
781            .downcast_ref::<Float32Array>()
782            .unwrap();
783        assert!((scores.value(0) - 0.0).abs() < f32::EPSILON);
784        assert!((scores.value(1) - 0.0).abs() < f32::EPSILON);
785    }
786
787    /// Check 1b: near-duplicate detected via vector similarity against persisted memories.
788    ///
789    /// Writes a nearly-identical embedding to MemoryStore first, then runs
790    /// InterferenceDetectorExec with HirnSessionExt wired to that store.
791    #[tokio::test(flavor = "multi_thread")]
792    async fn detects_near_duplicate_via_vector_search() {
793        use arrow_array::builder::{FixedSizeListBuilder, Float32Builder};
794        use datafusion::prelude::SessionContext;
795        use futures::StreamExt;
796        use hirn_core::config::HirnConfig;
797        use hirn_storage::memory_store::MemoryStore;
798        use std::sync::Arc;
799
800        // ── 1. Seed MemoryStore with an existing memory embedding [1.0, 0.0, 0.0] ──
801        let store: Arc<MemoryStore> = Arc::new(MemoryStore::new());
802        let dim = 3_i32;
803        let existing_schema = Arc::new(Schema::new(vec![
804            Field::new("id", DataType::Utf8, false),
805            Field::new(
806                "embedding",
807                DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), dim),
808                true,
809            ),
810        ]));
811        let mut emb_builder = FixedSizeListBuilder::new(Float32Builder::new(), dim);
812        for &v in &[1.0_f32, 0.0, 0.0] {
813            emb_builder.values().append_value(v);
814        }
815        emb_builder.append(true);
816        let existing_batch = RecordBatch::try_new(
817            existing_schema,
818            vec![
819                Arc::new(StringArray::from(vec!["existing-1"])),
820                Arc::new(emb_builder.finish()),
821            ],
822        )
823        .unwrap();
824        store.append("episodic", existing_batch).await.unwrap();
825
826        // ── 2. Build SessionContext with HirnSessionExt pointing to that store ──
827        let ctx = SessionContext::new();
828        let config = Arc::new(HirnConfig::default());
829        let ext = crate::extensions::HirnSessionExt::new(Arc::new(42_u32), config, None)
830            .with_storage(store as Arc<dyn hirn_storage::PhysicalStore>);
831        ext.register(&ctx).unwrap();
832
833        // ── 3. Incoming batch: one near-duplicate [0.99, 0.01, 0.0], one novel [0.0, 1.0, 0.0] ──
834        let input_schema = Arc::new(Schema::new(vec![
835            Field::new("id", DataType::Utf8, false),
836            Field::new("content", DataType::Utf8, false),
837            Field::new(
838                "embedding",
839                DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), dim),
840                true,
841            ),
842        ]));
843        let mut b = FixedSizeListBuilder::new(Float32Builder::new(), dim);
844        for &v in &[0.99_f32, 0.01, 0.0] {
845            b.values().append_value(v);
846        }
847        b.append(true);
848        for &v in &[0.0_f32, 1.0, 0.0] {
849            b.values().append_value(v);
850        }
851        b.append(true);
852        let input_batch = RecordBatch::try_new(
853            input_schema.clone(),
854            vec![
855                Arc::new(StringArray::from(vec!["new-1", "new-2"])),
856                Arc::new(StringArray::from(vec!["near text", "novel text"])),
857                Arc::new(b.finish()),
858            ],
859        )
860        .unwrap();
861
862        let input_exec = Arc::new(crate::test_utils::MemoryBatchExec::new(
863            input_schema,
864            vec![input_batch],
865        ));
866
867        // Use a low duplicate_threshold (0.5) so the near-dup is caught by L2 similarity.
868        // L2 distance([1,0,0], [0.99,0.01,0]) ≈ 0.01 → sim = 1/(1+0.01) ≈ 0.99 > 0.5.
869        let config = InterferenceConfig {
870            duplicate_threshold: 0.5,
871            search_datasets: vec!["episodic".to_string()],
872            ..Default::default()
873        };
874        let exec = InterferenceDetectorExec::new(input_exec, config);
875
876        let task_ctx = ctx.task_ctx();
877        let mut stream = exec.execute(0, task_ctx).unwrap();
878        let result = stream.next().await.unwrap().unwrap();
879        assert_eq!(result.num_rows(), 2);
880
881        let flags = result
882            .column_by_name("interference_flags")
883            .unwrap()
884            .as_any()
885            .downcast_ref::<StringArray>()
886            .unwrap();
887        // Row 0: near-duplicate of stored memory → flagged.
888        assert_eq!(
889            flags.value(0),
890            "near_duplicate",
891            "expected near_duplicate, got: {}",
892            flags.value(0)
893        );
894        // Row 1: novel → no flag.
895        assert_eq!(flags.value(1), "none");
896    }
897
898    /// Check 1b: when no storage is configured, near-dup search is silently skipped.
899    #[tokio::test]
900    async fn near_dup_silently_skipped_without_storage() {
901        use arrow_array::builder::{FixedSizeListBuilder, Float32Builder};
902        use futures::StreamExt;
903
904        let dim = 3_i32;
905        let input_schema = Arc::new(Schema::new(vec![
906            Field::new("id", DataType::Utf8, false),
907            Field::new("content", DataType::Utf8, false),
908            Field::new(
909                "embedding",
910                DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), dim),
911                true,
912            ),
913        ]));
914        let mut b = FixedSizeListBuilder::new(Float32Builder::new(), dim);
915        for &v in &[1.0_f32, 0.0, 0.0] {
916            b.values().append_value(v);
917        }
918        b.append(true);
919        let batch = RecordBatch::try_new(
920            input_schema.clone(),
921            vec![
922                Arc::new(StringArray::from(vec!["a"])),
923                Arc::new(StringArray::from(vec!["some content"])),
924                Arc::new(b.finish()),
925            ],
926        )
927        .unwrap();
928
929        let input_exec = Arc::new(crate::test_utils::MemoryBatchExec::new(
930            input_schema,
931            vec![batch],
932        ));
933
934        // No HirnSessionExt → no storage → near-dup silently skipped.
935        let exec = InterferenceDetectorExec::new(input_exec, InterferenceConfig::default());
936        let ctx = Arc::new(TaskContext::default());
937        let mut stream = exec.execute(0, ctx).unwrap();
938        let result = stream.next().await.unwrap().unwrap();
939        assert_eq!(result.num_rows(), 1);
940        let flags = result
941            .column_by_name("interference_flags")
942            .unwrap()
943            .as_any()
944            .downcast_ref::<StringArray>()
945            .unwrap();
946        assert_eq!(flags.value(0), "none");
947    }
948
949    // ── Check 3: NLI contradiction tests ─────────────────────────────────────
950
951    /// Check 3: Heuristic NLI detects a contradiction between two rows.
952    ///
953    /// Row 0: "The cat is alive." Row 1: "The cat is not alive." — negation pair.
954    /// HeuristicNliClassifier should return Contradiction with score ≥ 0.7.
955    #[tokio::test]
956    async fn detects_nli_contradiction_within_batch() {
957        use futures::StreamExt;
958
959        let schema = Arc::new(Schema::new(vec![
960            Field::new("id", DataType::Utf8, false),
961            Field::new("content", DataType::Utf8, false),
962        ]));
963
964        let batch = RecordBatch::try_new(
965            schema.clone(),
966            vec![
967                Arc::new(StringArray::from(vec!["r0", "r1"])),
968                Arc::new(StringArray::from(vec![
969                    "The cat is alive and healthy.",
970                    "The cat is not alive and not healthy.",
971                ])),
972            ],
973        )
974        .unwrap();
975
976        let input = Arc::new(crate::test_utils::MemoryBatchExec::new(
977            schema.clone(),
978            vec![batch],
979        ));
980        let exec = InterferenceDetectorExec::new(input, InterferenceConfig::default());
981        let ctx = Arc::new(TaskContext::default());
982        let mut stream = exec.execute(0, ctx).unwrap();
983        let result = stream.next().await.unwrap().unwrap();
984
985        assert_eq!(result.num_rows(), 2);
986        let flags = result
987            .column_by_name("interference_flags")
988            .unwrap()
989            .as_any()
990            .downcast_ref::<StringArray>()
991            .unwrap();
992        // Row 0: first row, no prior rows to compare — no flag.
993        assert_eq!(flags.value(0), "none", "row 0 should have no flag");
994        // Row 1: contradicts row 0 — should be flagged as conflict.
995        assert_eq!(
996            flags.value(1),
997            "conflict",
998            "row 1 should be flagged as conflict"
999        );
1000    }
1001
1002    /// Check 3: independent rows with no semantic overlap produce no conflict flags.
1003    #[tokio::test]
1004    async fn nli_no_false_positive_on_unrelated_content() {
1005        use futures::StreamExt;
1006
1007        let schema = Arc::new(Schema::new(vec![
1008            Field::new("id", DataType::Utf8, false),
1009            Field::new("content", DataType::Utf8, false),
1010        ]));
1011
1012        let batch = RecordBatch::try_new(
1013            schema.clone(),
1014            vec![
1015                Arc::new(StringArray::from(vec!["r0", "r1", "r2"])),
1016                Arc::new(StringArray::from(vec![
1017                    "Paris is the capital of France.",
1018                    "The boiling point of water is 100 degrees.",
1019                    "Jupiter is the largest planet in the solar system.",
1020                ])),
1021            ],
1022        )
1023        .unwrap();
1024
1025        let input = Arc::new(crate::test_utils::MemoryBatchExec::new(
1026            schema.clone(),
1027            vec![batch],
1028        ));
1029        let exec = InterferenceDetectorExec::new(input, InterferenceConfig::default());
1030        let ctx = Arc::new(TaskContext::default());
1031        let mut stream = exec.execute(0, ctx).unwrap();
1032        let result = stream.next().await.unwrap().unwrap();
1033
1034        let flags = result
1035            .column_by_name("interference_flags")
1036            .unwrap()
1037            .as_any()
1038            .downcast_ref::<StringArray>()
1039            .unwrap();
1040        for i in 0..3 {
1041            assert_eq!(flags.value(i), "none", "row {i} should not be flagged");
1042        }
1043    }
1044
1045    /// Check 3: when `nli_max_pairs` is 0, NLI check is skipped entirely.
1046    #[tokio::test]
1047    async fn nli_disabled_when_max_pairs_zero() {
1048        use futures::StreamExt;
1049
1050        let schema = Arc::new(Schema::new(vec![
1051            Field::new("id", DataType::Utf8, false),
1052            Field::new("content", DataType::Utf8, false),
1053        ]));
1054
1055        let batch = RecordBatch::try_new(
1056            schema.clone(),
1057            vec![
1058                Arc::new(StringArray::from(vec!["r0", "r1"])),
1059                Arc::new(StringArray::from(vec![
1060                    "The cat is alive.",
1061                    "The cat is not alive.",
1062                ])),
1063            ],
1064        )
1065        .unwrap();
1066
1067        let input = Arc::new(crate::test_utils::MemoryBatchExec::new(
1068            schema.clone(),
1069            vec![batch],
1070        ));
1071        let config = InterferenceConfig {
1072            nli_max_pairs: 0, // disable NLI
1073            ..Default::default()
1074        };
1075        let exec = InterferenceDetectorExec::new(input, config);
1076        let ctx = Arc::new(TaskContext::default());
1077        let mut stream = exec.execute(0, ctx).unwrap();
1078        let result = stream.next().await.unwrap().unwrap();
1079
1080        let flags = result
1081            .column_by_name("interference_flags")
1082            .unwrap()
1083            .as_any()
1084            .downcast_ref::<StringArray>()
1085            .unwrap();
1086        // NLI disabled — even the negation pair should be unflagged.
1087        assert_eq!(
1088            flags.value(1),
1089            "none",
1090            "NLI should be skipped when nli_max_pairs=0"
1091        );
1092    }
1093
1094    /// Check 3: already-duplicate rows do not trigger the NLI check.
1095    #[tokio::test]
1096    async fn nli_skipped_for_already_flagged_duplicate_rows() {
1097        use futures::StreamExt;
1098
1099        let schema = Arc::new(Schema::new(vec![
1100            Field::new("id", DataType::Utf8, false),
1101            Field::new("content", DataType::Utf8, false),
1102        ]));
1103
1104        // Row 1 is an exact dup of row 0 AND would read as "not X" if compared with
1105        // row 2 which has "not" prefix — but row 1 is already flagged as duplicate
1106        // so NLI should not fire for it.
1107        let batch = RecordBatch::try_new(
1108            schema.clone(),
1109            vec![
1110                Arc::new(StringArray::from(vec!["r0", "r1", "r2"])),
1111                Arc::new(StringArray::from(vec![
1112                    "The sky is blue.",
1113                    "The sky is blue.", // exact dup of r0
1114                    "The sky is not blue.",
1115                ])),
1116            ],
1117        )
1118        .unwrap();
1119
1120        let input = Arc::new(crate::test_utils::MemoryBatchExec::new(
1121            schema.clone(),
1122            vec![batch],
1123        ));
1124        let exec = InterferenceDetectorExec::new(input, InterferenceConfig::default());
1125        let ctx = Arc::new(TaskContext::default());
1126        let mut stream = exec.execute(0, ctx).unwrap();
1127        let result = stream.next().await.unwrap().unwrap();
1128
1129        let flags = result
1130            .column_by_name("interference_flags")
1131            .unwrap()
1132            .as_any()
1133            .downcast_ref::<StringArray>()
1134            .unwrap();
1135        assert_eq!(flags.value(0), "none", "row 0: first occurrence");
1136        assert_eq!(
1137            flags.value(1),
1138            "duplicate",
1139            "row 1: exact dup, not conflict"
1140        );
1141        // Row 2 contradicts row 0 — NLI should fire here.
1142        assert_eq!(
1143            flags.value(2),
1144            "conflict",
1145            "row 2: contradiction with row 0"
1146        );
1147    }
1148
1149    /// Check 3: `with_nli_classifier()` respects injected classifier.
1150    ///
1151    /// A stub classifier that always returns Contradiction lets us test the wiring
1152    /// without depending on heuristic text analysis.
1153    #[tokio::test]
1154    async fn nli_respects_injected_classifier() {
1155        use futures::StreamExt;
1156
1157        /// Classifier that always returns Contradiction at score 0.99.
1158        #[derive(Debug)]
1159        struct AlwaysContradiction;
1160        impl NliClassifier for AlwaysContradiction {
1161            fn classify(
1162                &self,
1163                _text_a: &str,
1164                _text_b: &str,
1165            ) -> (crate::operators::nli_contradiction::NliLabel, f32) {
1166                (NliLabel::Contradiction, 0.99)
1167            }
1168            fn backend_name(&self) -> &'static str {
1169                "always_contradiction"
1170            }
1171        }
1172
1173        let schema = Arc::new(Schema::new(vec![
1174            Field::new("id", DataType::Utf8, false),
1175            Field::new("content", DataType::Utf8, false),
1176        ]));
1177
1178        let batch = RecordBatch::try_new(
1179            schema.clone(),
1180            vec![
1181                Arc::new(StringArray::from(vec!["r0", "r1"])),
1182                Arc::new(StringArray::from(vec!["anything", "anything else"])),
1183            ],
1184        )
1185        .unwrap();
1186
1187        let input = Arc::new(crate::test_utils::MemoryBatchExec::new(
1188            schema.clone(),
1189            vec![batch],
1190        ));
1191        let exec = InterferenceDetectorExec::with_nli_classifier(
1192            input,
1193            InterferenceConfig::default(),
1194            Arc::new(AlwaysContradiction),
1195        );
1196        let ctx = Arc::new(TaskContext::default());
1197        let mut stream = exec.execute(0, ctx).unwrap();
1198        let result = stream.next().await.unwrap().unwrap();
1199
1200        let flags = result
1201            .column_by_name("interference_flags")
1202            .unwrap()
1203            .as_any()
1204            .downcast_ref::<StringArray>()
1205            .unwrap();
1206        assert_eq!(flags.value(0), "none", "row 0: no prior rows");
1207        assert_eq!(
1208            flags.value(1),
1209            "conflict",
1210            "row 1: injected classifier fires"
1211        );
1212    }
1213}