from abc import ABC, abstractmethod
from typing import Dict, List, Any, Optional, Protocol, runtime_checkable
from dataclasses import dataclass, field, asdict
from datetime import datetime
from enum import Enum
import json
import logging
logger = logging.getLogger(__name__)
class ComplianceStatus(Enum):
COMPLIANT = "compliant"
NON_COMPLIANT = "non_compliant"
PARTIAL = "partial"
NEEDS_REVIEW = "needs_review"
class ViolationSeverity(Enum):
CRITICAL = "critical"
MAJOR = "major"
MINOR = "minor"
ADVISORY = "advisory"
@dataclass
class Violation:
control_id: str
severity: ViolationSeverity
message: str
affected_items: List[str] = field(default_factory=list)
evidence: Dict[str, Any] = field(default_factory=dict)
remediation: str = ""
target_date: Optional[datetime] = None
def to_dict(self):
result = asdict(self)
result['severity'] = self.severity.value
if self.target_date:
result['target_date'] = self.target_date.isoformat()
return result
@dataclass
class ControlResult:
control_id: str
control_name: str
status: ComplianceStatus
score: float evidence: List[str] = field(default_factory=list)
violations: List[Violation] = field(default_factory=list)
recommendations: List[str] = field(default_factory=list)
def to_dict(self):
return {
'control_id': self.control_id,
'control_name': self.control_name,
'status': self.status.value,
'score': self.score,
'evidence': self.evidence,
'violations': [v.to_dict() for v in self.violations],
'recommendations': self.recommendations
}
@dataclass
class ComplianceReport:
framework: str organization: str
report_period_start: datetime
report_period_end: datetime
evaluation_date: datetime
overall_status: ComplianceStatus
overall_score: float
control_results: List[ControlResult] = field(default_factory=list)
violations: List[Violation] = field(default_factory=list)
total_controls_evaluated: int = 0
controls_passed: int = 0
controls_failed: int = 0
controls_partial: int = 0
total_decisions: int = 0
total_spans: int = 0
telemetry_completeness: float = 0.0
auditor: str = "Briefcase AI Compliance Engine"
auditor_version: str = "2.1"
report_id: str = ""
def to_json(self) -> dict:
return {
'framework': self.framework,
'organization': self.organization,
'report_period': {
'start': self.report_period_start.isoformat(),
'end': self.report_period_end.isoformat()
},
'evaluation_date': self.evaluation_date.isoformat(),
'overall_status': self.overall_status.value,
'overall_score': self.overall_score,
'control_results': [cr.to_dict() for cr in self.control_results],
'violations': [v.to_dict() for v in self.violations],
'statistics': {
'total_controls_evaluated': self.total_controls_evaluated,
'controls_passed': self.controls_passed,
'controls_failed': self.controls_failed,
'controls_partial': self.controls_partial,
'total_decisions': self.total_decisions,
'total_spans': self.total_spans,
'telemetry_completeness': self.telemetry_completeness
},
'metadata': {
'auditor': self.auditor,
'auditor_version': self.auditor_version,
'report_id': self.report_id or f"rpt_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
}
}
def to_markdown(self) -> str:
md = []
md.append(f"# {self.framework} Compliance Report\n")
md.append(f"**Organization**: {self.organization}\n")
md.append(f"**Report Period**: {self.report_period_start.date()} to {self.report_period_end.date()}\n")
md.append(f"**Evaluation Date**: {self.evaluation_date.date()}\n")
md.append(f"**Overall Status**: {self.overall_status.value.upper()}\n")
md.append(f"**Overall Score**: {self.overall_score:.1f}%\n")
md.append("\n---\n\n")
md.append("## Executive Summary\n\n")
md.append(f"- **Controls Evaluated**: {self.total_controls_evaluated}\n")
md.append(f"- **Controls Passed**: {self.controls_passed}\n")
md.append(f"- **Controls Failed**: {self.controls_failed}\n")
md.append(f"- **Controls Partially Compliant**: {self.controls_partial}\n")
md.append(f"- **Total Violations**: {len(self.violations)}\n\n")
md.append("## Control Results\n\n")
for cr in self.control_results:
md.append(f"### {cr.control_id}: {cr.control_name}\n\n")
md.append(f"- **Status**: {cr.status.value.upper()}\n")
md.append(f"- **Score**: {cr.score:.1f}%\n")
if cr.evidence:
md.append(f"- **Evidence**:\n")
for e in cr.evidence:
md.append(f" - {e}\n")
if cr.violations:
md.append(f"- **Violations**: {len(cr.violations)}\n")
md.append("\n")
if self.violations:
md.append("## Violations\n\n")
for v in self.violations:
md.append(f"### {v.severity.value.upper()}: {v.control_id}\n\n")
md.append(f"{v.message}\n\n")
if v.remediation:
md.append(f"**Remediation**: {v.remediation}\n\n")
return "".join(md)
@runtime_checkable
class TelemetryProvider(Protocol):
def query_decisions(
self,
engagement_id: str,
workstream_id: str,
start_date: datetime,
end_date: datetime,
filters: Optional[Dict] = None,
) -> List[Dict[str, Any]]:
...
def query_spans(
self,
engagement_id: str,
workstream_id: str,
start_date: datetime,
end_date: datetime,
filters: Optional[Dict] = None,
) -> List[Dict[str, Any]]:
...
def query_access_logs(
self,
engagement_id: str,
workstream_id: str,
start_date: datetime,
end_date: datetime,
filters: Optional[Dict] = None,
) -> List[Dict[str, Any]]:
...
class ComplianceReportGenerator(ABC):
def __init__(self, briefcase_client=None):
self.client = briefcase_client
self._decisions: List[Dict[str, Any]] = []
self._spans: List[Dict[str, Any]] = []
self._access_logs: List[Dict[str, Any]] = []
def ingest_decisions(self, decisions: List[Dict[str, Any]]) -> None:
self._decisions.extend(decisions)
def ingest_spans(self, spans: List[Dict[str, Any]]) -> None:
self._spans.extend(spans)
def ingest_access_logs(self, logs: List[Dict[str, Any]]) -> None:
self._access_logs.extend(logs)
def clear_data(self) -> None:
self._decisions.clear()
self._spans.clear()
self._access_logs.clear()
@abstractmethod
def evaluate(
self,
engagement_id: str,
workstream_id: str,
start_date: datetime,
end_date: datetime
) -> ComplianceReport:
pass
def _query_telemetry(
self,
engagement_id: str,
workstream_id: str,
start_date: datetime,
end_date: datetime,
filters: Optional[Dict] = None
) -> List[Any]:
if self.client is not None and hasattr(self.client, 'query_decisions'):
try:
return self.client.query_decisions(
engagement_id, workstream_id, start_date, end_date, filters
)
except Exception as exc:
logger.warning("TelemetryProvider.query_decisions failed: %s", exc)
return self._filter_records(
self._decisions, engagement_id, workstream_id,
start_date, end_date, filters
)
def _query_spans(
self,
engagement_id: str,
workstream_id: str,
start_date: datetime,
end_date: datetime,
filters: Optional[Dict] = None
) -> List[Any]:
if self.client is not None and hasattr(self.client, 'query_spans'):
try:
return self.client.query_spans(
engagement_id, workstream_id, start_date, end_date, filters
)
except Exception as exc:
logger.warning("TelemetryProvider.query_spans failed: %s", exc)
return self._filter_records(
self._spans, engagement_id, workstream_id,
start_date, end_date, filters
)
def _query_access_logs(
self,
engagement_id: str,
workstream_id: str,
start_date: datetime,
end_date: datetime,
filters: Optional[Dict] = None
) -> List[Any]:
if self.client is not None and hasattr(self.client, 'query_access_logs'):
try:
return self.client.query_access_logs(
engagement_id, workstream_id, start_date, end_date, filters
)
except Exception as exc:
logger.warning("TelemetryProvider.query_access_logs failed: %s", exc)
return self._filter_records(
self._access_logs, engagement_id, workstream_id,
start_date, end_date, filters
)
@staticmethod
def _filter_records(
records: List[Dict[str, Any]],
engagement_id: str,
workstream_id: str,
start_date: datetime,
end_date: datetime,
filters: Optional[Dict] = None,
) -> List[Dict[str, Any]]:
result = []
for rec in records:
rec_eng = rec.get("engagement_id", rec.get("organization", ""))
if engagement_id and rec_eng and rec_eng != engagement_id:
continue
rec_ws = rec.get("workstream_id", rec.get("branch", ""))
if workstream_id and rec_ws and rec_ws != workstream_id:
continue
rec_ts = rec.get("timestamp")
if rec_ts is not None:
if isinstance(rec_ts, str):
try:
rec_ts = datetime.fromisoformat(rec_ts)
except ValueError:
pass
if isinstance(rec_ts, datetime):
if rec_ts < start_date or rec_ts > end_date:
continue
if filters:
skip = False
for key, value in filters.items():
if rec.get(key) != value:
skip = True
break
if skip:
continue
result.append(rec)
return result
def _calculate_score(self, results: List[ControlResult]) -> float:
if not results:
return 0.0
total_score = sum(r.score for r in results)
return total_score / len(results)