import hashlib
import json
from dataclasses import dataclass, field, asdict
from datetime import datetime, timedelta
from enum import Enum
from typing import Any, Dict, List, Optional, Callable
import logging
try:
from opentelemetry import trace
HAS_OTEL = True
except ImportError:
HAS_OTEL = False
from briefcase.semantic_conventions.external_data import *
logger = logging.getLogger(__name__)
class SnapshotFrequency(Enum):
EVERY_CALL = "every_call"
ON_CHANGE = "on_change"
HOURLY = "hourly"
DAILY = "daily"
WEEKLY = "weekly"
@dataclass
class SnapshotPolicy:
frequency: SnapshotFrequency = SnapshotFrequency.ON_CHANGE
retention_days: int = 90
change_threshold: float = 0.0
max_snapshots: int = 0
compress: bool = False
@dataclass
class Snapshot:
snapshot_id: str
source_name: str
source_type: str data_hash: str
timestamp: str size_bytes: int
record_count: Optional[int] = None
metadata: Dict[str, Any] = field(default_factory=dict)
lakefs_path: Optional[str] = None
parent_snapshot_id: Optional[str] = None
def to_dict(self) -> Dict[str, Any]:
return {k: v for k, v in asdict(self).items() if v is not None}
@classmethod
def from_dict(cls, d: Dict[str, Any]) -> "Snapshot":
return cls(**{k: v for k, v in d.items() if k in cls.__dataclass_fields__})
@dataclass
class DriftReport:
source_name: str
baseline_snapshot_id: str
current_snapshot_id: str
baseline_hash: str
current_hash: str
has_changed: bool
size_delta: int record_count_delta: Optional[int] = None
drift_score: float = 0.0 details: Dict[str, Any] = field(default_factory=dict)
timestamp: str = field(default_factory=lambda: datetime.utcnow().isoformat())
def to_dict(self) -> Dict[str, Any]:
return asdict(self)
class ExternalDataTracker:
def __init__(
self,
lakefs_client: Optional[Any] = None,
repository: Optional[str] = None,
branch: str = "main",
default_policy: Optional[SnapshotPolicy] = None,
):
self.lakefs = lakefs_client
self.repository = repository
self.branch = branch
self._default_policy = default_policy or SnapshotPolicy()
self._policies: Dict[str, SnapshotPolicy] = {}
self._snapshots: Dict[str, List[Snapshot]] = {}
self._last_snapshot_times: Dict[str, datetime] = {}
self._change_detectors: Dict[str, Callable] = {}
def set_policy(self, source_name: str, policy: SnapshotPolicy) -> None:
self._policies[source_name] = policy
def get_policy(self, source_name: str) -> SnapshotPolicy:
return self._policies.get(source_name, self._default_policy)
def set_change_detector(
self, source_name: str, detector: Callable[[Any, Any], float]
) -> None:
self._change_detectors[source_name] = detector
def track_api_call(
self,
api_name: str,
endpoint: str,
method: str,
response_data: Any,
version: Optional[str] = None,
status_code: int = 200,
record_count: Optional[int] = None,
store_snapshot: bool = True,
) -> Dict[str, Any]:
span = self._start_span("external_data.track_api_call", {
EXTERNAL_API_NAME: api_name,
EXTERNAL_API_ENDPOINT: endpoint,
EXTERNAL_API_METHOD: method,
EXTERNAL_API_STATUS_CODE: status_code,
})
try:
response_str = json.dumps(response_data, sort_keys=True, default=str)
data_hash = hashlib.sha256(response_str.encode()).hexdigest()
size_bytes = len(response_str.encode())
timestamp = datetime.utcnow()
metadata = {
"api_name": api_name,
"endpoint": endpoint,
"method": method,
"status_code": status_code,
}
if version:
metadata["api_version"] = version
result = {
"data_hash": data_hash,
"timestamp": timestamp.isoformat(),
"size_bytes": size_bytes,
"snapshot_id": None,
"snapshot_stored": False,
"drift_detected": False,
}
if store_snapshot:
should_store, drift = self._evaluate_snapshot_policy(
api_name, data_hash, size_bytes, response_data, timestamp
)
if should_store:
snapshot = self._create_snapshot(
source_name=api_name,
source_type="api",
data_hash=data_hash,
size_bytes=size_bytes,
record_count=record_count,
metadata=metadata,
timestamp=timestamp,
data=response_str if self.lakefs else None,
)
result["snapshot_id"] = snapshot.snapshot_id
result["snapshot_stored"] = True
result["drift_detected"] = drift
self._set_span_attributes(span, {
EXTERNAL_DATA_HASH: data_hash,
EXTERNAL_DATA_SIZE: size_bytes,
EXTERNAL_DATA_SOURCE: api_name,
})
if result.get("snapshot_id"):
self._set_span_attributes(span, {
EXTERNAL_SNAPSHOT_ID: result["snapshot_id"],
})
return result
except Exception as e:
self._record_span_exception(span, e)
raise
finally:
self._end_span(span)
def track_db_query(
self,
db_system: str,
db_name: str,
query: str,
result_data: Any = None,
result_count: int = 0,
store_snapshot: bool = False,
) -> Dict[str, Any]:
source_name = f"{db_system}.{db_name}"
span = self._start_span("external_data.track_db_query", {
EXTERNAL_DB_SYSTEM: db_system,
EXTERNAL_DB_NAME: db_name,
EXTERNAL_DB_QUERY_HASH: hashlib.sha256(query.encode()).hexdigest()[:16],
EXTERNAL_DB_RESULT_COUNT: result_count,
})
try:
query_hash = hashlib.sha256(query.encode()).hexdigest()[:16]
timestamp = datetime.utcnow()
result = {
"query_hash": query_hash,
"result_count": result_count,
"timestamp": timestamp.isoformat(),
"data_hash": None,
"snapshot_id": None,
"snapshot_stored": False,
"drift_detected": False,
}
if result_data is not None:
data_str = json.dumps(result_data, sort_keys=True, default=str)
data_hash = hashlib.sha256(data_str.encode()).hexdigest()
size_bytes = len(data_str.encode())
result["data_hash"] = data_hash
if store_snapshot:
should_store, drift = self._evaluate_snapshot_policy(
source_name, data_hash, size_bytes, result_data, timestamp
)
if should_store:
snapshot = self._create_snapshot(
source_name=source_name,
source_type="db",
data_hash=data_hash,
size_bytes=size_bytes,
record_count=result_count,
metadata={
"db_system": db_system,
"db_name": db_name,
"query_hash": query_hash,
},
timestamp=timestamp,
data=data_str if self.lakefs else None,
)
result["snapshot_id"] = snapshot.snapshot_id
result["snapshot_stored"] = True
result["drift_detected"] = drift
self._set_span_attributes(span, {
EXTERNAL_DATA_SOURCE: source_name,
})
if result.get("data_hash"):
self._set_span_attributes(span, {
EXTERNAL_DATA_HASH: result["data_hash"],
})
return result
except Exception as e:
self._record_span_exception(span, e)
raise
finally:
self._end_span(span)
def track_file_fetch(
self,
source_name: str,
file_data: bytes,
file_path: Optional[str] = None,
record_count: Optional[int] = None,
store_snapshot: bool = True,
) -> Dict[str, Any]:
span = self._start_span("external_data.track_file_fetch", {
EXTERNAL_DATA_SOURCE: source_name,
})
try:
data_hash = hashlib.sha256(file_data).hexdigest()
size_bytes = len(file_data)
timestamp = datetime.utcnow()
metadata = {"source_name": source_name}
if file_path:
metadata["file_path"] = file_path
result = {
"data_hash": data_hash,
"size_bytes": size_bytes,
"timestamp": timestamp.isoformat(),
"snapshot_id": None,
"snapshot_stored": False,
"drift_detected": False,
}
if store_snapshot:
should_store, drift = self._evaluate_snapshot_policy(
source_name, data_hash, size_bytes, file_data, timestamp
)
if should_store:
snapshot = self._create_snapshot(
source_name=source_name,
source_type="file",
data_hash=data_hash,
size_bytes=size_bytes,
record_count=record_count,
metadata=metadata,
timestamp=timestamp,
data=None, )
result["snapshot_id"] = snapshot.snapshot_id
result["snapshot_stored"] = True
result["drift_detected"] = drift
self._set_span_attributes(span, {
EXTERNAL_DATA_HASH: data_hash,
EXTERNAL_DATA_SIZE: size_bytes,
})
return result
except Exception as e:
self._record_span_exception(span, e)
raise
finally:
self._end_span(span)
def detect_drift(
self,
source_name: str,
current_data: Any = None,
current_hash: Optional[str] = None,
current_size: Optional[int] = None,
current_record_count: Optional[int] = None,
) -> Optional[DriftReport]:
snapshots = self._snapshots.get(source_name)
if not snapshots:
return None
baseline = snapshots[-1]
if current_hash is None and current_data is not None:
if isinstance(current_data, bytes):
current_hash = hashlib.sha256(current_data).hexdigest()
current_size = current_size or len(current_data)
else:
data_str = json.dumps(current_data, sort_keys=True, default=str)
current_hash = hashlib.sha256(data_str.encode()).hexdigest()
current_size = current_size or len(data_str.encode())
if current_hash is None:
raise ValueError("Must provide either current_data or current_hash")
has_changed = current_hash != baseline.data_hash
size_delta = (current_size or 0) - baseline.size_bytes
record_delta = None
if current_record_count is not None and baseline.record_count is not None:
record_delta = current_record_count - baseline.record_count
drift_score = 0.0
if has_changed:
if source_name in self._change_detectors and current_data is not None:
try:
drift_score = self._change_detectors[source_name](
None, current_data )
except Exception as e:
logger.warning(f"Custom change detector failed for {source_name}: {e}")
drift_score = 1.0
else:
drift_score = 1.0
return DriftReport(
source_name=source_name,
baseline_snapshot_id=baseline.snapshot_id,
current_snapshot_id="pending",
baseline_hash=baseline.data_hash,
current_hash=current_hash,
has_changed=has_changed,
size_delta=size_delta,
record_count_delta=record_delta,
drift_score=drift_score,
)
def compare_snapshots(
self, snapshot_a_id: str, snapshot_b_id: str
) -> Optional[DriftReport]:
snap_a = self._find_snapshot(snapshot_a_id)
snap_b = self._find_snapshot(snapshot_b_id)
if snap_a is None or snap_b is None:
return None
has_changed = snap_a.data_hash != snap_b.data_hash
size_delta = snap_b.size_bytes - snap_a.size_bytes
record_delta = None
if snap_a.record_count is not None and snap_b.record_count is not None:
record_delta = snap_b.record_count - snap_a.record_count
return DriftReport(
source_name=snap_a.source_name,
baseline_snapshot_id=snap_a.snapshot_id,
current_snapshot_id=snap_b.snapshot_id,
baseline_hash=snap_a.data_hash,
current_hash=snap_b.data_hash,
has_changed=has_changed,
size_delta=size_delta,
record_count_delta=record_delta,
drift_score=1.0 if has_changed else 0.0,
)
def get_snapshots(
self,
source_name: str,
since: Optional[datetime] = None,
until: Optional[datetime] = None,
limit: Optional[int] = None,
) -> List[Snapshot]:
snaps = self._snapshots.get(source_name, [])
if since:
since_str = since.isoformat()
snaps = [s for s in snaps if s.timestamp >= since_str]
if until:
until_str = until.isoformat()
snaps = [s for s in snaps if s.timestamp <= until_str]
if limit:
snaps = snaps[-limit:]
return snaps
def get_latest_snapshot(self, source_name: str) -> Optional[Snapshot]:
snaps = self._snapshots.get(source_name, [])
return snaps[-1] if snaps else None
def get_all_sources(self) -> List[str]:
return list(self._snapshots.keys())
def get_snapshot_count(self, source_name: Optional[str] = None) -> int:
if source_name:
return len(self._snapshots.get(source_name, []))
return sum(len(v) for v in self._snapshots.values())
def enforce_retention(self, source_name: Optional[str] = None) -> int:
removed = 0
sources = [source_name] if source_name else list(self._snapshots.keys())
for src in sources:
policy = self.get_policy(src)
snaps = self._snapshots.get(src, [])
if not snaps:
continue
original_count = len(snaps)
if policy.retention_days > 0:
cutoff = (datetime.utcnow() - timedelta(days=policy.retention_days)).isoformat()
snaps = [s for s in snaps if s.timestamp >= cutoff]
if policy.max_snapshots > 0 and len(snaps) > policy.max_snapshots:
snaps = snaps[-policy.max_snapshots:]
self._snapshots[src] = snaps
removed += original_count - len(snaps)
return removed
def _evaluate_snapshot_policy(
self,
source_name: str,
data_hash: str,
size_bytes: int,
data: Any,
timestamp: datetime,
) -> tuple:
policy = self.get_policy(source_name)
last_snap = self.get_latest_snapshot(source_name)
if last_snap is None:
return True, False
drift_detected = data_hash != last_snap.data_hash
if policy.frequency == SnapshotFrequency.EVERY_CALL:
return True, drift_detected
if policy.frequency == SnapshotFrequency.ON_CHANGE:
if not drift_detected:
return False, False
if policy.change_threshold > 0.0 and source_name in self._change_detectors:
try:
score = self._change_detectors[source_name](None, data)
if score < policy.change_threshold:
return False, True except Exception:
pass
return True, drift_detected
last_time = self._last_snapshot_times.get(source_name)
if last_time is None:
return True, drift_detected
interval_map = {
SnapshotFrequency.HOURLY: timedelta(hours=1),
SnapshotFrequency.DAILY: timedelta(days=1),
SnapshotFrequency.WEEKLY: timedelta(weeks=1),
}
interval = interval_map.get(policy.frequency, timedelta(hours=1))
if timestamp - last_time >= interval:
return True, drift_detected
return False, drift_detected
def _create_snapshot(
self,
source_name: str,
source_type: str,
data_hash: str,
size_bytes: int,
record_count: Optional[int],
metadata: Dict[str, Any],
timestamp: datetime,
data: Optional[str] = None,
) -> Snapshot:
parent = self.get_latest_snapshot(source_name)
snapshot_id = f"{source_name}_{data_hash[:12]}_{timestamp.strftime('%Y%m%d%H%M%S')}"
lakefs_path = None
if self.lakefs and self.repository:
lakefs_path = f"snapshots/{source_name}/{snapshot_id}.json"
try:
self.lakefs.upload_object(
self.repository,
self.branch,
lakefs_path,
data or json.dumps(metadata),
)
except Exception as e:
logger.warning(f"Failed to upload snapshot to lakeFS: {e}")
lakefs_path = None
snapshot = Snapshot(
snapshot_id=snapshot_id,
source_name=source_name,
source_type=source_type,
data_hash=data_hash,
timestamp=timestamp.isoformat(),
size_bytes=size_bytes,
record_count=record_count,
metadata=metadata,
lakefs_path=lakefs_path,
parent_snapshot_id=parent.snapshot_id if parent else None,
)
self._snapshots.setdefault(source_name, []).append(snapshot)
self._last_snapshot_times[source_name] = timestamp
logger.info(
f"Stored snapshot {snapshot_id} for {source_name} "
f"(hash={data_hash[:12]}, size={size_bytes})"
)
return snapshot
def _find_snapshot(self, snapshot_id: str) -> Optional[Snapshot]:
for snaps in self._snapshots.values():
for snap in snaps:
if snap.snapshot_id == snapshot_id:
return snap
return None
def _start_span(self, name: str, attributes: Dict[str, Any] = None):
if not HAS_OTEL:
return None
try:
tracer = trace.get_tracer(__name__)
span = tracer.start_span(name, attributes=attributes or {})
return span
except Exception as e:
logger.debug(f"Failed to start OTel span: {e}")
return None
def _set_span_attributes(self, span, attributes: Dict[str, Any]) -> None:
if span is None:
return
try:
for k, v in attributes.items():
span.set_attribute(k, v)
except Exception:
pass
def _record_span_exception(self, span, exception: Exception) -> None:
if span is None:
return
try:
span.set_status(trace.StatusCode.ERROR, str(exception))
span.record_exception(exception)
except Exception:
pass
def _end_span(self, span) -> None:
if span is None:
return
try:
span.end()
except Exception:
pass