import yaml, glob, json, re, sys, collections
ATR_RULES = "/tmp/atr/rules" OUT_PACK = "config/shieldset-atr.yaml"
OUT_FIXTURE = "tests/fixtures/atr_cases.json"
FIELD_SCOPE = {
"tool_response": "tool_result",
"content": "tool_result",
"tool_description": "tool_description",
"agent_output": "llm_response",
}
SEV_MAP = {"critical": "Critical", "high": "High", "medium": "Medium", "low": "Low",
"informational": "Low", "info": "Low"}
LOOKAROUND = re.compile(r"\(\?<?[=!]")
BACKREF = re.compile(r"\\[1-9]")
UNICODE_4 = re.compile(r"\\u([0-9a-fA-F]{4})")
CHAR_CLASS = re.compile(r"\[(?:[^\]\\]|\\.)*\]")
def rust_safe(pattern: str):
if LOOKAROUND.search(pattern) or BACKREF.search(pattern):
return None
pattern = UNICODE_4.sub(lambda m: "\\u{%s}" % m.group(1), pattern)
pattern = CHAR_CLASS.sub(lambda m: m.group(0).replace("\\b", ""), pattern)
return pattern
def main():
kept, skipped = [], collections.Counter()
for path in sorted(glob.glob(f"{ATR_RULES}/*/*.yaml")):
cat = path.split("/")[-2]
try:
r = yaml.safe_load(open(path))
except Exception:
skipped["parse_error"] += 1
continue
det = r.get("detection") or {}
conds = det.get("conditions") or []
if not conds or det.get("condition", "any") != "any":
skipped["not_any"] += 1
continue
if any(c.get("operator") != "regex" for c in conds):
skipped["non_regex"] += 1
continue
scopes = {FIELD_SCOPE.get(c.get("field")) for c in conds}
if None in scopes or len(scopes) != 1:
skipped["unmappable_field"] += 1
continue
scope = scopes.pop()
conf = r.get("confidence")
if not isinstance(conf, (int, float)) or conf < 75:
skipped["low_confidence"] += 1
continue
fp = r.get("wild_fp_rate")
if fp is not None and fp != 0:
skipped["nonzero_fp"] += 1
continue
patterns = []
bad = False
for c in conds:
p = rust_safe(str(c["value"]).strip())
if p is None:
bad = True
break
patterns.append(p)
if bad:
skipped["rust_incompatible_regex"] += 1
continue
sev = SEV_MAP.get(str(r.get("severity", "medium")).lower(), "Medium")
atr_id = r.get("id", "ATR-UNKNOWN")
num = atr_id.split("-")[-1]
rule_id = f"atr.{cat.replace('-', '_')}.{num}"
title = r.get("title", atr_id)
kept.append({
"rule_id": rule_id,
"atr_id": atr_id,
"category": cat,
"scope": scope,
"severity": sev,
"confidence": conf,
"title": title,
"patterns": patterns,
"tests": r.get("test_cases") or {},
})
lines = [
"# Aperion Shield rule pack: ATR community rules",
"#",
"# Curated, machine-translated subset of the Agent Threat Rules",
"# corpus (https://github.com/Agent-Threat-Rule/agent-threat-rules,",
"# MIT license, 'ATR Community'). Selection criteria: regex-only",
"# detections on fields that map to Shield text scopes, OR ('any')",
"# condition semantics, confidence >= 75, zero observed wild false-",
"# positive rate. See NOTICE for attribution.",
"#",
"# Load IN ADDITION to the bundled defaults:",
"# aperion-shield --rules-extra shieldset-atr.yaml -- <upstream...>",
"#",
"# Generated by scripts/atr-import.py -- do not hand-edit; regenerate.",
"",
"shieldset:",
" rules:",
]
for k in kept:
lines.append(f" # {k['atr_id']}: {k['title']} (confidence {k['confidence']})")
lines.append(f" - id: {k['rule_id']}")
lines.append(f" severity: {k['severity']}")
lines.append(f" where: {k['scope']}")
lines.append(" match:")
lines.append(" text_matches:")
for p in k["patterns"]:
lines.append(f" - {json.dumps(p)}")
reason = f"{k['title']} (ATR community rule {k['atr_id']})."
lines.append(f" reason: {json.dumps(reason)}")
lines.append(" safer_alternative: \"Review the flagged content; if this server/tool is trusted, add an allow for this rule id.\"")
lines.append("")
open(OUT_PACK, "w").write("\n".join(lines))
cases = []
for k in kept:
for tp in (k["tests"].get("true_positives") or []):
text = tp.get("tool_response") or tp.get("input") or tp.get("content") or tp.get("tool_description")
if text:
cases.append({"rule_id": k["rule_id"], "scope": k["scope"],
"text": text, "expect": "triggered"})
for tn in (k["tests"].get("true_negatives") or []):
text = tn.get("tool_response") or tn.get("input") or tn.get("content") or tn.get("tool_description")
if text:
cases.append({"rule_id": k["rule_id"], "scope": k["scope"],
"text": text, "expect": "not_triggered"})
json.dump(cases, open(OUT_FIXTURE, "w"), indent=1)
bycat = collections.Counter(k["category"] for k in kept)
byscope = collections.Counter(k["scope"] for k in kept)
print(f"kept {len(kept)} rules; {sum(len(k['patterns']) for k in kept)} patterns; {len(cases)} test cases")
print("by category:", dict(bycat))
print("by scope:", dict(byscope))
print("skipped:", dict(skipped))
if __name__ == "__main__":
main()