1use crate::{Entity, EntityType, Model, Result};
69use std::borrow::Cow;
70use std::collections::HashMap;
71use std::sync::Arc;
72
73#[derive(Debug, Clone)]
84pub struct MiddlewareContext {
85 pub original_text: String,
87 pub current_text: String,
89 pub entity_types: Option<Vec<EntityType>>,
91 pub language: Option<String>,
93 pub metadata: HashMap<String, String>,
95}
96
97impl MiddlewareContext {
98 #[must_use]
100 pub fn new(text: impl Into<String>) -> Self {
101 let text = text.into();
102 Self {
103 original_text: text.clone(),
104 current_text: text,
105 entity_types: None,
106 language: None,
107 metadata: HashMap::new(),
108 }
109 }
110
111 #[must_use]
113 pub fn with_language(mut self, lang: impl Into<String>) -> Self {
114 self.language = Some(lang.into());
115 self
116 }
117
118 #[must_use]
120 pub fn with_entity_types(mut self, types: Vec<EntityType>) -> Self {
121 self.entity_types = Some(types);
122 self
123 }
124
125 pub fn set_metadata(&mut self, key: impl Into<String>, value: impl Into<String>) {
127 self.metadata.insert(key.into(), value.into());
128 }
129}
130
131pub trait Middleware: Send + Sync {
142 fn pre_process<'a>(&self, ctx: &mut MiddlewareContext, text: &'a str) -> Result<Cow<'a, str>> {
147 let _ = ctx;
148 Ok(Cow::Borrowed(text))
149 }
150
151 fn post_process(
156 &self,
157 ctx: &mut MiddlewareContext,
158 entities: Vec<Entity>,
159 ) -> Result<Vec<Entity>> {
160 let _ = ctx;
161 Ok(entities)
162 }
163
164 fn name(&self) -> &'static str {
166 "unnamed"
167 }
168}
169
170pub struct Pipeline {
195 backend: Arc<dyn Model>,
196 middleware: Vec<Box<dyn Middleware>>,
197}
198
199impl Pipeline {
200 #[must_use]
202 pub fn new(backend: Box<dyn Model>) -> Self {
203 Self {
204 backend: Arc::from(backend),
205 middleware: Vec::new(),
206 }
207 }
208
209 #[must_use]
211 pub fn with<M: Middleware + 'static>(mut self, middleware: M) -> Self {
212 self.middleware.push(Box::new(middleware));
213 self
214 }
215
216 #[must_use]
218 pub fn with_if<M: Middleware + 'static>(self, condition: bool, middleware: M) -> Self {
219 if condition {
220 self.with(middleware)
221 } else {
222 self
223 }
224 }
225
226 pub fn extract(&self, text: &str) -> Result<Vec<Entity>> {
228 self.extract_with_context(&mut MiddlewareContext::new(text))
229 }
230
231 pub fn extract_with_context(&self, ctx: &mut MiddlewareContext) -> Result<Vec<Entity>> {
233 let mut current_text = ctx.current_text.clone();
236 for mw in &self.middleware {
237 let result = mw.pre_process(ctx, ¤t_text)?;
238 current_text = result.into_owned();
239 }
240
241 ctx.current_text = current_text;
243
244 let mut entities = self
246 .backend
247 .extract_entities(&ctx.current_text, ctx.language.as_deref())?;
248
249 for mw in self.middleware.iter().rev() {
251 entities = mw.post_process(ctx, entities)?;
252 }
253
254 Ok(entities)
255 }
256
257 #[must_use]
259 pub fn backend(&self) -> &dyn Model {
260 &*self.backend
261 }
262
263 #[must_use]
265 pub fn middleware_names(&self) -> Vec<&'static str> {
266 self.middleware.iter().map(|m| m.name()).collect()
267 }
268}
269
270impl std::fmt::Debug for Pipeline {
271 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
272 f.debug_struct("Pipeline")
273 .field("middleware", &self.middleware_names())
274 .finish()
275 }
276}
277
278#[derive(Debug, Clone, Copy, Default)]
287pub struct NormalizeWhitespace;
288
289impl Middleware for NormalizeWhitespace {
290 fn pre_process<'a>(&self, _ctx: &mut MiddlewareContext, text: &'a str) -> Result<Cow<'a, str>> {
291 let needs_normalization = text.contains(" ")
293 || text.starts_with(char::is_whitespace)
294 || text.ends_with(char::is_whitespace);
295
296 if needs_normalization {
297 let normalized: String = text.split_whitespace().collect::<Vec<_>>().join(" ");
298 Ok(Cow::Owned(normalized))
299 } else {
300 Ok(Cow::Borrowed(text))
301 }
302 }
303
304 fn name(&self) -> &'static str {
305 "normalize_whitespace"
306 }
307}
308
309#[derive(Debug, Clone, Copy)]
311pub struct FilterByConfidence(pub f64);
312
313impl Middleware for FilterByConfidence {
314 fn post_process(
315 &self,
316 _ctx: &mut MiddlewareContext,
317 entities: Vec<Entity>,
318 ) -> Result<Vec<Entity>> {
319 let threshold = self.0;
320 Ok(entities
321 .into_iter()
322 .filter(|e| e.confidence >= threshold)
323 .collect())
324 }
325
326 fn name(&self) -> &'static str {
327 "filter_by_confidence"
328 }
329}
330
331#[derive(Debug, Clone)]
333pub struct FilterByType(pub Vec<EntityType>);
334
335impl Middleware for FilterByType {
336 fn post_process(
337 &self,
338 _ctx: &mut MiddlewareContext,
339 entities: Vec<Entity>,
340 ) -> Result<Vec<Entity>> {
341 Ok(entities
342 .into_iter()
343 .filter(|e| self.0.contains(&e.entity_type))
344 .collect())
345 }
346
347 fn name(&self) -> &'static str {
348 "filter_by_type"
349 }
350}
351
352#[derive(Debug, Clone, Copy, Default)]
354pub struct RemoveOverlaps;
355
356impl Middleware for RemoveOverlaps {
357 fn post_process(
358 &self,
359 _ctx: &mut MiddlewareContext,
360 mut entities: Vec<Entity>,
361 ) -> Result<Vec<Entity>> {
362 entities.sort_by(|a, b| {
364 b.confidence
365 .partial_cmp(&a.confidence)
366 .unwrap_or(std::cmp::Ordering::Equal)
367 });
368
369 let mut result = Vec::with_capacity(entities.len());
370 for entity in entities {
371 let overlaps = result
372 .iter()
373 .any(|e: &Entity| entity.start < e.end && entity.end > e.start);
374 if !overlaps {
375 result.push(entity);
376 }
377 }
378
379 result.sort_by_key(|e| e.start);
381 Ok(result)
382 }
383
384 fn name(&self) -> &'static str {
385 "remove_overlaps"
386 }
387}
388
389#[derive(Debug, Clone)]
391pub struct AddProvenance {
392 pub backend: String,
394 pub method: String,
396}
397
398impl AddProvenance {
399 #[must_use]
401 pub fn new(backend: impl Into<String>, method: impl Into<String>) -> Self {
402 Self {
403 backend: backend.into(),
404 method: method.into(),
405 }
406 }
407}
408
409impl Middleware for AddProvenance {
410 fn post_process(
411 &self,
412 _ctx: &mut MiddlewareContext,
413 mut entities: Vec<Entity>,
414 ) -> Result<Vec<Entity>> {
415 use anno_core::Provenance;
416 for entity in &mut entities {
417 if entity.provenance.is_none() {
418 entity.provenance = Some(Provenance::ml(self.backend.clone(), entity.confidence));
419 }
420 }
421 Ok(entities)
422 }
423
424 fn name(&self) -> &'static str {
425 "add_provenance"
426 }
427}
428
429#[derive(Debug, Clone, Copy)]
433pub struct MergeAdjacent {
434 pub max_gap: usize,
436}
437
438impl Default for MergeAdjacent {
439 fn default() -> Self {
440 Self { max_gap: 1 }
441 }
442}
443
444impl Middleware for MergeAdjacent {
445 fn post_process(
446 &self,
447 ctx: &mut MiddlewareContext,
448 mut entities: Vec<Entity>,
449 ) -> Result<Vec<Entity>> {
450 if entities.len() < 2 {
451 return Ok(entities);
452 }
453
454 entities.sort_by_key(|e| e.start);
456
457 let text = &ctx.current_text;
458 let mut merged = Vec::with_capacity(entities.len());
459 let mut current: Option<Entity> = None;
460
461 for entity in entities {
462 if let Some(prev) = current.take() {
463 let gap = entity.start.saturating_sub(prev.end);
465 let same_type = prev.entity_type == entity.entity_type;
466
467 if same_type && gap <= self.max_gap {
468 let merged_text = text
470 .chars()
471 .skip(prev.start)
472 .take(entity.end - prev.start)
473 .collect::<String>();
474 let merged_confidence = (prev.confidence + entity.confidence) / 2.0;
475
476 current = Some(Entity::new(
477 merged_text,
478 prev.entity_type,
479 prev.start,
480 entity.end,
481 merged_confidence,
482 ));
483 } else {
484 merged.push(prev);
485 current = Some(entity);
486 }
487 } else {
488 current = Some(entity);
489 }
490 }
491
492 if let Some(last) = current {
493 merged.push(last);
494 }
495
496 Ok(merged)
497 }
498
499 fn name(&self) -> &'static str {
500 "merge_adjacent"
501 }
502}
503
504pub struct Callback<F> {
508 name: &'static str,
509 func: F,
510}
511
512impl<F> Callback<F>
513where
514 F: Fn(&mut MiddlewareContext, Vec<Entity>) -> Result<Vec<Entity>> + Send + Sync,
515{
516 #[must_use]
518 pub fn new(name: &'static str, func: F) -> Self {
519 Self { name, func }
520 }
521}
522
523impl<F> Middleware for Callback<F>
524where
525 F: Fn(&mut MiddlewareContext, Vec<Entity>) -> Result<Vec<Entity>> + Send + Sync,
526{
527 fn post_process(
528 &self,
529 ctx: &mut MiddlewareContext,
530 entities: Vec<Entity>,
531 ) -> Result<Vec<Entity>> {
532 (self.func)(ctx, entities)
533 }
534
535 fn name(&self) -> &'static str {
536 self.name
537 }
538}
539
540impl<F> std::fmt::Debug for Callback<F> {
541 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
542 f.debug_struct("Callback")
543 .field("name", &self.name)
544 .finish()
545 }
546}
547
548#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
554pub enum HookEvent {
555 BeforeExtraction,
557 AfterExtraction,
559 EntityFound,
561 OnError,
563}
564
565pub type HookFn = Box<dyn Fn(HookEvent, &MiddlewareContext, Option<&[Entity]>) + Send + Sync>;
567
568pub struct HookRegistry {
570 hooks: HashMap<HookEvent, Vec<HookFn>>,
571}
572
573impl HookRegistry {
574 #[must_use]
576 pub fn new() -> Self {
577 Self {
578 hooks: HashMap::new(),
579 }
580 }
581
582 pub fn register(&mut self, event: HookEvent, hook: HookFn) {
584 self.hooks.entry(event).or_default().push(hook);
585 }
586
587 pub fn trigger(&self, event: HookEvent, ctx: &MiddlewareContext, entities: Option<&[Entity]>) {
589 if let Some(hooks) = self.hooks.get(&event) {
590 for hook in hooks {
591 hook(event, ctx, entities);
592 }
593 }
594 }
595}
596
597impl Default for HookRegistry {
598 fn default() -> Self {
599 Self::new()
600 }
601}
602
603impl std::fmt::Debug for HookRegistry {
604 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
605 f.debug_struct("HookRegistry")
606 .field("events", &self.hooks.keys().collect::<Vec<_>>())
607 .finish()
608 }
609}
610
611use std::cell::RefCell;
616
617pub struct HookedPipeline {
650 backend: Arc<dyn Model>,
651 middleware: Vec<Box<dyn Middleware>>,
652 hooks: RefCell<HookRegistry>,
654}
655
656impl HookedPipeline {
657 #[must_use]
659 pub fn new(backend: Box<dyn Model>) -> Self {
660 Self {
661 backend: Arc::from(backend),
662 middleware: Vec::new(),
663 hooks: RefCell::new(HookRegistry::new()),
664 }
665 }
666
667 #[must_use]
669 pub fn with<M: Middleware + 'static>(mut self, middleware: M) -> Self {
670 self.middleware.push(Box::new(middleware));
671 self
672 }
673
674 pub fn on<F>(&self, event: HookEvent, handler: F)
693 where
694 F: Fn(HookEvent, &str, Option<&[Entity]>) + Send + Sync + 'static,
695 {
696 let wrapper = Box::new(
698 move |evt: HookEvent, ctx: &MiddlewareContext, entities: Option<&[Entity]>| {
699 handler(evt, &ctx.current_text, entities);
700 },
701 );
702 self.hooks.borrow_mut().register(event, wrapper);
703 }
704
705 pub fn on_with_context(&self, event: HookEvent, hook: HookFn) {
707 self.hooks.borrow_mut().register(event, hook);
708 }
709
710 pub fn extract(&self, text: &str) -> Result<Vec<Entity>> {
712 let mut ctx = MiddlewareContext::new(text);
713
714 {
716 let hooks = self.hooks.borrow();
717 hooks.trigger(HookEvent::BeforeExtraction, &ctx, None);
718 }
719
720 let mut current_text = ctx.current_text.clone();
722 for mw in &self.middleware {
723 let result = mw.pre_process(&mut ctx, ¤t_text)?;
724 current_text = result.into_owned();
725 }
726 ctx.current_text = current_text;
727
728 let entities = match self
730 .backend
731 .extract_entities(&ctx.current_text, ctx.language.as_deref())
732 {
733 Ok(entities) => entities,
734 Err(e) => {
735 let hooks = self.hooks.borrow();
737 hooks.trigger(HookEvent::OnError, &ctx, None);
738 return Err(e);
739 }
740 };
741
742 {
744 let hooks = self.hooks.borrow();
745 for entity in &entities {
746 hooks.trigger(
747 HookEvent::EntityFound,
748 &ctx,
749 Some(std::slice::from_ref(entity)),
750 );
751 }
752 }
753
754 let mut entities = entities;
756 for mw in self.middleware.iter().rev() {
757 entities = mw.post_process(&mut ctx, entities)?;
758 }
759
760 {
762 let hooks = self.hooks.borrow();
763 hooks.trigger(HookEvent::AfterExtraction, &ctx, Some(&entities));
764 }
765
766 Ok(entities)
767 }
768
769 #[must_use]
771 pub fn backend(&self) -> &dyn Model {
772 &*self.backend
773 }
774
775 #[must_use]
777 pub fn middleware_names(&self) -> Vec<&'static str> {
778 self.middleware.iter().map(|m| m.name()).collect()
779 }
780
781 #[must_use]
783 pub fn hook_count(&self) -> usize {
784 self.hooks.borrow().hooks.values().map(|v| v.len()).sum()
785 }
786}
787
788impl std::fmt::Debug for HookedPipeline {
789 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
790 f.debug_struct("HookedPipeline")
791 .field("middleware", &self.middleware_names())
792 .field("hooks", &*self.hooks.borrow())
793 .finish()
794 }
795}
796
797#[cfg(test)]
802mod tests {
803 use super::*;
804 use crate::HeuristicNER;
805
806 #[test]
807 fn test_normalize_whitespace() {
808 let mw = NormalizeWhitespace;
809 let mut ctx = MiddlewareContext::new(" hello world ");
810 let text = ctx.original_text.clone();
811 let result = mw
812 .pre_process(&mut ctx, &text)
813 .expect("pre_process should succeed");
814 assert_eq!(result, "hello world");
815 }
816
817 #[test]
818 fn test_filter_by_confidence() {
819 let mw = FilterByConfidence(0.5);
820 let mut ctx = MiddlewareContext::new("test");
821 let entities = vec![
822 Entity::new("high", EntityType::Person, 0, 4, 0.8),
823 Entity::new("low", EntityType::Person, 5, 8, 0.3),
824 ];
825 let result = mw
826 .post_process(&mut ctx, entities)
827 .expect("post_process should succeed");
828 assert_eq!(result.len(), 1);
829 assert_eq!(result[0].text, "high");
830 }
831
832 #[test]
833 fn test_pipeline_basic() {
834 let pipeline = Pipeline::new(Box::new(HeuristicNER::new()))
835 .with(NormalizeWhitespace)
836 .with(FilterByConfidence(0.3));
837
838 let _entities = pipeline
839 .extract("Hello World")
840 .expect("extraction should succeed");
841 }
843
844 #[test]
845 fn test_remove_overlaps() {
846 let mw = RemoveOverlaps;
847 let mut ctx = MiddlewareContext::new("New York City");
848 let entities = vec![
849 Entity::new("New York", EntityType::Location, 0, 8, 0.9),
850 Entity::new("York City", EntityType::Location, 4, 13, 0.7),
851 ];
852 let result = mw
853 .post_process(&mut ctx, entities)
854 .expect("post_process should succeed");
855 assert_eq!(result.len(), 1);
856 assert_eq!(result[0].text, "New York"); }
858
859 #[test]
860 fn test_hooked_pipeline_basic() {
861 use std::sync::atomic::{AtomicUsize, Ordering};
862 use std::sync::Arc;
863
864 let pipeline = HookedPipeline::new(Box::new(HeuristicNER::new())).with(NormalizeWhitespace);
865
866 let before_count = Arc::new(AtomicUsize::new(0));
868 let after_count = Arc::new(AtomicUsize::new(0));
869
870 let before_count_clone = Arc::clone(&before_count);
871 pipeline.on(HookEvent::BeforeExtraction, move |_, _, _| {
872 before_count_clone.fetch_add(1, Ordering::SeqCst);
873 });
874
875 let after_count_clone = Arc::clone(&after_count);
876 pipeline.on(HookEvent::AfterExtraction, move |_, _, _| {
877 after_count_clone.fetch_add(1, Ordering::SeqCst);
878 });
879
880 let _entities = pipeline.extract("Hello World").unwrap();
881
882 assert_eq!(before_count.load(Ordering::SeqCst), 1);
883 assert_eq!(after_count.load(Ordering::SeqCst), 1);
884 }
885
886 #[test]
887 fn test_hooked_pipeline_entity_found_hook() {
888 use std::sync::atomic::{AtomicUsize, Ordering};
889 use std::sync::Arc;
890
891 let pipeline = HookedPipeline::new(Box::new(HeuristicNER::new()));
892
893 let entity_count = Arc::new(AtomicUsize::new(0));
894 let entity_count_clone = Arc::clone(&entity_count);
895
896 pipeline.on(HookEvent::EntityFound, move |_, _, entities| {
897 if entities.is_some() {
898 entity_count_clone.fetch_add(1, Ordering::SeqCst);
899 }
900 });
901
902 let _entities = pipeline.extract("John Smith went to New York").unwrap();
904
905 assert!(entity_count.load(Ordering::SeqCst) > 0);
907 }
908
909 #[test]
910 fn test_hooked_pipeline_with_middleware() {
911 let pipeline = HookedPipeline::new(Box::new(HeuristicNER::new()))
912 .with(NormalizeWhitespace)
913 .with(FilterByConfidence(0.3));
914
915 let entities = pipeline
916 .extract(" John Smith ")
917 .expect("extraction should succeed");
918 let _ = entities;
921 }
922
923 #[test]
924 fn test_hooked_pipeline_hook_count() {
925 let pipeline = HookedPipeline::new(Box::new(HeuristicNER::new()));
926
927 assert_eq!(pipeline.hook_count(), 0);
928
929 pipeline.on(HookEvent::BeforeExtraction, |_, _, _| {});
930 pipeline.on(HookEvent::AfterExtraction, |_, _, _| {});
931 pipeline.on(HookEvent::EntityFound, |_, _, _| {});
932
933 assert_eq!(pipeline.hook_count(), 3);
934 }
935}