import json
import os
import requests
from server import (
create_repo,
create_user,
get_api_key,
spawn_ragit_server,
)
from typing import Optional
from utils import cargo_run, goto_root, mk_and_cd_tmp_dir
def server_chat(test_model: str):
goto_root()
server_process = None
try:
server_process = spawn_ragit_server()
mk_and_cd_tmp_dir()
cargo_run(["clone", "https://ragit.baehyunsol.com/sample/rustc", "sample1"])
os.chdir("sample1")
cargo_run(["config", "--set", "model", test_model])
exact_model_name = cargo_run(["ls-models", "--selected", "--name-only"], stdout=True).strip()
create_user(id="test-user", password="secure-password")
server_api_key = get_api_key(id="test-user", password="secure-password")
model_api_key = get_model_api_key()
create_repo(user="test-user", repo="sample1", api_key=server_api_key)
register_model(
user="test-user",
model_name=exact_model_name,
model_api_key=model_api_key,
server_api_key=server_api_key,
)
cargo_run(["push", "--configs", "--remote=http://127.0.0.1:41127/test-user/sample1"])
os.chdir("..")
os.mkdir("sample2")
os.chdir("sample2")
cargo_run(["init"])
cargo_run(["config", "--set", "model", test_model])
create_repo(user="test-user", repo="sample2", api_key=server_api_key)
cargo_run(["push", "--configs", "--remote=http://127.0.0.1:41127/test-user/sample2"])
os.chdir("..")
chat_id1 = requests.post("http://127.0.0.1:41127/test-user/sample1/chat-list").text
chat_id2 = requests.post("http://127.0.0.1:41127/test-user/sample2/chat-list").text
responses1 = []
responses2 = []
chat_list = requests.get("http://127.0.0.1:41127/test-user/sample1/chat-list").json()
assert len(chat_list) == 1
assert str(chat_list[0]["id"]) == chat_id1
responses1.append(requests.post(f"http://127.0.0.1:41127/test-user/sample1/chat/{chat_id1}", files={"query": "How does the rust compiler implement type system?"}).json())
responses2.append(requests.post(f"http://127.0.0.1:41127/test-user/sample2/chat/{chat_id2}", files={"query": "How does the rust compiler implement type system?"}).json())
responses1.append(requests.post(f"http://127.0.0.1:41127/test-user/sample1/chat/{chat_id1}", data={"query": "What do you mean by MIR?"}).json())
responses2.append(requests.post(f"http://127.0.0.1:41127/test-user/sample2/chat/{chat_id2}", data={"query": "What do you mean by MIR?"}).json())
responses1.append(requests.post(f"http://127.0.0.1:41127/test-user/sample1/chat/{chat_id1}", files={"query": "Thanks!"}).json())
responses2.append(requests.post(f"http://127.0.0.1:41127/test-user/sample2/chat/{chat_id2}", files={"query": "Thanks!"}).json())
history1 = requests.get(f"http://127.0.0.1:41127/test-user/sample1/chat/{chat_id1}").json()["history"]
history2 = requests.get(f"http://127.0.0.1:41127/test-user/sample2/chat/{chat_id2}").json()["history"]
chat_list = requests.get("http://127.0.0.1:41127/test-user/sample1/chat-list").json()
assert len(chat_list) == 1
assert str(chat_list[0]["id"]) == chat_id1
for response in responses2:
assert len(response["chunk_uids"]) == 0
assert [(h["response"], h["multi_turn_schema"]) for h in history1] == [(r["response"], r["multi_turn_schema"]) for r in responses1]
assert [(h["response"], h["multi_turn_schema"]) for h in history2] == [(r["response"], r["multi_turn_schema"]) for r in responses2]
for _ in range(3):
assert requests.post("http://127.0.0.1:41127/test-user/sample1/build-search-index").status_code == 200
assert requests.post("http://127.0.0.1:41127/test-user/sample2/build-search-index").status_code == 200
finally:
if server_process is not None:
server_process.kill()
def get_model_api_key() -> Optional[str]:
if cargo_run(["config", "--get", "model"], stdout=True).strip() in ["dummy", "stdin", "error"]:
return None
with open(os.path.join(".ragit", "models.json"), "r") as f:
models = json.load(f)
model_full_name = json.loads(cargo_run(["ls-models", "--json", "--selected"], stdout=True).strip())[0]["name"]
model = [model for model in models if model["name"] == model_full_name][0]
if (api_key := model.get("api_key")) is not None:
return api_key
elif (api_env_var := model.get("api_env_var")) is not None:
if (api_key := os.environ.get(api_env_var)) is not None:
return api_key
else:
raise Exception(f"API key is not set. Please set the {api_env_var} environment variable.")
else:
return None
def register_model(
user: str,
model_name: str,
model_api_key: Optional[str],
server_api_key: str,
) -> str:
models = requests.get(f"http://127.0.0.1:41127/ai-model-list").json()
models = [model for model in models if model_name in model["name"] or model_name in model["api_name"]]
if len(models) == 0:
raise Exception(f"No model named {model_name} in the server.")
if len(models) > 1:
raise Exception(f"Model name {model_name} is ambiguous: {models}")
model_id = models[0]["id"]
for i in range(3):
response = requests.put(
f"http://127.0.0.1:41127/user-list/{user}/ai-model-list",
json={ "model_id": model_id, "default_model": i == 2, "api_key": model_api_key },
headers={ "x-api-key": server_api_key },
)
assert response.status_code == 200, f"Failed to set default model: {response.text}"
return model_name
def put_model_api_key(user: str, model_name: str, model_api_key: str, server_api_key: str):
response = requests.put(
f"http://127.0.0.1:41127/user-list/{user}/ai-model-list",
json={ "model": model_name, "api_key": model_api_key },
headers={ "x-api-key": server_api_key },
)
assert response.status_code == 200, f"Failed to set API key: {response.text}"