import time
from typing import Optional
try:
from opentelemetry import trace
HAS_OTEL = True
tracer = trace.get_tracer(__name__)
except ImportError:
HAS_OTEL = False
tracer = None
from briefcase.validation.extractors import ReferenceExtractor
from briefcase.validation.resolvers import ReferenceResolver
from briefcase.validation.semantic import SemanticValidator
from briefcase.validation.errors import ValidationReport, ValidationError
from briefcase.semantic_conventions.validation import *
class PromptValidationEngine:
def __init__(
self,
lakefs_client,
repository: str,
branch: str = "main",
mode: str = "strict",
enable_semantic: bool = False,
llm_client=None
):
self.extractor = ReferenceExtractor(
use_llm_fallback=enable_semantic,
llm_model=llm_client.model_name if llm_client and hasattr(llm_client, 'model_name') else None
)
self.resolver = ReferenceResolver(lakefs_client, repository, branch)
self.semantic = None
if enable_semantic and llm_client:
self.semantic = SemanticValidator(llm_client, lakefs_client, repository, branch)
self.mode = mode self.repository = repository
self.branch = branch
self.lakefs = lakefs_client
def validate(self, prompt: str) -> ValidationReport:
if HAS_OTEL and tracer:
with tracer.start_as_current_span("validation.validate_prompt") as span:
return self._validate_with_telemetry(prompt, span)
else:
return self._validate_internal(prompt)
def _validate_with_telemetry(self, prompt: str, span) -> ValidationReport:
span.set_attribute(VALIDATION_MODE, self.mode)
report = self._validate_internal(prompt)
span.set_attribute(VALIDATION_STATUS, report.status)
span.set_attribute(VALIDATION_ERROR_COUNT, len(report.errors))
span.set_attribute(VALIDATION_RESOLUTION_TIME_MS, report.validation_time_ms)
for error in report.errors:
span.add_event(
"validation.error",
attributes={
VALIDATION_ERROR_CODE: error.code.value,
VALIDATION_ERROR_MESSAGE: error.message,
VALIDATION_ERROR_REFERENCE: error.reference
}
)
return report
def _validate_internal(self, prompt: str) -> ValidationReport:
start_time = time.time()
all_errors = []
all_warnings = []
references = self.extractor.extract(prompt)
if len(references) == 0:
commit_sha = "unknown"
try:
commit_sha = self.lakefs.get_commit(self.repository, self.branch)
except Exception:
pass
return ValidationReport(
status="passed",
errors=[],
warnings=[],
references_checked=0,
validation_time_ms=(time.time() - start_time) * 1000,
lakefs_commit=commit_sha
)
resolution_errors = self.resolver.resolve_all(references)
for error in resolution_errors:
if error.severity == "error":
all_errors.append(error)
else:
all_warnings.append(error)
if self.semantic and len(all_errors) == 0:
semantic_errors = self.semantic.validate_semantic(prompt, references)
all_warnings.extend(semantic_errors)
elapsed_ms = (time.time() - start_time) * 1000
status = self._determine_status(all_errors, all_warnings)
commit_sha = "unknown"
try:
commit_sha = self.lakefs.get_commit(self.repository, self.branch)
except Exception:
pass
return ValidationReport(
status=status,
errors=all_errors,
warnings=all_warnings,
references_checked=len(references),
validation_time_ms=elapsed_ms,
lakefs_commit=commit_sha
)
def _determine_status(
self,
errors: list,
warnings: list
) -> str:
if self.mode == "strict":
if len(errors) > 0:
return "failed"
elif len(warnings) > 0:
return "warning"
else:
return "passed"
elif self.mode == "tolerant":
if len(errors) > 0:
return "failed"
else:
return "passed"
else: return "passed"