1use async_openai::{
4 config::OpenAIConfig,
5 types::{
6 ChatCompletionRequestAssistantMessage, ChatCompletionRequestAssistantMessageContent,
7 ChatCompletionRequestMessage, ChatCompletionRequestSystemMessage,
8 ChatCompletionRequestUserMessage, CreateChatCompletionRequest,
9 },
10 Client,
11};
12use async_trait::async_trait;
13use futures::{Stream, StreamExt};
14use std::pin::Pin;
15use tracing::{debug, instrument};
16
17use crate::{
18 error::LLMError,
19 traits::{FinishReason, LLMAdapter, LLMMessage, LLMResponse, Role, StreamChunk, TokenUsage},
20};
21
22pub struct OpenAIAdapter {
24 client: Client<OpenAIConfig>,
25 model: String,
26 temperature: f32,
27 max_tokens: Option<u32>,
28}
29
30impl OpenAIAdapter {
31 #[must_use]
38 pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Self {
39 let config = OpenAIConfig::new().with_api_key(api_key);
40 Self {
41 client: Client::with_config(config),
42 model: model.into(),
43 temperature: 0.7,
44 max_tokens: None,
45 }
46 }
47
48 #[must_use]
50 pub const fn with_temperature(mut self, temperature: f32) -> Self {
51 self.temperature = temperature;
52 self
53 }
54
55 #[must_use]
57 pub const fn with_max_tokens(mut self, max_tokens: u32) -> Self {
58 self.max_tokens = Some(max_tokens);
59 self
60 }
61
62 fn convert_messages(messages: &[LLMMessage]) -> Vec<ChatCompletionRequestMessage> {
64 messages
65 .iter()
66 .map(|msg| match msg.role {
67 Role::System => {
68 ChatCompletionRequestMessage::System(ChatCompletionRequestSystemMessage {
69 content: msg.content.clone().into(),
70 ..Default::default()
71 })
72 }
73 Role::User => {
74 ChatCompletionRequestMessage::User(ChatCompletionRequestUserMessage {
75 content: msg.content.clone().into(),
76 ..Default::default()
77 })
78 }
79 Role::Assistant => {
80 ChatCompletionRequestMessage::Assistant(ChatCompletionRequestAssistantMessage {
81 content: Some(ChatCompletionRequestAssistantMessageContent::Text(
82 msg.content.clone(),
83 )),
84 ..Default::default()
85 })
86 }
87 })
88 .collect()
89 }
90}
91
92#[async_trait]
93impl LLMAdapter for OpenAIAdapter {
94 fn provider(&self) -> &'static str {
95 "openai"
96 }
97
98 fn model(&self) -> &str {
99 &self.model
100 }
101
102 #[instrument(skip(self, messages), fields(provider = "openai", model = %self.model))]
103 async fn generate(&self, messages: &[LLMMessage]) -> Result<LLMResponse, LLMError> {
104 debug!("Generating completion with {} messages", messages.len());
105
106 let request = CreateChatCompletionRequest {
107 model: self.model.clone(),
108 messages: Self::convert_messages(messages),
109 temperature: Some(self.temperature),
110 max_completion_tokens: self.max_tokens,
111 ..Default::default()
112 };
113
114 let response = self
115 .client
116 .chat()
117 .create(request)
118 .await
119 .map_err(|e| LLMError::ApiError(e.to_string()))?;
120
121 let choice = response.choices.first().ok_or(LLMError::EmptyResponse)?;
122
123 let content = choice.message.content.clone().unwrap_or_default();
124
125 let usage = response.usage.as_ref();
126
127 Ok(LLMResponse {
128 content,
129 tokens_used: TokenUsage {
130 prompt: usage.map_or(0, |u| u.prompt_tokens),
131 completion: usage.map_or(0, |u| u.completion_tokens),
132 total: usage.map_or(0, |u| u.total_tokens),
133 },
134 finish_reason: match choice.finish_reason {
135 Some(async_openai::types::FinishReason::Length) => FinishReason::Length,
136 _ => FinishReason::Stop,
137 },
138 model: response.model,
139 })
140 }
141
142 fn generate_stream(
143 &self,
144 messages: &[LLMMessage],
145 ) -> Pin<Box<dyn Stream<Item = Result<StreamChunk, LLMError>> + Send + '_>> {
146 let request = CreateChatCompletionRequest {
147 model: self.model.clone(),
148 messages: Self::convert_messages(messages),
149 temperature: Some(self.temperature),
150 max_completion_tokens: self.max_tokens,
151 stream: Some(true),
152 ..Default::default()
153 };
154
155 Box::pin(async_stream::try_stream! {
156 let mut stream = self
157 .client
158 .chat()
159 .create_stream(request)
160 .await
161 .map_err(|e| LLMError::ApiError(e.to_string()))?;
162
163 while let Some(result) = stream.next().await {
164 let response = result.map_err(|e| LLMError::ApiError(e.to_string()))?;
165
166 if let Some(choice) = response.choices.first() {
167 let content = choice.delta.content.clone().unwrap_or_default();
168 let done = choice.finish_reason.is_some();
169
170 yield StreamChunk {
171 content,
172 done,
173 tokens_used: None,
174 finish_reason: choice.finish_reason.map(|r| match r {
175 async_openai::types::FinishReason::Length => FinishReason::Length,
176 _ => FinishReason::Stop,
177 }),
178 };
179 }
180 }
181 })
182 }
183
184 async fn health_check(&self) -> Result<bool, LLMError> {
185 self.client
186 .models()
187 .list()
188 .await
189 .map(|_| true)
190 .map_err(|e| LLMError::ConnectionError(e.to_string()))
191 }
192}
193
194#[cfg(test)]
195mod tests {
196 use super::*;
197
198 #[test]
199 fn test_message_conversion() {
200 let messages = vec![
201 LLMMessage::system("You are helpful."),
202 LLMMessage::user("Hello"),
203 ];
204
205 let converted = OpenAIAdapter::convert_messages(&messages);
206 assert_eq!(converted.len(), 2);
207 }
208}