Skip to main content

agentic_llm/
openai.rs

1//! `OpenAI` adapter implementation.
2
3use async_openai::{
4    config::OpenAIConfig,
5    types::{
6        ChatCompletionRequestAssistantMessage, ChatCompletionRequestAssistantMessageContent,
7        ChatCompletionRequestMessage, ChatCompletionRequestSystemMessage,
8        ChatCompletionRequestUserMessage, CreateChatCompletionRequest,
9    },
10    Client,
11};
12use async_trait::async_trait;
13use futures::{Stream, StreamExt};
14use std::pin::Pin;
15use tracing::{debug, instrument};
16
17use crate::{
18    error::LLMError,
19    traits::{FinishReason, LLMAdapter, LLMMessage, LLMResponse, Role, StreamChunk, TokenUsage},
20};
21
22/// `OpenAI` adapter for GPT models.
23pub struct OpenAIAdapter {
24    client: Client<OpenAIConfig>,
25    model: String,
26    temperature: f32,
27    max_tokens: Option<u32>,
28}
29
30impl OpenAIAdapter {
31    /// Create a new `OpenAI` adapter.
32    ///
33    /// # Arguments
34    ///
35    /// * `api_key` - `OpenAI` API key
36    /// * `model` - Model to use (e.g., "gpt-4o", "gpt-4o-mini")
37    #[must_use]
38    pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Self {
39        let config = OpenAIConfig::new().with_api_key(api_key);
40        Self {
41            client: Client::with_config(config),
42            model: model.into(),
43            temperature: 0.7,
44            max_tokens: None,
45        }
46    }
47
48    /// Set the temperature for generation.
49    #[must_use]
50    pub const fn with_temperature(mut self, temperature: f32) -> Self {
51        self.temperature = temperature;
52        self
53    }
54
55    /// Set the maximum tokens for generation.
56    #[must_use]
57    pub const fn with_max_tokens(mut self, max_tokens: u32) -> Self {
58        self.max_tokens = Some(max_tokens);
59        self
60    }
61
62    /// Convert our message format to `OpenAI`'s format.
63    fn convert_messages(messages: &[LLMMessage]) -> Vec<ChatCompletionRequestMessage> {
64        messages
65            .iter()
66            .map(|msg| match msg.role {
67                Role::System => {
68                    ChatCompletionRequestMessage::System(ChatCompletionRequestSystemMessage {
69                        content: msg.content.clone().into(),
70                        ..Default::default()
71                    })
72                }
73                Role::User => {
74                    ChatCompletionRequestMessage::User(ChatCompletionRequestUserMessage {
75                        content: msg.content.clone().into(),
76                        ..Default::default()
77                    })
78                }
79                Role::Assistant => {
80                    ChatCompletionRequestMessage::Assistant(ChatCompletionRequestAssistantMessage {
81                        content: Some(ChatCompletionRequestAssistantMessageContent::Text(
82                            msg.content.clone(),
83                        )),
84                        ..Default::default()
85                    })
86                }
87            })
88            .collect()
89    }
90}
91
92#[async_trait]
93impl LLMAdapter for OpenAIAdapter {
94    fn provider(&self) -> &'static str {
95        "openai"
96    }
97
98    fn model(&self) -> &str {
99        &self.model
100    }
101
102    #[instrument(skip(self, messages), fields(provider = "openai", model = %self.model))]
103    async fn generate(&self, messages: &[LLMMessage]) -> Result<LLMResponse, LLMError> {
104        debug!("Generating completion with {} messages", messages.len());
105
106        let request = CreateChatCompletionRequest {
107            model: self.model.clone(),
108            messages: Self::convert_messages(messages),
109            temperature: Some(self.temperature),
110            max_completion_tokens: self.max_tokens,
111            ..Default::default()
112        };
113
114        let response = self
115            .client
116            .chat()
117            .create(request)
118            .await
119            .map_err(|e| LLMError::ApiError(e.to_string()))?;
120
121        let choice = response.choices.first().ok_or(LLMError::EmptyResponse)?;
122
123        let content = choice.message.content.clone().unwrap_or_default();
124
125        let usage = response.usage.as_ref();
126
127        Ok(LLMResponse {
128            content,
129            tokens_used: TokenUsage {
130                prompt: usage.map_or(0, |u| u.prompt_tokens),
131                completion: usage.map_or(0, |u| u.completion_tokens),
132                total: usage.map_or(0, |u| u.total_tokens),
133            },
134            finish_reason: match choice.finish_reason {
135                Some(async_openai::types::FinishReason::Length) => FinishReason::Length,
136                _ => FinishReason::Stop,
137            },
138            model: response.model,
139        })
140    }
141
142    fn generate_stream(
143        &self,
144        messages: &[LLMMessage],
145    ) -> Pin<Box<dyn Stream<Item = Result<StreamChunk, LLMError>> + Send + '_>> {
146        let request = CreateChatCompletionRequest {
147            model: self.model.clone(),
148            messages: Self::convert_messages(messages),
149            temperature: Some(self.temperature),
150            max_completion_tokens: self.max_tokens,
151            stream: Some(true),
152            ..Default::default()
153        };
154
155        Box::pin(async_stream::try_stream! {
156            let mut stream = self
157                .client
158                .chat()
159                .create_stream(request)
160                .await
161                .map_err(|e| LLMError::ApiError(e.to_string()))?;
162
163            while let Some(result) = stream.next().await {
164                let response = result.map_err(|e| LLMError::ApiError(e.to_string()))?;
165
166                if let Some(choice) = response.choices.first() {
167                    let content = choice.delta.content.clone().unwrap_or_default();
168                    let done = choice.finish_reason.is_some();
169
170                    yield StreamChunk {
171                        content,
172                        done,
173                        tokens_used: None,
174                        finish_reason: choice.finish_reason.map(|r| match r {
175                            async_openai::types::FinishReason::Length => FinishReason::Length,
176                            _ => FinishReason::Stop,
177                        }),
178                    };
179                }
180            }
181        })
182    }
183
184    async fn health_check(&self) -> Result<bool, LLMError> {
185        self.client
186            .models()
187            .list()
188            .await
189            .map(|_| true)
190            .map_err(|e| LLMError::ConnectionError(e.to_string()))
191    }
192}
193
194#[cfg(test)]
195mod tests {
196    use super::*;
197
198    #[test]
199    fn test_message_conversion() {
200        let messages = vec![
201            LLMMessage::system("You are helpful."),
202            LLMMessage::user("Hello"),
203        ];
204
205        let converted = OpenAIAdapter::convert_messages(&messages);
206        assert_eq!(converted.len(), 2);
207    }
208}