litellm-rs 0.4.16

A high-performance AI Gateway written in Rust, providing OpenAI-compatible APIs with intelligent routing, load balancing, and enterprise features
Documentation
use crate::core::providers::unified_provider::ProviderError;
use std::collections::HashMap;
use std::env;

pub struct AuthUtils;

impl AuthUtils {
    pub fn check_valid_key(model: &str, api_key: &str) -> Result<bool, ProviderError> {
        if api_key.is_empty() {
            return Err(ProviderError::Authentication {
                provider: "unknown",
                message: "API key cannot be empty".to_string(),
            });
        }

        match model {
            model if model.starts_with("gpt-") => Self::validate_openai_key(api_key),
            model if model.starts_with("claude-") => Self::validate_anthropic_key(api_key),
            model if model.starts_with("gemini-") => Self::validate_google_key(api_key),
            _ => Ok(api_key.len() >= 8),
        }
    }

    pub fn validate_openai_key(api_key: &str) -> Result<bool, ProviderError> {
        if !api_key.starts_with("sk-") {
            return Err(ProviderError::Authentication {
                provider: "openai",
                message: "OpenAI API key must start with 'sk-'".to_string(),
            });
        }

        if api_key.len() < 20 {
            return Err(ProviderError::Authentication {
                provider: "openai",
                message: "OpenAI API key too short".to_string(),
            });
        }

        Ok(true)
    }

    pub fn validate_anthropic_key(api_key: &str) -> Result<bool, ProviderError> {
        if !api_key.starts_with("sk-ant-") {
            return Err(ProviderError::Authentication {
                provider: "anthropic",
                message: "Anthropic API key must start with 'sk-ant-'".to_string(),
            });
        }

        Ok(true)
    }

    pub fn validate_google_key(api_key: &str) -> Result<bool, ProviderError> {
        if api_key.len() < 10 {
            return Err(ProviderError::Authentication {
                provider: "google",
                message: "Google API key too short".to_string(),
            });
        }

        Ok(true)
    }

    pub fn load_credentials_from_list(
        kwargs: &mut HashMap<String, String>,
    ) -> Result<(), ProviderError> {
        if let Some(credential_name) = kwargs.get("credential_name").cloned() {
            let env_key = format!("{}_API_KEY", credential_name.to_uppercase());

            if let Ok(api_key) = env::var(&env_key) {
                kwargs.insert("api_key".to_string(), api_key);
            }

            let env_base = format!("{}_API_BASE", credential_name.to_uppercase());
            if let Ok(api_base) = env::var(&env_base) {
                kwargs.insert("api_base".to_string(), api_base);
            }

            let env_version = format!("{}_API_VERSION", credential_name.to_uppercase());
            if let Ok(api_version) = env::var(&env_version) {
                kwargs.insert("api_version".to_string(), api_version);
            }
        }

        Ok(())
    }

    pub fn get_api_key_from_env(provider: &str) -> Option<String> {
        let env_vars = match provider.to_lowercase().as_str() {
            "openai" => vec!["OPENAI_API_KEY", "OPENAI_KEY"],
            "anthropic" => vec!["ANTHROPIC_API_KEY", "CLAUDE_API_KEY"],
            "google" => vec!["GOOGLE_API_KEY", "GEMINI_API_KEY"],
            "azure" => vec!["AZURE_API_KEY", "AZURE_OPENAI_API_KEY"],
            "cohere" => vec!["COHERE_API_KEY"],
            "mistral" => vec!["MISTRAL_API_KEY"],
            _ => vec!["API_KEY"],
        };

        for env_var in env_vars {
            if let Ok(key) = env::var(env_var)
                && !key.is_empty()
            {
                return Some(key);
            }
        }

        None
    }

    pub fn mask_api_key(api_key: &str) -> String {
        if api_key.len() <= 8 {
            return "*".repeat(api_key.len());
        }

        let start = &api_key[..4];
        let end = &api_key[api_key.len() - 4..];
        let middle = "*".repeat(api_key.len() - 8);

        format!("{}{}{}", start, middle, end)
    }

    pub fn validate_environment_for_provider(provider: &str) -> Result<(), ProviderError> {
        match provider.to_lowercase().as_str() {
            "openai" => {
                if Self::get_api_key_from_env("openai").is_none() {
                    return Err(ProviderError::InvalidRequest {
                        provider: "openai",
                        message: "Missing OpenAI API key. Set OPENAI_API_KEY environment variable"
                            .to_string(),
                    });
                }
            }
            "anthropic" => {
                if Self::get_api_key_from_env("anthropic").is_none() {
                    return Err(ProviderError::InvalidRequest {
                        provider: "anthropic",
                        message:
                            "Missing Anthropic API key. Set ANTHROPIC_API_KEY environment variable"
                                .to_string(),
                    });
                }
            }
            "google" => {
                if Self::get_api_key_from_env("google").is_none() {
                    return Err(ProviderError::InvalidRequest {
                        provider: "google",
                        message: "Missing Google API key. Set GOOGLE_API_KEY environment variable"
                            .to_string(),
                    });
                }
            }
            _ => {}
        }

        Ok(())
    }

    pub fn get_bearer_token(api_key: &str) -> String {
        if api_key.starts_with("Bearer ") {
            api_key.to_string()
        } else {
            format!("Bearer {}", api_key)
        }
    }

    pub fn extract_api_key_from_bearer(bearer_token: &str) -> String {
        if bearer_token.starts_with("Bearer ") {
            bearer_token
                .strip_prefix("Bearer ")
                .unwrap_or(bearer_token)
                .to_string()
        } else {
            bearer_token.to_string()
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_validate_openai_key() {
        assert!(AuthUtils::validate_openai_key("sk-1234567890abcdef1234").is_ok());
        assert!(AuthUtils::validate_openai_key("invalid-key").is_err());
        assert!(AuthUtils::validate_openai_key("sk-short").is_err());
    }

    #[test]
    fn test_validate_anthropic_key() {
        assert!(AuthUtils::validate_anthropic_key("sk-ant-api03-1234567890").is_ok());
        assert!(AuthUtils::validate_anthropic_key("invalid-key").is_err());
    }

    #[test]
    fn test_mask_api_key() {
        let key = "sk-1234567890abcdef";
        let masked = AuthUtils::mask_api_key(key);
        assert_eq!(masked, "sk-1***********cdef");

        let short_key = "short";
        let masked_short = AuthUtils::mask_api_key(short_key);
        assert_eq!(masked_short, "*****");
    }

    #[test]
    fn test_bearer_token() {
        let key = "sk-1234567890";
        let bearer = AuthUtils::get_bearer_token(key);
        assert_eq!(bearer, "Bearer sk-1234567890");

        let already_bearer = "Bearer sk-1234567890";
        let bearer2 = AuthUtils::get_bearer_token(already_bearer);
        assert_eq!(bearer2, "Bearer sk-1234567890");
    }
}