from __future__ import annotations
import argparse, datetime, json, sys, urllib.request, pathlib
LITELLM_URL = (
"https://raw.githubusercontent.com/BerriAI/litellm/main/"
"litellm/model_prices_and_context_window_backup.json"
)
PROVIDER_PREFIXES = (
"openrouter/", "vertex_ai/", "vertex_ai-language-models/",
"vertex_ai-anthropic_models/", "vertex_ai-mistral_models/",
"bedrock/anthropic.", "bedrock/amazon.", "bedrock/cohere.",
"bedrock/meta.", "bedrock/", "azure/",
"anthropic/", "openai/", "gemini/", "groq/", "mistral/",
"deepseek/", "fireworks_ai/", "together_ai/", "perplexity/",
"cohere/", "replicate/", "ai21/", "xai/",
)
KEEP_MODES = {"chat", "completion", "responses"}
LOCAL_PREFIXES = ("ollama/", "ollama_chat/", "huggingface/", "vllm/")
def strip_provider(key: str) -> str:
s = key
for p in PROVIDER_PREFIXES:
if s.startswith(p):
s = s[len(p):]
return s
def parse_litellm(raw: dict) -> list[tuple[str, float, float, int | None]]:
out: dict[str, tuple[float, float, int | None]] = {}
for raw_key, entry in raw.items():
if raw_key == "sample_spec" or not isinstance(entry, dict):
continue
if any(raw_key.startswith(p) for p in LOCAL_PREFIXES):
continue
mode = entry.get("mode", "")
if mode and mode not in KEEP_MODES:
continue
i = entry.get("input_cost_per_token")
o = entry.get("output_cost_per_token")
if i is None or o is None:
continue
try:
i_per_mtok = float(i) * 1_000_000.0
o_per_mtok = float(o) * 1_000_000.0
except (TypeError, ValueError):
continue
if i_per_mtok <= 0.0 and o_per_mtok <= 0.0:
continue
ctx = (
entry.get("max_input_tokens")
or entry.get("max_tokens")
or entry.get("max_context_tokens")
)
try:
ctx = int(ctx) if ctx else None
if ctx is not None and ctx <= 0:
ctx = None
except (TypeError, ValueError):
ctx = None
key = strip_provider(raw_key)
prev = out.get(key)
if prev is None or (i_per_mtok + o_per_mtok) > (prev[0] + prev[1]):
chosen_ctx = max(ctx or 0, prev[2] or 0) if prev else ctx
out[key] = (i_per_mtok, o_per_mtok, chosen_ctx if chosen_ctx else ctx)
elif ctx and (prev[2] is None or ctx > prev[2]):
out[key] = (prev[0], prev[1], ctx)
return [(k, i, o, ctx) for k, (i, o, ctx) in sorted(out.items())]
def render_rust(rows: list[tuple[str, float, float, int | None]], date: str) -> str:
lines = [
"// AUTO-GENERATED by scripts/sync_prices.py — do not edit by hand.",
"// Source: LiteLLM model_prices_and_context_window_backup.json",
"// Regenerate via: python3 scripts/sync_prices.py",
"",
"use crate::pricing::ModelPrice;",
"",
f'pub const PRICES_UPDATED: &str = "{date}";',
f'pub const PRICES_SOURCE: &str = "litellm community registry";',
"",
f"pub const GENERATED: &[(&str, ModelPrice)] = &[",
]
for k, i, o, ctx in rows:
safe = k.replace("\\", "\\\\").replace('"', '\\"')
ctx_lit = f"Some({ctx})" if ctx else "None"
lines.append(
f' ("{safe}", ModelPrice {{ '
f'input_per_mtok: {i:.6}, output_per_mtok: {o:.6}, '
f'max_input_tokens: {ctx_lit} }}),'
)
lines.append("];")
lines.append("")
return "\n".join(lines)
def main() -> int:
ap = argparse.ArgumentParser()
ap.add_argument("--check", action="store_true",
help="Exit 1 if the generated file would change.")
args = ap.parse_args()
here = pathlib.Path(__file__).resolve().parent
out = here.parent / "src" / "pricing_data.rs"
with urllib.request.urlopen(LITELLM_URL, timeout=30) as r:
raw = json.load(r)
rows = parse_litellm(raw)
if not rows:
print("ERROR: parsed zero rows from LiteLLM — refusing to overwrite",
file=sys.stderr)
return 2
today = datetime.date.today().isoformat()
rendered = render_rust(rows, today)
if args.check:
existing = out.read_text() if out.exists() else ""
def strip_date(t: str) -> str:
return "\n".join(
l for l in t.splitlines()
if not l.startswith("pub const PRICES_UPDATED:")
)
if strip_date(existing) != strip_date(rendered):
print(f"prices drifted: {len(rows)} models in upstream, "
f"{out} would change")
return 1
print(f"prices fresh: {len(rows)} models, no change")
return 0
out.write_text(rendered)
print(f"wrote {out} — {len(rows)} models @ {today}")
return 0
if __name__ == "__main__":
sys.exit(main())