from typing import List, Dict, Any, Optional
import logging
try:
from opentelemetry import trace
HAS_OTEL = True
except ImportError:
HAS_OTEL = False
logger = logging.getLogger(__name__)
class VersionedWeaviateStore:
def __init__(
self,
url: str,
class_name: str,
lakefs_repository: str,
lakefs_commit: str,
api_key: Optional[str] = None
):
self.url = url
self.class_name = class_name
self.lakefs_repository = lakefs_repository
self.lakefs_commit = lakefs_commit
try:
import weaviate
auth_config = weaviate.AuthApiKey(api_key=api_key) if api_key else None
self.client = weaviate.Client(url=url, auth_client_secret=auth_config)
self._has_weaviate = True
except ImportError:
logger.warning("Weaviate not installed, using mock mode")
self.client = None
self._has_weaviate = False
except Exception as e:
logger.error(f"Failed to initialize Weaviate: {e}")
self.client = None
self._has_weaviate = False
if HAS_OTEL:
self._tracer = trace.get_tracer(__name__)
else:
self._tracer = None
def add_batch(
self,
embeddings: List[List[float]],
document_ids: List[str],
texts: List[str],
metadata: List[Dict[str, Any]]
) -> Dict[str, int]:
if self._tracer and HAS_OTEL:
span = self._tracer.start_span("rag.vector_store.add_batch")
else:
span = None
try:
if span:
span.set_attribute("rag.vector_store.type", "weaviate")
span.set_attribute("rag.vector_store.add_count", len(embeddings))
span.set_attribute("rag.index.version", self.lakefs_commit)
if not self._has_weaviate or not self.client:
logger.info(f"Mock mode: Would add {len(embeddings)} vectors to Weaviate")
return {"added_count": len(embeddings)}
with self.client.batch as batch:
for doc_id, embedding, text, meta in zip(document_ids, embeddings, texts, metadata):
properties = {
"document_id": doc_id,
"text": text,
"lakefs_repository": self.lakefs_repository,
"lakefs_commit": self.lakefs_commit,
**meta
}
batch.add_data_object(
data_object=properties,
class_name=self.class_name,
vector=embedding,
uuid=doc_id
)
return {"added_count": len(embeddings)}
finally:
if span:
span.end()
def query(
self,
query_embedding: List[float],
top_k: int = 5,
where_filter: Optional[Dict] = None
) -> List[Dict[str, Any]]:
if self._tracer and HAS_OTEL:
span = self._tracer.start_span("rag.vector_store.query")
else:
span = None
try:
if span:
span.set_attribute("rag.vector_store.type", "weaviate")
span.set_attribute("rag.vector_store.query_top_k", top_k)
if not self._has_weaviate or not self.client:
logger.info(f"Mock mode: Would query Weaviate with top_k={top_k}")
return [
{
"id": f"mock_doc_{i}",
"distance": 0.05 + (i * 0.05),
"text": f"Mock document {i}",
"metadata": {"lakefs_commit": self.lakefs_commit}
}
for i in range(min(top_k, 3))
]
version_filter = {
"path": ["lakefs_commit"],
"operator": "Equal",
"valueText": self.lakefs_commit
}
if where_filter:
combined_filter = {
"operator": "And",
"operands": [version_filter, where_filter]
}
else:
combined_filter = version_filter
result = (
self.client.query
.get(self.class_name, ["document_id", "text", "lakefs_commit", "lakefs_repository"])
.with_near_vector({"vector": query_embedding})
.with_limit(top_k)
.with_where(combined_filter)
.with_additional(["distance", "id"])
.do()
)
objects = result.get("data", {}).get("Get", {}).get(self.class_name, [])
if span:
span.set_attribute("rag.retrieval.count", len(objects))
return [
{
"id": obj.get("_additional", {}).get("id"),
"distance": obj.get("_additional", {}).get("distance"),
"text": obj.get("text"),
"metadata": {
"document_id": obj.get("document_id"),
"lakefs_commit": obj.get("lakefs_commit"),
"lakefs_repository": obj.get("lakefs_repository")
}
}
for obj in objects
]
finally:
if span:
span.end()
def delete_by_id(self, document_id: str) -> bool:
if not self._has_weaviate or not self.client:
logger.info(f"Mock mode: Would delete {document_id} from Weaviate")
return True
try:
self.client.data_object.delete(document_id, class_name=self.class_name)
return True
except Exception as e:
logger.error(f"Failed to delete {document_id}: {e}")
return False
def get_schema(self) -> Dict[str, Any]:
if not self._has_weaviate or not self.client:
return {"name": self.class_name, "properties": []}
try:
schema = self.client.schema.get(self.class_name)
return schema
except Exception as e:
logger.error(f"Failed to get schema: {e}")
return {}