use alloc::{
boxed::Box,
collections::BTreeMap,
string::{String, ToString},
sync::Arc,
vec,
vec::Vec,
};
use half::f16;
use crate::codec::{Codebook, Codec, CodecConfig, CompressedVector};
use crate::corpus::compression_policy::CompressionPolicy;
use crate::corpus::entry_meta_value::EntryMetaValue;
use crate::corpus::events::CorpusEvent;
use crate::corpus::vector_entry::VectorEntry;
use crate::errors::CorpusError;
use crate::types::{CorpusId, Timestamp, VectorId};
use crate::corpus::vector_id_map::VectorIdMap;
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct BatchReport {
pub inserted: usize,
pub skipped: usize,
pub first_error: Option<(usize, CorpusError)>,
}
pub struct Corpus {
corpus_id: CorpusId,
config: CodecConfig,
codebook: Codebook,
compression_policy: CompressionPolicy,
vectors: VectorIdMap<VectorEntry>,
pending_events: Vec<CorpusEvent>,
metadata: BTreeMap<String, EntryMetaValue>,
}
impl Corpus {
#[must_use]
pub fn new(
corpus_id: CorpusId,
config: CodecConfig,
codebook: Codebook,
compression_policy: CompressionPolicy,
metadata: BTreeMap<String, EntryMetaValue>,
) -> Self {
Self::new_at(corpus_id, config, codebook, compression_policy, metadata, 0)
}
#[must_use]
pub fn new_at(
corpus_id: CorpusId,
config: CodecConfig,
codebook: Codebook,
compression_policy: CompressionPolicy,
metadata: BTreeMap<String, EntryMetaValue>,
timestamp: Timestamp,
) -> Self {
let event = CorpusEvent::Created {
corpus_id: Arc::clone(&corpus_id),
codec_config: config.clone(),
compression_policy,
timestamp,
};
Self {
corpus_id,
config,
codebook,
compression_policy,
vectors: VectorIdMap::new(),
pending_events: vec![event],
metadata,
}
}
#[must_use]
pub const fn corpus_id(&self) -> &CorpusId {
&self.corpus_id
}
#[must_use]
pub const fn config(&self) -> &CodecConfig {
&self.config
}
#[must_use]
pub const fn compression_policy(&self) -> CompressionPolicy {
self.compression_policy
}
#[must_use]
pub fn vector_count(&self) -> usize {
self.vectors.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.vectors.is_empty()
}
#[must_use]
pub fn contains(&self, id: &VectorId) -> bool {
self.vectors.contains_key(id.as_ref())
}
pub fn iter(&self) -> impl Iterator<Item = (&VectorId, &VectorEntry)> {
self.vectors.iter()
}
#[must_use]
pub const fn metadata(&self) -> &BTreeMap<String, EntryMetaValue> {
&self.metadata
}
pub fn drain_events(&mut self) -> Vec<CorpusEvent> {
core::mem::take(&mut self.pending_events)
}
pub fn insert(
&mut self,
id: VectorId,
vector: &[f32],
entry_metadata: Option<EntryMetaValue>,
timestamp: Timestamp,
) -> Result<(), CorpusError> {
self.validate_dimension(vector)?;
if self.vectors.contains_key(id.as_ref()) {
return Err(CorpusError::DuplicateVectorId { id });
}
let compressed = self.compress_vector(vector)?;
let meta = Self::unwrap_meta(entry_metadata);
#[allow(clippy::cast_possible_truncation)]
let declared_dim = vector.len() as u32;
let entry = VectorEntry::new(Arc::clone(&id), compressed, declared_dim, timestamp, meta);
self.vectors.insert(Arc::clone(&id), entry);
self.pending_events.push(CorpusEvent::VectorsInserted {
corpus_id: Arc::clone(&self.corpus_id),
vector_ids: Arc::from([id]),
count: 1,
timestamp,
});
Ok(())
}
pub fn insert_batch(
&mut self,
vectors: &[(VectorId, &[f32], Option<EntryMetaValue>)],
timestamp: Timestamp,
) -> Result<BatchReport, CorpusError> {
let mut staged: Vec<(VectorId, VectorEntry)> = Vec::with_capacity(vectors.len());
for (index, (id, vector, meta)) in vectors.iter().enumerate() {
if let Err(e) = self.validate_dimension(vector) {
return Err(CorpusError::BatchAtomicityFailure {
index,
source: Box::new(e),
});
}
if self.vectors.contains_key(id.as_ref()) {
return Err(CorpusError::BatchAtomicityFailure {
index,
source: Box::new(CorpusError::DuplicateVectorId { id: Arc::clone(id) }),
});
}
let already_staged = staged
.iter()
.any(|(staged_id, _)| staged_id.as_ref() == id.as_ref());
if already_staged {
return Err(CorpusError::BatchAtomicityFailure {
index,
source: Box::new(CorpusError::DuplicateVectorId { id: Arc::clone(id) }),
});
}
let compressed = match self.compress_vector(vector) {
Ok(cv) => cv,
Err(e) => {
return Err(CorpusError::BatchAtomicityFailure {
index,
source: Box::new(CorpusError::Codec(e)),
});
}
};
let entry_meta = Self::unwrap_meta(meta.clone());
#[allow(clippy::cast_possible_truncation)]
let declared_dim = vector.len() as u32;
staged.push((
Arc::clone(id),
VectorEntry::new(
Arc::clone(id),
compressed,
declared_dim,
timestamp,
entry_meta,
),
));
}
let mut ids: Vec<VectorId> = Vec::with_capacity(staged.len());
let inserted_count = staged.len();
for (id, entry) in staged {
ids.push(Arc::clone(&id));
self.vectors.insert(id, entry);
}
if !ids.is_empty() {
#[allow(clippy::cast_possible_truncation)]
let count = ids.len() as u32;
self.pending_events.push(CorpusEvent::VectorsInserted {
corpus_id: Arc::clone(&self.corpus_id),
vector_ids: Arc::from(ids.into_boxed_slice()),
count,
timestamp,
});
}
Ok(BatchReport {
inserted: inserted_count,
skipped: 0,
first_error: None,
})
}
pub fn decompress(&self, id: &VectorId) -> Result<Vec<f32>, CorpusError> {
let entry = self
.vectors
.get(id.as_ref())
.ok_or_else(|| CorpusError::UnknownVectorId { id: Arc::clone(id) })?;
self.decompress_entry(entry)
}
pub fn decompress_all_at(
&mut self,
timestamp: Timestamp,
) -> Result<BTreeMap<VectorId, Vec<f32>>, CorpusError> {
if self.vectors.is_empty() {
return Ok(BTreeMap::new());
}
let mut out = BTreeMap::new();
for (id, entry) in self.vectors.iter() {
let decompressed = self.decompress_entry(entry)?;
out.insert(Arc::clone(id), decompressed);
}
#[allow(clippy::cast_possible_truncation)]
let vector_count = out.len() as u32;
self.pending_events.push(CorpusEvent::Decompressed {
corpus_id: Arc::clone(&self.corpus_id),
vector_count,
timestamp,
});
Ok(out)
}
pub fn remove(&mut self, id: &VectorId) -> Option<VectorEntry> {
self.vectors.remove(id.as_ref())
}
#[allow(clippy::missing_const_for_fn)] fn validate_dimension(&self, vector: &[f32]) -> Result<(), CorpusError> {
let expected = self.config.dimension();
#[allow(clippy::cast_possible_truncation)]
let got = vector.len() as u32;
if got != expected {
return Err(CorpusError::DimensionMismatch { expected, got });
}
Ok(())
}
fn compress_vector(
&self,
vector: &[f32],
) -> Result<CompressedVector, crate::errors::CodecError> {
match self.compression_policy {
CompressionPolicy::Compress => {
Codec::new().compress(vector, &self.config, &self.codebook)
}
CompressionPolicy::Passthrough => {
let mut indices: Vec<u8> = Vec::with_capacity(vector.len() * 4);
for &v in vector {
indices.extend_from_slice(&v.to_le_bytes());
}
let dim = vector.len();
#[allow(clippy::cast_possible_truncation)]
let dim_u32 = dim as u32;
let byte_len = indices.len();
#[allow(clippy::cast_possible_truncation)]
let byte_len_u32 = byte_len as u32;
let _ = dim_u32; CompressedVector::new(
indices.into_boxed_slice(),
None,
self.config.config_hash().clone(),
byte_len_u32,
8,
)
}
CompressionPolicy::Fp16 => {
let mut indices: Vec<u8> = Vec::with_capacity(vector.len() * 2);
for &v in vector {
let h = f16::from_f32(v);
indices.extend_from_slice(&h.to_le_bytes());
}
let byte_len = indices.len();
#[allow(clippy::cast_possible_truncation)]
let byte_len_u32 = byte_len as u32;
CompressedVector::new(
indices.into_boxed_slice(),
None,
self.config.config_hash().clone(),
byte_len_u32,
8,
)
}
}
}
fn decompress_entry(&self, entry: &VectorEntry) -> Result<Vec<f32>, CorpusError> {
match self.compression_policy {
CompressionPolicy::Compress => Codec::new()
.decompress(entry.compressed(), &self.config, &self.codebook)
.map_err(CorpusError::Codec),
CompressionPolicy::Passthrough => {
let bytes = entry.compressed().indices();
if bytes.len() % 4 != 0 {
return Err(CorpusError::Codec(
crate::errors::CodecError::LengthMismatch {
left: bytes.len(),
right: (bytes.len() / 4) * 4,
},
));
}
let floats: Vec<f32> = bytes
.chunks_exact(4)
.map(|chunk| {
#[allow(clippy::indexing_slicing)]
f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]])
})
.collect();
Ok(floats)
}
CompressionPolicy::Fp16 => {
let bytes = entry.compressed().indices();
if bytes.len() % 2 != 0 {
return Err(CorpusError::Codec(
crate::errors::CodecError::LengthMismatch {
left: bytes.len(),
right: (bytes.len() / 2) * 2,
},
));
}
let floats: Vec<f32> = bytes
.chunks_exact(2)
.map(|chunk| {
#[allow(clippy::indexing_slicing)]
let h = f16::from_le_bytes([chunk[0], chunk[1]]);
f32::from(h)
})
.collect();
Ok(floats)
}
}
}
fn unwrap_meta(meta: Option<EntryMetaValue>) -> BTreeMap<String, EntryMetaValue> {
match meta {
None => BTreeMap::new(),
Some(EntryMetaValue::Object(arc_map)) => {
arc_map
.iter()
.map(|(k, v)| (k.as_ref().to_string(), v.clone()))
.collect()
}
Some(other) => {
let mut m = BTreeMap::new();
m.insert("_value".to_string(), other);
m
}
}
}
}