Skip to main content

anno/backends/
streaming.rs

1//! Streaming NER API for incremental entity extraction.
2//!
3//! Provides iterator-based entity extraction for large documents,
4//! real-time text streams, or memory-constrained environments.
5//!
6//! # Overview
7//!
8//! Standard NER processes entire documents at once, which can be slow and
9//! memory-intensive for large texts. The streaming API offers:
10//!
11//! - **Chunked processing**: Split text into manageable chunks
12//! - **Iterator interface**: Lazily yield entities as they're found
13//! - **Backpressure**: Consumer controls the pace of extraction
14//! - **Stateful context**: Maintain context across chunk boundaries
15//!
16//! # Example
17//!
18//! ```rust,ignore
19//! use anno::backends::streaming::{StreamingExtractor, ChunkConfig};
20//! use anno::StackedNER;
21//!
22//! let backend = StackedNER::default();
23//! let config = ChunkConfig::default();
24//!
25//! // Process large text in chunks
26//! let extractor = StreamingExtractor::new(&backend, config);
27//! for entity in extractor.extract("Very long text...") {
28//!     println!("Found: {} at {}-{}", entity.text, entity.start, entity.end);
29//! }
30//! ```
31//!
32//! # Pipeline Integration
33//!
34//! The streaming API integrates with async pipelines:
35//!
36//! ```rust,ignore
37//! use futures::StreamExt;
38//!
39//! let stream = extractor.extract_stream(text);
40//! while let Some(entity) = stream.next().await {
41//!     process(entity);
42//! }
43//! ```
44
45use crate::{Entity, Model, Result};
46
47// Semantic chunking integration pending
48// #[cfg(feature = "semantic-chunking")]
49// use crate::backends::semantic_chunking::{SemanticChunkConfig, SemanticChunker};
50
51/// Configuration for chunked text processing.
52#[derive(Debug, Clone)]
53pub struct ChunkConfig {
54    /// Target chunk size in characters (actual may vary to avoid splitting words)
55    pub chunk_size: usize,
56    /// Overlap between chunks (characters) to catch entities at boundaries
57    pub overlap: usize,
58    /// Sentence boundary detection (if true, chunks end at sentence boundaries)
59    pub respect_sentences: bool,
60    /// Maximum entities to buffer before yielding
61    pub buffer_size: usize,
62}
63
64impl Default for ChunkConfig {
65    fn default() -> Self {
66        Self {
67            chunk_size: 10_000,
68            overlap: 100,
69            respect_sentences: true,
70            buffer_size: 1000,
71        }
72    }
73}
74
75impl ChunkConfig {
76    /// Create a config for small documents (no chunking).
77    pub fn no_chunking() -> Self {
78        Self {
79            chunk_size: usize::MAX,
80            overlap: 0,
81            respect_sentences: false,
82            buffer_size: usize::MAX,
83        }
84    }
85
86    /// Create a config optimized for long documents.
87    pub fn long_document() -> Self {
88        Self {
89            chunk_size: 50_000,
90            overlap: 200,
91            respect_sentences: true,
92            buffer_size: 5000,
93        }
94    }
95
96    /// Create a config for real-time/streaming input.
97    pub fn realtime() -> Self {
98        Self {
99            chunk_size: 1000,
100            overlap: 50,
101            respect_sentences: false,
102            buffer_size: 100,
103        }
104    }
105}
106
107/// A streaming entity extractor that processes text in chunks.
108#[derive(Debug)]
109pub struct StreamingExtractor<'m, M: Model> {
110    model: &'m M,
111    config: ChunkConfig,
112}
113
114impl<'m, M: Model> StreamingExtractor<'m, M> {
115    /// Create a new streaming extractor with the given model and config.
116    pub fn new(model: &'m M, config: ChunkConfig) -> Self {
117        Self { model, config }
118    }
119
120    /// Create with default config.
121    pub fn with_model(model: &'m M) -> Self {
122        Self::new(model, ChunkConfig::default())
123    }
124
125    /// Extract entities from text, yielding them as an iterator.
126    pub fn extract<'t>(&'m self, text: &'t str) -> EntityIterator<'m, 't, M> {
127        EntityIterator::new(self, text)
128    }
129
130    /// Process a single chunk and return entities with adjusted offsets.
131    fn process_chunk(&self, chunk: &str, offset: usize) -> Result<Vec<Entity>> {
132        let entities = self.model.extract_entities(chunk, None)?;
133
134        // Adjust offsets to be relative to original text
135        Ok(entities
136            .into_iter()
137            .map(|mut e| {
138                e.start += offset;
139                e.end += offset;
140                e
141            })
142            .collect())
143    }
144}
145
146/// Iterator over entities extracted from text.
147pub struct EntityIterator<'m, 't, M: Model> {
148    extractor: &'m StreamingExtractor<'m, M>,
149    text: &'t str,
150    /// Current position in text (character offset)
151    position: usize,
152    /// Buffer of entities from current chunk
153    buffer: Vec<Entity>,
154    /// Index into buffer
155    buffer_idx: usize,
156    /// Set of (start, end) pairs already yielded (for deduplication)
157    seen: std::collections::HashSet<(usize, usize)>,
158    /// Whether we've finished processing
159    done: bool,
160}
161
162impl<'m, 't, M: Model> EntityIterator<'m, 't, M> {
163    fn new(extractor: &'m StreamingExtractor<'m, M>, text: &'t str) -> Self {
164        Self {
165            extractor,
166            text,
167            position: 0,
168            buffer: Vec::new(),
169            buffer_idx: 0,
170            seen: std::collections::HashSet::new(),
171            done: false,
172        }
173    }
174
175    /// Fill the buffer with entities from the next chunk.
176    fn fill_buffer(&mut self) -> Result<()> {
177        if self.done {
178            return Ok(());
179        }
180
181        let text_chars: Vec<char> = self.text.chars().collect();
182        let text_len = text_chars.len();
183
184        if self.position >= text_len {
185            self.done = true;
186            return Ok(());
187        }
188
189        // Calculate chunk boundaries
190        let chunk_end = (self.position + self.extractor.config.chunk_size).min(text_len);
191
192        // Find a good break point (sentence boundary or word boundary)
193        let actual_end = if self.extractor.config.respect_sentences {
194            find_sentence_boundary(&text_chars, self.position, chunk_end)
195        } else {
196            find_word_boundary(&text_chars, chunk_end)
197        };
198
199        // Extract the chunk
200        let chunk: String = text_chars[self.position..actual_end].iter().collect();
201
202        // Process chunk
203        let entities = self.extractor.process_chunk(&chunk, self.position)?;
204
205        // Filter out entities we've already seen (from overlap regions)
206        self.buffer = entities
207            .into_iter()
208            .filter(|e| !self.seen.contains(&(e.start, e.end)))
209            .collect();
210
211        // Mark these entities as seen
212        for e in &self.buffer {
213            self.seen.insert((e.start, e.end));
214        }
215
216        self.buffer_idx = 0;
217
218        // Move position forward (with overlap for next chunk)
219        // CRITICAL: Always ensure we make forward progress to avoid infinite loops
220        let overlap = self.extractor.config.overlap;
221        let new_position = if actual_end >= text_len {
222            text_len
223        } else {
224            // Ensure we always advance by at least 1 character
225            let overlap_position = actual_end.saturating_sub(overlap);
226            // If overlap would cause us to not advance, force forward progress
227            if overlap_position <= self.position {
228                self.position + 1
229            } else {
230                overlap_position
231            }
232        };
233
234        self.position = new_position;
235
236        if actual_end >= text_len || self.position >= text_len {
237            self.done = true;
238        }
239
240        Ok(())
241    }
242}
243
244impl<'m, 't, M: Model> Iterator for EntityIterator<'m, 't, M> {
245    type Item = Entity;
246
247    fn next(&mut self) -> Option<Self::Item> {
248        loop {
249            // Return from buffer if available
250            if self.buffer_idx < self.buffer.len() {
251                let entity = self.buffer[self.buffer_idx].clone();
252                self.buffer_idx += 1;
253                return Some(entity);
254            }
255
256            // Buffer empty, try to fill it
257            if self.done {
258                return None;
259            }
260
261            if self.fill_buffer().is_err() {
262                self.done = true;
263                return None;
264            }
265
266            // If buffer is still empty after fill, we're done
267            if self.buffer.is_empty() && self.done {
268                return None;
269            }
270        }
271    }
272}
273
274/// Find a sentence boundary near the target position.
275fn find_sentence_boundary(chars: &[char], start: usize, target: usize) -> usize {
276    // Look backwards from target for sentence-ending punctuation
277    let search_start = target.saturating_sub(200);
278    for i in (search_start..target).rev() {
279        if i >= chars.len() {
280            continue;
281        }
282        let c = chars[i];
283        // Sentence boundaries: . ! ? followed by whitespace or end
284        if (c == '.' || c == '!' || c == '?' || c == '。' || c == '!' || c == '?')
285            && (i + 1 >= chars.len() || chars[i + 1].is_whitespace())
286        {
287            // Return position after the punctuation and whitespace
288            let mut end = i + 1;
289            while end < chars.len() && chars[end].is_whitespace() {
290                end += 1;
291            }
292            if end > start {
293                return end;
294            }
295        }
296    }
297    // No sentence boundary found, fall back to word boundary
298    find_word_boundary(chars, target)
299}
300
301/// Find a word boundary near the target position.
302fn find_word_boundary(chars: &[char], target: usize) -> usize {
303    let target = target.min(chars.len());
304
305    // If we're already at end, return it
306    if target >= chars.len() {
307        return chars.len();
308    }
309
310    // Look backwards for whitespace
311    for i in (0..target).rev() {
312        if chars[i].is_whitespace() {
313            return i + 1;
314        }
315    }
316    target
317}
318
319// =============================================================================
320// Async Stream Support (requires tokio/async-std)
321// =============================================================================
322
323/// Async streaming adapters for `StreamingExtractor`.
324#[cfg(feature = "production")]
325pub mod async_stream {
326    use super::*;
327    use futures::stream::{self, Stream};
328
329    impl<'m, M: Model + Sync> StreamingExtractor<'m, M> {
330        /// Create an async stream of entities.
331        pub fn extract_stream<'t>(&'m self, text: &'t str) -> impl Stream<Item = Entity> + 'm
332        where
333            't: 'm,
334        {
335            let iter = self.extract(text);
336            stream::iter(iter)
337        }
338    }
339}
340
341// =============================================================================
342// Pipeline Integration Hooks
343// =============================================================================
344
345/// A processing stage in an NER pipeline.
346pub trait PipelineStage: Send + Sync {
347    /// Process entities before they're returned.
348    fn process(&self, entities: Vec<Entity>, text: &str) -> Vec<Entity>;
349
350    /// Name of this stage (for debugging/logging).
351    fn name(&self) -> &'static str;
352}
353
354/// A complete NER pipeline with preprocessing and postprocessing stages.
355pub struct Pipeline<M: Model> {
356    model: M,
357    /// Stages that run after entity extraction
358    post_stages: Vec<Box<dyn PipelineStage>>,
359    /// Chunk configuration for streaming
360    chunk_config: ChunkConfig,
361}
362
363impl<M: Model> Pipeline<M> {
364    /// Create a new pipeline with the given model.
365    pub fn new(model: M) -> Self {
366        Self {
367            model,
368            post_stages: Vec::new(),
369            chunk_config: ChunkConfig::default(),
370        }
371    }
372
373    /// Add a post-processing stage.
374    pub fn add_stage(mut self, stage: Box<dyn PipelineStage>) -> Self {
375        self.post_stages.push(stage);
376        self
377    }
378
379    /// Set chunk configuration.
380    pub fn with_chunk_config(mut self, config: ChunkConfig) -> Self {
381        self.chunk_config = config;
382        self
383    }
384
385    /// Extract entities with all pipeline stages applied.
386    pub fn extract(&self, text: &str) -> Result<Vec<Entity>> {
387        let mut entities = self.model.extract_entities(text, None)?;
388
389        for stage in &self.post_stages {
390            entities = stage.process(entities, text);
391        }
392
393        Ok(entities)
394    }
395
396    /// Get a reference to the underlying model.
397    pub fn model(&self) -> &M {
398        &self.model
399    }
400}
401
402// =============================================================================
403// Common Pipeline Stages
404// =============================================================================
405
406/// Filter entities by confidence threshold.
407pub struct ConfidenceFilter {
408    threshold: f64,
409}
410
411impl ConfidenceFilter {
412    /// Create a new confidence filter with the given threshold.
413    pub fn new(threshold: f64) -> Self {
414        Self { threshold }
415    }
416}
417
418impl PipelineStage for ConfidenceFilter {
419    fn process(&self, entities: Vec<Entity>, _text: &str) -> Vec<Entity> {
420        entities
421            .into_iter()
422            .filter(|e| e.confidence >= self.threshold)
423            .collect()
424    }
425
426    fn name(&self) -> &'static str {
427        "ConfidenceFilter"
428    }
429}
430
431/// Deduplicate overlapping entities, keeping highest confidence.
432pub struct DeduplicateOverlapping;
433
434impl PipelineStage for DeduplicateOverlapping {
435    fn process(&self, mut entities: Vec<Entity>, _text: &str) -> Vec<Entity> {
436        // Sort by start, then by confidence (desc)
437        entities.sort_by(|a, b| {
438            a.start.cmp(&b.start).then(
439                b.confidence
440                    .partial_cmp(&a.confidence)
441                    .expect("confidence values should be comparable"),
442            )
443        });
444
445        let mut result = Vec::new();
446        let mut last_end = 0;
447
448        for entity in entities {
449            if entity.start >= last_end {
450                last_end = entity.end;
451                result.push(entity);
452            }
453            // Skip overlapping entities (we already have a higher-confidence one)
454        }
455
456        result
457    }
458
459    fn name(&self) -> &'static str {
460        "DeduplicateOverlapping"
461    }
462}
463
464/// Normalize entity text (trim whitespace, normalize case, etc.).
465pub struct NormalizeText {
466    lowercase: bool,
467}
468
469impl NormalizeText {
470    /// Create a new text normalizer with optional lowercasing.
471    pub fn new(lowercase: bool) -> Self {
472        Self { lowercase }
473    }
474}
475
476impl PipelineStage for NormalizeText {
477    fn process(&self, entities: Vec<Entity>, _text: &str) -> Vec<Entity> {
478        entities
479            .into_iter()
480            .map(|mut e| {
481                e.text = e.text.trim().to_string();
482                if self.lowercase {
483                    e.text = e.text.to_lowercase();
484                }
485                e
486            })
487            .collect()
488    }
489
490    fn name(&self) -> &'static str {
491        "NormalizeText"
492    }
493}
494
495#[cfg(test)]
496mod tests {
497    use super::*;
498    use crate::HeuristicNER;
499
500    #[test]
501    fn test_streaming_basic() {
502        let model = HeuristicNER::new();
503        let extractor = StreamingExtractor::with_model(&model);
504
505        let text = "John Smith works at Google Inc. in New York.";
506        let entities: Vec<Entity> = extractor.extract(text).collect();
507
508        assert!(!entities.is_empty());
509    }
510
511    #[test]
512    fn test_streaming_long_text() {
513        let model = HeuristicNER::new();
514        let config = ChunkConfig {
515            chunk_size: 50,
516            overlap: 10,
517            respect_sentences: false,
518            buffer_size: 100,
519        };
520        let extractor = StreamingExtractor::new(&model, config);
521
522        // Create a longer text
523        let text =
524            "John Smith works at Google. Mary Johnson is at Apple. Bob Williams joined Microsoft.";
525        let entities: Vec<Entity> = extractor.extract(text).collect();
526
527        // Should find entities across chunks
528        assert!(!entities.is_empty());
529    }
530
531    #[test]
532    fn test_pipeline() {
533        let model = HeuristicNER::new();
534        let pipeline = Pipeline::new(model)
535            .add_stage(Box::new(ConfidenceFilter::new(0.5)))
536            .add_stage(Box::new(DeduplicateOverlapping));
537
538        let text = "John Smith works at Google Inc.";
539        let entities = pipeline.extract(text).unwrap();
540
541        // All entities should have confidence >= 0.5
542        for entity in &entities {
543            assert!(entity.confidence >= 0.5);
544        }
545    }
546
547    #[test]
548    fn test_chunk_config_presets() {
549        let _no_chunk = ChunkConfig::no_chunking();
550        let _long = ChunkConfig::long_document();
551        let _realtime = ChunkConfig::realtime();
552    }
553
554    #[test]
555    fn test_find_sentence_boundary() {
556        let text: Vec<char> = "Hello world. This is a test.".chars().collect();
557        let boundary = find_sentence_boundary(&text, 0, 20);
558        // Should find boundary after "Hello world. "
559        assert!(boundary > 0);
560        assert!(boundary <= 20);
561    }
562
563    #[test]
564    fn test_entity_deduplication_across_chunks() {
565        // When an entity appears in the overlap region between chunks,
566        // it should be deduplicated (seen set should prevent duplicates)
567        let model = HeuristicNER::new();
568
569        // Use reasonable chunks with small overlap (avoid infinite loop edge cases)
570        let config = ChunkConfig {
571            chunk_size: 100,
572            overlap: 20,
573            respect_sentences: false,
574            buffer_size: 100,
575        };
576        let extractor = StreamingExtractor::new(&model, config);
577
578        let text = "I work at Google Inc in California. Then I visited Google headquarters.";
579        let entities: Vec<Entity> = extractor.extract(text).collect();
580
581        // Should find entities without infinite loops
582        // (the fix ensures forward progress)
583        assert!(
584            entities.len() < 100,
585            "Possible infinite loop: too many entities"
586        );
587    }
588
589    #[test]
590    fn test_empty_text_streaming() {
591        let model = HeuristicNER::new();
592        let extractor = StreamingExtractor::with_model(&model);
593
594        let entities: Vec<Entity> = extractor.extract("").collect();
595        assert!(entities.is_empty());
596    }
597
598    #[test]
599    fn test_unicode_text_streaming() {
600        let model = HeuristicNER::new();
601        let extractor = StreamingExtractor::with_model(&model);
602
603        let text = "東京 is the capital of 日本. Paris is in France.";
604        let entities: Vec<Entity> = extractor.extract(text).collect();
605
606        // Character offsets should be valid
607        let char_count = text.chars().count();
608        for entity in &entities {
609            assert!(entity.start <= entity.end, "Invalid span");
610            assert!(entity.end <= char_count, "Offset exceeds text length");
611        }
612    }
613
614    #[test]
615    fn test_forward_progress_guaranteed() {
616        // Test that streaming always makes forward progress even with small chunks
617        let model = HeuristicNER::new();
618
619        let config = ChunkConfig {
620            chunk_size: 5, // Very small chunks
621            overlap: 3,    // Large overlap relative to chunk
622            respect_sentences: false,
623            buffer_size: 10,
624        };
625        let extractor = StreamingExtractor::new(&model, config);
626
627        // Short text that could cause infinite loop without the fix
628        let text = "abc def";
629
630        // Should complete without hanging (the fix ensures forward progress)
631        let entities: Vec<Entity> = extractor.extract(text).collect();
632        // We don't care about the results, just that it terminates
633        let _ = entities;
634    }
635}