use crate::api::ModelTask;
use crate::error::{Result, RuntimeError};
use serde_json::Value;
pub fn validate_provider_options(
provider_id: &str,
task: ModelTask,
options: &Value,
) -> Result<()> {
match provider_id {
"remote/openai" | "remote/mistral" | "remote/voyageai" => {
validate_with_embedding_dimensions(provider_id, task, options, &["api_key_env"])
}
"remote/gemini" => validate_with_embedding_dimensions(
provider_id,
task,
options,
&["api_key_env", "api_version"],
),
"remote/anthropic" => {
validate_string_keys_only(provider_id, options, &["api_key_env", "anthropic_version"])
}
"remote/cohere" => validate_with_embedding_dimensions(
provider_id,
task,
options,
&["api_key_env", "input_type"],
),
"remote/azure-openai" => validate_azure_openai_options(provider_id, task, options),
"remote/vertexai" => validate_vertexai_options(provider_id, task, options),
"local/candle" | "local/fastembed" => {
validate_string_keys_only(provider_id, options, &["cache_dir"])
}
"local/mistralrs" => validate_mistralrs_options(provider_id, task, options),
_ => Ok(()),
}
}
fn as_object<'a>(
provider_id: &str,
options: &'a Value,
) -> Result<Option<&'a serde_json::Map<String, Value>>> {
match options {
Value::Null => Ok(None),
Value::Object(map) => Ok(Some(map)),
_ => Err(RuntimeError::Config(format!(
"Options for provider '{}' must be a JSON object or null",
provider_id
))),
}
}
fn reject_unknown_keys(
provider_id: &str,
map: &serde_json::Map<String, Value>,
allowed: &[&str],
) -> Result<()> {
for key in map.keys() {
if !allowed.contains(&key.as_str()) {
return Err(RuntimeError::Config(format!(
"Unknown option '{}' for provider '{}'",
key, provider_id
)));
}
}
Ok(())
}
fn require_string_keys(
provider_id: &str,
map: &serde_json::Map<String, Value>,
keys: &[&str],
) -> Result<()> {
for key in keys {
if let Some(value) = map.get(*key)
&& !value.is_string()
{
return Err(RuntimeError::Config(format!(
"Option '{}' for provider '{}' must be a string",
key, provider_id
)));
}
}
Ok(())
}
fn require_positive_u64(
provider_id: &str,
map: &serde_json::Map<String, Value>,
key: &str,
) -> Result<()> {
if let Some(value) = map.get(key) {
let Some(v) = value.as_u64() else {
return Err(RuntimeError::Config(format!(
"Option '{}' for provider '{}' must be a positive integer",
key, provider_id
)));
};
if v == 0 {
return Err(RuntimeError::Config(format!(
"Option '{}' for provider '{}' must be greater than 0",
key, provider_id
)));
}
}
Ok(())
}
fn require_embedding_dimensions(
provider_id: &str,
task: ModelTask,
map: &serde_json::Map<String, Value>,
) -> Result<()> {
if map.contains_key("embedding_dimensions") {
require_positive_u64(provider_id, map, "embedding_dimensions")?;
if task != ModelTask::Embed {
return Err(RuntimeError::Config(
"Option 'embedding_dimensions' is only valid for embed tasks".to_string(),
));
}
}
Ok(())
}
fn validate_string_keys_only(
provider_id: &str,
options: &Value,
allowed_keys: &[&str],
) -> Result<()> {
let Some(map) = as_object(provider_id, options)? else {
return Ok(());
};
reject_unknown_keys(provider_id, map, allowed_keys)?;
require_string_keys(provider_id, map, allowed_keys)
}
fn validate_with_embedding_dimensions(
provider_id: &str,
task: ModelTask,
options: &Value,
string_keys: &[&str],
) -> Result<()> {
let Some(map) = as_object(provider_id, options)? else {
return Ok(());
};
let mut all_keys: Vec<&str> = string_keys.to_vec();
all_keys.push("embedding_dimensions");
reject_unknown_keys(provider_id, map, &all_keys)?;
require_string_keys(provider_id, map, string_keys)?;
require_embedding_dimensions(provider_id, task, map)
}
fn validate_azure_openai_options(
provider_id: &str,
task: ModelTask,
options: &Value,
) -> Result<()> {
let Some(map) = as_object(provider_id, options)? else {
return Ok(());
};
reject_unknown_keys(
provider_id,
map,
&[
"api_key_env",
"resource_name",
"api_version",
"embedding_dimensions",
],
)?;
require_string_keys(
provider_id,
map,
&["api_key_env", "resource_name", "api_version"],
)?;
require_embedding_dimensions(provider_id, task, map)
}
fn validate_vertexai_options(provider_id: &str, task: ModelTask, options: &Value) -> Result<()> {
let Some(map) = as_object(provider_id, options)? else {
return Ok(());
};
reject_unknown_keys(
provider_id,
map,
&[
"api_token_env",
"project_id",
"location",
"publisher",
"embedding_dimensions",
],
)?;
require_string_keys(
provider_id,
map,
&["api_token_env", "project_id", "location", "publisher"],
)?;
require_embedding_dimensions(provider_id, task, map)
}
fn validate_mistralrs_options(provider_id: &str, task: ModelTask, options: &Value) -> Result<()> {
let Some(map) = as_object(provider_id, options)? else {
return Ok(());
};
reject_unknown_keys(
provider_id,
map,
&[
"isq",
"force_cpu",
"paged_attention",
"max_num_seqs",
"chat_template",
"tokenizer_json",
"embedding_dimensions",
"gguf_files",
"dtype",
"pipeline",
"diffusion_loader_type",
"speech_loader_type",
],
)?;
let pipeline = if let Some(value) = map.get("pipeline") {
let s = value.as_str().ok_or_else(|| {
RuntimeError::Config(format!(
"Option 'pipeline' for provider '{}' must be a string",
provider_id
))
})?;
let valid = ["text", "vision", "diffusion", "speech"];
if !valid.contains(&s) {
return Err(RuntimeError::Config(format!(
"Option 'pipeline' for provider '{}' must be one of: text, vision, diffusion, speech",
provider_id
)));
}
s
} else {
"text"
};
require_string_keys(provider_id, map, &["dtype"])?;
if let Some(value) = map.get("dtype") {
if let Some(s) = value.as_str() {
let valid = ["auto", "f16", "bf16", "f32"];
if !valid.contains(&s.to_lowercase().as_str()) {
return Err(RuntimeError::Config(format!(
"Option 'dtype' for provider '{}' must be one of: auto, f16, bf16, f32",
provider_id
)));
}
}
}
if let Some(value) = map.get("force_cpu")
&& !value.is_boolean()
{
return Err(RuntimeError::Config(format!(
"Option 'force_cpu' for provider '{}' must be a boolean",
provider_id
)));
}
match pipeline {
"vision" => {
if map.contains_key("gguf_files") {
return Err(RuntimeError::Config(
"Option 'gguf_files' is not supported for the vision pipeline".to_string(),
));
}
if map.contains_key("embedding_dimensions") {
return Err(RuntimeError::Config(
"Option 'embedding_dimensions' is not supported for the vision pipeline"
.to_string(),
));
}
require_string_keys(
provider_id,
map,
&["isq", "chat_template", "tokenizer_json"],
)?;
if let Some(value) = map.get("paged_attention")
&& !value.is_boolean()
{
return Err(RuntimeError::Config(format!(
"Option 'paged_attention' for provider '{}' must be a boolean",
provider_id
)));
}
require_positive_u64(provider_id, map, "max_num_seqs")?;
}
"diffusion" => {
if let Some(value) = map.get("diffusion_loader_type") {
let s = value.as_str().ok_or_else(|| {
RuntimeError::Config(format!(
"Option 'diffusion_loader_type' for provider '{}' must be a string",
provider_id
))
})?;
let valid = ["flux", "flux_offloaded"];
if !valid.contains(&s) {
return Err(RuntimeError::Config(format!(
"Option 'diffusion_loader_type' for provider '{}' must be one of: flux, flux_offloaded",
provider_id
)));
}
}
for key in [
"isq",
"paged_attention",
"max_num_seqs",
"chat_template",
"tokenizer_json",
"embedding_dimensions",
"gguf_files",
"speech_loader_type",
] {
if map.contains_key(key) {
return Err(RuntimeError::Config(format!(
"Option '{}' is not supported for the diffusion pipeline",
key
)));
}
}
}
"speech" => {
if let Some(value) = map.get("speech_loader_type") {
let s = value.as_str().ok_or_else(|| {
RuntimeError::Config(format!(
"Option 'speech_loader_type' for provider '{}' must be a string",
provider_id
))
})?;
let valid = ["dia"];
if !valid.contains(&s) {
return Err(RuntimeError::Config(format!(
"Option 'speech_loader_type' for provider '{}' must be one of: dia",
provider_id
)));
}
}
for key in [
"isq",
"paged_attention",
"max_num_seqs",
"chat_template",
"tokenizer_json",
"embedding_dimensions",
"gguf_files",
"diffusion_loader_type",
] {
if map.contains_key(key) {
return Err(RuntimeError::Config(format!(
"Option '{}' is not supported for the speech pipeline",
key
)));
}
}
}
_ => {
require_string_keys(
provider_id,
map,
&["isq", "chat_template", "tokenizer_json"],
)?;
if let Some(value) = map.get("paged_attention")
&& !value.is_boolean()
{
return Err(RuntimeError::Config(format!(
"Option 'paged_attention' for provider '{}' must be a boolean",
provider_id
)));
}
require_positive_u64(provider_id, map, "max_num_seqs")?;
require_embedding_dimensions(provider_id, task, map)?;
if let Some(value) = map.get("gguf_files") {
let Some(items) = value.as_array() else {
return Err(RuntimeError::Config(format!(
"Option 'gguf_files' for provider '{}' must be an array of strings",
provider_id
)));
};
if items.iter().any(|item| !item.is_string()) {
return Err(RuntimeError::Config(format!(
"Option 'gguf_files' for provider '{}' must be an array of strings",
provider_id
)));
}
}
}
}
Ok(())
}