use std::fmt::Display;
pub const MIN_TEMPERATURE: f32 = 0.0;
pub const MAX_TEMPERATURE: f32 = 2.0;
pub const TEMPERATURE_RANGE: (f32, f32) = (MIN_TEMPERATURE, MAX_TEMPERATURE);
pub const MIN_TOP_P: f32 = 0.0;
pub const MAX_TOP_P: f32 = 1.0;
pub const TOP_P_RANGE: (f32, f32) = (MIN_TOP_P, MAX_TOP_P);
pub const MIN_MIN_P: f32 = 0.0;
pub const MAX_MIN_P: f32 = 1.0;
pub const MIN_P_RANGE: (f32, f32) = (MIN_MIN_P, MAX_MIN_P);
pub const MIN_FREQUENCY_PENALTY: f32 = -2.0;
pub const MAX_FREQUENCY_PENALTY: f32 = 2.0;
pub const FREQUENCY_PENALTY_RANGE: (f32, f32) = (MIN_FREQUENCY_PENALTY, MAX_FREQUENCY_PENALTY);
pub const MIN_PRESENCE_PENALTY: f32 = -2.0;
pub const MAX_PRESENCE_PENALTY: f32 = 2.0;
pub const PRESENCE_PENALTY_RANGE: (f32, f32) = (MIN_PRESENCE_PENALTY, MAX_PRESENCE_PENALTY);
pub const MIN_LENGTH_PENALTY: f32 = -2.0;
pub const MAX_LENGTH_PENALTY: f32 = 2.0;
pub const LENGTH_PENALTY_RANGE: (f32, f32) = (MIN_LENGTH_PENALTY, MAX_LENGTH_PENALTY);
pub const MIN_TOP_LOGPROBS: u8 = 0;
pub const MAX_TOP_LOGPROBS: u8 = 20;
pub const MIN_LOGPROBS: u8 = 0;
pub const MAX_LOGPROBS: u8 = 5;
pub const MIN_N: u8 = 1;
pub const MAX_N: u8 = 128;
pub const N_RANGE: (u8, u8) = (MIN_N, MAX_N);
pub const MAX_TOTAL_CHOICES: usize = 128;
pub const MIN_LOGIT_BIAS: f32 = -100.0;
pub const MAX_LOGIT_BIAS: f32 = 100.0;
pub const MIN_BEST_OF: u8 = 0;
pub const MAX_BEST_OF: u8 = 20;
pub const BEST_OF_RANGE: (u8, u8) = (MIN_BEST_OF, MAX_BEST_OF);
pub const MAX_STOP_SEQUENCES: usize = 32;
pub const MAX_TOOLS: usize = 1536;
pub const MAX_FUNCTION_NAME_LENGTH: usize = 96;
pub const MIN_REPETITION_PENALTY: f32 = 0.0;
pub const MAX_REPETITION_PENALTY: f32 = 2.0;
pub fn validate_no_unsupported_fields(
unsupported_fields: &std::collections::HashMap<String, serde_json::Value>,
) -> Result<(), anyhow::Error> {
if !unsupported_fields.is_empty() {
let fields: Vec<_> = unsupported_fields
.keys()
.map(|s| format!("`{}`", s))
.collect();
anyhow::bail!("Unsupported parameter(s): {}", fields.join(", "));
}
Ok(())
}
pub fn validate_response_format(
response_format: &Option<dynamo_async_openai::types::ResponseFormat>,
) -> Result<(), anyhow::Error> {
use dynamo_async_openai::types::ResponseFormat;
let Some(fmt) = response_format else {
return Ok(());
};
match fmt {
ResponseFormat::Text => Ok(()),
ResponseFormat::JsonObject => Ok(()),
ResponseFormat::JsonSchema { json_schema } => {
if json_schema.name.is_empty() {
anyhow::bail!("`response_format.json_schema.name` cannot be empty");
}
if json_schema.schema.is_none() {
anyhow::bail!(
"`response_format.json_schema.schema` is required when `response_format.type` is `json_schema`"
);
}
Ok(())
}
}
}
pub fn validate_temperature(temperature: Option<f32>) -> Result<(), anyhow::Error> {
if let Some(temp) = temperature
&& !(MIN_TEMPERATURE..=MAX_TEMPERATURE).contains(&temp)
{
anyhow::bail!(
"Temperature must be between {} and {}, got {}",
MIN_TEMPERATURE,
MAX_TEMPERATURE,
temp
);
}
Ok(())
}
pub fn validate_top_p(top_p: Option<f32>) -> Result<(), anyhow::Error> {
if let Some(p) = top_p
&& !(MIN_TOP_P..=MAX_TOP_P).contains(&p)
{
anyhow::bail!(
"Top_p must be between {} and {}, got {}",
MIN_TOP_P,
MAX_TOP_P,
p
);
}
Ok(())
}
pub fn validate_top_k(top_k: Option<i32>) -> Result<(), anyhow::Error> {
match top_k {
None => Ok(()),
Some(k) if k == -1 || k >= 1 => Ok(()),
_ => anyhow::bail!("Top_k must be null, -1, or greater than or equal to 1"),
}
}
pub fn validate_temperature_top_p_exclusion(
temperature: Option<f32>,
top_p: Option<f32>,
) -> Result<(), anyhow::Error> {
match (temperature, top_p) {
(Some(t), Some(p)) if t != 1.0 && p != 1.0 => {
anyhow::bail!("Only one of temperature or top_p should be set (not both)");
}
_ => Ok(()),
}
}
pub fn validate_frequency_penalty(frequency_penalty: Option<f32>) -> Result<(), anyhow::Error> {
if let Some(penalty) = frequency_penalty
&& !(MIN_FREQUENCY_PENALTY..=MAX_FREQUENCY_PENALTY).contains(&penalty)
{
anyhow::bail!(
"Frequency penalty must be between {} and {}, got {}",
MIN_FREQUENCY_PENALTY,
MAX_FREQUENCY_PENALTY,
penalty
);
}
Ok(())
}
pub fn validate_presence_penalty(presence_penalty: Option<f32>) -> Result<(), anyhow::Error> {
if let Some(penalty) = presence_penalty
&& !(MIN_PRESENCE_PENALTY..=MAX_PRESENCE_PENALTY).contains(&penalty)
{
anyhow::bail!(
"Presence penalty must be between {} and {}, got {}",
MIN_PRESENCE_PENALTY,
MAX_PRESENCE_PENALTY,
penalty
);
}
Ok(())
}
pub fn validate_repetition_penalty(repetition_penalty: Option<f32>) -> Result<(), anyhow::Error> {
if let Some(penalty) = repetition_penalty
&& (penalty <= MIN_REPETITION_PENALTY || penalty > MAX_REPETITION_PENALTY)
{
anyhow::bail!(
"Repetition penalty must be between {} and {}, got {}",
MIN_REPETITION_PENALTY,
MAX_REPETITION_PENALTY,
penalty
);
}
Ok(())
}
pub fn validate_min_p(min_p: Option<f32>) -> Result<(), anyhow::Error> {
if let Some(p) = min_p
&& !(MIN_MIN_P..=MAX_MIN_P).contains(&p)
{
anyhow::bail!(
"Min_p must be between {} and {}, got {}",
MIN_MIN_P,
MAX_MIN_P,
p
);
}
Ok(())
}
pub fn validate_logit_bias(
logit_bias: &Option<std::collections::HashMap<String, serde_json::Value>>,
) -> Result<(), anyhow::Error> {
let logit_bias = match logit_bias {
Some(val) => val,
None => return Ok(()),
};
for (token, bias_value) in logit_bias {
let bias = bias_value.as_f64().ok_or_else(|| {
anyhow::anyhow!(
"Logit bias value for token '{}' must be a number, got {:?}",
token,
bias_value
)
})? as f32;
if !(MIN_LOGIT_BIAS..=MAX_LOGIT_BIAS).contains(&bias) {
anyhow::bail!(
"Logit bias for token '{}' must be between {} and {}, got {}",
token,
MIN_LOGIT_BIAS,
MAX_LOGIT_BIAS,
bias
);
}
}
Ok(())
}
pub fn validate_n(n: Option<u8>) -> Result<(), anyhow::Error> {
if let Some(value) = n
&& !(MIN_N..=MAX_N).contains(&value)
{
anyhow::bail!("n must be between {} and {}, got {}", MIN_N, MAX_N, value);
}
Ok(())
}
pub fn validate_total_choices(batch_size: usize, n: u8) -> Result<(), anyhow::Error> {
let total_choices = batch_size * (n as usize);
if total_choices > MAX_TOTAL_CHOICES {
anyhow::bail!(
"Total choices (batch_size × n = {} × {} = {}) exceeds maximum of {}",
batch_size,
n,
total_choices,
MAX_TOTAL_CHOICES
);
}
Ok(())
}
pub fn validate_n_with_temperature(
n: Option<u8>,
temperature: Option<f32>,
) -> Result<(), anyhow::Error> {
if let Some(n_value) = n
&& n_value > 1
{
let temp = temperature.unwrap_or(1.0);
if temp == 0.0 {
anyhow::bail!(
"When n > 1, temperature must be greater than 0 to ensure diverse outputs. Got n={}, temperature={}",
n_value,
temp
);
}
}
Ok(())
}
pub fn validate_model(model: &str) -> Result<(), anyhow::Error> {
if model.trim().is_empty() {
anyhow::bail!("Model cannot be empty");
}
Ok(())
}
pub fn validate_user(user: Option<&str>) -> Result<(), anyhow::Error> {
if let Some(user_id) = user
&& user_id.trim().is_empty()
{
anyhow::bail!("User ID cannot be empty");
}
Ok(())
}
pub fn validate_stop(stop: &Option<dynamo_async_openai::types::Stop>) -> Result<(), anyhow::Error> {
if let Some(stop_value) = stop {
match stop_value {
dynamo_async_openai::types::Stop::String(s) => {
if s.is_empty() {
anyhow::bail!("Stop sequence cannot be empty");
}
}
dynamo_async_openai::types::Stop::StringArray(sequences) => {
if sequences.is_empty() {
anyhow::bail!("Stop sequences array cannot be empty");
}
if sequences.len() > MAX_STOP_SEQUENCES {
anyhow::bail!(
"Maximum of {} stop sequences allowed, got {}",
MAX_STOP_SEQUENCES,
sequences.len()
);
}
for (i, sequence) in sequences.iter().enumerate() {
if sequence.is_empty() {
anyhow::bail!("Stop sequence at index {} cannot be empty", i);
}
}
}
}
}
Ok(())
}
pub fn validate_messages(
messages: &[dynamo_async_openai::types::ChatCompletionRequestMessage],
) -> Result<(), anyhow::Error> {
if messages.is_empty() {
anyhow::bail!("Messages array cannot be empty");
}
Ok(())
}
pub fn validate_top_logprobs(top_logprobs: Option<u8>) -> Result<(), anyhow::Error> {
if let Some(value) = top_logprobs
&& !(0..=20).contains(&value)
{
anyhow::bail!(
"Top_logprobs must be between 0 and {}, got {}",
MAX_TOP_LOGPROBS,
value
);
}
Ok(())
}
pub fn validate_tools(
tools: &Option<&[dynamo_async_openai::types::ChatCompletionTool]>,
) -> Result<(), anyhow::Error> {
let tools = match tools {
Some(val) => val,
None => return Ok(()),
};
if tools.len() > MAX_TOOLS {
anyhow::bail!(
"Maximum of {} tools are supported, got {}",
MAX_TOOLS,
tools.len()
);
}
for (i, tool) in tools.iter().enumerate() {
if tool.function.name.len() > MAX_FUNCTION_NAME_LENGTH {
anyhow::bail!(
"Function name at index {} exceeds {} character limit, got {} characters",
i,
MAX_FUNCTION_NAME_LENGTH,
tool.function.name.len()
);
}
if tool.function.name.trim().is_empty() {
anyhow::bail!("Function name at index {} cannot be empty", i);
}
}
Ok(())
}
pub fn validate_reasoning_effort(
_reasoning_effort: &Option<dynamo_async_openai::types::ReasoningEffort>,
) -> Result<(), anyhow::Error> {
Ok(())
}
pub fn validate_service_tier(
_service_tier: &Option<dynamo_async_openai::types::ServiceTier>,
) -> Result<(), anyhow::Error> {
Ok(())
}
pub fn validate_prompt(prompt: &dynamo_async_openai::types::Prompt) -> Result<(), anyhow::Error> {
match prompt {
dynamo_async_openai::types::Prompt::String(s) => {
if s.is_empty() {
anyhow::bail!("Prompt string cannot be empty");
}
}
dynamo_async_openai::types::Prompt::StringArray(arr) => {
if arr.is_empty() {
anyhow::bail!("Prompt string array cannot be empty");
}
for (i, s) in arr.iter().enumerate() {
if s.is_empty() {
anyhow::bail!("Prompt string at index {} cannot be empty", i);
}
}
}
dynamo_async_openai::types::Prompt::IntegerArray(arr) => {
if arr.is_empty() {
anyhow::bail!("Prompt integer array cannot be empty");
}
}
dynamo_async_openai::types::Prompt::ArrayOfIntegerArray(arr) => {
if arr.is_empty() {
anyhow::bail!("Prompt array of integer arrays cannot be empty");
}
for (i, inner_arr) in arr.iter().enumerate() {
if inner_arr.is_empty() {
anyhow::bail!("Prompt integer array at index {} cannot be empty", i);
}
}
}
}
Ok(())
}
pub fn validate_prompt_or_embeds(
prompt: Option<&dynamo_async_openai::types::Prompt>,
prompt_embeds: Option<&str>,
) -> Result<(), anyhow::Error> {
if prompt.is_none() && prompt_embeds.is_none() {
anyhow::bail!("At least one of 'prompt' or 'prompt_embeds' must be provided");
}
if let Some(embeds) = prompt_embeds {
validate_prompt_embeds_format(embeds)?;
} else if let Some(p) = prompt {
validate_prompt(p)?;
}
Ok(())
}
fn validate_prompt_embeds_format(embeds: &str) -> Result<(), anyhow::Error> {
use base64::{Engine as _, engine::general_purpose};
let decoded = general_purpose::STANDARD
.decode(embeds)
.map_err(|_| anyhow::anyhow!("prompt_embeds must be valid base64-encoded data"))?;
const MIN_SIZE: usize = 100;
if decoded.len() < MIN_SIZE {
anyhow::bail!(
"prompt_embeds decoded data must be at least {MIN_SIZE} bytes, got {} bytes",
decoded.len()
);
}
const MAX_SIZE: usize = 10 * 1024 * 1024;
if decoded.len() > MAX_SIZE {
anyhow::bail!(
"prompt_embeds decoded data exceeds maximum size of 10MB, got {} bytes",
decoded.len()
);
}
Ok(())
}
pub fn validate_prompt_embeds(prompt_embeds: Option<&str>) -> Result<(), anyhow::Error> {
if let Some(embeds) = prompt_embeds {
validate_prompt_embeds_format(embeds)?;
}
Ok(())
}
pub fn validate_logprobs(logprobs: Option<u8>) -> Result<(), anyhow::Error> {
if let Some(value) = logprobs
&& !(MIN_LOGPROBS..=MAX_LOGPROBS).contains(&value)
{
anyhow::bail!(
"Logprobs must be between 0 and {}, got {}",
MAX_LOGPROBS,
value
);
}
Ok(())
}
pub fn validate_best_of(best_of: Option<u8>, n: Option<u8>) -> Result<(), anyhow::Error> {
if let Some(best_of_value) = best_of {
if !(MIN_BEST_OF..=MAX_BEST_OF).contains(&best_of_value) {
anyhow::bail!(
"Best_of must be between 0 and {}, got {}",
MAX_BEST_OF,
best_of_value
);
}
if let Some(n_value) = n
&& best_of_value < n_value
{
anyhow::bail!(
"Best_of must be greater than or equal to n, got best_of={} and n={}",
best_of_value,
n_value
);
}
}
Ok(())
}
pub fn validate_suffix(suffix: Option<&str>) -> Result<(), anyhow::Error> {
if let Some(suffix_str) = suffix {
if suffix_str.len() > 10000 {
anyhow::bail!("Suffix is too long, maximum 10000 characters");
}
}
Ok(())
}
pub fn validate_max_tokens(max_tokens: Option<u32>) -> Result<(), anyhow::Error> {
if let Some(tokens) = max_tokens
&& tokens == 0
{
anyhow::bail!("Max tokens must be greater than 0, got {}", tokens);
}
Ok(())
}
pub fn validate_max_completion_tokens(
max_completion_tokens: Option<u32>,
) -> Result<(), anyhow::Error> {
if let Some(tokens) = max_completion_tokens
&& tokens == 0
{
anyhow::bail!(
"Max completion tokens must be greater than 0, got {}",
tokens
);
}
Ok(())
}
pub fn validate_range<T>(value: Option<T>, range: &(T, T)) -> anyhow::Result<Option<T>>
where
T: PartialOrd + Display,
{
if value.is_none() {
return Ok(None);
}
let value = value.unwrap();
if value < range.0 || value > range.1 {
anyhow::bail!("Value {} is out of range [{}, {}]", value, range.0, range.1);
}
Ok(Some(value))
}