from __future__ import annotations
import logging
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any, Callable, Dict, List, Optional, Sequence
from briefcase.cowork.redaction import CoworkRedactionFilter
from briefcase.semantic_conventions import cowork as conv
logger = logging.getLogger(__name__)
@dataclass(frozen=True)
class CoworkEvent:
event_type: str
timestamp: str sequence: int
session_id: str
organization_id: str
user_account_uuid: str
user_id: str
user_email: str
terminal_type: str
prompt_id: str attributes: Dict[str, Any] raw_attributes: Dict[str, Any] = field(repr=False)
@dataclass
class ValidationError:
event_type: str
attribute: str
message: str
class CoworkEventReceiver:
def __init__(
self,
*,
redaction_filter: Optional[CoworkRedactionFilter] = None,
on_event: Optional[Callable[[CoworkEvent], None]] = None,
validate_strict: bool = False,
) -> None:
self._filter = redaction_filter or CoworkRedactionFilter()
self._on_event = on_event
self._strict = validate_strict
self._events: List[CoworkEvent] = []
self._errors: List[ValidationError] = []
@property
def events(self) -> List[CoworkEvent]:
return list(self._events)
@property
def validation_errors(self) -> List[ValidationError]:
return list(self._errors)
def receive(self, log_record: Dict[str, Any]) -> Optional[CoworkEvent]:
attrs = dict(log_record.get("attributes", {}))
resource_attrs = log_record.get("resource", {}).get("attributes", {})
event_type = (
log_record.get("body", {}).get("event_type")
or log_record.get("name")
or attrs.pop("event_type", None)
or ""
)
if event_type not in conv.ALL_EVENT_TYPES:
self._errors.append(
ValidationError(
event_type=event_type,
attribute="event_type",
message=f"Unknown event type: {event_type!r}",
)
)
return None
if self._strict:
missing = conv.REQUIRED_ATTRS.get(event_type, set()) - set(attrs.keys())
if missing:
for attr in sorted(missing):
self._errors.append(
ValidationError(
event_type=event_type,
attribute=attr,
message=f"Missing required attribute: {attr}",
)
)
return None
std = self._extract_standard_attrs(attrs, resource_attrs)
standard_keys = {
conv.SESSION_ID,
conv.ORGANIZATION_ID,
conv.USER_ACCOUNT_UUID,
conv.USER_ID,
conv.USER_EMAIL,
conv.TERMINAL_TYPE,
conv.EVENT_TIMESTAMP,
conv.EVENT_SEQUENCE,
conv.PROMPT_ID,
"event_type",
}
event_attrs = {k: v for k, v in attrs.items() if k not in standard_keys}
redacted_attrs = self._filter.redact_event(event_attrs)
event = CoworkEvent(
event_type=event_type,
timestamp=std["timestamp"],
sequence=std["sequence"],
session_id=std["session_id"],
organization_id=std["organization_id"],
user_account_uuid=std["user_account_uuid"],
user_id=std["user_id"],
user_email=std["user_email"],
terminal_type=std["terminal_type"],
prompt_id=std["prompt_id"],
attributes=redacted_attrs,
raw_attributes=event_attrs,
)
self._events.append(event)
if self._on_event is not None:
try:
self._on_event(event)
except Exception:
logger.exception("on_event callback failed for %s", event_type)
return event
def receive_batch(
self, log_records: Sequence[Dict[str, Any]]
) -> List[CoworkEvent]:
accepted: List[CoworkEvent] = []
for record in log_records:
evt = self.receive(record)
if evt is not None:
accepted.append(evt)
return accepted
def clear(self) -> None:
self._events.clear()
self._errors.clear()
@staticmethod
def _extract_standard_attrs(
attrs: Dict[str, Any], resource_attrs: Dict[str, Any]
) -> Dict[str, Any]:
return {
"session_id": str(attrs.get(conv.SESSION_ID, "")),
"organization_id": str(attrs.get(conv.ORGANIZATION_ID, "")),
"user_account_uuid": str(attrs.get(conv.USER_ACCOUNT_UUID, "")),
"user_id": str(attrs.get(conv.USER_ID, "")),
"user_email": str(attrs.get(conv.USER_EMAIL, "")),
"terminal_type": str(attrs.get(conv.TERMINAL_TYPE, "")),
"timestamp": str(attrs.get(conv.EVENT_TIMESTAMP, "")),
"sequence": int(attrs.get(conv.EVENT_SEQUENCE, 0)),
"prompt_id": str(attrs.get(conv.PROMPT_ID, "")),
}