anno/backends/middleware/mod.rs
1//! Pipeline middleware for composing and extending NER backends.
2//!
3//! This module provides a flexible middleware architecture for:
4//! - Pre-processing text before entity extraction
5//! - Post-processing entities after extraction
6//! - Filtering, enriching, or transforming entities
7//! - Integrating with external systems
8//!
9//! # Design
10//!
11//! The middleware pattern follows a "chain of responsibility" approach where
12//! each middleware can:
13//! 1. Transform input before passing to the next stage
14//! 2. Transform output after receiving from the previous stage
15//! 3. Short-circuit the chain by returning early
16//!
17//! # Example: Basic Pipeline
18//!
19//! ```rust,ignore
20//! use anno::backends::middleware::{Pipeline, Middleware, NormalizeWhitespace, FilterByConfidence};
21//! use anno::StackedNER;
22//!
23//! let pipeline = Pipeline::new(Box::new(StackedNER::default()))
24//! .with(NormalizeWhitespace)
25//! .with(FilterByConfidence(0.5));
26//!
27//! let entities = pipeline.extract_entities("Hello world", None)?;
28//! ```
29//!
30//! # Example: Custom Middleware
31//!
32//! ```rust,ignore
33//! use anno::backends::middleware::{Middleware, MiddlewareContext};
34//! use anno::{Entity, Result};
35//!
36//! struct LogEntities;
37//!
38//! impl Middleware for LogEntities {
39//! fn post_process(&self, ctx: &mut MiddlewareContext, entities: Vec<Entity>) -> Result<Vec<Entity>> {
40//! eprintln!("Found {} entities", entities.len());
41//! Ok(entities)
42//! }
43//! }
44//! ```
45//!
46//! # Example: Entity Enrichment
47//!
48//! ```rust,ignore
49//! use anno::backends::middleware::{Middleware, MiddlewareContext};
50//! use anno::{Entity, Result};
51//!
52//! struct AddKnowledgeBaseLinks {
53//! kb_lookup: HashMap<String, String>,
54//! }
55//!
56//! impl Middleware for AddKnowledgeBaseLinks {
57//! fn post_process(&self, ctx: &mut MiddlewareContext, mut entities: Vec<Entity>) -> Result<Vec<Entity>> {
58//! for entity in &mut entities {
59//! if let Some(kb_id) = self.kb_lookup.get(&entity.text.to_lowercase()) {
60//! entity.kb_id = Some(kb_id.clone());
61//! }
62//! }
63//! Ok(entities)
64//! }
65//! }
66//! ```
67
68use crate::{Entity, EntityType, Model, Result};
69use std::borrow::Cow;
70use std::collections::HashMap;
71use std::sync::Arc;
72
73// =============================================================================
74// Middleware Context
75// =============================================================================
76
77/// Context passed through the middleware chain.
78///
79/// Provides access to:
80/// - Original and transformed text
81/// - Metadata and configuration
82/// - Extracted entity types filter
83#[derive(Debug, Clone)]
84pub struct MiddlewareContext {
85 /// Original input text (before any preprocessing).
86 pub original_text: String,
87 /// Current text (may be transformed by preprocessors).
88 pub current_text: String,
89 /// Requested entity types filter.
90 pub entity_types: Option<Vec<EntityType>>,
91 /// Language hint for the text.
92 pub language: Option<String>,
93 /// Arbitrary metadata (for custom middleware).
94 pub metadata: HashMap<String, String>,
95}
96
97impl MiddlewareContext {
98 /// Create a new context from input text.
99 #[must_use]
100 pub fn new(text: impl Into<String>) -> Self {
101 let text = text.into();
102 Self {
103 original_text: text.clone(),
104 current_text: text,
105 entity_types: None,
106 language: None,
107 metadata: HashMap::new(),
108 }
109 }
110
111 /// Set language hint.
112 #[must_use]
113 pub fn with_language(mut self, lang: impl Into<String>) -> Self {
114 self.language = Some(lang.into());
115 self
116 }
117
118 /// Set entity types filter.
119 #[must_use]
120 pub fn with_entity_types(mut self, types: Vec<EntityType>) -> Self {
121 self.entity_types = Some(types);
122 self
123 }
124
125 /// Set metadata value.
126 pub fn set_metadata(&mut self, key: impl Into<String>, value: impl Into<String>) {
127 self.metadata.insert(key.into(), value.into());
128 }
129}
130
131// =============================================================================
132// Middleware Trait
133// =============================================================================
134
135/// Middleware for transforming input/output in the NER pipeline.
136///
137/// Implement this trait to create custom middleware that can:
138/// - Preprocess text before extraction
139/// - Postprocess entities after extraction
140/// - Add metadata or enrich entities
141pub trait Middleware: Send + Sync {
142 /// Preprocess text before entity extraction.
143 ///
144 /// Returns the (possibly transformed) text to pass to the next stage.
145 /// The default implementation passes through unchanged.
146 fn pre_process<'a>(&self, ctx: &mut MiddlewareContext, text: &'a str) -> Result<Cow<'a, str>> {
147 let _ = ctx;
148 Ok(Cow::Borrowed(text))
149 }
150
151 /// Postprocess entities after extraction.
152 ///
153 /// Returns the (possibly transformed) entities to pass to the next stage.
154 /// The default implementation passes through unchanged.
155 fn post_process(
156 &self,
157 ctx: &mut MiddlewareContext,
158 entities: Vec<Entity>,
159 ) -> Result<Vec<Entity>> {
160 let _ = ctx;
161 Ok(entities)
162 }
163
164 /// Name of this middleware (for debugging/logging).
165 fn name(&self) -> &'static str {
166 "unnamed"
167 }
168}
169
170// =============================================================================
171// Pipeline
172// =============================================================================
173
174/// A composable NER pipeline with middleware support.
175///
176/// The pipeline executes middleware in order:
177/// 1. Pre-process: Each middleware transforms the input text
178/// 2. Extract: The core backend extracts entities
179/// 3. Post-process: Each middleware transforms the entities (in reverse order)
180///
181/// # Example
182///
183/// ```rust,ignore
184/// use anno::backends::middleware::{Pipeline, NormalizeWhitespace, FilterByConfidence};
185/// use anno::StackedNER;
186///
187/// let pipeline = Pipeline::new(Box::new(StackedNER::default()))
188/// .with(NormalizeWhitespace)
189/// .with(FilterByConfidence(0.5));
190///
191/// // Extract entities through the pipeline
192/// let entities = pipeline.extract("Hello world")?;
193/// ```
194pub struct Pipeline {
195 backend: Arc<dyn Model>,
196 middleware: Vec<Box<dyn Middleware>>,
197}
198
199impl Pipeline {
200 /// Create a new pipeline with a backend.
201 #[must_use]
202 pub fn new(backend: Box<dyn Model>) -> Self {
203 Self {
204 backend: Arc::from(backend),
205 middleware: Vec::new(),
206 }
207 }
208
209 /// Add middleware to the pipeline.
210 #[must_use]
211 pub fn with<M: Middleware + 'static>(mut self, middleware: M) -> Self {
212 self.middleware.push(Box::new(middleware));
213 self
214 }
215
216 /// Add middleware conditionally.
217 #[must_use]
218 pub fn with_if<M: Middleware + 'static>(self, condition: bool, middleware: M) -> Self {
219 if condition {
220 self.with(middleware)
221 } else {
222 self
223 }
224 }
225
226 /// Extract entities through the pipeline.
227 pub fn extract(&self, text: &str) -> Result<Vec<Entity>> {
228 self.extract_with_context(&mut MiddlewareContext::new(text))
229 }
230
231 /// Extract entities with explicit context.
232 pub fn extract_with_context(&self, ctx: &mut MiddlewareContext) -> Result<Vec<Entity>> {
233 // Pre-process: each middleware transforms the text
234 // Clone the text to avoid borrow issues
235 let mut current_text = ctx.current_text.clone();
236 for mw in &self.middleware {
237 let result = mw.pre_process(ctx, ¤t_text)?;
238 current_text = result.into_owned();
239 }
240
241 // Update context with final preprocessed text
242 ctx.current_text = current_text;
243
244 // Extract entities from the backend
245 let mut entities = self
246 .backend
247 .extract_entities(&ctx.current_text, ctx.language.as_deref())?;
248
249 // Post-process: each middleware transforms entities (reverse order)
250 for mw in self.middleware.iter().rev() {
251 entities = mw.post_process(ctx, entities)?;
252 }
253
254 Ok(entities)
255 }
256
257 /// Get the underlying backend.
258 #[must_use]
259 pub fn backend(&self) -> &dyn Model {
260 &*self.backend
261 }
262
263 /// List middleware names.
264 #[must_use]
265 pub fn middleware_names(&self) -> Vec<&'static str> {
266 self.middleware.iter().map(|m| m.name()).collect()
267 }
268}
269
270impl std::fmt::Debug for Pipeline {
271 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
272 f.debug_struct("Pipeline")
273 .field("middleware", &self.middleware_names())
274 .finish()
275 }
276}
277
278// =============================================================================
279// Built-in Middleware
280// =============================================================================
281
282/// Normalize whitespace in input text.
283///
284/// - Collapses multiple spaces into single spaces
285/// - Trims leading/trailing whitespace
286#[derive(Debug, Clone, Copy, Default)]
287pub struct NormalizeWhitespace;
288
289impl Middleware for NormalizeWhitespace {
290 fn pre_process<'a>(&self, _ctx: &mut MiddlewareContext, text: &'a str) -> Result<Cow<'a, str>> {
291 // Check if normalization is needed
292 let needs_normalization = text.contains(" ")
293 || text.starts_with(char::is_whitespace)
294 || text.ends_with(char::is_whitespace);
295
296 if needs_normalization {
297 let normalized: String = text.split_whitespace().collect::<Vec<_>>().join(" ");
298 Ok(Cow::Owned(normalized))
299 } else {
300 Ok(Cow::Borrowed(text))
301 }
302 }
303
304 fn name(&self) -> &'static str {
305 "normalize_whitespace"
306 }
307}
308
309/// Filter entities by minimum confidence threshold.
310#[derive(Debug, Clone, Copy)]
311pub struct FilterByConfidence(pub f64);
312
313impl Middleware for FilterByConfidence {
314 fn post_process(
315 &self,
316 _ctx: &mut MiddlewareContext,
317 entities: Vec<Entity>,
318 ) -> Result<Vec<Entity>> {
319 let threshold = self.0;
320 Ok(entities
321 .into_iter()
322 .filter(|e| e.confidence >= threshold)
323 .collect())
324 }
325
326 fn name(&self) -> &'static str {
327 "filter_by_confidence"
328 }
329}
330
331/// Filter entities by entity type.
332#[derive(Debug, Clone)]
333pub struct FilterByType(pub Vec<EntityType>);
334
335impl Middleware for FilterByType {
336 fn post_process(
337 &self,
338 _ctx: &mut MiddlewareContext,
339 entities: Vec<Entity>,
340 ) -> Result<Vec<Entity>> {
341 Ok(entities
342 .into_iter()
343 .filter(|e| self.0.contains(&e.entity_type))
344 .collect())
345 }
346
347 fn name(&self) -> &'static str {
348 "filter_by_type"
349 }
350}
351
352/// Remove overlapping entities, keeping the highest confidence one.
353#[derive(Debug, Clone, Copy, Default)]
354pub struct RemoveOverlaps;
355
356impl Middleware for RemoveOverlaps {
357 fn post_process(
358 &self,
359 _ctx: &mut MiddlewareContext,
360 mut entities: Vec<Entity>,
361 ) -> Result<Vec<Entity>> {
362 // Sort by confidence descending
363 entities.sort_by(|a, b| {
364 b.confidence
365 .partial_cmp(&a.confidence)
366 .unwrap_or(std::cmp::Ordering::Equal)
367 });
368
369 let mut result = Vec::with_capacity(entities.len());
370 for entity in entities {
371 let overlaps = result
372 .iter()
373 .any(|e: &Entity| entity.start < e.end && entity.end > e.start);
374 if !overlaps {
375 result.push(entity);
376 }
377 }
378
379 // Sort back by position
380 result.sort_by_key(|e| e.start);
381 Ok(result)
382 }
383
384 fn name(&self) -> &'static str {
385 "remove_overlaps"
386 }
387}
388
389/// Add provenance information to all entities.
390#[derive(Debug, Clone)]
391pub struct AddProvenance {
392 /// Backend name to set.
393 pub backend: String,
394 /// Method description.
395 pub method: String,
396}
397
398impl AddProvenance {
399 /// Create a new AddProvenance middleware.
400 #[must_use]
401 pub fn new(backend: impl Into<String>, method: impl Into<String>) -> Self {
402 Self {
403 backend: backend.into(),
404 method: method.into(),
405 }
406 }
407}
408
409impl Middleware for AddProvenance {
410 fn post_process(
411 &self,
412 _ctx: &mut MiddlewareContext,
413 mut entities: Vec<Entity>,
414 ) -> Result<Vec<Entity>> {
415 use anno_core::Provenance;
416 for entity in &mut entities {
417 if entity.provenance.is_none() {
418 entity.provenance = Some(Provenance::ml(self.backend.clone(), entity.confidence));
419 }
420 }
421 Ok(entities)
422 }
423
424 fn name(&self) -> &'static str {
425 "add_provenance"
426 }
427}
428
429/// Merge adjacent entities of the same type.
430///
431/// Useful for combining split entities like "New" + "York" → "New York".
432#[derive(Debug, Clone, Copy)]
433pub struct MergeAdjacent {
434 /// Maximum gap (in characters) between entities to merge.
435 pub max_gap: usize,
436}
437
438impl Default for MergeAdjacent {
439 fn default() -> Self {
440 Self { max_gap: 1 }
441 }
442}
443
444impl Middleware for MergeAdjacent {
445 fn post_process(
446 &self,
447 ctx: &mut MiddlewareContext,
448 mut entities: Vec<Entity>,
449 ) -> Result<Vec<Entity>> {
450 if entities.len() < 2 {
451 return Ok(entities);
452 }
453
454 // Sort by position
455 entities.sort_by_key(|e| e.start);
456
457 let text = &ctx.current_text;
458 let mut merged = Vec::with_capacity(entities.len());
459 let mut current: Option<Entity> = None;
460
461 for entity in entities {
462 if let Some(prev) = current.take() {
463 // Check if should merge
464 let gap = entity.start.saturating_sub(prev.end);
465 let same_type = prev.entity_type == entity.entity_type;
466
467 if same_type && gap <= self.max_gap {
468 // Merge entities
469 let merged_text = text
470 .chars()
471 .skip(prev.start)
472 .take(entity.end - prev.start)
473 .collect::<String>();
474 let merged_confidence = (prev.confidence + entity.confidence) / 2.0;
475
476 current = Some(Entity::new(
477 merged_text,
478 prev.entity_type,
479 prev.start,
480 entity.end,
481 merged_confidence,
482 ));
483 } else {
484 merged.push(prev);
485 current = Some(entity);
486 }
487 } else {
488 current = Some(entity);
489 }
490 }
491
492 if let Some(last) = current {
493 merged.push(last);
494 }
495
496 Ok(merged)
497 }
498
499 fn name(&self) -> &'static str {
500 "merge_adjacent"
501 }
502}
503
504/// Callback middleware for custom processing.
505///
506/// Wraps a closure for simple one-off transformations.
507pub struct Callback<F> {
508 name: &'static str,
509 func: F,
510}
511
512impl<F> Callback<F>
513where
514 F: Fn(&mut MiddlewareContext, Vec<Entity>) -> Result<Vec<Entity>> + Send + Sync,
515{
516 /// Create a new callback middleware.
517 #[must_use]
518 pub fn new(name: &'static str, func: F) -> Self {
519 Self { name, func }
520 }
521}
522
523impl<F> Middleware for Callback<F>
524where
525 F: Fn(&mut MiddlewareContext, Vec<Entity>) -> Result<Vec<Entity>> + Send + Sync,
526{
527 fn post_process(
528 &self,
529 ctx: &mut MiddlewareContext,
530 entities: Vec<Entity>,
531 ) -> Result<Vec<Entity>> {
532 (self.func)(ctx, entities)
533 }
534
535 fn name(&self) -> &'static str {
536 self.name
537 }
538}
539
540impl<F> std::fmt::Debug for Callback<F> {
541 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
542 f.debug_struct("Callback")
543 .field("name", &self.name)
544 .finish()
545 }
546}
547
548// =============================================================================
549// Hook System
550// =============================================================================
551
552/// Event types for hooks.
553#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
554pub enum HookEvent {
555 /// Before extraction starts.
556 BeforeExtraction,
557 /// After extraction completes.
558 AfterExtraction,
559 /// When an entity is found.
560 EntityFound,
561 /// When an error occurs.
562 OnError,
563}
564
565/// Hook function signature.
566pub type HookFn = Box<dyn Fn(HookEvent, &MiddlewareContext, Option<&[Entity]>) + Send + Sync>;
567
568/// Hook registry for pipeline events.
569pub struct HookRegistry {
570 hooks: HashMap<HookEvent, Vec<HookFn>>,
571}
572
573impl HookRegistry {
574 /// Create a new hook registry.
575 #[must_use]
576 pub fn new() -> Self {
577 Self {
578 hooks: HashMap::new(),
579 }
580 }
581
582 /// Register a hook for an event.
583 pub fn register(&mut self, event: HookEvent, hook: HookFn) {
584 self.hooks.entry(event).or_default().push(hook);
585 }
586
587 /// Trigger hooks for an event.
588 pub fn trigger(&self, event: HookEvent, ctx: &MiddlewareContext, entities: Option<&[Entity]>) {
589 if let Some(hooks) = self.hooks.get(&event) {
590 for hook in hooks {
591 hook(event, ctx, entities);
592 }
593 }
594 }
595}
596
597impl Default for HookRegistry {
598 fn default() -> Self {
599 Self::new()
600 }
601}
602
603impl std::fmt::Debug for HookRegistry {
604 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
605 f.debug_struct("HookRegistry")
606 .field("events", &self.hooks.keys().collect::<Vec<_>>())
607 .finish()
608 }
609}
610
611// =============================================================================
612// Hooked Pipeline (with interior mutability)
613// =============================================================================
614
615use std::cell::RefCell;
616
617/// A pipeline with hook support using interior mutability.
618///
619/// This addresses the borrow checker issues with the standard `HookRegistry` by:
620/// 1. Using `RefCell` for interior mutability of hook-related state
621/// 2. Separating "before" and "after" contexts to avoid simultaneous borrows
622/// 3. Using owned data in hook invocations to avoid reference conflicts
623///
624/// # Design
625///
626/// The `HookedPipeline` solves the problem where:
627/// - `extract_with_context` needs `&mut MiddlewareContext`
628/// - Hooks need to be called with context data during extraction
629/// - Multiple hooks may need to read entity data while context is borrowed
630///
631/// By cloning context data before passing to hooks, we avoid borrow conflicts.
632///
633/// # Example
634///
635/// ```rust,ignore
636/// use anno::backends::middleware::{HookedPipeline, HookEvent};
637/// use anno::HeuristicNER;
638///
639/// let mut pipeline = HookedPipeline::new(Box::new(HeuristicNER::new()));
640///
641/// pipeline.on(HookEvent::AfterExtraction, |event, text, entities| {
642/// if let Some(entities) = entities {
643/// println!("Found {} entities in: {}", entities.len(), text);
644/// }
645/// });
646///
647/// let entities = pipeline.extract("Hello World")?;
648/// ```
649pub struct HookedPipeline {
650 backend: Arc<dyn Model>,
651 middleware: Vec<Box<dyn Middleware>>,
652 /// Hook registry with interior mutability for safe hook registration during extraction
653 hooks: RefCell<HookRegistry>,
654}
655
656impl HookedPipeline {
657 /// Create a new hooked pipeline with a backend.
658 #[must_use]
659 pub fn new(backend: Box<dyn Model>) -> Self {
660 Self {
661 backend: Arc::from(backend),
662 middleware: Vec::new(),
663 hooks: RefCell::new(HookRegistry::new()),
664 }
665 }
666
667 /// Add middleware to the pipeline.
668 #[must_use]
669 pub fn with<M: Middleware + 'static>(mut self, middleware: M) -> Self {
670 self.middleware.push(Box::new(middleware));
671 self
672 }
673
674 /// Register a hook for an event.
675 ///
676 /// Hooks receive:
677 /// - `event`: The event type
678 /// - `text`: The current text being processed (owned copy)
679 /// - `entities`: Optional entities (owned copy)
680 ///
681 /// # Example
682 ///
683 /// ```rust,ignore
684 /// pipeline.on(HookEvent::EntityFound, |event, text, entities| {
685 /// if let Some(entities) = entities {
686 /// for entity in entities {
687 /// println!("Found: {} ({})", entity.text, entity.entity_type);
688 /// }
689 /// }
690 /// });
691 /// ```
692 pub fn on<F>(&self, event: HookEvent, handler: F)
693 where
694 F: Fn(HookEvent, &str, Option<&[Entity]>) + Send + Sync + 'static,
695 {
696 // Wrap the simpler handler in the full hook signature
697 let wrapper = Box::new(
698 move |evt: HookEvent, ctx: &MiddlewareContext, entities: Option<&[Entity]>| {
699 handler(evt, &ctx.current_text, entities);
700 },
701 );
702 self.hooks.borrow_mut().register(event, wrapper);
703 }
704
705 /// Register a hook with full context access.
706 pub fn on_with_context(&self, event: HookEvent, hook: HookFn) {
707 self.hooks.borrow_mut().register(event, hook);
708 }
709
710 /// Extract entities through the pipeline with hook support.
711 pub fn extract(&self, text: &str) -> Result<Vec<Entity>> {
712 let mut ctx = MiddlewareContext::new(text);
713
714 // Trigger before extraction hooks (with cloned context to avoid borrow issues)
715 {
716 let hooks = self.hooks.borrow();
717 hooks.trigger(HookEvent::BeforeExtraction, &ctx, None);
718 }
719
720 // Pre-process: each middleware transforms the text
721 let mut current_text = ctx.current_text.clone();
722 for mw in &self.middleware {
723 let result = mw.pre_process(&mut ctx, ¤t_text)?;
724 current_text = result.into_owned();
725 }
726 ctx.current_text = current_text;
727
728 // Extract entities from the backend
729 let entities = match self
730 .backend
731 .extract_entities(&ctx.current_text, ctx.language.as_deref())
732 {
733 Ok(entities) => entities,
734 Err(e) => {
735 // Trigger error hooks
736 let hooks = self.hooks.borrow();
737 hooks.trigger(HookEvent::OnError, &ctx, None);
738 return Err(e);
739 }
740 };
741
742 // Trigger entity found hooks for each entity
743 {
744 let hooks = self.hooks.borrow();
745 for entity in &entities {
746 hooks.trigger(
747 HookEvent::EntityFound,
748 &ctx,
749 Some(std::slice::from_ref(entity)),
750 );
751 }
752 }
753
754 // Post-process: each middleware transforms entities (reverse order)
755 let mut entities = entities;
756 for mw in self.middleware.iter().rev() {
757 entities = mw.post_process(&mut ctx, entities)?;
758 }
759
760 // Trigger after extraction hooks
761 {
762 let hooks = self.hooks.borrow();
763 hooks.trigger(HookEvent::AfterExtraction, &ctx, Some(&entities));
764 }
765
766 Ok(entities)
767 }
768
769 /// Get the underlying backend.
770 #[must_use]
771 pub fn backend(&self) -> &dyn Model {
772 &*self.backend
773 }
774
775 /// List middleware names.
776 #[must_use]
777 pub fn middleware_names(&self) -> Vec<&'static str> {
778 self.middleware.iter().map(|m| m.name()).collect()
779 }
780
781 /// Get the number of registered hooks.
782 #[must_use]
783 pub fn hook_count(&self) -> usize {
784 self.hooks.borrow().hooks.values().map(|v| v.len()).sum()
785 }
786}
787
788impl std::fmt::Debug for HookedPipeline {
789 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
790 f.debug_struct("HookedPipeline")
791 .field("middleware", &self.middleware_names())
792 .field("hooks", &*self.hooks.borrow())
793 .finish()
794 }
795}
796
797// =============================================================================
798// Tests
799// =============================================================================
800
801#[cfg(test)]
802mod tests;