use crate::error::LlmError;
use crate::types::{CommonParams, ProviderParams, ProviderType};
pub trait ProviderParamsExt {
fn provider_type(&self) -> ProviderType;
}
pub struct ParameterValidator;
impl ParameterValidator {
pub fn validate_temperature(
temp: f64,
_min: f64, _suggested_max: f64, provider: &str,
) -> Result<(), LlmError> {
if temp < 0.0 {
return Err(LlmError::InvalidParameter(format!(
"temperature must be non-negative for {provider}, got {temp}"
)));
}
Ok(())
}
pub fn validate_top_p(top_p: f64) -> Result<(), LlmError> {
if !(0.0..=1.0).contains(&top_p) {
return Err(LlmError::InvalidParameter(
"top_p must be between 0.0 and 1.0".to_string(),
));
}
Ok(())
}
pub fn validate_max_tokens(
max_tokens: u64,
_min: u64, suggested_max: u64,
provider: &str,
) -> Result<(), LlmError> {
if max_tokens == 0 {
return Err(LlmError::InvalidParameter(format!(
"max_tokens must be positive for {provider}, got {max_tokens}"
)));
}
if max_tokens > suggested_max {
}
Ok(())
}
pub fn validate_numeric_range<T: PartialOrd + std::fmt::Display>(
value: T,
min: T,
max: T,
param_name: &str,
provider: &str,
) -> Result<(), LlmError> {
if value < min || value > max {
return Err(LlmError::InvalidParameter(format!(
"{param_name} must be between {min} and {max} for {provider}"
)));
}
Ok(())
}
}
pub struct ParameterMapper;
impl ParameterMapper {
pub fn map_common_to_json(params: &CommonParams) -> serde_json::Value {
let mut json = serde_json::json!({
"model": params.model
});
if let Some(temp) = params.temperature {
json["temperature"] = temp.into();
}
if let Some(max_tokens) = params.max_tokens {
json["max_tokens"] = max_tokens.into();
}
if let Some(top_p) = params.top_p {
json["top_p"] = top_p.into();
}
if let Some(seed) = params.seed {
json["seed"] = seed.into();
}
json
}
pub fn merge_provider_params(
mut base: serde_json::Value,
provider: &ProviderParams,
) -> serde_json::Value {
if let serde_json::Value::Object(ref mut base_obj) = base {
for (key, value) in &provider.params {
if !value.is_null() {
base_obj.insert(key.clone(), value.clone());
}
}
}
base
}
pub fn map_stop_sequences(
stop_sequences: &Option<Vec<String>>,
field_name: &str,
) -> Option<(String, serde_json::Value)> {
stop_sequences
.as_ref()
.map(|stop| (field_name.to_string(), stop.clone().into()))
}
}
pub struct ParameterConverter;
impl ParameterConverter {
pub fn convert_param_name(common_name: &str, provider_type: &ProviderType) -> String {
match (common_name, provider_type) {
("max_tokens", ProviderType::Gemini) => "maxOutputTokens".to_string(),
("top_p", ProviderType::Gemini) => "topP".to_string(),
("stop_sequences", ProviderType::Gemini) => "stopSequences".to_string(),
("stop_sequences", ProviderType::Anthropic) => "stop_sequences".to_string(),
("stop_sequences", ProviderType::OpenAi) => "stop".to_string(),
_ => common_name.to_string(),
}
}
pub fn convert_param_value(
value: &serde_json::Value,
_param_name: &str,
_provider_type: &ProviderType,
) -> serde_json::Value {
value.clone()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parameter_validator() {
assert!(ParameterValidator::validate_temperature(0.7, 0.0, 2.0, "test").is_ok());
assert!(ParameterValidator::validate_temperature(3.0, 0.0, 2.0, "test").is_ok()); assert!(ParameterValidator::validate_temperature(-1.0, 0.0, 2.0, "test").is_err());
assert!(ParameterValidator::validate_top_p(0.9).is_ok());
assert!(ParameterValidator::validate_top_p(1.5).is_err());
assert!(ParameterValidator::validate_max_tokens(1000, 1, 200_000, "test").is_ok());
assert!(ParameterValidator::validate_max_tokens(500_000, 1, 200_000, "test").is_ok()); assert!(ParameterValidator::validate_max_tokens(0, 1, 200_000, "test").is_err()); }
#[test]
fn test_parameter_converter() {
assert_eq!(
ParameterConverter::convert_param_name("max_tokens", &ProviderType::Gemini),
"maxOutputTokens"
);
assert_eq!(
ParameterConverter::convert_param_name("max_tokens", &ProviderType::OpenAi),
"max_tokens"
);
assert_eq!(
ParameterConverter::convert_param_name("stop_sequences", &ProviderType::OpenAi),
"stop"
);
assert_eq!(
ParameterConverter::convert_param_name("stop_sequences", &ProviderType::Anthropic),
"stop_sequences"
);
}
#[test]
fn test_common_parameter_mapping() {
let params = CommonParams {
model: "test-model".to_string(),
temperature: Some(0.7),
max_tokens: Some(1000),
top_p: Some(0.9),
stop_sequences: None,
seed: Some(42),
};
let json = ParameterMapper::map_common_to_json(¶ms);
assert_eq!(json["model"], "test-model");
assert_eq!(json["max_tokens"], 1000);
assert_eq!(json["seed"], 42);
}
#[test]
fn test_merge_provider_params_skips_null_values() {
use std::collections::HashMap;
let base = serde_json::json!({
"model": "gpt-4",
"temperature": 0.7
});
let mut params_map = HashMap::new();
params_map.insert("parallel_tool_calls".to_string(), serde_json::Value::Null);
params_map.insert("frequency_penalty".to_string(), serde_json::json!(0.1));
params_map.insert("presence_penalty".to_string(), serde_json::Value::Null);
let provider_params = ProviderParams { params: params_map };
let result = ParameterMapper::merge_provider_params(base, &provider_params);
assert!(
!result
.as_object()
.unwrap()
.contains_key("parallel_tool_calls")
);
assert!(!result.as_object().unwrap().contains_key("presence_penalty"));
assert_eq!(result["frequency_penalty"], 0.1);
assert_eq!(result["model"], "gpt-4");
assert_eq!(result["temperature"], 0.7);
}
}