from __future__ import annotations
import argparse
import json
import sys
from collections import defaultdict
from dataclasses import dataclass
from pathlib import Path
from typing import Any
@dataclass
class TokenTotals:
baseline: int
compressed: int
rows: int
@property
def saved(self) -> int:
return self.baseline - self.compressed
def load_jsonl(path: Path) -> list[dict[str, Any]]:
rows: list[dict[str, Any]] = []
with path.open() as handle:
for lineno, line in enumerate(handle, 1):
stripped = line.strip()
if not stripped:
continue
try:
rows.append(json.loads(stripped))
except json.JSONDecodeError as exc:
raise SystemExit(f"{path}:{lineno}: invalid JSON: {exc}") from exc
return rows
def row_key(row: dict[str, Any]) -> tuple[str, str, str, str]:
return (
str(row.get("scenario", "")),
str(row.get("run", "")),
str(row.get("stream", "")),
str(row.get("budget_tokens", "")),
)
def pair_rows(
baseline: list[dict[str, Any]],
compressed: list[dict[str, Any]],
) -> tuple[list[tuple[dict[str, Any], dict[str, Any]]], int, int]:
baseline_by_key: dict[
tuple[str, str, str, str], list[dict[str, Any]]
] = defaultdict(list)
compressed_by_key: dict[
tuple[str, str, str, str], list[dict[str, Any]]
] = defaultdict(list)
for row in baseline:
baseline_by_key[row_key(row)].append(row)
for row in compressed:
compressed_by_key[row_key(row)].append(row)
pairs: list[tuple[dict[str, Any], dict[str, Any]]] = []
baseline_unpaired = 0
compressed_unpaired = 0
for key in sorted(set(baseline_by_key) | set(compressed_by_key)):
baseline_rows = baseline_by_key.get(key, [])
compressed_rows = compressed_by_key.get(key, [])
paired = min(len(baseline_rows), len(compressed_rows))
pairs.extend(zip(baseline_rows[:paired], compressed_rows[:paired]))
baseline_unpaired += len(baseline_rows) - paired
compressed_unpaired += len(compressed_rows) - paired
return pairs, baseline_unpaired, compressed_unpaired
def as_int(value: Any) -> int | None:
if isinstance(value, bool):
return None
if isinstance(value, int):
return value
if isinstance(value, float) and value.is_integer():
return int(value)
return None
def token_totals(
pairs: list[tuple[dict[str, Any], dict[str, Any]]],
*fields: str,
) -> TokenTotals:
baseline_total = 0
compressed_total = 0
rows = 0
for baseline, compressed in pairs:
baseline_values = [as_int(baseline.get(field)) for field in fields]
compressed_values = [as_int(compressed.get(field)) for field in fields]
if any(value is None for value in baseline_values + compressed_values):
continue
baseline_total += sum(value for value in baseline_values if value is not None)
compressed_total += sum(value for value in compressed_values if value is not None)
rows += 1
return TokenTotals(baseline_total, compressed_total, rows)
def count_true(rows: list[dict[str, Any]], field: str) -> int:
return sum(1 for row in rows if bool(row.get(field)))
def count_accuracy_false(rows: list[dict[str, Any]]) -> int:
return sum(1 for row in rows if row.get("accuracy") is False)
def pct(saved: int, baseline: int) -> str:
if baseline <= 0:
return "n/a"
return f"{saved / baseline * 100:.1f}%"
def print_token_line(label: str, totals: TokenTotals) -> None:
if totals.rows == 0:
print(f" {label}: unavailable; no paired rows reported usage")
return
print(
f" {label}: baseline {totals.baseline}, compressed {totals.compressed}, "
f"saved {totals.saved} ({pct(totals.saved, totals.baseline)}), rows {totals.rows}"
)
def scenario_savings(
pairs: list[tuple[dict[str, Any], dict[str, Any]]],
) -> list[tuple[str, TokenTotals, int, int, int]]:
by_scenario: dict[
str, list[tuple[dict[str, Any], dict[str, Any]]]
] = defaultdict(list)
for baseline, compressed in pairs:
scenario = str(baseline.get("scenario", "<unknown>"))
by_scenario[scenario].append((baseline, compressed))
results = []
for scenario, scenario_pairs in sorted(by_scenario.items()):
totals = token_totals(scenario_pairs, "input_tokens")
baseline_success = count_true(
[baseline for baseline, _ in scenario_pairs],
"success",
)
compressed_success = count_true(
[compressed for _, compressed in scenario_pairs],
"success",
)
results.append((
scenario,
totals,
baseline_success,
compressed_success,
len(scenario_pairs),
))
return results
def main() -> int:
parser = argparse.ArgumentParser(
description="Compare disabled and compressed local proxy eval JSONL outputs"
)
parser.add_argument("baseline_jsonl", type=Path)
parser.add_argument("compressed_jsonl", type=Path)
parser.add_argument(
"--min-input-token-savings",
type=int,
default=0,
help="Minimum aggregate prompt-token savings required for success",
)
parser.add_argument(
"--allow-behavior-regression",
action="store_true",
help="Report behavior regressions without failing",
)
args = parser.parse_args()
baseline_rows = load_jsonl(args.baseline_jsonl)
compressed_rows = load_jsonl(args.compressed_jsonl)
pairs, baseline_unpaired, compressed_unpaired = pair_rows(
baseline_rows,
compressed_rows,
)
baseline_paired = [baseline for baseline, _ in pairs]
compressed_paired = [compressed for _, compressed in pairs]
input_totals = token_totals(pairs, "input_tokens")
output_totals = token_totals(pairs, "output_tokens")
total_totals = token_totals(pairs, "input_tokens", "output_tokens")
baseline_success = count_true(baseline_paired, "success")
compressed_success = count_true(compressed_paired, "success")
baseline_complete = count_true(baseline_paired, "completeness")
compressed_complete = count_true(compressed_paired, "completeness")
baseline_accuracy_false = count_accuracy_false(baseline_paired)
compressed_accuracy_false = count_accuracy_false(compressed_paired)
failures: list[str] = []
warnings: list[str] = []
if not pairs:
failures.append("no comparable rows found")
if baseline_unpaired:
warnings.append(f"{baseline_unpaired} baseline rows were not paired")
if compressed_unpaired:
warnings.append(f"{compressed_unpaired} compressed rows were not paired")
if not args.allow_behavior_regression:
if compressed_success < baseline_success:
failures.append(
f"success regressed from {baseline_success}/{len(pairs)} "
f"to {compressed_success}/{len(pairs)}"
)
if compressed_complete < baseline_complete:
failures.append(
f"completeness regressed from {baseline_complete}/{len(pairs)} "
f"to {compressed_complete}/{len(pairs)}"
)
if compressed_accuracy_false > baseline_accuracy_false:
failures.append(
f"accuracy_false increased from {baseline_accuracy_false} "
f"to {compressed_accuracy_false}"
)
if input_totals.rows == 0 and args.min_input_token_savings > 0:
failures.append("input token savings unavailable because no paired rows reported usage")
elif input_totals.saved < args.min_input_token_savings:
failures.append(
f"input token savings {input_totals.saved} below required "
f"{args.min_input_token_savings}"
)
print("Compression Eval Summary")
print(f" Baseline: {args.baseline_jsonl}")
print(f" Compressed: {args.compressed_jsonl}")
print(
f" Rows: baseline {len(baseline_rows)}, compressed {len(compressed_rows)}, "
f"paired {len(pairs)}"
)
print(
f" Success: baseline {baseline_success}/{len(pairs)}, "
f"compressed {compressed_success}/{len(pairs)}"
)
print(
f" Completeness: baseline {baseline_complete}/{len(pairs)}, "
f"compressed {compressed_complete}/{len(pairs)}"
)
print(
f" Accuracy false: baseline {baseline_accuracy_false}, "
f"compressed {compressed_accuracy_false}"
)
print_token_line("Input tokens", input_totals)
print_token_line("Output tokens", output_totals)
print_token_line("Total tokens", total_totals)
scenario_rows = scenario_savings(pairs)
if scenario_rows:
print(" Per-scenario input savings:")
for scenario, totals, baseline_s, compressed_s, pair_count in scenario_rows:
if totals.rows == 0:
savings = "usage unavailable"
else:
savings = (
f"saved {totals.saved} ({pct(totals.saved, totals.baseline)})"
)
print(
f" {scenario}: {savings}, "
f"success {baseline_s}/{pair_count} -> "
f"{compressed_s}/{pair_count}"
)
if warnings:
print("\nWarnings:")
for warning in warnings:
print(f" - {warning}")
if failures:
sys.stdout.flush()
print("\nFailures:", file=sys.stderr)
for failure in failures:
print(f" - {failure}", file=sys.stderr)
return 1
print("\nCompression comparison passed.")
return 0
if __name__ == "__main__":
raise SystemExit(main())