briefcase-python 2.4.1

Python bindings for Briefcase AI
Documentation
"""
Main validation engine coordinating all layers.
"""

import time
from typing import Optional

try:
    from opentelemetry import trace
    HAS_OTEL = True
    tracer = trace.get_tracer(__name__)
except ImportError:
    HAS_OTEL = False
    tracer = None

from briefcase.validation.extractors import ReferenceExtractor
from briefcase.validation.resolvers import ReferenceResolver
from briefcase.validation.semantic import SemanticValidator
from briefcase.validation.errors import ValidationReport, ValidationError
from briefcase.semantic_conventions.validation import *


class PromptValidationEngine:
    """
    Multi-layer validation engine for prompt-knowledge consistency.
    """

    def __init__(
        self,
        lakefs_client,
        repository: str,
        branch: str = "main",
        mode: str = "strict",
        enable_semantic: bool = False,
        llm_client=None
    ):
        self.extractor = ReferenceExtractor(
            use_llm_fallback=enable_semantic,
            llm_model=llm_client.model_name if llm_client and hasattr(llm_client, 'model_name') else None
        )
        self.resolver = ReferenceResolver(lakefs_client, repository, branch)
        self.semantic = None
        if enable_semantic and llm_client:
            self.semantic = SemanticValidator(llm_client, lakefs_client, repository, branch)

        self.mode = mode  # "strict" | "tolerant" | "warn_only"
        self.repository = repository
        self.branch = branch
        self.lakefs = lakefs_client

    def validate(self, prompt: str) -> ValidationReport:
        """
        Validate prompt against knowledge base.
        Returns ValidationReport with all errors and warnings.
        """
        if HAS_OTEL and tracer:
            with tracer.start_as_current_span("validation.validate_prompt") as span:
                return self._validate_with_telemetry(prompt, span)
        else:
            return self._validate_internal(prompt)

    def _validate_with_telemetry(self, prompt: str, span) -> ValidationReport:
        """Validate with telemetry."""
        span.set_attribute(VALIDATION_MODE, self.mode)

        report = self._validate_internal(prompt)

        span.set_attribute(VALIDATION_STATUS, report.status)
        span.set_attribute(VALIDATION_ERROR_COUNT, len(report.errors))
        span.set_attribute(VALIDATION_RESOLUTION_TIME_MS, report.validation_time_ms)

        # Record errors as events
        for error in report.errors:
            span.add_event(
                "validation.error",
                attributes={
                    VALIDATION_ERROR_CODE: error.code.value,
                    VALIDATION_ERROR_MESSAGE: error.message,
                    VALIDATION_ERROR_REFERENCE: error.reference
                }
            )

        return report

    def _validate_internal(self, prompt: str) -> ValidationReport:
        """Internal validation logic."""
        start_time = time.time()
        all_errors = []
        all_warnings = []

        # Layer 1: Extract references
        references = self.extractor.extract(prompt)

        if len(references) == 0:
            # No references to validate
            commit_sha = "unknown"
            try:
                commit_sha = self.lakefs.get_commit(self.repository, self.branch)
            except Exception:
                pass

            return ValidationReport(
                status="passed",
                errors=[],
                warnings=[],
                references_checked=0,
                validation_time_ms=(time.time() - start_time) * 1000,
                lakefs_commit=commit_sha
            )

        # Layer 2: Resolve references
        resolution_errors = self.resolver.resolve_all(references)

        for error in resolution_errors:
            if error.severity == "error":
                all_errors.append(error)
            else:
                all_warnings.append(error)

        # Layer 3: Semantic validation (optional)
        if self.semantic and len(all_errors) == 0:
            semantic_errors = self.semantic.validate_semantic(prompt, references)
            all_warnings.extend(semantic_errors)

        # Determine overall status
        elapsed_ms = (time.time() - start_time) * 1000
        status = self._determine_status(all_errors, all_warnings)

        # Get commit SHA
        commit_sha = "unknown"
        try:
            commit_sha = self.lakefs.get_commit(self.repository, self.branch)
        except Exception:
            pass

        return ValidationReport(
            status=status,
            errors=all_errors,
            warnings=all_warnings,
            references_checked=len(references),
            validation_time_ms=elapsed_ms,
            lakefs_commit=commit_sha
        )

    def _determine_status(
        self,
        errors: list,
        warnings: list
    ) -> str:
        """Determine overall validation status based on mode."""

        if self.mode == "strict":
            if len(errors) > 0:
                return "failed"
            elif len(warnings) > 0:
                return "warning"
            else:
                return "passed"

        elif self.mode == "tolerant":
            # Only hard errors fail
            if len(errors) > 0:
                return "failed"
            else:
                return "passed"

        else:  # warn_only
            # Never fails, just warns
            return "passed"