use std::fs;
use std::path::Path;
use tempfile::tempdir;
use outrig::config::{Config, ConfigValidationError};
use outrig::error::OutrigError;
fn parse(s: &str) -> Config {
Config::load_from_str(s).expect("config parses")
}
fn expect_validation_err(cfg: &Config, repo_root: Option<&Path>) -> ConfigValidationError {
match cfg.validate(repo_root) {
Err(OutrigError::ConfigValidation(e)) => e,
Err(other) => panic!("expected ConfigValidation, got: {other:?}"),
Ok(()) => panic!("expected validation error, got Ok"),
}
}
#[test]
fn mistralrs_with_model_id_and_file_parses_and_validates() {
let cfg = parse(
r#"
[providers.local]
style = "mistralrs"
[models.qwen]
provider = "local"
model-id = "Qwen/Qwen2.5-7B-Instruct"
model-file = "qwen2.5-7b-instruct-q4_k_m.gguf"
"#,
);
cfg.validate(None).expect("validates");
let serialized = toml::to_string(&cfg).expect("serializes");
let again = Config::load_from_str(&serialized).expect("reserialized parses");
assert_eq!(cfg, again);
}
#[test]
fn mistralrs_device_forms_parse_validate_and_round_trip() {
let cfg = parse(
r#"
[providers.local]
style = "mistralrs"
[models.cpu]
provider = "local"
model-id = "Qwen/Qwen2.5-7B-Instruct"
model-file = "qwen2.5-7b-instruct-q4_k_m.gguf"
device = "cpu"
[models.cuda_default]
provider = "local"
model-id = "Qwen/Qwen2.5-7B-Instruct"
model-file = "qwen2.5-7b-instruct-q4_k_m.gguf"
device = "cuda"
[models.cuda_indexed]
provider = "local"
model-id = "Qwen/Qwen2.5-7B-Instruct"
model-file = "qwen2.5-7b-instruct-q4_k_m.gguf"
device = "cuda:1"
[models.metal]
provider = "local"
model-id = "Qwen/Qwen2.5-7B-Instruct"
model-file = "qwen2.5-7b-instruct-q4_k_m.gguf"
device = "metal"
"#,
);
cfg.validate(None).expect("all documented forms validate");
assert_eq!(cfg.models["cpu"].device.as_deref(), Some("cpu"));
assert_eq!(cfg.models["cuda_default"].device.as_deref(), Some("cuda"));
assert_eq!(cfg.models["cuda_indexed"].device.as_deref(), Some("cuda:1"));
assert_eq!(cfg.models["metal"].device.as_deref(), Some("metal"));
let serialized = toml::to_string(&cfg).expect("serializes");
let again = Config::load_from_str(&serialized).expect("reserialized parses");
assert_eq!(cfg, again);
}
#[test]
fn mistralrs_invalid_device_fails_validate() {
for device in ["gpu", "cuda:", "cuda:abc", "metal:0"] {
let cfg = parse(&format!(
r#"
[providers.local]
style = "mistralrs"
[models.qwen]
provider = "local"
model-id = "Qwen/Qwen2.5-7B-Instruct"
model-file = "qwen2.5-7b-instruct-q4_k_m.gguf"
device = "{device}"
"#,
));
let err = expect_validation_err(&cfg, None);
assert!(
matches!(
err,
ConfigValidationError::MistralrsDeviceInvalid {
ref model,
device: ref got,
} if model == "qwen" && got.as_str() == device
),
"device {device:?} got: {err:?}",
);
}
}
#[test]
fn mistralrs_model_id_without_model_file_fails_validate() {
let cfg = parse(
r#"
[providers.local]
style = "mistralrs"
[models.qwen]
provider = "local"
model-id = "Qwen/Qwen2.5-7B-Instruct"
"#,
);
let err = expect_validation_err(&cfg, None);
assert!(
matches!(
err,
ConfigValidationError::MistralrsModelIdMissingFile {
ref model, ref model_id,
} if model == "qwen" && model_id == "Qwen/Qwen2.5-7B-Instruct"
),
"got: {err:?}",
);
}
#[test]
fn mistralrs_missing_both_fails_validate() {
let cfg = parse(
r#"
[providers.local]
style = "mistralrs"
[models.qwen]
provider = "local"
"#,
);
let err = expect_validation_err(&cfg, None);
assert!(
matches!(
err,
ConfigValidationError::MistralrsMissingModelSource { ref model }
if model == "qwen"
),
"got: {err:?}",
);
}
#[test]
fn mistralrs_with_both_fails_validate() {
let cfg = parse(
r#"
[providers.local]
style = "mistralrs"
[models.qwen]
provider = "local"
model-id = "Qwen/Qwen2.5-7B-Instruct"
model-path = "/tmp/model.gguf"
"#,
);
let err = expect_validation_err(&cfg, None);
assert!(
matches!(
err,
ConfigValidationError::MistralrsBothModelSources { ref model }
if model == "qwen"
),
"got: {err:?}",
);
}
#[test]
fn mistralrs_extra_field_without_model_id_fails_validate() {
let cfg = parse(
r#"
[providers.local]
style = "mistralrs"
[models.qwen]
provider = "local"
model-path = "/tmp/model.gguf"
model-file = "weights.gguf"
"#,
);
let err = expect_validation_err(&cfg, None);
assert!(
matches!(
err,
ConfigValidationError::MistralrsExtraFieldRequiresModelId {
ref model, field,
} if model == "qwen" && field == "model-file"
),
"got: {err:?}",
);
}
#[test]
fn mistralrs_model_with_identifier_fails_validate() {
let cfg = parse(
r#"
[providers.local]
style = "mistralrs"
[models.qwen]
provider = "local"
identifier = "qwen-on-the-wire"
model-id = "Qwen/Qwen2.5-7B-Instruct"
"#,
);
let err = expect_validation_err(&cfg, None);
assert!(
matches!(
err,
ConfigValidationError::MistralrsModelHasOpenAiField { ref model, field }
if model == "qwen" && field == "identifier"
),
"got: {err:?}",
);
}
#[test]
fn openai_model_missing_identifier_fails_validate() {
let cfg = parse(
r#"
[providers.openai]
style = "openai"
base-url = "https://api.openai.com/v1"
api-key = "${OPENAI_API_KEY}"
[models.fast]
provider = "openai"
"#,
);
let err = expect_validation_err(&cfg, None);
assert!(
matches!(
err,
ConfigValidationError::OpenAiModelMissingIdentifier { ref model }
if model == "fast"
),
"got: {err:?}",
);
}
#[test]
fn openai_model_with_weight_field_fails_validate() {
let cfg = parse(
r#"
[providers.openai]
style = "openai"
base-url = "https://api.openai.com/v1"
api-key = "${OPENAI_API_KEY}"
[models.fast]
provider = "openai"
identifier = "gpt-4o-mini"
model-id = "should-not-be-here"
"#,
);
let err = expect_validation_err(&cfg, None);
assert!(
matches!(
err,
ConfigValidationError::OpenAiModelHasMistralrsField { ref model, field }
if model == "fast" && field == "model-id"
),
"got: {err:?}",
);
}
#[test]
fn openai_model_with_device_fails_validate() {
let cfg = parse(
r#"
[providers.openai]
style = "openai"
base-url = "https://api.openai.com/v1"
api-key = "${OPENAI_API_KEY}"
[models.fast]
provider = "openai"
identifier = "gpt-4o-mini"
device = "cuda"
"#,
);
let err = expect_validation_err(&cfg, None);
assert!(
matches!(
err,
ConfigValidationError::OpenAiModelHasMistralrsField { ref model, field }
if model == "fast" && field == "device"
),
"got: {err:?}",
);
}
#[test]
fn unknown_style_typo_useful_error() {
let toml = r#"
[providers.local]
style = "mistral-rs"
base-url = "https://localhost:1234/v1"
api-key = "${KEY}"
"#;
let err = Config::load_from_str(toml).unwrap_err();
let msg = err.to_string();
assert!(
msg.contains("mistral-rs"),
"error should quote the offending value, got: {msg}",
);
assert!(
msg.contains("openai") || msg.contains("mistralrs"),
"error should name at least one legal variant, got: {msg}",
);
}
#[test]
fn model_cache_root_relative_fails_validate() {
let cfg = parse(
r#"
model-cache-root = "models"
"#,
);
let err = expect_validation_err(&cfg, None);
assert!(
matches!(
err,
ConfigValidationError::ModelCacheRootNotAbsolute { ref path }
if path.as_os_str() == "models"
),
"got: {err:?}",
);
}
#[test]
fn model_cache_root_absolute_validates() {
let cfg = parse(
r#"
model-cache-root = "/var/cache/outrig/models"
"#,
);
cfg.validate(None).expect("absolute path is fine");
}
#[test]
fn mistralrs_relative_model_path_resolves_against_repo_root() {
let tmp = tempdir().unwrap();
let model_dir = tmp.path().join("models");
fs::create_dir_all(&model_dir).unwrap();
fs::write(model_dir.join("local.gguf"), b"\0").unwrap();
let cfg = parse(
r#"
[providers.local]
style = "mistralrs"
[models.local]
provider = "local"
model-path = "models/local.gguf"
"#,
);
cfg.validate(Some(tmp.path()))
.expect("relative model-path under repo_root resolves");
}
#[test]
fn mistralrs_relative_model_path_missing_under_repo_root_errors() {
let tmp = tempdir().unwrap();
let cfg = parse(
r#"
[providers.local]
style = "mistralrs"
[models.local]
provider = "local"
model-path = "models/missing.gguf"
"#,
);
let err = match cfg.validate(Some(tmp.path())) {
Err(OutrigError::ConfigValidation(e)) => e,
other => panic!("expected validation error, got: {other:?}"),
};
assert!(
matches!(
err,
ConfigValidationError::MistralrsModelPathMissing { ref model, .. }
if model == "local"
),
"got: {err:?}",
);
}