1use crate::{Document, RagError, Result, SearchResult};
58use std::sync::Arc;
59
60#[derive(Debug, Clone)]
66pub struct BatchProgress {
67 pub completed: usize,
69 pub total: Option<usize>,
71 pub batch_number: usize,
73 pub batch_size: usize,
75 pub elapsed_ms: u64,
77 pub docs_per_sec: f64,
79}
80
81impl BatchProgress {
82 pub fn percent(&self) -> Option<f64> {
84 self.total
85 .map(|t| (self.completed as f64 / t as f64) * 100.0)
86 }
87
88 pub fn eta_ms(&self) -> Option<u64> {
90 if self.docs_per_sec > 0.0 {
91 self.total.map(|t| {
92 let remaining = t.saturating_sub(self.completed);
93 ((remaining as f64) / self.docs_per_sec * 1000.0) as u64
94 })
95 } else {
96 None
97 }
98 }
99}
100
101pub type ProgressCallback = Arc<dyn Fn(&BatchProgress) + Send + Sync>;
103
104#[derive(Clone)]
106pub struct BatchConfig {
107 pub batch_size: usize,
109 pub progress_callback: Option<ProgressCallback>,
111 pub total_documents: Option<usize>,
113 pub validate_dimensions: bool,
115 pub continue_on_error: bool,
117}
118
119impl Default for BatchConfig {
120 fn default() -> Self {
121 Self {
122 batch_size: 1000,
123 progress_callback: None,
124 total_documents: None,
125 validate_dimensions: true,
126 continue_on_error: false,
127 }
128 }
129}
130
131impl BatchConfig {
132 pub fn with_batch_size(mut self, size: usize) -> Self {
134 self.batch_size = size;
135 self
136 }
137
138 pub fn with_progress<F>(mut self, callback: F) -> Self
140 where
141 F: Fn(&BatchProgress) + Send + Sync + 'static,
142 {
143 self.progress_callback = Some(Arc::new(callback));
144 self
145 }
146
147 pub fn with_total(mut self, total: usize) -> Self {
149 self.total_documents = Some(total);
150 self
151 }
152
153 pub fn with_validation(mut self, validate: bool) -> Self {
155 self.validate_dimensions = validate;
156 self
157 }
158
159 pub fn continue_on_error(mut self, continue_: bool) -> Self {
161 self.continue_on_error = continue_;
162 self
163 }
164}
165
166pub trait BatchIndex {
172 fn add_document(&mut self, doc: Document) -> Result<()>;
174
175 fn embedding_dim(&self) -> usize;
177
178 fn len(&self) -> usize;
180
181 fn is_empty(&self) -> bool {
183 self.len() == 0
184 }
185}
186
187impl<T: crate::index::VectorIndex> BatchIndex for T {
190 fn add_document(&mut self, doc: Document) -> Result<()> {
191 crate::index::VectorIndex::add(self, doc)
192 }
193
194 fn embedding_dim(&self) -> usize {
195 crate::index::VectorIndex::embedding_dim(self)
196 }
197
198 fn len(&self) -> usize {
199 crate::index::VectorIndex::len(self)
200 }
201}
202
203pub struct BatchBuilder<'a, I: BatchIndex> {
211 index: &'a mut I,
212 config: BatchConfig,
213 completed: usize,
214 batch_count: usize,
215 errors: Vec<(String, RagError)>,
216 start_time: std::time::Instant,
217}
218
219impl<'a, I: BatchIndex> BatchBuilder<'a, I> {
220 pub fn new(index: &'a mut I, config: BatchConfig) -> Self {
222 Self {
223 index,
224 config,
225 completed: 0,
226 batch_count: 0,
227 errors: Vec::new(),
228 start_time: std::time::Instant::now(),
229 }
230 }
231
232 pub fn add(&mut self, doc: Document) -> Result<()> {
234 let doc_id = doc.id.clone();
235
236 if self.config.validate_dimensions && doc.embedding.len() != self.index.embedding_dim() {
238 let err = RagError::DimensionMismatch {
239 expected: self.index.embedding_dim(),
240 actual: doc.embedding.len(),
241 };
242
243 if self.config.continue_on_error {
244 self.errors.push((doc_id, err));
245 return Ok(());
246 } else {
247 return Err(err);
248 }
249 }
250
251 match self.index.add_document(doc) {
253 Ok(()) => {
254 self.completed += 1;
255
256 if self.completed % self.config.batch_size == 0 {
258 self.batch_count += 1;
259 self.report_progress();
260 }
261
262 Ok(())
263 }
264 Err(e) => {
265 if self.config.continue_on_error {
266 self.errors.push((doc_id, e));
267 Ok(())
268 } else {
269 Err(e)
270 }
271 }
272 }
273 }
274
275 pub fn add_all<T: IntoIterator<Item = Document>>(&mut self, docs: T) -> Result<()> {
277 for doc in docs {
278 self.add(doc)?;
279 }
280 Ok(())
281 }
282
283 pub fn finish(mut self) -> BatchResult {
285 if self.completed % self.config.batch_size != 0 {
287 self.batch_count += 1;
288 self.report_progress();
289 }
290
291 BatchResult {
292 documents_indexed: self.completed,
293 errors: self.errors,
294 elapsed_ms: self.start_time.elapsed().as_millis() as u64,
295 batches_processed: self.batch_count,
296 }
297 }
298
299 pub fn progress(&self) -> BatchProgress {
301 let elapsed_ms = self.start_time.elapsed().as_millis() as u64;
302 let docs_per_sec = if elapsed_ms > 0 {
303 (self.completed as f64) / (elapsed_ms as f64 / 1000.0)
304 } else {
305 0.0
306 };
307
308 BatchProgress {
309 completed: self.completed,
310 total: self.config.total_documents,
311 batch_number: self.batch_count,
312 batch_size: self.config.batch_size,
313 elapsed_ms,
314 docs_per_sec,
315 }
316 }
317
318 pub fn errors(&self) -> &[(String, RagError)] {
320 &self.errors
321 }
322
323 fn report_progress(&self) {
324 if let Some(ref callback) = self.config.progress_callback {
325 callback(&self.progress());
326 }
327 }
328}
329
330#[derive(Debug)]
332pub struct BatchResult {
333 pub documents_indexed: usize,
335 pub errors: Vec<(String, RagError)>,
337 pub elapsed_ms: u64,
339 pub batches_processed: usize,
341}
342
343impl BatchResult {
344 pub fn has_errors(&self) -> bool {
346 !self.errors.is_empty()
347 }
348
349 pub fn throughput(&self) -> f64 {
351 if self.elapsed_ms > 0 {
352 (self.documents_indexed as f64) / (self.elapsed_ms as f64 / 1000.0)
353 } else {
354 0.0
355 }
356 }
357}
358
359pub trait StreamingSearchIndex {
365 fn search_raw(&self, query: &[f32], k: usize) -> Result<Vec<(usize, f32)>>;
367
368 fn get_document(&self, idx: usize) -> Option<SearchResult>;
370}
371
372pub struct SearchResultIterator {
379 results: Vec<SearchResult>,
380 position: usize,
381}
382
383impl SearchResultIterator {
384 pub fn new(results: Vec<SearchResult>) -> Self {
386 Self {
387 results,
388 position: 0,
389 }
390 }
391
392 pub fn total(&self) -> usize {
394 self.results.len()
395 }
396
397 pub fn peek(&self) -> Option<&SearchResult> {
399 self.results.get(self.position)
400 }
401
402 pub fn skip_n(&mut self, n: usize) {
404 self.position = (self.position + n).min(self.results.len());
405 }
406
407 pub fn collect_remaining(self) -> Vec<SearchResult> {
409 self.results.into_iter().skip(self.position).collect()
410 }
411}
412
413impl Iterator for SearchResultIterator {
414 type Item = SearchResult;
415
416 fn next(&mut self) -> Option<Self::Item> {
417 if self.position < self.results.len() {
418 let result = self.results[self.position].clone();
419 self.position += 1;
420 Some(result)
421 } else {
422 None
423 }
424 }
425
426 fn size_hint(&self) -> (usize, Option<usize>) {
427 let remaining = self.results.len() - self.position;
428 (remaining, Some(remaining))
429 }
430}
431
432impl ExactSizeIterator for SearchResultIterator {}
433
434#[derive(Debug, Clone)]
440pub struct PaginationConfig {
441 pub page_size: usize,
443 pub oversample: f32,
445}
446
447impl Default for PaginationConfig {
448 fn default() -> Self {
449 Self {
450 page_size: 10,
451 oversample: 2.0,
452 }
453 }
454}
455
456impl PaginationConfig {
457 pub fn with_page_size(mut self, size: usize) -> Self {
459 self.page_size = size;
460 self
461 }
462}
463
464#[derive(Debug, Clone)]
466pub struct SearchPage {
467 pub results: Vec<SearchResult>,
469 pub page: usize,
471 pub total_pages: usize,
473 pub total_results: usize,
475 pub has_next: bool,
477 pub has_prev: bool,
479}
480
481impl SearchPage {
482 pub fn from_results(all_results: Vec<SearchResult>, page: usize, page_size: usize) -> Self {
484 let total_results = all_results.len();
485 let total_pages = (total_results + page_size - 1) / page_size;
486 let start = page * page_size;
487 let end = (start + page_size).min(total_results);
488
489 let results = if start < total_results {
490 all_results[start..end].to_vec()
491 } else {
492 Vec::new()
493 };
494
495 Self {
496 results,
497 page,
498 total_pages,
499 total_results,
500 has_next: page + 1 < total_pages,
501 has_prev: page > 0,
502 }
503 }
504}
505
506pub type SearchFilter = Box<dyn Fn(&SearchResult) -> bool + Send + Sync>;
512
513pub struct FilteredSearchBuilder {
515 filters: Vec<SearchFilter>,
516 min_score: Option<f32>,
517 max_results: Option<usize>,
518}
519
520impl FilteredSearchBuilder {
521 pub fn new() -> Self {
523 Self {
524 filters: Vec::new(),
525 min_score: None,
526 max_results: None,
527 }
528 }
529
530 pub fn filter<F>(mut self, f: F) -> Self
532 where
533 F: Fn(&SearchResult) -> bool + Send + Sync + 'static,
534 {
535 self.filters.push(Box::new(f));
536 self
537 }
538
539 pub fn min_score(mut self, score: f32) -> Self {
541 self.min_score = Some(score);
542 self
543 }
544
545 pub fn max_results(mut self, max: usize) -> Self {
547 self.max_results = Some(max);
548 self
549 }
550
551 pub fn has_metadata_field(self, field: &'static str) -> Self {
553 self.filter(move |r| {
554 r.metadata
555 .as_ref()
556 .map(|m| m.get(field).is_some())
557 .unwrap_or(false)
558 })
559 }
560
561 pub fn metadata_equals(self, field: &'static str, value: serde_json::Value) -> Self {
563 self.filter(move |r| {
564 r.metadata
565 .as_ref()
566 .and_then(|m| m.get(field))
567 .map(|v| *v == value)
568 .unwrap_or(false)
569 })
570 }
571
572 pub fn apply(&self, results: Vec<SearchResult>) -> Vec<SearchResult> {
574 let mut filtered: Vec<SearchResult> = results
575 .into_iter()
576 .filter(|r| {
577 if let Some(min) = self.min_score {
579 if r.score < min {
580 return false;
581 }
582 }
583
584 for filter in &self.filters {
586 if !filter(r) {
587 return false;
588 }
589 }
590
591 true
592 })
593 .collect();
594
595 if let Some(max) = self.max_results {
597 filtered.truncate(max);
598 }
599
600 filtered
601 }
602}
603
604impl Default for FilteredSearchBuilder {
605 fn default() -> Self {
606 Self::new()
607 }
608}
609
610#[cfg(test)]
615mod tests {
616 use super::*;
617 use crate::index::HNSWIndex;
618
619 fn create_test_document(id: &str, embedding: Vec<f32>) -> Document {
620 Document {
621 id: id.to_string(),
622 content: format!("Content for {}", id),
623 embedding,
624 metadata: Some(serde_json::json!({"category": "test"})),
625 }
626 }
627
628 fn generate_random_vector(dim: usize, seed: u64) -> Vec<f32> {
629 use rand::SeedableRng;
630 let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
631 (0..dim)
632 .map(|_| rand::Rng::gen_range(&mut rng, -1.0..1.0))
633 .collect()
634 }
635
636 #[test]
637 fn test_batch_config_builder() {
638 let config = BatchConfig::default()
639 .with_batch_size(500)
640 .with_total(10000)
641 .with_validation(false)
642 .continue_on_error(true);
643
644 assert_eq!(config.batch_size, 500);
645 assert_eq!(config.total_documents, Some(10000));
646 assert!(!config.validate_dimensions);
647 assert!(config.continue_on_error);
648 }
649
650 #[test]
651 fn test_batch_builder_basic() {
652 let mut index = HNSWIndex::with_defaults(128);
653 let config = BatchConfig::default().with_batch_size(10);
654
655 let mut builder = BatchBuilder::new(&mut index, config);
656
657 for i in 0..25 {
658 let doc = create_test_document(&format!("doc{}", i), generate_random_vector(128, i));
659 builder.add(doc).unwrap();
660 }
661
662 let result = builder.finish();
663
664 assert_eq!(result.documents_indexed, 25);
665 assert!(!result.has_errors());
666 assert_eq!(result.batches_processed, 3); assert_eq!(index.len(), 25);
668 }
669
670 #[test]
671 fn test_batch_builder_with_progress() {
672 use std::sync::atomic::{AtomicUsize, Ordering};
673
674 let progress_count = Arc::new(AtomicUsize::new(0));
675 let progress_count_clone = progress_count.clone();
676
677 let mut index = HNSWIndex::with_defaults(128);
678 let config = BatchConfig::default()
679 .with_batch_size(10)
680 .with_progress(move |_p| {
681 progress_count_clone.fetch_add(1, Ordering::SeqCst);
682 });
683
684 let mut builder = BatchBuilder::new(&mut index, config);
685
686 for i in 0..35 {
687 let doc = create_test_document(&format!("doc{}", i), generate_random_vector(128, i));
688 builder.add(doc).unwrap();
689 }
690
691 let _result = builder.finish();
692
693 assert_eq!(progress_count.load(Ordering::SeqCst), 4);
695 }
696
697 #[test]
698 fn test_batch_builder_dimension_error() {
699 let mut index = HNSWIndex::with_defaults(128);
700 let config = BatchConfig::default();
701
702 let mut builder = BatchBuilder::new(&mut index, config);
703
704 let doc = create_test_document("doc1", generate_random_vector(128, 1));
706 assert!(builder.add(doc).is_ok());
707
708 let doc = create_test_document("doc2", generate_random_vector(64, 2));
710 assert!(builder.add(doc).is_err());
711 }
712
713 #[test]
714 fn test_batch_builder_continue_on_error() {
715 let mut index = HNSWIndex::with_defaults(128);
716 let config = BatchConfig::default().continue_on_error(true);
717
718 let mut builder = BatchBuilder::new(&mut index, config);
719
720 let doc = create_test_document("doc1", generate_random_vector(128, 1));
722 builder.add(doc).unwrap();
723
724 let doc = create_test_document("doc2", generate_random_vector(64, 2));
726 builder.add(doc).unwrap();
727
728 let doc = create_test_document("doc3", generate_random_vector(128, 3));
730 builder.add(doc).unwrap();
731
732 let result = builder.finish();
733
734 assert_eq!(result.documents_indexed, 2);
735 assert!(result.has_errors());
736 assert_eq!(result.errors.len(), 1);
737 assert_eq!(result.errors[0].0, "doc2");
738 }
739
740 #[test]
741 fn test_batch_progress_eta() {
742 let progress = BatchProgress {
743 completed: 5000,
744 total: Some(10000),
745 batch_number: 5,
746 batch_size: 1000,
747 elapsed_ms: 5000,
748 docs_per_sec: 1000.0,
749 };
750
751 assert_eq!(progress.percent(), Some(50.0));
752 assert_eq!(progress.eta_ms(), Some(5000)); }
754
755 #[test]
756 fn test_search_result_iterator() {
757 let results = vec![
758 SearchResult {
759 id: "doc1".to_string(),
760 content: "Content 1".to_string(),
761 score: 0.9,
762 metadata: None,
763 },
764 SearchResult {
765 id: "doc2".to_string(),
766 content: "Content 2".to_string(),
767 score: 0.8,
768 metadata: None,
769 },
770 SearchResult {
771 id: "doc3".to_string(),
772 content: "Content 3".to_string(),
773 score: 0.7,
774 metadata: None,
775 },
776 ];
777
778 let mut iter = SearchResultIterator::new(results);
779
780 assert_eq!(iter.total(), 3);
781 assert_eq!(iter.peek().unwrap().id, "doc1");
782
783 let first = iter.next().unwrap();
784 assert_eq!(first.id, "doc1");
785
786 let second = iter.next().unwrap();
787 assert_eq!(second.id, "doc2");
788
789 let remaining = iter.collect_remaining();
790 assert_eq!(remaining.len(), 1);
791 assert_eq!(remaining[0].id, "doc3");
792 }
793
794 #[test]
795 fn test_search_page() {
796 let results: Vec<SearchResult> = (0..25)
797 .map(|i| SearchResult {
798 id: format!("doc{}", i),
799 content: format!("Content {}", i),
800 score: 1.0 - (i as f32 * 0.01),
801 metadata: None,
802 })
803 .collect();
804
805 let page0 = SearchPage::from_results(results.clone(), 0, 10);
807 assert_eq!(page0.results.len(), 10);
808 assert_eq!(page0.page, 0);
809 assert_eq!(page0.total_pages, 3);
810 assert!(page0.has_next);
811 assert!(!page0.has_prev);
812
813 let page1 = SearchPage::from_results(results.clone(), 1, 10);
815 assert_eq!(page1.results.len(), 10);
816 assert!(page1.has_next);
817 assert!(page1.has_prev);
818
819 let page2 = SearchPage::from_results(results.clone(), 2, 10);
821 assert_eq!(page2.results.len(), 5);
822 assert!(!page2.has_next);
823 assert!(page2.has_prev);
824 }
825
826 #[test]
827 fn test_filtered_search_builder() {
828 let results: Vec<SearchResult> = vec![
829 SearchResult {
830 id: "doc1".to_string(),
831 content: "Content 1".to_string(),
832 score: 0.9,
833 metadata: Some(serde_json::json!({"category": "A"})),
834 },
835 SearchResult {
836 id: "doc2".to_string(),
837 content: "Content 2".to_string(),
838 score: 0.5,
839 metadata: Some(serde_json::json!({"category": "B"})),
840 },
841 SearchResult {
842 id: "doc3".to_string(),
843 content: "Content 3".to_string(),
844 score: 0.3,
845 metadata: None,
846 },
847 ];
848
849 let filtered = FilteredSearchBuilder::new()
851 .min_score(0.4)
852 .apply(results.clone());
853 assert_eq!(filtered.len(), 2);
854
855 let filtered = FilteredSearchBuilder::new()
857 .has_metadata_field("category")
858 .apply(results.clone());
859 assert_eq!(filtered.len(), 2);
860
861 let filtered = FilteredSearchBuilder::new()
863 .metadata_equals("category", serde_json::json!("A"))
864 .apply(results.clone());
865 assert_eq!(filtered.len(), 1);
866 assert_eq!(filtered[0].id, "doc1");
867
868 let filtered = FilteredSearchBuilder::new()
870 .min_score(0.4)
871 .max_results(1)
872 .apply(results);
873 assert_eq!(filtered.len(), 1);
874 }
875}