use std::path::{Path, PathBuf};
use std::sync::{Arc, Mutex};
#[cfg(feature = "vec")]
use memvid_core::{LocalTextEmbedder, TextEmbedConfig};
use memvid_core::{Memvid, PutOptions, SearchHit, SearchRequest};
use rig::{
Embed, OneOrMany,
embeddings::Embedding,
vector_store::{
InsertDocuments, VectorSearchRequest, VectorStoreError, VectorStoreIndex,
request::SearchFilter,
},
wasm_compat::WasmCompatSend,
};
use serde::{Deserialize, Serialize};
use crate::error::MemvidError;
#[derive(Clone)]
pub struct MemvidStore {
inner: Arc<Mutex<Memvid>>,
#[cfg(feature = "vec")]
embedder: Option<Arc<LocalTextEmbedder>>,
}
impl std::fmt::Debug for MemvidStore {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MemvidStore").finish_non_exhaustive()
}
}
impl MemvidStore {
pub fn from_memvid(memvid: Memvid) -> Self {
Self {
inner: Arc::new(Mutex::new(memvid)),
#[cfg(feature = "vec")]
embedder: None,
}
}
pub fn builder() -> MemvidStoreBuilder {
MemvidStoreBuilder::default()
}
fn lock(&self) -> Result<std::sync::MutexGuard<'_, Memvid>, MemvidError> {
self.inner.lock().map_err(|_| MemvidError::Poisoned)
}
#[cfg(feature = "vec")]
#[must_use]
pub fn has_embedder(&self) -> bool {
self.embedder.is_some()
}
#[cfg(feature = "vec")]
fn encode(&self, text: &str) -> Result<Option<Vec<f32>>, MemvidError> {
match &self.embedder {
Some(embedder) => Ok(Some(embedder.encode_text(text)?)),
None => Ok(None),
}
}
pub fn put_text(&self, text: &str, options: PutOptions) -> Result<u64, MemvidError> {
#[cfg(feature = "vec")]
let embedding = self.encode(text)?;
let mut guard = self.lock()?;
#[cfg(feature = "vec")]
let id = if let Some(emb) = embedding {
guard.put_with_embedding_and_options(text.as_bytes(), emb, options)?
} else {
guard.put_bytes_with_options(text.as_bytes(), options)?
};
#[cfg(not(feature = "vec"))]
let id = guard.put_bytes_with_options(text.as_bytes(), options)?;
guard.commit()?;
Ok(id)
}
pub fn put_text_uncommitted(
&self,
text: &str,
options: PutOptions,
) -> Result<u64, MemvidError> {
#[cfg(feature = "vec")]
let embedding = self.encode(text)?;
let mut guard = self.lock()?;
#[cfg(feature = "vec")]
let id = if let Some(emb) = embedding {
guard.put_with_embedding_and_options(text.as_bytes(), emb, options)?
} else {
guard.put_bytes_with_options(text.as_bytes(), options)?
};
#[cfg(not(feature = "vec"))]
let id = guard.put_bytes_with_options(text.as_bytes(), options)?;
Ok(id)
}
pub fn commit(&self) -> Result<(), MemvidError> {
let mut guard = self.lock()?;
guard.commit()?;
Ok(())
}
pub fn search(
&self,
request: SearchRequest,
) -> Result<memvid_core::SearchResponse, MemvidError> {
let mut guard = self.lock()?;
let resp = guard.search(request)?;
Ok(resp)
}
}
#[derive(Default)]
pub struct MemvidStoreBuilder {
path: Option<PathBuf>,
enable_lex: bool,
#[cfg(feature = "vec")]
enable_vec: bool,
#[cfg(feature = "vec")]
vec_model: Option<String>,
#[cfg(feature = "vec")]
embedder: Option<Arc<LocalTextEmbedder>>,
}
impl std::fmt::Debug for MemvidStoreBuilder {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut d = f.debug_struct("MemvidStoreBuilder");
d.field("path", &self.path)
.field("enable_lex", &self.enable_lex);
#[cfg(feature = "vec")]
{
d.field("enable_vec", &self.enable_vec)
.field("vec_model", &self.vec_model)
.field("embedder", &self.embedder.as_ref().map(|_| "<embedder>"));
}
d.finish()
}
}
impl MemvidStoreBuilder {
pub fn path<P: Into<PathBuf>>(mut self, path: P) -> Self {
self.path = Some(path.into());
self
}
pub fn enable_lex(mut self) -> Self {
self.enable_lex = true;
self
}
#[cfg(feature = "vec")]
pub fn enable_vec(mut self) -> Self {
self.enable_vec = true;
self
}
#[cfg(feature = "vec")]
pub fn vec_model(mut self, model: impl Into<String>) -> Self {
self.vec_model = Some(model.into());
self
}
#[cfg(feature = "vec")]
pub fn embedder(mut self, embedder: LocalTextEmbedder) -> Self {
if self.vec_model.is_none() {
self.vec_model = Some(embedder.model_info().name.to_string());
}
self.embedder = Some(Arc::new(embedder));
self.enable_vec = true;
self
}
#[cfg(feature = "vec")]
pub fn with_default_embedder(self) -> Result<Self, MemvidError> {
let embedder = LocalTextEmbedder::new(TextEmbedConfig::bge_small())?;
Ok(self.embedder(embedder))
}
#[cfg(feature = "vec")]
pub fn with_embedder_config(self, config: TextEmbedConfig) -> Result<Self, MemvidError> {
let embedder = LocalTextEmbedder::new(config)?;
Ok(self.embedder(embedder))
}
fn require_path(&self) -> Result<&Path, MemvidError> {
self.path.as_deref().ok_or_else(|| {
MemvidError::Io(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"MemvidStoreBuilder requires a path",
))
})
}
fn finish(self, memvid: Memvid) -> Result<MemvidStore, MemvidError> {
let mut memvid = memvid;
if self.enable_lex {
memvid.enable_lex()?;
}
#[cfg(feature = "vec")]
{
if self.enable_vec {
memvid.enable_vec()?;
}
if let Some(model) = self.vec_model.as_deref() {
memvid.set_vec_model(model)?;
}
}
#[cfg_attr(not(feature = "vec"), allow(unused_mut))]
let mut store = MemvidStore::from_memvid(memvid);
#[cfg(feature = "vec")]
{
store.embedder = self.embedder;
}
Ok(store)
}
pub fn open(self) -> Result<MemvidStore, MemvidError> {
let path = self.require_path()?.to_path_buf();
let memvid = Memvid::open(&path)?;
self.finish(memvid)
}
pub fn create(self) -> Result<MemvidStore, MemvidError> {
let path = self.require_path()?.to_path_buf();
let memvid = Memvid::create(&path)?;
self.finish(memvid)
}
pub fn open_or_create(self) -> Result<MemvidStore, MemvidError> {
let path = self.require_path()?.to_path_buf();
let memvid = if path.exists() {
Memvid::open(&path)?
} else {
Memvid::create(&path)?
};
self.finish(memvid)
}
pub fn open_read_only(self) -> Result<MemvidStore, MemvidError> {
let path = self.require_path()?.to_path_buf();
let memvid = Memvid::open_read_only(&path)?;
self.finish(memvid)
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct MemvidFilter {
pub uri: Option<String>,
pub scope: Option<String>,
pub as_of_frame: Option<u64>,
pub as_of_ts: Option<i64>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
invalid: Vec<String>,
}
impl MemvidFilter {
fn unsupported(reason: impl Into<String>) -> Self {
Self {
invalid: vec![reason.into()],
..Self::default()
}
}
fn merge(mut self, rhs: Self) -> Self {
if rhs.uri.is_some() {
self.uri = rhs.uri;
}
if rhs.scope.is_some() {
self.scope = rhs.scope;
}
if rhs.as_of_frame.is_some() {
self.as_of_frame = rhs.as_of_frame;
}
if rhs.as_of_ts.is_some() {
self.as_of_ts = rhs.as_of_ts;
}
self.invalid.extend(rhs.invalid);
self
}
fn into_validated(self) -> Result<Self, MemvidError> {
if self.invalid.is_empty() {
Ok(self)
} else {
Err(MemvidError::UnsupportedFilter(self.invalid.join("; ")))
}
}
fn apply_to(self, request: &mut SearchRequest) {
request.uri = self.uri;
request.scope = self.scope;
request.as_of_frame = self.as_of_frame;
request.as_of_ts = self.as_of_ts;
}
}
fn json_as_string(value: &serde_json::Value) -> Option<String> {
match value {
serde_json::Value::String(s) => Some(s.clone()),
other => Some(other.to_string()),
}
}
impl SearchFilter for MemvidFilter {
type Value = serde_json::Value;
fn eq(key: impl AsRef<str>, value: Self::Value) -> Self {
let key = key.as_ref();
match key {
"uri" => Self {
uri: json_as_string(&value),
..Self::default()
},
"scope" => Self {
scope: json_as_string(&value),
..Self::default()
},
"as_of_frame" => match value.as_u64() {
Some(n) => Self {
as_of_frame: Some(n),
..Self::default()
},
None => Self::unsupported(format!("as_of_frame must be a u64, got {value}")),
},
"as_of_ts" => match value.as_i64() {
Some(n) => Self {
as_of_ts: Some(n),
..Self::default()
},
None => Self::unsupported(format!("as_of_ts must be an i64, got {value}")),
},
other => Self::unsupported(format!(
"unsupported filter key '{other}' (allowed: uri, scope, as_of_frame, as_of_ts)"
)),
}
}
fn gt(key: impl AsRef<str>, _value: Self::Value) -> Self {
Self::unsupported(format!(
"memvid does not support gt() on '{}'",
key.as_ref()
))
}
fn lt(key: impl AsRef<str>, _value: Self::Value) -> Self {
Self::unsupported(format!(
"memvid does not support lt() on '{}'",
key.as_ref()
))
}
fn and(self, rhs: Self) -> Self {
self.merge(rhs)
}
fn or(self, rhs: Self) -> Self {
self.merge(rhs)
.merge(Self::unsupported("memvid does not support or() in filters"))
}
}
const DEFAULT_SNIPPET_CHARS: usize = 400;
fn build_search_request(
query: String,
samples: u64,
filter: Option<MemvidFilter>,
) -> Result<SearchRequest, MemvidError> {
let filter = match filter {
Some(f) => f.into_validated()?,
None => MemvidFilter::default(),
};
let mut req = SearchRequest {
query,
top_k: usize::try_from(samples).unwrap_or(usize::MAX),
snippet_chars: DEFAULT_SNIPPET_CHARS,
uri: None,
scope: None,
cursor: None,
#[cfg(feature = "temporal")]
temporal: None,
as_of_frame: None,
as_of_ts: None,
no_sketch: false,
acl_context: None,
acl_enforcement_mode: memvid_core::AclEnforcementMode::default(),
};
filter.apply_to(&mut req);
Ok(req)
}
fn hit_score(hit: &SearchHit) -> f64 {
match hit.score {
Some(s) => f64::from(s),
None => 1.0 / (hit.rank as f64 + 1.0),
}
}
#[cfg(feature = "vec")]
fn ensure_vec_filter_supported(filter: &MemvidFilter) -> Result<(), MemvidError> {
if filter.uri.is_some() {
return Err(MemvidError::UnsupportedFilter(
"`uri` filter is not supported when querying through the embedder; use lex search"
.into(),
));
}
if filter.as_of_frame.is_some() || filter.as_of_ts.is_some() {
return Err(MemvidError::UnsupportedFilter(
"point-in-time filters (`as_of_frame`, `as_of_ts`) are not supported under vector \
search; use lex or `MemvidStore::search` directly"
.into(),
));
}
Ok(())
}
impl MemvidStore {
#[cfg(feature = "vec")]
fn vec_search(
&self,
query: &str,
samples: u64,
filter: &MemvidFilter,
) -> Result<memvid_core::SearchResponse, MemvidError> {
let embedder = self
.embedder
.as_ref()
.ok_or_else(|| MemvidError::UnsupportedFilter("no embedder configured".into()))?;
let embedding = embedder.encode_text(query)?;
let top_k = usize::try_from(samples).unwrap_or(usize::MAX);
let mut guard = self.lock()?;
let resp = guard.vec_search_with_embedding(
query,
&embedding,
top_k,
DEFAULT_SNIPPET_CHARS,
filter.scope.as_deref(),
)?;
Ok(resp)
}
}
impl VectorStoreIndex for MemvidStore {
type Filter = MemvidFilter;
async fn top_n<T>(
&self,
req: VectorSearchRequest<Self::Filter>,
) -> Result<Vec<(f64, String, T)>, VectorStoreError>
where
T: for<'a> Deserialize<'a> + WasmCompatSend,
{
let query = req.query().to_owned();
let samples = req.samples();
let filter = req.filter().clone();
#[cfg(feature = "vec")]
let response = if self.embedder.is_some() {
let validated = match filter.clone() {
Some(f) => f.into_validated().map_err(VectorStoreError::from)?,
None => MemvidFilter::default(),
};
ensure_vec_filter_supported(&validated).map_err(VectorStoreError::from)?;
self.vec_search(&query, samples, &validated)
.map_err(VectorStoreError::from)?
} else {
let request =
build_search_request(query, samples, filter).map_err(VectorStoreError::from)?;
let mut guard = self
.inner
.lock()
.map_err(|_| VectorStoreError::from(MemvidError::Poisoned))?;
guard.search(request).map_err(MemvidError::from)?
};
#[cfg(not(feature = "vec"))]
let response = {
let request =
build_search_request(query, samples, filter).map_err(VectorStoreError::from)?;
let mut guard = self
.inner
.lock()
.map_err(|_| VectorStoreError::from(MemvidError::Poisoned))?;
guard.search(request).map_err(MemvidError::from)?
};
let mut out = Vec::with_capacity(response.hits.len());
for hit in response.hits {
let score = hit_score(&hit);
let id = hit.frame_id.to_string();
let value = serde_json::to_value(&hit).map_err(MemvidError::from)?;
let doc: T = serde_json::from_value(value).map_err(MemvidError::from)?;
out.push((score, id, doc));
}
Ok(out)
}
async fn top_n_ids(
&self,
req: VectorSearchRequest<Self::Filter>,
) -> Result<Vec<(f64, String)>, VectorStoreError> {
let query = req.query().to_owned();
let samples = req.samples();
let filter = req.filter().clone();
#[cfg(feature = "vec")]
let response = if self.embedder.is_some() {
let validated = match filter.clone() {
Some(f) => f.into_validated().map_err(VectorStoreError::from)?,
None => MemvidFilter::default(),
};
ensure_vec_filter_supported(&validated).map_err(VectorStoreError::from)?;
self.vec_search(&query, samples, &validated)
.map_err(VectorStoreError::from)?
} else {
let request =
build_search_request(query, samples, filter).map_err(VectorStoreError::from)?;
let mut guard = self
.inner
.lock()
.map_err(|_| VectorStoreError::from(MemvidError::Poisoned))?;
guard.search(request).map_err(MemvidError::from)?
};
#[cfg(not(feature = "vec"))]
let response = {
let request =
build_search_request(query, samples, filter).map_err(VectorStoreError::from)?;
let mut guard = self
.inner
.lock()
.map_err(|_| VectorStoreError::from(MemvidError::Poisoned))?;
guard.search(request).map_err(MemvidError::from)?
};
Ok(response
.hits
.into_iter()
.map(|hit| (hit_score(&hit), hit.frame_id.to_string()))
.collect())
}
}
impl InsertDocuments for MemvidStore {
async fn insert_documents<Doc>(
&self,
documents: Vec<(Doc, OneOrMany<Embedding>)>,
) -> Result<(), VectorStoreError>
where
Doc: Serialize + Embed + WasmCompatSend,
{
#[cfg(feature = "vec")]
let local_embedder = self.embedder.clone();
let mut prepared: Vec<(Vec<u8>, Option<Vec<f32>>)> = Vec::with_capacity(documents.len());
for (doc, _embeddings) in documents {
let bytes = serde_json::to_vec(&doc).map_err(MemvidError::from)?;
#[cfg(feature = "vec")]
let emb = match &local_embedder {
Some(embedder) => {
let text = std::str::from_utf8(&bytes).unwrap_or("");
Some(embedder.encode_text(text).map_err(MemvidError::from)?)
}
None => None,
};
#[cfg(not(feature = "vec"))]
let emb: Option<Vec<f32>> = None;
prepared.push((bytes, emb));
}
let mut guard = self
.inner
.lock()
.map_err(|_| VectorStoreError::from(MemvidError::Poisoned))?;
for (bytes, emb) in prepared {
match emb {
Some(embedding) => {
guard
.put_with_embedding_and_options(&bytes, embedding, PutOptions::default())
.map_err(MemvidError::from)?;
}
None => {
guard
.put_bytes_with_options(&bytes, PutOptions::default())
.map_err(MemvidError::from)?;
}
}
}
guard.commit().map_err(MemvidError::from)?;
Ok(())
}
}