gent/runtime/providers/
openai.rs

1//! OpenAI API client
2
3use async_trait::async_trait;
4use serde::{Deserialize, Serialize};
5use serde_json::{json, Value as JsonValue};
6
7use crate::errors::{GentError, GentResult};
8use crate::runtime::llm::{LLMClient, LLMResponse, Message, Role, ToolCall, ToolDefinition};
9
10/// OpenAI API client
11pub struct OpenAIClient {
12    api_key: String,
13    model: String,
14    base_url: String,
15    client: reqwest::Client,
16}
17
18impl OpenAIClient {
19    pub fn new(api_key: String) -> Self {
20        Self {
21            api_key,
22            model: "gpt-4o-mini".to_string(),
23            base_url: "https://api.openai.com".to_string(),
24            client: reqwest::Client::new(),
25        }
26    }
27
28    pub fn with_model(mut self, model: &str) -> Self {
29        self.model = model.to_string();
30        self
31    }
32
33    pub fn with_base_url(mut self, url: &str) -> Self {
34        self.base_url = url.to_string();
35        self
36    }
37
38    pub fn model(&self) -> &str {
39        &self.model
40    }
41
42    fn to_openai_messages(&self, messages: &[Message]) -> Vec<OpenAIMessage> {
43        messages.iter().map(|m| self.to_openai_message(m)).collect()
44    }
45
46    fn to_openai_message(&self, message: &Message) -> OpenAIMessage {
47        OpenAIMessage {
48            role: match message.role {
49                Role::System => "system".to_string(),
50                Role::User => "user".to_string(),
51                Role::Assistant => "assistant".to_string(),
52                Role::Tool => "tool".to_string(),
53            },
54            content: if message.content.is_empty() {
55                None
56            } else {
57                Some(message.content.clone())
58            },
59            tool_call_id: message.tool_call_id.clone(),
60            tool_calls: message.tool_calls.as_ref().map(|tcs| {
61                tcs.iter()
62                    .map(|tc| OpenAIToolCall {
63                        id: tc.id.clone(),
64                        r#type: "function".to_string(),
65                        function: OpenAIFunction {
66                            name: tc.name.clone(),
67                            arguments: tc.arguments.to_string(),
68                        },
69                    })
70                    .collect()
71            }),
72        }
73    }
74
75    fn to_openai_tools(&self, tools: &[ToolDefinition]) -> Vec<OpenAITool> {
76        tools
77            .iter()
78            .map(|t| OpenAITool {
79                r#type: "function".to_string(),
80                function: OpenAIFunctionDef {
81                    name: t.name.clone(),
82                    description: t.description.clone(),
83                    parameters: t.parameters.clone(),
84                },
85            })
86            .collect()
87    }
88}
89
90#[async_trait]
91impl LLMClient for OpenAIClient {
92    async fn chat(
93        &self,
94        messages: Vec<Message>,
95        tools: Vec<ToolDefinition>,
96        model: Option<&str>,
97        json_mode: bool,
98    ) -> GentResult<LLMResponse> {
99        let url = format!("{}/v1/chat/completions", self.base_url);
100
101        // Use provided model or fall back to client default
102        let model_to_use = model.unwrap_or(&self.model);
103
104        let mut body = json!({
105            "model": model_to_use,
106            "messages": self.to_openai_messages(&messages),
107        });
108
109        if !tools.is_empty() {
110            body["tools"] = json!(self.to_openai_tools(&tools));
111        }
112
113        if json_mode {
114            body["response_format"] = json!({"type": "json_object"});
115        }
116
117        let response = self
118            .client
119            .post(&url)
120            .header("Authorization", format!("Bearer {}", self.api_key))
121            .header("Content-Type", "application/json")
122            .json(&body)
123            .send()
124            .await
125            .map_err(|e| GentError::ApiError {
126                message: format!("Request failed: {}", e),
127            })?;
128
129        if !response.status().is_success() {
130            let status = response.status();
131            let text = response.text().await.unwrap_or_default();
132            return Err(GentError::ApiError {
133                message: format!("API error ({}): {}", status, text),
134            });
135        }
136
137        let api_response: OpenAIResponse =
138            response.json().await.map_err(|e| GentError::ApiError {
139                message: format!("Failed to parse response: {}", e),
140            })?;
141
142        let choice =
143            api_response
144                .choices
145                .into_iter()
146                .next()
147                .ok_or_else(|| GentError::ApiError {
148                    message: "No choices in response".to_string(),
149                })?;
150
151        let tool_calls = choice
152            .message
153            .tool_calls
154            .unwrap_or_default()
155            .into_iter()
156            .map(|tc| ToolCall {
157                id: tc.id,
158                name: tc.function.name,
159                arguments: serde_json::from_str(&tc.function.arguments).unwrap_or(JsonValue::Null),
160            })
161            .collect();
162
163        Ok(LLMResponse {
164            content: choice.message.content,
165            tool_calls,
166        })
167    }
168}
169
170// OpenAI API types
171#[derive(Debug, Serialize)]
172struct OpenAIMessage {
173    role: String,
174    #[serde(skip_serializing_if = "Option::is_none")]
175    content: Option<String>,
176    #[serde(skip_serializing_if = "Option::is_none")]
177    tool_call_id: Option<String>,
178    #[serde(skip_serializing_if = "Option::is_none")]
179    tool_calls: Option<Vec<OpenAIToolCall>>,
180}
181
182#[derive(Debug, Serialize, Deserialize)]
183struct OpenAIToolCall {
184    id: String,
185    r#type: String,
186    function: OpenAIFunction,
187}
188
189#[derive(Debug, Serialize, Deserialize)]
190struct OpenAIFunction {
191    name: String,
192    arguments: String,
193}
194
195#[derive(Debug, Serialize)]
196struct OpenAITool {
197    r#type: String,
198    function: OpenAIFunctionDef,
199}
200
201#[derive(Debug, Serialize)]
202struct OpenAIFunctionDef {
203    name: String,
204    description: String,
205    parameters: JsonValue,
206}
207
208#[derive(Debug, Deserialize)]
209struct OpenAIResponse {
210    choices: Vec<OpenAIChoice>,
211}
212
213#[derive(Debug, Deserialize)]
214struct OpenAIChoice {
215    message: OpenAIResponseMessage,
216}
217
218#[derive(Debug, Deserialize)]
219struct OpenAIResponseMessage {
220    content: Option<String>,
221    tool_calls: Option<Vec<OpenAIToolCall>>,
222}