from __future__ import annotations
import logging
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import Any, Callable, Dict, List, Optional, Sequence
from briefcase.cowork.receiver import CoworkEvent
from briefcase.semantic_conventions import cowork as conv
logger = logging.getLogger(__name__)
@dataclass
class Alert:
alert_type: str
severity: str message: str
event: CoworkEvent
timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
metadata: Dict[str, Any] = field(default_factory=dict)
AlertHandler = Callable[[Alert], None]
@dataclass
class AlertRule:
name: str
event_types: List[str]
condition: Callable[[CoworkEvent], bool]
severity: str = "warning"
message_template: str = "Alert: {name} triggered"
class CoworkAlertManager:
def __init__(
self,
*,
handlers: Optional[List[AlertHandler]] = None,
enable_builtin_rules: bool = True,
cost_threshold_usd: Optional[float] = None,
) -> None:
self._handlers: List[AlertHandler] = list(handlers or [])
self._rules: List[AlertRule] = []
self._alerts: List[Alert] = []
if enable_builtin_rules:
self._register_builtin_rules(cost_threshold_usd)
@property
def alerts(self) -> List[Alert]:
return list(self._alerts)
@property
def rules(self) -> List[AlertRule]:
return list(self._rules)
def add_handler(self, handler: AlertHandler) -> None:
self._handlers.append(handler)
def add_rule(self, rule: AlertRule) -> None:
self._rules.append(rule)
def evaluate(self, event: CoworkEvent) -> List[Alert]:
fired: List[Alert] = []
for rule in self._rules:
if event.event_type not in rule.event_types:
continue
try:
if rule.condition(event):
alert = Alert(
alert_type=rule.name,
severity=rule.severity,
message=rule.message_template.format(
name=rule.name,
event_type=event.event_type,
**event.attributes,
),
event=event,
metadata={"rule": rule.name},
)
fired.append(alert)
self._alerts.append(alert)
self._dispatch(alert)
except Exception:
logger.exception("Alert rule %r failed for event %s", rule.name, event.event_type)
return fired
def evaluate_many(self, events: Sequence[CoworkEvent]) -> List[Alert]:
all_alerts: List[Alert] = []
for event in events:
all_alerts.extend(self.evaluate(event))
return all_alerts
def clear(self) -> None:
self._alerts.clear()
def _register_builtin_rules(self, cost_threshold: Optional[float]) -> None:
self._rules.append(
AlertRule(
name="api_error",
event_types=[conv.EVENT_API_ERROR],
condition=lambda _: True,
severity="critical",
message_template="API error: model={model} status={status_code} error={error}",
)
)
self._rules.append(
AlertRule(
name="tool_failure",
event_types=[conv.EVENT_TOOL_RESULT],
condition=lambda e: str(
e.attributes.get(conv.TOOL_SUCCESS, "true")
).lower()
== "false",
severity="warning",
message_template="Tool failure: {tool_name} error={error}",
)
)
if cost_threshold is not None:
threshold = cost_threshold
self._rules.append(
AlertRule(
name="cost_threshold",
event_types=[conv.EVENT_API_REQUEST],
condition=lambda e, t=threshold: _float(
e.attributes.get(conv.API_COST_USD, 0)
)
> t,
severity="warning",
message_template="Cost threshold exceeded: model={model} cost=${cost_usd}",
)
)
def _dispatch(self, alert: Alert) -> None:
for handler in self._handlers:
try:
handler(alert)
except Exception:
logger.exception("Alert handler failed for %s", alert.alert_type)
def _float(v: Any) -> float:
try:
return float(v)
except (TypeError, ValueError):
return 0.0