briefcase-python 2.4.1

Python bindings for Briefcase AI
Documentation
"""
Reference extraction from prompts (Layer 1).
"""

import json as json_module
import re
from typing import List, Optional
from dataclasses import dataclass

try:
    from opentelemetry import trace
    HAS_OTEL = True
    tracer = trace.get_tracer(__name__)
except ImportError:
    HAS_OTEL = False
    tracer = None


@dataclass
class Reference:
    """Extracted reference from prompt."""

    text: str  # Original reference text
    type: str  # "document_id" | "policy_name" | "section_ref" | "file_path" | "versioned_file"
    path: Optional[str] = None  # Resolved lakeFS path
    version: Optional[str] = None  # Explicit version if specified
    metadata: Optional[dict] = None


class ReferenceExtractor:
    """
    Extracts knowledge base references from prompts.
    Uses regex patterns with optional LLM fallback.
    """

    # Common patterns
    PATTERNS = {
        'document_id': re.compile(r'\bdoc[_-]?id[:\s]+([a-zA-Z0-9_-]+)'),
        'policy_name': re.compile(r'\bpolicy[:\s]+["\']?([^"\'}\n]+?)["\']?(?=\s+(?:for|and|to|or|\.|,|$))'),
        'section_ref': re.compile(r'\bSection\s+(\d+(?:\.\d+)*)'),
        'file_path': re.compile(r'\b([a-zA-Z0-9_/-]+\.(?:pdf|md|txt|docx))\b'),
        'versioned_file': re.compile(r'\b([a-zA-Z0-9_/-]+\.(?:pdf|md|txt|docx))@([a-f0-9]{6,40})\b'),
    }

    def __init__(self, use_llm_fallback: bool = False, llm_model: Optional[str] = None):
        self.use_llm_fallback = use_llm_fallback
        self.llm_model = llm_model

    def extract(self, prompt: str) -> List[Reference]:
        """
        Extract all knowledge base references from prompt.
        """
        if HAS_OTEL and tracer:
            with tracer.start_as_current_span("validation.extract_references") as span:
                return self._extract_with_telemetry(prompt, span)
        else:
            return self._extract_internal(prompt)

    def _extract_with_telemetry(self, prompt: str, span) -> List[Reference]:
        """Extract references with telemetry."""
        references = self._extract_internal(prompt)

        span.set_attribute("validation.reference.count", len(references))
        if references:
            span.set_attribute(
                "validation.reference.extracted",
                str([r.text for r in references])
            )

        return references

    def _extract_internal(self, prompt: str) -> List[Reference]:
        """Internal extraction logic."""
        references = []

        # Try regex patterns first (in order to handle versioned_file before file_path)
        pattern_order = ['versioned_file', 'document_id', 'policy_name', 'section_ref', 'file_path']

        for ref_type in pattern_order:
            pattern = self.PATTERNS[ref_type]
            matches = pattern.finditer(prompt)

            for match in matches:
                ref = Reference(
                    text=match.group(0),
                    type=ref_type,
                    metadata={'regex_pattern': ref_type}
                )

                # Extract version if present
                if ref_type == 'versioned_file':
                    ref.path = match.group(1)
                    ref.version = match.group(2)
                else:
                    ref.path = match.group(1)

                references.append(ref)

        # Deduplicate
        references = self._deduplicate(references)

        # Fallback to LLM if no refs found and fallback enabled
        if len(references) == 0 and self.use_llm_fallback:
            references = self._llm_extract(prompt)

        return references

    def _deduplicate(self, references: List[Reference]) -> List[Reference]:
        """
        Remove duplicate references.
        Priority: versioned_file > file_path (if same base path).
        """
        # Track versioned file paths to exclude their unversioned duplicates
        versioned_paths = {ref.path for ref in references if ref.type == 'versioned_file'}

        seen = set()
        unique = []
        for ref in references:
            # Skip unversioned file_path if we have a versioned version of it
            if ref.type == 'file_path' and ref.path in versioned_paths:
                continue

            key = (ref.text, ref.type)
            if key not in seen:
                seen.add(key)
                unique.append(ref)
        return unique

    def _llm_extract(self, prompt: str) -> List[Reference]:
        """
        Use LLM to extract references (fallback for complex cases).
        """
        if not self.llm_model:
            return []

        extraction_prompt = (
            "Extract all knowledge base references from this prompt.\n"
            "Look for: document IDs, policy names, section numbers, file paths.\n\n"
            f"Prompt:\n{prompt}\n\n"
            'Return JSON array of references:\n'
            '[{"text": "...", "type": "document|section|policy", "path": "..."}]'
        )

        def _call_llm(ep: str) -> List[Reference]:
            try:
                if hasattr(self.llm_model, 'complete'):
                    response = self.llm_model.complete(ep)
                elif callable(self.llm_model):
                    response = self.llm_model(ep)
                else:
                    return []
                return self._parse_llm_response(str(response))
            except Exception:
                return []

        if HAS_OTEL and tracer:
            with tracer.start_as_current_span("validation.llm_extract"):
                return _call_llm(extraction_prompt)

        return _call_llm(extraction_prompt)

    def _parse_llm_response(self, response: str) -> List[Reference]:
        """Parse JSON response from LLM into Reference objects."""
        try:
            # Extract JSON array from response (handle markdown code blocks)
            json_str = response
            if "```" in response:
                # Extract content between code blocks
                start = response.find("[")
                end = response.rfind("]") + 1
                if start != -1 and end > start:
                    json_str = response[start:end]

            # Parse JSON
            data = json_module.loads(json_str)
            if not isinstance(data, list):
                return []

            references = []
            for item in data:
                if isinstance(item, dict):
                    ref = Reference(
                        text=item.get("text", ""),
                        type=item.get("type", "file_path"),
                        path=item.get("path"),
                        metadata={"source": "llm"}
                    )
                    references.append(ref)
            return references
        except Exception:
            return []