use crate::common::errors::{OpenAIToolError, Result};
use dotenvy::dotenv;
use request::header::{HeaderMap, HeaderValue};
use std::env;
const OPENAI_DEFAULT_BASE_URL: &str = "https://api.openai.com/v1";
#[derive(Debug, Clone)]
pub enum AuthProvider {
OpenAI(OpenAIAuth),
Azure(AzureAuth),
}
#[derive(Debug, Clone)]
pub struct OpenAIAuth {
api_key: String,
base_url: String,
}
impl OpenAIAuth {
pub fn new<T: Into<String>>(api_key: T) -> Self {
Self { api_key: api_key.into(), base_url: OPENAI_DEFAULT_BASE_URL.to_string() }
}
pub fn with_base_url<T: Into<String>>(mut self, url: T) -> Self {
self.base_url = url.into();
self
}
pub fn api_key(&self) -> &str {
&self.api_key
}
pub fn base_url(&self) -> &str {
&self.base_url
}
fn endpoint(&self, path: &str) -> String {
format!("{}/{}", self.base_url.trim_end_matches('/'), path.trim_start_matches('/'))
}
fn apply_headers(&self, headers: &mut HeaderMap) -> Result<()> {
headers.insert(
"Authorization",
HeaderValue::from_str(&format!("Bearer {}", self.api_key)).map_err(|e| OpenAIToolError::Error(format!("Invalid header value: {}", e)))?,
);
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct AzureAuth {
api_key: String,
base_url: String,
}
impl AzureAuth {
pub fn new<T: Into<String>>(api_key: T, base_url: T) -> Self {
Self { api_key: api_key.into(), base_url: base_url.into() }
}
pub fn api_key(&self) -> &str {
&self.api_key
}
pub fn base_url(&self) -> &str {
&self.base_url
}
fn endpoint(&self, _path: &str) -> String {
self.base_url.clone()
}
fn apply_headers(&self, headers: &mut HeaderMap) -> Result<()> {
headers.insert("api-key", HeaderValue::from_str(&self.api_key).map_err(|e| OpenAIToolError::Error(format!("Invalid header value: {}", e)))?);
Ok(())
}
}
impl AuthProvider {
pub fn openai_from_env() -> Result<Self> {
dotenv().ok();
let api_key = env::var("OPENAI_API_KEY").map_err(|_| OpenAIToolError::Error("OPENAI_API_KEY environment variable not set".into()))?;
Ok(Self::OpenAI(OpenAIAuth::new(api_key)))
}
pub fn azure_from_env() -> Result<Self> {
dotenv().ok();
let api_key =
env::var("AZURE_OPENAI_API_KEY").map_err(|_| OpenAIToolError::Error("AZURE_OPENAI_API_KEY environment variable not set".into()))?;
let base_url =
env::var("AZURE_OPENAI_BASE_URL").map_err(|_| OpenAIToolError::Error("AZURE_OPENAI_BASE_URL environment variable not set".into()))?;
Ok(Self::Azure(AzureAuth::new(api_key, base_url)))
}
pub fn from_env() -> Result<Self> {
dotenv().ok();
if env::var("AZURE_OPENAI_API_KEY").is_ok() {
return Self::azure_from_env();
}
Self::openai_from_env()
}
pub fn endpoint(&self, path: &str) -> String {
match self {
Self::OpenAI(auth) => auth.endpoint(path),
Self::Azure(auth) => auth.endpoint(path),
}
}
pub fn apply_headers(&self, headers: &mut HeaderMap) -> Result<()> {
match self {
Self::OpenAI(auth) => auth.apply_headers(headers),
Self::Azure(auth) => auth.apply_headers(headers),
}
}
pub fn api_key(&self) -> &str {
match self {
Self::OpenAI(auth) => auth.api_key(),
Self::Azure(auth) => auth.api_key(),
}
}
pub fn is_azure(&self) -> bool {
matches!(self, Self::Azure(_))
}
pub fn is_openai(&self) -> bool {
matches!(self, Self::OpenAI(_))
}
pub fn from_url_with_key<S: Into<String>>(base_url: S, api_key: S) -> Self {
let url_str = base_url.into();
let api_key_str = api_key.into();
if url_str.contains(".openai.azure.com") {
Self::Azure(AzureAuth::new(api_key_str, url_str))
} else {
Self::OpenAI(OpenAIAuth::new(api_key_str).with_base_url(url_str))
}
}
pub fn from_url<S: Into<String>>(base_url: S) -> Result<Self> {
let url_str = base_url.into();
dotenv().ok();
if url_str.contains(".openai.azure.com") {
let api_key = env::var("AZURE_OPENAI_API_KEY")
.map_err(|_| OpenAIToolError::Error("Azure URL detected but AZURE_OPENAI_API_KEY is not set".into()))?;
Ok(Self::Azure(AzureAuth::new(api_key, url_str)))
} else {
let api_key = env::var("OPENAI_API_KEY").map_err(|_| OpenAIToolError::Error("OPENAI_API_KEY environment variable not set".into()))?;
Ok(Self::OpenAI(OpenAIAuth::new(api_key).with_base_url(url_str)))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_openai_auth_new() {
let auth = OpenAIAuth::new("sk-test-key");
assert_eq!(auth.api_key(), "sk-test-key");
assert_eq!(auth.base_url(), OPENAI_DEFAULT_BASE_URL);
}
#[test]
fn test_openai_auth_with_base_url() {
let auth = OpenAIAuth::new("sk-test-key").with_base_url("https://custom.example.com/v1");
assert_eq!(auth.base_url(), "https://custom.example.com/v1");
}
#[test]
fn test_openai_endpoint() {
let auth = OpenAIAuth::new("sk-test-key");
assert_eq!(auth.endpoint("chat/completions"), "https://api.openai.com/v1/chat/completions");
assert_eq!(auth.endpoint("/chat/completions"), "https://api.openai.com/v1/chat/completions");
}
#[test]
fn test_openai_apply_headers() {
let auth = OpenAIAuth::new("sk-test-key");
let mut headers = HeaderMap::new();
auth.apply_headers(&mut headers).unwrap();
assert_eq!(headers.get("Authorization").unwrap(), "Bearer sk-test-key");
}
#[test]
fn test_openai_endpoint_trailing_slash_handling() {
let auth = OpenAIAuth::new("key").with_base_url("https://example.com/v1/");
assert_eq!(auth.endpoint("chat/completions"), "https://example.com/v1/chat/completions");
assert_eq!(auth.endpoint("/chat/completions"), "https://example.com/v1/chat/completions");
}
#[test]
fn test_azure_auth_new() {
let auth = AzureAuth::new(
"api-key",
"https://my-resource.openai.azure.com/openai/deployments/gpt-4o/chat/completions?api-version=2024-08-01-preview",
);
assert_eq!(auth.api_key(), "api-key");
assert_eq!(auth.base_url(), "https://my-resource.openai.azure.com/openai/deployments/gpt-4o/chat/completions?api-version=2024-08-01-preview");
}
#[test]
fn test_azure_endpoint_returns_base_url() {
let auth =
AzureAuth::new("key", "https://my-resource.openai.azure.com/openai/deployments/gpt-4o/chat/completions?api-version=2024-08-01-preview");
let endpoint = auth.endpoint("ignored");
assert_eq!(endpoint, "https://my-resource.openai.azure.com/openai/deployments/gpt-4o/chat/completions?api-version=2024-08-01-preview");
}
#[test]
fn test_azure_apply_headers() {
let auth = AzureAuth::new("my-api-key", "https://my-resource.openai.azure.com");
let mut headers = HeaderMap::new();
auth.apply_headers(&mut headers).unwrap();
assert_eq!(headers.get("api-key").unwrap(), "my-api-key");
assert!(headers.get("Authorization").is_none());
}
#[test]
fn test_auth_provider_openai() {
let auth = AuthProvider::OpenAI(OpenAIAuth::new("sk-key"));
assert!(auth.is_openai());
assert!(!auth.is_azure());
assert_eq!(auth.api_key(), "sk-key");
}
#[test]
fn test_auth_provider_azure() {
let auth = AuthProvider::Azure(AzureAuth::new("key", "https://my-resource.openai.azure.com/openai/deployments/gpt-4o"));
assert!(auth.is_azure());
assert!(!auth.is_openai());
assert_eq!(auth.api_key(), "key");
}
#[test]
fn test_auth_provider_endpoint_openai() {
let auth = AuthProvider::OpenAI(OpenAIAuth::new("key"));
assert_eq!(auth.endpoint("chat/completions"), "https://api.openai.com/v1/chat/completions");
}
#[test]
fn test_auth_provider_endpoint_azure() {
let base_url = "https://my-resource.openai.azure.com/openai/deployments/gpt-4o/chat/completions?api-version=2024-08-01-preview";
let auth = AuthProvider::Azure(AzureAuth::new("key", base_url));
let endpoint = auth.endpoint("ignored");
assert_eq!(endpoint, base_url);
}
#[test]
fn test_auth_provider_apply_headers() {
let openai_auth = AuthProvider::OpenAI(OpenAIAuth::new("sk-key"));
let mut headers = HeaderMap::new();
openai_auth.apply_headers(&mut headers).unwrap();
assert!(headers.get("Authorization").unwrap().to_str().unwrap().starts_with("Bearer"));
let azure_auth = AuthProvider::Azure(AzureAuth::new("azure-key", "https://my-resource.openai.azure.com"));
let mut headers = HeaderMap::new();
azure_auth.apply_headers(&mut headers).unwrap();
assert_eq!(headers.get("api-key").unwrap(), "azure-key");
}
#[test]
fn test_from_env_returns_correct_provider_type() {
let openai = AuthProvider::OpenAI(OpenAIAuth::new("sk-test"));
assert!(openai.is_openai());
assert!(!openai.is_azure());
let azure = AuthProvider::Azure(AzureAuth::new("key", "https://my-resource.openai.azure.com/openai/deployments/gpt-4o"));
assert!(azure.is_azure());
assert!(!azure.is_openai());
}
#[test]
fn test_from_url_with_key_openai_api() {
let auth = AuthProvider::from_url_with_key("https://api.openai.com/v1", "sk-test-key");
assert!(auth.is_openai());
assert!(!auth.is_azure());
assert_eq!(auth.api_key(), "sk-test-key");
assert_eq!(auth.endpoint("chat/completions"), "https://api.openai.com/v1/chat/completions");
}
#[test]
fn test_from_url_with_key_azure() {
let base_url = "https://my-resource.openai.azure.com/openai/deployments/gpt-4o/chat/completions?api-version=2024-08-01-preview";
let auth = AuthProvider::from_url_with_key(base_url, "azure-api-key");
assert!(auth.is_azure());
assert!(!auth.is_openai());
assert_eq!(auth.api_key(), "azure-api-key");
let endpoint = auth.endpoint("ignored");
assert_eq!(endpoint, base_url);
}
#[test]
fn test_from_url_with_key_local_api_ollama() {
let auth = AuthProvider::from_url_with_key("http://localhost:11434/v1", "ollama");
assert!(auth.is_openai());
assert_eq!(auth.endpoint("chat/completions"), "http://localhost:11434/v1/chat/completions");
}
#[test]
fn test_from_url_with_key_custom_openai_compatible() {
let auth = AuthProvider::from_url_with_key("https://my-proxy.example.com/openai/v1", "proxy-key");
assert!(auth.is_openai());
assert_eq!(auth.endpoint("embeddings"), "https://my-proxy.example.com/openai/v1/embeddings");
}
#[test]
fn test_from_url_with_key_azure_various_patterns() {
let patterns = [
"https://eastus.openai.azure.com/openai/deployments/gpt-4o",
"https://my-company-resource.openai.azure.com/openai/deployments/gpt-4o",
"https://test.openai.azure.com/openai/deployments/gpt-4o?api-version=2024-08-01-preview",
];
for url in patterns {
let auth = AuthProvider::from_url_with_key(url, "key");
assert!(auth.is_azure(), "Should be Azure provider for URL: {}", url);
}
}
#[test]
fn test_from_url_with_key_headers_openai() {
let auth = AuthProvider::from_url_with_key("https://api.openai.com/v1", "sk-secret-key");
let mut headers = HeaderMap::new();
auth.apply_headers(&mut headers).unwrap();
assert_eq!(headers.get("Authorization").unwrap(), "Bearer sk-secret-key");
}
#[test]
fn test_from_url_with_key_headers_azure() {
let auth = AuthProvider::from_url_with_key("https://resource.openai.azure.com/openai/deployments/gpt-4o", "azure-secret");
let mut headers = HeaderMap::new();
auth.apply_headers(&mut headers).unwrap();
assert_eq!(headers.get("api-key").unwrap(), "azure-secret");
assert!(headers.get("Authorization").is_none());
}
}