1use async_trait::async_trait;
4use serde::{Deserialize, Serialize};
5use serde_json::{json, Value as JsonValue};
6
7use crate::errors::{GentError, GentResult};
8use crate::runtime::llm::{LLMClient, LLMResponse, Message, Role, ToolCall, ToolDefinition};
9
10pub struct OpenAIClient {
12 api_key: String,
13 model: String,
14 base_url: String,
15 client: reqwest::Client,
16}
17
18impl OpenAIClient {
19 pub fn new(api_key: String) -> Self {
20 Self {
21 api_key,
22 model: "gpt-4o-mini".to_string(),
23 base_url: "https://api.openai.com".to_string(),
24 client: reqwest::Client::new(),
25 }
26 }
27
28 pub fn with_model(mut self, model: &str) -> Self {
29 self.model = model.to_string();
30 self
31 }
32
33 pub fn with_base_url(mut self, url: &str) -> Self {
34 self.base_url = url.to_string();
35 self
36 }
37
38 pub fn model(&self) -> &str {
39 &self.model
40 }
41
42 fn to_openai_messages(&self, messages: &[Message]) -> Vec<OpenAIMessage> {
43 messages.iter().map(|m| self.to_openai_message(m)).collect()
44 }
45
46 fn to_openai_message(&self, message: &Message) -> OpenAIMessage {
47 OpenAIMessage {
48 role: match message.role {
49 Role::System => "system".to_string(),
50 Role::User => "user".to_string(),
51 Role::Assistant => "assistant".to_string(),
52 Role::Tool => "tool".to_string(),
53 },
54 content: if message.content.is_empty() {
55 None
56 } else {
57 Some(message.content.clone())
58 },
59 tool_call_id: message.tool_call_id.clone(),
60 tool_calls: message.tool_calls.as_ref().map(|tcs| {
61 tcs.iter()
62 .map(|tc| OpenAIToolCall {
63 id: tc.id.clone(),
64 r#type: "function".to_string(),
65 function: OpenAIFunction {
66 name: tc.name.clone(),
67 arguments: tc.arguments.to_string(),
68 },
69 })
70 .collect()
71 }),
72 }
73 }
74
75 fn to_openai_tools(&self, tools: &[ToolDefinition]) -> Vec<OpenAITool> {
76 tools
77 .iter()
78 .map(|t| OpenAITool {
79 r#type: "function".to_string(),
80 function: OpenAIFunctionDef {
81 name: t.name.clone(),
82 description: t.description.clone(),
83 parameters: t.parameters.clone(),
84 },
85 })
86 .collect()
87 }
88}
89
90#[async_trait]
91impl LLMClient for OpenAIClient {
92 async fn chat(
93 &self,
94 messages: Vec<Message>,
95 tools: Vec<ToolDefinition>,
96 model: Option<&str>,
97 json_mode: bool,
98 ) -> GentResult<LLMResponse> {
99 let url = format!("{}/v1/chat/completions", self.base_url);
100
101 let model_to_use = model.unwrap_or(&self.model);
103
104 let mut body = json!({
105 "model": model_to_use,
106 "messages": self.to_openai_messages(&messages),
107 });
108
109 if !tools.is_empty() {
110 body["tools"] = json!(self.to_openai_tools(&tools));
111 }
112
113 if json_mode {
114 body["response_format"] = json!({"type": "json_object"});
115 }
116
117 let response = self
118 .client
119 .post(&url)
120 .header("Authorization", format!("Bearer {}", self.api_key))
121 .header("Content-Type", "application/json")
122 .json(&body)
123 .send()
124 .await
125 .map_err(|e| GentError::ApiError {
126 message: format!("Request failed: {}", e),
127 })?;
128
129 if !response.status().is_success() {
130 let status = response.status();
131 let text = response.text().await.unwrap_or_default();
132 return Err(GentError::ApiError {
133 message: format!("API error ({}): {}", status, text),
134 });
135 }
136
137 let api_response: OpenAIResponse =
138 response.json().await.map_err(|e| GentError::ApiError {
139 message: format!("Failed to parse response: {}", e),
140 })?;
141
142 let choice =
143 api_response
144 .choices
145 .into_iter()
146 .next()
147 .ok_or_else(|| GentError::ApiError {
148 message: "No choices in response".to_string(),
149 })?;
150
151 let tool_calls = choice
152 .message
153 .tool_calls
154 .unwrap_or_default()
155 .into_iter()
156 .map(|tc| ToolCall {
157 id: tc.id,
158 name: tc.function.name,
159 arguments: serde_json::from_str(&tc.function.arguments).unwrap_or(JsonValue::Null),
160 })
161 .collect();
162
163 Ok(LLMResponse {
164 content: choice.message.content,
165 tool_calls,
166 })
167 }
168}
169
170#[derive(Debug, Serialize)]
172struct OpenAIMessage {
173 role: String,
174 #[serde(skip_serializing_if = "Option::is_none")]
175 content: Option<String>,
176 #[serde(skip_serializing_if = "Option::is_none")]
177 tool_call_id: Option<String>,
178 #[serde(skip_serializing_if = "Option::is_none")]
179 tool_calls: Option<Vec<OpenAIToolCall>>,
180}
181
182#[derive(Debug, Serialize, Deserialize)]
183struct OpenAIToolCall {
184 id: String,
185 r#type: String,
186 function: OpenAIFunction,
187}
188
189#[derive(Debug, Serialize, Deserialize)]
190struct OpenAIFunction {
191 name: String,
192 arguments: String,
193}
194
195#[derive(Debug, Serialize)]
196struct OpenAITool {
197 r#type: String,
198 function: OpenAIFunctionDef,
199}
200
201#[derive(Debug, Serialize)]
202struct OpenAIFunctionDef {
203 name: String,
204 description: String,
205 parameters: JsonValue,
206}
207
208#[derive(Debug, Deserialize)]
209struct OpenAIResponse {
210 choices: Vec<OpenAIChoice>,
211}
212
213#[derive(Debug, Deserialize)]
214struct OpenAIChoice {
215 message: OpenAIResponseMessage,
216}
217
218#[derive(Debug, Deserialize)]
219struct OpenAIResponseMessage {
220 content: Option<String>,
221 tool_calls: Option<Vec<OpenAIToolCall>>,
222}