use std::collections::HashSet;
use tracing::warn;
use crate::types::tools::ToolChoice;
pub fn validate_config_keys(
provided_keys: &HashSet<&str>,
valid_keys: &HashSet<&str>,
config_name: &str,
) {
let invalid_keys: Vec<&&str> = provided_keys.difference(valid_keys).collect();
if !invalid_keys.is_empty() {
let mut invalid_sorted: Vec<_> = invalid_keys.iter().map(|k| **k).collect();
invalid_sorted.sort();
let mut valid_sorted: Vec<_> = valid_keys.iter().copied().collect();
valid_sorted.sort();
warn!(
config = config_name,
invalid_keys = ?invalid_sorted,
valid_keys = ?valid_sorted,
"Invalid configuration parameters provided"
);
}
}
pub fn warn_on_tool_choice_not_supported(tool_choice: Option<&ToolChoice>, provider_name: &str) {
if tool_choice.is_some() {
warn!(
provider = provider_name,
"A ToolChoice was provided to this provider but is not supported and will be ignored"
);
}
}
pub mod config_keys {
use std::collections::HashSet;
pub fn base_config_keys() -> HashSet<&'static str> {
HashSet::from([
"model_id",
"max_tokens",
"temperature",
"top_p",
"stop_sequences",
])
}
pub fn bedrock_config_keys() -> HashSet<&'static str> {
let mut keys = base_config_keys();
keys.extend([
"guardrail_id",
"guardrail_version",
"guardrail_trace",
"cache_prompt",
"cache_tools",
"include_tool_result_status",
"additional_request_fields",
"additional_response_field_paths",
]);
keys
}
pub fn openai_config_keys() -> HashSet<&'static str> {
let mut keys = base_config_keys();
keys.extend([
"api_key",
"base_url",
"organization",
"frequency_penalty",
"presence_penalty",
"seed",
"response_format",
]);
keys
}
pub fn anthropic_config_keys() -> HashSet<&'static str> {
let mut keys = base_config_keys();
keys.extend([
"api_key",
"base_url",
]);
keys
}
pub fn ollama_config_keys() -> HashSet<&'static str> {
let mut keys = base_config_keys();
keys.extend([
"host",
"timeout",
]);
keys
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_validate_config_keys_no_invalid() {
let provided = HashSet::from(["model_id", "max_tokens"]);
let valid = HashSet::from(["model_id", "max_tokens", "temperature"]);
validate_config_keys(&provided, &valid, "TestConfig");
}
#[test]
fn test_validate_config_keys_with_invalid() {
let provided = HashSet::from(["model_id", "invalid_key", "another_invalid"]);
let valid = HashSet::from(["model_id", "max_tokens"]);
validate_config_keys(&provided, &valid, "TestConfig");
}
#[test]
fn test_warn_on_tool_choice_not_supported_none() {
warn_on_tool_choice_not_supported(None, "TestProvider");
}
#[test]
fn test_warn_on_tool_choice_not_supported_some() {
let tool_choice = ToolChoice::auto();
warn_on_tool_choice_not_supported(Some(&tool_choice), "TestProvider");
}
#[test]
fn test_bedrock_config_keys() {
let keys = config_keys::bedrock_config_keys();
assert!(keys.contains("model_id"));
assert!(keys.contains("guardrail_id"));
assert!(keys.contains("cache_prompt"));
}
}