1use crate::types::*;
3
4use super::{ChatMessage, LLM, LLMOptions, LLMResult, ResponseFormat};
5use serde_json::json;
6use std::future::Future;
7use std::pin::Pin;
8
9#[derive(Debug, Clone)]
10pub enum OpenAIModel {
11 Gpt4, Gpt4Turbo, Gpt4Vision, Gpt432k, Gpt35Turbo, Gpt35Turbo16k, O1Preview, O1Mini, }
20
21impl OpenAIModel {
22 fn as_str(&self) -> &'static str {
23 match self {
24 OpenAIModel::Gpt4 => "gpt-4",
25 OpenAIModel::Gpt4Turbo => "gpt-4-turbo-preview",
26 OpenAIModel::Gpt4Vision => "gpt-4-vision-preview",
27 OpenAIModel::Gpt432k => "gpt-4-32k",
28 OpenAIModel::Gpt35Turbo => "gpt-3.5-turbo",
29 OpenAIModel::Gpt35Turbo16k => "gpt-3.5-turbo-16k",
30 OpenAIModel::O1Preview => "o1-preview",
31 OpenAIModel::O1Mini => "o1-mini",
32 }
33 }
34}
35
36impl From<OpenAIModel> for String {
37 fn from(model: OpenAIModel) -> Self {
38 model.as_str().to_string()
39 }
40}
41
42pub struct OpenAI {
43 api_key: String,
44 model: OpenAIModel,
45 base_url: String,
46 client: reqwest::Client,
47 default_options: LLMOptions,
48}
49
50impl OpenAI {
51 pub fn new(api_key: String) -> Self {
52 Self {
53 api_key,
54 model: OpenAIModel::Gpt4Turbo,
55 base_url: "https://api.openai.com/v1".to_string(),
56 client: reqwest::Client::new(),
57 default_options: LLMOptions::default(),
58 }
59 }
60
61 pub fn with_model(mut self, model: OpenAIModel) -> Self {
62 self.model = model;
63 self
64 }
65
66 pub fn gpt4(self) -> Self {
67 self.with_model(OpenAIModel::Gpt4)
68 }
69
70 pub fn gpt4_turbo(self) -> Self {
71 self.with_model(OpenAIModel::Gpt4Turbo)
72 }
73
74 pub fn gpt35_turbo(self) -> Self {
75 self.with_model(OpenAIModel::Gpt35Turbo)
76 }
77
78 pub fn o1_preview(self) -> Self {
79 self.with_model(OpenAIModel::O1Preview)
80 }
81
82 pub fn with_temperature(mut self, temperature: f32) -> Self {
83 self.default_options.temperature = Some(temperature);
84 self
85 }
86
87 pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
88 self.default_options.max_tokens = Some(max_tokens);
89 self
90 }
91
92 pub fn with_top_p(mut self, top_p: f32) -> Self {
93 self.default_options.top_p = Some(top_p);
94 self
95 }
96
97 pub fn with_base_url(mut self, base_url: &str) -> Self {
98 self.base_url = base_url.to_string();
99 self
100 }
101
102 pub fn with_json_mode(mut self) -> Self {
103 self.default_options.response_format = Some(ResponseFormat::Json);
104 self
105 }
106
107 async fn chat_completion(
108 &self,
109 messages: &[ChatMessage],
110 options: &LLMOptions,
111 ) -> Result<String> {
112 let model_name: String = self.model.clone().into();
113
114 let mut request_body = json!({
115 "model": model_name,
116 "messages": messages.iter().map(|m| json!({
117 "role": m.role,
118 "content": m.content,
119 })).collect::<Vec<_>>(),
120 });
121
122 if let Some(temp) = options.temperature.or(self.default_options.temperature) {
123 request_body["temperature"] = json!(temp);
124 }
125 if let Some(max_tokens) = options.max_tokens.or(self.default_options.max_tokens) {
126 request_body["max_tokens"] = json!(max_tokens);
127 }
128 if let Some(top_p) = options.top_p.or(self.default_options.top_p) {
129 request_body["top_p"] = json!(top_p);
130 }
131 if let Some(freq_penalty) = options
132 .frequency_penalty
133 .or(self.default_options.frequency_penalty)
134 {
135 request_body["frequency_penalty"] = json!(freq_penalty);
136 }
137 if let Some(pres_penalty) = options
138 .presence_penalty
139 .or(self.default_options.presence_penalty)
140 {
141 request_body["presence_penalty"] = json!(pres_penalty);
142 }
143 if let Some(stop) = options
144 .stop_sequences
145 .as_ref()
146 .or(self.default_options.stop_sequences.as_ref())
147 {
148 request_body["stop"] = json!(stop);
149 }
150 if let Some(response_format) = &options.response_format {
151 match response_format {
152 ResponseFormat::Json => {
153 request_body["response_format"] = json!({ "type": "json_object" });
154 }
155 ResponseFormat::JsonSchema { schema } => {
156 request_body["response_format"] = json!({
157 "type": "json_schema",
158 "json_schema": schema
159 });
160 }
161 _ => {}
162 }
163 }
164
165 let response = self
166 .client
167 .post(format!("{}/chat/completions", self.base_url))
168 .header("Authorization", format!("Bearer {}", self.api_key))
169 .header("Content-Type", "application/json")
170 .json(&request_body)
171 .send()
172 .await
173 .map_err(|e| LangHubError::LLMError(format!("OpenAI request error: {}", e)))?;
174
175 if !response.status().is_success() {
176 let status = response.status();
177 let error_text = response.text().await.unwrap_or_default();
178 return Err(LangHubError::LLMError(format!(
179 "OpenAI API error ({}): {}",
180 status, error_text
181 )));
182 }
183
184 let json: serde_json::Value = response
185 .json()
186 .await
187 .map_err(|e| LangHubError::LLMError(format!("JSON parse error: {}", e)))?;
188
189 let text = json["choices"][0]["message"]["content"]
190 .as_str()
191 .ok_or_else(|| {
192 LangHubError::ParseError("Missing 'content' field in response".to_string())
193 })?
194 .to_string();
195
196 Ok(text)
197 }
198}
199
200impl LLM for OpenAI {
201 fn generate(
202 &self,
203 prompt: &str,
204 ) -> Pin<Box<dyn Future<Output = Result<LLMResult>> + Send + '_>> {
205 let prompt = prompt.to_string();
206 let options = self.default_options.clone();
207 Box::pin(async move {
208 let messages = vec![ChatMessage::user(&prompt)];
209 let text = self.chat_completion(&messages, &options).await?;
210 Ok(LLMResult {
211 text,
212 metadata: None,
213 })
214 })
215 }
216
217 fn generate_with_options(
218 &self,
219 prompt: &str,
220 options: LLMOptions,
221 ) -> Pin<Box<dyn Future<Output = Result<LLMResult>> + Send + '_>> {
222 let prompt = prompt.to_string();
223 Box::pin(async move {
224 let messages = vec![ChatMessage::user(&prompt)];
225 let text = self.chat_completion(&messages, &options).await?;
226 Ok(LLMResult {
227 text,
228 metadata: None,
229 })
230 })
231 }
232
233 fn chat(
234 &self,
235 messages: Vec<ChatMessage>,
236 ) -> Pin<Box<dyn Future<Output = Result<LLMResult>> + Send + '_>> {
237 Box::pin(async move {
238 let text = self
239 .chat_completion(&messages, &LLMOptions::default())
240 .await?;
241 Ok(LLMResult {
242 text,
243 metadata: None,
244 })
245 })
246 }
247
248 fn get_model_name(&self) -> &str {
249 match self.model {
250 OpenAIModel::Gpt4 => "gpt-4",
251 OpenAIModel::Gpt4Turbo => "gpt-4-turbo-preview",
252 OpenAIModel::Gpt4Vision => "gpt-4-vision-preview",
253 OpenAIModel::Gpt432k => "gpt-4-32k",
254 OpenAIModel::Gpt35Turbo => "gpt-3.5-turbo",
255 OpenAIModel::Gpt35Turbo16k => "gpt-3.5-turbo-16k",
256 OpenAIModel::O1Preview => "o1-preview",
257 OpenAIModel::O1Mini => "o1-mini",
258 }
259 }
260
261 fn get_provider_name(&self) -> &str {
262 match self.model {
263 OpenAIModel::Gpt4 => "OpenAI-GPT4",
264 OpenAIModel::Gpt4Turbo => "OpenAI-GPT4-Turbo",
265 OpenAIModel::Gpt4Vision => "OpenAI-GPT4-Vision",
266 OpenAIModel::Gpt432k => "OpenAI-GPT4-32K",
267 OpenAIModel::Gpt35Turbo => "OpenAI-GPT3.5-Turbo",
268 OpenAIModel::Gpt35Turbo16k => "OpenAI-GPT3.5-Turbo-16K",
269 OpenAIModel::O1Preview => "OpenAI-O1-Preview",
270 OpenAIModel::O1Mini => "OpenAI-O1-Mini",
271 }
272 }
273
274 fn get_provider_enum(&self) -> ModelProvider {
275 ModelProvider::OpenAI
276 }
277
278 fn supports_function_calling(&self) -> bool {
279 true
280 }
281
282 fn supports_json_mode(&self) -> bool {
283 true
284 }
285
286 fn max_context_length(&self) -> Option<usize> {
287 match self.model {
288 OpenAIModel::Gpt4Turbo => Some(128000),
289 OpenAIModel::Gpt4 => Some(8192),
290 OpenAIModel::Gpt432k => Some(32768),
291 OpenAIModel::Gpt35Turbo16k => Some(16384),
292 OpenAIModel::O1Preview => Some(128000),
293 OpenAIModel::O1Mini => Some(128000),
294 _ => Some(4096),
295 }
296 }
297}