Skip to main content

anno/backends/
streaming.rs

1//! Streaming NER API for incremental entity extraction.
2//!
3//! Provides iterator-based entity extraction for large documents,
4//! real-time text streams, or memory-constrained environments.
5//!
6//! # Overview
7//!
8//! Standard NER processes entire documents at once, which can be slow and
9//! memory-intensive for large texts. The streaming API offers:
10//!
11//! - **Chunked processing**: Split text into manageable chunks
12//! - **Iterator interface**: Lazily yield entities as they're found
13//! - **Backpressure**: Consumer controls the pace of extraction
14//! - **Stateful context**: Maintain context across chunk boundaries
15//!
16//! # Example
17//!
18//! ```rust,ignore
19//! use anno::backends::streaming::{StreamingExtractor, ChunkConfig};
20//! use anno::StackedNER;
21//!
22//! let backend = StackedNER::default();
23//! let config = ChunkConfig::default();
24//!
25//! // Process large text in chunks
26//! let extractor = StreamingExtractor::new(&backend, config);
27//! for entity in extractor.extract("Very long text...") {
28//!     println!("Found: {} at {}-{}", entity.text, entity.start, entity.end);
29//! }
30//! ```
31//!
32//! # Pipeline Integration
33//!
34//! The streaming API integrates with async pipelines:
35//!
36//! ```rust,ignore
37//! use futures::StreamExt;
38//!
39//! let stream = extractor.extract_stream(text);
40//! while let Some(entity) = stream.next().await {
41//!     process(entity);
42//! }
43//! ```
44
45use crate::{Entity, Model, Result};
46
47// Semantic chunking integration pending
48// #[cfg(feature = "semantic-chunking")]
49// use crate::backends::semantic_chunking::{SemanticChunkConfig, SemanticChunker};
50
51/// Configuration for chunked text processing.
52#[derive(Debug, Clone)]
53pub struct ChunkConfig {
54    /// Target chunk size in characters (actual may vary to avoid splitting words)
55    pub chunk_size: usize,
56    /// Overlap between chunks (characters) to catch entities at boundaries
57    pub overlap: usize,
58    /// Sentence boundary detection (if true, chunks end at sentence boundaries)
59    pub respect_sentences: bool,
60    /// Maximum entities to buffer before yielding
61    pub buffer_size: usize,
62}
63
64impl Default for ChunkConfig {
65    fn default() -> Self {
66        Self {
67            chunk_size: 10_000,
68            overlap: 100,
69            respect_sentences: true,
70            buffer_size: 1000,
71        }
72    }
73}
74
75impl ChunkConfig {
76    /// Create a config for small documents (no chunking).
77    pub fn no_chunking() -> Self {
78        Self {
79            chunk_size: usize::MAX,
80            overlap: 0,
81            respect_sentences: false,
82            buffer_size: usize::MAX,
83        }
84    }
85
86    /// Create a config optimized for long documents.
87    pub fn long_document() -> Self {
88        Self {
89            chunk_size: 50_000,
90            overlap: 200,
91            respect_sentences: true,
92            buffer_size: 5000,
93        }
94    }
95
96    /// Create a config for real-time/streaming input.
97    pub fn realtime() -> Self {
98        Self {
99            chunk_size: 1000,
100            overlap: 50,
101            respect_sentences: false,
102            buffer_size: 100,
103        }
104    }
105}
106
107/// A streaming entity extractor that processes text in chunks.
108#[derive(Debug)]
109pub struct StreamingExtractor<'m, M: Model> {
110    model: &'m M,
111    config: ChunkConfig,
112}
113
114impl<'m, M: Model> StreamingExtractor<'m, M> {
115    /// Create a new streaming extractor with the given model and config.
116    pub fn new(model: &'m M, config: ChunkConfig) -> Self {
117        Self { model, config }
118    }
119
120    /// Create with default config.
121    pub fn with_model(model: &'m M) -> Self {
122        Self::new(model, ChunkConfig::default())
123    }
124
125    /// Extract entities from text, yielding them as an iterator.
126    pub fn extract<'t>(&'m self, text: &'t str) -> EntityIterator<'m, 't, M> {
127        EntityIterator::new(self, text)
128    }
129
130    /// Process a single chunk and return entities with adjusted offsets.
131    fn process_chunk(&self, chunk: &str, offset: usize) -> Result<Vec<Entity>> {
132        let entities = self.model.extract_entities(chunk, None)?;
133
134        // Adjust offsets to be relative to original text
135        Ok(entities
136            .into_iter()
137            .map(|mut e| {
138                e.start += offset;
139                e.end += offset;
140                e
141            })
142            .collect())
143    }
144}
145
146/// Iterator over entities extracted from text.
147pub struct EntityIterator<'m, 't, M: Model> {
148    extractor: &'m StreamingExtractor<'m, M>,
149    text: &'t str,
150    /// Current position in text (character offset)
151    position: usize,
152    /// Buffer of entities from current chunk
153    buffer: Vec<Entity>,
154    /// Index into buffer
155    buffer_idx: usize,
156    /// Set of (start, end) pairs already yielded (for deduplication)
157    seen: std::collections::HashSet<(usize, usize)>,
158    /// Whether we've finished processing
159    done: bool,
160}
161
162impl<'m, 't, M: Model> EntityIterator<'m, 't, M> {
163    fn new(extractor: &'m StreamingExtractor<'m, M>, text: &'t str) -> Self {
164        Self {
165            extractor,
166            text,
167            position: 0,
168            buffer: Vec::new(),
169            buffer_idx: 0,
170            seen: std::collections::HashSet::new(),
171            done: false,
172        }
173    }
174
175    /// Fill the buffer with entities from the next chunk.
176    fn fill_buffer(&mut self) -> Result<()> {
177        if self.done {
178            return Ok(());
179        }
180
181        let text_chars: Vec<char> = self.text.chars().collect();
182        let text_len = text_chars.len();
183
184        if self.position >= text_len {
185            self.done = true;
186            return Ok(());
187        }
188
189        // Calculate chunk boundaries
190        let chunk_end = (self.position + self.extractor.config.chunk_size).min(text_len);
191
192        // Find a good break point (sentence boundary or word boundary)
193        let actual_end = if self.extractor.config.respect_sentences {
194            find_sentence_boundary(&text_chars, self.position, chunk_end)
195        } else {
196            find_word_boundary(&text_chars, chunk_end)
197        };
198
199        // Extract the chunk
200        let chunk: String = text_chars[self.position..actual_end].iter().collect();
201
202        // Process chunk
203        let entities = self.extractor.process_chunk(&chunk, self.position)?;
204
205        // Filter out entities we've already seen (from overlap regions)
206        self.buffer = entities
207            .into_iter()
208            .filter(|e| !self.seen.contains(&(e.start, e.end)))
209            .collect();
210
211        // Mark these entities as seen
212        for e in &self.buffer {
213            self.seen.insert((e.start, e.end));
214        }
215
216        self.buffer_idx = 0;
217
218        // Move position forward (with overlap for next chunk)
219        // CRITICAL: Always ensure we make forward progress to avoid infinite loops
220        let overlap = self.extractor.config.overlap;
221        let new_position = if actual_end >= text_len {
222            text_len
223        } else {
224            // Ensure we always advance by at least 1 character
225            let overlap_position = actual_end.saturating_sub(overlap);
226            // If overlap would cause us to not advance, force forward progress
227            if overlap_position <= self.position {
228                self.position + 1
229            } else {
230                overlap_position
231            }
232        };
233
234        self.position = new_position;
235
236        if actual_end >= text_len || self.position >= text_len {
237            self.done = true;
238        }
239
240        Ok(())
241    }
242}
243
244impl<'m, 't, M: Model> Iterator for EntityIterator<'m, 't, M> {
245    type Item = Entity;
246
247    fn next(&mut self) -> Option<Self::Item> {
248        loop {
249            // Return from buffer if available
250            if self.buffer_idx < self.buffer.len() {
251                let entity = self.buffer[self.buffer_idx].clone();
252                self.buffer_idx += 1;
253                return Some(entity);
254            }
255
256            // Buffer empty, try to fill it
257            if self.done {
258                return None;
259            }
260
261            if self.fill_buffer().is_err() {
262                self.done = true;
263                return None;
264            }
265
266            // If buffer is still empty after fill, we're done
267            if self.buffer.is_empty() && self.done {
268                return None;
269            }
270        }
271    }
272}
273
274/// Find a sentence boundary near the target position.
275fn find_sentence_boundary(chars: &[char], start: usize, target: usize) -> usize {
276    // Look backwards from target for sentence-ending punctuation
277    let search_start = target.saturating_sub(200);
278    for i in (search_start..target).rev() {
279        if i >= chars.len() {
280            continue;
281        }
282        let c = chars[i];
283        // Sentence boundaries: . ! ? followed by whitespace or end
284        if (c == '.' || c == '!' || c == '?' || c == '。' || c == '!' || c == '?')
285            && (i + 1 >= chars.len() || chars[i + 1].is_whitespace())
286        {
287            // Return position after the punctuation and whitespace
288            let mut end = i + 1;
289            while end < chars.len() && chars[end].is_whitespace() {
290                end += 1;
291            }
292            if end > start {
293                return end;
294            }
295        }
296    }
297    // No sentence boundary found, fall back to word boundary
298    find_word_boundary(chars, target)
299}
300
301/// Find a word boundary near the target position.
302fn find_word_boundary(chars: &[char], target: usize) -> usize {
303    let target = target.min(chars.len());
304
305    // If we're already at end, return it
306    if target >= chars.len() {
307        return chars.len();
308    }
309
310    // Look backwards for whitespace
311    for i in (0..target).rev() {
312        if chars[i].is_whitespace() {
313            return i + 1;
314        }
315    }
316    target
317}
318
319// =============================================================================
320// Async Stream Support (requires tokio/async-std)
321// =============================================================================
322
323#[cfg(feature = "async-inference")]
324pub mod async_stream {
325    use super::*;
326    use futures::stream::{self, Stream};
327
328    impl<'m, M: Model + Sync> StreamingExtractor<'m, M> {
329        /// Create an async stream of entities.
330        pub fn extract_stream<'t>(&'m self, text: &'t str) -> impl Stream<Item = Entity> + 'm
331        where
332            't: 'm,
333        {
334            let iter = self.extract(text);
335            stream::iter(iter)
336        }
337    }
338}
339
340// =============================================================================
341// Pipeline Integration Hooks
342// =============================================================================
343
344/// A processing stage in an NER pipeline.
345pub trait PipelineStage: Send + Sync {
346    /// Process entities before they're returned.
347    fn process(&self, entities: Vec<Entity>, text: &str) -> Vec<Entity>;
348
349    /// Name of this stage (for debugging/logging).
350    fn name(&self) -> &'static str;
351}
352
353/// A complete NER pipeline with preprocessing and postprocessing stages.
354pub struct Pipeline<M: Model> {
355    model: M,
356    /// Stages that run after entity extraction
357    post_stages: Vec<Box<dyn PipelineStage>>,
358    /// Chunk configuration for streaming
359    chunk_config: ChunkConfig,
360}
361
362impl<M: Model> Pipeline<M> {
363    /// Create a new pipeline with the given model.
364    pub fn new(model: M) -> Self {
365        Self {
366            model,
367            post_stages: Vec::new(),
368            chunk_config: ChunkConfig::default(),
369        }
370    }
371
372    /// Add a post-processing stage.
373    pub fn add_stage(mut self, stage: Box<dyn PipelineStage>) -> Self {
374        self.post_stages.push(stage);
375        self
376    }
377
378    /// Set chunk configuration.
379    pub fn with_chunk_config(mut self, config: ChunkConfig) -> Self {
380        self.chunk_config = config;
381        self
382    }
383
384    /// Extract entities with all pipeline stages applied.
385    pub fn extract(&self, text: &str) -> Result<Vec<Entity>> {
386        let mut entities = self.model.extract_entities(text, None)?;
387
388        for stage in &self.post_stages {
389            entities = stage.process(entities, text);
390        }
391
392        Ok(entities)
393    }
394
395    /// Get a reference to the underlying model.
396    pub fn model(&self) -> &M {
397        &self.model
398    }
399}
400
401// =============================================================================
402// Common Pipeline Stages
403// =============================================================================
404
405/// Filter entities by confidence threshold.
406pub struct ConfidenceFilter {
407    threshold: f64,
408}
409
410impl ConfidenceFilter {
411    /// Create a new confidence filter with the given threshold.
412    pub fn new(threshold: f64) -> Self {
413        Self { threshold }
414    }
415}
416
417impl PipelineStage for ConfidenceFilter {
418    fn process(&self, entities: Vec<Entity>, _text: &str) -> Vec<Entity> {
419        entities
420            .into_iter()
421            .filter(|e| e.confidence >= self.threshold)
422            .collect()
423    }
424
425    fn name(&self) -> &'static str {
426        "ConfidenceFilter"
427    }
428}
429
430/// Deduplicate overlapping entities, keeping highest confidence.
431pub struct DeduplicateOverlapping;
432
433impl PipelineStage for DeduplicateOverlapping {
434    fn process(&self, mut entities: Vec<Entity>, _text: &str) -> Vec<Entity> {
435        // Sort by start, then by confidence (desc)
436        entities.sort_by(|a, b| {
437            a.start.cmp(&b.start).then(
438                b.confidence
439                    .partial_cmp(&a.confidence)
440                    .expect("confidence values should be comparable"),
441            )
442        });
443
444        let mut result = Vec::new();
445        let mut last_end = 0;
446
447        for entity in entities {
448            if entity.start >= last_end {
449                last_end = entity.end;
450                result.push(entity);
451            }
452            // Skip overlapping entities (we already have a higher-confidence one)
453        }
454
455        result
456    }
457
458    fn name(&self) -> &'static str {
459        "DeduplicateOverlapping"
460    }
461}
462
463/// Normalize entity text (trim whitespace, normalize case, etc.).
464pub struct NormalizeText {
465    lowercase: bool,
466}
467
468impl NormalizeText {
469    /// Create a new text normalizer with optional lowercasing.
470    pub fn new(lowercase: bool) -> Self {
471        Self { lowercase }
472    }
473}
474
475impl PipelineStage for NormalizeText {
476    fn process(&self, entities: Vec<Entity>, _text: &str) -> Vec<Entity> {
477        entities
478            .into_iter()
479            .map(|mut e| {
480                e.text = e.text.trim().to_string();
481                if self.lowercase {
482                    e.text = e.text.to_lowercase();
483                }
484                e
485            })
486            .collect()
487    }
488
489    fn name(&self) -> &'static str {
490        "NormalizeText"
491    }
492}
493
494#[cfg(test)]
495mod tests {
496    use super::*;
497    use crate::HeuristicNER;
498
499    #[test]
500    fn test_streaming_basic() {
501        let model = HeuristicNER::new();
502        let extractor = StreamingExtractor::with_model(&model);
503
504        let text = "John Smith works at Google Inc. in New York.";
505        let entities: Vec<Entity> = extractor.extract(text).collect();
506
507        assert!(!entities.is_empty());
508    }
509
510    #[test]
511    fn test_streaming_long_text() {
512        let model = HeuristicNER::new();
513        let config = ChunkConfig {
514            chunk_size: 50,
515            overlap: 10,
516            respect_sentences: false,
517            buffer_size: 100,
518        };
519        let extractor = StreamingExtractor::new(&model, config);
520
521        // Create a longer text
522        let text =
523            "John Smith works at Google. Mary Johnson is at Apple. Bob Williams joined Microsoft.";
524        let entities: Vec<Entity> = extractor.extract(text).collect();
525
526        // Should find entities across chunks
527        assert!(!entities.is_empty());
528    }
529
530    #[test]
531    fn test_pipeline() {
532        let model = HeuristicNER::new();
533        let pipeline = Pipeline::new(model)
534            .add_stage(Box::new(ConfidenceFilter::new(0.5)))
535            .add_stage(Box::new(DeduplicateOverlapping));
536
537        let text = "John Smith works at Google Inc.";
538        let entities = pipeline.extract(text).unwrap();
539
540        // All entities should have confidence >= 0.5
541        for entity in &entities {
542            assert!(entity.confidence >= 0.5);
543        }
544    }
545
546    #[test]
547    fn test_chunk_config_presets() {
548        let _no_chunk = ChunkConfig::no_chunking();
549        let _long = ChunkConfig::long_document();
550        let _realtime = ChunkConfig::realtime();
551    }
552
553    #[test]
554    fn test_find_sentence_boundary() {
555        let text: Vec<char> = "Hello world. This is a test.".chars().collect();
556        let boundary = find_sentence_boundary(&text, 0, 20);
557        // Should find boundary after "Hello world. "
558        assert!(boundary > 0);
559        assert!(boundary <= 20);
560    }
561
562    #[test]
563    fn test_entity_deduplication_across_chunks() {
564        // When an entity appears in the overlap region between chunks,
565        // it should be deduplicated (seen set should prevent duplicates)
566        let model = HeuristicNER::new();
567
568        // Use reasonable chunks with small overlap (avoid infinite loop edge cases)
569        let config = ChunkConfig {
570            chunk_size: 100,
571            overlap: 20,
572            respect_sentences: false,
573            buffer_size: 100,
574        };
575        let extractor = StreamingExtractor::new(&model, config);
576
577        let text = "I work at Google Inc in California. Then I visited Google headquarters.";
578        let entities: Vec<Entity> = extractor.extract(text).collect();
579
580        // Should find entities without infinite loops
581        // (the fix ensures forward progress)
582        assert!(
583            entities.len() < 100,
584            "Possible infinite loop: too many entities"
585        );
586    }
587
588    #[test]
589    fn test_empty_text_streaming() {
590        let model = HeuristicNER::new();
591        let extractor = StreamingExtractor::with_model(&model);
592
593        let entities: Vec<Entity> = extractor.extract("").collect();
594        assert!(entities.is_empty());
595    }
596
597    #[test]
598    fn test_unicode_text_streaming() {
599        let model = HeuristicNER::new();
600        let extractor = StreamingExtractor::with_model(&model);
601
602        let text = "東京 is the capital of 日本. Paris is in France.";
603        let entities: Vec<Entity> = extractor.extract(text).collect();
604
605        // Character offsets should be valid
606        let char_count = text.chars().count();
607        for entity in &entities {
608            assert!(entity.start <= entity.end, "Invalid span");
609            assert!(entity.end <= char_count, "Offset exceeds text length");
610        }
611    }
612
613    #[test]
614    fn test_forward_progress_guaranteed() {
615        // Test that streaming always makes forward progress even with small chunks
616        let model = HeuristicNER::new();
617
618        let config = ChunkConfig {
619            chunk_size: 5, // Very small chunks
620            overlap: 3,    // Large overlap relative to chunk
621            respect_sentences: false,
622            buffer_size: 10,
623        };
624        let extractor = StreamingExtractor::new(&model, config);
625
626        // Short text that could cause infinite loop without the fix
627        let text = "abc def";
628
629        // Should complete without hanging (the fix ensures forward progress)
630        let entities: Vec<Entity> = extractor.extract(text).collect();
631        // We don't care about the results, just that it terminates
632        let _ = entities;
633    }
634}