#![allow(dead_code)]
use crate::core::models::openai::MessageContent;
use crate::utils::error::{GatewayError, Result};
use regex::Regex;
use std::collections::HashSet;
use uuid::Uuid;
pub struct RequestValidator;
impl RequestValidator {
pub fn validate_chat_completion_request(
model: &str,
messages: &[crate::core::models::openai::ChatMessage],
max_tokens: Option<u32>,
temperature: Option<f32>,
) -> Result<()> {
Self::validate_model_name(model)?;
if messages.is_empty() {
return Err(GatewayError::Validation(
"Messages cannot be empty".to_string(),
));
}
for (i, message) in messages.iter().enumerate() {
Self::validate_chat_message(message, i)?;
}
if let Some(max_tokens) = max_tokens {
if max_tokens == 0 {
return Err(GatewayError::Validation(
"max_tokens must be greater than 0".to_string(),
));
}
if max_tokens > 100000 {
return Err(GatewayError::Validation(
"max_tokens cannot exceed 100000".to_string(),
));
}
}
if let Some(temperature) = temperature {
if !(0.0..=2.0).contains(&temperature) {
return Err(GatewayError::Validation(
"temperature must be between 0.0 and 2.0".to_string(),
));
}
}
Ok(())
}
fn validate_chat_message(
message: &crate::core::models::openai::ChatMessage,
index: usize,
) -> Result<()> {
use crate::core::models::openai::MessageRole;
match message.role {
MessageRole::System | MessageRole::User | MessageRole::Assistant => {
if message.content.is_none() {
return Err(GatewayError::Validation(format!(
"Message at index {} with role {:?} must have content",
index, message.role
)));
}
}
MessageRole::Function => {
if message.name.is_none() {
return Err(GatewayError::Validation(format!(
"Function message at index {} must have a name",
index
)));
}
if message.content.is_none() {
return Err(GatewayError::Validation(format!(
"Function message at index {} must have content",
index
)));
}
}
MessageRole::Tool => {
if message.tool_call_id.is_none() {
return Err(GatewayError::Validation(format!(
"Tool message at index {} must have tool_call_id",
index
)));
}
if message.content.is_none() {
return Err(GatewayError::Validation(format!(
"Tool message at index {} must have content",
index
)));
}
}
}
if let Some(content) = &message.content {
Self::validate_message_content(content, index)?;
}
if let Some(name) = &message.name {
Self::validate_function_name(name)?;
}
Ok(())
}
fn validate_message_content(
content: &crate::core::models::openai::MessageContent,
index: usize,
) -> Result<()> {
match content {
MessageContent::Text(text) => {
if text.trim().is_empty() {
return Err(GatewayError::Validation(format!(
"Text content at message index {} cannot be empty",
index
)));
}
if text.len() > 1_000_000 {
return Err(GatewayError::Validation(format!(
"Text content at message index {} is too long (max 1M characters)",
index
)));
}
}
MessageContent::Parts(parts) => {
if parts.is_empty() {
return Err(GatewayError::Validation(format!(
"Content parts at message index {} cannot be empty",
index
)));
}
for (part_index, part) in parts.iter().enumerate() {
Self::validate_content_part(part, index, part_index)?;
}
}
}
Ok(())
}
fn validate_content_part(
part: &crate::core::models::openai::ContentPart,
message_index: usize,
part_index: usize,
) -> Result<()> {
use crate::core::models::openai::ContentPart;
match part {
ContentPart::Text { text } => {
if text.trim().is_empty() {
return Err(GatewayError::Validation(format!(
"Text part at message {} part {} cannot be empty",
message_index, part_index
)));
}
}
ContentPart::ImageUrl { image_url } => {
Self::validate_image_url(&image_url.url)?;
if let Some(detail) = &image_url.detail {
if !["low", "high", "auto"].contains(&detail.as_str()) {
return Err(GatewayError::Validation(
"Image detail must be 'low', 'high', or 'auto'".to_string(),
));
}
}
}
ContentPart::Audio { audio } => {
Self::validate_audio_data(&audio.data)?;
Self::validate_audio_format(&audio.format)?;
}
}
Ok(())
}
fn validate_model_name(model: &str) -> Result<()> {
if model.trim().is_empty() {
return Err(GatewayError::Validation(
"Model name cannot be empty".to_string(),
));
}
let model_regex = Regex::new(r"^[a-zA-Z0-9._/-]+$")
.map_err(|e| GatewayError::Internal(format!("Regex error: {}", e)))?;
if !model_regex.is_match(model) {
return Err(GatewayError::Validation(
"Model name contains invalid characters".to_string(),
));
}
Ok(())
}
fn validate_function_name(name: &str) -> Result<()> {
if name.trim().is_empty() {
return Err(GatewayError::Validation(
"Function name cannot be empty".to_string(),
));
}
let name_regex = Regex::new(r"^[a-zA-Z_][a-zA-Z0-9_]*$")
.map_err(|e| GatewayError::Internal(format!("Regex error: {}", e)))?;
if !name_regex.is_match(name) {
return Err(GatewayError::Validation(
"Function name must be a valid identifier".to_string(),
));
}
Ok(())
}
fn validate_image_url(url: &str) -> Result<()> {
if url.starts_with("data:image/") {
Self::validate_base64_image(url)?;
} else {
url::Url::parse(url)
.map_err(|e| GatewayError::Validation(format!("Invalid image URL: {}", e)))?;
}
Ok(())
}
fn validate_base64_image(data_url: &str) -> Result<()> {
if !data_url.starts_with("data:image/") {
return Err(GatewayError::Validation(
"Invalid image data URL format".to_string(),
));
}
let parts: Vec<&str> = data_url.splitn(2, ',').collect();
if parts.len() != 2 {
return Err(GatewayError::Validation(
"Invalid image data URL format".to_string(),
));
}
base64::Engine::decode(&base64::engine::general_purpose::STANDARD, parts[1])
.map_err(|e| GatewayError::Validation(format!("Invalid base64 image data: {}", e)))?;
Ok(())
}
fn validate_audio_data(data: &str) -> Result<()> {
base64::Engine::decode(&base64::engine::general_purpose::STANDARD, data)
.map_err(|e| GatewayError::Validation(format!("Invalid base64 audio data: {}", e)))?;
Ok(())
}
fn validate_audio_format(format: &str) -> Result<()> {
let valid_formats = ["mp3", "wav", "flac", "m4a", "ogg", "webm"];
if !valid_formats.contains(&format) {
return Err(GatewayError::Validation(format!(
"Invalid audio format: {}. Supported formats: {:?}",
format, valid_formats
)));
}
Ok(())
}
}
pub struct ApiValidator;
impl ApiValidator {
pub fn validate_api_key(api_key: &str) -> Result<()> {
if api_key.trim().is_empty() {
return Err(GatewayError::Validation(
"API key cannot be empty".to_string(),
));
}
if api_key.len() < 10 {
return Err(GatewayError::Validation("API key is too short".to_string()));
}
if api_key.len() > 200 {
return Err(GatewayError::Validation("API key is too long".to_string()));
}
Ok(())
}
pub fn validate_uuid(uuid_str: &str) -> Result<Uuid> {
Uuid::parse_str(uuid_str)
.map_err(|e| GatewayError::Validation(format!("Invalid UUID format: {}", e)))
}
pub fn validate_pagination(page: Option<u32>, limit: Option<u32>) -> Result<(u32, u32)> {
let page = page.unwrap_or(1);
let limit = limit.unwrap_or(20);
if page == 0 {
return Err(GatewayError::Validation(
"Page must be greater than 0".to_string(),
));
}
if limit == 0 {
return Err(GatewayError::Validation(
"Limit must be greater than 0".to_string(),
));
}
if limit > 1000 {
return Err(GatewayError::Validation(
"Limit cannot exceed 1000".to_string(),
));
}
Ok((page, limit))
}
pub fn validate_date_range(
start_date: Option<chrono::DateTime<chrono::Utc>>,
end_date: Option<chrono::DateTime<chrono::Utc>>,
) -> Result<()> {
if let (Some(start), Some(end)) = (start_date, end_date) {
if start >= end {
return Err(GatewayError::Validation(
"Start date must be before end date".to_string(),
));
}
let max_range = chrono::Duration::days(365);
if end - start > max_range {
return Err(GatewayError::Validation(
"Date range cannot exceed 365 days".to_string(),
));
}
}
Ok(())
}
pub fn validate_sort_params(
sort_by: &str,
sort_order: &str,
valid_fields: &[&str],
) -> Result<()> {
if !valid_fields.contains(&sort_by) {
return Err(GatewayError::Validation(format!(
"Invalid sort field: {}. Valid fields: {:?}",
sort_by, valid_fields
)));
}
if !["asc", "desc"].contains(&sort_order) {
return Err(GatewayError::Validation(
"Sort order must be 'asc' or 'desc'".to_string(),
));
}
Ok(())
}
pub fn validate_filters(
filters: &std::collections::HashMap<String, String>,
valid_filters: &[&str],
) -> Result<()> {
for key in filters.keys() {
if !valid_filters.contains(&key.as_str()) {
return Err(GatewayError::Validation(format!(
"Invalid filter: {}. Valid filters: {:?}",
key, valid_filters
)));
}
}
Ok(())
}
}
pub struct DataValidator;
impl DataValidator {
pub fn validate_username(username: &str) -> Result<()> {
if username.trim().is_empty() {
return Err(GatewayError::Validation(
"Username cannot be empty".to_string(),
));
}
if username.len() < 3 {
return Err(GatewayError::Validation(
"Username must be at least 3 characters".to_string(),
));
}
if username.len() > 50 {
return Err(GatewayError::Validation(
"Username cannot exceed 50 characters".to_string(),
));
}
let username_regex = Regex::new(r"^[a-zA-Z0-9_-]+$")
.map_err(|e| GatewayError::Internal(format!("Regex error: {}", e)))?;
if !username_regex.is_match(username) {
return Err(GatewayError::Validation(
"Username can only contain letters, numbers, underscores, and hyphens".to_string(),
));
}
Ok(())
}
pub fn validate_password(password: &str) -> Result<()> {
if password.len() < 8 {
return Err(GatewayError::Validation(
"Password must be at least 8 characters".to_string(),
));
}
if password.len() > 128 {
return Err(GatewayError::Validation(
"Password cannot exceed 128 characters".to_string(),
));
}
let has_lowercase = password.chars().any(|c| c.is_lowercase());
let has_uppercase = password.chars().any(|c| c.is_uppercase());
let has_digit = password.chars().any(|c| c.is_ascii_digit());
let has_special = password
.chars()
.any(|c| "!@#$%^&*()_+-=[]{}|;:,.<>?".contains(c));
let strength_count = [has_lowercase, has_uppercase, has_digit, has_special]
.iter()
.filter(|&&x| x)
.count();
if strength_count < 3 {
return Err(GatewayError::Validation(
"Password must contain at least 3 of: lowercase, uppercase, digit, special character".to_string()
));
}
Ok(())
}
pub fn validate_team_name(name: &str) -> Result<()> {
if name.trim().is_empty() {
return Err(GatewayError::Validation(
"Team name cannot be empty".to_string(),
));
}
if name.len() < 2 {
return Err(GatewayError::Validation(
"Team name must be at least 2 characters".to_string(),
));
}
if name.len() > 100 {
return Err(GatewayError::Validation(
"Team name cannot exceed 100 characters".to_string(),
));
}
Ok(())
}
pub fn validate_tags(tags: &[String]) -> Result<()> {
if tags.len() > 20 {
return Err(GatewayError::Validation(
"Cannot have more than 20 tags".to_string(),
));
}
let mut unique_tags = HashSet::new();
for tag in tags {
if tag.trim().is_empty() {
return Err(GatewayError::Validation("Tag cannot be empty".to_string()));
}
if tag.len() > 50 {
return Err(GatewayError::Validation(
"Tag cannot exceed 50 characters".to_string(),
));
}
if !unique_tags.insert(tag.to_lowercase()) {
return Err(GatewayError::Validation(format!("Duplicate tag: {}", tag)));
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_model_name_validation() {
assert!(RequestValidator::validate_model_name("gpt-4").is_ok());
assert!(RequestValidator::validate_model_name("claude-3.5-sonnet").is_ok());
assert!(RequestValidator::validate_model_name("").is_err());
assert!(RequestValidator::validate_model_name("invalid@model").is_err());
}
#[test]
fn test_api_key_validation() {
assert!(ApiValidator::validate_api_key("valid_api_key_123").is_ok());
assert!(ApiValidator::validate_api_key("").is_err());
assert!(ApiValidator::validate_api_key("short").is_err());
}
#[test]
fn test_uuid_validation() {
let valid_uuid = "550e8400-e29b-41d4-a716-446655440000";
assert!(ApiValidator::validate_uuid(valid_uuid).is_ok());
assert!(ApiValidator::validate_uuid("invalid-uuid").is_err());
}
#[test]
fn test_pagination_validation() {
assert!(ApiValidator::validate_pagination(Some(1), Some(20)).is_ok());
assert!(ApiValidator::validate_pagination(Some(0), Some(20)).is_err());
assert!(ApiValidator::validate_pagination(Some(1), Some(0)).is_err());
assert!(ApiValidator::validate_pagination(Some(1), Some(2000)).is_err());
}
#[test]
fn test_username_validation() {
assert!(DataValidator::validate_username("valid_user").is_ok());
assert!(DataValidator::validate_username("user123").is_ok());
assert!(DataValidator::validate_username("").is_err());
assert!(DataValidator::validate_username("ab").is_err());
assert!(DataValidator::validate_username("invalid@user").is_err());
}
#[test]
fn test_password_validation() {
assert!(DataValidator::validate_password("StrongPass123!").is_ok());
assert!(DataValidator::validate_password("NoSpecialChars123").is_ok()); assert!(DataValidator::validate_password("weak").is_err()); assert!(DataValidator::validate_password("onlylowercase").is_err()); assert!(DataValidator::validate_password("ONLYUPPERCASE").is_err()); assert!(DataValidator::validate_password("OnlyTwoTypes").is_err()); }
#[test]
fn test_tags_validation() {
assert!(DataValidator::validate_tags(&["tag1".to_string(), "tag2".to_string()]).is_ok());
assert!(DataValidator::validate_tags(&["".to_string()]).is_err());
assert!(DataValidator::validate_tags(&["tag1".to_string(), "tag1".to_string()]).is_err());
}
}