use std::collections::{BTreeMap, BTreeSet};
use std::sync::Arc;
use parking_lot::{Mutex, RwLock};
use thiserror::Error;
use crate::vector::schema::{IndexAlgorithm, MetadataFieldType, VectorSchema};
use dyntext::TextIndex;
use dynvec::Engine;
#[derive(Debug, Error)]
#[non_exhaustive]
pub enum RegistryError {
#[error("index already exists: {0}")]
AlreadyExists(String),
#[error("index not found: {0}")]
NotFound(String),
#[error("unsupported index algorithm: {0:?}")]
UnsupportedAlgorithm(IndexAlgorithm),
#[error("engine: {0}")]
Engine(#[from] dynvec::storage::StoreError),
}
#[derive(Debug, Default)]
pub struct TextFieldIndex {
pub index: TextIndex,
pub doc_to_key: BTreeMap<u32, Vec<u8>>,
pub key_to_doc: BTreeMap<Vec<u8>, u32>,
}
pub type TextHit = (Vec<u8>, Vec<u8>);
pub type TextRegexResult = Option<Result<Vec<TextHit>, dyntext::regex_ast::RegexError>>;
pub type TextRegexApproxResult = Option<Result<Vec<TextHit>, dyntext::TreError>>;
#[derive(Debug)]
pub struct VectorTable {
pub name: String,
pub schema: VectorSchema,
pub engine: Engine,
indexed_keys: Mutex<BTreeSet<Vec<u8>>>,
text_indexes: Mutex<BTreeMap<String, TextFieldIndex>>,
}
impl VectorTable {
pub fn record_indexed_key(&self, key: Vec<u8>) {
self.indexed_keys.lock().insert(key);
}
#[must_use]
pub fn indexed_keys(&self) -> Vec<Vec<u8>> {
self.indexed_keys.lock().iter().cloned().collect()
}
#[must_use]
pub fn has_text_field(&self, field: &str) -> bool {
let in_schema = self
.schema
.metadata_fields
.iter()
.any(|f| f.field_type == MetadataFieldType::Text && f.name == field);
if in_schema {
return true;
}
self.text_indexes.lock().contains_key(field)
}
pub fn add_text_field(&self, field: &str) -> bool {
let mut guard = self.text_indexes.lock();
if guard.contains_key(field) {
return false;
}
guard.insert(field.to_string(), TextFieldIndex::default());
true
}
#[must_use]
pub fn text_field_names(&self) -> Vec<String> {
let mut names: BTreeSet<String> = BTreeSet::new();
for f in &self.schema.metadata_fields {
if f.field_type == MetadataFieldType::Text {
names.insert(f.name.clone());
}
}
for k in self.text_indexes.lock().keys() {
names.insert(k.clone());
}
names.into_iter().collect()
}
#[must_use]
pub fn has_text_index(&self, field: &str) -> bool {
self.text_indexes.lock().contains_key(field)
}
#[must_use]
pub fn text_index_doc_count(&self, field: &str) -> Option<usize> {
self.text_indexes
.lock()
.get(field)
.map(|state| state.index.doc_count())
}
pub fn upsert_text_field(&self, field: &str, key: &[u8], text: &[u8]) {
let mut guard = self.text_indexes.lock();
let Some(state) = guard.get_mut(field) else {
return;
};
if let Some(prev_id) = state.key_to_doc.remove(key) {
state.doc_to_key.remove(&prev_id);
state.index.remove(prev_id);
}
let doc_id = state.index.insert(text.to_vec());
state.doc_to_key.insert(doc_id, key.to_vec());
state.key_to_doc.insert(key.to_vec(), doc_id);
}
#[must_use]
pub fn search_text_substring(&self, field: &str, query: &[u8]) -> Option<Vec<TextHit>> {
let guard = self.text_indexes.lock();
let state = guard.get(field)?;
let mut hits: Vec<TextHit> = Vec::new();
for doc_id in state.index.search_substring(query) {
let Some(key) = state.doc_to_key.get(&doc_id) else {
continue;
};
let Some(doc) = state.index.docs().get(&doc_id) else {
continue;
};
hits.push((key.clone(), doc.text.clone()));
}
Some(hits)
}
pub fn search_text_regex(&self, field: &str, pattern: &str) -> TextRegexResult {
let guard = self.text_indexes.lock();
let state = guard.get(field)?;
let result = state.index.search_regex(pattern).map(|ids| {
let mut out: Vec<TextHit> = Vec::new();
for doc_id in ids {
let Some(key) = state.doc_to_key.get(&doc_id) else {
continue;
};
let Some(doc) = state.index.docs().get(&doc_id) else {
continue;
};
out.push((key.clone(), doc.text.clone()));
}
out
});
Some(result)
}
pub fn search_text_regex_approx(
&self,
field: &str,
pattern: &str,
max_errors: u16,
) -> TextRegexApproxResult {
let guard = self.text_indexes.lock();
let state = guard.get(field)?;
let result = state
.index
.search_regex_approx(pattern, max_errors)
.map(|ids| {
let mut out: Vec<TextHit> = Vec::new();
for doc_id in ids {
let Some(key) = state.doc_to_key.get(&doc_id) else {
continue;
};
let Some(doc) = state.index.docs().get(&doc_id) else {
continue;
};
out.push((key.clone(), doc.text.clone()));
}
out
});
Some(result)
}
}
#[derive(Clone, Debug, PartialEq)]
pub struct VectorTableInfo {
pub name: String,
pub dim: u16,
pub distance: crate::vector::schema::DistanceMetric,
pub algorithm: IndexAlgorithm,
pub live_rows: usize,
pub tracked_rows: usize,
}
#[derive(Clone, Default)]
pub struct VectorRegistry {
inner: Arc<RwLock<BTreeMap<String, Arc<VectorTable>>>>,
}
impl std::fmt::Debug for VectorRegistry {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let names: Vec<String> = self.inner.read().keys().cloned().collect();
f.debug_struct("VectorRegistry")
.field("indexes", &names)
.finish()
}
}
impl VectorRegistry {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn create(&self, name: String, schema: VectorSchema) -> Result<(), RegistryError> {
if !matches!(schema.algorithm, IndexAlgorithm::Hnsw) {
return Err(RegistryError::UnsupportedAlgorithm(schema.algorithm));
}
let mut guard = self.inner.write();
if guard.contains_key(&name) {
return Err(RegistryError::AlreadyExists(name));
}
let engine_schema = schema.to_engine_schema(&name);
let engine = Engine::in_memory(engine_schema)?;
let mut text_indexes: BTreeMap<String, TextFieldIndex> = BTreeMap::new();
for f in &schema.metadata_fields {
if f.field_type == MetadataFieldType::Text {
text_indexes.insert(f.name.clone(), TextFieldIndex::default());
}
}
let table = VectorTable {
name: name.clone(),
schema,
engine,
indexed_keys: Mutex::new(BTreeSet::new()),
text_indexes: Mutex::new(text_indexes),
};
guard.insert(name, Arc::new(table));
Ok(())
}
pub fn drop(&self, name: &str) -> Result<Arc<VectorTable>, RegistryError> {
let mut guard = self.inner.write();
guard
.remove(name)
.ok_or_else(|| RegistryError::NotFound(name.to_string()))
}
pub fn drop_with_dd(&self, name: &str) -> Result<Vec<Vec<u8>>, RegistryError> {
let table = self.drop(name)?;
Ok(table.indexed_keys())
}
#[must_use]
pub fn get(&self, name: &str) -> Option<Arc<VectorTable>> {
self.inner.read().get(name).cloned()
}
#[must_use]
pub fn list(&self) -> Vec<String> {
self.inner.read().keys().cloned().collect()
}
#[must_use]
pub fn info(&self, name: &str) -> Option<VectorTableInfo> {
let table = self.get(name)?;
let stats = table.engine.stats().ok()?;
Some(VectorTableInfo {
name: table.name.clone(),
dim: table.schema.dim,
distance: table.schema.distance,
algorithm: table.schema.algorithm,
live_rows: stats.live_rows,
tracked_rows: stats.tracked_rows,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::vector::schema::{DistanceMetric, IndexAlgorithm, VectorType};
fn schema(algorithm: IndexAlgorithm) -> VectorSchema {
VectorSchema {
vector_field: "vec".to_string(),
vector_type: VectorType::Float32,
dim: 4,
distance: DistanceMetric::Cosine,
algorithm,
prefixes: Vec::new(),
metadata_fields: Vec::new(),
}
}
#[test]
fn create_and_get_returns_table() {
let reg = VectorRegistry::new();
reg.create("idx".to_string(), schema(IndexAlgorithm::Hnsw))
.unwrap();
let table = reg.get("idx").expect("table present");
assert_eq!(table.name, "idx");
assert_eq!(table.schema.dim, 4);
}
#[test]
fn duplicate_name_errors() {
let reg = VectorRegistry::new();
reg.create("idx".to_string(), schema(IndexAlgorithm::Hnsw))
.unwrap();
let err = reg
.create("idx".to_string(), schema(IndexAlgorithm::Hnsw))
.unwrap_err();
assert!(matches!(err, RegistryError::AlreadyExists(_)));
}
#[test]
fn unsupported_algorithm_errors() {
let reg = VectorRegistry::new();
let err = reg
.create("idx".to_string(), schema(IndexAlgorithm::Flat))
.unwrap_err();
assert!(matches!(err, RegistryError::UnsupportedAlgorithm(_)));
}
}