briefcase-python 2.4.1

Python bindings for Briefcase AI
Documentation
"""
Weaviate vector database adapter with lakeFS versioning.
"""

from typing import List, Dict, Any, Optional
import logging

try:
    from opentelemetry import trace
    HAS_OTEL = True
except ImportError:
    HAS_OTEL = False

logger = logging.getLogger(__name__)


class VersionedWeaviateStore:
    """
    Weaviate vector store with lakeFS version tracking.
    """

    def __init__(
        self,
        url: str,
        class_name: str,
        lakefs_repository: str,
        lakefs_commit: str,
        api_key: Optional[str] = None
    ):
        self.url = url
        self.class_name = class_name
        self.lakefs_repository = lakefs_repository
        self.lakefs_commit = lakefs_commit

        # Try to import and initialize Weaviate
        try:
            import weaviate
            auth_config = weaviate.AuthApiKey(api_key=api_key) if api_key else None
            self.client = weaviate.Client(url=url, auth_client_secret=auth_config)
            self._has_weaviate = True
        except ImportError:
            logger.warning("Weaviate not installed, using mock mode")
            self.client = None
            self._has_weaviate = False
        except Exception as e:
            logger.error(f"Failed to initialize Weaviate: {e}")
            self.client = None
            self._has_weaviate = False

        if HAS_OTEL:
            self._tracer = trace.get_tracer(__name__)
        else:
            self._tracer = None

    def add_batch(
        self,
        embeddings: List[List[float]],
        document_ids: List[str],
        texts: List[str],
        metadata: List[Dict[str, Any]]
    ) -> Dict[str, int]:
        """
        Add embeddings in batch with version metadata.
        """
        if self._tracer and HAS_OTEL:
            span = self._tracer.start_span("rag.vector_store.add_batch")
        else:
            span = None

        try:
            if span:
                span.set_attribute("rag.vector_store.type", "weaviate")
                span.set_attribute("rag.vector_store.add_count", len(embeddings))
                span.set_attribute("rag.index.version", self.lakefs_commit)

            if not self._has_weaviate or not self.client:
                logger.info(f"Mock mode: Would add {len(embeddings)} vectors to Weaviate")
                return {"added_count": len(embeddings)}

            # Add with version metadata
            with self.client.batch as batch:
                for doc_id, embedding, text, meta in zip(document_ids, embeddings, texts, metadata):
                    properties = {
                        "document_id": doc_id,
                        "text": text,
                        "lakefs_repository": self.lakefs_repository,
                        "lakefs_commit": self.lakefs_commit,
                        **meta
                    }

                    batch.add_data_object(
                        data_object=properties,
                        class_name=self.class_name,
                        vector=embedding,
                        uuid=doc_id
                    )

            return {"added_count": len(embeddings)}

        finally:
            if span:
                span.end()

    def query(
        self,
        query_embedding: List[float],
        top_k: int = 5,
        where_filter: Optional[Dict] = None
    ) -> List[Dict[str, Any]]:
        """
        Query with version filtering.
        """
        if self._tracer and HAS_OTEL:
            span = self._tracer.start_span("rag.vector_store.query")
        else:
            span = None

        try:
            if span:
                span.set_attribute("rag.vector_store.type", "weaviate")
                span.set_attribute("rag.vector_store.query_top_k", top_k)

            if not self._has_weaviate or not self.client:
                logger.info(f"Mock mode: Would query Weaviate with top_k={top_k}")
                return [
                    {
                        "id": f"mock_doc_{i}",
                        "distance": 0.05 + (i * 0.05),
                        "text": f"Mock document {i}",
                        "metadata": {"lakefs_commit": self.lakefs_commit}
                    }
                    for i in range(min(top_k, 3))
                ]

            # Build version filter
            version_filter = {
                "path": ["lakefs_commit"],
                "operator": "Equal",
                "valueText": self.lakefs_commit
            }

            if where_filter:
                combined_filter = {
                    "operator": "And",
                    "operands": [version_filter, where_filter]
                }
            else:
                combined_filter = version_filter

            # Query Weaviate
            result = (
                self.client.query
                .get(self.class_name, ["document_id", "text", "lakefs_commit", "lakefs_repository"])
                .with_near_vector({"vector": query_embedding})
                .with_limit(top_k)
                .with_where(combined_filter)
                .with_additional(["distance", "id"])
                .do()
            )

            objects = result.get("data", {}).get("Get", {}).get(self.class_name, [])

            if span:
                span.set_attribute("rag.retrieval.count", len(objects))

            return [
                {
                    "id": obj.get("_additional", {}).get("id"),
                    "distance": obj.get("_additional", {}).get("distance"),
                    "text": obj.get("text"),
                    "metadata": {
                        "document_id": obj.get("document_id"),
                        "lakefs_commit": obj.get("lakefs_commit"),
                        "lakefs_repository": obj.get("lakefs_repository")
                    }
                }
                for obj in objects
            ]

        finally:
            if span:
                span.end()

    def delete_by_id(self, document_id: str) -> bool:
        """Delete vector by ID."""
        if not self._has_weaviate or not self.client:
            logger.info(f"Mock mode: Would delete {document_id} from Weaviate")
            return True

        try:
            self.client.data_object.delete(document_id, class_name=self.class_name)
            return True
        except Exception as e:
            logger.error(f"Failed to delete {document_id}: {e}")
            return False

    def get_schema(self) -> Dict[str, Any]:
        """Get class schema."""
        if not self._has_weaviate or not self.client:
            return {"name": self.class_name, "properties": []}

        try:
            schema = self.client.schema.get(self.class_name)
            return schema
        except Exception as e:
            logger.error(f"Failed to get schema: {e}")
            return {}