ferrous_llm_openai/
provider.rs

1//! OpenAI provider implementation.
2
3use crate::{config::OpenAIConfig, error::OpenAIError, types::*};
4use async_trait::async_trait;
5use ferrous_llm_core::{
6    ChatProvider, ChatRequest, CompletionProvider, CompletionRequest, Embedding, EmbeddingProvider,
7    ProviderResult, StreamingProvider, Tool, ToolProvider,
8};
9use futures::Stream;
10use reqwest::{Client, RequestBuilder};
11use serde_json::json;
12use std::pin::Pin;
13use tokio_stream::{StreamExt, wrappers::ReceiverStream};
14
15/// OpenAI provider implementation.
16#[derive(Debug, Clone)]
17pub struct OpenAIProvider {
18    config: OpenAIConfig,
19    client: Client,
20}
21
22impl OpenAIProvider {
23    /// Create a new OpenAI provider with the given configuration.
24    pub fn new(config: OpenAIConfig) -> Result<Self, OpenAIError> {
25        let mut headers = reqwest::header::HeaderMap::new();
26
27        // Add authorization header
28        let auth_value = format!("Bearer {}", config.api_key.expose_secret());
29        headers.insert(
30            reqwest::header::AUTHORIZATION,
31            auth_value.parse().map_err(|_| OpenAIError::Config {
32                source: ferrous_llm_core::ConfigError::invalid_value(
33                    "api_key",
34                    "Invalid API key format",
35                ),
36            })?,
37        );
38
39        // Add organization header if provided
40        if let Some(ref org) = config.organization {
41            headers.insert(
42                "OpenAI-Organization",
43                org.parse().map_err(|_| OpenAIError::Config {
44                    source: ferrous_llm_core::ConfigError::invalid_value(
45                        "organization",
46                        "Invalid organization format",
47                    ),
48                })?,
49            );
50        }
51
52        // Add project header if provided
53        if let Some(ref project) = config.project {
54            headers.insert(
55                "OpenAI-Project",
56                project.parse().map_err(|_| OpenAIError::Config {
57                    source: ferrous_llm_core::ConfigError::invalid_value(
58                        "project",
59                        "Invalid project format",
60                    ),
61                })?,
62            );
63        }
64
65        // Add user agent
66        if let Some(ref user_agent) = config.http.user_agent {
67            headers.insert(
68                reqwest::header::USER_AGENT,
69                user_agent.parse().map_err(|_| OpenAIError::Config {
70                    source: ferrous_llm_core::ConfigError::invalid_value(
71                        "user_agent",
72                        "Invalid user agent format",
73                    ),
74                })?,
75            );
76        }
77
78        // Add custom headers
79        for (key, value) in &config.http.headers {
80            let header_name: reqwest::header::HeaderName =
81                key.parse().map_err(|_| OpenAIError::Config {
82                    source: ferrous_llm_core::ConfigError::invalid_value(
83                        "headers",
84                        "Invalid header name",
85                    ),
86                })?;
87            let header_value: reqwest::header::HeaderValue =
88                value.parse().map_err(|_| OpenAIError::Config {
89                    source: ferrous_llm_core::ConfigError::invalid_value(
90                        "headers",
91                        "Invalid header value",
92                    ),
93                })?;
94            headers.insert(header_name, header_value);
95        }
96
97        let mut client_builder = Client::builder()
98            .timeout(config.http.timeout)
99            .default_headers(headers);
100
101        // Configure compression
102        if !config.http.compression {
103            client_builder = client_builder.no_gzip();
104        }
105
106        // Configure connection pool
107        client_builder = client_builder
108            .pool_max_idle_per_host(config.http.pool.max_idle_connections)
109            .pool_idle_timeout(config.http.pool.idle_timeout)
110            .connect_timeout(config.http.pool.connect_timeout);
111
112        let client = client_builder
113            .build()
114            .map_err(|e| OpenAIError::Network { source: e })?;
115
116        Ok(Self { config, client })
117    }
118
119    /// Create a request builder with common settings.
120    fn request_builder(&self, method: reqwest::Method, url: &str) -> RequestBuilder {
121        self.client.request(method, url)
122    }
123
124    /// Handle HTTP response and convert to appropriate error.
125    async fn handle_response<T>(&self, response: reqwest::Response) -> Result<T, OpenAIError>
126    where
127        T: serde::de::DeserializeOwned,
128    {
129        let status = response.status();
130
131        if status.is_success() {
132            response
133                .json()
134                .await
135                .map_err(|e| OpenAIError::Network { source: e })
136        } else {
137            let body = response.text().await.unwrap_or_default();
138            Err(OpenAIError::from_response(status.as_u16(), &body))
139        }
140    }
141
142    /// Convert core ChatRequest to OpenAI format.
143    fn convert_chat_request(&self, request: &ChatRequest) -> OpenAIChatRequest {
144        OpenAIChatRequest {
145            model: self.config.model.clone(),
146            messages: request.messages.iter().map(|m| m.into()).collect(),
147            temperature: request.parameters.temperature,
148            max_tokens: request.parameters.max_tokens,
149            top_p: request.parameters.top_p,
150            frequency_penalty: request.parameters.frequency_penalty,
151            presence_penalty: request.parameters.presence_penalty,
152            stop: request.parameters.stop_sequences.clone(),
153            stream: Some(false),
154            tools: None, // Will be set by chat_with_tools
155            tool_choice: None,
156            user: request.metadata.user_id.clone(),
157        }
158    }
159
160    /// Convert core CompletionRequest to OpenAI format.
161    fn convert_completion_request(&self, request: &CompletionRequest) -> OpenAICompletionRequest {
162        OpenAICompletionRequest {
163            model: self.config.model.clone(),
164            prompt: request.prompt.clone(),
165            max_tokens: request.parameters.max_tokens,
166            temperature: request.parameters.temperature,
167            top_p: request.parameters.top_p,
168            frequency_penalty: request.parameters.frequency_penalty,
169            presence_penalty: request.parameters.presence_penalty,
170            stop: request.parameters.stop_sequences.clone(),
171            stream: Some(false),
172            user: request.metadata.user_id.clone(),
173        }
174    }
175}
176
177#[async_trait]
178impl ChatProvider for OpenAIProvider {
179    type Config = OpenAIConfig;
180    type Response = OpenAIChatResponse;
181    type Error = OpenAIError;
182
183    async fn chat(&self, request: ChatRequest) -> ProviderResult<Self::Response, Self::Error> {
184        let openai_request = self.convert_chat_request(&request);
185
186        let response = self
187            .request_builder(reqwest::Method::POST, &self.config.chat_url())
188            .json(&openai_request)
189            .send()
190            .await
191            .map_err(|e| OpenAIError::Network { source: e })?;
192
193        self.handle_response(response).await
194    }
195}
196
197#[async_trait]
198impl CompletionProvider for OpenAIProvider {
199    type Config = OpenAIConfig;
200    type Response = OpenAICompletionResponse;
201    type Error = OpenAIError;
202
203    async fn complete(
204        &self,
205        request: CompletionRequest,
206    ) -> ProviderResult<Self::Response, Self::Error> {
207        let openai_request = self.convert_completion_request(&request);
208
209        let response = self
210            .request_builder(reqwest::Method::POST, &self.config.completions_url())
211            .json(&openai_request)
212            .send()
213            .await
214            .map_err(|e| OpenAIError::Network { source: e })?;
215
216        self.handle_response(response).await
217    }
218}
219
220#[async_trait]
221impl EmbeddingProvider for OpenAIProvider {
222    type Config = OpenAIConfig;
223    type Error = OpenAIError;
224
225    async fn embed(&self, texts: &[String]) -> ProviderResult<Vec<Embedding>, Self::Error> {
226        let request = OpenAIEmbeddingsRequest {
227            model: self
228                .config
229                .embedding_model
230                .clone()
231                .unwrap_or_else(|| "text-embedding-ada-002".to_string()),
232            input: if texts.len() == 1 {
233                json!(texts[0])
234            } else {
235                json!(texts)
236            },
237            encoding_format: Some("float".to_string()),
238            dimensions: None,
239            user: None,
240        };
241
242        let response = self
243            .request_builder(reqwest::Method::POST, &self.config.embeddings_url())
244            .json(&request)
245            .send()
246            .await
247            .map_err(|e| OpenAIError::Network { source: e })?;
248
249        let embeddings_response: OpenAIEmbeddingsResponse = self.handle_response(response).await?;
250
251        let embeddings = embeddings_response
252            .data
253            .into_iter()
254            .map(|e| Embedding {
255                embedding: e.embedding,
256                index: e.index,
257            })
258            .collect();
259
260        Ok(embeddings)
261    }
262}
263
264#[async_trait]
265impl StreamingProvider for OpenAIProvider {
266    type StreamItem = String;
267    type Stream = Pin<Box<dyn Stream<Item = Result<Self::StreamItem, Self::Error>> + Send>>;
268
269    async fn chat_stream(&self, request: ChatRequest) -> ProviderResult<Self::Stream, Self::Error> {
270        let mut openai_request = self.convert_chat_request(&request);
271        openai_request.stream = Some(true);
272
273        let response = self
274            .request_builder(reqwest::Method::POST, &self.config.chat_url())
275            .json(&openai_request)
276            .send()
277            .await
278            .map_err(|e| OpenAIError::Network { source: e })?;
279
280        if !response.status().is_success() {
281            let status = response.status().as_u16();
282            let body = response.text().await.unwrap_or_default();
283            return Err(OpenAIError::from_response(status, &body));
284        }
285
286        // Create a tokio channel for streaming
287        let (tx, rx) = tokio::sync::mpsc::channel::<Result<String, OpenAIError>>(100);
288
289        // Spawn a task to process the SSE stream
290        let tx_clone = tx.clone();
291        tokio::spawn(async move {
292            let mut byte_stream = response.bytes_stream();
293            let mut buffer = Vec::new();
294
295            while let Some(chunk_result) = byte_stream.next().await {
296                match chunk_result {
297                    Ok(chunk) => {
298                        buffer.extend_from_slice(chunk.as_ref());
299
300                        // Process complete lines
301                        let mut start = 0;
302                        while let Some(pos) = buffer[start..].iter().position(|&b| b == b'\n') {
303                            let line_end = start + pos;
304                            let line = String::from_utf8_lossy(&buffer[start..line_end])
305                                .trim()
306                                .to_string();
307                            start = line_end + 1;
308
309                            // Process SSE format: "data: {json}" or "data: [DONE]"
310                            if let Some(data) = line.strip_prefix("data: ") {
311                                if data == "[DONE]" {
312                                    // End of stream
313                                    drop(tx_clone);
314                                    return;
315                                }
316
317                                // Try to parse the JSON chunk
318                                if let Ok(chunk) = serde_json::from_str::<OpenAIStreamChunk>(data)
319                                    && let Some(choice) = chunk.choices.first()
320                                    && let Some(content) = &choice.delta.content
321                                    && !content.is_empty()
322                                    && tx_clone.send(Ok(content.clone())).await.is_err()
323                                {
324                                    // Receiver dropped
325                                    return;
326                                }
327                            }
328                        }
329
330                        // Keep remaining bytes in buffer
331                        buffer.drain(0..start);
332                    }
333                    Err(e) => {
334                        let _ = tx_clone.send(Err(OpenAIError::Network { source: e })).await;
335                        return;
336                    }
337                }
338            }
339
340            // Close the channel when done
341            drop(tx_clone);
342        });
343
344        // Convert the receiver to a stream
345        let content_stream = ReceiverStream::new(rx);
346
347        Ok(Box::pin(content_stream))
348    }
349}
350
351#[async_trait]
352impl ToolProvider for OpenAIProvider {
353    async fn chat_with_tools(
354        &self,
355        request: ChatRequest,
356        tools: &[Tool],
357    ) -> ProviderResult<Self::Response, Self::Error> {
358        let mut openai_request = self.convert_chat_request(&request);
359
360        if !tools.is_empty() {
361            openai_request.tools = Some(tools.iter().map(|t| t.into()).collect());
362            openai_request.tool_choice = Some(json!("auto"));
363        }
364
365        let response = self
366            .request_builder(reqwest::Method::POST, &self.config.chat_url())
367            .json(&openai_request)
368            .send()
369            .await
370            .map_err(|e| OpenAIError::Network { source: e })?;
371
372        self.handle_response(response).await
373    }
374}
375
376#[cfg(test)]
377mod tests {
378    use super::*;
379    use ferrous_llm_core::{Message, Metadata, Parameters};
380
381    fn create_test_config() -> OpenAIConfig {
382        OpenAIConfig::new("sk-test123456789", "gpt-3.5-turbo")
383    }
384
385    #[test]
386    fn test_provider_creation() {
387        let config = create_test_config();
388        let provider = OpenAIProvider::new(config);
389        assert!(provider.is_ok());
390    }
391
392    #[test]
393    fn test_convert_chat_request() {
394        let config = create_test_config();
395        let provider = OpenAIProvider::new(config).unwrap();
396
397        let request = ChatRequest {
398            messages: vec![Message::user("Hello")],
399            parameters: Parameters {
400                temperature: Some(0.7),
401                max_tokens: Some(100),
402                ..Default::default()
403            },
404            metadata: Metadata::default(),
405        };
406
407        let openai_request = provider.convert_chat_request(&request);
408        assert_eq!(openai_request.model, "gpt-3.5-turbo");
409        assert_eq!(openai_request.temperature, Some(0.7));
410        assert_eq!(openai_request.max_tokens, Some(100));
411        assert_eq!(openai_request.messages.len(), 1);
412    }
413}