from flask import Flask, request
import random
import re
import subprocess
import time
from typing import Callable, Optional, Tuple
from utils import goto_root
app = Flask(__name__)
last_request = 0
@app.route("/api/chat", methods=["POST"])
def chat():
j = request.get_json()
model = j["model"]
if (r := re.match(r"rate\-limit\-(\d+)", model)) is not None:
return worker(
request = j,
rate_limit = int(r.group(1)),
)
elif (r := re.match(r"delay\-(\d+)(?:\-(\d+))?", model)) is not None:
d_from, d_to = r.groups()
delay = (lambda: d_from) if d_to is None else (lambda: random.randint(d_from, d_to))
return worker(
request = j,
delay = delay,
)
elif (r := re.match(r"fail-(\d+)", model)) is not None:
return worker(
request = j,
fail = lambda: (random.randint(1, 100) <= int(r.group(1)))
)
elif (r := re.match(r"dummy-([0-9a-zA-Z]+)", model)) is not None:
bytes_hex = r.group(1)
bytes_list = []
i = 0
while i < len(bytes_hex):
bytes_list.append(eval(f"0x{bytes_hex[i]}{bytes_hex[i + 1]}"))
i += 2
response = bytes(bytes_list).decode("utf-8")
return worker(
request = j,
output_gen = lambda _: response,
)
elif model == "text-only":
return worker(
request = j,
can_read_images = False,
)
elif model == "repeat-after-me":
return worker(
request = j,
output_gen = get_last_turn,
)
else:
return worker(j)
def worker(
request: dict,
can_read_images: bool = True,
output_gen: Optional[Callable] = None,
delay: Optional[Callable] = None,
rate_limit: Optional[int] = None,
fail: Optional[Callable] = None,
) -> Tuple[dict, int]: global last_request
fail_before_rate_limit = random.randint(0, 1) == 1
if fail: fail = fail()
if fail_before_rate_limit and fail:
return {}, 500
if rate_limit:
if check_rate_limit(rate_limit):
return {}, 429
push_rate_limit_queue()
if not fail_before_rate_limit and fail:
return {}, 500
output_gen = output_gen or (lambda _: "dummy")
messages = request["messages"]
input_tokens = 0
if not can_read_images:
for message in messages:
if isinstance(content := message["content"], list):
for c in content:
if c["type"] != "text":
return {}, 400
for message in messages:
content = message["content"]
if isinstance(content, list):
content = " ".join([c.get("text", "") for c in content])
input_tokens += len(content.split(" "))
output = output_gen(request)
output_tokens = len(output.split(" "))
if delay:
time.sleep(delay())
return {
"id": "dummy",
"object": "dummy",
"created": int(time.time()),
"model": request["model"],
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": output,
},
"finish_reason": "stop",
}],
"usage": {
"prompt_tokens": input_tokens,
"completion_tokens": output_tokens,
"total_tokens": input_tokens + output_tokens,
},
}, 200
rate_limit_queue = {}
def check_rate_limit(limit: int) -> bool:
now = int(time.time()) // 60
return rate_limit_queue.get(now, 0) >= limit
def push_rate_limit_queue():
now = int(time.time()) // 60
rate_limit_queue[now] = rate_limit_queue.get(now, 0) + 1
def host_fake_llm_server():
goto_root()
server_process = subprocess.Popen(["python3", "./tests/fake_llm_server.py"])
return server_process
def get_last_turn(request: dict) -> str:
content = request["messages"][-1]["content"]
if isinstance(content, str):
return content
result = ""
for c in content:
if "text" in c:
result += c["text"]
else:
result += "(image)"
return result
if __name__ == "__main__":
app.run(host="0.0.0.0", port=11435)