from datetime import datetime
from typing import List
from briefcase.compliance.reports.base import (
ComplianceReportGenerator,
ComplianceReport,
ComplianceStatus,
ControlResult,
Violation,
ViolationSeverity
)
class SOC2ReportGenerator(ComplianceReportGenerator):
MAX_UNAUTHORIZED_ATTEMPTS = 0 MIN_ENCRYPTION_PERCENT = 100.0 MIN_TRACE_COMPLETENESS = 95.0 MAX_ERROR_RATE_PERCENT = 5.0 MIN_DECISION_AUDIT_PERCENT = 99.0 MAX_UNENCRYPTED_OBJECTS = 0
def evaluate(
self,
engagement_id: str,
workstream_id: str,
start_date: datetime,
end_date: datetime
) -> ComplianceReport:
report = ComplianceReport(
framework="SOC2 Type II",
organization=engagement_id,
report_period_start=start_date,
report_period_end=end_date,
evaluation_date=datetime.now(),
overall_status=ComplianceStatus.COMPLIANT,
overall_score=0.0
)
all_decisions = self._query_telemetry(
engagement_id, workstream_id, start_date, end_date
)
all_spans = self._query_spans(
engagement_id, workstream_id, start_date, end_date
)
report.total_decisions = len(all_decisions)
report.total_spans = len(all_spans)
results = []
results.extend(self._evaluate_security(
engagement_id, workstream_id, start_date, end_date
))
results.extend(self._evaluate_availability(
engagement_id, workstream_id, start_date, end_date
))
results.extend(self._evaluate_processing_integrity(
engagement_id, workstream_id, start_date, end_date
))
report.control_results = results
report.total_controls_evaluated = len(results)
report.controls_passed = sum(
1 for r in results if r.status == ComplianceStatus.COMPLIANT
)
report.controls_failed = sum(
1 for r in results if r.status == ComplianceStatus.NON_COMPLIANT
)
report.controls_partial = sum(
1 for r in results if r.status == ComplianceStatus.PARTIAL
)
report.violations = []
for result in results:
report.violations.extend(result.violations)
report.overall_score = self._calculate_score(results)
if report.total_decisions > 0:
decisions_with_trail = sum(
1 for d in all_decisions
if d.get("has_audit_trail", d.get("outputs") is not None)
)
report.telemetry_completeness = (
(decisions_with_trail / report.total_decisions) * 100
)
if report.controls_failed > 0:
report.overall_status = ComplianceStatus.NON_COMPLIANT
elif report.controls_partial > 0:
report.overall_status = ComplianceStatus.PARTIAL
else:
report.overall_status = ComplianceStatus.COMPLIANT
return report
def _evaluate_security(
self, engagement_id, workstream_id, start_date, end_date
) -> List[ControlResult]:
results = []
results.append(self._check_access_controls(
engagement_id, workstream_id, start_date, end_date
))
results.append(self._check_encryption(
engagement_id, workstream_id, start_date, end_date
))
results.append(self._check_logging(
engagement_id, workstream_id, start_date, end_date
))
return results
def _check_access_controls(
self, engagement_id, workstream_id, start_date, end_date
) -> ControlResult:
unauthorized_attempts = self._query_access_logs(
engagement_id, workstream_id, start_date, end_date,
filters={"authorized": False}
)
total_access = self._query_access_logs(
engagement_id, workstream_id, start_date, end_date
)
if len(unauthorized_attempts) <= self.MAX_UNAUTHORIZED_ATTEMPTS:
return ControlResult(
control_id="CC6.1",
control_name="Logical and Physical Access Controls",
status=ComplianceStatus.COMPLIANT,
score=100.0,
evidence=[
f"{len(unauthorized_attempts)} unauthorized access attempts "
f"in period (threshold: {self.MAX_UNAUTHORIZED_ATTEMPTS})",
f"{len(total_access)} total access events audited",
]
)
else:
violation = Violation(
control_id="CC6.1",
severity=ViolationSeverity.CRITICAL,
message=(
f"{len(unauthorized_attempts)} unauthorized access attempts "
f"detected (max allowed: {self.MAX_UNAUTHORIZED_ATTEMPTS})"
),
affected_items=[
str(a.get("id", a.get("timestamp", "unknown")))
for a in unauthorized_attempts[:10]
],
remediation=(
"Review access logs and revoke compromised credentials. "
"Enforce MFA for all admin access. Rotate API keys."
)
)
if len(unauthorized_attempts) < 5:
score = max(0.0, 100.0 - len(unauthorized_attempts) * 20)
status = ComplianceStatus.PARTIAL
else:
score = 0.0
status = ComplianceStatus.NON_COMPLIANT
return ControlResult(
control_id="CC6.1",
control_name="Logical and Physical Access Controls",
status=status,
score=score,
violations=[violation],
recommendations=[
"Enable MFA for all admin access",
"Rotate API keys every 90 days",
"Implement IP allowlisting"
]
)
def _check_encryption(
self, engagement_id, workstream_id, start_date, end_date
) -> ControlResult:
unencrypted = self._query_spans(
engagement_id, workstream_id, start_date, end_date,
filters={"encrypted": False}
)
all_transfers = self._query_spans(
engagement_id, workstream_id, start_date, end_date,
filters={"type": "data_transfer"}
)
total_transfers = max(len(all_transfers), 1) encrypted_percent = (
((total_transfers - len(unencrypted)) / total_transfers) * 100
)
if len(unencrypted) <= self.MAX_UNENCRYPTED_OBJECTS:
return ControlResult(
control_id="CC6.2",
control_name="Encryption of Data",
status=ComplianceStatus.COMPLIANT,
score=100.0,
evidence=[
f"{encrypted_percent:.1f}% of transfers encrypted (TLS 1.3)",
"Data at rest encrypted (AES-256)",
f"{len(unencrypted)} unencrypted objects (max: "
f"{self.MAX_UNENCRYPTED_OBJECTS})"
]
)
else:
violation = Violation(
control_id="CC6.2",
severity=ViolationSeverity.MAJOR,
message=(
f"{len(unencrypted)} unencrypted data transfers detected"
),
affected_items=[
str(u.get("id", u.get("path", "unknown")))
for u in unencrypted[:10]
],
remediation="Enable TLS for all data in transit. "
"Enable S3 server-side encryption for data at rest."
)
return ControlResult(
control_id="CC6.2",
control_name="Encryption of Data",
status=ComplianceStatus.NON_COMPLIANT,
score=encrypted_percent,
violations=[violation],
recommendations=[
"Enforce TLS 1.3 for all API endpoints",
"Enable server-side encryption for LakeFS storage"
]
)
def _check_logging(
self, engagement_id, workstream_id, start_date, end_date
) -> ControlResult:
all_spans = self._query_spans(
engagement_id, workstream_id, start_date, end_date
)
all_decisions = self._query_telemetry(
engagement_id, workstream_id, start_date, end_date
)
decisions_with_spans = set()
for span in all_spans:
decision_id = span.get("decision_id", span.get("trace_id"))
if decision_id:
decisions_with_spans.add(decision_id)
total_decisions = max(len(all_decisions), 1)
completeness = (len(decisions_with_spans) / total_decisions) * 100
if completeness >= self.MIN_TRACE_COMPLETENESS:
return ControlResult(
control_id="CC6.7",
control_name="Logging and Monitoring",
status=ComplianceStatus.COMPLIANT,
score=min(completeness, 100.0),
evidence=[
f"Trace completeness: {completeness:.1f}% "
f"(threshold: {self.MIN_TRACE_COMPLETENESS}%)",
f"{len(all_decisions)} decisions audited",
f"{len(all_spans)} spans recorded"
]
)
elif completeness >= self.MIN_TRACE_COMPLETENESS * 0.9:
return ControlResult(
control_id="CC6.7",
control_name="Logging and Monitoring",
status=ComplianceStatus.PARTIAL,
score=completeness,
evidence=[
f"Trace completeness: {completeness:.1f}% "
f"(threshold: {self.MIN_TRACE_COMPLETENESS}%)"
],
recommendations=[
"Increase OpenTelemetry instrumentation coverage",
"Add span recording for all decision points"
]
)
else:
violation = Violation(
control_id="CC6.7",
severity=ViolationSeverity.MAJOR,
message=(
f"Trace completeness {completeness:.1f}% is below "
f"threshold of {self.MIN_TRACE_COMPLETENESS}%"
),
remediation=(
"Enable comprehensive OpenTelemetry instrumentation. "
"Ensure every decision point emits at least one span."
)
)
return ControlResult(
control_id="CC6.7",
control_name="Logging and Monitoring",
status=ComplianceStatus.NON_COMPLIANT,
score=completeness,
violations=[violation]
)
def _evaluate_availability(
self, engagement_id, workstream_id, start_date, end_date
) -> List[ControlResult]:
results = []
results.append(self._check_error_rate(
engagement_id, workstream_id, start_date, end_date
))
return results
def _check_error_rate(
self, engagement_id, workstream_id, start_date, end_date
) -> ControlResult:
all_decisions = self._query_telemetry(
engagement_id, workstream_id, start_date, end_date
)
errored = [d for d in all_decisions if d.get("error") is not None]
total = max(len(all_decisions), 1)
error_rate = (len(errored) / total) * 100
if error_rate <= self.MAX_ERROR_RATE_PERCENT:
return ControlResult(
control_id="A1.1",
control_name="System Availability",
status=ComplianceStatus.COMPLIANT,
score=100.0 - error_rate,
evidence=[
f"Error rate: {error_rate:.2f}% "
f"(max: {self.MAX_ERROR_RATE_PERCENT}%)",
f"{len(all_decisions)} total decisions, "
f"{len(errored)} errors"
]
)
else:
violation = Violation(
control_id="A1.1",
severity=ViolationSeverity.MAJOR,
message=(
f"Error rate {error_rate:.2f}% exceeds threshold "
f"of {self.MAX_ERROR_RATE_PERCENT}%"
),
affected_items=[
str(e.get("id", e.get("function_name", "unknown")))
for e in errored[:10]
],
remediation="Investigate error root causes. "
"Implement retry logic and circuit breakers."
)
return ControlResult(
control_id="A1.1",
control_name="System Availability",
status=ComplianceStatus.NON_COMPLIANT,
score=max(0.0, 100.0 - error_rate),
violations=[violation]
)
def _evaluate_processing_integrity(
self, engagement_id, workstream_id, start_date, end_date
) -> List[ControlResult]:
results = []
results.append(self._check_decision_audit_trail(
engagement_id, workstream_id, start_date, end_date
))
return results
def _check_decision_audit_trail(
self, engagement_id, workstream_id, start_date, end_date
) -> ControlResult:
all_decisions = self._query_telemetry(
engagement_id, workstream_id, start_date, end_date
)
if not all_decisions:
return ControlResult(
control_id="PI1.1",
control_name="Decision Audit Trail",
status=ComplianceStatus.NEEDS_REVIEW,
score=0.0,
evidence=["No decision data available for evaluation"],
recommendations=[
"Ensure Briefcase SDK is capturing decisions",
"Verify telemetry pipeline is operational"
]
)
complete = 0
incomplete_ids = []
for d in all_decisions:
has_inputs = bool(d.get("inputs"))
has_outputs = bool(d.get("outputs"))
has_model = bool(d.get("model_parameters") or d.get("model_name"))
if has_inputs and has_outputs and has_model:
complete += 1
else:
incomplete_ids.append(
str(d.get("id", d.get("function_name", "unknown")))
)
total = len(all_decisions)
completeness = (complete / total) * 100
if completeness >= self.MIN_DECISION_AUDIT_PERCENT:
return ControlResult(
control_id="PI1.1",
control_name="Decision Audit Trail",
status=ComplianceStatus.COMPLIANT,
score=completeness,
evidence=[
f"{completeness:.1f}% of decisions have complete "
f"audit trails (threshold: "
f"{self.MIN_DECISION_AUDIT_PERCENT}%)",
f"{complete}/{total} decisions fully audited"
]
)
elif completeness >= self.MIN_DECISION_AUDIT_PERCENT * 0.9:
return ControlResult(
control_id="PI1.1",
control_name="Decision Audit Trail",
status=ComplianceStatus.PARTIAL,
score=completeness,
evidence=[
f"{completeness:.1f}% completeness "
f"(threshold: {self.MIN_DECISION_AUDIT_PERCENT}%)"
],
recommendations=[
"Ensure all decision points capture inputs, outputs, "
"and model parameters",
f"{len(incomplete_ids)} decisions missing data"
]
)
else:
violation = Violation(
control_id="PI1.1",
severity=ViolationSeverity.CRITICAL,
message=(
f"Decision audit completeness {completeness:.1f}% "
f"is below threshold of "
f"{self.MIN_DECISION_AUDIT_PERCENT}%"
),
affected_items=incomplete_ids[:10],
remediation=(
"Add Briefcase decorators to all AI decision functions. "
"Ensure inputs, outputs, and model_parameters are captured."
)
)
return ControlResult(
control_id="PI1.1",
control_name="Decision Audit Trail",
status=ComplianceStatus.NON_COMPLIANT,
score=completeness,
violations=[violation]
)