Skip to main content

anno/backends/
middleware.rs

1//! Pipeline middleware for composing and extending NER backends.
2//!
3//! This module provides a flexible middleware architecture for:
4//! - Pre-processing text before entity extraction
5//! - Post-processing entities after extraction
6//! - Filtering, enriching, or transforming entities
7//! - Integrating with external systems
8//!
9//! # Design
10//!
11//! The middleware pattern follows a "chain of responsibility" approach where
12//! each middleware can:
13//! 1. Transform input before passing to the next stage
14//! 2. Transform output after receiving from the previous stage
15//! 3. Short-circuit the chain by returning early
16//!
17//! # Example: Basic Pipeline
18//!
19//! ```rust,ignore
20//! use anno::backends::middleware::{Pipeline, Middleware, NormalizeWhitespace, FilterByConfidence};
21//! use anno::StackedNER;
22//!
23//! let pipeline = Pipeline::new(Box::new(StackedNER::default()))
24//!     .with(NormalizeWhitespace)
25//!     .with(FilterByConfidence(0.5));
26//!
27//! let entities = pipeline.extract_entities("Hello  world", None)?;
28//! ```
29//!
30//! # Example: Custom Middleware
31//!
32//! ```rust,ignore
33//! use anno::backends::middleware::{Middleware, MiddlewareContext};
34//! use anno::{Entity, Result};
35//!
36//! struct LogEntities;
37//!
38//! impl Middleware for LogEntities {
39//!     fn post_process(&self, ctx: &mut MiddlewareContext, entities: Vec<Entity>) -> Result<Vec<Entity>> {
40//!         eprintln!("Found {} entities", entities.len());
41//!         Ok(entities)
42//!     }
43//! }
44//! ```
45//!
46//! # Example: Entity Enrichment
47//!
48//! ```rust,ignore
49//! use anno::backends::middleware::{Middleware, MiddlewareContext};
50//! use anno::{Entity, Result};
51//!
52//! struct AddKnowledgeBaseLinks {
53//!     kb_lookup: HashMap<String, String>,
54//! }
55//!
56//! impl Middleware for AddKnowledgeBaseLinks {
57//!     fn post_process(&self, ctx: &mut MiddlewareContext, mut entities: Vec<Entity>) -> Result<Vec<Entity>> {
58//!         for entity in &mut entities {
59//!             if let Some(kb_id) = self.kb_lookup.get(&entity.text.to_lowercase()) {
60//!                 entity.kb_id = Some(kb_id.clone());
61//!             }
62//!         }
63//!         Ok(entities)
64//!     }
65//! }
66//! ```
67
68use crate::{Entity, EntityType, Model, Result};
69use std::borrow::Cow;
70use std::collections::HashMap;
71use std::sync::Arc;
72
73// =============================================================================
74// Middleware Context
75// =============================================================================
76
77/// Context passed through the middleware chain.
78///
79/// Provides access to:
80/// - Original and transformed text
81/// - Metadata and configuration
82/// - Extracted entity types filter
83#[derive(Debug, Clone)]
84pub struct MiddlewareContext {
85    /// Original input text (before any preprocessing).
86    pub original_text: String,
87    /// Current text (may be transformed by preprocessors).
88    pub current_text: String,
89    /// Requested entity types filter.
90    pub entity_types: Option<Vec<EntityType>>,
91    /// Language hint for the text.
92    pub language: Option<String>,
93    /// Arbitrary metadata (for custom middleware).
94    pub metadata: HashMap<String, String>,
95}
96
97impl MiddlewareContext {
98    /// Create a new context from input text.
99    #[must_use]
100    pub fn new(text: impl Into<String>) -> Self {
101        let text = text.into();
102        Self {
103            original_text: text.clone(),
104            current_text: text,
105            entity_types: None,
106            language: None,
107            metadata: HashMap::new(),
108        }
109    }
110
111    /// Set language hint.
112    #[must_use]
113    pub fn with_language(mut self, lang: impl Into<String>) -> Self {
114        self.language = Some(lang.into());
115        self
116    }
117
118    /// Set entity types filter.
119    #[must_use]
120    pub fn with_entity_types(mut self, types: Vec<EntityType>) -> Self {
121        self.entity_types = Some(types);
122        self
123    }
124
125    /// Set metadata value.
126    pub fn set_metadata(&mut self, key: impl Into<String>, value: impl Into<String>) {
127        self.metadata.insert(key.into(), value.into());
128    }
129}
130
131// =============================================================================
132// Middleware Trait
133// =============================================================================
134
135/// Middleware for transforming input/output in the NER pipeline.
136///
137/// Implement this trait to create custom middleware that can:
138/// - Preprocess text before extraction
139/// - Postprocess entities after extraction
140/// - Add metadata or enrich entities
141pub trait Middleware: Send + Sync {
142    /// Preprocess text before entity extraction.
143    ///
144    /// Returns the (possibly transformed) text to pass to the next stage.
145    /// The default implementation passes through unchanged.
146    fn pre_process<'a>(&self, ctx: &mut MiddlewareContext, text: &'a str) -> Result<Cow<'a, str>> {
147        let _ = ctx;
148        Ok(Cow::Borrowed(text))
149    }
150
151    /// Postprocess entities after extraction.
152    ///
153    /// Returns the (possibly transformed) entities to pass to the next stage.
154    /// The default implementation passes through unchanged.
155    fn post_process(
156        &self,
157        ctx: &mut MiddlewareContext,
158        entities: Vec<Entity>,
159    ) -> Result<Vec<Entity>> {
160        let _ = ctx;
161        Ok(entities)
162    }
163
164    /// Name of this middleware (for debugging/logging).
165    fn name(&self) -> &'static str {
166        "unnamed"
167    }
168}
169
170// =============================================================================
171// Pipeline
172// =============================================================================
173
174/// A composable NER pipeline with middleware support.
175///
176/// The pipeline executes middleware in order:
177/// 1. Pre-process: Each middleware transforms the input text
178/// 2. Extract: The core backend extracts entities
179/// 3. Post-process: Each middleware transforms the entities (in reverse order)
180///
181/// # Example
182///
183/// ```rust,ignore
184/// use anno::backends::middleware::{Pipeline, NormalizeWhitespace, FilterByConfidence};
185/// use anno::StackedNER;
186///
187/// let pipeline = Pipeline::new(Box::new(StackedNER::default()))
188///     .with(NormalizeWhitespace)
189///     .with(FilterByConfidence(0.5));
190///
191/// // Extract entities through the pipeline
192/// let entities = pipeline.extract("Hello  world")?;
193/// ```
194pub struct Pipeline {
195    backend: Arc<dyn Model>,
196    middleware: Vec<Box<dyn Middleware>>,
197}
198
199impl Pipeline {
200    /// Create a new pipeline with a backend.
201    #[must_use]
202    pub fn new(backend: Box<dyn Model>) -> Self {
203        Self {
204            backend: Arc::from(backend),
205            middleware: Vec::new(),
206        }
207    }
208
209    /// Add middleware to the pipeline.
210    #[must_use]
211    pub fn with<M: Middleware + 'static>(mut self, middleware: M) -> Self {
212        self.middleware.push(Box::new(middleware));
213        self
214    }
215
216    /// Add middleware conditionally.
217    #[must_use]
218    pub fn with_if<M: Middleware + 'static>(self, condition: bool, middleware: M) -> Self {
219        if condition {
220            self.with(middleware)
221        } else {
222            self
223        }
224    }
225
226    /// Extract entities through the pipeline.
227    pub fn extract(&self, text: &str) -> Result<Vec<Entity>> {
228        self.extract_with_context(&mut MiddlewareContext::new(text))
229    }
230
231    /// Extract entities with explicit context.
232    pub fn extract_with_context(&self, ctx: &mut MiddlewareContext) -> Result<Vec<Entity>> {
233        // Pre-process: each middleware transforms the text
234        // Clone the text to avoid borrow issues
235        let mut current_text = ctx.current_text.clone();
236        for mw in &self.middleware {
237            let result = mw.pre_process(ctx, &current_text)?;
238            current_text = result.into_owned();
239        }
240
241        // Update context with final preprocessed text
242        ctx.current_text = current_text;
243
244        // Extract entities from the backend
245        let mut entities = self
246            .backend
247            .extract_entities(&ctx.current_text, ctx.language.as_deref())?;
248
249        // Post-process: each middleware transforms entities (reverse order)
250        for mw in self.middleware.iter().rev() {
251            entities = mw.post_process(ctx, entities)?;
252        }
253
254        Ok(entities)
255    }
256
257    /// Get the underlying backend.
258    #[must_use]
259    pub fn backend(&self) -> &dyn Model {
260        &*self.backend
261    }
262
263    /// List middleware names.
264    #[must_use]
265    pub fn middleware_names(&self) -> Vec<&'static str> {
266        self.middleware.iter().map(|m| m.name()).collect()
267    }
268}
269
270impl std::fmt::Debug for Pipeline {
271    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
272        f.debug_struct("Pipeline")
273            .field("middleware", &self.middleware_names())
274            .finish()
275    }
276}
277
278// =============================================================================
279// Built-in Middleware
280// =============================================================================
281
282/// Normalize whitespace in input text.
283///
284/// - Collapses multiple spaces into single spaces
285/// - Trims leading/trailing whitespace
286#[derive(Debug, Clone, Copy, Default)]
287pub struct NormalizeWhitespace;
288
289impl Middleware for NormalizeWhitespace {
290    fn pre_process<'a>(&self, _ctx: &mut MiddlewareContext, text: &'a str) -> Result<Cow<'a, str>> {
291        // Check if normalization is needed
292        let needs_normalization = text.contains("  ")
293            || text.starts_with(char::is_whitespace)
294            || text.ends_with(char::is_whitespace);
295
296        if needs_normalization {
297            let normalized: String = text.split_whitespace().collect::<Vec<_>>().join(" ");
298            Ok(Cow::Owned(normalized))
299        } else {
300            Ok(Cow::Borrowed(text))
301        }
302    }
303
304    fn name(&self) -> &'static str {
305        "normalize_whitespace"
306    }
307}
308
309/// Filter entities by minimum confidence threshold.
310#[derive(Debug, Clone, Copy)]
311pub struct FilterByConfidence(pub f64);
312
313impl Middleware for FilterByConfidence {
314    fn post_process(
315        &self,
316        _ctx: &mut MiddlewareContext,
317        entities: Vec<Entity>,
318    ) -> Result<Vec<Entity>> {
319        let threshold = self.0;
320        Ok(entities
321            .into_iter()
322            .filter(|e| e.confidence >= threshold)
323            .collect())
324    }
325
326    fn name(&self) -> &'static str {
327        "filter_by_confidence"
328    }
329}
330
331/// Filter entities by entity type.
332#[derive(Debug, Clone)]
333pub struct FilterByType(pub Vec<EntityType>);
334
335impl Middleware for FilterByType {
336    fn post_process(
337        &self,
338        _ctx: &mut MiddlewareContext,
339        entities: Vec<Entity>,
340    ) -> Result<Vec<Entity>> {
341        Ok(entities
342            .into_iter()
343            .filter(|e| self.0.contains(&e.entity_type))
344            .collect())
345    }
346
347    fn name(&self) -> &'static str {
348        "filter_by_type"
349    }
350}
351
352/// Remove overlapping entities, keeping the highest confidence one.
353#[derive(Debug, Clone, Copy, Default)]
354pub struct RemoveOverlaps;
355
356impl Middleware for RemoveOverlaps {
357    fn post_process(
358        &self,
359        _ctx: &mut MiddlewareContext,
360        mut entities: Vec<Entity>,
361    ) -> Result<Vec<Entity>> {
362        // Sort by confidence descending
363        entities.sort_by(|a, b| {
364            b.confidence
365                .partial_cmp(&a.confidence)
366                .unwrap_or(std::cmp::Ordering::Equal)
367        });
368
369        let mut result = Vec::with_capacity(entities.len());
370        for entity in entities {
371            let overlaps = result
372                .iter()
373                .any(|e: &Entity| entity.start < e.end && entity.end > e.start);
374            if !overlaps {
375                result.push(entity);
376            }
377        }
378
379        // Sort back by position
380        result.sort_by_key(|e| e.start);
381        Ok(result)
382    }
383
384    fn name(&self) -> &'static str {
385        "remove_overlaps"
386    }
387}
388
389/// Add provenance information to all entities.
390#[derive(Debug, Clone)]
391pub struct AddProvenance {
392    /// Backend name to set.
393    pub backend: String,
394    /// Method description.
395    pub method: String,
396}
397
398impl AddProvenance {
399    /// Create a new AddProvenance middleware.
400    #[must_use]
401    pub fn new(backend: impl Into<String>, method: impl Into<String>) -> Self {
402        Self {
403            backend: backend.into(),
404            method: method.into(),
405        }
406    }
407}
408
409impl Middleware for AddProvenance {
410    fn post_process(
411        &self,
412        _ctx: &mut MiddlewareContext,
413        mut entities: Vec<Entity>,
414    ) -> Result<Vec<Entity>> {
415        use anno_core::Provenance;
416        for entity in &mut entities {
417            if entity.provenance.is_none() {
418                entity.provenance = Some(Provenance::ml(self.backend.clone(), entity.confidence));
419            }
420        }
421        Ok(entities)
422    }
423
424    fn name(&self) -> &'static str {
425        "add_provenance"
426    }
427}
428
429/// Merge adjacent entities of the same type.
430///
431/// Useful for combining split entities like "New" + "York" → "New York".
432#[derive(Debug, Clone, Copy)]
433pub struct MergeAdjacent {
434    /// Maximum gap (in characters) between entities to merge.
435    pub max_gap: usize,
436}
437
438impl Default for MergeAdjacent {
439    fn default() -> Self {
440        Self { max_gap: 1 }
441    }
442}
443
444impl Middleware for MergeAdjacent {
445    fn post_process(
446        &self,
447        ctx: &mut MiddlewareContext,
448        mut entities: Vec<Entity>,
449    ) -> Result<Vec<Entity>> {
450        if entities.len() < 2 {
451            return Ok(entities);
452        }
453
454        // Sort by position
455        entities.sort_by_key(|e| e.start);
456
457        let text = &ctx.current_text;
458        let mut merged = Vec::with_capacity(entities.len());
459        let mut current: Option<Entity> = None;
460
461        for entity in entities {
462            if let Some(prev) = current.take() {
463                // Check if should merge
464                let gap = entity.start.saturating_sub(prev.end);
465                let same_type = prev.entity_type == entity.entity_type;
466
467                if same_type && gap <= self.max_gap {
468                    // Merge entities
469                    let merged_text = text
470                        .chars()
471                        .skip(prev.start)
472                        .take(entity.end - prev.start)
473                        .collect::<String>();
474                    let merged_confidence = (prev.confidence + entity.confidence) / 2.0;
475
476                    current = Some(Entity::new(
477                        merged_text,
478                        prev.entity_type,
479                        prev.start,
480                        entity.end,
481                        merged_confidence,
482                    ));
483                } else {
484                    merged.push(prev);
485                    current = Some(entity);
486                }
487            } else {
488                current = Some(entity);
489            }
490        }
491
492        if let Some(last) = current {
493            merged.push(last);
494        }
495
496        Ok(merged)
497    }
498
499    fn name(&self) -> &'static str {
500        "merge_adjacent"
501    }
502}
503
504/// Callback middleware for custom processing.
505///
506/// Wraps a closure for simple one-off transformations.
507pub struct Callback<F> {
508    name: &'static str,
509    func: F,
510}
511
512impl<F> Callback<F>
513where
514    F: Fn(&mut MiddlewareContext, Vec<Entity>) -> Result<Vec<Entity>> + Send + Sync,
515{
516    /// Create a new callback middleware.
517    #[must_use]
518    pub fn new(name: &'static str, func: F) -> Self {
519        Self { name, func }
520    }
521}
522
523impl<F> Middleware for Callback<F>
524where
525    F: Fn(&mut MiddlewareContext, Vec<Entity>) -> Result<Vec<Entity>> + Send + Sync,
526{
527    fn post_process(
528        &self,
529        ctx: &mut MiddlewareContext,
530        entities: Vec<Entity>,
531    ) -> Result<Vec<Entity>> {
532        (self.func)(ctx, entities)
533    }
534
535    fn name(&self) -> &'static str {
536        self.name
537    }
538}
539
540impl<F> std::fmt::Debug for Callback<F> {
541    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
542        f.debug_struct("Callback")
543            .field("name", &self.name)
544            .finish()
545    }
546}
547
548// =============================================================================
549// Hook System
550// =============================================================================
551
552/// Event types for hooks.
553#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
554pub enum HookEvent {
555    /// Before extraction starts.
556    BeforeExtraction,
557    /// After extraction completes.
558    AfterExtraction,
559    /// When an entity is found.
560    EntityFound,
561    /// When an error occurs.
562    OnError,
563}
564
565/// Hook function signature.
566pub type HookFn = Box<dyn Fn(HookEvent, &MiddlewareContext, Option<&[Entity]>) + Send + Sync>;
567
568/// Hook registry for pipeline events.
569pub struct HookRegistry {
570    hooks: HashMap<HookEvent, Vec<HookFn>>,
571}
572
573impl HookRegistry {
574    /// Create a new hook registry.
575    #[must_use]
576    pub fn new() -> Self {
577        Self {
578            hooks: HashMap::new(),
579        }
580    }
581
582    /// Register a hook for an event.
583    pub fn register(&mut self, event: HookEvent, hook: HookFn) {
584        self.hooks.entry(event).or_default().push(hook);
585    }
586
587    /// Trigger hooks for an event.
588    pub fn trigger(&self, event: HookEvent, ctx: &MiddlewareContext, entities: Option<&[Entity]>) {
589        if let Some(hooks) = self.hooks.get(&event) {
590            for hook in hooks {
591                hook(event, ctx, entities);
592            }
593        }
594    }
595}
596
597impl Default for HookRegistry {
598    fn default() -> Self {
599        Self::new()
600    }
601}
602
603impl std::fmt::Debug for HookRegistry {
604    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
605        f.debug_struct("HookRegistry")
606            .field("events", &self.hooks.keys().collect::<Vec<_>>())
607            .finish()
608    }
609}
610
611// =============================================================================
612// Hooked Pipeline (with interior mutability)
613// =============================================================================
614
615use std::cell::RefCell;
616
617/// A pipeline with hook support using interior mutability.
618///
619/// This addresses the borrow checker issues with the standard `HookRegistry` by:
620/// 1. Using `RefCell` for interior mutability of hook-related state
621/// 2. Separating "before" and "after" contexts to avoid simultaneous borrows
622/// 3. Using owned data in hook invocations to avoid reference conflicts
623///
624/// # Design
625///
626/// The `HookedPipeline` solves the problem where:
627/// - `extract_with_context` needs `&mut MiddlewareContext`
628/// - Hooks need to be called with context data during extraction
629/// - Multiple hooks may need to read entity data while context is borrowed
630///
631/// By cloning context data before passing to hooks, we avoid borrow conflicts.
632///
633/// # Example
634///
635/// ```rust,ignore
636/// use anno::backends::middleware::{HookedPipeline, HookEvent};
637/// use anno::HeuristicNER;
638///
639/// let mut pipeline = HookedPipeline::new(Box::new(HeuristicNER::new()));
640///
641/// pipeline.on(HookEvent::AfterExtraction, |event, text, entities| {
642///     if let Some(entities) = entities {
643///         println!("Found {} entities in: {}", entities.len(), text);
644///     }
645/// });
646///
647/// let entities = pipeline.extract("Hello World")?;
648/// ```
649pub struct HookedPipeline {
650    backend: Arc<dyn Model>,
651    middleware: Vec<Box<dyn Middleware>>,
652    /// Hook registry with interior mutability for safe hook registration during extraction
653    hooks: RefCell<HookRegistry>,
654}
655
656impl HookedPipeline {
657    /// Create a new hooked pipeline with a backend.
658    #[must_use]
659    pub fn new(backend: Box<dyn Model>) -> Self {
660        Self {
661            backend: Arc::from(backend),
662            middleware: Vec::new(),
663            hooks: RefCell::new(HookRegistry::new()),
664        }
665    }
666
667    /// Add middleware to the pipeline.
668    #[must_use]
669    pub fn with<M: Middleware + 'static>(mut self, middleware: M) -> Self {
670        self.middleware.push(Box::new(middleware));
671        self
672    }
673
674    /// Register a hook for an event.
675    ///
676    /// Hooks receive:
677    /// - `event`: The event type
678    /// - `text`: The current text being processed (owned copy)
679    /// - `entities`: Optional entities (owned copy)
680    ///
681    /// # Example
682    ///
683    /// ```rust,ignore
684    /// pipeline.on(HookEvent::EntityFound, |event, text, entities| {
685    ///     if let Some(entities) = entities {
686    ///         for entity in entities {
687    ///             println!("Found: {} ({})", entity.text, entity.entity_type);
688    ///         }
689    ///     }
690    /// });
691    /// ```
692    pub fn on<F>(&self, event: HookEvent, handler: F)
693    where
694        F: Fn(HookEvent, &str, Option<&[Entity]>) + Send + Sync + 'static,
695    {
696        // Wrap the simpler handler in the full hook signature
697        let wrapper = Box::new(
698            move |evt: HookEvent, ctx: &MiddlewareContext, entities: Option<&[Entity]>| {
699                handler(evt, &ctx.current_text, entities);
700            },
701        );
702        self.hooks.borrow_mut().register(event, wrapper);
703    }
704
705    /// Register a hook with full context access.
706    pub fn on_with_context(&self, event: HookEvent, hook: HookFn) {
707        self.hooks.borrow_mut().register(event, hook);
708    }
709
710    /// Extract entities through the pipeline with hook support.
711    pub fn extract(&self, text: &str) -> Result<Vec<Entity>> {
712        let mut ctx = MiddlewareContext::new(text);
713
714        // Trigger before extraction hooks (with cloned context to avoid borrow issues)
715        {
716            let hooks = self.hooks.borrow();
717            hooks.trigger(HookEvent::BeforeExtraction, &ctx, None);
718        }
719
720        // Pre-process: each middleware transforms the text
721        let mut current_text = ctx.current_text.clone();
722        for mw in &self.middleware {
723            let result = mw.pre_process(&mut ctx, &current_text)?;
724            current_text = result.into_owned();
725        }
726        ctx.current_text = current_text;
727
728        // Extract entities from the backend
729        let entities = match self
730            .backend
731            .extract_entities(&ctx.current_text, ctx.language.as_deref())
732        {
733            Ok(entities) => entities,
734            Err(e) => {
735                // Trigger error hooks
736                let hooks = self.hooks.borrow();
737                hooks.trigger(HookEvent::OnError, &ctx, None);
738                return Err(e);
739            }
740        };
741
742        // Trigger entity found hooks for each entity
743        {
744            let hooks = self.hooks.borrow();
745            for entity in &entities {
746                hooks.trigger(
747                    HookEvent::EntityFound,
748                    &ctx,
749                    Some(std::slice::from_ref(entity)),
750                );
751            }
752        }
753
754        // Post-process: each middleware transforms entities (reverse order)
755        let mut entities = entities;
756        for mw in self.middleware.iter().rev() {
757            entities = mw.post_process(&mut ctx, entities)?;
758        }
759
760        // Trigger after extraction hooks
761        {
762            let hooks = self.hooks.borrow();
763            hooks.trigger(HookEvent::AfterExtraction, &ctx, Some(&entities));
764        }
765
766        Ok(entities)
767    }
768
769    /// Get the underlying backend.
770    #[must_use]
771    pub fn backend(&self) -> &dyn Model {
772        &*self.backend
773    }
774
775    /// List middleware names.
776    #[must_use]
777    pub fn middleware_names(&self) -> Vec<&'static str> {
778        self.middleware.iter().map(|m| m.name()).collect()
779    }
780
781    /// Get the number of registered hooks.
782    #[must_use]
783    pub fn hook_count(&self) -> usize {
784        self.hooks.borrow().hooks.values().map(|v| v.len()).sum()
785    }
786}
787
788impl std::fmt::Debug for HookedPipeline {
789    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
790        f.debug_struct("HookedPipeline")
791            .field("middleware", &self.middleware_names())
792            .field("hooks", &*self.hooks.borrow())
793            .finish()
794    }
795}
796
797// =============================================================================
798// Tests
799// =============================================================================
800
801#[cfg(test)]
802mod tests {
803    use super::*;
804    use crate::HeuristicNER;
805
806    #[test]
807    fn test_normalize_whitespace() {
808        let mw = NormalizeWhitespace;
809        let mut ctx = MiddlewareContext::new("  hello   world  ");
810        let text = ctx.original_text.clone();
811        let result = mw
812            .pre_process(&mut ctx, &text)
813            .expect("pre_process should succeed");
814        assert_eq!(result, "hello world");
815    }
816
817    #[test]
818    fn test_filter_by_confidence() {
819        let mw = FilterByConfidence(0.5);
820        let mut ctx = MiddlewareContext::new("test");
821        let entities = vec![
822            Entity::new("high", EntityType::Person, 0, 4, 0.8),
823            Entity::new("low", EntityType::Person, 5, 8, 0.3),
824        ];
825        let result = mw
826            .post_process(&mut ctx, entities)
827            .expect("post_process should succeed");
828        assert_eq!(result.len(), 1);
829        assert_eq!(result[0].text, "high");
830    }
831
832    #[test]
833    fn test_pipeline_basic() {
834        let pipeline = Pipeline::new(Box::new(HeuristicNER::new()))
835            .with(NormalizeWhitespace)
836            .with(FilterByConfidence(0.3));
837
838        let _entities = pipeline
839            .extract("Hello  World")
840            .expect("extraction should succeed");
841        // Just verify it runs without error
842    }
843
844    #[test]
845    fn test_remove_overlaps() {
846        let mw = RemoveOverlaps;
847        let mut ctx = MiddlewareContext::new("New York City");
848        let entities = vec![
849            Entity::new("New York", EntityType::Location, 0, 8, 0.9),
850            Entity::new("York City", EntityType::Location, 4, 13, 0.7),
851        ];
852        let result = mw
853            .post_process(&mut ctx, entities)
854            .expect("post_process should succeed");
855        assert_eq!(result.len(), 1);
856        assert_eq!(result[0].text, "New York"); // Higher confidence wins
857    }
858
859    #[test]
860    fn test_hooked_pipeline_basic() {
861        use std::sync::atomic::{AtomicUsize, Ordering};
862        use std::sync::Arc;
863
864        let pipeline = HookedPipeline::new(Box::new(HeuristicNER::new())).with(NormalizeWhitespace);
865
866        // Track hook invocations
867        let before_count = Arc::new(AtomicUsize::new(0));
868        let after_count = Arc::new(AtomicUsize::new(0));
869
870        let before_count_clone = Arc::clone(&before_count);
871        pipeline.on(HookEvent::BeforeExtraction, move |_, _, _| {
872            before_count_clone.fetch_add(1, Ordering::SeqCst);
873        });
874
875        let after_count_clone = Arc::clone(&after_count);
876        pipeline.on(HookEvent::AfterExtraction, move |_, _, _| {
877            after_count_clone.fetch_add(1, Ordering::SeqCst);
878        });
879
880        let _entities = pipeline.extract("Hello World").unwrap();
881
882        assert_eq!(before_count.load(Ordering::SeqCst), 1);
883        assert_eq!(after_count.load(Ordering::SeqCst), 1);
884    }
885
886    #[test]
887    fn test_hooked_pipeline_entity_found_hook() {
888        use std::sync::atomic::{AtomicUsize, Ordering};
889        use std::sync::Arc;
890
891        let pipeline = HookedPipeline::new(Box::new(HeuristicNER::new()));
892
893        let entity_count = Arc::new(AtomicUsize::new(0));
894        let entity_count_clone = Arc::clone(&entity_count);
895
896        pipeline.on(HookEvent::EntityFound, move |_, _, entities| {
897            if entities.is_some() {
898                entity_count_clone.fetch_add(1, Ordering::SeqCst);
899            }
900        });
901
902        // HeuristicNER should find capitalized words
903        let _entities = pipeline.extract("John Smith went to New York").unwrap();
904
905        // EntityFound should be called for each entity
906        assert!(entity_count.load(Ordering::SeqCst) > 0);
907    }
908
909    #[test]
910    fn test_hooked_pipeline_with_middleware() {
911        let pipeline = HookedPipeline::new(Box::new(HeuristicNER::new()))
912            .with(NormalizeWhitespace)
913            .with(FilterByConfidence(0.3));
914
915        let entities = pipeline
916            .extract("  John   Smith  ")
917            .expect("extraction should succeed");
918        // Should normalize whitespace and filter by confidence
919        // Just verify it runs without error
920        let _ = entities;
921    }
922
923    #[test]
924    fn test_hooked_pipeline_hook_count() {
925        let pipeline = HookedPipeline::new(Box::new(HeuristicNER::new()));
926
927        assert_eq!(pipeline.hook_count(), 0);
928
929        pipeline.on(HookEvent::BeforeExtraction, |_, _, _| {});
930        pipeline.on(HookEvent::AfterExtraction, |_, _, _| {});
931        pipeline.on(HookEvent::EntityFound, |_, _, _| {});
932
933        assert_eq!(pipeline.hook_count(), 3);
934    }
935}