use crate::config::ShardexConfig;
use crate::error::ShardexError;
use crate::identifiers::DocumentId;
use crate::structures::Posting;
use std::path::PathBuf;
#[derive(Debug, Clone)]
pub struct CreateIndexParams {
pub directory_path: PathBuf,
pub vector_size: usize,
pub shard_size: usize,
pub batch_write_interval_ms: u64,
pub wal_segment_size: usize,
pub bloom_filter_size: usize,
pub default_slop_factor: u32,
}
impl CreateIndexParams {
pub fn builder() -> CreateIndexParamsBuilder {
CreateIndexParamsBuilder::default()
}
pub fn high_performance(directory: PathBuf) -> Self {
Self {
directory_path: directory,
vector_size: 256,
shard_size: 15000,
batch_write_interval_ms: 75,
wal_segment_size: 2 * 1024 * 1024, bloom_filter_size: 2048,
default_slop_factor: 4,
}
}
pub fn memory_optimized(directory: PathBuf) -> Self {
Self {
directory_path: directory,
vector_size: 128,
shard_size: 5000,
batch_write_interval_ms: 200,
wal_segment_size: 256 * 1024, bloom_filter_size: 512,
default_slop_factor: 2,
}
}
pub fn from_shardex_config(config: ShardexConfig) -> Self {
Self::from(config)
}
pub fn validate(&self) -> Result<(), ShardexError> {
if self.vector_size == 0 {
return Err(ShardexError::config_error(
"vector_size",
"must be greater than 0",
"Set vector_size to match your embedding model dimensions (e.g., 384, 768, 1536)",
));
}
if self.vector_size > 100_000 {
return Err(ShardexError::config_error(
"vector_size",
"exceeds maximum supported size",
"Use vector_size <= 100,000 dimensions for optimal performance",
));
}
if self.shard_size == 0 {
return Err(ShardexError::config_error(
"shard_size",
"must be greater than 0",
"Set shard_size to control memory usage (recommended: 10,000-100,000)",
));
}
if self.shard_size > 10_000_000 {
return Err(ShardexError::config_error(
"shard_size",
"exceeds recommended maximum",
"Use shard_size <= 10,000,000 to avoid excessive memory usage",
));
}
if self.batch_write_interval_ms == 0 {
return Err(ShardexError::config_error(
"batch_write_interval_ms",
"must be greater than 0",
"Set batch_write_interval_ms to balance throughput and latency (recommended: 50-500ms)",
));
}
if self.wal_segment_size < 1024 {
return Err(ShardexError::config_error(
"wal_segment_size",
"must be at least 1KB",
"Set wal_segment_size to at least 1024 bytes for proper WAL operation",
));
}
if self.bloom_filter_size == 0 {
return Err(ShardexError::config_error(
"bloom_filter_size",
"must be greater than 0",
"Set bloom_filter_size to optimize memory vs false positive trade-off (recommended: 1024-65536)",
));
}
if self.default_slop_factor == 0 {
return Err(ShardexError::config_error(
"default_slop_factor",
"must be greater than 0",
"Set default_slop_factor to control search breadth (recommended: 3-10)",
));
}
if self.default_slop_factor > 1000 {
return Err(ShardexError::config_error(
"default_slop_factor",
"exceeds reasonable maximum",
"Use default_slop_factor <= 1000 to maintain search performance",
));
}
if self.directory_path.as_os_str().is_empty() {
return Err(ShardexError::config_error(
"directory_path",
"cannot be empty",
"Provide a valid directory path where the index will be stored",
));
}
Ok(())
}
}
impl From<ShardexConfig> for CreateIndexParams {
fn from(config: ShardexConfig) -> Self {
Self {
directory_path: config.directory_path,
vector_size: config.vector_size,
shard_size: config.shard_size,
batch_write_interval_ms: config.batch_write_interval_ms,
wal_segment_size: config.wal_segment_size,
bloom_filter_size: config.bloom_filter_size,
default_slop_factor: config.slop_factor_config.default_factor as u32,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct CreateIndexParamsBuilder {
directory_path: Option<PathBuf>,
vector_size: Option<usize>,
shard_size: Option<usize>,
batch_write_interval_ms: Option<u64>,
wal_segment_size: Option<usize>,
bloom_filter_size: Option<usize>,
default_slop_factor: Option<u32>,
}
impl CreateIndexParamsBuilder {
pub fn directory_path(mut self, path: PathBuf) -> Self {
self.directory_path = Some(path);
self
}
pub fn vector_size(mut self, size: usize) -> Self {
self.vector_size = Some(size);
self
}
pub fn shard_size(mut self, size: usize) -> Self {
self.shard_size = Some(size);
self
}
pub fn batch_write_interval_ms(mut self, ms: u64) -> Self {
self.batch_write_interval_ms = Some(ms);
self
}
pub fn wal_segment_size(mut self, size: usize) -> Self {
self.wal_segment_size = Some(size);
self
}
pub fn bloom_filter_size(mut self, size: usize) -> Self {
self.bloom_filter_size = Some(size);
self
}
pub fn default_slop_factor(mut self, factor: u32) -> Self {
self.default_slop_factor = Some(factor);
self
}
pub fn build(self) -> Result<CreateIndexParams, ShardexError> {
let params = CreateIndexParams {
directory_path: self
.directory_path
.unwrap_or_else(|| PathBuf::from("./shardex_index")),
vector_size: self.vector_size.unwrap_or(384),
shard_size: self.shard_size.unwrap_or(10000),
batch_write_interval_ms: self.batch_write_interval_ms.unwrap_or(100),
wal_segment_size: self.wal_segment_size.unwrap_or(1024 * 1024), bloom_filter_size: self.bloom_filter_size.unwrap_or(1024),
default_slop_factor: self.default_slop_factor.unwrap_or(3),
};
params.validate()?;
Ok(params)
}
}
#[derive(Debug, Clone)]
pub struct AddPostingsParams {
pub postings: Vec<Posting>,
}
impl AddPostingsParams {
pub fn new(postings: Vec<Posting>) -> Result<Self, ShardexError> {
if postings.is_empty() {
return Err(ShardexError::config_error(
"postings",
"cannot be empty",
"Provide at least one posting to add to the index",
));
}
Ok(Self { postings })
}
pub fn validate(&self) -> Result<(), ShardexError> {
if self.postings.is_empty() {
return Err(ShardexError::config_error(
"postings",
"cannot be empty",
"Provide at least one posting to add to the index",
));
}
if let Some(first) = self.postings.first() {
let expected_dim = first.vector.len();
for (i, posting) in self.postings.iter().enumerate() {
if posting.vector.len() != expected_dim {
return Err(ShardexError::config_error(
format!("postings[{}].vector", i),
format!(
"dimension mismatch: expected {}, got {}",
expected_dim,
posting.vector.len()
),
"Ensure all postings have vectors with the same dimensions",
));
}
for (j, &value) in posting.vector.iter().enumerate() {
if !value.is_finite() {
return Err(ShardexError::config_error(
format!("postings[{}].vector[{}]", i, j),
"contains non-finite value",
"Ensure all vector components are finite numbers (not NaN or infinity)",
));
}
}
}
}
Ok(())
}
pub fn len(&self) -> usize {
self.postings.len()
}
pub fn is_empty(&self) -> bool {
self.postings.is_empty()
}
}
#[derive(Debug, Clone)]
pub struct SearchParams {
pub query_vector: Vec<f32>,
pub k: usize,
pub slop_factor: Option<u32>,
}
impl SearchParams {
pub fn builder() -> SearchParamsBuilder {
SearchParamsBuilder::default()
}
pub fn new(query_vector: Vec<f32>, k: usize) -> Result<Self, ShardexError> {
let params = Self {
query_vector,
k,
slop_factor: None,
};
params.validate()?;
Ok(params)
}
pub fn with_slop_factor(query_vector: Vec<f32>, k: usize, slop_factor: u32) -> Result<Self, ShardexError> {
let params = Self {
query_vector,
k,
slop_factor: Some(slop_factor),
};
params.validate()?;
Ok(params)
}
pub fn validate(&self) -> Result<(), ShardexError> {
if self.query_vector.is_empty() {
return Err(ShardexError::config_error(
"query_vector",
"cannot be empty",
"Provide a query vector with at least one dimension",
));
}
for (i, &value) in self.query_vector.iter().enumerate() {
if !value.is_finite() {
return Err(ShardexError::config_error(
format!("query_vector[{}]", i),
"contains non-finite value",
"Ensure all query vector components are finite numbers (not NaN or infinity)",
));
}
}
if self.k == 0 {
return Err(ShardexError::config_error(
"k",
"must be greater than 0",
"Set k to the number of top results you want to retrieve",
));
}
if self.k > 10_000 {
return Err(ShardexError::config_error(
"k",
"exceeds reasonable maximum",
"Use k <= 10,000 to maintain reasonable response times and memory usage",
));
}
if let Some(slop) = self.slop_factor {
if slop == 0 {
return Err(ShardexError::config_error(
"slop_factor",
"must be greater than 0 when provided",
"Set slop_factor to a positive value or use None for default",
));
}
if slop > 1000 {
return Err(ShardexError::config_error(
"slop_factor",
"exceeds reasonable maximum",
"Use slop_factor <= 1000 to maintain search performance",
));
}
}
Ok(())
}
pub fn vector_dimension(&self) -> usize {
self.query_vector.len()
}
}
#[derive(Debug, Clone, Default)]
pub struct SearchParamsBuilder {
query_vector: Option<Vec<f32>>,
k: Option<usize>,
slop_factor: Option<u32>,
}
impl SearchParamsBuilder {
pub fn query_vector(mut self, vector: Vec<f32>) -> Self {
self.query_vector = Some(vector);
self
}
pub fn k(mut self, k: usize) -> Self {
self.k = Some(k);
self
}
pub fn slop_factor(mut self, factor: Option<u32>) -> Self {
self.slop_factor = factor;
self
}
pub fn build(self) -> Result<SearchParams, ShardexError> {
let query_vector = self.query_vector.ok_or_else(|| {
ShardexError::config_error(
"query_vector",
"is required",
"Provide a query vector using query_vector()",
)
})?;
let k = self
.k
.ok_or_else(|| ShardexError::config_error("k", "is required", "Provide the number of results using k()"))?;
let params = SearchParams {
query_vector,
k,
slop_factor: self.slop_factor,
};
params.validate()?;
Ok(params)
}
}
#[derive(Debug, Clone, Default)]
pub struct FlushParams {
pub with_stats: bool,
}
impl FlushParams {
pub fn new() -> Self {
Self::default()
}
pub fn with_stats() -> Self {
Self { with_stats: true }
}
pub fn set_with_stats(mut self, with_stats: bool) -> Self {
self.with_stats = with_stats;
self
}
pub fn validate(&self) -> Result<(), ShardexError> {
Ok(())
}
}
#[derive(Debug, Clone, Default)]
pub struct GetStatsParams;
impl GetStatsParams {
pub fn new() -> Self {
Self
}
pub fn validate(&self) -> Result<(), ShardexError> {
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct OpenIndexParams {
pub directory_path: PathBuf,
}
impl OpenIndexParams {
pub fn new(directory_path: PathBuf) -> Self {
Self { directory_path }
}
pub fn validate(&self) -> Result<(), ShardexError> {
if self.directory_path.as_os_str().is_empty() {
return Err(ShardexError::config_error(
"directory_path",
"cannot be empty",
"Provide a valid directory path where the existing index is stored",
));
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct ValidateConfigParams {
pub config: ShardexConfig,
}
impl ValidateConfigParams {
pub fn new(config: ShardexConfig) -> Self {
Self { config }
}
pub fn validate(&self) -> Result<(), ShardexError> {
Ok(())
}
pub fn get_config(&self) -> &ShardexConfig {
&self.config
}
}
#[derive(Debug, Clone)]
pub struct BatchAddPostingsParams {
pub postings: Vec<Posting>,
pub flush_immediately: bool,
pub track_performance: bool,
}
impl BatchAddPostingsParams {
pub fn new(postings: Vec<Posting>, flush_immediately: bool, track_performance: bool) -> Result<Self, ShardexError> {
if postings.is_empty() {
return Err(ShardexError::config_error(
"postings",
"cannot be empty",
"Provide at least one posting to add to the index",
));
}
Ok(Self {
postings,
flush_immediately,
track_performance,
})
}
pub fn simple(postings: Vec<Posting>) -> Result<Self, ShardexError> {
Self::new(postings, false, false)
}
pub fn with_immediate_flush(postings: Vec<Posting>) -> Result<Self, ShardexError> {
Self::new(postings, true, false)
}
pub fn with_performance_tracking(postings: Vec<Posting>) -> Result<Self, ShardexError> {
Self::new(postings, false, true)
}
pub fn with_flush_and_tracking(postings: Vec<Posting>) -> Result<Self, ShardexError> {
Self::new(postings, true, true)
}
pub fn validate(&self) -> Result<(), ShardexError> {
if self.postings.is_empty() {
return Err(ShardexError::config_error(
"postings",
"cannot be empty",
"Provide at least one posting to add to the index",
));
}
if let Some(first) = self.postings.first() {
let expected_dim = first.vector.len();
for (i, posting) in self.postings.iter().enumerate() {
if posting.vector.len() != expected_dim {
return Err(ShardexError::config_error(
format!("postings[{}].vector", i),
format!(
"dimension mismatch: expected {}, got {}",
expected_dim,
posting.vector.len()
),
"Ensure all postings have vectors with the same dimensions",
));
}
for (j, &value) in posting.vector.iter().enumerate() {
if !value.is_finite() {
return Err(ShardexError::config_error(
format!("postings[{}].vector[{}]", i, j),
"contains non-finite value",
"Ensure all vector components are finite numbers (not NaN or infinity)",
));
}
}
}
}
Ok(())
}
pub fn len(&self) -> usize {
self.postings.len()
}
pub fn is_empty(&self) -> bool {
self.postings.is_empty()
}
}
#[derive(Debug, Clone, Default)]
pub struct GetPerformanceStatsParams {
pub include_detailed: bool,
}
impl GetPerformanceStatsParams {
pub fn new() -> Self {
Self::default()
}
pub fn detailed() -> Self {
Self { include_detailed: true }
}
pub fn set_include_detailed(mut self, include_detailed: bool) -> Self {
self.include_detailed = include_detailed;
self
}
pub fn validate(&self) -> Result<(), ShardexError> {
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct IncrementalAddParams {
pub postings: Vec<Posting>,
pub batch_id: Option<String>,
}
impl IncrementalAddParams {
pub fn new(postings: Vec<Posting>, batch_id: Option<String>) -> Result<Self, ShardexError> {
if postings.is_empty() {
return Err(ShardexError::config_error(
"postings",
"cannot be empty",
"Provide at least one posting to add to the index",
));
}
Ok(Self { postings, batch_id })
}
pub fn simple(postings: Vec<Posting>) -> Result<Self, ShardexError> {
Self::new(postings, None)
}
pub fn with_batch_id(postings: Vec<Posting>, batch_id: String) -> Result<Self, ShardexError> {
Self::new(postings, Some(batch_id))
}
pub fn validate(&self) -> Result<(), ShardexError> {
if self.postings.is_empty() {
return Err(ShardexError::config_error(
"postings",
"cannot be empty",
"Provide at least one posting to add to the index",
));
}
if let Some(first) = self.postings.first() {
let expected_dim = first.vector.len();
for (i, posting) in self.postings.iter().enumerate() {
if posting.vector.len() != expected_dim {
return Err(ShardexError::config_error(
format!("postings[{}].vector", i),
format!(
"dimension mismatch: expected {}, got {}",
expected_dim,
posting.vector.len()
),
"Ensure all postings have vectors with the same dimensions",
));
}
for (j, &value) in posting.vector.iter().enumerate() {
if !value.is_finite() {
return Err(ShardexError::config_error(
format!("postings[{}].vector[{}]", i, j),
"contains non-finite value",
"Ensure all vector components are finite numbers (not NaN or infinity)",
));
}
}
}
}
if let Some(ref batch_id) = self.batch_id {
if batch_id.trim().is_empty() {
return Err(ShardexError::config_error(
"batch_id",
"cannot be empty when provided",
"Provide a non-empty batch identifier or use None",
));
}
}
Ok(())
}
pub fn len(&self) -> usize {
self.postings.len()
}
pub fn is_empty(&self) -> bool {
self.postings.is_empty()
}
}
#[derive(Debug, Clone)]
pub struct RemoveDocumentsParams {
pub document_ids: Vec<u128>,
}
impl RemoveDocumentsParams {
pub fn new(document_ids: Vec<u128>) -> Result<Self, ShardexError> {
if document_ids.is_empty() {
return Err(ShardexError::config_error(
"document_ids",
"cannot be empty",
"Provide at least one document ID to remove",
));
}
Ok(Self { document_ids })
}
pub fn validate(&self) -> Result<(), ShardexError> {
if self.document_ids.is_empty() {
return Err(ShardexError::config_error(
"document_ids",
"cannot be empty",
"Provide at least one document ID to remove",
));
}
Ok(())
}
pub fn len(&self) -> usize {
self.document_ids.len()
}
pub fn is_empty(&self) -> bool {
self.document_ids.is_empty()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::identifiers::DocumentId;
#[test]
fn test_create_index_params_builder() {
let params = CreateIndexParams::builder()
.directory_path(PathBuf::from("./test"))
.vector_size(768)
.shard_size(50000)
.build()
.unwrap();
assert_eq!(params.directory_path, PathBuf::from("./test"));
assert_eq!(params.vector_size, 768);
assert_eq!(params.shard_size, 50000);
}
#[test]
fn test_create_index_params_validation_vector_size_zero() {
let params = CreateIndexParams::builder().vector_size(0).build();
assert!(params.is_err());
}
#[test]
fn test_create_index_params_from_config() {
let config = ShardexConfig::new().vector_size(512).shard_size(25000);
let params: CreateIndexParams = config.into();
assert_eq!(params.vector_size, 512);
assert_eq!(params.shard_size, 25000);
}
#[test]
fn test_add_postings_params_empty() {
let result = AddPostingsParams::new(vec![]);
assert!(result.is_err());
}
#[test]
fn test_add_postings_params_valid() {
let posting = Posting {
document_id: DocumentId::from_raw(1),
start: 0,
length: 100,
vector: vec![0.1, 0.2, 0.3],
};
let params = AddPostingsParams::new(vec![posting]).unwrap();
assert_eq!(params.len(), 1);
assert!(!params.is_empty());
assert!(params.validate().is_ok());
}
#[test]
fn test_add_postings_params_dimension_mismatch() {
let posting1 = Posting {
document_id: DocumentId::from_raw(1),
start: 0,
length: 100,
vector: vec![0.1, 0.2, 0.3],
};
let posting2 = Posting {
document_id: DocumentId::from_raw(2),
start: 0,
length: 100,
vector: vec![0.1, 0.2], };
let params = AddPostingsParams::new(vec![posting1, posting2]).unwrap();
assert!(params.validate().is_err());
}
#[test]
fn test_search_params_builder() {
let params = SearchParams::builder()
.query_vector(vec![0.1, 0.2, 0.3])
.k(10)
.slop_factor(Some(5))
.build()
.unwrap();
assert_eq!(params.vector_dimension(), 3);
assert_eq!(params.k, 10);
assert_eq!(params.slop_factor, Some(5));
}
#[test]
fn test_search_params_validation_empty_vector() {
let result = SearchParams::new(vec![], 10);
assert!(result.is_err());
}
#[test]
fn test_search_params_validation_k_zero() {
let result = SearchParams::new(vec![0.1, 0.2, 0.3], 0);
assert!(result.is_err());
}
#[test]
fn test_search_params_with_slop_factor() {
let params = SearchParams::with_slop_factor(vec![0.1, 0.2, 0.3], 10, 5).unwrap();
assert_eq!(params.slop_factor, Some(5));
assert!(params.validate().is_ok());
}
#[test]
fn test_flush_params() {
let params1 = FlushParams::new();
assert!(!params1.with_stats);
assert!(params1.validate().is_ok());
let params2 = FlushParams::with_stats();
assert!(params2.with_stats);
assert!(params2.validate().is_ok());
}
#[test]
fn test_get_stats_params() {
let params = GetStatsParams::new();
assert!(params.validate().is_ok());
}
#[test]
fn test_create_index_params_high_performance() {
let params = CreateIndexParams::high_performance(PathBuf::from("./high_perf"));
assert_eq!(params.directory_path, PathBuf::from("./high_perf"));
assert_eq!(params.vector_size, 256);
assert_eq!(params.shard_size, 15000);
assert_eq!(params.batch_write_interval_ms, 75);
assert_eq!(params.wal_segment_size, 2 * 1024 * 1024);
assert_eq!(params.bloom_filter_size, 2048);
assert_eq!(params.default_slop_factor, 4);
assert!(params.validate().is_ok());
}
#[test]
fn test_create_index_params_memory_optimized() {
let params = CreateIndexParams::memory_optimized(PathBuf::from("./memory_opt"));
assert_eq!(params.directory_path, PathBuf::from("./memory_opt"));
assert_eq!(params.vector_size, 128);
assert_eq!(params.shard_size, 5000);
assert_eq!(params.batch_write_interval_ms, 200);
assert_eq!(params.wal_segment_size, 256 * 1024);
assert_eq!(params.bloom_filter_size, 512);
assert_eq!(params.default_slop_factor, 2);
assert!(params.validate().is_ok());
}
#[test]
fn test_create_index_params_from_shardex_config() {
let config = ShardexConfig::new()
.directory_path("./config_test")
.vector_size(512)
.shard_size(25000);
let params = CreateIndexParams::from_shardex_config(config);
assert_eq!(params.directory_path, PathBuf::from("./config_test"));
assert_eq!(params.vector_size, 512);
assert_eq!(params.shard_size, 25000);
assert!(params.validate().is_ok());
}
#[test]
fn test_open_index_params() {
let params = OpenIndexParams::new(PathBuf::from("./existing_index"));
assert_eq!(params.directory_path, PathBuf::from("./existing_index"));
assert!(params.validate().is_ok());
}
#[test]
fn test_open_index_params_empty_path() {
let params = OpenIndexParams::new(PathBuf::from(""));
assert!(params.validate().is_err());
}
#[test]
fn test_validate_config_params() {
let config = ShardexConfig::new().vector_size(384).shard_size(10000);
let params = ValidateConfigParams::new(config);
assert_eq!(params.get_config().vector_size, 384);
assert_eq!(params.get_config().shard_size, 10000);
assert!(params.validate().is_ok());
}
#[test]
fn test_batch_add_postings_params() {
let posting = Posting {
document_id: DocumentId::from_raw(1),
start: 0,
length: 100,
vector: vec![0.1, 0.2, 0.3],
};
let params = BatchAddPostingsParams::simple(vec![posting]).unwrap();
assert_eq!(params.len(), 1);
assert!(!params.is_empty());
assert!(!params.flush_immediately);
assert!(!params.track_performance);
assert!(params.validate().is_ok());
}
#[test]
fn test_batch_add_postings_params_with_options() {
let posting = Posting {
document_id: DocumentId::from_raw(1),
start: 0,
length: 100,
vector: vec![0.1, 0.2, 0.3],
};
let params = BatchAddPostingsParams::with_flush_and_tracking(vec![posting]).unwrap();
assert!(params.flush_immediately);
assert!(params.track_performance);
assert!(params.validate().is_ok());
}
#[test]
fn test_batch_add_postings_params_empty() {
let result = BatchAddPostingsParams::simple(vec![]);
assert!(result.is_err());
}
#[test]
fn test_get_performance_stats_params() {
let params1 = GetPerformanceStatsParams::new();
assert!(!params1.include_detailed);
assert!(params1.validate().is_ok());
let params2 = GetPerformanceStatsParams::detailed();
assert!(params2.include_detailed);
assert!(params2.validate().is_ok());
}
#[test]
fn test_incremental_add_params() {
let posting = Posting {
document_id: DocumentId::from_raw(1),
start: 0,
length: 100,
vector: vec![0.1, 0.2, 0.3],
};
let params1 = IncrementalAddParams::simple(vec![posting.clone()]).unwrap();
assert_eq!(params1.len(), 1);
assert!(!params1.is_empty());
assert!(params1.batch_id.is_none());
assert!(params1.validate().is_ok());
let params2 = IncrementalAddParams::with_batch_id(vec![posting], "batch_123".to_string()).unwrap();
assert_eq!(params2.batch_id.as_ref().unwrap(), "batch_123");
assert!(params2.validate().is_ok());
}
#[test]
fn test_incremental_add_params_empty_batch_id() {
let posting = Posting {
document_id: DocumentId::from_raw(1),
start: 0,
length: 100,
vector: vec![0.1, 0.2, 0.3],
};
let params = IncrementalAddParams::new(vec![posting], Some(" ".to_string()));
assert!(params.is_ok()); assert!(params.unwrap().validate().is_err()); }
#[test]
fn test_remove_documents_params() {
let params = RemoveDocumentsParams::new(vec![1, 2, 3]).unwrap();
assert_eq!(params.len(), 3);
assert!(!params.is_empty());
assert!(params.validate().is_ok());
}
#[test]
fn test_remove_documents_params_empty() {
let result = RemoveDocumentsParams::new(vec![]);
assert!(result.is_err());
}
}
#[derive(Debug, Clone)]
pub struct StoreDocumentTextParams {
pub document_id: DocumentId,
pub text: String,
pub postings: Vec<Posting>,
}
impl StoreDocumentTextParams {
pub fn new(document_id: DocumentId, text: String, postings: Vec<Posting>) -> Result<Self, ShardexError> {
let params = Self {
document_id,
text,
postings,
};
params.validate()?;
Ok(params)
}
pub fn validate(&self) -> Result<(), ShardexError> {
if self.text.is_empty() {
return Err(ShardexError::config_error(
"text",
"cannot be empty",
"Provide non-empty text content for the document",
));
}
for (i, posting) in self.postings.iter().enumerate() {
if posting.document_id != self.document_id {
return Err(ShardexError::config_error(
format!("postings[{}].document_id", i),
format!(
"mismatch: expected {}, got {}",
self.document_id.raw(),
posting.document_id.raw()
),
"Ensure all postings belong to the same document",
));
}
let end_pos = posting.start + posting.length;
if end_pos > self.text.len() as u32 {
return Err(ShardexError::config_error(
format!("postings[{}]", i),
format!(
"range {}+{} exceeds document length {}",
posting.start,
posting.length,
self.text.len()
),
"Ensure all posting ranges are within document bounds",
));
}
for (j, &value) in posting.vector.iter().enumerate() {
if !value.is_finite() {
return Err(ShardexError::config_error(
format!("postings[{}].vector[{}]", i, j),
"contains non-finite value",
"Ensure all vector components are finite numbers (not NaN or infinity)",
));
}
}
}
if let Some(first) = self.postings.first() {
let expected_dim = first.vector.len();
for (i, posting) in self.postings.iter().enumerate() {
if posting.vector.len() != expected_dim {
return Err(ShardexError::config_error(
format!("postings[{}].vector", i),
format!(
"dimension mismatch: expected {}, got {}",
expected_dim,
posting.vector.len()
),
"Ensure all postings have vectors with the same dimensions",
));
}
}
}
Ok(())
}
pub fn len(&self) -> usize {
self.postings.len()
}
pub fn is_empty(&self) -> bool {
self.postings.is_empty()
}
pub fn text_size(&self) -> usize {
self.text.len()
}
}
#[derive(Debug, Clone)]
pub struct GetDocumentTextParams {
pub document_id: DocumentId,
}
impl GetDocumentTextParams {
pub fn new(document_id: DocumentId) -> Self {
Self { document_id }
}
pub fn validate(&self) -> Result<(), ShardexError> {
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct ExtractSnippetParams {
pub document_id: DocumentId,
pub start: u32,
pub length: u32,
}
impl ExtractSnippetParams {
pub fn new(document_id: DocumentId, start: u32, length: u32) -> Result<Self, ShardexError> {
let params = Self {
document_id,
start,
length,
};
params.validate()?;
Ok(params)
}
pub fn from_posting(posting: &Posting) -> Self {
Self {
document_id: posting.document_id,
start: posting.start,
length: posting.length,
}
}
pub fn validate(&self) -> Result<(), ShardexError> {
if self.length == 0 {
return Err(ShardexError::config_error(
"length",
"must be greater than 0",
"Provide a positive length for the snippet to extract",
));
}
if let Some(end_pos) = self.start.checked_add(self.length) {
if end_pos == 0 {
return Err(ShardexError::config_error(
"start + length",
"results in zero end position",
"Ensure the snippet range is valid",
));
}
} else {
return Err(ShardexError::config_error(
"start + length",
"results in overflow",
"Ensure the start position and length don't exceed u32 limits",
));
}
Ok(())
}
pub fn end_position(&self) -> u32 {
self.start + self.length
}
}
#[derive(Debug, Clone)]
pub struct DocumentTextEntry {
pub document_id: DocumentId,
pub text: String,
pub postings: Vec<Posting>,
}
impl DocumentTextEntry {
pub fn new(document_id: DocumentId, text: String, postings: Vec<Posting>) -> Self {
Self {
document_id,
text,
postings,
}
}
pub fn text_size(&self) -> usize {
self.text.len()
}
pub fn posting_count(&self) -> usize {
self.postings.len()
}
}
#[derive(Debug, Clone)]
pub struct BatchStoreDocumentTextParams {
pub documents: Vec<DocumentTextEntry>,
pub flush_immediately: bool,
pub track_performance: bool,
}
impl BatchStoreDocumentTextParams {
pub fn new(
documents: Vec<DocumentTextEntry>,
flush_immediately: bool,
track_performance: bool,
) -> Result<Self, ShardexError> {
if documents.is_empty() {
return Err(ShardexError::config_error(
"documents",
"cannot be empty",
"Provide at least one document entry for batch storage",
));
}
let params = Self {
documents,
flush_immediately,
track_performance,
};
params.validate()?;
Ok(params)
}
pub fn simple(documents: Vec<DocumentTextEntry>) -> Result<Self, ShardexError> {
Self::new(documents, false, false)
}
pub fn with_immediate_flush(documents: Vec<DocumentTextEntry>) -> Result<Self, ShardexError> {
Self::new(documents, true, false)
}
pub fn with_performance_tracking(documents: Vec<DocumentTextEntry>) -> Result<Self, ShardexError> {
Self::new(documents, false, true)
}
pub fn with_flush_and_tracking(documents: Vec<DocumentTextEntry>) -> Result<Self, ShardexError> {
Self::new(documents, true, true)
}
pub fn validate(&self) -> Result<(), ShardexError> {
if self.documents.is_empty() {
return Err(ShardexError::config_error(
"documents",
"cannot be empty",
"Provide at least one document entry for batch storage",
));
}
for (i, entry) in self.documents.iter().enumerate() {
if entry.text.is_empty() {
return Err(ShardexError::config_error(
format!("documents[{}].text", i),
"cannot be empty",
"Provide non-empty text content for all document entries",
));
}
for (j, posting) in entry.postings.iter().enumerate() {
if posting.document_id != entry.document_id {
return Err(ShardexError::config_error(
format!("documents[{}].postings[{}].document_id", i, j),
format!(
"mismatch: expected {}, got {}",
entry.document_id.raw(),
posting.document_id.raw()
),
"Ensure all postings belong to their respective documents",
));
}
let end_pos = posting.start + posting.length;
if end_pos > entry.text.len() as u32 {
return Err(ShardexError::config_error(
format!("documents[{}].postings[{}]", i, j),
format!(
"range {}+{} exceeds document length {}",
posting.start,
posting.length,
entry.text.len()
),
"Ensure all posting ranges are within document bounds",
));
}
for (k, &value) in posting.vector.iter().enumerate() {
if !value.is_finite() {
return Err(ShardexError::config_error(
format!("documents[{}].postings[{}].vector[{}]", i, j, k),
"contains non-finite value",
"Ensure all vector components are finite numbers (not NaN or infinity)",
));
}
}
}
if let Some(first) = entry.postings.first() {
let expected_dim = first.vector.len();
for (j, posting) in entry.postings.iter().enumerate() {
if posting.vector.len() != expected_dim {
return Err(ShardexError::config_error(
format!("documents[{}].postings[{}].vector", i, j),
format!(
"dimension mismatch: expected {}, got {}",
expected_dim,
posting.vector.len()
),
"Ensure all postings within a document have vectors with the same dimensions",
));
}
}
}
}
let mut seen_ids = std::collections::HashSet::new();
for (i, entry) in self.documents.iter().enumerate() {
if !seen_ids.insert(entry.document_id) {
return Err(ShardexError::config_error(
format!("documents[{}].document_id", i),
format!("duplicate document ID: {}", entry.document_id.raw()),
"Ensure all documents in the batch have unique document IDs",
));
}
}
Ok(())
}
pub fn len(&self) -> usize {
self.documents.len()
}
pub fn is_empty(&self) -> bool {
self.documents.is_empty()
}
pub fn total_postings(&self) -> usize {
self.documents.iter().map(|doc| doc.postings.len()).sum()
}
pub fn total_text_size(&self) -> usize {
self.documents.iter().map(|doc| doc.text.len()).sum()
}
pub fn average_document_size(&self) -> usize {
if self.documents.is_empty() {
0
} else {
self.total_text_size() / self.documents.len()
}
}
}
#[cfg(test)]
mod document_text_tests {
use super::*;
use crate::identifiers::DocumentId;
#[test]
fn test_store_document_text_params() {
let doc_id = DocumentId::from_raw(1);
let text = "Hello world".to_string();
let posting = Posting {
document_id: doc_id,
start: 0,
length: 5,
vector: vec![0.1, 0.2, 0.3],
};
let params = StoreDocumentTextParams::new(doc_id, text, vec![posting]).unwrap();
assert_eq!(params.document_id, doc_id);
assert_eq!(params.text, "Hello world");
assert_eq!(params.len(), 1);
assert_eq!(params.text_size(), 11);
assert!(!params.is_empty());
assert!(params.validate().is_ok());
}
#[test]
fn test_store_document_text_params_empty_text() {
let doc_id = DocumentId::from_raw(1);
let result = StoreDocumentTextParams::new(doc_id, String::new(), vec![]);
assert!(result.is_err());
}
#[test]
fn test_store_document_text_params_document_id_mismatch() {
let doc_id1 = DocumentId::from_raw(1);
let doc_id2 = DocumentId::from_raw(2);
let text = "Hello world".to_string();
let posting = Posting {
document_id: doc_id2, start: 0,
length: 5,
vector: vec![0.1, 0.2, 0.3],
};
let result = StoreDocumentTextParams::new(doc_id1, text, vec![posting]);
assert!(result.is_err());
}
#[test]
fn test_store_document_text_params_range_exceeds_document() {
let doc_id = DocumentId::from_raw(1);
let text = "Hi".to_string(); let posting = Posting {
document_id: doc_id,
start: 0,
length: 10, vector: vec![0.1, 0.2, 0.3],
};
let result = StoreDocumentTextParams::new(doc_id, text, vec![posting]);
assert!(result.is_err());
}
#[test]
fn test_get_document_text_params() {
let doc_id = DocumentId::from_raw(1);
let params = GetDocumentTextParams::new(doc_id);
assert_eq!(params.document_id, doc_id);
assert!(params.validate().is_ok());
}
#[test]
fn test_extract_snippet_params() {
let doc_id = DocumentId::from_raw(1);
let params = ExtractSnippetParams::new(doc_id, 10, 5).unwrap();
assert_eq!(params.document_id, doc_id);
assert_eq!(params.start, 10);
assert_eq!(params.length, 5);
assert_eq!(params.end_position(), 15);
assert!(params.validate().is_ok());
}
#[test]
fn test_extract_snippet_params_from_posting() {
let doc_id = DocumentId::from_raw(1);
let posting = Posting {
document_id: doc_id,
start: 5,
length: 10,
vector: vec![0.1, 0.2, 0.3],
};
let params = ExtractSnippetParams::from_posting(&posting);
assert_eq!(params.document_id, doc_id);
assert_eq!(params.start, 5);
assert_eq!(params.length, 10);
}
#[test]
fn test_extract_snippet_params_zero_length() {
let doc_id = DocumentId::from_raw(1);
let result = ExtractSnippetParams::new(doc_id, 10, 0);
assert!(result.is_err());
}
#[test]
fn test_document_text_entry() {
let doc_id = DocumentId::from_raw(1);
let text = "Hello world".to_string();
let posting = Posting {
document_id: doc_id,
start: 0,
length: 5,
vector: vec![0.1, 0.2, 0.3],
};
let entry = DocumentTextEntry::new(doc_id, text.clone(), vec![posting]);
assert_eq!(entry.document_id, doc_id);
assert_eq!(entry.text, text);
assert_eq!(entry.text_size(), 11);
assert_eq!(entry.posting_count(), 1);
}
#[test]
fn test_batch_store_document_text_params() {
let doc_id = DocumentId::from_raw(1);
let entry = DocumentTextEntry::new(
doc_id,
"Hello world".to_string(),
vec![Posting {
document_id: doc_id,
start: 0,
length: 5,
vector: vec![0.1, 0.2, 0.3],
}],
);
let params = BatchStoreDocumentTextParams::simple(vec![entry]).unwrap();
assert_eq!(params.len(), 1);
assert_eq!(params.total_postings(), 1);
assert_eq!(params.total_text_size(), 11);
assert_eq!(params.average_document_size(), 11);
assert!(!params.flush_immediately);
assert!(!params.track_performance);
assert!(params.validate().is_ok());
}
#[test]
fn test_batch_store_document_text_params_with_options() {
let doc_id = DocumentId::from_raw(1);
let entry = DocumentTextEntry::new(
doc_id,
"Hello world".to_string(),
vec![Posting {
document_id: doc_id,
start: 0,
length: 5,
vector: vec![0.1, 0.2, 0.3],
}],
);
let params = BatchStoreDocumentTextParams::with_flush_and_tracking(vec![entry]).unwrap();
assert!(params.flush_immediately);
assert!(params.track_performance);
assert!(params.validate().is_ok());
}
#[test]
fn test_batch_store_document_text_params_empty() {
let result = BatchStoreDocumentTextParams::simple(vec![]);
assert!(result.is_err());
}
#[test]
fn test_batch_store_document_text_params_duplicate_ids() {
let doc_id = DocumentId::from_raw(1);
let entry1 = DocumentTextEntry::new(doc_id, "First".to_string(), vec![]);
let entry2 = DocumentTextEntry::new(doc_id, "Second".to_string(), vec![]);
let result = BatchStoreDocumentTextParams::simple(vec![entry1, entry2]);
assert!(result.is_err());
}
}