from __future__ import annotations
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Sequence
from briefcase.cowork.receiver import CoworkEvent
from briefcase.semantic_conventions import cowork as conv
@dataclass
class PromptTrace:
prompt_id: str
session_id: str = ""
organization_id: str = ""
user_id: str = ""
user_email: str = ""
user_prompt: Optional[CoworkEvent] = None
tool_results: List[CoworkEvent] = field(default_factory=list)
api_requests: List[CoworkEvent] = field(default_factory=list)
api_errors: List[CoworkEvent] = field(default_factory=list)
tool_decisions: List[CoworkEvent] = field(default_factory=list)
total_cost_usd: float = 0.0
total_input_tokens: int = 0
total_output_tokens: int = 0
total_cache_read_tokens: int = 0
total_cache_creation_tokens: int = 0
tool_count: int = 0
tool_failure_count: int = 0
api_error_count: int = 0
total_duration_ms: float = 0.0
@property
def event_count(self) -> int:
count = len(self.tool_results) + len(self.api_requests)
count += len(self.api_errors) + len(self.tool_decisions)
if self.user_prompt is not None:
count += 1
return count
@property
def events(self) -> List[CoworkEvent]:
all_evts: List[CoworkEvent] = []
if self.user_prompt is not None:
all_evts.append(self.user_prompt)
all_evts.extend(self.tool_results)
all_evts.extend(self.api_requests)
all_evts.extend(self.api_errors)
all_evts.extend(self.tool_decisions)
all_evts.sort(key=lambda e: e.sequence)
return all_evts
def finalize(self) -> None:
self.total_cost_usd = sum(
_float(e.attributes.get(conv.API_COST_USD, 0))
for e in self.api_requests
)
self.total_input_tokens = sum(
_int(e.attributes.get(conv.API_INPUT_TOKENS, 0))
for e in self.api_requests
)
self.total_output_tokens = sum(
_int(e.attributes.get(conv.API_OUTPUT_TOKENS, 0))
for e in self.api_requests
)
self.total_cache_read_tokens = sum(
_int(e.attributes.get(conv.API_CACHE_READ_TOKENS, 0))
for e in self.api_requests
)
self.total_cache_creation_tokens = sum(
_int(e.attributes.get(conv.API_CACHE_CREATION_TOKENS, 0))
for e in self.api_requests
)
self.tool_count = len(self.tool_results)
self.tool_failure_count = sum(
1
for e in self.tool_results
if str(e.attributes.get(conv.TOOL_SUCCESS, "true")).lower() == "false"
)
self.api_error_count = len(self.api_errors)
all_durations = [
_float(e.attributes.get(conv.API_DURATION_MS, 0))
for e in self.api_requests
] + [
_float(e.attributes.get(conv.TOOL_DURATION_MS, 0))
for e in self.tool_results
]
self.total_duration_ms = sum(all_durations)
class PromptCorrelationEngine:
def __init__(self) -> None:
self._buckets: Dict[str, List[CoworkEvent]] = defaultdict(list)
def add(self, event: CoworkEvent) -> None:
self._buckets[event.prompt_id].append(event)
def add_many(self, events: Sequence[CoworkEvent]) -> None:
for event in events:
self.add(event)
def get_trace(self, prompt_id: str) -> Optional[PromptTrace]:
events = self._buckets.get(prompt_id)
if not events:
return None
return self._build_trace(prompt_id, events)
def traces(self) -> List[PromptTrace]:
result: List[PromptTrace] = []
for prompt_id, events in self._buckets.items():
result.append(self._build_trace(prompt_id, events))
return result
@property
def prompt_ids(self) -> List[str]:
return list(self._buckets.keys())
def clear(self) -> None:
self._buckets.clear()
@staticmethod
def _build_trace(prompt_id: str, events: List[CoworkEvent]) -> PromptTrace:
trace = PromptTrace(prompt_id=prompt_id)
for event in events:
if not trace.session_id:
trace.session_id = event.session_id
trace.organization_id = event.organization_id
trace.user_id = event.user_id
trace.user_email = event.user_email
if event.event_type == conv.EVENT_USER_PROMPT:
trace.user_prompt = event
elif event.event_type == conv.EVENT_TOOL_RESULT:
trace.tool_results.append(event)
elif event.event_type == conv.EVENT_API_REQUEST:
trace.api_requests.append(event)
elif event.event_type == conv.EVENT_API_ERROR:
trace.api_errors.append(event)
elif event.event_type == conv.EVENT_TOOL_DECISION:
trace.tool_decisions.append(event)
trace.finalize()
return trace
def _float(v: Any) -> float:
try:
return float(v)
except (TypeError, ValueError):
return 0.0
def _int(v: Any) -> int:
try:
return int(v)
except (TypeError, ValueError):
return 0