import json
import os
from utils import cargo_run, goto_root, mk_and_cd_tmp_dir
def models_init():
goto_root()
print("\n--- Testing with default model ---")
mk_and_cd_tmp_dir()
print("Running 'rag init'...")
cargo_run(["init"])
models_json_path = os.path.join(".ragit", "models.json")
assert os.path.exists(models_json_path), f"models.json file does not exist at: {models_json_path}"
print(f"models.json file exists at: {models_json_path}")
with open(models_json_path, 'r') as f:
models_json = json.load(f)
print(f"Number of models in models.json: {len(models_json)}")
model_names = [model['name'] for model in models_json]
print(f"Model names: {model_names}")
api_json_path = os.path.join(".ragit", "configs", "api.json")
assert os.path.exists(api_json_path), f"api.json file does not exist at: {api_json_path}"
print(f"api.json file exists at: {api_json_path}")
with open(api_json_path, 'r') as f:
api_json = json.load(f)
selected_model = api_json.get('model', 'Not found')
print(f"Selected model in api.json: {selected_model}")
assert selected_model in model_names, f"Selected model '{selected_model}' does NOT exist in models.json"
print(f"Selected model '{selected_model}' exists in models.json")
os.chdir("..")
print("\n--- Testing with custom models.json without default model ---")
mk_and_cd_tmp_dir()
custom_models = [
{
"name": "gpt-4o",
"api_name": "gpt-4o",
"can_read_images": True,
"api_provider": "openai",
"api_url": "https://api.openai.com/v1/chat/completions",
"input_price": 2.5,
"output_price": 10.0,
"api_timeout": None,
"explanation": None,
"api_key": None,
"api_env_var": "OPENAI_API_KEY"
},
{
"name": "gpt-4o-mini",
"api_name": "gpt-4o-mini",
"can_read_images": True,
"api_provider": "openai",
"api_url": "https://api.openai.com/v1/chat/completions",
"input_price": 0.15,
"output_price": 0.6,
"api_timeout": None,
"explanation": None,
"api_key": None,
"api_env_var": "OPENAI_API_KEY"
},
{
"name": "phi-4-14b-ollama",
"api_name": "phi4:14b",
"can_read_images": True,
"api_provider": "openai",
"api_url": "http://127.0.0.1:11434/v1/chat/completions",
"input_price": 0.0,
"output_price": 0.0,
"api_timeout": None,
"explanation": None,
"api_key": None,
"api_env_var": None
}
]
custom_models_path = os.path.join(os.getcwd(), "custom_models.json")
with open(custom_models_path, 'w') as f:
json.dump(custom_models, f, indent=2)
print("Running 'rag init' with custom models.json...")
os.environ["RAGIT_MODEL_CONFIG"] = custom_models_path
cargo_run(["init"])
models_json_path = os.path.join(".ragit", "models.json")
assert os.path.exists(models_json_path), f"models.json file does not exist at: {models_json_path}"
print(f"models.json file exists at: {models_json_path}")
with open(models_json_path, 'r') as f:
models_json = json.load(f)
print(f"Number of models in models.json: {len(models_json)}")
model_names = [model['name'] for model in models_json]
print(f"Model names: {model_names}")
assert len(models_json) == 3, f"Expected 3 models, got {len(models_json)}"
api_json_path = os.path.join(".ragit", "configs", "api.json")
assert os.path.exists(api_json_path), f"api.json file does not exist at: {api_json_path}"
print(f"api.json file exists at: {api_json_path}")
with open(api_json_path, 'r') as f:
api_json = json.load(f)
selected_model = api_json.get('model', 'Not found')
print(f"Selected model in api.json: {selected_model}")
assert selected_model in model_names, f"Selected model '{selected_model}' does NOT exist in models.json"
print(f"Selected model '{selected_model}' exists in models.json")
assert selected_model == 'phi-4-14b-ollama', f"Expected lowest-cost model 'phi-4-14b-ollama', got '{selected_model}'"
print("SUCCESS: The lowest-cost model was correctly selected!")
if "RAGIT_MODEL_CONFIG" in os.environ:
del os.environ["RAGIT_MODEL_CONFIG"]
def test_home_config_override():
goto_root()
home_dir = os.path.expanduser("~")
config_dir = os.path.join(home_dir, ".config", "ragit")
os.makedirs(config_dir, exist_ok=True)
home_api_json = {
"api_key": None,
"model": "gpt-4o", "timeout": 240000, "sleep_between_retries": 30000, "max_retry": 10, "sleep_after_llm_call": 1000, "dump_log": True, "dump_api_usage": True
}
home_build_json = {
"chunk_size": 2000, "slide_len": 500, }
home_query_json = {
"max_titles": 16, "enable_ii": False }
home_api_path = os.path.join(config_dir, "api.json")
home_build_path = os.path.join(config_dir, "build.json")
home_query_path = os.path.join(config_dir, "query.json")
if os.path.exists(p := home_api_path) or os.path.exists(p := home_build_path) or os.path.exists(p := home_query_path):
raise Exception(f"{p} already exists! This test might overwrite your config file. Please run this test in an isolated environment.")
with open(home_api_path, 'w') as f:
json.dump(home_api_json, f, indent=2)
with open(home_build_path, 'w') as f:
json.dump(home_build_json, f, indent=2)
with open(home_query_path, 'w') as f:
json.dump(home_query_json, f, indent=2)
print("\n--- Testing with ~/.config/ragit/*.json override ---")
mk_and_cd_tmp_dir()
print("Running 'rag init'...")
cargo_run(["init"])
api_json_path = os.path.join(".ragit", "configs", "api.json")
build_json_path = os.path.join(".ragit", "configs", "build.json")
query_json_path = os.path.join(".ragit", "configs", "query.json")
assert os.path.exists(api_json_path), f"api.json file does not exist at: {api_json_path}"
assert os.path.exists(build_json_path), f"build.json file does not exist at: {build_json_path}"
assert os.path.exists(query_json_path), f"query.json file does not exist at: {query_json_path}"
print(f"Config files exist at: {os.path.join('.ragit', 'configs')}")
with open(api_json_path, 'r') as f:
api_json = json.load(f)
with open(build_json_path, 'r') as f:
build_json = json.load(f)
with open(query_json_path, 'r') as f:
query_json = json.load(f)
assert api_json.get('model') == home_api_json['model'], f"Expected model '{home_api_json['model']}', got '{api_json.get('model')}'"
assert api_json.get('timeout') == home_api_json['timeout'], f"Expected timeout {home_api_json['timeout']}, got {api_json.get('timeout')}"
assert api_json.get('sleep_between_retries') == home_api_json['sleep_between_retries'], f"Expected sleep_between_retries {home_api_json['sleep_between_retries']}, got {api_json.get('sleep_between_retries')}"
assert api_json.get('max_retry') == home_api_json['max_retry'], f"Expected max_retry {home_api_json['max_retry']}, got {api_json.get('max_retry')}"
assert api_json.get('sleep_after_llm_call') == home_api_json['sleep_after_llm_call'], f"Expected sleep_after_llm_call {home_api_json['sleep_after_llm_call']}, got {api_json.get('sleep_after_llm_call')}"
assert api_json.get('dump_log') == home_api_json['dump_log'], f"Expected dump_log {home_api_json['dump_log']}, got {api_json.get('dump_log')}"
assert build_json.get('chunk_size') == home_build_json['chunk_size'], f"Expected chunk_size {home_build_json['chunk_size']}, got {build_json.get('chunk_size')}"
assert build_json.get('slide_len') == home_build_json['slide_len'], f"Expected slide_len {home_build_json['slide_len']}, got {build_json.get('slide_len')}"
default_build = {
"image_size": 2000,
"min_summary_len": 200,
"max_summary_len": 1000,
"strict_file_reader": False,
"compression_threshold": 2048,
"compression_level": 3,
}
assert build_json.get('image_size') == default_build['image_size'], f"Expected default image_size {default_build['image_size']}, got {build_json.get('image_size')}"
assert build_json.get('min_summary_len') == default_build['min_summary_len'], f"Expected default min_summary_len {default_build['min_summary_len']}, got {build_json.get('min_summary_len')}"
assert build_json.get('max_summary_len') == default_build['max_summary_len'], f"Expected default max_summary_len {default_build['max_summary_len']}, got {build_json.get('max_summary_len')}"
assert build_json.get('strict_file_reader') == default_build['strict_file_reader'], f"Expected default strict_file_reader {default_build['strict_file_reader']}, got {build_json.get('strict_file_reader')}"
assert build_json.get('compression_threshold') == default_build['compression_threshold'], f"Expected default compression_threshold {default_build['compression_threshold']}, got {build_json.get('compression_threshold')}"
assert build_json.get('compression_level') == default_build['compression_level'], f"Expected default compression_level {default_build['compression_level']}, got {build_json.get('compression_level')}"
assert query_json.get('max_titles') == home_query_json['max_titles'], f"Expected max_titles {home_query_json['max_titles']}, got {query_json.get('max_titles')}"
assert query_json.get('enable_ii') == home_query_json['enable_ii'], f"Expected enable_ii {home_query_json['enable_ii']}, got {query_json.get('enable_ii')}"
default_query = {
"max_summaries": 10,
"max_retrieval": 3,
}
assert query_json.get('max_summaries') == default_query['max_summaries'], f"Expected default max_summaries {default_query['max_summaries']}, got {query_json.get('max_summaries')}"
assert query_json.get('max_retrieval') == default_query['max_retrieval'], f"Expected default max_retrieval {default_query['max_retrieval']}, got {query_json.get('max_retrieval')}"
print("SUCCESS: Values from ~/.config/ragit/*.json were correctly used to override defaults!")
os.remove(home_api_path)
os.remove(home_build_path)
os.remove(home_query_path)
if not os.listdir(config_dir):
os.rmdir(config_dir)
parent_dir = os.path.dirname(config_dir)
if not os.listdir(parent_dir):
os.rmdir(parent_dir)
if __name__ == "__main__":
models_init()
test_home_config_override()