1use crate::{config::OpenAIConfig, error::OpenAIError, types::*};
4use async_trait::async_trait;
5use ferrous_llm_core::{
6 ChatProvider, ChatRequest, CompletionProvider, CompletionRequest, Embedding, EmbeddingProvider,
7 ProviderResult, StreamingProvider, Tool, ToolProvider,
8};
9use futures::Stream;
10use reqwest::{Client, RequestBuilder};
11use serde_json::json;
12use std::pin::Pin;
13use tokio_stream::{StreamExt, wrappers::ReceiverStream};
14
15#[derive(Debug, Clone)]
17pub struct OpenAIProvider {
18 config: OpenAIConfig,
19 client: Client,
20}
21
22impl OpenAIProvider {
23 pub fn new(config: OpenAIConfig) -> Result<Self, OpenAIError> {
25 let mut headers = reqwest::header::HeaderMap::new();
26
27 let auth_value = format!("Bearer {}", config.api_key.expose_secret());
29 headers.insert(
30 reqwest::header::AUTHORIZATION,
31 auth_value.parse().map_err(|_| OpenAIError::Config {
32 source: ferrous_llm_core::ConfigError::invalid_value(
33 "api_key",
34 "Invalid API key format",
35 ),
36 })?,
37 );
38
39 if let Some(ref org) = config.organization {
41 headers.insert(
42 "OpenAI-Organization",
43 org.parse().map_err(|_| OpenAIError::Config {
44 source: ferrous_llm_core::ConfigError::invalid_value(
45 "organization",
46 "Invalid organization format",
47 ),
48 })?,
49 );
50 }
51
52 if let Some(ref project) = config.project {
54 headers.insert(
55 "OpenAI-Project",
56 project.parse().map_err(|_| OpenAIError::Config {
57 source: ferrous_llm_core::ConfigError::invalid_value(
58 "project",
59 "Invalid project format",
60 ),
61 })?,
62 );
63 }
64
65 if let Some(ref user_agent) = config.http.user_agent {
67 headers.insert(
68 reqwest::header::USER_AGENT,
69 user_agent.parse().map_err(|_| OpenAIError::Config {
70 source: ferrous_llm_core::ConfigError::invalid_value(
71 "user_agent",
72 "Invalid user agent format",
73 ),
74 })?,
75 );
76 }
77
78 for (key, value) in &config.http.headers {
80 let header_name: reqwest::header::HeaderName =
81 key.parse().map_err(|_| OpenAIError::Config {
82 source: ferrous_llm_core::ConfigError::invalid_value(
83 "headers",
84 "Invalid header name",
85 ),
86 })?;
87 let header_value: reqwest::header::HeaderValue =
88 value.parse().map_err(|_| OpenAIError::Config {
89 source: ferrous_llm_core::ConfigError::invalid_value(
90 "headers",
91 "Invalid header value",
92 ),
93 })?;
94 headers.insert(header_name, header_value);
95 }
96
97 let mut client_builder = Client::builder()
98 .timeout(config.http.timeout)
99 .default_headers(headers);
100
101 if !config.http.compression {
103 client_builder = client_builder.no_gzip();
104 }
105
106 client_builder = client_builder
108 .pool_max_idle_per_host(config.http.pool.max_idle_connections)
109 .pool_idle_timeout(config.http.pool.idle_timeout)
110 .connect_timeout(config.http.pool.connect_timeout);
111
112 let client = client_builder
113 .build()
114 .map_err(|e| OpenAIError::Network { source: e })?;
115
116 Ok(Self { config, client })
117 }
118
119 fn request_builder(&self, method: reqwest::Method, url: &str) -> RequestBuilder {
121 self.client.request(method, url)
122 }
123
124 async fn handle_response<T>(&self, response: reqwest::Response) -> Result<T, OpenAIError>
126 where
127 T: serde::de::DeserializeOwned,
128 {
129 let status = response.status();
130
131 if status.is_success() {
132 response
133 .json()
134 .await
135 .map_err(|e| OpenAIError::Network { source: e })
136 } else {
137 let body = response.text().await.unwrap_or_default();
138 Err(OpenAIError::from_response(status.as_u16(), &body))
139 }
140 }
141
142 fn convert_chat_request(&self, request: &ChatRequest) -> OpenAIChatRequest {
144 OpenAIChatRequest {
145 model: self.config.model.clone(),
146 messages: request.messages.iter().map(|m| m.into()).collect(),
147 temperature: request.parameters.temperature,
148 max_tokens: request.parameters.max_tokens,
149 top_p: request.parameters.top_p,
150 frequency_penalty: request.parameters.frequency_penalty,
151 presence_penalty: request.parameters.presence_penalty,
152 stop: request.parameters.stop_sequences.clone(),
153 stream: Some(false),
154 tools: None, tool_choice: None,
156 user: request.metadata.user_id.clone(),
157 }
158 }
159
160 fn convert_completion_request(&self, request: &CompletionRequest) -> OpenAICompletionRequest {
162 OpenAICompletionRequest {
163 model: self.config.model.clone(),
164 prompt: request.prompt.clone(),
165 max_tokens: request.parameters.max_tokens,
166 temperature: request.parameters.temperature,
167 top_p: request.parameters.top_p,
168 frequency_penalty: request.parameters.frequency_penalty,
169 presence_penalty: request.parameters.presence_penalty,
170 stop: request.parameters.stop_sequences.clone(),
171 stream: Some(false),
172 user: request.metadata.user_id.clone(),
173 }
174 }
175}
176
177#[async_trait]
178impl ChatProvider for OpenAIProvider {
179 type Config = OpenAIConfig;
180 type Response = OpenAIChatResponse;
181 type Error = OpenAIError;
182
183 async fn chat(&self, request: ChatRequest) -> ProviderResult<Self::Response, Self::Error> {
184 let openai_request = self.convert_chat_request(&request);
185
186 let response = self
187 .request_builder(reqwest::Method::POST, &self.config.chat_url())
188 .json(&openai_request)
189 .send()
190 .await
191 .map_err(|e| OpenAIError::Network { source: e })?;
192
193 self.handle_response(response).await
194 }
195}
196
197#[async_trait]
198impl CompletionProvider for OpenAIProvider {
199 type Config = OpenAIConfig;
200 type Response = OpenAICompletionResponse;
201 type Error = OpenAIError;
202
203 async fn complete(
204 &self,
205 request: CompletionRequest,
206 ) -> ProviderResult<Self::Response, Self::Error> {
207 let openai_request = self.convert_completion_request(&request);
208
209 let response = self
210 .request_builder(reqwest::Method::POST, &self.config.completions_url())
211 .json(&openai_request)
212 .send()
213 .await
214 .map_err(|e| OpenAIError::Network { source: e })?;
215
216 self.handle_response(response).await
217 }
218}
219
220#[async_trait]
221impl EmbeddingProvider for OpenAIProvider {
222 type Config = OpenAIConfig;
223 type Error = OpenAIError;
224
225 async fn embed(&self, texts: &[String]) -> ProviderResult<Vec<Embedding>, Self::Error> {
226 let request = OpenAIEmbeddingsRequest {
227 model: self
228 .config
229 .embedding_model
230 .clone()
231 .unwrap_or_else(|| "text-embedding-ada-002".to_string()),
232 input: if texts.len() == 1 {
233 json!(texts[0])
234 } else {
235 json!(texts)
236 },
237 encoding_format: Some("float".to_string()),
238 dimensions: None,
239 user: None,
240 };
241
242 let response = self
243 .request_builder(reqwest::Method::POST, &self.config.embeddings_url())
244 .json(&request)
245 .send()
246 .await
247 .map_err(|e| OpenAIError::Network { source: e })?;
248
249 let embeddings_response: OpenAIEmbeddingsResponse = self.handle_response(response).await?;
250
251 let embeddings = embeddings_response
252 .data
253 .into_iter()
254 .map(|e| Embedding {
255 embedding: e.embedding,
256 index: e.index,
257 })
258 .collect();
259
260 Ok(embeddings)
261 }
262}
263
264#[async_trait]
265impl StreamingProvider for OpenAIProvider {
266 type StreamItem = String;
267 type Stream = Pin<Box<dyn Stream<Item = Result<Self::StreamItem, Self::Error>> + Send>>;
268
269 async fn chat_stream(&self, request: ChatRequest) -> ProviderResult<Self::Stream, Self::Error> {
270 let mut openai_request = self.convert_chat_request(&request);
271 openai_request.stream = Some(true);
272
273 let response = self
274 .request_builder(reqwest::Method::POST, &self.config.chat_url())
275 .json(&openai_request)
276 .send()
277 .await
278 .map_err(|e| OpenAIError::Network { source: e })?;
279
280 if !response.status().is_success() {
281 let status = response.status().as_u16();
282 let body = response.text().await.unwrap_or_default();
283 return Err(OpenAIError::from_response(status, &body));
284 }
285
286 let (tx, rx) = tokio::sync::mpsc::channel::<Result<String, OpenAIError>>(100);
288
289 let tx_clone = tx.clone();
291 tokio::spawn(async move {
292 let mut byte_stream = response.bytes_stream();
293 let mut buffer = Vec::new();
294
295 while let Some(chunk_result) = byte_stream.next().await {
296 match chunk_result {
297 Ok(chunk) => {
298 buffer.extend_from_slice(chunk.as_ref());
299
300 let mut start = 0;
302 while let Some(pos) = buffer[start..].iter().position(|&b| b == b'\n') {
303 let line_end = start + pos;
304 let line = String::from_utf8_lossy(&buffer[start..line_end])
305 .trim()
306 .to_string();
307 start = line_end + 1;
308
309 if let Some(data) = line.strip_prefix("data: ") {
311 if data == "[DONE]" {
312 drop(tx_clone);
314 return;
315 }
316
317 if let Ok(chunk) = serde_json::from_str::<OpenAIStreamChunk>(data)
319 && let Some(choice) = chunk.choices.first()
320 && let Some(content) = &choice.delta.content
321 && !content.is_empty()
322 && tx_clone.send(Ok(content.clone())).await.is_err()
323 {
324 return;
326 }
327 }
328 }
329
330 buffer.drain(0..start);
332 }
333 Err(e) => {
334 let _ = tx_clone.send(Err(OpenAIError::Network { source: e })).await;
335 return;
336 }
337 }
338 }
339
340 drop(tx_clone);
342 });
343
344 let content_stream = ReceiverStream::new(rx);
346
347 Ok(Box::pin(content_stream))
348 }
349}
350
351#[async_trait]
352impl ToolProvider for OpenAIProvider {
353 async fn chat_with_tools(
354 &self,
355 request: ChatRequest,
356 tools: &[Tool],
357 ) -> ProviderResult<Self::Response, Self::Error> {
358 let mut openai_request = self.convert_chat_request(&request);
359
360 if !tools.is_empty() {
361 openai_request.tools = Some(tools.iter().map(|t| t.into()).collect());
362 openai_request.tool_choice = Some(json!("auto"));
363 }
364
365 let response = self
366 .request_builder(reqwest::Method::POST, &self.config.chat_url())
367 .json(&openai_request)
368 .send()
369 .await
370 .map_err(|e| OpenAIError::Network { source: e })?;
371
372 self.handle_response(response).await
373 }
374}
375
376#[cfg(test)]
377mod tests {
378 use super::*;
379 use ferrous_llm_core::{Message, Metadata, Parameters};
380
381 fn create_test_config() -> OpenAIConfig {
382 OpenAIConfig::new("sk-test123456789", "gpt-3.5-turbo")
383 }
384
385 #[test]
386 fn test_provider_creation() {
387 let config = create_test_config();
388 let provider = OpenAIProvider::new(config);
389 assert!(provider.is_ok());
390 }
391
392 #[test]
393 fn test_convert_chat_request() {
394 let config = create_test_config();
395 let provider = OpenAIProvider::new(config).unwrap();
396
397 let request = ChatRequest {
398 messages: vec![Message::user("Hello")],
399 parameters: Parameters {
400 temperature: Some(0.7),
401 max_tokens: Some(100),
402 ..Default::default()
403 },
404 metadata: Metadata::default(),
405 };
406
407 let openai_request = provider.convert_chat_request(&request);
408 assert_eq!(openai_request.model, "gpt-3.5-turbo");
409 assert_eq!(openai_request.temperature, Some(0.7));
410 assert_eq!(openai_request.max_tokens, Some(100));
411 assert_eq!(openai_request.messages.len(), 1);
412 }
413}