Skip to main content

oneshot_openai_cli/
lib.rs

1use anyhow::{Context, Result};
2use futures::stream::StreamExt;
3use reqwest::{Client as HttpClient, header::HeaderValue};
4use serde::{Deserialize, Serialize};
5use std::{env::var, time::Duration};
6
7// ============================================================================
8// Request Types
9// ============================================================================
10
11#[derive(Debug, Clone, Serialize)]
12pub struct ChatCompletionRequest {
13    pub model: String,
14    pub messages: Vec<Message>,
15    #[serde(skip_serializing_if = "Option::is_none")]
16    pub temperature: Option<f32>,
17    #[serde(skip_serializing_if = "Option::is_none")]
18    pub max_tokens: Option<u32>,
19    #[serde(skip_serializing_if = "Option::is_none")]
20    pub top_p: Option<f32>,
21    #[serde(skip_serializing_if = "Option::is_none")]
22    pub stream: Option<bool>,
23    #[serde(skip_serializing_if = "Option::is_none")]
24    pub stop: Option<Vec<String>>,
25}
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct Message {
29    pub role: String,
30    pub content: String,
31}
32
33impl Message {
34    pub fn system(content: impl Into<String>) -> Self {
35        Self {
36            role: "system".to_string(),
37            content: content.into(),
38        }
39    }
40
41    pub fn user(content: impl Into<String>) -> Self {
42        Self {
43            role: "user".to_string(),
44            content: content.into(),
45        }
46    }
47
48    pub fn assistant(content: impl Into<String>) -> Self {
49        Self {
50            role: "assistant".to_string(),
51            content: content.into(),
52        }
53    }
54}
55
56// ============================================================================
57// Response Types
58// ============================================================================
59
60#[derive(Debug, Clone, Deserialize)]
61pub struct ChatCompletionResponse {
62    pub id: String,
63    pub object: String,
64    pub created: u64,
65    pub model: String,
66    pub choices: Vec<ChatChoice>,
67    pub usage: Usage,
68}
69
70#[derive(Debug, Clone, Deserialize)]
71pub struct ChatChoice {
72    pub index: u32,
73    pub message: Message,
74    pub finish_reason: Option<String>,
75}
76
77#[derive(Debug, Clone, Deserialize)]
78pub struct ChatCompletionChunk {
79    pub id: String,
80    pub object: String,
81    pub created: u64,
82    pub model: String,
83    pub choices: Vec<ChatChoiceDelta>,
84}
85
86#[derive(Debug, Clone, Deserialize)]
87pub struct ChatChoiceDelta {
88    pub index: u32,
89    pub delta: Delta,
90    pub finish_reason: Option<String>,
91}
92
93#[derive(Debug, Clone, Deserialize)]
94pub struct Delta {
95    #[serde(skip_serializing_if = "Option::is_none")]
96    pub role: Option<String>,
97    #[serde(skip_serializing_if = "Option::is_none")]
98    pub content: Option<String>,
99}
100
101#[derive(Debug, Clone, Deserialize)]
102pub struct Usage {
103    pub prompt_tokens: u32,
104    pub completion_tokens: Option<u32>,
105    pub total_tokens: u32,
106}
107
108#[derive(Debug, Clone, Deserialize)]
109pub struct ChatModelResponse {
110    pub object: String,
111    pub data: Vec<ChatModel>,
112}
113
114#[derive(Debug, Clone, Deserialize)]
115pub struct ChatModel {
116    pub id: String,
117    pub object: String,
118    pub created: u64,
119    pub owned_by: String,
120}
121
122// ============================================================================
123// Client
124// ============================================================================
125
126#[derive(Debug, Clone)]
127pub struct OpenAIClient {
128    http_client: HttpClient,
129    base_url: String,
130    api_key: Option<String>,
131}
132
133impl OpenAIClient {
134    /// Create a new client with the specified base URL
135    pub fn new(base_url: impl Into<String>) -> Result<Self> {
136        let http_client = HttpClient::builder()
137            .timeout(Duration::from_secs(300))
138            .connect_timeout(Duration::from_secs(10))
139            .build()
140            .context("Failed to build HTTP client")?;
141
142        Ok(Self {
143            http_client,
144            base_url: base_url.into(),
145            api_key: None,
146        })
147    }
148
149    /// Create a new client with the specified base URL and API key
150    pub fn with_api_key(base_url: impl Into<String>, api_key: impl Into<String>) -> Result<Self> {
151        let http_client = HttpClient::builder()
152            .timeout(Duration::from_secs(300))
153            .connect_timeout(Duration::from_secs(10))
154            .build()
155            .context("Failed to build HTTP client")?;
156
157        Ok(Self {
158            http_client,
159            base_url: base_url.into(),
160            api_key: Some(api_key.into()),
161        })
162    }
163
164    /// get all models
165    pub async fn list_models(&self) -> Result<ChatModelResponse> {
166        let url = format!("{}/models", self.base_url);
167
168        let mut req = self.http_client.get(&url);
169
170        if let Some(api_key) = &self.api_key {
171            req = req.header("Authorization", format!("Bearer {}", api_key));
172        }
173
174        let user = var("USERNAME").unwrap_or(var("USER").unwrap_or_default());
175        if !user.is_empty() {
176            req = req.header("X-User-ID", HeaderValue::from_str(user.as_str()).unwrap());
177        }
178        let response = req
179            .send()
180            .await
181            .context("Failed to send list model request")?;
182
183        if !response.status().is_success() {
184            let status = response.status();
185            let error_text = response.text().await.unwrap_or_default();
186            anyhow::bail!("API error ({}): {}", status, error_text);
187        }
188
189        response
190            .json()
191            .await
192            .context("Failed to parse list models response")
193    }
194
195    /// Send a chat completion request
196    pub async fn chat_completion(
197        &self,
198        request: ChatCompletionRequest,
199    ) -> Result<ChatCompletionResponse> {
200        let url = format!("{}/chat/completions", self.base_url);
201
202        let mut req = self.http_client.post(&url).json(&request);
203
204        if let Some(api_key) = &self.api_key {
205            req = req.header("Authorization", format!("Bearer {}", api_key));
206        }
207
208        let user = var("USERNAME").unwrap_or(var("USER").unwrap_or_default());
209        if !user.is_empty() {
210            req = req.header("X-User-ID", HeaderValue::from_str(user.as_str()).unwrap());
211        }
212        let response = req
213            .send()
214            .await
215            .context("Failed to send chat completion request")?;
216
217        if !response.status().is_success() {
218            let status = response.status();
219            let error_text = response.text().await.unwrap_or_default();
220            anyhow::bail!("API error ({}): {}", status, error_text);
221        }
222
223        response
224            .json()
225            .await
226            .context("Failed to parse chat completion response")
227    }
228
229    /// Send a streaming chat completion request
230    pub async fn chat_completion_stream(
231        &self,
232        request: ChatCompletionRequest,
233    ) -> Result<impl futures::Stream<Item = Result<ChatCompletionChunk>>> {
234        let url = format!("{}/chat/completions", self.base_url);
235
236        let mut req = self.http_client.post(&url).json(&request);
237
238        if let Some(api_key) = &self.api_key {
239            req = req.header("Authorization", format!("Bearer {}", api_key));
240        }
241
242        let user = var("USERNAME").unwrap_or(var("USER").unwrap_or_default());
243        if !user.is_empty() {
244            req = req.header("X-User-ID", HeaderValue::from_str(user.as_str()).unwrap());
245        }
246        let response = req
247            .send()
248            .await
249            .context("Failed to send streaming chat completion request")?;
250
251        if !response.status().is_success() {
252            let status = response.status();
253            let error_text = response.text().await.unwrap_or_default();
254            anyhow::bail!("API error({}): {}", status, error_text);
255        }
256
257        let stream = response.bytes_stream().map(|result| {
258            let bytes = result.context("Failed to read stream chunk")?;
259            let text = String::from_utf8_lossy(&bytes);
260
261            for line in text.lines() {
262                if let Some(data) = line.strip_prefix("data: ") {
263                    if data == "[DONE]" {
264                        continue;
265                    }
266                    let chunk: ChatCompletionChunk =
267                        serde_json::from_str(data).context("Failed to parse chunk")?;
268                    return Ok(chunk);
269                }
270            }
271
272            anyhow::bail!("No valid data in chunk")
273        });
274
275        Ok(stream)
276    }
277}
278
279// ============================================================================
280// Builder Pattern for Requests
281// ============================================================================
282
283impl ChatCompletionRequest {
284    pub fn new(model: impl Into<String>) -> Self {
285        Self {
286            model: model.into(),
287            messages: Vec::new(),
288            temperature: None,
289            max_tokens: None,
290            top_p: None,
291            stream: None,
292            stop: None,
293        }
294    }
295
296    pub fn message(mut self, message: Message) -> Self {
297        self.messages.push(message);
298        self
299    }
300
301    pub fn messages(mut self, messages: Vec<Message>) -> Self {
302        self.messages = messages;
303        self
304    }
305
306    pub fn temperature(mut self, temperature: f32) -> Self {
307        self.temperature = Some(temperature);
308        self
309    }
310
311    pub fn max_tokens(mut self, max_tokens: u32) -> Self {
312        self.max_tokens = Some(max_tokens);
313        self
314    }
315
316    pub fn top_p(mut self, top_p: f32) -> Self {
317        self.top_p = Some(top_p);
318        self
319    }
320
321    pub fn stream(mut self, stream: bool) -> Self {
322        self.stream = Some(stream);
323        self
324    }
325
326    pub fn stop(mut self, stop: Vec<String>) -> Self {
327        self.stop = Some(stop);
328        self
329    }
330}