from __future__ import annotations
import json
import os
import statistics
import subprocess
import tempfile
import time
from pathlib import Path
from typing import Callable, Iterable
THIS_DIR = Path(__file__).parent
WORK_DIR = THIS_DIR / ".simstring_cpp"
SRC_DIR = WORK_DIR / "src"
INSTALL_DIR = WORK_DIR / "install"
SIMSTRING_BIN = INSTALL_DIR / "bin" / "simstring"
REPO_URL = "https://github.com/chokkan/simstring.git"
COMPANY_FILE = THIS_DIR / "data" / "company_names.txt"
MEASUREMENT_TIME = int(os.environ.get("SIMSTRING_BENCH_DURATION", "20"))
MAX_ITERATIONS = int(os.environ.get("SIMSTRING_BENCH_MAX_ITERATIONS", "100"))
NGRAM_SIZES = (2, 3, 4)
THRESHOLDS = (0.6, 0.7, 0.8, 0.9)
DEBUG = bool(os.environ.get("SIMSTRING_BENCH_DEBUG"))
def run_command(cmd: list[str], cwd: Path | None = None) -> None:
proc = subprocess.run(
cmd, cwd=cwd, check=False, text=True, capture_output=not DEBUG
)
if DEBUG:
pass
if proc.returncode != 0:
stdout = proc.stdout or ""
stderr = proc.stderr or ""
raise RuntimeError(
f"Command {' '.join(cmd)} failed with code {proc.returncode}\n"
f"stdout:\n{stdout}\n\nstderr:\n{stderr}"
)
def ensure_simstring_binary() -> Path:
if SIMSTRING_BIN.exists():
return SIMSTRING_BIN
WORK_DIR.mkdir(parents=True, exist_ok=True)
if not SRC_DIR.exists():
run_command(["git", "clone", "--depth", "1", REPO_URL, str(SRC_DIR)])
autogen = SRC_DIR / "autogen.sh"
if autogen.exists() and not (SRC_DIR / "configure").exists():
run_command(["/bin/sh", str(autogen)], cwd=SRC_DIR)
if not (SRC_DIR / "Makefile").exists():
run_command(
[
"./configure",
f"--prefix={INSTALL_DIR}",
],
cwd=SRC_DIR,
)
jobs = os.cpu_count() or 2
run_command(["make", f"-j{jobs}"], cwd=SRC_DIR)
run_command(["make", "install"], cwd=SRC_DIR)
return SIMSTRING_BIN
def load_company_names() -> list[str]:
with COMPANY_FILE.open() as fh:
return [line.strip() for line in fh if line.strip()]
def measure(func: Callable[[], None]) -> float:
start = time.perf_counter()
func()
return (time.perf_counter() - start) * 1000.0
def run_simstring(args: Iterable[str], stdin_data: str) -> None:
subprocess.run(
[str(SIMSTRING_BIN), *args],
input=stdin_data,
text=True,
capture_output=True,
check=True,
)
def build_database(ngram_size: int, db_path: Path, company_blob: str) -> None:
db_path.parent.mkdir(parents=True, exist_ok=True)
run_simstring(
[
"-b",
"-d",
str(db_path),
"-n",
str(ngram_size),
"-m",
"-q",
],
company_blob,
)
def ensure_persistent_db(ngram_size: int, company_blob: str) -> Path:
db_path = WORK_DIR / f"company_names_{ngram_size}.db"
if db_path.exists():
return db_path
build_database(ngram_size, db_path, company_blob)
return db_path
def summarize_measurements(measurements: list[float]) -> tuple[float, float]:
mean = statistics.fmean(measurements)
stddev = statistics.stdev(measurements) if len(measurements) > 1 else 0.0
return mean, stddev
def bench_insert(results: list[dict], company_blob: str) -> None:
for ngram_size in NGRAM_SIZES:
measurements: list[float] = []
start = time.time()
iteration = 0
while time.time() - start < MEASUREMENT_TIME and iteration < MAX_ITERATIONS:
with tempfile.TemporaryDirectory(dir=WORK_DIR) as tmpdir:
db_path = Path(tmpdir) / "bench.db"
def workload() -> None:
build_database(ngram_size, db_path, company_blob)
duration_ms = measure(workload)
measurements.append(duration_ms)
iteration += 1
mean, stddev = summarize_measurements(measurements)
results.append(
{
"language": "c++",
"backend": "simstring (C++ CLI)",
"benchmark": "insert",
"parameters": {"ngram_size": ngram_size},
"stats": {
"mean": mean,
"stddev": stddev,
"iterations": len(measurements),
},
}
)
def bench_search(results: list[dict], company_blob: str, query_blob: str) -> None:
for ngram_size in NGRAM_SIZES:
db_path = ensure_persistent_db(ngram_size, company_blob)
for threshold in THRESHOLDS:
measurements: list[float] = []
start = time.time()
iteration = 0
cmd = [
"-d",
str(db_path),
"-t",
f"{threshold}",
"-s",
"cosine",
"-m",
"-p",
"-q",
]
while time.time() - start < MEASUREMENT_TIME and iteration < MAX_ITERATIONS:
duration_ms = measure(lambda: run_simstring(cmd, query_blob))
measurements.append(duration_ms)
iteration += 1
mean, stddev = summarize_measurements(measurements)
results.append(
{
"language": "c++",
"backend": "simstring (C++ CLI)",
"benchmark": "search",
"parameters": {"ngram_size": ngram_size, "threshold": threshold},
"stats": {
"mean": mean,
"stddev": stddev,
"iterations": len(measurements),
},
}
)
def main() -> None:
ensure_simstring_binary()
companies = load_company_names()
company_blob = "\n".join(companies) + "\n"
queries = companies[:100]
query_blob = "\n".join(queries) + "\n"
results: list[dict] = []
bench_insert(results, company_blob)
bench_search(results, company_blob, query_blob)
print(json.dumps(results, indent=2))
if __name__ == "__main__":
main()