import hashlib
import json
from dataclasses import dataclass, field, asdict
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Optional
import logging
try:
from opentelemetry import trace
HAS_OTEL = True
except ImportError:
HAS_OTEL = False
import briefcase.semantic_conventions.rag as rag_conventions
logger = logging.getLogger(__name__)
@dataclass
class Document:
id: str
content: str
metadata: dict = field(default_factory=dict)
path: str = ""
@property
def content_hash(self) -> str:
return hashlib.sha256(self.content.encode()).hexdigest()
@dataclass
class EmbeddingRecord:
document_id: str
document_hash: str embedding: List[float]
model: str
model_version: str
created_at: str
@dataclass
class EmbeddingBatch:
batch_id: str
model: str
model_version: str
dimensions: int
embeddings: List[List[float]]
document_ids: List[str]
document_hashes: List[str]
created_at: datetime
source_commit: str
def to_dict(self) -> Dict[str, Any]:
d = asdict(self)
d["created_at"] = self.created_at.isoformat()
d.pop("embeddings", None)
return d
class ManifestStatus(Enum):
CURRENT = "current"
STALE_DOCUMENTS = "stale_documents"
STALE_MODEL = "stale_model"
STALE_BOTH = "stale_both"
REBUILDING = "rebuilding"
@dataclass
class EmbeddingManifest:
manifest_id: str
index_name: str
model: str
model_version: str
dimensions: int
source_commit: str document_count: int
document_hashes: Dict[str, str] batch_ids: List[str]
created_at: str status: str = ManifestStatus.CURRENT.value
parent_manifest_id: Optional[str] = None
metadata: Dict[str, Any] = field(default_factory=dict)
def to_dict(self) -> Dict[str, Any]:
return asdict(self)
def to_json(self) -> str:
return json.dumps(self.to_dict(), sort_keys=True, indent=2)
@classmethod
def from_dict(cls, d: Dict[str, Any]) -> "EmbeddingManifest":
return cls(**{k: v for k, v in d.items() if k in cls.__dataclass_fields__})
@classmethod
def from_json(cls, s: str) -> "EmbeddingManifest":
return cls.from_dict(json.loads(s))
@property
def manifest_hash(self) -> str:
content = json.dumps({
"index_name": self.index_name,
"model": self.model,
"model_version": self.model_version,
"source_commit": self.source_commit,
"document_hashes": self.document_hashes,
}, sort_keys=True)
return hashlib.sha256(content.encode()).hexdigest()
@dataclass
class InvalidationReport:
manifest_id: str
index_name: str
is_valid: bool
status: str added_documents: List[str] = field(default_factory=list)
removed_documents: List[str] = field(default_factory=list)
changed_documents: List[str] = field(default_factory=list)
model_changed: bool = False
old_model: Optional[str] = None
new_model: Optional[str] = None
old_model_version: Optional[str] = None
new_model_version: Optional[str] = None
timestamp: str = field(default_factory=lambda: datetime.utcnow().isoformat())
def to_dict(self) -> Dict[str, Any]:
return {k: v for k, v in asdict(self).items() if v is not None}
class VersionedEmbeddingPipeline:
def __init__(
self,
embedding_model: Any = None,
lakefs_client: Any = None,
repository: Optional[str] = None,
branch: str = "main",
):
self.model = embedding_model
self.lakefs = lakefs_client
self.repository = repository
self.branch = branch
self._manifests: Dict[str, List[EmbeddingManifest]] = {}
self._batches: Dict[str, EmbeddingBatch] = {}
def create_embedding_batch(
self,
documents: List[Document],
batch_id: Optional[str] = None,
source_commit: Optional[str] = None,
) -> EmbeddingBatch:
if batch_id is None:
batch_id = f"batch_{datetime.utcnow().strftime('%Y%m%d_%H%M%S')}"
if source_commit is None:
if self.lakefs and hasattr(self.lakefs, 'get_commit'):
try:
source_commit = self.lakefs.get_commit()
except Exception:
source_commit = "unknown"
else:
source_commit = "unknown"
model_name = getattr(self.model, 'name', 'mock-model') if self.model else 'mock-model'
model_version = getattr(self.model, 'version', '1.0') if self.model else '1.0'
texts = [doc.content for doc in documents]
if self.model and hasattr(self.model, 'embed'):
try:
raw_embeddings = self.model.embed(texts)
embeddings = [list(e) for e in raw_embeddings]
except Exception as e:
logger.warning(f"Embedding model failed, using mock: {e}")
embeddings = self._mock_embeddings(len(texts))
else:
embeddings = self._mock_embeddings(len(texts))
dimensions = len(embeddings[0]) if embeddings else 0
document_ids = [doc.id for doc in documents]
document_hashes = [doc.content_hash for doc in documents]
batch = EmbeddingBatch(
batch_id=batch_id,
model=model_name,
model_version=model_version,
dimensions=dimensions,
embeddings=embeddings,
document_ids=document_ids,
document_hashes=document_hashes,
created_at=datetime.utcnow(),
source_commit=source_commit,
)
self._batches[batch_id] = batch
logger.info(f"Created embedding batch {batch_id}: {len(documents)} docs, {dimensions}d")
return batch
def create_manifest(
self,
index_name: str,
batches: List[EmbeddingBatch],
metadata: Optional[Dict[str, Any]] = None,
) -> EmbeddingManifest:
if not batches:
raise ValueError("At least one batch is required to create a manifest")
doc_hashes: Dict[str, str] = {}
batch_ids = []
total_docs = 0
for batch in batches:
batch_ids.append(batch.batch_id)
for doc_id, doc_hash in zip(batch.document_ids, batch.document_hashes):
doc_hashes[doc_id] = doc_hash
total_docs += 1
first = batches[0]
source_commit = first.source_commit
parent = self.get_latest_manifest(index_name)
manifest_id = f"{index_name}_{datetime.utcnow().strftime('%Y%m%d_%H%M%S')}_{first.source_commit[:8]}"
manifest = EmbeddingManifest(
manifest_id=manifest_id,
index_name=index_name,
model=first.model,
model_version=first.model_version,
dimensions=first.dimensions,
source_commit=source_commit,
document_count=len(doc_hashes),
document_hashes=doc_hashes,
batch_ids=batch_ids,
created_at=datetime.utcnow().isoformat(),
status=ManifestStatus.CURRENT.value,
parent_manifest_id=parent.manifest_id if parent else None,
metadata=metadata or {},
)
self._manifests.setdefault(index_name, []).append(manifest)
if self.lakefs and self.repository:
path = f"manifests/{index_name}/{manifest_id}.json"
try:
self.lakefs.upload_object(
self.repository, self.branch, path, manifest.to_json()
)
except Exception as e:
logger.warning(f"Failed to upload manifest to lakeFS: {e}")
logger.info(
f"Created manifest {manifest_id}: {len(doc_hashes)} docs, "
f"model={first.model}@{first.model_version}"
)
return manifest
def get_latest_manifest(self, index_name: str) -> Optional[EmbeddingManifest]:
manifests = self._manifests.get(index_name, [])
return manifests[-1] if manifests else None
def get_manifests(
self,
index_name: str,
limit: Optional[int] = None,
) -> List[EmbeddingManifest]:
manifests = self._manifests.get(index_name, [])
if limit:
manifests = manifests[-limit:]
return manifests
def get_all_index_names(self) -> List[str]:
return list(self._manifests.keys())
def check_invalidation(
self,
index_name: str,
current_documents: List[Document],
current_model: Optional[str] = None,
current_model_version: Optional[str] = None,
) -> InvalidationReport:
manifest = self.get_latest_manifest(index_name)
if manifest is None:
return InvalidationReport(
manifest_id="none",
index_name=index_name,
is_valid=False,
status=ManifestStatus.STALE_DOCUMENTS.value,
added_documents=[doc.id for doc in current_documents],
)
current_hashes = {doc.id: doc.content_hash for doc in current_documents}
manifest_hashes = manifest.document_hashes
current_ids = set(current_hashes.keys())
manifest_ids = set(manifest_hashes.keys())
added = sorted(current_ids - manifest_ids)
removed = sorted(manifest_ids - current_ids)
changed = sorted([
doc_id for doc_id in current_ids & manifest_ids
if current_hashes[doc_id] != manifest_hashes[doc_id]
])
docs_changed = bool(added or removed or changed)
model_changed = False
effective_model = current_model or (
getattr(self.model, 'name', None) if self.model else None
)
effective_version = current_model_version or (
getattr(self.model, 'version', None) if self.model else None
)
if effective_model and effective_model != manifest.model:
model_changed = True
if effective_version and effective_version != manifest.model_version:
model_changed = True
if docs_changed and model_changed:
status = ManifestStatus.STALE_BOTH
elif docs_changed:
status = ManifestStatus.STALE_DOCUMENTS
elif model_changed:
status = ManifestStatus.STALE_MODEL
else:
status = ManifestStatus.CURRENT
is_valid = status == ManifestStatus.CURRENT
if not is_valid:
manifest.status = status.value
return InvalidationReport(
manifest_id=manifest.manifest_id,
index_name=index_name,
is_valid=is_valid,
status=status.value,
added_documents=added,
removed_documents=removed,
changed_documents=changed,
model_changed=model_changed,
old_model=manifest.model if model_changed else None,
new_model=effective_model if model_changed else None,
old_model_version=manifest.model_version if model_changed else None,
new_model_version=effective_version if model_changed else None,
)
def rebuild_index(
self,
index_name: str,
documents: List[Document],
source_commit: Optional[str] = None,
batch_id: Optional[str] = None,
) -> EmbeddingManifest:
batch = self.create_embedding_batch(
documents, batch_id=batch_id, source_commit=source_commit
)
manifest = self.create_manifest(index_name, [batch])
return manifest
@staticmethod
def _mock_embeddings(count: int, dimensions: int = 128) -> List[List[float]]:
return [[0.0] * dimensions for _ in range(count)]