1use anyhow::{Context, Result};
2use futures::stream::StreamExt;
3use reqwest::{Client as HttpClient, header::HeaderValue};
4use serde::{Deserialize, Serialize};
5use std::{env::var, time::Duration};
6
7#[derive(Debug, Clone, Serialize)]
12pub struct ChatCompletionRequest {
13 pub model: String,
14 pub messages: Vec<Message>,
15 #[serde(skip_serializing_if = "Option::is_none")]
16 pub temperature: Option<f32>,
17 #[serde(skip_serializing_if = "Option::is_none")]
18 pub max_tokens: Option<u32>,
19 #[serde(skip_serializing_if = "Option::is_none")]
20 pub top_p: Option<f32>,
21 #[serde(skip_serializing_if = "Option::is_none")]
22 pub stream: Option<bool>,
23 #[serde(skip_serializing_if = "Option::is_none")]
24 pub stop: Option<Vec<String>>,
25}
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct Message {
29 pub role: String,
30 pub content: String,
31}
32
33impl Message {
34 pub fn system(content: impl Into<String>) -> Self {
35 Self {
36 role: "system".to_string(),
37 content: content.into(),
38 }
39 }
40
41 pub fn user(content: impl Into<String>) -> Self {
42 Self {
43 role: "user".to_string(),
44 content: content.into(),
45 }
46 }
47
48 pub fn assistant(content: impl Into<String>) -> Self {
49 Self {
50 role: "assistant".to_string(),
51 content: content.into(),
52 }
53 }
54}
55
56#[derive(Debug, Clone, Deserialize)]
61pub struct ChatCompletionResponse {
62 pub id: String,
63 pub object: String,
64 pub created: u64,
65 pub model: String,
66 pub choices: Vec<ChatChoice>,
67 pub usage: Usage,
68}
69
70#[derive(Debug, Clone, Deserialize)]
71pub struct ChatChoice {
72 pub index: u32,
73 pub message: Message,
74 pub finish_reason: Option<String>,
75}
76
77#[derive(Debug, Clone, Deserialize)]
78pub struct ChatCompletionChunk {
79 pub id: String,
80 pub object: String,
81 pub created: u64,
82 pub model: String,
83 pub choices: Vec<ChatChoiceDelta>,
84}
85
86#[derive(Debug, Clone, Deserialize)]
87pub struct ChatChoiceDelta {
88 pub index: u32,
89 pub delta: Delta,
90 pub finish_reason: Option<String>,
91}
92
93#[derive(Debug, Clone, Deserialize)]
94pub struct Delta {
95 #[serde(skip_serializing_if = "Option::is_none")]
96 pub role: Option<String>,
97 #[serde(skip_serializing_if = "Option::is_none")]
98 pub content: Option<String>,
99}
100
101#[derive(Debug, Clone, Deserialize)]
102pub struct Usage {
103 pub prompt_tokens: u32,
104 pub completion_tokens: Option<u32>,
105 pub total_tokens: u32,
106}
107
108#[derive(Debug, Clone, Deserialize)]
109pub struct ChatModelResponse {
110 pub object: String,
111 pub data: Vec<ChatModel>,
112}
113
114#[derive(Debug, Clone, Deserialize)]
115pub struct ChatModel {
116 pub id: String,
117 pub object: String,
118 pub created: u64,
119 pub owned_by: String,
120}
121
122#[derive(Debug, Clone)]
127pub struct OpenAIClient {
128 http_client: HttpClient,
129 base_url: String,
130 api_key: Option<String>,
131}
132
133impl OpenAIClient {
134 pub fn new(base_url: impl Into<String>) -> Result<Self> {
136 let http_client = HttpClient::builder()
137 .timeout(Duration::from_secs(300))
138 .connect_timeout(Duration::from_secs(10))
139 .build()
140 .context("Failed to build HTTP client")?;
141
142 Ok(Self {
143 http_client,
144 base_url: base_url.into(),
145 api_key: None,
146 })
147 }
148
149 pub fn with_api_key(base_url: impl Into<String>, api_key: impl Into<String>) -> Result<Self> {
151 let http_client = HttpClient::builder()
152 .timeout(Duration::from_secs(300))
153 .connect_timeout(Duration::from_secs(10))
154 .build()
155 .context("Failed to build HTTP client")?;
156
157 Ok(Self {
158 http_client,
159 base_url: base_url.into(),
160 api_key: Some(api_key.into()),
161 })
162 }
163
164 pub async fn list_models(&self) -> Result<ChatModelResponse> {
166 let url = format!("{}/models", self.base_url);
167
168 let mut req = self.http_client.get(&url);
169
170 if let Some(api_key) = &self.api_key {
171 req = req.header("Authorization", format!("Bearer {}", api_key));
172 }
173
174 let user = var("USERNAME").unwrap_or(var("USER").unwrap_or_default());
175 if !user.is_empty() {
176 req = req.header("X-User-ID", HeaderValue::from_str(user.as_str()).unwrap());
177 }
178 let response = req
179 .send()
180 .await
181 .context("Failed to send list model request")?;
182
183 if !response.status().is_success() {
184 let status = response.status();
185 let error_text = response.text().await.unwrap_or_default();
186 anyhow::bail!("API error ({}): {}", status, error_text);
187 }
188
189 response
190 .json()
191 .await
192 .context("Failed to parse list models response")
193 }
194
195 pub async fn chat_completion(
197 &self,
198 request: ChatCompletionRequest,
199 ) -> Result<ChatCompletionResponse> {
200 let url = format!("{}/chat/completions", self.base_url);
201
202 let mut req = self.http_client.post(&url).json(&request);
203
204 if let Some(api_key) = &self.api_key {
205 req = req.header("Authorization", format!("Bearer {}", api_key));
206 }
207
208 let user = var("USERNAME").unwrap_or(var("USER").unwrap_or_default());
209 if !user.is_empty() {
210 req = req.header("X-User-ID", HeaderValue::from_str(user.as_str()).unwrap());
211 }
212 let response = req
213 .send()
214 .await
215 .context("Failed to send chat completion request")?;
216
217 if !response.status().is_success() {
218 let status = response.status();
219 let error_text = response.text().await.unwrap_or_default();
220 anyhow::bail!("API error ({}): {}", status, error_text);
221 }
222
223 response
224 .json()
225 .await
226 .context("Failed to parse chat completion response")
227 }
228
229 pub async fn chat_completion_stream(
231 &self,
232 request: ChatCompletionRequest,
233 ) -> Result<impl futures::Stream<Item = Result<ChatCompletionChunk>>> {
234 let url = format!("{}/chat/completions", self.base_url);
235
236 let mut req = self.http_client.post(&url).json(&request);
237
238 if let Some(api_key) = &self.api_key {
239 req = req.header("Authorization", format!("Bearer {}", api_key));
240 }
241
242 let user = var("USERNAME").unwrap_or(var("USER").unwrap_or_default());
243 if !user.is_empty() {
244 req = req.header("X-User-ID", HeaderValue::from_str(user.as_str()).unwrap());
245 }
246 let response = req
247 .send()
248 .await
249 .context("Failed to send streaming chat completion request")?;
250
251 if !response.status().is_success() {
252 let status = response.status();
253 let error_text = response.text().await.unwrap_or_default();
254 anyhow::bail!("API error({}): {}", status, error_text);
255 }
256
257 let stream = response.bytes_stream().map(|result| {
258 let bytes = result.context("Failed to read stream chunk")?;
259 let text = String::from_utf8_lossy(&bytes);
260
261 for line in text.lines() {
262 if let Some(data) = line.strip_prefix("data: ") {
263 if data == "[DONE]" {
264 continue;
265 }
266 let chunk: ChatCompletionChunk =
267 serde_json::from_str(data).context("Failed to parse chunk")?;
268 return Ok(chunk);
269 }
270 }
271
272 anyhow::bail!("No valid data in chunk")
273 });
274
275 Ok(stream)
276 }
277}
278
279impl ChatCompletionRequest {
284 pub fn new(model: impl Into<String>) -> Self {
285 Self {
286 model: model.into(),
287 messages: Vec::new(),
288 temperature: None,
289 max_tokens: None,
290 top_p: None,
291 stream: None,
292 stop: None,
293 }
294 }
295
296 pub fn message(mut self, message: Message) -> Self {
297 self.messages.push(message);
298 self
299 }
300
301 pub fn messages(mut self, messages: Vec<Message>) -> Self {
302 self.messages = messages;
303 self
304 }
305
306 pub fn temperature(mut self, temperature: f32) -> Self {
307 self.temperature = Some(temperature);
308 self
309 }
310
311 pub fn max_tokens(mut self, max_tokens: u32) -> Self {
312 self.max_tokens = Some(max_tokens);
313 self
314 }
315
316 pub fn top_p(mut self, top_p: f32) -> Self {
317 self.top_p = Some(top_p);
318 self
319 }
320
321 pub fn stream(mut self, stream: bool) -> Self {
322 self.stream = Some(stream);
323 self
324 }
325
326 pub fn stop(mut self, stop: Vec<String>) -> Self {
327 self.stop = Some(stop);
328 self
329 }
330}