briefcase-python 2.4.1

Python bindings for Briefcase AI
Documentation
"""
Chroma vector database adapter with lakeFS versioning.
"""

from typing import List, Dict, Any, Optional
import logging

try:
    from opentelemetry import trace
    HAS_OTEL = True
except ImportError:
    HAS_OTEL = False

logger = logging.getLogger(__name__)


class VersionedChromaStore:
    """
    Chroma vector store with lakeFS version tracking.
    """

    def __init__(
        self,
        collection_name: str,
        lakefs_repository: str,
        lakefs_commit: str,
        persist_directory: Optional[str] = None
    ):
        self.collection_name = collection_name
        self.lakefs_repository = lakefs_repository
        self.lakefs_commit = lakefs_commit

        # Try to import and initialize Chroma
        try:
            import chromadb
            if persist_directory:
                self.client = chromadb.PersistentClient(path=persist_directory)
            else:
                self.client = chromadb.Client()

            self.collection = self.client.get_or_create_collection(
                name=collection_name,
                metadata={"lakefs_repository": lakefs_repository}
            )
            self._has_chroma = True
        except ImportError:
            logger.warning("Chroma not installed, using mock mode")
            self.client = None
            self.collection = None
            self._has_chroma = False
        except Exception as e:
            logger.error(f"Failed to initialize Chroma: {e}")
            self.client = None
            self.collection = None
            self._has_chroma = False

        if HAS_OTEL:
            self._tracer = trace.get_tracer(__name__)
        else:
            self._tracer = None

    def add_batch(
        self,
        embeddings: List[List[float]],
        document_ids: List[str],
        documents: List[str],
        metadata: List[Dict[str, Any]]
    ) -> Dict[str, int]:
        """
        Add embeddings in batch with version metadata.
        """
        if self._tracer and HAS_OTEL:
            span = self._tracer.start_span("rag.vector_store.add_batch")
        else:
            span = None

        try:
            if span:
                span.set_attribute("rag.vector_store.type", "chroma")
                span.set_attribute("rag.vector_store.add_count", len(embeddings))
                span.set_attribute("rag.index.version", self.lakefs_commit)

            # Enrich metadata with version info
            enriched_metadata = []
            for meta in metadata:
                meta_copy = meta.copy()
                meta_copy['lakefs_repository'] = self.lakefs_repository
                meta_copy['lakefs_commit'] = self.lakefs_commit
                enriched_metadata.append(meta_copy)

            if not self._has_chroma or not self.collection:
                logger.info(f"Mock mode: Would add {len(embeddings)} vectors to Chroma")
                return {"added_count": len(embeddings)}

            # Add to Chroma
            self.collection.add(
                embeddings=embeddings,
                documents=documents,
                metadatas=enriched_metadata,
                ids=document_ids
            )

            return {"added_count": len(embeddings)}

        finally:
            if span:
                span.end()

    def query(
        self,
        query_embeddings: List[List[float]],
        n_results: int = 5,
        where: Optional[Dict[str, Any]] = None,
        include: Optional[List[str]] = None
    ) -> Dict[str, List[Any]]:
        """
        Query with version filtering.
        """
        if self._tracer and HAS_OTEL:
            span = self._tracer.start_span("rag.vector_store.query")
        else:
            span = None

        try:
            if span:
                span.set_attribute("rag.vector_store.type", "chroma")
                span.set_attribute("rag.vector_store.query_top_k", n_results)

            # Add version filter
            version_where = {"lakefs_commit": self.lakefs_commit}
            if where:
                combined_where = {"$and": [version_where, where]}
            else:
                combined_where = version_where

            if not self._has_chroma or not self.collection:
                logger.info(f"Mock mode: Would query Chroma with n_results={n_results}")
                return {
                    "ids": [[f"mock_doc_{i}" for i in range(min(n_results, 3))]],
                    "distances": [[0.1 + (i * 0.1) for i in range(min(n_results, 3))]],
                    "documents": [[f"Mock document {i}" for i in range(min(n_results, 3))]],
                    "metadatas": [[{"lakefs_commit": self.lakefs_commit} for i in range(min(n_results, 3))]]
                }

            # Query Chroma
            results = self.collection.query(
                query_embeddings=query_embeddings,
                n_results=n_results,
                where=combined_where,
                include=include or ["documents", "metadatas", "distances"]
            )

            if span:
                result_count = len(results.get("ids", [[]])[0])
                span.set_attribute("rag.retrieval.count", result_count)

            return results

        finally:
            if span:
                span.end()

    def delete(self, ids: List[str]) -> Dict[str, int]:
        """Delete vectors by IDs."""
        if not self._has_chroma or not self.collection:
            logger.info(f"Mock mode: Would delete {len(ids)} vectors from Chroma")
            return {"deleted_count": len(ids)}

        self.collection.delete(ids=ids)
        return {"deleted_count": len(ids)}

    def count(self) -> int:
        """Get collection count."""
        if not self._has_chroma or not self.collection:
            return 0

        return self.collection.count()

    def peek(self, limit: int = 10) -> Dict[str, List[Any]]:
        """Peek at collection contents."""
        if not self._has_chroma or not self.collection:
            return {"ids": [], "documents": [], "metadatas": []}

        return self.collection.peek(limit=limit)