import json as json_module
import re
from typing import List, Optional
from dataclasses import dataclass
try:
from opentelemetry import trace
HAS_OTEL = True
tracer = trace.get_tracer(__name__)
except ImportError:
HAS_OTEL = False
tracer = None
@dataclass
class Reference:
text: str type: str path: Optional[str] = None version: Optional[str] = None metadata: Optional[dict] = None
class ReferenceExtractor:
PATTERNS = {
'document_id': re.compile(r'\bdoc[_-]?id[:\s]+([a-zA-Z0-9_-]+)'),
'policy_name': re.compile(r'\bpolicy[:\s]+["\']?([^"\'}\n]+?)["\']?(?=\s+(?:for|and|to|or|\.|,|$))'),
'section_ref': re.compile(r'\bSection\s+(\d+(?:\.\d+)*)'),
'file_path': re.compile(r'\b([a-zA-Z0-9_/-]+\.(?:pdf|md|txt|docx))\b'),
'versioned_file': re.compile(r'\b([a-zA-Z0-9_/-]+\.(?:pdf|md|txt|docx))@([a-f0-9]{6,40})\b'),
}
def __init__(self, use_llm_fallback: bool = False, llm_model: Optional[str] = None):
self.use_llm_fallback = use_llm_fallback
self.llm_model = llm_model
def extract(self, prompt: str) -> List[Reference]:
if HAS_OTEL and tracer:
with tracer.start_as_current_span("validation.extract_references") as span:
return self._extract_with_telemetry(prompt, span)
else:
return self._extract_internal(prompt)
def _extract_with_telemetry(self, prompt: str, span) -> List[Reference]:
references = self._extract_internal(prompt)
span.set_attribute("validation.reference.count", len(references))
if references:
span.set_attribute(
"validation.reference.extracted",
str([r.text for r in references])
)
return references
def _extract_internal(self, prompt: str) -> List[Reference]:
references = []
pattern_order = ['versioned_file', 'document_id', 'policy_name', 'section_ref', 'file_path']
for ref_type in pattern_order:
pattern = self.PATTERNS[ref_type]
matches = pattern.finditer(prompt)
for match in matches:
ref = Reference(
text=match.group(0),
type=ref_type,
metadata={'regex_pattern': ref_type}
)
if ref_type == 'versioned_file':
ref.path = match.group(1)
ref.version = match.group(2)
else:
ref.path = match.group(1)
references.append(ref)
references = self._deduplicate(references)
if len(references) == 0 and self.use_llm_fallback:
references = self._llm_extract(prompt)
return references
def _deduplicate(self, references: List[Reference]) -> List[Reference]:
versioned_paths = {ref.path for ref in references if ref.type == 'versioned_file'}
seen = set()
unique = []
for ref in references:
if ref.type == 'file_path' and ref.path in versioned_paths:
continue
key = (ref.text, ref.type)
if key not in seen:
seen.add(key)
unique.append(ref)
return unique
def _llm_extract(self, prompt: str) -> List[Reference]:
if not self.llm_model:
return []
extraction_prompt = (
"Extract all knowledge base references from this prompt.\n"
"Look for: document IDs, policy names, section numbers, file paths.\n\n"
f"Prompt:\n{prompt}\n\n"
'Return JSON array of references:\n'
'[{"text": "...", "type": "document|section|policy", "path": "..."}]'
)
def _call_llm(ep: str) -> List[Reference]:
try:
if hasattr(self.llm_model, 'complete'):
response = self.llm_model.complete(ep)
elif callable(self.llm_model):
response = self.llm_model(ep)
else:
return []
return self._parse_llm_response(str(response))
except Exception:
return []
if HAS_OTEL and tracer:
with tracer.start_as_current_span("validation.llm_extract"):
return _call_llm(extraction_prompt)
return _call_llm(extraction_prompt)
def _parse_llm_response(self, response: str) -> List[Reference]:
try:
json_str = response
if "```" in response:
start = response.find("[")
end = response.rfind("]") + 1
if start != -1 and end > start:
json_str = response[start:end]
data = json_module.loads(json_str)
if not isinstance(data, list):
return []
references = []
for item in data:
if isinstance(item, dict):
ref = Reference(
text=item.get("text", ""),
type=item.get("type", "file_path"),
path=item.get("path"),
metadata={"source": "llm"}
)
references.append(ref)
return references
except Exception:
return []