Skip to main content

oneshot_openai_cli/
lib.rs

1use anyhow::{Context, Result};
2use futures::stream::StreamExt;
3use reqwest::{Client as HttpClient, header::HeaderValue};
4use serde::{Deserialize, Serialize};
5use std::{env::var, time::Duration};
6
7// ============================================================================
8// Request Types
9// ============================================================================
10#[derive(Debug, Clone, Serialize)]
11pub struct ChatCompletionRequest {
12    pub model: String,
13    pub messages: Vec<Message>,
14    #[serde(skip_serializing_if = "Option::is_none")]
15    pub temperature: Option<f32>,
16    #[serde(skip_serializing_if = "Option::is_none")]
17    pub max_tokens: Option<u32>,
18    #[serde(skip_serializing_if = "Option::is_none")]
19    pub top_p: Option<f32>,
20    #[serde(skip_serializing_if = "Option::is_none")]
21    pub stream: Option<bool>,
22    #[serde(skip_serializing_if = "Option::is_none")]
23    pub stop: Option<Vec<String>>,
24    #[serde(skip_serializing_if = "Option::is_none")]
25    pub extra_body: Option<ExtraBody>,
26}
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct ExtraBody {
30    pub thinking: ThinkType,
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct ThinkType {
35    #[serde(rename = "type")]
36    pub r#type: String,
37}
38
39#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct Message {
41    pub role: String,
42    pub content: String,
43}
44
45impl Message {
46    pub fn system(content: impl Into<String>) -> Self {
47        Self {
48            role: "system".to_string(),
49            content: content.into(),
50        }
51    }
52
53    pub fn user(content: impl Into<String>) -> Self {
54        Self {
55            role: "user".to_string(),
56            content: content.into(),
57        }
58    }
59
60    pub fn assistant(content: impl Into<String>) -> Self {
61        Self {
62            role: "assistant".to_string(),
63            content: content.into(),
64        }
65    }
66}
67
68// ============================================================================
69// Response Types
70// ============================================================================
71
72#[derive(Debug, Clone, Deserialize)]
73pub struct ChatCompletionResponse {
74    pub id: String,
75    pub object: String,
76    pub created: u64,
77    pub model: String,
78    pub choices: Vec<ChatChoice>,
79    pub usage: Usage,
80}
81
82#[derive(Debug, Clone, Deserialize)]
83pub struct ChatChoice {
84    pub index: u32,
85    pub message: Message,
86    pub finish_reason: Option<String>,
87}
88
89#[derive(Debug, Clone, Deserialize)]
90pub struct ChatCompletionChunk {
91    pub id: String,
92    pub object: String,
93    pub created: u64,
94    pub model: String,
95    pub choices: Vec<ChatChoiceDelta>,
96}
97
98#[derive(Debug, Clone, Deserialize)]
99pub struct ChatChoiceDelta {
100    pub index: u32,
101    pub delta: Delta,
102    pub finish_reason: Option<String>,
103}
104
105#[derive(Debug, Clone, Deserialize)]
106pub struct Delta {
107    #[serde(skip_serializing_if = "Option::is_none")]
108    pub role: Option<String>,
109    #[serde(skip_serializing_if = "Option::is_none")]
110    pub content: Option<String>,
111}
112
113#[derive(Debug, Clone, Deserialize)]
114pub struct Usage {
115    pub prompt_tokens: u32,
116    pub completion_tokens: Option<u32>,
117    pub total_tokens: u32,
118}
119
120#[derive(Debug, Clone, Deserialize)]
121pub struct ChatModelResponse {
122    pub object: String,
123    pub data: Vec<ChatModel>,
124}
125
126#[derive(Debug, Clone, Deserialize)]
127pub struct ChatModel {
128    pub id: String,
129    pub object: String,
130    pub created: u64,
131    pub owned_by: String,
132}
133
134// ============================================================================
135// Client
136// ============================================================================
137
138#[derive(Debug, Clone)]
139pub struct OpenAIClient {
140    http_client: HttpClient,
141    base_url: String,
142    api_key: Option<String>,
143}
144
145impl OpenAIClient {
146    /// Create a new client with the specified base URL
147    pub fn new(base_url: impl Into<String>) -> Result<Self> {
148        let http_client = HttpClient::builder()
149            .timeout(Duration::from_secs(300))
150            .connect_timeout(Duration::from_secs(10))
151            .build()
152            .context("Failed to build HTTP client")?;
153
154        Ok(Self {
155            http_client,
156            base_url: base_url.into(),
157            api_key: None,
158        })
159    }
160
161    /// Create a new client with the specified base URL and API key
162    pub fn with_api_key(base_url: impl Into<String>, api_key: impl Into<String>) -> Result<Self> {
163        let http_client = HttpClient::builder()
164            .timeout(Duration::from_secs(300))
165            .connect_timeout(Duration::from_secs(10))
166            .build()
167            .context("Failed to build HTTP client")?;
168
169        Ok(Self {
170            http_client,
171            base_url: base_url.into(),
172            api_key: Some(api_key.into()),
173        })
174    }
175
176    /// get all models
177    pub async fn list_models(&self) -> Result<ChatModelResponse> {
178        let url = format!("{}/models", self.base_url);
179
180        let mut req = self.http_client.get(&url);
181
182        if let Some(api_key) = &self.api_key {
183            req = req.header("Authorization", format!("Bearer {}", api_key));
184        }
185
186        let user = var("USERNAME").unwrap_or(var("USER").unwrap_or_default());
187        if !user.is_empty() {
188            req = req.header("X-User-ID", HeaderValue::from_str(user.as_str()).unwrap());
189        }
190        let response = req
191            .send()
192            .await
193            .context("Failed to send list model request")?;
194
195        if !response.status().is_success() {
196            let status = response.status();
197            let error_text = response.text().await.unwrap_or_default();
198            anyhow::bail!("API error ({}): {}", status, error_text);
199        }
200
201        response
202            .json()
203            .await
204            .context("Failed to parse list models response")
205    }
206
207    /// Send a chat completion request
208    pub async fn chat_completion(
209        &self,
210        request: ChatCompletionRequest,
211    ) -> Result<ChatCompletionResponse> {
212        let url = format!("{}/chat/completions", self.base_url);
213
214        let mut req = self.http_client.post(&url).json(&request);
215
216        if let Some(api_key) = &self.api_key {
217            req = req.header("Authorization", format!("Bearer {}", api_key));
218        }
219
220        let user = var("USERNAME").unwrap_or(var("USER").unwrap_or_default());
221        if !user.is_empty() {
222            req = req.header("X-User-ID", HeaderValue::from_str(user.as_str()).unwrap());
223        }
224        let response = req
225            .send()
226            .await
227            .context("Failed to send chat completion request")?;
228
229        if !response.status().is_success() {
230            let status = response.status();
231            let error_text = response.text().await.unwrap_or_default();
232            anyhow::bail!("API error ({}): {}", status, error_text);
233        }
234
235        response
236            .json()
237            .await
238            .context("Failed to parse chat completion response")
239    }
240
241    /// Send a streaming chat completion request
242    pub async fn chat_completion_stream(
243        &self,
244        request: ChatCompletionRequest,
245    ) -> Result<impl futures::Stream<Item = Result<ChatCompletionChunk>>> {
246        let url = format!("{}/chat/completions", self.base_url);
247
248        let mut req = self.http_client.post(&url).json(&request);
249
250        if let Some(api_key) = &self.api_key {
251            req = req.header("Authorization", format!("Bearer {}", api_key));
252        }
253
254        let user = var("USERNAME").unwrap_or(var("USER").unwrap_or_default());
255        if !user.is_empty() {
256            req = req.header("X-User-ID", HeaderValue::from_str(user.as_str()).unwrap());
257        }
258        let response = req
259            .send()
260            .await
261            .context("Failed to send streaming chat completion request")?;
262
263        if !response.status().is_success() {
264            let status = response.status();
265            let error_text = response.text().await.unwrap_or_default();
266            anyhow::bail!("API error({}): {}", status, error_text);
267        }
268
269        let stream = response.bytes_stream().map(|result| {
270            let bytes = result.context("Failed to read stream chunk")?;
271            let text = String::from_utf8_lossy(&bytes);
272
273            for line in text.lines() {
274                if let Some(data) = line.strip_prefix("data: ") {
275                    if data == "[DONE]" {
276                        continue;
277                    }
278                    let chunk: ChatCompletionChunk =
279                        serde_json::from_str(data).context("Failed to parse chunk")?;
280                    return Ok(chunk);
281                }
282            }
283
284            anyhow::bail!("No valid data in chunk")
285        });
286
287        Ok(stream)
288    }
289}
290
291// ============================================================================
292// Builder Pattern for Requests
293// ============================================================================
294
295impl ChatCompletionRequest {
296    pub fn new(model: impl Into<String>) -> Self {
297        Self {
298            model: model.into(),
299            messages: Vec::new(),
300            temperature: None,
301            max_tokens: None,
302            top_p: None,
303            stream: None,
304            stop: None,
305            extra_body: None,
306        }
307    }
308
309    pub fn message(mut self, message: Message) -> Self {
310        self.messages.push(message);
311        self
312    }
313
314    pub fn messages(mut self, messages: Vec<Message>) -> Self {
315        self.messages = messages;
316        self
317    }
318
319    pub fn temperature(mut self, temperature: f32) -> Self {
320        self.temperature = Some(temperature);
321        self
322    }
323
324    pub fn max_tokens(mut self, max_tokens: u32) -> Self {
325        self.max_tokens = Some(max_tokens);
326        self
327    }
328
329    pub fn top_p(mut self, top_p: f32) -> Self {
330        self.top_p = Some(top_p);
331        self
332    }
333
334    pub fn stream(mut self, stream: bool) -> Self {
335        self.stream = Some(stream);
336        self
337    }
338
339    pub fn stop(mut self, stop: Vec<String>) -> Self {
340        self.stop = Some(stop);
341        self
342    }
343
344    pub fn thinking(mut self, thinking: bool) -> Self {
345        if thinking {
346            self.extra_body = Some(ExtraBody {
347                thinking: ThinkType {
348                    r#type: "enabled".into(),
349                },
350            })
351        }
352        self
353    }
354}