import json
import os
import shutil
from utils import (
cargo_run,
goto_root,
mk_and_cd_tmp_dir,
read_string,
write_string,
)
def pdl(test_model: str):
assert test_model != "dummy"
home_dir = os.path.expanduser("~")
global_config_dir = os.path.join(home_dir, ".config", "ragit")
assert not os.path.exists(global_config_dir), "~/.config/ragit is found. Please run this test in an isolated environment."
try:
os.mkdir(global_config_dir)
goto_root()
mk_and_cd_tmp_dir()
write_string("test1.pdl", """
<|schema|>
{ name: str, age: int }
<|system|>
You'll be given a short story with a character. Your job is to extract a name and age of the character.
Your response must be a json object, with 2 keys: "name" and "age". "name" is a string and "age" is an integer.
<|user|>
There lives a man named ragit. He is 26 years old.
""")
assert cargo_run(["pdl", "test1.pdl"], check=False) != 0
cargo_run(["init"])
cargo_run(["config", "--set", "model", "dummy"])
assert json.loads(cargo_run(["pdl", "test1.pdl"], stdout=True)) == None
result = json.loads(cargo_run(["pdl", "test1.pdl", "--model", test_model], stdout=True))
assert result == { "name": "ragit", "age": 26 }
shutil.copyfile(".ragit/models.json", "models.json")
shutil.rmtree(".ragit")
shutil.copyfile("../tests/images/hello_world.webp", "sample.webp")
write_string("test2.pdl", """
<|user|>
I have an image of a wooden plank. There's something written on it... What does it say?
<|media(sample.webp)|>
""")
assert cargo_run(["pdl", "test2.pdl"], check=False) != 0
shutil.copyfile("models.json", os.path.join(global_config_dir, "models.json"))
assert cargo_run(["pdl", "test2.pdl"], check=False) != 0
write_string(
os.path.join(global_config_dir, "api.json"),
"{ \"model\": \"dummy\" }",
)
assert cargo_run(["pdl", "test2.pdl"], stdout=True).strip() == "dummy"
result = cargo_run(["pdl", "test2.pdl", "--model", test_model], stdout=True).lower()
assert "hello" in result
assert "world" in result
write_string("test3.pdl", """
<|user|>
Say something
""")
cargo_run(["pdl", "test3.pdl", "--model=dummy", "--log=log1"])
log_file = [f for f in os.listdir("log1") if f.endswith(".pdl")]
assert len(log_file) == 1
log_file = os.path.join("log1", log_file[0])
assert read_string(log_file).count("<|Assistant|>") == 1
cargo_run(["pdl", "test3.pdl", "--model=dummy", "--schema=str { min: 100 }", "--log=log2"])
log_file = [f for f in os.listdir("log2") if f.endswith(".pdl")]
assert len(log_file) == 1
log_file = os.path.join("log2", log_file[0])
assert read_string(log_file).count("<|Assistant|>") > 1
assert "at least 100 characters" in read_string(log_file)
write_string("test4.pdl", """
<|schema|>
integer
<|user|>
Below is the list of the customers
- {{customer1.name}}: {{customer1.age}} years old
- {{customer2.name}}: {{customer2.age}} years old
- {{customer3.name}}: {{customer3.age}} years old
How old is {{customer1.name}}?
<|assistant|>
""")
context = {
"customer1": { "name": "Bae", "age": 29 },
"customer2": { "name": "Hyun", "age": 35 },
"customer3": { "name": "Sol", "age": 16 },
}
broken_context = {
"customer1": [],
}
with open("context.json", "w") as f:
f.write(json.dumps(context))
with open("broken_context.json", "w") as f:
f.write(json.dumps(broken_context))
assert int(cargo_run(["pdl", "test4.pdl", "--model", test_model, "--context=context.json"], stdout=True)) == context["customer1"]["age"]
cargo_run(["pdl", "test4.pdl", "--no-strict", "--model", "dummy", "--context=broken_context.json"])
assert cargo_run(["pdl", "test4.pdl", "--model", "dummy", "--context=broken_context.json"], check=False) != 0
assert cargo_run(["pdl", "test4.pdl", "--strict", "--model", "dummy", "--context=broken_context.json"], check=False) != 0
finally:
shutil.rmtree(global_config_dir)