briefcase-python 2.4.1

Python bindings for Briefcase AI
Documentation
"""
Prompt-level correlation engine for Cowork events.

All events originating from a single user prompt share a ``prompt.id``
(UUID v4).  This module groups events by that key and provides
end-to-end prompt trace views.
"""

from __future__ import annotations

from collections import defaultdict
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Sequence

from briefcase.cowork.receiver import CoworkEvent
from briefcase.semantic_conventions import cowork as conv


@dataclass
class PromptTrace:
    """All events correlated to a single user prompt."""

    prompt_id: str
    session_id: str = ""
    organization_id: str = ""
    user_id: str = ""
    user_email: str = ""

    # Event buckets
    user_prompt: Optional[CoworkEvent] = None
    tool_results: List[CoworkEvent] = field(default_factory=list)
    api_requests: List[CoworkEvent] = field(default_factory=list)
    api_errors: List[CoworkEvent] = field(default_factory=list)
    tool_decisions: List[CoworkEvent] = field(default_factory=list)

    # Derived metrics (computed on finalize)
    total_cost_usd: float = 0.0
    total_input_tokens: int = 0
    total_output_tokens: int = 0
    total_cache_read_tokens: int = 0
    total_cache_creation_tokens: int = 0
    tool_count: int = 0
    tool_failure_count: int = 0
    api_error_count: int = 0
    total_duration_ms: float = 0.0

    @property
    def event_count(self) -> int:
        count = len(self.tool_results) + len(self.api_requests)
        count += len(self.api_errors) + len(self.tool_decisions)
        if self.user_prompt is not None:
            count += 1
        return count

    @property
    def events(self) -> List[CoworkEvent]:
        """All events in this trace sorted by sequence number."""
        all_evts: List[CoworkEvent] = []
        if self.user_prompt is not None:
            all_evts.append(self.user_prompt)
        all_evts.extend(self.tool_results)
        all_evts.extend(self.api_requests)
        all_evts.extend(self.api_errors)
        all_evts.extend(self.tool_decisions)
        all_evts.sort(key=lambda e: e.sequence)
        return all_evts

    def finalize(self) -> None:
        """Compute derived metrics from the collected events."""
        self.total_cost_usd = sum(
            _float(e.attributes.get(conv.API_COST_USD, 0))
            for e in self.api_requests
        )
        self.total_input_tokens = sum(
            _int(e.attributes.get(conv.API_INPUT_TOKENS, 0))
            for e in self.api_requests
        )
        self.total_output_tokens = sum(
            _int(e.attributes.get(conv.API_OUTPUT_TOKENS, 0))
            for e in self.api_requests
        )
        self.total_cache_read_tokens = sum(
            _int(e.attributes.get(conv.API_CACHE_READ_TOKENS, 0))
            for e in self.api_requests
        )
        self.total_cache_creation_tokens = sum(
            _int(e.attributes.get(conv.API_CACHE_CREATION_TOKENS, 0))
            for e in self.api_requests
        )
        self.tool_count = len(self.tool_results)
        self.tool_failure_count = sum(
            1
            for e in self.tool_results
            if str(e.attributes.get(conv.TOOL_SUCCESS, "true")).lower() == "false"
        )
        self.api_error_count = len(self.api_errors)

        # Total duration: max timestamp – min timestamp across all events
        all_durations = [
            _float(e.attributes.get(conv.API_DURATION_MS, 0))
            for e in self.api_requests
        ] + [
            _float(e.attributes.get(conv.TOOL_DURATION_MS, 0))
            for e in self.tool_results
        ]
        self.total_duration_ms = sum(all_durations)


class PromptCorrelationEngine:
    """Groups Cowork events by ``prompt.id`` and builds prompt traces.

    Usage::

        engine = PromptCorrelationEngine()
        for event in receiver.events:
            engine.add(event)
        for trace in engine.traces():
            print(trace.prompt_id, trace.total_cost_usd)
    """

    def __init__(self) -> None:
        self._buckets: Dict[str, List[CoworkEvent]] = defaultdict(list)

    def add(self, event: CoworkEvent) -> None:
        """Index an event by its prompt_id."""
        self._buckets[event.prompt_id].append(event)

    def add_many(self, events: Sequence[CoworkEvent]) -> None:
        for event in events:
            self.add(event)

    def get_trace(self, prompt_id: str) -> Optional[PromptTrace]:
        """Build a PromptTrace for a specific prompt_id."""
        events = self._buckets.get(prompt_id)
        if not events:
            return None
        return self._build_trace(prompt_id, events)

    def traces(self) -> List[PromptTrace]:
        """Build and return all prompt traces, finalized."""
        result: List[PromptTrace] = []
        for prompt_id, events in self._buckets.items():
            result.append(self._build_trace(prompt_id, events))
        return result

    @property
    def prompt_ids(self) -> List[str]:
        return list(self._buckets.keys())

    def clear(self) -> None:
        self._buckets.clear()

    # ------------------------------------------------------------------
    # Internal
    # ------------------------------------------------------------------

    @staticmethod
    def _build_trace(prompt_id: str, events: List[CoworkEvent]) -> PromptTrace:
        trace = PromptTrace(prompt_id=prompt_id)

        for event in events:
            # Populate identity from first event seen
            if not trace.session_id:
                trace.session_id = event.session_id
                trace.organization_id = event.organization_id
                trace.user_id = event.user_id
                trace.user_email = event.user_email

            if event.event_type == conv.EVENT_USER_PROMPT:
                trace.user_prompt = event
            elif event.event_type == conv.EVENT_TOOL_RESULT:
                trace.tool_results.append(event)
            elif event.event_type == conv.EVENT_API_REQUEST:
                trace.api_requests.append(event)
            elif event.event_type == conv.EVENT_API_ERROR:
                trace.api_errors.append(event)
            elif event.event_type == conv.EVENT_TOOL_DECISION:
                trace.tool_decisions.append(event)

        trace.finalize()
        return trace


# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------

def _float(v: Any) -> float:
    try:
        return float(v)
    except (TypeError, ValueError):
        return 0.0


def _int(v: Any) -> int:
    try:
        return int(v)
    except (TypeError, ValueError):
        return 0