import argparse
import json
import os
import sys
def count_openai(model: str, text: str) -> int:
import tiktoken
model = model.lower()
encoding = None
try:
if model.startswith("gpt-5") or model.startswith("o1") or model.startswith("o3"):
encoding = tiktoken.get_encoding("o200k_base")
else:
encoding = tiktoken.encoding_for_model(model)
except Exception:
if any(k in model for k in ("gpt-4", "gpt-3", "gpt-4o", "o1", "o2", "o3", "gpt-5", "gpt-oss")):
encoding = tiktoken.get_encoding("o200k_base")
else:
encoding = tiktoken.get_encoding("cl100k_base")
return len(encoding.encode(text))
def count_hf(repo: str, text: str) -> int:
from transformers import AutoTokenizer
local_only = os.getenv("COUNT_TOKENS_HF_OFFLINE", "0") == "1"
tok = AutoTokenizer.from_pretrained(repo, trust_remote_code=True, local_files_only=local_only)
ids = tok.encode(text, add_special_tokens=False)
return len(ids)
def main() -> int:
parser = argparse.ArgumentParser()
parser.add_argument("--mode", required=True, choices=["tiktoken", "hf"])
parser.add_argument("--model", required=True)
parser.add_argument("--text", required=True)
args = parser.parse_args()
if args.mode == "tiktoken":
n = count_openai(args.model, args.text)
else:
n = count_hf(args.model, args.text)
print(json.dumps({"tokens": int(n), "mode": args.mode, "model": args.model}))
return 0
if __name__ == "__main__":
try:
raise SystemExit(main())
except Exception as e:
print(json.dumps({"error": str(e)}))
raise SystemExit(2)