1use crate::{Entity, Model, Result};
46
47#[derive(Debug, Clone)]
53pub struct ChunkConfig {
54 pub chunk_size: usize,
56 pub overlap: usize,
58 pub respect_sentences: bool,
60 pub buffer_size: usize,
62}
63
64impl Default for ChunkConfig {
65 fn default() -> Self {
66 Self {
67 chunk_size: 10_000,
68 overlap: 100,
69 respect_sentences: true,
70 buffer_size: 1000,
71 }
72 }
73}
74
75impl ChunkConfig {
76 pub fn no_chunking() -> Self {
78 Self {
79 chunk_size: usize::MAX,
80 overlap: 0,
81 respect_sentences: false,
82 buffer_size: usize::MAX,
83 }
84 }
85
86 pub fn long_document() -> Self {
88 Self {
89 chunk_size: 50_000,
90 overlap: 200,
91 respect_sentences: true,
92 buffer_size: 5000,
93 }
94 }
95
96 pub fn realtime() -> Self {
98 Self {
99 chunk_size: 1000,
100 overlap: 50,
101 respect_sentences: false,
102 buffer_size: 100,
103 }
104 }
105}
106
107#[derive(Debug)]
109pub struct StreamingExtractor<'m, M: Model> {
110 model: &'m M,
111 config: ChunkConfig,
112}
113
114impl<'m, M: Model> StreamingExtractor<'m, M> {
115 pub fn new(model: &'m M, config: ChunkConfig) -> Self {
117 Self { model, config }
118 }
119
120 pub fn with_model(model: &'m M) -> Self {
122 Self::new(model, ChunkConfig::default())
123 }
124
125 pub fn extract<'t>(&'m self, text: &'t str) -> EntityIterator<'m, 't, M> {
127 EntityIterator::new(self, text)
128 }
129
130 fn process_chunk(&self, chunk: &str, offset: usize) -> Result<Vec<Entity>> {
132 let entities = self.model.extract_entities(chunk, None)?;
133
134 Ok(entities
136 .into_iter()
137 .map(|mut e| {
138 e.start += offset;
139 e.end += offset;
140 e
141 })
142 .collect())
143 }
144}
145
146pub struct EntityIterator<'m, 't, M: Model> {
148 extractor: &'m StreamingExtractor<'m, M>,
149 text: &'t str,
150 position: usize,
152 buffer: Vec<Entity>,
154 buffer_idx: usize,
156 seen: std::collections::HashSet<(usize, usize)>,
158 done: bool,
160}
161
162impl<'m, 't, M: Model> EntityIterator<'m, 't, M> {
163 fn new(extractor: &'m StreamingExtractor<'m, M>, text: &'t str) -> Self {
164 Self {
165 extractor,
166 text,
167 position: 0,
168 buffer: Vec::new(),
169 buffer_idx: 0,
170 seen: std::collections::HashSet::new(),
171 done: false,
172 }
173 }
174
175 fn fill_buffer(&mut self) -> Result<()> {
177 if self.done {
178 return Ok(());
179 }
180
181 let text_chars: Vec<char> = self.text.chars().collect();
182 let text_len = text_chars.len();
183
184 if self.position >= text_len {
185 self.done = true;
186 return Ok(());
187 }
188
189 let chunk_end = (self.position + self.extractor.config.chunk_size).min(text_len);
191
192 let actual_end = if self.extractor.config.respect_sentences {
194 find_sentence_boundary(&text_chars, self.position, chunk_end)
195 } else {
196 find_word_boundary(&text_chars, chunk_end)
197 };
198
199 let chunk: String = text_chars[self.position..actual_end].iter().collect();
201
202 let entities = self.extractor.process_chunk(&chunk, self.position)?;
204
205 self.buffer = entities
207 .into_iter()
208 .filter(|e| !self.seen.contains(&(e.start, e.end)))
209 .collect();
210
211 for e in &self.buffer {
213 self.seen.insert((e.start, e.end));
214 }
215
216 self.buffer_idx = 0;
217
218 let overlap = self.extractor.config.overlap;
221 let new_position = if actual_end >= text_len {
222 text_len
223 } else {
224 let overlap_position = actual_end.saturating_sub(overlap);
226 if overlap_position <= self.position {
228 self.position + 1
229 } else {
230 overlap_position
231 }
232 };
233
234 self.position = new_position;
235
236 if actual_end >= text_len || self.position >= text_len {
237 self.done = true;
238 }
239
240 Ok(())
241 }
242}
243
244impl<'m, 't, M: Model> Iterator for EntityIterator<'m, 't, M> {
245 type Item = Entity;
246
247 fn next(&mut self) -> Option<Self::Item> {
248 loop {
249 if self.buffer_idx < self.buffer.len() {
251 let entity = self.buffer[self.buffer_idx].clone();
252 self.buffer_idx += 1;
253 return Some(entity);
254 }
255
256 if self.done {
258 return None;
259 }
260
261 if self.fill_buffer().is_err() {
262 self.done = true;
263 return None;
264 }
265
266 if self.buffer.is_empty() && self.done {
268 return None;
269 }
270 }
271 }
272}
273
274fn find_sentence_boundary(chars: &[char], start: usize, target: usize) -> usize {
276 let search_start = target.saturating_sub(200);
278 for i in (search_start..target).rev() {
279 if i >= chars.len() {
280 continue;
281 }
282 let c = chars[i];
283 if (c == '.' || c == '!' || c == '?' || c == '。' || c == '!' || c == '?')
285 && (i + 1 >= chars.len() || chars[i + 1].is_whitespace())
286 {
287 let mut end = i + 1;
289 while end < chars.len() && chars[end].is_whitespace() {
290 end += 1;
291 }
292 if end > start {
293 return end;
294 }
295 }
296 }
297 find_word_boundary(chars, target)
299}
300
301fn find_word_boundary(chars: &[char], target: usize) -> usize {
303 let target = target.min(chars.len());
304
305 if target >= chars.len() {
307 return chars.len();
308 }
309
310 for i in (0..target).rev() {
312 if chars[i].is_whitespace() {
313 return i + 1;
314 }
315 }
316 target
317}
318
319#[cfg(feature = "async-inference")]
324pub mod async_stream {
325 use super::*;
326 use futures::stream::{self, Stream};
327
328 impl<'m, M: Model + Sync> StreamingExtractor<'m, M> {
329 pub fn extract_stream<'t>(&'m self, text: &'t str) -> impl Stream<Item = Entity> + 'm
331 where
332 't: 'm,
333 {
334 let iter = self.extract(text);
335 stream::iter(iter)
336 }
337 }
338}
339
340pub trait PipelineStage: Send + Sync {
346 fn process(&self, entities: Vec<Entity>, text: &str) -> Vec<Entity>;
348
349 fn name(&self) -> &'static str;
351}
352
353pub struct Pipeline<M: Model> {
355 model: M,
356 post_stages: Vec<Box<dyn PipelineStage>>,
358 chunk_config: ChunkConfig,
360}
361
362impl<M: Model> Pipeline<M> {
363 pub fn new(model: M) -> Self {
365 Self {
366 model,
367 post_stages: Vec::new(),
368 chunk_config: ChunkConfig::default(),
369 }
370 }
371
372 pub fn add_stage(mut self, stage: Box<dyn PipelineStage>) -> Self {
374 self.post_stages.push(stage);
375 self
376 }
377
378 pub fn with_chunk_config(mut self, config: ChunkConfig) -> Self {
380 self.chunk_config = config;
381 self
382 }
383
384 pub fn extract(&self, text: &str) -> Result<Vec<Entity>> {
386 let mut entities = self.model.extract_entities(text, None)?;
387
388 for stage in &self.post_stages {
389 entities = stage.process(entities, text);
390 }
391
392 Ok(entities)
393 }
394
395 pub fn model(&self) -> &M {
397 &self.model
398 }
399}
400
401pub struct ConfidenceFilter {
407 threshold: f64,
408}
409
410impl ConfidenceFilter {
411 pub fn new(threshold: f64) -> Self {
413 Self { threshold }
414 }
415}
416
417impl PipelineStage for ConfidenceFilter {
418 fn process(&self, entities: Vec<Entity>, _text: &str) -> Vec<Entity> {
419 entities
420 .into_iter()
421 .filter(|e| e.confidence >= self.threshold)
422 .collect()
423 }
424
425 fn name(&self) -> &'static str {
426 "ConfidenceFilter"
427 }
428}
429
430pub struct DeduplicateOverlapping;
432
433impl PipelineStage for DeduplicateOverlapping {
434 fn process(&self, mut entities: Vec<Entity>, _text: &str) -> Vec<Entity> {
435 entities.sort_by(|a, b| {
437 a.start.cmp(&b.start).then(
438 b.confidence
439 .partial_cmp(&a.confidence)
440 .expect("confidence values should be comparable"),
441 )
442 });
443
444 let mut result = Vec::new();
445 let mut last_end = 0;
446
447 for entity in entities {
448 if entity.start >= last_end {
449 last_end = entity.end;
450 result.push(entity);
451 }
452 }
454
455 result
456 }
457
458 fn name(&self) -> &'static str {
459 "DeduplicateOverlapping"
460 }
461}
462
463pub struct NormalizeText {
465 lowercase: bool,
466}
467
468impl NormalizeText {
469 pub fn new(lowercase: bool) -> Self {
471 Self { lowercase }
472 }
473}
474
475impl PipelineStage for NormalizeText {
476 fn process(&self, entities: Vec<Entity>, _text: &str) -> Vec<Entity> {
477 entities
478 .into_iter()
479 .map(|mut e| {
480 e.text = e.text.trim().to_string();
481 if self.lowercase {
482 e.text = e.text.to_lowercase();
483 }
484 e
485 })
486 .collect()
487 }
488
489 fn name(&self) -> &'static str {
490 "NormalizeText"
491 }
492}
493
494#[cfg(test)]
495mod tests {
496 use super::*;
497 use crate::HeuristicNER;
498
499 #[test]
500 fn test_streaming_basic() {
501 let model = HeuristicNER::new();
502 let extractor = StreamingExtractor::with_model(&model);
503
504 let text = "John Smith works at Google Inc. in New York.";
505 let entities: Vec<Entity> = extractor.extract(text).collect();
506
507 assert!(!entities.is_empty());
508 }
509
510 #[test]
511 fn test_streaming_long_text() {
512 let model = HeuristicNER::new();
513 let config = ChunkConfig {
514 chunk_size: 50,
515 overlap: 10,
516 respect_sentences: false,
517 buffer_size: 100,
518 };
519 let extractor = StreamingExtractor::new(&model, config);
520
521 let text =
523 "John Smith works at Google. Mary Johnson is at Apple. Bob Williams joined Microsoft.";
524 let entities: Vec<Entity> = extractor.extract(text).collect();
525
526 assert!(!entities.is_empty());
528 }
529
530 #[test]
531 fn test_pipeline() {
532 let model = HeuristicNER::new();
533 let pipeline = Pipeline::new(model)
534 .add_stage(Box::new(ConfidenceFilter::new(0.5)))
535 .add_stage(Box::new(DeduplicateOverlapping));
536
537 let text = "John Smith works at Google Inc.";
538 let entities = pipeline.extract(text).unwrap();
539
540 for entity in &entities {
542 assert!(entity.confidence >= 0.5);
543 }
544 }
545
546 #[test]
547 fn test_chunk_config_presets() {
548 let _no_chunk = ChunkConfig::no_chunking();
549 let _long = ChunkConfig::long_document();
550 let _realtime = ChunkConfig::realtime();
551 }
552
553 #[test]
554 fn test_find_sentence_boundary() {
555 let text: Vec<char> = "Hello world. This is a test.".chars().collect();
556 let boundary = find_sentence_boundary(&text, 0, 20);
557 assert!(boundary > 0);
559 assert!(boundary <= 20);
560 }
561
562 #[test]
563 fn test_entity_deduplication_across_chunks() {
564 let model = HeuristicNER::new();
567
568 let config = ChunkConfig {
570 chunk_size: 100,
571 overlap: 20,
572 respect_sentences: false,
573 buffer_size: 100,
574 };
575 let extractor = StreamingExtractor::new(&model, config);
576
577 let text = "I work at Google Inc in California. Then I visited Google headquarters.";
578 let entities: Vec<Entity> = extractor.extract(text).collect();
579
580 assert!(
583 entities.len() < 100,
584 "Possible infinite loop: too many entities"
585 );
586 }
587
588 #[test]
589 fn test_empty_text_streaming() {
590 let model = HeuristicNER::new();
591 let extractor = StreamingExtractor::with_model(&model);
592
593 let entities: Vec<Entity> = extractor.extract("").collect();
594 assert!(entities.is_empty());
595 }
596
597 #[test]
598 fn test_unicode_text_streaming() {
599 let model = HeuristicNER::new();
600 let extractor = StreamingExtractor::with_model(&model);
601
602 let text = "東京 is the capital of 日本. Paris is in France.";
603 let entities: Vec<Entity> = extractor.extract(text).collect();
604
605 let char_count = text.chars().count();
607 for entity in &entities {
608 assert!(entity.start <= entity.end, "Invalid span");
609 assert!(entity.end <= char_count, "Offset exceeds text length");
610 }
611 }
612
613 #[test]
614 fn test_forward_progress_guaranteed() {
615 let model = HeuristicNER::new();
617
618 let config = ChunkConfig {
619 chunk_size: 5, overlap: 3, respect_sentences: false,
622 buffer_size: 10,
623 };
624 let extractor = StreamingExtractor::new(&model, config);
625
626 let text = "abc def";
628
629 let entities: Vec<Entity> = extractor.extract(text).collect();
631 let _ = entities;
633 }
634}