use crate::core::{GenericProvider, HttpClient, Protocol};
use crate::protocols::OpenAIProtocol;
use crate::error::LlmConnectorError;
use std::collections::HashMap;
pub type OpenAIProvider = GenericProvider<OpenAIProtocol>;
pub fn openai(api_key: &str) -> Result<OpenAIProvider, LlmConnectorError> {
openai_with_config(api_key, None, None, None)
}
pub fn openai_with_base_url(api_key: &str, base_url: &str) -> Result<OpenAIProvider, LlmConnectorError> {
openai_with_config(api_key, Some(base_url), None, None)
}
pub fn openai_with_config(
api_key: &str,
base_url: Option<&str>,
timeout_secs: Option<u64>,
proxy: Option<&str>,
) -> Result<OpenAIProvider, LlmConnectorError> {
let protocol = OpenAIProtocol::new(api_key);
let client = HttpClient::with_config(
base_url.unwrap_or("https://api.openai.com"),
timeout_secs,
proxy,
)?;
let auth_headers: HashMap<String, String> = protocol.auth_headers().into_iter().collect();
let client = client.with_headers(auth_headers);
Ok(GenericProvider::new(protocol, client))
}
pub fn azure_openai(
api_key: &str,
endpoint: &str,
api_version: &str,
) -> Result<OpenAIProvider, LlmConnectorError> {
let protocol = OpenAIProtocol::new(api_key);
let client = HttpClient::new(endpoint)?
.with_header("api-key".to_string(), api_key.to_string())
.with_header("api-version".to_string(), api_version.to_string());
Ok(GenericProvider::new(protocol, client))
}
pub fn openai_compatible(
api_key: &str,
base_url: &str,
service_name: &str,
) -> Result<OpenAIProvider, LlmConnectorError> {
let protocol = OpenAIProtocol::new(api_key);
let client = HttpClient::new(base_url)?
.with_header("Authorization".to_string(), format!("Bearer {}", api_key))
.with_header("User-Agent".to_string(), format!("llm-connector/{}", service_name));
Ok(GenericProvider::new(protocol, client))
}
pub fn validate_openai_key(api_key: &str) -> bool {
api_key.starts_with("sk-") && api_key.len() > 20
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_openai_provider_creation() {
let provider = openai("test-key");
assert!(provider.is_ok());
let provider = provider.unwrap();
assert_eq!(provider.protocol().name(), "openai");
assert_eq!(provider.protocol().api_key(), "test-key");
}
#[test]
fn test_openai_with_base_url() {
let provider = openai_with_base_url("test-key", "https://custom.api.com");
assert!(provider.is_ok());
let provider = provider.unwrap();
assert_eq!(provider.client().base_url(), "https://custom.api.com");
}
#[test]
fn test_azure_openai() {
let provider = azure_openai(
"test-key",
"https://test.openai.azure.com",
"2024-02-15-preview"
);
assert!(provider.is_ok());
}
#[test]
fn test_openai_compatible() {
let provider = openai_compatible(
"test-key",
"https://api.deepseek.com",
"deepseek"
);
assert!(provider.is_ok());
}
}