Skip to main content

langhub_core/llms/
openai.rs

1/// OpenAI (GPT-4, GPT-3.5, O1)
2use crate::types::*;
3
4use super::{ChatMessage, LLM, LLMOptions, LLMResult, ResponseFormat};
5use serde_json::json;
6use std::future::Future;
7use std::pin::Pin;
8
9#[derive(Debug, Clone)]
10pub enum OpenAIModel {
11    Gpt4,           // GPT-4
12    Gpt4Turbo,      // GPT-4 Turbo
13    Gpt4Vision,     // GPT-4 Vision
14    Gpt432k,        // GPT-4 32K
15    Gpt35Turbo,     // GPT-3.5 Turbo
16    Gpt35Turbo16k,  // GPT-3.5 Turbo 16K
17    O1Preview,      // O1 Preview
18    O1Mini,         // O1 Mini
19}
20
21impl OpenAIModel {
22    fn as_str(&self) -> &'static str {
23        match self {
24            OpenAIModel::Gpt4 => "gpt-4",
25            OpenAIModel::Gpt4Turbo => "gpt-4-turbo-preview",
26            OpenAIModel::Gpt4Vision => "gpt-4-vision-preview",
27            OpenAIModel::Gpt432k => "gpt-4-32k",
28            OpenAIModel::Gpt35Turbo => "gpt-3.5-turbo",
29            OpenAIModel::Gpt35Turbo16k => "gpt-3.5-turbo-16k",
30            OpenAIModel::O1Preview => "o1-preview",
31            OpenAIModel::O1Mini => "o1-mini",
32        }
33    }
34}
35
36impl From<OpenAIModel> for String {
37    fn from(model: OpenAIModel) -> Self {
38        model.as_str().to_string()
39    }
40}
41
42pub struct OpenAI {
43    api_key: String,
44    model: OpenAIModel,
45    base_url: String,
46    client: reqwest::Client,
47    default_options: LLMOptions,
48}
49
50impl OpenAI {
51    pub fn new(api_key: String) -> Self {
52        Self {
53            api_key,
54            model: OpenAIModel::Gpt4Turbo,
55            base_url: "https://api.openai.com/v1".to_string(),
56            client: reqwest::Client::new(),
57            default_options: LLMOptions::default(),
58        }
59    }
60
61    pub fn with_model(mut self, model: OpenAIModel) -> Self {
62        self.model = model;
63        self
64    }
65
66    pub fn gpt4(self) -> Self {
67        self.with_model(OpenAIModel::Gpt4)
68    }
69
70    pub fn gpt4_turbo(self) -> Self {
71        self.with_model(OpenAIModel::Gpt4Turbo)
72    }
73
74    pub fn gpt35_turbo(self) -> Self {
75        self.with_model(OpenAIModel::Gpt35Turbo)
76    }
77
78    pub fn o1_preview(self) -> Self {
79        self.with_model(OpenAIModel::O1Preview)
80    }
81
82    pub fn with_temperature(mut self, temperature: f32) -> Self {
83        self.default_options.temperature = Some(temperature);
84        self
85    }
86
87    pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
88        self.default_options.max_tokens = Some(max_tokens);
89        self
90    }
91
92    pub fn with_top_p(mut self, top_p: f32) -> Self {
93        self.default_options.top_p = Some(top_p);
94        self
95    }
96
97    pub fn with_base_url(mut self, base_url: &str) -> Self {
98        self.base_url = base_url.to_string();
99        self
100    }
101
102    pub fn with_json_mode(mut self) -> Self {
103        self.default_options.response_format = Some(ResponseFormat::Json);
104        self
105    }
106
107    async fn chat_completion(
108        &self,
109        messages: &[ChatMessage],
110        options: &LLMOptions,
111    ) -> Result<String> {
112        let model_name: String = self.model.clone().into();
113
114        let mut request_body = json!({
115            "model": model_name,
116            "messages": messages.iter().map(|m| json!({
117                "role": m.role,
118                "content": m.content,
119            })).collect::<Vec<_>>(),
120        });
121
122        if let Some(temp) = options.temperature.or(self.default_options.temperature) {
123            request_body["temperature"] = json!(temp);
124        }
125        if let Some(max_tokens) = options.max_tokens.or(self.default_options.max_tokens) {
126            request_body["max_tokens"] = json!(max_tokens);
127        }
128        if let Some(top_p) = options.top_p.or(self.default_options.top_p) {
129            request_body["top_p"] = json!(top_p);
130        }
131        if let Some(freq_penalty) = options
132            .frequency_penalty
133            .or(self.default_options.frequency_penalty)
134        {
135            request_body["frequency_penalty"] = json!(freq_penalty);
136        }
137        if let Some(pres_penalty) = options
138            .presence_penalty
139            .or(self.default_options.presence_penalty)
140        {
141            request_body["presence_penalty"] = json!(pres_penalty);
142        }
143        if let Some(stop) = options
144            .stop_sequences
145            .as_ref()
146            .or(self.default_options.stop_sequences.as_ref())
147        {
148            request_body["stop"] = json!(stop);
149        }
150        if let Some(response_format) = &options.response_format {
151            match response_format {
152                ResponseFormat::Json => {
153                    request_body["response_format"] = json!({ "type": "json_object" });
154                }
155                ResponseFormat::JsonSchema { schema } => {
156                    request_body["response_format"] = json!({
157                        "type": "json_schema",
158                        "json_schema": schema
159                    });
160                }
161                _ => {}
162            }
163        }
164
165        let response = self
166            .client
167            .post(format!("{}/chat/completions", self.base_url))
168            .header("Authorization", format!("Bearer {}", self.api_key))
169            .header("Content-Type", "application/json")
170            .json(&request_body)
171            .send()
172            .await
173            .map_err(|e| LangHubError::LLMError(format!("OpenAI request error: {}", e)))?;
174
175        if !response.status().is_success() {
176            let status = response.status();
177            let error_text = response.text().await.unwrap_or_default();
178            return Err(LangHubError::LLMError(format!(
179                "OpenAI API error ({}): {}",
180                status, error_text
181            )));
182        }
183
184        let json: serde_json::Value = response
185            .json()
186            .await
187            .map_err(|e| LangHubError::LLMError(format!("JSON parse error: {}", e)))?;
188
189        let text = json["choices"][0]["message"]["content"]
190            .as_str()
191            .ok_or_else(|| {
192                LangHubError::ParseError("Missing 'content' field in response".to_string())
193            })?
194            .to_string();
195
196        Ok(text)
197    }
198}
199
200impl LLM for OpenAI {
201    fn generate(
202        &self,
203        prompt: &str,
204    ) -> Pin<Box<dyn Future<Output = Result<LLMResult>> + Send + '_>> {
205        let prompt = prompt.to_string();
206        let options = self.default_options.clone();
207        Box::pin(async move {
208            let messages = vec![ChatMessage::user(&prompt)];
209            let text = self.chat_completion(&messages, &options).await?;
210            Ok(LLMResult {
211                text,
212                metadata: None,
213            })
214        })
215    }
216
217    fn generate_with_options(
218        &self,
219        prompt: &str,
220        options: LLMOptions,
221    ) -> Pin<Box<dyn Future<Output = Result<LLMResult>> + Send + '_>> {
222        let prompt = prompt.to_string();
223        Box::pin(async move {
224            let messages = vec![ChatMessage::user(&prompt)];
225            let text = self.chat_completion(&messages, &options).await?;
226            Ok(LLMResult {
227                text,
228                metadata: None,
229            })
230        })
231    }
232
233    fn chat(
234        &self,
235        messages: Vec<ChatMessage>,
236    ) -> Pin<Box<dyn Future<Output = Result<LLMResult>> + Send + '_>> {
237        Box::pin(async move {
238            let text = self
239                .chat_completion(&messages, &LLMOptions::default())
240                .await?;
241            Ok(LLMResult {
242                text,
243                metadata: None,
244            })
245        })
246    }
247
248    fn get_model_name(&self) -> &str {
249        match self.model {
250            OpenAIModel::Gpt4 => "gpt-4",
251            OpenAIModel::Gpt4Turbo => "gpt-4-turbo-preview",
252            OpenAIModel::Gpt4Vision => "gpt-4-vision-preview",
253            OpenAIModel::Gpt432k => "gpt-4-32k",
254            OpenAIModel::Gpt35Turbo => "gpt-3.5-turbo",
255            OpenAIModel::Gpt35Turbo16k => "gpt-3.5-turbo-16k",
256            OpenAIModel::O1Preview => "o1-preview",
257            OpenAIModel::O1Mini => "o1-mini",
258        }
259    }
260
261    fn get_provider_name(&self) -> &str {
262        match self.model {
263            OpenAIModel::Gpt4 => "OpenAI-GPT4",
264            OpenAIModel::Gpt4Turbo => "OpenAI-GPT4-Turbo",
265            OpenAIModel::Gpt4Vision => "OpenAI-GPT4-Vision",
266            OpenAIModel::Gpt432k => "OpenAI-GPT4-32K",
267            OpenAIModel::Gpt35Turbo => "OpenAI-GPT3.5-Turbo",
268            OpenAIModel::Gpt35Turbo16k => "OpenAI-GPT3.5-Turbo-16K",
269            OpenAIModel::O1Preview => "OpenAI-O1-Preview",
270            OpenAIModel::O1Mini => "OpenAI-O1-Mini",
271        }
272    }
273
274    fn get_provider_enum(&self) -> ModelProvider {
275        ModelProvider::OpenAI
276    }
277
278    fn supports_function_calling(&self) -> bool {
279        true
280    }
281
282    fn supports_json_mode(&self) -> bool {
283        true
284    }
285
286    fn max_context_length(&self) -> Option<usize> {
287        match self.model {
288            OpenAIModel::Gpt4Turbo => Some(128000),
289            OpenAIModel::Gpt4 => Some(8192),
290            OpenAIModel::Gpt432k => Some(32768),
291            OpenAIModel::Gpt35Turbo16k => Some(16384),
292            OpenAIModel::O1Preview => Some(128000),
293            OpenAIModel::O1Mini => Some(128000),
294            _ => Some(4096),
295        }
296    }
297}