rust_agent/models/
openai.rs

1// OpenAI model implementation - based on LangChain design
2use super::chat::{ChatCompletion, ChatModel};
3use super::message::{ChatMessage, ChatMessageContent, TokenUsage};
4use anyhow::Error;
5use reqwest::Client;
6use serde::{Deserialize, Serialize};
7use serde_json;
8use std::collections::HashMap;
9use log::info;
10#[derive(Serialize, Deserialize, Clone)]
11struct OpenAIMessage {
12    role: String,
13    content: String,
14    name: Option<String>,
15    #[serde(skip_serializing_if = "Option::is_none")]
16    tool_call_id: Option<String>,
17}
18
19// Token usage details structure - referencing LangChain's InputTokenDetails and OutputTokenDetails
20#[derive(Deserialize, Default)]
21struct InputTokenDetails {
22    audio_tokens: Option<usize>,
23    cache_read: Option<usize>,
24    reasoning_tokens: Option<usize>,
25    // Other possible fields
26}
27
28#[derive(Deserialize, Default)]
29struct OutputTokenDetails {
30    cache_write: Option<usize>,
31    reasoning_tokens: Option<usize>,
32    // Other possible fields
33}
34
35// OpenAI traditional API usage statistics
36#[derive(Deserialize, Default)]
37struct OpenAIUsage {
38    prompt_tokens: usize,
39    completion_tokens: usize,
40    total_tokens: usize,
41    // Extended fields, supporting more details
42    input_tokens_details: Option<InputTokenDetails>,
43    output_tokens_details: Option<OutputTokenDetails>,
44}
45
46// Responses API usage statistics format
47#[derive(Deserialize, Default)]
48struct OpenAIResponsesUsage {
49    input_tokens: Option<usize>,
50    output_tokens: Option<usize>,
51    total_tokens: Option<usize>,
52    // Fields specific to Responses API
53    input_tokens_details: Option<InputTokenDetails>,
54    output_tokens_details: Option<OutputTokenDetails>,
55}
56
57// Generic API response structure - compatible with OpenAI and other providers
58#[derive(Deserialize)]
59struct OpenAIResponse {
60    id: Option<String>,
61    object: Option<String>,
62    created: Option<u64>,
63    model: Option<String>,
64    choices: Vec<OpenAIChoice>, // This field is usually required
65    usage: Option<OpenAIUsage>,
66    // Fields compatible with Responses API
67    output: Option<Vec<OpenAIChoice>>,
68    // Other possible response fields
69}
70
71#[derive(Deserialize)]
72struct OpenAIChoice {
73    index: u32,
74    message: OpenAIMessage,
75    finish_reason: String,
76}
77
78// API type enumeration - supporting traditional Chat Completions API and new Responses API
79#[derive(Debug, Clone, Copy)]
80enum OpenAIApiType {
81    ChatCompletions,
82    Responses,
83}
84
85// OpenAI model implementation - supporting multiple API formats
86#[derive(Clone)]
87pub struct OpenAIChatModel {
88    client: Client,
89    api_key: String,
90    base_url: String,
91    model_name: Option<String>,
92    temperature: Option<f32>,
93    max_tokens: Option<u32>,
94    api_type: OpenAIApiType,
95    additional_headers: HashMap<String, String>,
96    additional_params: HashMap<String, serde_json::Value>,
97}
98
99impl OpenAIChatModel {
100    /// Create a new OpenAI chat model instance
101    pub fn new(api_key: String, base_url: Option<String>) -> Self {
102        Self {
103            client: Client::new(),
104            api_key,
105            base_url: base_url.unwrap_or_else(|| "https://api.openai.com/v1".to_string()),
106            model_name: None,
107            temperature: Some(0.7),
108            max_tokens: None,
109            api_type: OpenAIApiType::ChatCompletions,
110            additional_headers: HashMap::new(),
111            additional_params: HashMap::new(),
112        }
113    }
114
115    /// Get model name
116    pub fn model_name(&self) -> Option<&String> {
117        self.model_name.as_ref()
118    }
119
120    /// Get base URL
121    pub fn base_url(&self) -> &String {
122        &self.base_url
123    }
124
125    /// Get temperature parameter
126    pub fn temperature(&self) -> Option<f32> {
127        self.temperature
128    }
129
130    /// Get maximum number of tokens
131    pub fn max_tokens(&self) -> Option<u32> {
132        self.max_tokens
133    }
134
135    /// Set model name
136    pub fn with_model(mut self, model_name: String) -> Self {
137        self.model_name = Some(model_name);
138        self
139    }
140
141    /// Set temperature parameter
142    pub fn with_temperature(mut self, temperature: f32) -> Self {
143        self.temperature = Some(temperature);
144        self
145    }
146
147    /// Set maximum number of tokens
148    pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
149        self.max_tokens = Some(max_tokens);
150        self
151    }
152
153    /// Set API type (Chat Completions or Responses)
154    pub fn with_api_type(mut self, api_type: OpenAIApiType) -> Self {
155        self.api_type = api_type;
156        self
157    }
158
159    /// Add additional request headers
160    pub fn with_additional_header(mut self, key: String, value: String) -> Self {
161        self.additional_headers.insert(key, value);
162        self
163    }
164
165    /// Add additional request parameters
166    pub fn with_additional_param(mut self, key: String, value: serde_json::Value) -> Self {
167        self.additional_params.insert(key, value);
168        self
169    }
170
171    /// Build request payload - referencing LangChain's _get_request_payload method
172    fn _get_request_payload(&self, messages: &[OpenAIMessage]) -> Result<serde_json::Value, Error> {
173        Ok(serde_json::json!({"messages": messages}))
174    }
175
176    /// Convert message to dictionary format - referencing LangChain's _convert_message_to_dict
177    fn _convert_message_to_dict(&self, message: &OpenAIMessage) -> Result<serde_json::Value, Error> {
178        Ok(serde_json::to_value(message)?)  
179    }
180
181    /// Build Responses API payload - referencing LangChain's _construct_responses_api_payload
182    fn _construct_responses_api_payload(&self, messages: &[OpenAIMessage]) -> Result<serde_json::Value, Error> {
183        Ok(serde_json::json!({"messages": messages}))
184    }
185
186    /// Create usage metadata - referencing LangChain's _create_usage_metadata
187    fn _create_usage_metadata(&self, usage: &OpenAIUsage) -> TokenUsage {
188        TokenUsage {
189            prompt_tokens: usage.prompt_tokens,
190            completion_tokens: usage.completion_tokens,
191            total_tokens: usage.total_tokens,
192        }
193    }
194
195    /// Create usage metadata for Responses API - referencing LangChain's _create_usage_metadata_responses
196    fn _create_usage_metadata_responses(&self, usage: &OpenAIResponsesUsage) -> TokenUsage {
197        TokenUsage {
198            prompt_tokens: usage.input_tokens.unwrap_or(0),
199            completion_tokens: usage.output_tokens.unwrap_or(0),
200            total_tokens: usage.total_tokens.unwrap_or(0),
201        }
202    }
203
204    /// Convert dictionary to message - referencing LangChain's _convert_dict_to_message
205    fn _convert_dict_to_message(&self, message_dict: serde_json::Value) -> Result<ChatMessage, Error> {
206        // Simple implementation: try to extract role and content from JSON
207        let role = message_dict.get("role").and_then(|v| v.as_str()).unwrap_or("assistant");
208        let content = message_dict.get("content").and_then(|v| v.as_str()).unwrap_or("").to_string();
209        
210        let chat_content = ChatMessageContent {
211            content,
212            name: None,
213            additional_kwargs: HashMap::new(),
214        };
215        
216        match role {
217            "system" => Ok(ChatMessage::System(chat_content)),
218            "user" => Ok(ChatMessage::Human(chat_content)),
219            "assistant" => Ok(ChatMessage::AIMessage(chat_content)),
220            "tool" => Ok(ChatMessage::ToolMessage(chat_content)),
221            _ => Ok(ChatMessage::AIMessage(chat_content)),
222        }
223    }
224}
225
226impl ChatModel for OpenAIChatModel {
227    fn model_name(&self) -> Option<&str> {
228        self.model_name.as_deref()
229    }
230
231    fn base_url(&self) -> String {
232        self.base_url.to_string()
233    }
234
235    fn invoke(&self, messages: Vec<ChatMessage>) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<ChatCompletion, Error>> + Send + '_>> {
236        let messages = messages;
237        let client = self.client.clone();
238        let api_key = self.api_key.clone();
239        let base_url = self.base_url.clone();
240        let model_name = self.model_name.clone();
241        let temperature = self.temperature;
242        let max_tokens = self.max_tokens;
243        let additional_headers = self.additional_headers.clone();
244        let additional_params = self.additional_params.clone();
245
246        Box::pin(async move {
247            // Convert message format
248            let openai_messages: Vec<OpenAIMessage> = messages
249                .into_iter()
250                .map(|msg| match msg {
251                    ChatMessage::System(content) => OpenAIMessage {
252                        role: "system".to_string(),
253                        content: content.content,
254                        name: content.name,
255                        tool_call_id: None,
256                    },
257                    ChatMessage::Human(content) => OpenAIMessage {
258                        role: "user".to_string(),
259                        content: content.content,
260                        name: content.name,
261                        tool_call_id: None,
262                    },
263                    ChatMessage::AIMessage(content) => OpenAIMessage {
264                        role: "assistant".to_string(),
265                        content: content.content,
266                        name: content.name,
267                        tool_call_id: None,
268                    },
269                    ChatMessage::ToolMessage(content) => {
270                        info!("Converting tool message: role=tool, content={}", content.content);
271                        // Add tool_call_id for tool messages
272                        let tool_call_id = content.additional_kwargs.get("tool_call_id")
273                            .and_then(|v| v.as_str())
274                            .unwrap_or("default_tool_call_id").to_string();
275                        OpenAIMessage {
276                            role: "tool".to_string(),
277                            content: content.content,
278                            name: content.name,
279                            tool_call_id: Some(tool_call_id),
280                        }
281                    },
282                })
283                .collect();
284
285            // Build request body
286            let mut request_body = serde_json::json!({
287                "messages": openai_messages,
288                "model": model_name.clone().unwrap_or("".to_string()),
289            });
290
291            // Add optional parameters
292            if let Some(temp) = temperature {
293                request_body["temperature"] = serde_json::json!(temp);
294            }
295            if let Some(max) = max_tokens {
296                request_body["max_tokens"] = serde_json::json!(max);
297            }
298            
299            // Add additional parameters
300            for (key, value) in additional_params {
301                request_body[key] = value;
302            }
303
304            // Build complete API path, concatenating base_url with specific endpoint
305            let api_url = format!("{}/chat/completions", base_url);
306            
307            // Build request
308            let mut request = client.post(&api_url)
309                .header("Authorization", format!("Bearer {}", api_key))
310                .header("Content-Type", "application/json");
311
312            // Add additional request headers
313            for (key, value) in additional_headers {
314                request = request.header(key, value);
315            }
316            
317            // Send request
318            let response = request.json(&request_body).send().await?;
319            
320            // Check response status
321            let status = response.status();
322            if !status.is_success() {
323                let error_text = response.text().await?;
324                return Err(Error::msg(format!("API request failed: {} - {}", status, error_text)));
325            }
326
327            // Parse response
328            let response: OpenAIResponse = response.json().await?;
329
330            // Handle response
331            let chat_message = match response.choices.first() {
332                Some(choice) => {
333                    let message = &choice.message;
334                    match message.role.as_str() {
335                        "assistant" => ChatMessage::AIMessage(ChatMessageContent {
336                            content: message.content.clone(),
337                            name: message.name.clone(),
338                            additional_kwargs: HashMap::new(),
339                        }),
340                        _ => {
341                            return Err(Error::msg(format!("Unexpected message role: {}", message.role)));
342                        }
343                    }
344                },
345                None => {
346                    // Try to use output field (Responses API)
347                    match &response.output {
348                        Some(outputs) => {
349                            match outputs.first() {
350                                Some(choice) => {
351                                    let message = &choice.message;
352                                    ChatMessage::AIMessage(ChatMessageContent {
353                                        content: message.content.clone(),
354                                        name: message.name.clone(),
355                                        additional_kwargs: HashMap::new(),
356                                    })
357                                },
358                                None => return Err(Error::msg("No output returned from API")),
359                            }
360                        },
361                        None => return Err(Error::msg("No choices or output returned from API")),
362                    }
363                },
364            };
365
366            // Convert usage statistics
367            let usage = match &response.usage {
368                Some(openai_usage) => {
369                    Some(TokenUsage {
370                        prompt_tokens: openai_usage.prompt_tokens,
371                        completion_tokens: openai_usage.completion_tokens,
372                        total_tokens: openai_usage.total_tokens,
373                    })
374                },
375                None => None,
376            };
377
378            let model_name_str = response.model.as_deref().unwrap_or("unknown");
379            Ok(ChatCompletion {
380                message: chat_message,
381                usage,
382                model_name: model_name_str.to_string(),
383            })
384        })
385    }
386}