from typing import List, Dict, Any, Optional
import logging
try:
from opentelemetry import trace
HAS_OTEL = True
except ImportError:
HAS_OTEL = False
logger = logging.getLogger(__name__)
class VersionedChromaStore:
def __init__(
self,
collection_name: str,
lakefs_repository: str,
lakefs_commit: str,
persist_directory: Optional[str] = None
):
self.collection_name = collection_name
self.lakefs_repository = lakefs_repository
self.lakefs_commit = lakefs_commit
try:
import chromadb
if persist_directory:
self.client = chromadb.PersistentClient(path=persist_directory)
else:
self.client = chromadb.Client()
self.collection = self.client.get_or_create_collection(
name=collection_name,
metadata={"lakefs_repository": lakefs_repository}
)
self._has_chroma = True
except ImportError:
logger.warning("Chroma not installed, using mock mode")
self.client = None
self.collection = None
self._has_chroma = False
except Exception as e:
logger.error(f"Failed to initialize Chroma: {e}")
self.client = None
self.collection = None
self._has_chroma = False
if HAS_OTEL:
self._tracer = trace.get_tracer(__name__)
else:
self._tracer = None
def add_batch(
self,
embeddings: List[List[float]],
document_ids: List[str],
documents: List[str],
metadata: List[Dict[str, Any]]
) -> Dict[str, int]:
if self._tracer and HAS_OTEL:
span = self._tracer.start_span("rag.vector_store.add_batch")
else:
span = None
try:
if span:
span.set_attribute("rag.vector_store.type", "chroma")
span.set_attribute("rag.vector_store.add_count", len(embeddings))
span.set_attribute("rag.index.version", self.lakefs_commit)
enriched_metadata = []
for meta in metadata:
meta_copy = meta.copy()
meta_copy['lakefs_repository'] = self.lakefs_repository
meta_copy['lakefs_commit'] = self.lakefs_commit
enriched_metadata.append(meta_copy)
if not self._has_chroma or not self.collection:
logger.info(f"Mock mode: Would add {len(embeddings)} vectors to Chroma")
return {"added_count": len(embeddings)}
self.collection.add(
embeddings=embeddings,
documents=documents,
metadatas=enriched_metadata,
ids=document_ids
)
return {"added_count": len(embeddings)}
finally:
if span:
span.end()
def query(
self,
query_embeddings: List[List[float]],
n_results: int = 5,
where: Optional[Dict[str, Any]] = None,
include: Optional[List[str]] = None
) -> Dict[str, List[Any]]:
if self._tracer and HAS_OTEL:
span = self._tracer.start_span("rag.vector_store.query")
else:
span = None
try:
if span:
span.set_attribute("rag.vector_store.type", "chroma")
span.set_attribute("rag.vector_store.query_top_k", n_results)
version_where = {"lakefs_commit": self.lakefs_commit}
if where:
combined_where = {"$and": [version_where, where]}
else:
combined_where = version_where
if not self._has_chroma or not self.collection:
logger.info(f"Mock mode: Would query Chroma with n_results={n_results}")
return {
"ids": [[f"mock_doc_{i}" for i in range(min(n_results, 3))]],
"distances": [[0.1 + (i * 0.1) for i in range(min(n_results, 3))]],
"documents": [[f"Mock document {i}" for i in range(min(n_results, 3))]],
"metadatas": [[{"lakefs_commit": self.lakefs_commit} for i in range(min(n_results, 3))]]
}
results = self.collection.query(
query_embeddings=query_embeddings,
n_results=n_results,
where=combined_where,
include=include or ["documents", "metadatas", "distances"]
)
if span:
result_count = len(results.get("ids", [[]])[0])
span.set_attribute("rag.retrieval.count", result_count)
return results
finally:
if span:
span.end()
def delete(self, ids: List[str]) -> Dict[str, int]:
if not self._has_chroma or not self.collection:
logger.info(f"Mock mode: Would delete {len(ids)} vectors from Chroma")
return {"deleted_count": len(ids)}
self.collection.delete(ids=ids)
return {"deleted_count": len(ids)}
def count(self) -> int:
if not self._has_chroma or not self.collection:
return 0
return self.collection.count()
def peek(self, limit: int = 10) -> Dict[str, List[Any]]:
if not self._has_chroma or not self.collection:
return {"ids": [], "documents": [], "metadatas": []}
return self.collection.peek(limit=limit)