use crate::{Document, RagError, Result, SearchResult};
use std::sync::Arc;
#[derive(Debug, Clone)]
pub struct BatchProgress {
pub completed: usize,
pub total: Option<usize>,
pub batch_number: usize,
pub batch_size: usize,
pub elapsed_ms: u64,
pub docs_per_sec: f64,
}
impl BatchProgress {
pub fn percent(&self) -> Option<f64> {
self.total
.map(|t| (self.completed as f64 / t as f64) * 100.0)
}
pub fn eta_ms(&self) -> Option<u64> {
if self.docs_per_sec > 0.0 {
self.total.map(|t| {
let remaining = t.saturating_sub(self.completed);
((remaining as f64) / self.docs_per_sec * 1000.0) as u64
})
} else {
None
}
}
}
pub type ProgressCallback = Arc<dyn Fn(&BatchProgress) + Send + Sync>;
#[derive(Clone)]
pub struct BatchConfig {
pub batch_size: usize,
pub progress_callback: Option<ProgressCallback>,
pub total_documents: Option<usize>,
pub validate_dimensions: bool,
pub continue_on_error: bool,
}
impl Default for BatchConfig {
fn default() -> Self {
Self {
batch_size: 1000,
progress_callback: None,
total_documents: None,
validate_dimensions: true,
continue_on_error: false,
}
}
}
impl BatchConfig {
pub fn with_batch_size(mut self, size: usize) -> Self {
self.batch_size = size;
self
}
pub fn with_progress<F>(mut self, callback: F) -> Self
where
F: Fn(&BatchProgress) + Send + Sync + 'static,
{
self.progress_callback = Some(Arc::new(callback));
self
}
pub fn with_total(mut self, total: usize) -> Self {
self.total_documents = Some(total);
self
}
pub fn with_validation(mut self, validate: bool) -> Self {
self.validate_dimensions = validate;
self
}
pub fn continue_on_error(mut self, continue_: bool) -> Self {
self.continue_on_error = continue_;
self
}
}
pub trait BatchIndex {
fn add_document(&mut self, doc: Document) -> Result<()>;
fn embedding_dim(&self) -> usize;
fn len(&self) -> usize;
fn is_empty(&self) -> bool {
self.len() == 0
}
}
impl<T: crate::index::VectorIndex> BatchIndex for T {
fn add_document(&mut self, doc: Document) -> Result<()> {
crate::index::VectorIndex::add(self, doc)
}
fn embedding_dim(&self) -> usize {
crate::index::VectorIndex::embedding_dim(self)
}
fn len(&self) -> usize {
crate::index::VectorIndex::len(self)
}
}
pub struct BatchBuilder<'a, I: BatchIndex> {
index: &'a mut I,
config: BatchConfig,
completed: usize,
batch_count: usize,
errors: Vec<(String, RagError)>,
start_time: std::time::Instant,
}
impl<'a, I: BatchIndex> BatchBuilder<'a, I> {
pub fn new(index: &'a mut I, config: BatchConfig) -> Self {
Self {
index,
config,
completed: 0,
batch_count: 0,
errors: Vec::new(),
start_time: std::time::Instant::now(),
}
}
pub fn add(&mut self, doc: Document) -> Result<()> {
let doc_id = doc.id.clone();
if self.config.validate_dimensions && doc.embedding.len() != self.index.embedding_dim() {
let err = RagError::DimensionMismatch {
expected: self.index.embedding_dim(),
actual: doc.embedding.len(),
};
if self.config.continue_on_error {
self.errors.push((doc_id, err));
return Ok(());
} else {
return Err(err);
}
}
match self.index.add_document(doc) {
Ok(()) => {
self.completed += 1;
if self.completed % self.config.batch_size == 0 {
self.batch_count += 1;
self.report_progress();
}
Ok(())
}
Err(e) => {
if self.config.continue_on_error {
self.errors.push((doc_id, e));
Ok(())
} else {
Err(e)
}
}
}
}
pub fn add_all<T: IntoIterator<Item = Document>>(&mut self, docs: T) -> Result<()> {
for doc in docs {
self.add(doc)?;
}
Ok(())
}
pub fn finish(mut self) -> BatchResult {
if self.completed % self.config.batch_size != 0 {
self.batch_count += 1;
self.report_progress();
}
BatchResult {
documents_indexed: self.completed,
errors: self.errors,
elapsed_ms: self.start_time.elapsed().as_millis() as u64,
batches_processed: self.batch_count,
}
}
pub fn progress(&self) -> BatchProgress {
let elapsed_ms = self.start_time.elapsed().as_millis() as u64;
let docs_per_sec = if elapsed_ms > 0 {
(self.completed as f64) / (elapsed_ms as f64 / 1000.0)
} else {
0.0
};
BatchProgress {
completed: self.completed,
total: self.config.total_documents,
batch_number: self.batch_count,
batch_size: self.config.batch_size,
elapsed_ms,
docs_per_sec,
}
}
pub fn errors(&self) -> &[(String, RagError)] {
&self.errors
}
fn report_progress(&self) {
if let Some(ref callback) = self.config.progress_callback {
callback(&self.progress());
}
}
}
#[derive(Debug)]
pub struct BatchResult {
pub documents_indexed: usize,
pub errors: Vec<(String, RagError)>,
pub elapsed_ms: u64,
pub batches_processed: usize,
}
impl BatchResult {
pub fn has_errors(&self) -> bool {
!self.errors.is_empty()
}
pub fn throughput(&self) -> f64 {
if self.elapsed_ms > 0 {
(self.documents_indexed as f64) / (self.elapsed_ms as f64 / 1000.0)
} else {
0.0
}
}
}
pub trait StreamingSearchIndex {
fn search_raw(&self, query: &[f32], k: usize) -> Result<Vec<(usize, f32)>>;
fn get_document(&self, idx: usize) -> Option<SearchResult>;
}
pub struct SearchResultIterator {
results: Vec<SearchResult>,
position: usize,
}
impl SearchResultIterator {
pub fn new(results: Vec<SearchResult>) -> Self {
Self {
results,
position: 0,
}
}
pub fn total(&self) -> usize {
self.results.len()
}
pub fn peek(&self) -> Option<&SearchResult> {
self.results.get(self.position)
}
pub fn skip_n(&mut self, n: usize) {
self.position = (self.position + n).min(self.results.len());
}
pub fn collect_remaining(self) -> Vec<SearchResult> {
self.results.into_iter().skip(self.position).collect()
}
}
impl Iterator for SearchResultIterator {
type Item = SearchResult;
fn next(&mut self) -> Option<Self::Item> {
if self.position < self.results.len() {
let result = self.results[self.position].clone();
self.position += 1;
Some(result)
} else {
None
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
let remaining = self.results.len() - self.position;
(remaining, Some(remaining))
}
}
impl ExactSizeIterator for SearchResultIterator {}
#[derive(Debug, Clone)]
pub struct PaginationConfig {
pub page_size: usize,
pub oversample: f32,
}
impl Default for PaginationConfig {
fn default() -> Self {
Self {
page_size: 10,
oversample: 2.0,
}
}
}
impl PaginationConfig {
pub fn with_page_size(mut self, size: usize) -> Self {
self.page_size = size;
self
}
}
#[derive(Debug, Clone)]
pub struct SearchPage {
pub results: Vec<SearchResult>,
pub page: usize,
pub total_pages: usize,
pub total_results: usize,
pub has_next: bool,
pub has_prev: bool,
}
impl SearchPage {
pub fn from_results(all_results: Vec<SearchResult>, page: usize, page_size: usize) -> Self {
let total_results = all_results.len();
let total_pages = (total_results + page_size - 1) / page_size;
let start = page * page_size;
let end = (start + page_size).min(total_results);
let results = if start < total_results {
all_results[start..end].to_vec()
} else {
Vec::new()
};
Self {
results,
page,
total_pages,
total_results,
has_next: page + 1 < total_pages,
has_prev: page > 0,
}
}
}
pub type SearchFilter = Box<dyn Fn(&SearchResult) -> bool + Send + Sync>;
pub struct FilteredSearchBuilder {
filters: Vec<SearchFilter>,
min_score: Option<f32>,
max_results: Option<usize>,
}
impl FilteredSearchBuilder {
pub fn new() -> Self {
Self {
filters: Vec::new(),
min_score: None,
max_results: None,
}
}
pub fn filter<F>(mut self, f: F) -> Self
where
F: Fn(&SearchResult) -> bool + Send + Sync + 'static,
{
self.filters.push(Box::new(f));
self
}
pub fn min_score(mut self, score: f32) -> Self {
self.min_score = Some(score);
self
}
pub fn max_results(mut self, max: usize) -> Self {
self.max_results = Some(max);
self
}
pub fn has_metadata_field(self, field: &'static str) -> Self {
self.filter(move |r| {
r.metadata
.as_ref()
.map(|m| m.get(field).is_some())
.unwrap_or(false)
})
}
pub fn metadata_equals(self, field: &'static str, value: serde_json::Value) -> Self {
self.filter(move |r| {
r.metadata
.as_ref()
.and_then(|m| m.get(field))
.map(|v| *v == value)
.unwrap_or(false)
})
}
pub fn apply(&self, results: Vec<SearchResult>) -> Vec<SearchResult> {
let mut filtered: Vec<SearchResult> = results
.into_iter()
.filter(|r| {
if let Some(min) = self.min_score {
if r.score < min {
return false;
}
}
for filter in &self.filters {
if !filter(r) {
return false;
}
}
true
})
.collect();
if let Some(max) = self.max_results {
filtered.truncate(max);
}
filtered
}
}
impl Default for FilteredSearchBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::index::HNSWIndex;
fn create_test_document(id: &str, embedding: Vec<f32>) -> Document {
Document {
id: id.to_string(),
content: format!("Content for {}", id),
embedding,
metadata: Some(serde_json::json!({"category": "test"})),
}
}
fn generate_random_vector(dim: usize, seed: u64) -> Vec<f32> {
use rand::SeedableRng;
let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
(0..dim)
.map(|_| rand::Rng::gen_range(&mut rng, -1.0..1.0))
.collect()
}
#[test]
fn test_batch_config_builder() {
let config = BatchConfig::default()
.with_batch_size(500)
.with_total(10000)
.with_validation(false)
.continue_on_error(true);
assert_eq!(config.batch_size, 500);
assert_eq!(config.total_documents, Some(10000));
assert!(!config.validate_dimensions);
assert!(config.continue_on_error);
}
#[test]
fn test_batch_builder_basic() {
let mut index = HNSWIndex::with_defaults(128);
let config = BatchConfig::default().with_batch_size(10);
let mut builder = BatchBuilder::new(&mut index, config);
for i in 0..25 {
let doc = create_test_document(&format!("doc{}", i), generate_random_vector(128, i));
builder.add(doc).unwrap();
}
let result = builder.finish();
assert_eq!(result.documents_indexed, 25);
assert!(!result.has_errors());
assert_eq!(result.batches_processed, 3); assert_eq!(index.len(), 25);
}
#[test]
fn test_batch_builder_with_progress() {
use std::sync::atomic::{AtomicUsize, Ordering};
let progress_count = Arc::new(AtomicUsize::new(0));
let progress_count_clone = progress_count.clone();
let mut index = HNSWIndex::with_defaults(128);
let config = BatchConfig::default()
.with_batch_size(10)
.with_progress(move |_p| {
progress_count_clone.fetch_add(1, Ordering::SeqCst);
});
let mut builder = BatchBuilder::new(&mut index, config);
for i in 0..35 {
let doc = create_test_document(&format!("doc{}", i), generate_random_vector(128, i));
builder.add(doc).unwrap();
}
let _result = builder.finish();
assert_eq!(progress_count.load(Ordering::SeqCst), 4);
}
#[test]
fn test_batch_builder_dimension_error() {
let mut index = HNSWIndex::with_defaults(128);
let config = BatchConfig::default();
let mut builder = BatchBuilder::new(&mut index, config);
let doc = create_test_document("doc1", generate_random_vector(128, 1));
assert!(builder.add(doc).is_ok());
let doc = create_test_document("doc2", generate_random_vector(64, 2));
assert!(builder.add(doc).is_err());
}
#[test]
fn test_batch_builder_continue_on_error() {
let mut index = HNSWIndex::with_defaults(128);
let config = BatchConfig::default().continue_on_error(true);
let mut builder = BatchBuilder::new(&mut index, config);
let doc = create_test_document("doc1", generate_random_vector(128, 1));
builder.add(doc).unwrap();
let doc = create_test_document("doc2", generate_random_vector(64, 2));
builder.add(doc).unwrap();
let doc = create_test_document("doc3", generate_random_vector(128, 3));
builder.add(doc).unwrap();
let result = builder.finish();
assert_eq!(result.documents_indexed, 2);
assert!(result.has_errors());
assert_eq!(result.errors.len(), 1);
assert_eq!(result.errors[0].0, "doc2");
}
#[test]
fn test_batch_progress_eta() {
let progress = BatchProgress {
completed: 5000,
total: Some(10000),
batch_number: 5,
batch_size: 1000,
elapsed_ms: 5000,
docs_per_sec: 1000.0,
};
assert_eq!(progress.percent(), Some(50.0));
assert_eq!(progress.eta_ms(), Some(5000)); }
#[test]
fn test_search_result_iterator() {
let results = vec![
SearchResult {
id: "doc1".to_string(),
content: "Content 1".to_string(),
score: 0.9,
metadata: None,
},
SearchResult {
id: "doc2".to_string(),
content: "Content 2".to_string(),
score: 0.8,
metadata: None,
},
SearchResult {
id: "doc3".to_string(),
content: "Content 3".to_string(),
score: 0.7,
metadata: None,
},
];
let mut iter = SearchResultIterator::new(results);
assert_eq!(iter.total(), 3);
assert_eq!(iter.peek().unwrap().id, "doc1");
let first = iter.next().unwrap();
assert_eq!(first.id, "doc1");
let second = iter.next().unwrap();
assert_eq!(second.id, "doc2");
let remaining = iter.collect_remaining();
assert_eq!(remaining.len(), 1);
assert_eq!(remaining[0].id, "doc3");
}
#[test]
fn test_search_page() {
let results: Vec<SearchResult> = (0..25)
.map(|i| SearchResult {
id: format!("doc{}", i),
content: format!("Content {}", i),
score: 1.0 - (i as f32 * 0.01),
metadata: None,
})
.collect();
let page0 = SearchPage::from_results(results.clone(), 0, 10);
assert_eq!(page0.results.len(), 10);
assert_eq!(page0.page, 0);
assert_eq!(page0.total_pages, 3);
assert!(page0.has_next);
assert!(!page0.has_prev);
let page1 = SearchPage::from_results(results.clone(), 1, 10);
assert_eq!(page1.results.len(), 10);
assert!(page1.has_next);
assert!(page1.has_prev);
let page2 = SearchPage::from_results(results.clone(), 2, 10);
assert_eq!(page2.results.len(), 5);
assert!(!page2.has_next);
assert!(page2.has_prev);
}
#[test]
fn test_filtered_search_builder() {
let results: Vec<SearchResult> = vec![
SearchResult {
id: "doc1".to_string(),
content: "Content 1".to_string(),
score: 0.9,
metadata: Some(serde_json::json!({"category": "A"})),
},
SearchResult {
id: "doc2".to_string(),
content: "Content 2".to_string(),
score: 0.5,
metadata: Some(serde_json::json!({"category": "B"})),
},
SearchResult {
id: "doc3".to_string(),
content: "Content 3".to_string(),
score: 0.3,
metadata: None,
},
];
let filtered = FilteredSearchBuilder::new()
.min_score(0.4)
.apply(results.clone());
assert_eq!(filtered.len(), 2);
let filtered = FilteredSearchBuilder::new()
.has_metadata_field("category")
.apply(results.clone());
assert_eq!(filtered.len(), 2);
let filtered = FilteredSearchBuilder::new()
.metadata_equals("category", serde_json::json!("A"))
.apply(results.clone());
assert_eq!(filtered.len(), 1);
assert_eq!(filtered[0].id, "doc1");
let filtered = FilteredSearchBuilder::new()
.min_score(0.4)
.max_results(1)
.apply(results);
assert_eq!(filtered.len(), 1);
}
}