use crate::api::ModelTask;
use crate::error::{Result, RuntimeError};
use serde_json::Value;
pub fn validate_provider_options(
provider_id: &str,
task: ModelTask,
options: &Value,
) -> Result<()> {
match provider_id {
"remote/openai" | "remote/gemini" | "remote/mistral" | "remote/voyageai" => {
validate_string_keys_only(provider_id, options, &["api_key_env"])
}
"remote/anthropic" => {
validate_string_keys_only(provider_id, options, &["api_key_env", "anthropic_version"])
}
"remote/cohere" => {
validate_string_keys_only(provider_id, options, &["api_key_env", "input_type"])
}
"remote/azure-openai" => validate_string_keys_only(
provider_id,
options,
&["api_key_env", "resource_name", "api_version"],
),
"remote/vertexai" => validate_vertexai_options(provider_id, task, options),
"local/candle" | "local/fastembed" => {
validate_string_keys_only(provider_id, options, &["cache_dir"])
}
"local/mistralrs" => validate_mistralrs_options(provider_id, task, options),
_ => Ok(()),
}
}
fn as_object<'a>(
provider_id: &str,
options: &'a Value,
) -> Result<Option<&'a serde_json::Map<String, Value>>> {
match options {
Value::Null => Ok(None),
Value::Object(map) => Ok(Some(map)),
_ => Err(RuntimeError::Config(format!(
"Options for provider '{}' must be a JSON object or null",
provider_id
))),
}
}
fn reject_unknown_keys(
provider_id: &str,
map: &serde_json::Map<String, Value>,
allowed: &[&str],
) -> Result<()> {
for key in map.keys() {
if !allowed.contains(&key.as_str()) {
return Err(RuntimeError::Config(format!(
"Unknown option '{}' for provider '{}'",
key, provider_id
)));
}
}
Ok(())
}
fn require_string_keys(
provider_id: &str,
map: &serde_json::Map<String, Value>,
keys: &[&str],
) -> Result<()> {
for key in keys {
if let Some(value) = map.get(*key)
&& !value.is_string()
{
return Err(RuntimeError::Config(format!(
"Option '{}' for provider '{}' must be a string",
key, provider_id
)));
}
}
Ok(())
}
fn require_positive_u64(
provider_id: &str,
map: &serde_json::Map<String, Value>,
key: &str,
) -> Result<()> {
if let Some(value) = map.get(key) {
let Some(v) = value.as_u64() else {
return Err(RuntimeError::Config(format!(
"Option '{}' for provider '{}' must be a positive integer",
key, provider_id
)));
};
if v == 0 {
return Err(RuntimeError::Config(format!(
"Option '{}' for provider '{}' must be greater than 0",
key, provider_id
)));
}
}
Ok(())
}
fn require_embedding_dimensions(
provider_id: &str,
task: ModelTask,
map: &serde_json::Map<String, Value>,
) -> Result<()> {
if map.contains_key("embedding_dimensions") {
require_positive_u64(provider_id, map, "embedding_dimensions")?;
if task != ModelTask::Embed {
return Err(RuntimeError::Config(
"Option 'embedding_dimensions' is only valid for embed tasks".to_string(),
));
}
}
Ok(())
}
fn validate_string_keys_only(
provider_id: &str,
options: &Value,
allowed_keys: &[&str],
) -> Result<()> {
let Some(map) = as_object(provider_id, options)? else {
return Ok(());
};
reject_unknown_keys(provider_id, map, allowed_keys)?;
require_string_keys(provider_id, map, allowed_keys)
}
fn validate_vertexai_options(provider_id: &str, task: ModelTask, options: &Value) -> Result<()> {
let Some(map) = as_object(provider_id, options)? else {
return Ok(());
};
reject_unknown_keys(
provider_id,
map,
&[
"api_token_env",
"project_id",
"location",
"publisher",
"embedding_dimensions",
],
)?;
require_string_keys(
provider_id,
map,
&["api_token_env", "project_id", "location", "publisher"],
)?;
require_embedding_dimensions(provider_id, task, map)
}
fn validate_mistralrs_options(provider_id: &str, task: ModelTask, options: &Value) -> Result<()> {
let Some(map) = as_object(provider_id, options)? else {
return Ok(());
};
reject_unknown_keys(
provider_id,
map,
&[
"isq",
"force_cpu",
"paged_attention",
"max_num_seqs",
"chat_template",
"tokenizer_json",
"embedding_dimensions",
"gguf_files",
],
)?;
require_string_keys(
provider_id,
map,
&["isq", "chat_template", "tokenizer_json"],
)?;
for key in ["force_cpu", "paged_attention"] {
if let Some(value) = map.get(key)
&& !value.is_boolean()
{
return Err(RuntimeError::Config(format!(
"Option '{}' for provider '{}' must be a boolean",
key, provider_id
)));
}
}
require_positive_u64(provider_id, map, "max_num_seqs")?;
require_embedding_dimensions(provider_id, task, map)?;
if let Some(value) = map.get("gguf_files") {
let Some(items) = value.as_array() else {
return Err(RuntimeError::Config(format!(
"Option 'gguf_files' for provider '{}' must be an array of strings",
provider_id
)));
};
if items.iter().any(|item| !item.is_string()) {
return Err(RuntimeError::Config(format!(
"Option 'gguf_files' for provider '{}' must be an array of strings",
provider_id
)));
}
}
Ok(())
}