1use std::any::Any;
24use std::fmt;
25use std::sync::Arc;
26
27use arrow_array::{Array, FixedSizeListArray, Float32Array, RecordBatch, StringArray};
28use arrow_schema::{DataType, Field, Schema, SchemaRef};
29use datafusion_common::Result;
30use datafusion_execution::{SendableRecordBatchStream, TaskContext};
31use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType};
32use datafusion_physical_plan::stream::RecordBatchStreamAdapter;
33use datafusion_physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
34use hirn_storage::PhysicalStore;
35use hirn_storage::store::{DistanceMetric, VectorSearchOptions};
36
37use crate::extensions::HirnSessionExt;
38use crate::operators::nli_contradiction::{HeuristicNliClassifier, NliClassifier, NliLabel};
39
40#[derive(Debug, Clone)]
42pub struct InterferenceConfig {
43 pub duplicate_threshold: f32,
45 pub consolidation_trigger: f32,
47 pub search_datasets: Vec<String>,
49 pub distance_metric: DistanceMetric,
51 pub near_dup_search_limit: usize,
53 pub nli_contradiction_threshold: f32,
55 pub nli_max_pairs: usize,
58}
59
60impl Default for InterferenceConfig {
61 fn default() -> Self {
62 Self {
63 duplicate_threshold: 0.95,
64 consolidation_trigger: 0.3,
65 search_datasets: vec![
66 "episodic".to_string(),
67 "semantic".to_string(),
68 "procedural".to_string(),
69 ],
70 distance_metric: DistanceMetric::L2,
71 near_dup_search_limit: 3,
72 nli_contradiction_threshold: 0.7,
73 nli_max_pairs: 32,
74 }
75 }
76}
77
78#[allow(clippy::struct_excessive_bools)] #[derive(Debug, Clone, Default)]
80pub struct InterferenceFlags {
81 pub is_duplicate: bool,
83 pub is_near_duplicate: bool,
85 pub is_supersession: bool,
87 pub has_conflict: bool,
89 pub score: f32,
91}
92
93impl InterferenceFlags {
94 pub fn flag_string(&self) -> String {
95 let mut flags = Vec::new();
96 if self.is_duplicate {
97 flags.push("duplicate");
98 }
99 if self.is_near_duplicate {
100 flags.push("near_duplicate");
101 }
102 if self.is_supersession {
103 flags.push("supersession");
104 }
105 if self.has_conflict {
106 flags.push("conflict");
107 }
108 if flags.is_empty() {
109 "none".to_string()
110 } else {
111 flags.join(",")
112 }
113 }
114}
115
116#[derive(Debug)]
135pub struct InterferenceDetectorExec {
136 input: Arc<dyn ExecutionPlan>,
137 config: InterferenceConfig,
138 nli_classifier: Arc<dyn NliClassifier>,
140 schema: SchemaRef,
141 properties: PlanProperties,
142}
143
144impl InterferenceDetectorExec {
145 pub fn new(input: Arc<dyn ExecutionPlan>, config: InterferenceConfig) -> Self {
147 Self::with_nli_classifier(input, config, Arc::new(HeuristicNliClassifier))
148 }
149
150 pub fn with_nli_classifier(
152 input: Arc<dyn ExecutionPlan>,
153 config: InterferenceConfig,
154 nli_classifier: Arc<dyn NliClassifier>,
155 ) -> Self {
156 let mut fields: Vec<Arc<Field>> = input.schema().fields().iter().cloned().collect();
157 fields.push(Arc::new(Field::new(
158 "interference_flags",
159 DataType::Utf8,
160 false,
161 )));
162 fields.push(Arc::new(Field::new(
163 "interference_score",
164 DataType::Float32,
165 false,
166 )));
167 let schema = Arc::new(Schema::new(fields));
168
169 let properties = PlanProperties::new(
170 datafusion_physical_expr::EquivalenceProperties::new(schema.clone()),
171 datafusion_physical_plan::Partitioning::UnknownPartitioning(1),
172 EmissionType::Final,
173 Boundedness::Bounded,
174 );
175
176 Self {
177 input,
178 config,
179 nli_classifier,
180 schema,
181 properties,
182 }
183 }
184
185 pub fn config(&self) -> &InterferenceConfig {
186 &self.config
187 }
188
189 pub fn nli_classifier(&self) -> &Arc<dyn NliClassifier> {
190 &self.nli_classifier
191 }
192}
193
194impl DisplayAs for InterferenceDetectorExec {
195 fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
196 write!(
197 f,
198 "InterferenceDetectorExec: dup_threshold={}, consolidation_trigger={}, near_dup_limit={}",
199 self.config.duplicate_threshold,
200 self.config.consolidation_trigger,
201 self.config.near_dup_search_limit,
202 )
203 }
204}
205
206impl ExecutionPlan for InterferenceDetectorExec {
207 fn name(&self) -> &str {
208 "InterferenceDetectorExec"
209 }
210
211 fn as_any(&self) -> &dyn Any {
212 self
213 }
214
215 fn schema(&self) -> SchemaRef {
216 self.schema.clone()
217 }
218
219 fn properties(&self) -> &PlanProperties {
220 &self.properties
221 }
222
223 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
224 vec![&self.input]
225 }
226
227 fn with_new_children(
228 self: Arc<Self>,
229 children: Vec<Arc<dyn ExecutionPlan>>,
230 ) -> Result<Arc<dyn ExecutionPlan>> {
231 let [child]: [Arc<dyn ExecutionPlan>; 1] = children.try_into().map_err(|v: Vec<_>| {
232 datafusion_common::DataFusionError::Plan(format!(
233 "InterferenceDetectorExec requires exactly 1 child, got {}",
234 v.len()
235 ))
236 })?;
237 Ok(Arc::new(Self::with_nli_classifier(
238 child,
239 self.config.clone(),
240 Arc::clone(&self.nli_classifier),
241 )))
242 }
243
244 fn execute(
245 &self,
246 partition: usize,
247 context: Arc<TaskContext>,
248 ) -> Result<SendableRecordBatchStream> {
249 let input_stream = self.input.execute(partition, context.clone())?;
250 let schema = self.schema.clone();
251 let dup_threshold = self.config.duplicate_threshold;
252 let config = self.config.clone();
253
254 let session_ext = context
256 .session_config()
257 .options()
258 .extensions
259 .get::<HirnSessionExt>();
260
261 let storage = session_ext.as_ref().and_then(|ext| ext.storage_arc());
263
264 let nli_classifier: Arc<dyn NliClassifier> = session_ext
266 .and_then(|ext| ext.nli_classifier())
267 .unwrap_or_else(|| Arc::clone(&self.nli_classifier));
268
269 let stream = futures::stream::once(async move {
270 use futures::StreamExt;
271 use std::collections::HashMap;
272
273 #[inline]
279 fn fnv1a_64(bytes: &[u8]) -> u64 {
280 const OFFSET: u64 = 14_695_981_039_346_656_037;
281 const PRIME: u64 = 1_099_511_628_211;
282 let mut h = OFFSET;
283 for &b in bytes {
284 h ^= b as u64;
285 h = h.wrapping_mul(PRIME);
286 }
287 h
288 }
289
290 let mut batches = Vec::new();
291 let mut input_stream = input_stream;
292 while let Some(batch_result) = input_stream.next().await {
293 batches.push(batch_result?);
294 }
295
296 if batches.is_empty() {
297 let columns: Vec<Arc<dyn Array>> = schema
298 .fields()
299 .iter()
300 .map(|f| arrow_array::new_empty_array(f.data_type()))
301 .collect();
302 return RecordBatch::try_new(schema, columns).map_err(Into::into);
303 }
304
305 let merged =
306 arrow_select::concat::concat_batches(&batches[0].schema(), batches.iter())?;
307
308 let n = merged.num_rows();
309
310 let content_col = merged.column_by_name("content");
313 let contents = content_col.and_then(|c| c.as_any().downcast_ref::<StringArray>());
314
315 let mut content_hashes: HashMap<u64, usize> = HashMap::new();
316 let mut all_flags: Vec<InterferenceFlags> = Vec::with_capacity(n);
317
318 for i in 0..n {
319 let mut flags = InterferenceFlags::default();
320
321 if let Some(contents) = contents {
323 if !contents.is_null(i) {
324 let content = contents.value(i);
325 let h = fnv1a_64(content.as_bytes());
326 if content_hashes.contains_key(&h) {
327 flags.is_duplicate = true;
328 flags.score = dup_threshold;
329 }
330 content_hashes.insert(h, i);
331 }
332 }
333
334 if !flags.is_duplicate {
337 let entities_col = merged.column_by_name("entities_json");
338 let ts_col = merged.column_by_name("timestamp_ms");
339 let ns_col = merged.column_by_name("namespace");
340
341 let entities =
342 entities_col.and_then(|c| c.as_any().downcast_ref::<StringArray>());
343 let timestamps =
344 ts_col.and_then(|c| c.as_any().downcast_ref::<arrow_array::Int64Array>());
345 let namespaces = ns_col.and_then(|c| c.as_any().downcast_ref::<StringArray>());
346
347 if let (Some(ents), Some(tss), Some(nss)) = (entities, timestamps, namespaces) {
348 if !ents.is_null(i) && !tss.is_null(i) && !nss.is_null(i) {
349 let ns_i = nss.value(i);
350 let ts_i = tss.value(i);
351 let ents_i: std::collections::HashSet<String> =
355 match serde_json::from_str(ents.value(i)) {
356 Ok(v) => v,
357 Err(e) => {
358 tracing::warn!(
359 row = i,
360 error = %e,
361 "interference_detector: malformed entities_json \
362 at row {i} — treating as empty set (no supersession)"
363 );
364 std::collections::HashSet::new()
365 }
366 };
367
368 for j in 0..i {
369 if nss.is_null(j)
370 || tss.is_null(j)
371 || ents.is_null(j)
372 || nss.value(j) != ns_i
373 {
374 continue;
375 }
376 let ts_j = tss.value(j);
377 if ts_i <= ts_j {
378 continue;
380 }
381 let ents_j: std::collections::HashSet<String> =
382 match serde_json::from_str(ents.value(j)) {
383 Ok(v) => v,
384 Err(e) => {
385 tracing::warn!(
386 row = j,
387 error = %e,
388 "interference_detector: malformed entities_json \
389 at row {j} — treating as empty set (no supersession)"
390 );
391 std::collections::HashSet::new()
392 }
393 };
394 let overlap = ents_i.intersection(&ents_j).count();
395 if overlap > 0 {
396 flags.is_supersession = true;
397 let union_sz = ents_i.union(&ents_j).count().max(1) as f32;
398 let jaccard = overlap as f32 / union_sz;
399 flags.score = flags.score.max(jaccard * 0.8);
400 break;
401 }
402 }
403 }
404 }
405 }
406
407 if !flags.is_duplicate
413 && !flags.is_supersession
414 && config.nli_max_pairs > 0
415 && i > 0
416 {
418 if let Some(contents) = contents {
419 if !contents.is_null(i) {
420 let text_i = contents.value(i);
421 let mut pairs_checked = 0usize;
422 let mut j = i.saturating_sub(1);
423 loop {
424 if pairs_checked >= config.nli_max_pairs {
425 break;
426 }
427 if !contents.is_null(j) {
428 let text_j = contents.value(j);
429 let (label, score) = nli_classifier.classify(text_j, text_i);
430 if label == NliLabel::Contradiction
431 && score >= config.nli_contradiction_threshold
432 {
433 flags.has_conflict = true;
434 flags.score = flags.score.max(score * 0.9);
436 tracing::debug!(
437 row = i,
438 against_row = j,
439 score,
440 "InterferenceDetectorExec: NLI contradiction detected"
441 );
442 break;
443 }
444 }
445 pairs_checked += 1;
446 if j == 0 {
447 break;
448 }
449 j -= 1;
450 }
451 }
452 }
453 }
454
455 all_flags.push(flags);
456 }
457
458 if let Some(ref storage) = storage {
465 let fsl = merged
466 .column_by_name("embedding")
467 .and_then(|c| c.as_any().downcast_ref::<FixedSizeListArray>());
468
469 if let Some(fsl) = fsl {
470 let row_embeddings: Vec<(usize, Vec<f32>)> = (0..n)
472 .filter(|&i| !all_flags[i].is_duplicate && !all_flags[i].is_supersession)
473 .filter_map(|i| {
474 if fsl.is_null(i) {
475 return None;
476 }
477 let values = fsl.value(i);
478 let f32_arr = values.as_any().downcast_ref::<Float32Array>()?;
479 Some((i, f32_arr.values().to_vec()))
480 })
481 .collect();
482
483 if !row_embeddings.is_empty() {
484 let emb_slices: Vec<&[f32]> =
485 row_embeddings.iter().map(|(_, e)| e.as_slice()).collect();
486
487 let max_sims = find_max_similarities(&emb_slices, storage, &config).await;
488
489 for (q_idx, &(row_idx, _)) in row_embeddings.iter().enumerate() {
490 let sim = max_sims.get(q_idx).copied().unwrap_or(0.0);
491 if sim >= dup_threshold {
492 all_flags[row_idx].is_near_duplicate = true;
493 all_flags[row_idx].score = all_flags[row_idx].score.max(sim);
494 tracing::debug!(
495 row = row_idx,
496 similarity = sim,
497 "InterferenceDetectorExec: near-duplicate detected"
498 );
499 }
500 }
501 }
502 }
503 }
504
505 let flags_col: StringArray = all_flags
507 .iter()
508 .map(|f| f.flag_string())
509 .collect::<Vec<_>>()
510 .into();
511 let score_col: Float32Array =
512 all_flags.iter().map(|f| f.score).collect::<Vec<_>>().into();
513
514 let mut columns: Vec<Arc<dyn Array>> = merged.columns().to_vec();
515 columns.push(Arc::new(flags_col));
516 columns.push(Arc::new(score_col));
517
518 RecordBatch::try_new(schema, columns).map_err(Into::into)
519 });
520
521 Ok(Box::pin(RecordBatchStreamAdapter::new(
522 self.schema.clone(),
523 stream,
524 )))
525 }
526}
527
528async fn find_max_similarities(
536 embeddings: &[&[f32]],
537 storage: &Arc<dyn PhysicalStore>,
538 config: &InterferenceConfig,
539) -> Vec<f32> {
540 if embeddings.is_empty() {
541 return Vec::new();
542 }
543
544 let metric = config.distance_metric;
545 let limit = config.near_dup_search_limit;
546 let n_queries = embeddings.len();
547
548 let queries: Vec<VectorSearchOptions> = embeddings
549 .iter()
550 .map(|emb| VectorSearchOptions {
551 query: emb.to_vec(),
552 column: "embedding".into(),
553 limit,
554 metric,
555 ..Default::default()
556 })
557 .collect();
558
559 let search_futures = config.search_datasets.iter().map(|dataset| {
561 let storage = Arc::clone(storage);
562 let dataset = dataset.clone();
563 let queries = queries.clone();
564 async move {
565 let exists = storage.exists(&dataset).await.unwrap_or(false);
566 let n_q = queries.len();
567 if !exists {
568 return vec![0.0_f32; n_q];
569 }
570 match storage.vector_search_many(&dataset, queries).await {
571 Ok(per_query_results) => per_query_results
572 .iter()
573 .map(|batches| {
574 batches
576 .iter()
577 .map(|b| {
578 b.column_by_name("_distance")
579 .and_then(|c| c.as_any().downcast_ref::<Float32Array>())
580 .map(|dists| {
581 (0..dists.len())
582 .filter(|&j| !dists.is_null(j))
583 .map(|j| dist_to_sim(metric, dists.value(j)))
584 .fold(0.0_f32, f32::max)
585 })
586 .unwrap_or(0.0)
587 })
588 .fold(0.0_f32, f32::max)
589 })
590 .collect(),
591 Err(e) => {
592 tracing::warn!(
593 dataset,
594 error = %e,
595 "InterferenceDetectorExec: near-dup search failed, skipping dataset"
596 );
597 vec![0.0_f32; n_q]
598 }
599 }
600 }
601 });
602
603 let per_dataset_sims: Vec<Vec<f32>> = futures::future::join_all(search_futures).await;
604
605 (0..n_queries)
607 .map(|q_idx| {
608 per_dataset_sims
609 .iter()
610 .map(|sims| sims.get(q_idx).copied().unwrap_or(0.0))
611 .fold(0.0_f32, f32::max)
612 })
613 .collect()
614}
615
616fn dist_to_sim(metric: DistanceMetric, dist: f32) -> f32 {
620 match metric {
621 DistanceMetric::Cosine => (1.0 - dist).clamp(0.0, 1.0),
623 DistanceMetric::DotProduct => (1.0 - dist).clamp(0.0, 1.0),
625 DistanceMetric::L2 => 1.0 / (1.0 + dist),
627 }
628}
629
630#[cfg(test)]
631mod tests {
632 use super::*;
633
634 #[test]
635 fn default_config() {
636 let config = InterferenceConfig::default();
637 assert!((config.duplicate_threshold - 0.95).abs() < f32::EPSILON);
638 assert!((config.consolidation_trigger - 0.3).abs() < f32::EPSILON);
639 assert_eq!(config.search_datasets.len(), 3);
640 assert_eq!(config.near_dup_search_limit, 3);
641 }
642
643 #[test]
644 fn flag_string_none() {
645 let flags = InterferenceFlags::default();
646 assert_eq!(flags.flag_string(), "none");
647 }
648
649 #[test]
650 fn flag_string_near_duplicate() {
651 let flags = InterferenceFlags {
652 is_near_duplicate: true,
653 score: 0.97,
654 ..Default::default()
655 };
656 assert_eq!(flags.flag_string(), "near_duplicate");
657 }
658
659 #[test]
660 fn flag_string_multiple() {
661 let flags = InterferenceFlags {
662 is_duplicate: true,
663 has_conflict: true,
664 ..Default::default()
665 };
666 assert_eq!(flags.flag_string(), "duplicate,conflict");
667 }
668
669 #[test]
670 fn dist_to_sim_l2() {
671 assert!((dist_to_sim(DistanceMetric::L2, 0.0) - 1.0).abs() < f32::EPSILON);
673 assert!((dist_to_sim(DistanceMetric::L2, 1.0) - 0.5).abs() < f32::EPSILON);
675 }
676
677 #[test]
678 fn dist_to_sim_cosine() {
679 assert!((dist_to_sim(DistanceMetric::Cosine, 0.0) - 1.0).abs() < f32::EPSILON);
681 assert!((dist_to_sim(DistanceMetric::Cosine, 0.1) - 0.9).abs() < f32::EPSILON);
683 }
684
685 #[tokio::test]
686 async fn execute_empty_input() {
687 use futures::StreamExt;
688
689 let empty_schema = Arc::new(Schema::new(vec![
690 Field::new("id", DataType::Utf8, false),
691 Field::new("content", DataType::Utf8, false),
692 ]));
693 let empty = Arc::new(datafusion_physical_plan::empty::EmptyExec::new(
694 empty_schema,
695 ));
696 let exec = InterferenceDetectorExec::new(empty, InterferenceConfig::default());
697 let ctx = Arc::new(TaskContext::default());
698 let mut stream = exec.execute(0, ctx).unwrap();
699 let batch = stream.next().await.unwrap().unwrap();
700 assert_eq!(batch.num_rows(), 0);
701 }
702
703 #[tokio::test]
704 async fn detects_exact_content_duplicate() {
705 use futures::StreamExt;
706
707 let schema = Arc::new(Schema::new(vec![
708 Field::new("id", DataType::Utf8, false),
709 Field::new("content", DataType::Utf8, false),
710 ]));
711
712 let batch = RecordBatch::try_new(
713 schema.clone(),
714 vec![
715 Arc::new(StringArray::from(vec!["a", "b", "c"])),
716 Arc::new(StringArray::from(vec![
717 "hello world",
718 "unique text",
719 "hello world",
720 ])),
721 ],
722 )
723 .unwrap();
724
725 let input = Arc::new(crate::test_utils::MemoryBatchExec::new(
726 schema.clone(),
727 vec![batch],
728 ));
729 let exec = InterferenceDetectorExec::new(input, InterferenceConfig::default());
730 let ctx = Arc::new(TaskContext::default());
731 let mut stream = exec.execute(0, ctx).unwrap();
732 let result = stream.next().await.unwrap().unwrap();
733
734 assert_eq!(result.num_rows(), 3);
735 let flags = result
736 .column_by_name("interference_flags")
737 .unwrap()
738 .as_any()
739 .downcast_ref::<StringArray>()
740 .unwrap();
741 assert_eq!(flags.value(0), "none");
743 assert_eq!(flags.value(1), "none");
745 assert_eq!(flags.value(2), "duplicate");
747 }
748
749 #[tokio::test]
750 async fn no_duplicates_all_unique() {
751 use futures::StreamExt;
752
753 let schema = Arc::new(Schema::new(vec![
754 Field::new("id", DataType::Utf8, false),
755 Field::new("content", DataType::Utf8, false),
756 ]));
757
758 let batch = RecordBatch::try_new(
759 schema.clone(),
760 vec![
761 Arc::new(StringArray::from(vec!["a", "b"])),
762 Arc::new(StringArray::from(vec!["first content", "second content"])),
763 ],
764 )
765 .unwrap();
766
767 let input = Arc::new(crate::test_utils::MemoryBatchExec::new(
768 schema.clone(),
769 vec![batch],
770 ));
771 let exec = InterferenceDetectorExec::new(input, InterferenceConfig::default());
772 let ctx = Arc::new(TaskContext::default());
773 let mut stream = exec.execute(0, ctx).unwrap();
774 let result = stream.next().await.unwrap().unwrap();
775
776 assert_eq!(result.num_rows(), 2);
777 let scores = result
778 .column_by_name("interference_score")
779 .unwrap()
780 .as_any()
781 .downcast_ref::<Float32Array>()
782 .unwrap();
783 assert!((scores.value(0) - 0.0).abs() < f32::EPSILON);
784 assert!((scores.value(1) - 0.0).abs() < f32::EPSILON);
785 }
786
787 #[tokio::test(flavor = "multi_thread")]
792 async fn detects_near_duplicate_via_vector_search() {
793 use arrow_array::builder::{FixedSizeListBuilder, Float32Builder};
794 use datafusion::prelude::SessionContext;
795 use futures::StreamExt;
796 use hirn_core::config::HirnConfig;
797 use hirn_storage::memory_store::MemoryStore;
798 use std::sync::Arc;
799
800 let store: Arc<MemoryStore> = Arc::new(MemoryStore::new());
802 let dim = 3_i32;
803 let existing_schema = Arc::new(Schema::new(vec![
804 Field::new("id", DataType::Utf8, false),
805 Field::new(
806 "embedding",
807 DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), dim),
808 true,
809 ),
810 ]));
811 let mut emb_builder = FixedSizeListBuilder::new(Float32Builder::new(), dim);
812 for &v in &[1.0_f32, 0.0, 0.0] {
813 emb_builder.values().append_value(v);
814 }
815 emb_builder.append(true);
816 let existing_batch = RecordBatch::try_new(
817 existing_schema,
818 vec![
819 Arc::new(StringArray::from(vec!["existing-1"])),
820 Arc::new(emb_builder.finish()),
821 ],
822 )
823 .unwrap();
824 store.append("episodic", existing_batch).await.unwrap();
825
826 let ctx = SessionContext::new();
828 let config = Arc::new(HirnConfig::default());
829 let ext = crate::extensions::HirnSessionExt::new(Arc::new(42_u32), config, None)
830 .with_storage(store as Arc<dyn hirn_storage::PhysicalStore>);
831 ext.register(&ctx).unwrap();
832
833 let input_schema = Arc::new(Schema::new(vec![
835 Field::new("id", DataType::Utf8, false),
836 Field::new("content", DataType::Utf8, false),
837 Field::new(
838 "embedding",
839 DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), dim),
840 true,
841 ),
842 ]));
843 let mut b = FixedSizeListBuilder::new(Float32Builder::new(), dim);
844 for &v in &[0.99_f32, 0.01, 0.0] {
845 b.values().append_value(v);
846 }
847 b.append(true);
848 for &v in &[0.0_f32, 1.0, 0.0] {
849 b.values().append_value(v);
850 }
851 b.append(true);
852 let input_batch = RecordBatch::try_new(
853 input_schema.clone(),
854 vec![
855 Arc::new(StringArray::from(vec!["new-1", "new-2"])),
856 Arc::new(StringArray::from(vec!["near text", "novel text"])),
857 Arc::new(b.finish()),
858 ],
859 )
860 .unwrap();
861
862 let input_exec = Arc::new(crate::test_utils::MemoryBatchExec::new(
863 input_schema,
864 vec![input_batch],
865 ));
866
867 let config = InterferenceConfig {
870 duplicate_threshold: 0.5,
871 search_datasets: vec!["episodic".to_string()],
872 ..Default::default()
873 };
874 let exec = InterferenceDetectorExec::new(input_exec, config);
875
876 let task_ctx = ctx.task_ctx();
877 let mut stream = exec.execute(0, task_ctx).unwrap();
878 let result = stream.next().await.unwrap().unwrap();
879 assert_eq!(result.num_rows(), 2);
880
881 let flags = result
882 .column_by_name("interference_flags")
883 .unwrap()
884 .as_any()
885 .downcast_ref::<StringArray>()
886 .unwrap();
887 assert_eq!(
889 flags.value(0),
890 "near_duplicate",
891 "expected near_duplicate, got: {}",
892 flags.value(0)
893 );
894 assert_eq!(flags.value(1), "none");
896 }
897
898 #[tokio::test]
900 async fn near_dup_silently_skipped_without_storage() {
901 use arrow_array::builder::{FixedSizeListBuilder, Float32Builder};
902 use futures::StreamExt;
903
904 let dim = 3_i32;
905 let input_schema = Arc::new(Schema::new(vec![
906 Field::new("id", DataType::Utf8, false),
907 Field::new("content", DataType::Utf8, false),
908 Field::new(
909 "embedding",
910 DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), dim),
911 true,
912 ),
913 ]));
914 let mut b = FixedSizeListBuilder::new(Float32Builder::new(), dim);
915 for &v in &[1.0_f32, 0.0, 0.0] {
916 b.values().append_value(v);
917 }
918 b.append(true);
919 let batch = RecordBatch::try_new(
920 input_schema.clone(),
921 vec![
922 Arc::new(StringArray::from(vec!["a"])),
923 Arc::new(StringArray::from(vec!["some content"])),
924 Arc::new(b.finish()),
925 ],
926 )
927 .unwrap();
928
929 let input_exec = Arc::new(crate::test_utils::MemoryBatchExec::new(
930 input_schema,
931 vec![batch],
932 ));
933
934 let exec = InterferenceDetectorExec::new(input_exec, InterferenceConfig::default());
936 let ctx = Arc::new(TaskContext::default());
937 let mut stream = exec.execute(0, ctx).unwrap();
938 let result = stream.next().await.unwrap().unwrap();
939 assert_eq!(result.num_rows(), 1);
940 let flags = result
941 .column_by_name("interference_flags")
942 .unwrap()
943 .as_any()
944 .downcast_ref::<StringArray>()
945 .unwrap();
946 assert_eq!(flags.value(0), "none");
947 }
948
949 #[tokio::test]
956 async fn detects_nli_contradiction_within_batch() {
957 use futures::StreamExt;
958
959 let schema = Arc::new(Schema::new(vec![
960 Field::new("id", DataType::Utf8, false),
961 Field::new("content", DataType::Utf8, false),
962 ]));
963
964 let batch = RecordBatch::try_new(
965 schema.clone(),
966 vec![
967 Arc::new(StringArray::from(vec!["r0", "r1"])),
968 Arc::new(StringArray::from(vec![
969 "The cat is alive and healthy.",
970 "The cat is not alive and not healthy.",
971 ])),
972 ],
973 )
974 .unwrap();
975
976 let input = Arc::new(crate::test_utils::MemoryBatchExec::new(
977 schema.clone(),
978 vec![batch],
979 ));
980 let exec = InterferenceDetectorExec::new(input, InterferenceConfig::default());
981 let ctx = Arc::new(TaskContext::default());
982 let mut stream = exec.execute(0, ctx).unwrap();
983 let result = stream.next().await.unwrap().unwrap();
984
985 assert_eq!(result.num_rows(), 2);
986 let flags = result
987 .column_by_name("interference_flags")
988 .unwrap()
989 .as_any()
990 .downcast_ref::<StringArray>()
991 .unwrap();
992 assert_eq!(flags.value(0), "none", "row 0 should have no flag");
994 assert_eq!(
996 flags.value(1),
997 "conflict",
998 "row 1 should be flagged as conflict"
999 );
1000 }
1001
1002 #[tokio::test]
1004 async fn nli_no_false_positive_on_unrelated_content() {
1005 use futures::StreamExt;
1006
1007 let schema = Arc::new(Schema::new(vec![
1008 Field::new("id", DataType::Utf8, false),
1009 Field::new("content", DataType::Utf8, false),
1010 ]));
1011
1012 let batch = RecordBatch::try_new(
1013 schema.clone(),
1014 vec![
1015 Arc::new(StringArray::from(vec!["r0", "r1", "r2"])),
1016 Arc::new(StringArray::from(vec![
1017 "Paris is the capital of France.",
1018 "The boiling point of water is 100 degrees.",
1019 "Jupiter is the largest planet in the solar system.",
1020 ])),
1021 ],
1022 )
1023 .unwrap();
1024
1025 let input = Arc::new(crate::test_utils::MemoryBatchExec::new(
1026 schema.clone(),
1027 vec![batch],
1028 ));
1029 let exec = InterferenceDetectorExec::new(input, InterferenceConfig::default());
1030 let ctx = Arc::new(TaskContext::default());
1031 let mut stream = exec.execute(0, ctx).unwrap();
1032 let result = stream.next().await.unwrap().unwrap();
1033
1034 let flags = result
1035 .column_by_name("interference_flags")
1036 .unwrap()
1037 .as_any()
1038 .downcast_ref::<StringArray>()
1039 .unwrap();
1040 for i in 0..3 {
1041 assert_eq!(flags.value(i), "none", "row {i} should not be flagged");
1042 }
1043 }
1044
1045 #[tokio::test]
1047 async fn nli_disabled_when_max_pairs_zero() {
1048 use futures::StreamExt;
1049
1050 let schema = Arc::new(Schema::new(vec![
1051 Field::new("id", DataType::Utf8, false),
1052 Field::new("content", DataType::Utf8, false),
1053 ]));
1054
1055 let batch = RecordBatch::try_new(
1056 schema.clone(),
1057 vec![
1058 Arc::new(StringArray::from(vec!["r0", "r1"])),
1059 Arc::new(StringArray::from(vec![
1060 "The cat is alive.",
1061 "The cat is not alive.",
1062 ])),
1063 ],
1064 )
1065 .unwrap();
1066
1067 let input = Arc::new(crate::test_utils::MemoryBatchExec::new(
1068 schema.clone(),
1069 vec![batch],
1070 ));
1071 let config = InterferenceConfig {
1072 nli_max_pairs: 0, ..Default::default()
1074 };
1075 let exec = InterferenceDetectorExec::new(input, config);
1076 let ctx = Arc::new(TaskContext::default());
1077 let mut stream = exec.execute(0, ctx).unwrap();
1078 let result = stream.next().await.unwrap().unwrap();
1079
1080 let flags = result
1081 .column_by_name("interference_flags")
1082 .unwrap()
1083 .as_any()
1084 .downcast_ref::<StringArray>()
1085 .unwrap();
1086 assert_eq!(
1088 flags.value(1),
1089 "none",
1090 "NLI should be skipped when nli_max_pairs=0"
1091 );
1092 }
1093
1094 #[tokio::test]
1096 async fn nli_skipped_for_already_flagged_duplicate_rows() {
1097 use futures::StreamExt;
1098
1099 let schema = Arc::new(Schema::new(vec![
1100 Field::new("id", DataType::Utf8, false),
1101 Field::new("content", DataType::Utf8, false),
1102 ]));
1103
1104 let batch = RecordBatch::try_new(
1108 schema.clone(),
1109 vec![
1110 Arc::new(StringArray::from(vec!["r0", "r1", "r2"])),
1111 Arc::new(StringArray::from(vec![
1112 "The sky is blue.",
1113 "The sky is blue.", "The sky is not blue.",
1115 ])),
1116 ],
1117 )
1118 .unwrap();
1119
1120 let input = Arc::new(crate::test_utils::MemoryBatchExec::new(
1121 schema.clone(),
1122 vec![batch],
1123 ));
1124 let exec = InterferenceDetectorExec::new(input, InterferenceConfig::default());
1125 let ctx = Arc::new(TaskContext::default());
1126 let mut stream = exec.execute(0, ctx).unwrap();
1127 let result = stream.next().await.unwrap().unwrap();
1128
1129 let flags = result
1130 .column_by_name("interference_flags")
1131 .unwrap()
1132 .as_any()
1133 .downcast_ref::<StringArray>()
1134 .unwrap();
1135 assert_eq!(flags.value(0), "none", "row 0: first occurrence");
1136 assert_eq!(
1137 flags.value(1),
1138 "duplicate",
1139 "row 1: exact dup, not conflict"
1140 );
1141 assert_eq!(
1143 flags.value(2),
1144 "conflict",
1145 "row 2: contradiction with row 0"
1146 );
1147 }
1148
1149 #[tokio::test]
1154 async fn nli_respects_injected_classifier() {
1155 use futures::StreamExt;
1156
1157 #[derive(Debug)]
1159 struct AlwaysContradiction;
1160 impl NliClassifier for AlwaysContradiction {
1161 fn classify(
1162 &self,
1163 _text_a: &str,
1164 _text_b: &str,
1165 ) -> (crate::operators::nli_contradiction::NliLabel, f32) {
1166 (NliLabel::Contradiction, 0.99)
1167 }
1168 fn backend_name(&self) -> &'static str {
1169 "always_contradiction"
1170 }
1171 }
1172
1173 let schema = Arc::new(Schema::new(vec![
1174 Field::new("id", DataType::Utf8, false),
1175 Field::new("content", DataType::Utf8, false),
1176 ]));
1177
1178 let batch = RecordBatch::try_new(
1179 schema.clone(),
1180 vec![
1181 Arc::new(StringArray::from(vec!["r0", "r1"])),
1182 Arc::new(StringArray::from(vec!["anything", "anything else"])),
1183 ],
1184 )
1185 .unwrap();
1186
1187 let input = Arc::new(crate::test_utils::MemoryBatchExec::new(
1188 schema.clone(),
1189 vec![batch],
1190 ));
1191 let exec = InterferenceDetectorExec::with_nli_classifier(
1192 input,
1193 InterferenceConfig::default(),
1194 Arc::new(AlwaysContradiction),
1195 );
1196 let ctx = Arc::new(TaskContext::default());
1197 let mut stream = exec.execute(0, ctx).unwrap();
1198 let result = stream.next().await.unwrap().unwrap();
1199
1200 let flags = result
1201 .column_by_name("interference_flags")
1202 .unwrap()
1203 .as_any()
1204 .downcast_ref::<StringArray>()
1205 .unwrap();
1206 assert_eq!(flags.value(0), "none", "row 0: no prior rows");
1207 assert_eq!(
1208 flags.value(1),
1209 "conflict",
1210 "row 1: injected classifier fires"
1211 );
1212 }
1213}