llm_connector/providers/
openai.rs1use crate::core::{GenericProvider, HttpClient, Protocol};
6use crate::protocols::OpenAIProtocol;
7use crate::error::LlmConnectorError;
8use std::collections::HashMap;
9
10pub type OpenAIProvider = GenericProvider<OpenAIProtocol>;
12
13pub fn openai(api_key: &str) -> Result<OpenAIProvider, LlmConnectorError> {
28 openai_with_config(api_key, None, None, None)
29}
30
31pub fn openai_with_base_url(api_key: &str, base_url: &str) -> Result<OpenAIProvider, LlmConnectorError> {
45 openai_with_config(api_key, Some(base_url), None, None)
46}
47
48pub fn openai_with_config(
68 api_key: &str,
69 base_url: Option<&str>,
70 timeout_secs: Option<u64>,
71 proxy: Option<&str>,
72) -> Result<OpenAIProvider, LlmConnectorError> {
73 let protocol = OpenAIProtocol::new(api_key);
75
76 let client = HttpClient::with_config(
78 base_url.unwrap_or("https://api.openai.com"),
79 timeout_secs,
80 proxy,
81 )?;
82
83 let auth_headers: HashMap<String, String> = protocol.auth_headers().into_iter().collect();
85 let client = client.with_headers(auth_headers);
86
87 Ok(GenericProvider::new(protocol, client))
89}
90
91pub fn azure_openai(
109 api_key: &str,
110 endpoint: &str,
111 api_version: &str,
112) -> Result<OpenAIProvider, LlmConnectorError> {
113 let protocol = OpenAIProtocol::new(api_key);
114
115 let client = HttpClient::new(endpoint)?
117 .with_header("api-key".to_string(), api_key.to_string())
118 .with_header("api-version".to_string(), api_version.to_string());
119
120 Ok(GenericProvider::new(protocol, client))
121}
122
123pub fn openai_compatible(
152 api_key: &str,
153 base_url: &str,
154 service_name: &str,
155) -> Result<OpenAIProvider, LlmConnectorError> {
156 let protocol = OpenAIProtocol::new(api_key);
157
158 let client = HttpClient::new(base_url)?
160 .with_header("Authorization".to_string(), format!("Bearer {}", api_key))
161 .with_header("User-Agent".to_string(), format!("llm-connector/{}", service_name));
162
163 Ok(GenericProvider::new(protocol, client))
164}
165
166pub fn validate_openai_key(api_key: &str) -> bool {
168 api_key.starts_with("sk-") && api_key.len() > 20
169}
170
171#[cfg(test)]
172mod tests {
173 use super::*;
174
175 #[test]
176 fn test_openai_provider_creation() {
177 let provider = openai("test-key");
178 assert!(provider.is_ok());
179
180 let provider = provider.unwrap();
181 assert_eq!(provider.protocol().name(), "openai");
182 assert_eq!(provider.protocol().api_key(), "test-key");
183 }
184
185 #[test]
186 fn test_openai_with_base_url() {
187 let provider = openai_with_base_url("test-key", "https://custom.api.com");
188 assert!(provider.is_ok());
189
190 let provider = provider.unwrap();
191 assert_eq!(provider.client().base_url(), "https://custom.api.com");
192 }
193
194 #[test]
195 fn test_azure_openai() {
196 let provider = azure_openai(
197 "test-key",
198 "https://test.openai.azure.com",
199 "2024-02-15-preview"
200 );
201 assert!(provider.is_ok());
202 }
203
204 #[test]
205 fn test_openai_compatible() {
206 let provider = openai_compatible(
207 "test-key",
208 "https://api.deepseek.com",
209 "deepseek"
210 );
211 assert!(provider.is_ok());
212 }
213}