1use crate::{Entity, Model, Result};
46
47#[derive(Debug, Clone)]
53pub struct ChunkConfig {
54 pub chunk_size: usize,
56 pub overlap: usize,
58 pub respect_sentences: bool,
60 pub buffer_size: usize,
62}
63
64impl Default for ChunkConfig {
65 fn default() -> Self {
66 Self {
67 chunk_size: 10_000,
68 overlap: 100,
69 respect_sentences: true,
70 buffer_size: 1000,
71 }
72 }
73}
74
75impl ChunkConfig {
76 pub fn no_chunking() -> Self {
78 Self {
79 chunk_size: usize::MAX,
80 overlap: 0,
81 respect_sentences: false,
82 buffer_size: usize::MAX,
83 }
84 }
85
86 pub fn long_document() -> Self {
88 Self {
89 chunk_size: 50_000,
90 overlap: 200,
91 respect_sentences: true,
92 buffer_size: 5000,
93 }
94 }
95
96 pub fn realtime() -> Self {
98 Self {
99 chunk_size: 1000,
100 overlap: 50,
101 respect_sentences: false,
102 buffer_size: 100,
103 }
104 }
105}
106
107#[derive(Debug)]
109pub struct StreamingExtractor<'m, M: Model> {
110 model: &'m M,
111 config: ChunkConfig,
112}
113
114impl<'m, M: Model> StreamingExtractor<'m, M> {
115 pub fn new(model: &'m M, config: ChunkConfig) -> Self {
117 Self { model, config }
118 }
119
120 pub fn with_model(model: &'m M) -> Self {
122 Self::new(model, ChunkConfig::default())
123 }
124
125 pub fn extract<'t>(&'m self, text: &'t str) -> EntityIterator<'m, 't, M> {
127 EntityIterator::new(self, text)
128 }
129
130 fn process_chunk(&self, chunk: &str, offset: usize) -> Result<Vec<Entity>> {
132 let entities = self.model.extract_entities(chunk, None)?;
133
134 Ok(entities
136 .into_iter()
137 .map(|mut e| {
138 e.start += offset;
139 e.end += offset;
140 e
141 })
142 .collect())
143 }
144}
145
146pub struct EntityIterator<'m, 't, M: Model> {
148 extractor: &'m StreamingExtractor<'m, M>,
149 text: &'t str,
150 position: usize,
152 buffer: Vec<Entity>,
154 buffer_idx: usize,
156 seen: std::collections::HashSet<(usize, usize)>,
158 done: bool,
160}
161
162impl<'m, 't, M: Model> EntityIterator<'m, 't, M> {
163 fn new(extractor: &'m StreamingExtractor<'m, M>, text: &'t str) -> Self {
164 Self {
165 extractor,
166 text,
167 position: 0,
168 buffer: Vec::new(),
169 buffer_idx: 0,
170 seen: std::collections::HashSet::new(),
171 done: false,
172 }
173 }
174
175 fn fill_buffer(&mut self) -> Result<()> {
177 if self.done {
178 return Ok(());
179 }
180
181 let text_chars: Vec<char> = self.text.chars().collect();
182 let text_len = text_chars.len();
183
184 if self.position >= text_len {
185 self.done = true;
186 return Ok(());
187 }
188
189 let chunk_end = (self.position + self.extractor.config.chunk_size).min(text_len);
191
192 let actual_end = if self.extractor.config.respect_sentences {
194 find_sentence_boundary(&text_chars, self.position, chunk_end)
195 } else {
196 find_word_boundary(&text_chars, chunk_end)
197 };
198
199 let chunk: String = text_chars[self.position..actual_end].iter().collect();
201
202 let entities = self.extractor.process_chunk(&chunk, self.position)?;
204
205 self.buffer = entities
207 .into_iter()
208 .filter(|e| !self.seen.contains(&(e.start, e.end)))
209 .collect();
210
211 for e in &self.buffer {
213 self.seen.insert((e.start, e.end));
214 }
215
216 self.buffer_idx = 0;
217
218 let overlap = self.extractor.config.overlap;
221 let new_position = if actual_end >= text_len {
222 text_len
223 } else {
224 let overlap_position = actual_end.saturating_sub(overlap);
226 if overlap_position <= self.position {
228 self.position + 1
229 } else {
230 overlap_position
231 }
232 };
233
234 self.position = new_position;
235
236 if actual_end >= text_len || self.position >= text_len {
237 self.done = true;
238 }
239
240 Ok(())
241 }
242}
243
244impl<'m, 't, M: Model> Iterator for EntityIterator<'m, 't, M> {
245 type Item = Entity;
246
247 fn next(&mut self) -> Option<Self::Item> {
248 loop {
249 if self.buffer_idx < self.buffer.len() {
251 let entity = self.buffer[self.buffer_idx].clone();
252 self.buffer_idx += 1;
253 return Some(entity);
254 }
255
256 if self.done {
258 return None;
259 }
260
261 if self.fill_buffer().is_err() {
262 self.done = true;
263 return None;
264 }
265
266 if self.buffer.is_empty() && self.done {
268 return None;
269 }
270 }
271 }
272}
273
274fn find_sentence_boundary(chars: &[char], start: usize, target: usize) -> usize {
276 let search_start = target.saturating_sub(200);
278 for i in (search_start..target).rev() {
279 if i >= chars.len() {
280 continue;
281 }
282 let c = chars[i];
283 if (c == '.' || c == '!' || c == '?' || c == '。' || c == '!' || c == '?')
285 && (i + 1 >= chars.len() || chars[i + 1].is_whitespace())
286 {
287 let mut end = i + 1;
289 while end < chars.len() && chars[end].is_whitespace() {
290 end += 1;
291 }
292 if end > start {
293 return end;
294 }
295 }
296 }
297 find_word_boundary(chars, target)
299}
300
301fn find_word_boundary(chars: &[char], target: usize) -> usize {
303 let target = target.min(chars.len());
304
305 if target >= chars.len() {
307 return chars.len();
308 }
309
310 for i in (0..target).rev() {
312 if chars[i].is_whitespace() {
313 return i + 1;
314 }
315 }
316 target
317}
318
319#[cfg(feature = "production")]
325pub mod async_stream {
326 use super::*;
327 use futures::stream::{self, Stream};
328
329 impl<'m, M: Model + Sync> StreamingExtractor<'m, M> {
330 pub fn extract_stream<'t>(&'m self, text: &'t str) -> impl Stream<Item = Entity> + 'm
332 where
333 't: 'm,
334 {
335 let iter = self.extract(text);
336 stream::iter(iter)
337 }
338 }
339}
340
341pub trait PipelineStage: Send + Sync {
347 fn process(&self, entities: Vec<Entity>, text: &str) -> Vec<Entity>;
349
350 fn name(&self) -> &'static str;
352}
353
354pub struct Pipeline<M: Model> {
356 model: M,
357 post_stages: Vec<Box<dyn PipelineStage>>,
359 chunk_config: ChunkConfig,
361}
362
363impl<M: Model> Pipeline<M> {
364 pub fn new(model: M) -> Self {
366 Self {
367 model,
368 post_stages: Vec::new(),
369 chunk_config: ChunkConfig::default(),
370 }
371 }
372
373 pub fn add_stage(mut self, stage: Box<dyn PipelineStage>) -> Self {
375 self.post_stages.push(stage);
376 self
377 }
378
379 pub fn with_chunk_config(mut self, config: ChunkConfig) -> Self {
381 self.chunk_config = config;
382 self
383 }
384
385 pub fn extract(&self, text: &str) -> Result<Vec<Entity>> {
387 let mut entities = self.model.extract_entities(text, None)?;
388
389 for stage in &self.post_stages {
390 entities = stage.process(entities, text);
391 }
392
393 Ok(entities)
394 }
395
396 pub fn model(&self) -> &M {
398 &self.model
399 }
400}
401
402pub struct ConfidenceFilter {
408 threshold: f64,
409}
410
411impl ConfidenceFilter {
412 pub fn new(threshold: f64) -> Self {
414 Self { threshold }
415 }
416}
417
418impl PipelineStage for ConfidenceFilter {
419 fn process(&self, entities: Vec<Entity>, _text: &str) -> Vec<Entity> {
420 entities
421 .into_iter()
422 .filter(|e| e.confidence >= self.threshold)
423 .collect()
424 }
425
426 fn name(&self) -> &'static str {
427 "ConfidenceFilter"
428 }
429}
430
431pub struct DeduplicateOverlapping;
433
434impl PipelineStage for DeduplicateOverlapping {
435 fn process(&self, mut entities: Vec<Entity>, _text: &str) -> Vec<Entity> {
436 entities.sort_by(|a, b| {
438 a.start.cmp(&b.start).then(
439 b.confidence
440 .partial_cmp(&a.confidence)
441 .expect("confidence values should be comparable"),
442 )
443 });
444
445 let mut result = Vec::new();
446 let mut last_end = 0;
447
448 for entity in entities {
449 if entity.start >= last_end {
450 last_end = entity.end;
451 result.push(entity);
452 }
453 }
455
456 result
457 }
458
459 fn name(&self) -> &'static str {
460 "DeduplicateOverlapping"
461 }
462}
463
464pub struct NormalizeText {
466 lowercase: bool,
467}
468
469impl NormalizeText {
470 pub fn new(lowercase: bool) -> Self {
472 Self { lowercase }
473 }
474}
475
476impl PipelineStage for NormalizeText {
477 fn process(&self, entities: Vec<Entity>, _text: &str) -> Vec<Entity> {
478 entities
479 .into_iter()
480 .map(|mut e| {
481 e.text = e.text.trim().to_string();
482 if self.lowercase {
483 e.text = e.text.to_lowercase();
484 }
485 e
486 })
487 .collect()
488 }
489
490 fn name(&self) -> &'static str {
491 "NormalizeText"
492 }
493}
494
495#[cfg(test)]
496mod tests {
497 use super::*;
498 use crate::HeuristicNER;
499
500 #[test]
501 fn test_streaming_basic() {
502 let model = HeuristicNER::new();
503 let extractor = StreamingExtractor::with_model(&model);
504
505 let text = "John Smith works at Google Inc. in New York.";
506 let entities: Vec<Entity> = extractor.extract(text).collect();
507
508 assert!(!entities.is_empty());
509 }
510
511 #[test]
512 fn test_streaming_long_text() {
513 let model = HeuristicNER::new();
514 let config = ChunkConfig {
515 chunk_size: 50,
516 overlap: 10,
517 respect_sentences: false,
518 buffer_size: 100,
519 };
520 let extractor = StreamingExtractor::new(&model, config);
521
522 let text =
524 "John Smith works at Google. Mary Johnson is at Apple. Bob Williams joined Microsoft.";
525 let entities: Vec<Entity> = extractor.extract(text).collect();
526
527 assert!(!entities.is_empty());
529 }
530
531 #[test]
532 fn test_pipeline() {
533 let model = HeuristicNER::new();
534 let pipeline = Pipeline::new(model)
535 .add_stage(Box::new(ConfidenceFilter::new(0.5)))
536 .add_stage(Box::new(DeduplicateOverlapping));
537
538 let text = "John Smith works at Google Inc.";
539 let entities = pipeline.extract(text).unwrap();
540
541 for entity in &entities {
543 assert!(entity.confidence >= 0.5);
544 }
545 }
546
547 #[test]
548 fn test_chunk_config_presets() {
549 let _no_chunk = ChunkConfig::no_chunking();
550 let _long = ChunkConfig::long_document();
551 let _realtime = ChunkConfig::realtime();
552 }
553
554 #[test]
555 fn test_find_sentence_boundary() {
556 let text: Vec<char> = "Hello world. This is a test.".chars().collect();
557 let boundary = find_sentence_boundary(&text, 0, 20);
558 assert!(boundary > 0);
560 assert!(boundary <= 20);
561 }
562
563 #[test]
564 fn test_entity_deduplication_across_chunks() {
565 let model = HeuristicNER::new();
568
569 let config = ChunkConfig {
571 chunk_size: 100,
572 overlap: 20,
573 respect_sentences: false,
574 buffer_size: 100,
575 };
576 let extractor = StreamingExtractor::new(&model, config);
577
578 let text = "I work at Google Inc in California. Then I visited Google headquarters.";
579 let entities: Vec<Entity> = extractor.extract(text).collect();
580
581 assert!(
584 entities.len() < 100,
585 "Possible infinite loop: too many entities"
586 );
587 }
588
589 #[test]
590 fn test_empty_text_streaming() {
591 let model = HeuristicNER::new();
592 let extractor = StreamingExtractor::with_model(&model);
593
594 let entities: Vec<Entity> = extractor.extract("").collect();
595 assert!(entities.is_empty());
596 }
597
598 #[test]
599 fn test_unicode_text_streaming() {
600 let model = HeuristicNER::new();
601 let extractor = StreamingExtractor::with_model(&model);
602
603 let text = "東京 is the capital of 日本. Paris is in France.";
604 let entities: Vec<Entity> = extractor.extract(text).collect();
605
606 let char_count = text.chars().count();
608 for entity in &entities {
609 assert!(entity.start <= entity.end, "Invalid span");
610 assert!(entity.end <= char_count, "Offset exceeds text length");
611 }
612 }
613
614 #[test]
615 fn test_forward_progress_guaranteed() {
616 let model = HeuristicNER::new();
618
619 let config = ChunkConfig {
620 chunk_size: 5, overlap: 3, respect_sentences: false,
623 buffer_size: 10,
624 };
625 let extractor = StreamingExtractor::new(&model, config);
626
627 let text = "abc def";
629
630 let entities: Vec<Entity> = extractor.extract(text).collect();
632 let _ = entities;
634 }
635}