1use anyhow::{Context, Result};
4use futures_core::Stream;
5use reqwest::header;
6use serde::{Deserialize, Serialize};
7
8const API_URL: &str = "https://api.openai.com/v1/chat/completions";
9
10#[derive(Clone)]
12pub struct ChatGptClient {
13 client: reqwest::Client,
14 model: String,
15}
16
17#[derive(Debug, Serialize, Deserialize, Clone)]
18pub struct Message {
19 pub role: String,
20 #[serde(skip_serializing_if = "Option::is_none")]
21 pub content: Option<String>,
22 #[serde(skip_serializing_if = "Option::is_none")]
23 pub tool_calls: Option<Vec<ToolCall>>,
24 #[serde(skip_serializing_if = "Option::is_none")]
25 pub tool_call_id: Option<String>,
26}
27
28impl Message {
29 pub fn system(content: impl Into<String>) -> Self {
30 Self {
31 role: "system".to_owned(),
32 content: Some(content.into()),
33 tool_calls: None,
34 tool_call_id: None,
35 }
36 }
37
38 pub fn user(content: impl Into<String>) -> Self {
39 Self {
40 role: "user".to_owned(),
41 content: Some(content.into()),
42 tool_calls: None,
43 tool_call_id: None,
44 }
45 }
46
47 pub fn assistant(content: impl Into<String>) -> Self {
48 Self {
49 role: "assistant".to_owned(),
50 content: Some(content.into()),
51 tool_calls: None,
52 tool_call_id: None,
53 }
54 }
55
56 pub fn assistant_tool_calls(tool_calls: Vec<ToolCall>) -> Self {
57 Self {
58 role: "assistant".to_owned(),
59 content: None,
60 tool_calls: Some(tool_calls),
61 tool_call_id: None,
62 }
63 }
64
65 pub fn tool_result(tool_call_id: impl Into<String>, content: impl Into<String>) -> Self {
66 Self {
67 role: "tool".to_owned(),
68 content: Some(content.into()),
69 tool_calls: None,
70 tool_call_id: Some(tool_call_id.into()),
71 }
72 }
73}
74
75#[derive(Debug, Serialize, Deserialize, Clone)]
78pub struct ToolCall {
79 pub id: String,
80 #[serde(rename = "type")]
81 pub call_type: String,
82 pub function: FunctionCall,
83}
84
85#[derive(Debug, Serialize, Deserialize, Clone)]
86pub struct FunctionCall {
87 pub name: String,
88 pub arguments: String,
89}
90
91#[derive(Debug, Serialize, Clone)]
92pub struct ToolDefinition {
93 #[serde(rename = "type")]
94 pub tool_type: String,
95 pub function: FunctionDefinition,
96}
97
98#[derive(Debug, Serialize, Clone)]
99pub struct FunctionDefinition {
100 pub name: String,
101 pub description: String,
102 pub parameters: serde_json::Value,
103}
104
105#[derive(Debug, Serialize)]
108struct ChatRequest {
109 model: String,
110 messages: Vec<Message>,
111 #[serde(skip_serializing_if = "Option::is_none")]
112 temperature: Option<f32>,
113 #[serde(skip_serializing_if = "std::ops::Not::not")]
114 stream: bool,
115 #[serde(skip_serializing_if = "Option::is_none")]
116 tools: Option<Vec<ToolDefinition>>,
117}
118
119#[derive(Debug, Deserialize)]
122struct ChatResponse {
123 choices: Vec<Choice>,
124}
125
126#[derive(Debug, Deserialize)]
127struct Choice {
128 message: Message,
129 finish_reason: Option<String>,
130}
131
132#[derive(Debug, Deserialize)]
135struct StreamChunk {
136 choices: Vec<StreamChoice>,
137}
138
139#[derive(Debug, Deserialize)]
140struct StreamChoice {
141 delta: Delta,
142}
143
144#[derive(Debug, Deserialize)]
145struct Delta {
146 content: Option<String>,
147}
148
149#[derive(Debug, thiserror::Error)]
152pub enum LlmError {
153 #[error("OpenAI API error (HTTP {status}): {body}")]
154 Api { status: u16, body: String },
155
156 #[error(transparent)]
157 Transport(#[from] reqwest::Error),
158
159 #[error(transparent)]
160 Other(#[from] anyhow::Error),
161}
162
163impl ChatGptClient {
164 pub fn new(api_key: &str, model: &str) -> Result<Self> {
168 let mut headers = header::HeaderMap::new();
169 let mut auth = header::HeaderValue::from_str(&format!("Bearer {api_key}"))
170 .context("invalid API key characters")?;
171 auth.set_sensitive(true);
172 headers.insert(header::AUTHORIZATION, auth);
173
174 let client = reqwest::Client::builder()
175 .default_headers(headers)
176 .build()
177 .context("failed to build HTTP client")?;
178
179 Ok(Self {
180 client,
181 model: model.to_owned(),
182 })
183 }
184
185 pub async fn chat(&self, messages: Vec<Message>) -> Result<String, LlmError> {
187 let request = ChatRequest {
188 model: self.model.clone(),
189 messages,
190 temperature: None,
191 stream: false,
192 tools: None,
193 };
194
195 let response = self.client.post(API_URL).json(&request).send().await?;
196 let status = response.status();
197 if !status.is_success() {
198 let body = response.text().await.unwrap_or_default();
199 return Err(LlmError::Api {
200 status: status.as_u16(),
201 body,
202 });
203 }
204
205 let parsed: ChatResponse = response.json().await?;
206 Ok(parsed
207 .choices
208 .into_iter()
209 .next()
210 .and_then(|c| c.message.content)
211 .unwrap_or_default())
212 }
213
214 pub async fn chat_with_tools(
220 &self,
221 messages: Vec<Message>,
222 tools: Option<&[ToolDefinition]>,
223 ) -> Result<(Message, Option<String>), LlmError> {
224 let request = ChatRequest {
225 model: self.model.clone(),
226 messages,
227 temperature: None,
228 stream: false,
229 tools: tools.map(|t| t.to_vec()),
230 };
231
232 let response = self.client.post(API_URL).json(&request).send().await?;
233 let status = response.status();
234 if !status.is_success() {
235 let body = response.text().await.unwrap_or_default();
236 return Err(LlmError::Api {
237 status: status.as_u16(),
238 body,
239 });
240 }
241
242 let parsed: ChatResponse = response.json().await?;
243 let choice = parsed
244 .choices
245 .into_iter()
246 .next()
247 .ok_or_else(|| LlmError::Other(anyhow::anyhow!("no choices in response")))?;
248
249 Ok((choice.message, choice.finish_reason))
250 }
251
252 pub fn chat_stream(
257 &self,
258 messages: Vec<Message>,
259 ) -> impl Stream<Item = Result<String, LlmError>> + Send {
260 let client = self.client.clone();
261 let model = self.model.clone();
262
263 async_stream::try_stream! {
264 let request = ChatRequest {
265 model,
266 messages,
267 temperature: None,
268 stream: true,
269 tools: None,
270 };
271
272 let mut response = client.post(API_URL).json(&request).send().await?;
273 if !response.status().is_success() {
274 let status = response.status().as_u16();
275 let mut body = String::new();
277 while let Some(chunk) = response.chunk().await? {
278 body.push_str(&String::from_utf8_lossy(&chunk));
279 }
280 Err(LlmError::Api { status, body })?;
281 }
282 let mut buffer = String::new();
283
284 while let Some(chunk) = response.chunk().await? {
285 buffer.push_str(&String::from_utf8_lossy(&chunk));
286
287 while let Some(pos) = buffer.find("\n\n") {
289 let event = buffer[..pos].to_owned();
290 buffer = buffer[pos + 2..].to_owned();
291
292 for line in event.lines() {
293 let data = match line.strip_prefix("data: ") {
294 Some(d) => d.trim(),
295 None => continue,
296 };
297
298 if data == "[DONE]" {
299 return;
300 }
301
302 if let Ok(parsed) = serde_json::from_str::<StreamChunk>(data) {
303 for choice in parsed.choices {
304 if let Some(content) = choice.delta.content {
305 yield content;
306 }
307 }
308 }
309 }
310 }
311 }
312 }
313 }
314}