1use crate::http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
2use anyhow::{Context as _, Result, anyhow};
3use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use std::{convert::TryFrom, future::Future};
7use strum::EnumIter;
8
9pub const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
10
11fn is_none_or_empty<T: AsRef<[U]>, U>(opt: &Option<T>) -> bool {
12 opt.as_ref().map_or(true, |v| v.as_ref().is_empty())
13}
14
15#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
16#[serde(rename_all = "lowercase")]
17pub enum Role {
18 User,
19 Assistant,
20 System,
21 Tool,
22}
23
24impl TryFrom<String> for Role {
25 type Error = anyhow::Error;
26
27 fn try_from(value: String) -> Result<Self> {
28 match value.as_str() {
29 "user" => Ok(Self::User),
30 "assistant" => Ok(Self::Assistant),
31 "system" => Ok(Self::System),
32 "tool" => Ok(Self::Tool),
33 _ => anyhow::bail!("invalid role '{value}'"),
34 }
35 }
36}
37
38impl From<Role> for String {
39 fn from(val: Role) -> Self {
40 match val {
41 Role::User => "user".to_owned(),
42 Role::Assistant => "assistant".to_owned(),
43 Role::System => "system".to_owned(),
44 Role::Tool => "tool".to_owned(),
45 }
46 }
47}
48
49#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
50#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
51pub enum Model {
52 #[serde(rename = "gpt-3.5-turbo")]
53 ThreePointFiveTurbo,
54 #[serde(rename = "gpt-4")]
55 Four,
56 #[serde(rename = "gpt-4-turbo")]
57 FourTurbo,
58 #[serde(rename = "gpt-4o")]
59 #[default]
60 FourOmni,
61 #[serde(rename = "gpt-4o-mini")]
62 FourOmniMini,
63 #[serde(rename = "gpt-4.1")]
64 FourPointOne,
65 #[serde(rename = "gpt-4.1-mini")]
66 FourPointOneMini,
67 #[serde(rename = "gpt-4.1-nano")]
68 FourPointOneNano,
69 #[serde(rename = "o1")]
70 O1,
71 #[serde(rename = "o3-mini")]
72 O3Mini,
73 #[serde(rename = "o3")]
74 O3,
75 #[serde(rename = "o4-mini")]
76 O4Mini,
77
78 #[serde(rename = "custom")]
79 Custom {
80 name: String,
81 display_name: Option<String>,
83 max_tokens: u64,
84 max_output_tokens: Option<u64>,
85 max_completion_tokens: Option<u64>,
86 },
87}
88
89impl Model {
90 pub fn default_fast() -> Self {
91 Self::FourPointOneMini
92 }
93
94 pub fn from_id(id: &str) -> Result<Self> {
95 match id {
96 "gpt-3.5-turbo" => Ok(Self::ThreePointFiveTurbo),
97 "gpt-4" => Ok(Self::Four),
98 "gpt-4-turbo-preview" => Ok(Self::FourTurbo),
99 "gpt-4o" => Ok(Self::FourOmni),
100 "gpt-4o-mini" => Ok(Self::FourOmniMini),
101 "gpt-4.1" => Ok(Self::FourPointOne),
102 "gpt-4.1-mini" => Ok(Self::FourPointOneMini),
103 "gpt-4.1-nano" => Ok(Self::FourPointOneNano),
104 "o1" => Ok(Self::O1),
105 "o3-mini" => Ok(Self::O3Mini),
106 "o3" => Ok(Self::O3),
107 "o4-mini" => Ok(Self::O4Mini),
108 invalid_id => anyhow::bail!("invalid model id '{invalid_id}'"),
109 }
110 }
111
112 pub fn id(&self) -> &str {
113 match self {
114 Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
115 Self::Four => "gpt-4",
116 Self::FourTurbo => "gpt-4-turbo",
117 Self::FourOmni => "gpt-4o",
118 Self::FourOmniMini => "gpt-4o-mini",
119 Self::FourPointOne => "gpt-4.1",
120 Self::FourPointOneMini => "gpt-4.1-mini",
121 Self::FourPointOneNano => "gpt-4.1-nano",
122 Self::O1 => "o1",
123 Self::O3Mini => "o3-mini",
124 Self::O3 => "o3",
125 Self::O4Mini => "o4-mini",
126 Self::Custom { name, .. } => name,
127 }
128 }
129
130 pub fn display_name(&self) -> &str {
131 match self {
132 Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
133 Self::Four => "gpt-4",
134 Self::FourTurbo => "gpt-4-turbo",
135 Self::FourOmni => "gpt-4o",
136 Self::FourOmniMini => "gpt-4o-mini",
137 Self::FourPointOne => "gpt-4.1",
138 Self::FourPointOneMini => "gpt-4.1-mini",
139 Self::FourPointOneNano => "gpt-4.1-nano",
140 Self::O1 => "o1",
141 Self::O3Mini => "o3-mini",
142 Self::O3 => "o3",
143 Self::O4Mini => "o4-mini",
144 Self::Custom {
145 name, display_name, ..
146 } => display_name.as_ref().unwrap_or(name),
147 }
148 }
149
150 pub fn max_token_count(&self) -> u64 {
151 match self {
152 Self::ThreePointFiveTurbo => 16_385,
153 Self::Four => 8_192,
154 Self::FourTurbo => 128_000,
155 Self::FourOmni => 128_000,
156 Self::FourOmniMini => 128_000,
157 Self::FourPointOne => 1_047_576,
158 Self::FourPointOneMini => 1_047_576,
159 Self::FourPointOneNano => 1_047_576,
160 Self::O1 => 200_000,
161 Self::O3Mini => 200_000,
162 Self::O3 => 200_000,
163 Self::O4Mini => 200_000,
164 Self::Custom { max_tokens, .. } => *max_tokens,
165 }
166 }
167
168 pub fn max_output_tokens(&self) -> Option<u64> {
169 match self {
170 Self::Custom {
171 max_output_tokens, ..
172 } => *max_output_tokens,
173 Self::ThreePointFiveTurbo => Some(4_096),
174 Self::Four => Some(8_192),
175 Self::FourTurbo => Some(4_096),
176 Self::FourOmni => Some(16_384),
177 Self::FourOmniMini => Some(16_384),
178 Self::FourPointOne => Some(32_768),
179 Self::FourPointOneMini => Some(32_768),
180 Self::FourPointOneNano => Some(32_768),
181 Self::O1 => Some(100_000),
182 Self::O3Mini => Some(100_000),
183 Self::O3 => Some(100_000),
184 Self::O4Mini => Some(100_000),
185 }
186 }
187
188 pub fn supports_parallel_tool_calls(&self) -> bool {
192 match self {
193 Self::ThreePointFiveTurbo
194 | Self::Four
195 | Self::FourTurbo
196 | Self::FourOmni
197 | Self::FourOmniMini
198 | Self::FourPointOne
199 | Self::FourPointOneMini
200 | Self::FourPointOneNano => true,
201 Self::O1 | Self::O3 | Self::O3Mini | Self::O4Mini | Model::Custom { .. } => false,
202 }
203 }
204}
205
206#[derive(Debug, Serialize, Deserialize)]
207pub struct Request {
208 pub model: String,
209 pub messages: Vec<RequestMessage>,
210 pub stream: bool,
211 #[serde(default, skip_serializing_if = "Option::is_none")]
212 pub max_completion_tokens: Option<u64>,
213 #[serde(default, skip_serializing_if = "Vec::is_empty")]
214 pub stop: Vec<String>,
215 pub temperature: f32,
216 #[serde(default, skip_serializing_if = "Option::is_none")]
217 pub tool_choice: Option<ToolChoice>,
218 #[serde(default, skip_serializing_if = "Option::is_none")]
220 pub parallel_tool_calls: Option<bool>,
221 #[serde(default, skip_serializing_if = "Vec::is_empty")]
222 pub tools: Vec<ToolDefinition>,
223}
224
225#[derive(Debug, Serialize, Deserialize)]
226#[serde(untagged)]
227pub enum ToolChoice {
228 Auto,
229 Required,
230 None,
231 Other(ToolDefinition),
232}
233
234#[derive(Clone, Deserialize, Serialize, Debug)]
235#[serde(tag = "type", rename_all = "snake_case")]
236pub enum ToolDefinition {
237 #[allow(dead_code)]
238 Function { function: FunctionDefinition },
239}
240
241#[derive(Clone, Debug, Serialize, Deserialize)]
242pub struct FunctionDefinition {
243 pub name: String,
244 pub description: Option<String>,
245 pub parameters: Option<Value>,
246}
247
248#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
249#[serde(tag = "role", rename_all = "lowercase")]
250pub enum RequestMessage {
251 Assistant {
252 content: Option<MessageContent>,
253 #[serde(default, skip_serializing_if = "Vec::is_empty")]
254 tool_calls: Vec<ToolCall>,
255 },
256 User {
257 content: MessageContent,
258 },
259 System {
260 content: MessageContent,
261 },
262 Tool {
263 content: MessageContent,
264 tool_call_id: String,
265 },
266}
267
268#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
269#[serde(untagged)]
270pub enum MessageContent {
271 Plain(String),
272 Multipart(Vec<MessagePart>),
273}
274
275impl MessageContent {
276 pub fn empty() -> Self {
277 MessageContent::Multipart(vec![])
278 }
279
280 pub fn push_part(&mut self, part: MessagePart) {
281 match self {
282 MessageContent::Plain(text) => {
283 *self =
284 MessageContent::Multipart(vec![MessagePart::Text { text: text.clone() }, part]);
285 }
286 MessageContent::Multipart(parts) if parts.is_empty() => match part {
287 MessagePart::Text { text } => *self = MessageContent::Plain(text),
288 MessagePart::Image { .. } => *self = MessageContent::Multipart(vec![part]),
289 },
290 MessageContent::Multipart(parts) => parts.push(part),
291 }
292 }
293}
294
295impl From<Vec<MessagePart>> for MessageContent {
296 fn from(mut parts: Vec<MessagePart>) -> Self {
297 if let [MessagePart::Text { text }] = parts.as_mut_slice() {
298 MessageContent::Plain(std::mem::take(text))
299 } else {
300 MessageContent::Multipart(parts)
301 }
302 }
303}
304
305#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
306#[serde(tag = "type")]
307pub enum MessagePart {
308 #[serde(rename = "text")]
309 Text { text: String },
310 #[serde(rename = "image_url")]
311 Image { image_url: ImageUrl },
312}
313
314#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
315pub struct ImageUrl {
316 pub url: String,
317 #[serde(skip_serializing_if = "Option::is_none")]
318 pub detail: Option<String>,
319}
320
321#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
322pub struct ToolCall {
323 pub id: String,
324 #[serde(flatten)]
325 pub content: ToolCallContent,
326}
327
328#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
329#[serde(tag = "type", rename_all = "lowercase")]
330pub enum ToolCallContent {
331 Function { function: FunctionContent },
332}
333
334#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
335pub struct FunctionContent {
336 pub name: String,
337 pub arguments: String,
338}
339
340#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
341pub struct ResponseMessageDelta {
342 pub role: Option<Role>,
343 pub content: Option<String>,
344 #[serde(default, skip_serializing_if = "is_none_or_empty")]
345 pub tool_calls: Option<Vec<ToolCallChunk>>,
346}
347
348#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
349pub struct ToolCallChunk {
350 pub index: usize,
351 pub id: Option<String>,
352
353 pub function: Option<FunctionChunk>,
357}
358
359#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
360pub struct FunctionChunk {
361 pub name: Option<String>,
362 pub arguments: Option<String>,
363}
364
365#[derive(Serialize, Deserialize, Debug)]
366pub struct Usage {
367 pub prompt_tokens: u64,
368 pub completion_tokens: u64,
369 pub total_tokens: u64,
370}
371
372#[derive(Serialize, Deserialize, Debug)]
373pub struct ChoiceDelta {
374 pub index: u32,
375 pub delta: ResponseMessageDelta,
376 pub finish_reason: Option<String>,
377}
378
379#[derive(Serialize, Deserialize, Debug)]
380#[serde(untagged)]
381pub enum ResponseStreamResult {
382 Ok(ResponseStreamEvent),
383 Err { error: String },
384}
385
386#[derive(Serialize, Deserialize, Debug)]
387pub struct ResponseStreamEvent {
388 pub model: String,
389 pub choices: Vec<ChoiceDelta>,
390 pub usage: Option<Usage>,
391}
392
393pub async fn stream_completion(
394 client: &dyn HttpClient,
395 api_url: &str,
396 api_key: &str,
397 request: Request,
398) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>> {
399 let uri = format!("{api_url}/chat/completions");
400 let request_builder = HttpRequest::builder()
402 .method(Method::POST)
403 .uri(uri)
404 .header("Content-Type", "application/json")
405 .header("Authorization", format!("Bearer {}", api_key));
406
407 let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
408 let mut response = client.send(request).await?;
409 if response.status().is_success() {
410 let reader = BufReader::new(response.into_body());
411 Ok(reader
412 .lines()
413 .filter_map(|line| async move {
414 match line {
415 Ok(line) => {
416 let line = line.strip_prefix("data: ")?;
417 if line == "[DONE]" {
418 None
419 } else {
420 match serde_json::from_str(line) {
421 Ok(ResponseStreamResult::Ok(response)) => Some(Ok(response)),
422 Ok(ResponseStreamResult::Err { error }) => {
423 Some(Err(anyhow!(error)))
424 }
425 Err(error) => Some(Err(anyhow!(error))),
426 }
427 }
428 }
429 Err(error) => Some(Err(anyhow!(error))),
430 }
431 })
432 .boxed())
433 } else {
434 let mut body = String::new();
435 response.body_mut().read_to_string(&mut body).await?;
436
437 #[derive(Deserialize)]
438 struct OpenAiResponse {
439 error: OpenAiError,
440 }
441
442 #[derive(Deserialize)]
443 struct OpenAiError {
444 message: String,
445 }
446
447 match serde_json::from_str::<OpenAiResponse>(&body) {
448 Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
449 "API request to {} failed: {}",
450 api_url,
451 response.error.message,
452 )),
453
454 _ => anyhow::bail!(
455 "API request to {} failed with status {}: {}",
456 api_url,
457 response.status(),
458 body,
459 ),
460 }
461 }
462}
463
464#[derive(Copy, Clone, Serialize, Deserialize)]
465pub enum OpenAiEmbeddingModel {
466 #[serde(rename = "text-embedding-3-small")]
467 TextEmbedding3Small,
468 #[serde(rename = "text-embedding-3-large")]
469 TextEmbedding3Large,
470}
471
472#[derive(Serialize)]
473struct OpenAiEmbeddingRequest<'a> {
474 model: OpenAiEmbeddingModel,
475 input: Vec<&'a str>,
476}
477
478#[derive(Deserialize)]
479pub struct OpenAiEmbeddingResponse {
480 pub data: Vec<OpenAiEmbedding>,
481}
482
483#[derive(Deserialize)]
484pub struct OpenAiEmbedding {
485 pub embedding: Vec<f32>,
486}
487
488pub fn embed<'a>(
489 client: &dyn HttpClient,
490 api_url: &str,
491 api_key: &str,
492 model: OpenAiEmbeddingModel,
493 texts: impl IntoIterator<Item = &'a str>,
494) -> impl 'static + Future<Output = Result<OpenAiEmbeddingResponse>> {
495 let uri = format!("{api_url}/embeddings");
496
497 let request = OpenAiEmbeddingRequest {
498 model,
499 input: texts.into_iter().collect(),
500 };
501 let body = AsyncBody::from(serde_json::to_string(&request).unwrap());
502 let request = HttpRequest::builder()
503 .method(Method::POST)
504 .uri(uri)
505 .header("Content-Type", "application/json")
506 .header("Authorization", format!("Bearer {}", api_key))
507 .body(body)
508 .map(|request| client.send(request));
509
510 async move {
511 let mut response = request?.await?;
512 let mut body = String::new();
513 response.body_mut().read_to_string(&mut body).await?;
514
515 anyhow::ensure!(
516 response.status().is_success(),
517 "error during embedding, status: {:?}, body: {:?}",
518 response.status(),
519 body
520 );
521 let response: OpenAiEmbeddingResponse =
522 serde_json::from_str(&body).context("failed to parse OpenAI embedding response")?;
523 Ok(response)
524 }
525}