1use super::chat::{ChatCompletion, ChatModel};
3use super::message::{ChatMessage, ChatMessageContent, TokenUsage};
4use anyhow::Error;
5use reqwest::Client;
6use serde::{Deserialize, Serialize};
7use serde_json;
8use std::collections::HashMap;
9use log::info;
10#[derive(Serialize, Deserialize, Clone)]
11struct OpenAIMessage {
12 role: String,
13 content: String,
14 name: Option<String>,
15 #[serde(skip_serializing_if = "Option::is_none")]
16 tool_call_id: Option<String>,
17}
18
19#[derive(Deserialize, Default)]
21struct InputTokenDetails {
22 audio_tokens: Option<usize>,
23 cache_read: Option<usize>,
24 reasoning_tokens: Option<usize>,
25 }
27
28#[derive(Deserialize, Default)]
29struct OutputTokenDetails {
30 cache_write: Option<usize>,
31 reasoning_tokens: Option<usize>,
32 }
34
35#[derive(Deserialize, Default)]
37struct OpenAIUsage {
38 prompt_tokens: usize,
39 completion_tokens: usize,
40 total_tokens: usize,
41 input_tokens_details: Option<InputTokenDetails>,
43 output_tokens_details: Option<OutputTokenDetails>,
44}
45
46#[derive(Deserialize, Default)]
48struct OpenAIResponsesUsage {
49 input_tokens: Option<usize>,
50 output_tokens: Option<usize>,
51 total_tokens: Option<usize>,
52 input_tokens_details: Option<InputTokenDetails>,
54 output_tokens_details: Option<OutputTokenDetails>,
55}
56
57#[derive(Deserialize)]
59struct OpenAIResponse {
60 id: Option<String>,
61 object: Option<String>,
62 created: Option<u64>,
63 model: Option<String>,
64 choices: Vec<OpenAIChoice>, usage: Option<OpenAIUsage>,
66 output: Option<Vec<OpenAIChoice>>,
68 }
70
71#[derive(Deserialize)]
72struct OpenAIChoice {
73 index: u32,
74 message: OpenAIMessage,
75 finish_reason: String,
76}
77
78#[derive(Debug, Clone, Copy)]
80enum OpenAIApiType {
81 ChatCompletions,
82 Responses,
83}
84
85#[derive(Clone)]
87pub struct OpenAIChatModel {
88 client: Client,
89 api_key: String,
90 base_url: String,
91 model_name: Option<String>,
92 temperature: Option<f32>,
93 max_tokens: Option<u32>,
94 api_type: OpenAIApiType,
95 additional_headers: HashMap<String, String>,
96 additional_params: HashMap<String, serde_json::Value>,
97}
98
99impl OpenAIChatModel {
100 pub fn new(api_key: String, base_url: Option<String>) -> Self {
102 Self {
103 client: Client::new(),
104 api_key,
105 base_url: base_url.unwrap_or_else(|| "https://api.openai.com/v1".to_string()),
106 model_name: None,
107 temperature: Some(0.7),
108 max_tokens: None,
109 api_type: OpenAIApiType::ChatCompletions,
110 additional_headers: HashMap::new(),
111 additional_params: HashMap::new(),
112 }
113 }
114
115 pub fn model_name(&self) -> Option<&String> {
117 self.model_name.as_ref()
118 }
119
120 pub fn base_url(&self) -> &String {
122 &self.base_url
123 }
124
125 pub fn temperature(&self) -> Option<f32> {
127 self.temperature
128 }
129
130 pub fn max_tokens(&self) -> Option<u32> {
132 self.max_tokens
133 }
134
135 pub fn with_model(mut self, model_name: String) -> Self {
137 self.model_name = Some(model_name);
138 self
139 }
140
141 pub fn with_temperature(mut self, temperature: f32) -> Self {
143 self.temperature = Some(temperature);
144 self
145 }
146
147 pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
149 self.max_tokens = Some(max_tokens);
150 self
151 }
152
153 pub fn with_api_type(mut self, api_type: OpenAIApiType) -> Self {
155 self.api_type = api_type;
156 self
157 }
158
159 pub fn with_additional_header(mut self, key: String, value: String) -> Self {
161 self.additional_headers.insert(key, value);
162 self
163 }
164
165 pub fn with_additional_param(mut self, key: String, value: serde_json::Value) -> Self {
167 self.additional_params.insert(key, value);
168 self
169 }
170
171 fn _get_request_payload(&self, messages: &[OpenAIMessage]) -> Result<serde_json::Value, Error> {
173 Ok(serde_json::json!({"messages": messages}))
174 }
175
176 fn _convert_message_to_dict(&self, message: &OpenAIMessage) -> Result<serde_json::Value, Error> {
178 Ok(serde_json::to_value(message)?)
179 }
180
181 fn _construct_responses_api_payload(&self, messages: &[OpenAIMessage]) -> Result<serde_json::Value, Error> {
183 Ok(serde_json::json!({"messages": messages}))
184 }
185
186 fn _create_usage_metadata(&self, usage: &OpenAIUsage) -> TokenUsage {
188 TokenUsage {
189 prompt_tokens: usage.prompt_tokens,
190 completion_tokens: usage.completion_tokens,
191 total_tokens: usage.total_tokens,
192 }
193 }
194
195 fn _create_usage_metadata_responses(&self, usage: &OpenAIResponsesUsage) -> TokenUsage {
197 TokenUsage {
198 prompt_tokens: usage.input_tokens.unwrap_or(0),
199 completion_tokens: usage.output_tokens.unwrap_or(0),
200 total_tokens: usage.total_tokens.unwrap_or(0),
201 }
202 }
203
204 fn _convert_dict_to_message(&self, message_dict: serde_json::Value) -> Result<ChatMessage, Error> {
206 let role = message_dict.get("role").and_then(|v| v.as_str()).unwrap_or("assistant");
208 let content = message_dict.get("content").and_then(|v| v.as_str()).unwrap_or("").to_string();
209
210 let chat_content = ChatMessageContent {
211 content,
212 name: None,
213 additional_kwargs: HashMap::new(),
214 };
215
216 match role {
217 "system" => Ok(ChatMessage::System(chat_content)),
218 "user" => Ok(ChatMessage::Human(chat_content)),
219 "assistant" => Ok(ChatMessage::AIMessage(chat_content)),
220 "tool" => Ok(ChatMessage::ToolMessage(chat_content)),
221 _ => Ok(ChatMessage::AIMessage(chat_content)),
222 }
223 }
224}
225
226impl ChatModel for OpenAIChatModel {
227 fn model_name(&self) -> Option<&str> {
228 self.model_name.as_deref()
229 }
230
231 fn base_url(&self) -> String {
232 self.base_url.to_string()
233 }
234
235 fn invoke(&self, messages: Vec<ChatMessage>) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<ChatCompletion, Error>> + Send + '_>> {
236 let messages = messages;
237 let client = self.client.clone();
238 let api_key = self.api_key.clone();
239 let base_url = self.base_url.clone();
240 let model_name = self.model_name.clone();
241 let temperature = self.temperature;
242 let max_tokens = self.max_tokens;
243 let additional_headers = self.additional_headers.clone();
244 let additional_params = self.additional_params.clone();
245
246 Box::pin(async move {
247 let openai_messages: Vec<OpenAIMessage> = messages
249 .into_iter()
250 .map(|msg| match msg {
251 ChatMessage::System(content) => OpenAIMessage {
252 role: "system".to_string(),
253 content: content.content,
254 name: content.name,
255 tool_call_id: None,
256 },
257 ChatMessage::Human(content) => OpenAIMessage {
258 role: "user".to_string(),
259 content: content.content,
260 name: content.name,
261 tool_call_id: None,
262 },
263 ChatMessage::AIMessage(content) => OpenAIMessage {
264 role: "assistant".to_string(),
265 content: content.content,
266 name: content.name,
267 tool_call_id: None,
268 },
269 ChatMessage::ToolMessage(content) => {
270 info!("Converting tool message: role=tool, content={}", content.content);
271 let tool_call_id = content.additional_kwargs.get("tool_call_id")
273 .and_then(|v| v.as_str())
274 .unwrap_or("default_tool_call_id").to_string();
275 OpenAIMessage {
276 role: "tool".to_string(),
277 content: content.content,
278 name: content.name,
279 tool_call_id: Some(tool_call_id),
280 }
281 },
282 })
283 .collect();
284
285 let mut request_body = serde_json::json!({
287 "messages": openai_messages,
288 "model": model_name.clone().unwrap_or("".to_string()),
289 });
290
291 if let Some(temp) = temperature {
293 request_body["temperature"] = serde_json::json!(temp);
294 }
295 if let Some(max) = max_tokens {
296 request_body["max_tokens"] = serde_json::json!(max);
297 }
298
299 for (key, value) in additional_params {
301 request_body[key] = value;
302 }
303
304 let api_url = format!("{}/chat/completions", base_url);
306
307 let mut request = client.post(&api_url)
309 .header("Authorization", format!("Bearer {}", api_key))
310 .header("Content-Type", "application/json");
311
312 for (key, value) in additional_headers {
314 request = request.header(key, value);
315 }
316
317 let response = request.json(&request_body).send().await?;
319
320 let status = response.status();
322 if !status.is_success() {
323 let error_text = response.text().await?;
324 return Err(Error::msg(format!("API request failed: {} - {}", status, error_text)));
325 }
326
327 let response: OpenAIResponse = response.json().await?;
329
330 let chat_message = match response.choices.first() {
332 Some(choice) => {
333 let message = &choice.message;
334 match message.role.as_str() {
335 "assistant" => ChatMessage::AIMessage(ChatMessageContent {
336 content: message.content.clone(),
337 name: message.name.clone(),
338 additional_kwargs: HashMap::new(),
339 }),
340 _ => {
341 return Err(Error::msg(format!("Unexpected message role: {}", message.role)));
342 }
343 }
344 },
345 None => {
346 match &response.output {
348 Some(outputs) => {
349 match outputs.first() {
350 Some(choice) => {
351 let message = &choice.message;
352 ChatMessage::AIMessage(ChatMessageContent {
353 content: message.content.clone(),
354 name: message.name.clone(),
355 additional_kwargs: HashMap::new(),
356 })
357 },
358 None => return Err(Error::msg("No output returned from API")),
359 }
360 },
361 None => return Err(Error::msg("No choices or output returned from API")),
362 }
363 },
364 };
365
366 let usage = match &response.usage {
368 Some(openai_usage) => {
369 Some(TokenUsage {
370 prompt_tokens: openai_usage.prompt_tokens,
371 completion_tokens: openai_usage.completion_tokens,
372 total_tokens: openai_usage.total_tokens,
373 })
374 },
375 None => None,
376 };
377
378 let model_name_str = response.model.as_deref().unwrap_or("unknown");
379 Ok(ChatCompletion {
380 message: chat_message,
381 usage,
382 model_name: model_name_str.to_string(),
383 })
384 })
385 }
386}