Skip to main content

atlas/providers/
openai.rs

1//! OpenAI API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::openai;
6//!
7//! let client = openai::Client::new("YOUR_API_KEY");
8//!
9//! let gpt4o = client.completion_model(openai::GPT_4O);
10//! ```
11use crate::{
12    agent::AgentBuilder,
13    completion::{self, CompletionError, CompletionRequest},
14    embeddings::{self, EmbeddingError, EmbeddingsBuilder},
15    extractor::ExtractorBuilder,
16    json_utils, Embed,
17};
18use schemars::JsonSchema;
19use serde::{Deserialize, Serialize};
20use serde_json::json;
21
22// ================================================================
23// Main OpenAI Client
24// ================================================================
25const OPENAI_API_BASE_URL: &str = "https://api.openai.com/v1";
26
27#[derive(Clone)]
28pub struct Client {
29    base_url: String,
30    http_client: reqwest::Client,
31}
32
33impl Client {
34    /// Create a new OpenAI client with the given API key.
35    pub fn new(api_key: &str) -> Self {
36        Self::from_url(api_key, OPENAI_API_BASE_URL)
37    }
38
39    /// Create a new OpenAI client with the given API key and base API URL.
40    pub fn from_url(api_key: &str, base_url: &str) -> Self {
41        Self {
42            base_url: base_url.to_string(),
43            http_client: reqwest::Client::builder()
44                .default_headers({
45                    let mut headers = reqwest::header::HeaderMap::new();
46                    headers.insert(
47                        "Authorization",
48                        format!("Bearer {}", api_key)
49                            .parse()
50                            .expect("Bearer token should parse"),
51                    );
52                    headers
53                })
54                .build()
55                .expect("OpenAI reqwest client should build"),
56        }
57    }
58
59    /// Create a new OpenAI client from the `OPENAI_API_KEY` environment variable.
60    /// Panics if the environment variable is not set.
61    pub fn from_env() -> Self {
62        let api_key = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set");
63        Self::new(&api_key)
64    }
65
66    fn post(&self, path: &str) -> reqwest::RequestBuilder {
67        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
68        self.http_client.post(url)
69    }
70
71    /// Create an embedding model with the given name.
72    /// Note: default embedding dimension of 0 will be used if model is not known.
73    /// If this is the case, it's better to use function `embedding_model_with_ndims`
74    ///
75    /// # Example
76    /// ```
77    /// use rig::providers::openai::{Client, self};
78    ///
79    /// // Initialize the OpenAI client
80    /// let openai = Client::new("your-open-ai-api-key");
81    ///
82    /// let embedding_model = openai.embedding_model(openai::TEXT_EMBEDDING_3_LARGE);
83    /// ```
84    pub fn embedding_model(&self, model: &str) -> EmbeddingModel {
85        let ndims = match model {
86            TEXT_EMBEDDING_3_LARGE => 3072,
87            TEXT_EMBEDDING_3_SMALL | TEXT_EMBEDDING_ADA_002 => 1536,
88            _ => 0,
89        };
90        EmbeddingModel::new(self.clone(), model, ndims)
91    }
92
93    /// Create an embedding model with the given name and the number of dimensions in the embedding generated by the model.
94    ///
95    /// # Example
96    /// ```
97    /// use rig::providers::openai::{Client, self};
98    ///
99    /// // Initialize the OpenAI client
100    /// let openai = Client::new("your-open-ai-api-key");
101    ///
102    /// let embedding_model = openai.embedding_model("model-unknown-to-rig", 3072);
103    /// ```
104    pub fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> EmbeddingModel {
105        EmbeddingModel::new(self.clone(), model, ndims)
106    }
107
108    /// Create an embedding builder with the given embedding model.
109    ///
110    /// # Example
111    /// ```
112    /// use rig::providers::openai::{Client, self};
113    ///
114    /// // Initialize the OpenAI client
115    /// let openai = Client::new("your-open-ai-api-key");
116    ///
117    /// let embeddings = openai.embeddings(openai::TEXT_EMBEDDING_3_LARGE)
118    ///     .simple_document("doc0", "Hello, world!")
119    ///     .simple_document("doc1", "Goodbye, world!")
120    ///     .build()
121    ///     .await
122    ///     .expect("Failed to embed documents");
123    /// ```
124    pub fn embeddings<D: Embed>(&self, model: &str) -> EmbeddingsBuilder<EmbeddingModel, D> {
125        EmbeddingsBuilder::new(self.embedding_model(model))
126    }
127
128    /// Create a completion model with the given name.
129    ///
130    /// # Example
131    /// ```
132    /// use rig::providers::openai::{Client, self};
133    ///
134    /// // Initialize the OpenAI client
135    /// let openai = Client::new("your-open-ai-api-key");
136    ///
137    /// let gpt4 = openai.completion_model(openai::GPT_4);
138    /// ```
139    pub fn completion_model(&self, model: &str) -> CompletionModel {
140        CompletionModel::new(self.clone(), model)
141    }
142
143    /// Create an agent builder with the given completion model.
144    ///
145    /// # Example
146    /// ```
147    /// use rig::providers::openai::{Client, self};
148    ///
149    /// // Initialize the OpenAI client
150    /// let openai = Client::new("your-open-ai-api-key");
151    ///
152    /// let agent = openai.agent(openai::GPT_4)
153    ///    .preamble("You are comedian AI with a mission to make people laugh.")
154    ///    .temperature(0.0)
155    ///    .build();
156    /// ```
157    pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
158        AgentBuilder::new(self.completion_model(model))
159    }
160
161    /// Create an extractor builder with the given completion model.
162    pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
163        &self,
164        model: &str,
165    ) -> ExtractorBuilder<T, CompletionModel> {
166        ExtractorBuilder::new(self.completion_model(model))
167    }
168}
169
170#[derive(Debug, Deserialize)]
171struct ApiErrorResponse {
172    message: String,
173}
174
175#[derive(Debug, Deserialize)]
176#[serde(untagged)]
177enum ApiResponse<T> {
178    Ok(T),
179    Err(ApiErrorResponse),
180}
181
182// ================================================================
183// OpenAI Embedding API
184// ================================================================
185/// `text-embedding-3-large` embedding model
186pub const TEXT_EMBEDDING_3_LARGE: &str = "text-embedding-3-large";
187/// `text-embedding-3-small` embedding model
188pub const TEXT_EMBEDDING_3_SMALL: &str = "text-embedding-3-small";
189/// `text-embedding-ada-002` embedding model
190pub const TEXT_EMBEDDING_ADA_002: &str = "text-embedding-ada-002";
191
192#[derive(Debug, Deserialize)]
193pub struct EmbeddingResponse {
194    pub object: String,
195    pub data: Vec<EmbeddingData>,
196    pub model: String,
197    pub usage: Usage,
198}
199
200impl From<ApiErrorResponse> for EmbeddingError {
201    fn from(err: ApiErrorResponse) -> Self {
202        EmbeddingError::ProviderError(err.message)
203    }
204}
205
206impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
207    fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
208        match value {
209            ApiResponse::Ok(response) => Ok(response),
210            ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
211        }
212    }
213}
214
215#[derive(Debug, Deserialize)]
216pub struct EmbeddingData {
217    pub object: String,
218    pub embedding: Vec<f64>,
219    pub index: usize,
220}
221
222#[derive(Clone, Debug, Deserialize)]
223pub struct Usage {
224    pub prompt_tokens: usize,
225    pub total_tokens: usize,
226}
227
228impl std::fmt::Display for Usage {
229    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
230        write!(
231            f,
232            "Prompt tokens: {} Total tokens: {}",
233            self.prompt_tokens, self.total_tokens
234        )
235    }
236}
237
238#[derive(Clone)]
239pub struct EmbeddingModel {
240    client: Client,
241    pub model: String,
242    ndims: usize,
243}
244
245impl embeddings::EmbeddingModel for EmbeddingModel {
246    const MAX_DOCUMENTS: usize = 1024;
247
248    fn ndims(&self) -> usize {
249        self.ndims
250    }
251
252    async fn embed_texts(
253        &self,
254        documents: impl IntoIterator<Item = String>,
255    ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
256        let documents = documents.into_iter().collect::<Vec<_>>();
257
258        let response = self
259            .client
260            .post("/embeddings")
261            .json(&json!({
262                "model": self.model,
263                "input": documents,
264            }))
265            .send()
266            .await?;
267
268        if response.status().is_success() {
269            match response.json::<ApiResponse<EmbeddingResponse>>().await? {
270                ApiResponse::Ok(response) => {
271                    tracing::info!(target: "rig",
272                        "OpenAI embedding token usage: {}",
273                        response.usage
274                    );
275
276                    if response.data.len() != documents.len() {
277                        return Err(EmbeddingError::ResponseError(
278                            "Response data length does not match input length".into(),
279                        ));
280                    }
281
282                    Ok(response
283                        .data
284                        .into_iter()
285                        .zip(documents.into_iter())
286                        .map(|(embedding, document)| embeddings::Embedding {
287                            document,
288                            vec: embedding.embedding,
289                        })
290                        .collect())
291                }
292                ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
293            }
294        } else {
295            Err(EmbeddingError::ProviderError(response.text().await?))
296        }
297    }
298}
299
300impl EmbeddingModel {
301    pub fn new(client: Client, model: &str, ndims: usize) -> Self {
302        Self {
303            client,
304            model: model.to_string(),
305            ndims,
306        }
307    }
308}
309
310// ================================================================
311// OpenAI Completion API
312// ================================================================
313/// `o1-preview` completion model
314pub const O1_PREVIEW: &str = "o1-preview";
315/// `o1-preview-2024-09-12` completion model
316pub const O1_PREVIEW_2024_09_12: &str = "o1-preview-2024-09-12";
317/// `o1-mini completion model
318pub const O1_MINI: &str = "o1-mini";
319/// `o1-mini-2024-09-12` completion model
320pub const O1_MINI_2024_09_12: &str = "o1-mini-2024-09-12";
321/// `gpt-4o` completion model
322pub const GPT_4O: &str = "gpt-4o";
323/// `gpt-4o-2024-05-13` completion model
324pub const GPT_4O_2024_05_13: &str = "gpt-4o-2024-05-13";
325/// `gpt-4-turbo` completion model
326pub const GPT_4_TURBO: &str = "gpt-4-turbo";
327/// `gpt-4-turbo-2024-04-09` completion model
328pub const GPT_4_TURBO_2024_04_09: &str = "gpt-4-turbo-2024-04-09";
329/// `gpt-4-turbo-preview` completion model
330pub const GPT_4_TURBO_PREVIEW: &str = "gpt-4-turbo-preview";
331/// `gpt-4-0125-preview` completion model
332pub const GPT_4_0125_PREVIEW: &str = "gpt-4-0125-preview";
333/// `gpt-4-1106-preview` completion model
334pub const GPT_4_1106_PREVIEW: &str = "gpt-4-1106-preview";
335/// `gpt-4-vision-preview` completion model
336pub const GPT_4_VISION_PREVIEW: &str = "gpt-4-vision-preview";
337/// `gpt-4-1106-vision-preview` completion model
338pub const GPT_4_1106_VISION_PREVIEW: &str = "gpt-4-1106-vision-preview";
339/// `gpt-4` completion model
340pub const GPT_4: &str = "gpt-4";
341/// `gpt-4-0613` completion model
342pub const GPT_4_0613: &str = "gpt-4-0613";
343/// `gpt-4-32k` completion model
344pub const GPT_4_32K: &str = "gpt-4-32k";
345/// `gpt-4-32k-0613` completion model
346pub const GPT_4_32K_0613: &str = "gpt-4-32k-0613";
347/// `gpt-3.5-turbo` completion model
348pub const GPT_35_TURBO: &str = "gpt-3.5-turbo";
349/// `gpt-3.5-turbo-0125` completion model
350pub const GPT_35_TURBO_0125: &str = "gpt-3.5-turbo-0125";
351/// `gpt-3.5-turbo-1106` completion model
352pub const GPT_35_TURBO_1106: &str = "gpt-3.5-turbo-1106";
353/// `gpt-3.5-turbo-instruct` completion model
354pub const GPT_35_TURBO_INSTRUCT: &str = "gpt-3.5-turbo-instruct";
355
356#[derive(Debug, Deserialize)]
357pub struct CompletionResponse {
358    pub id: String,
359    pub object: String,
360    pub created: u64,
361    pub model: String,
362    pub system_fingerprint: Option<String>,
363    pub choices: Vec<Choice>,
364    pub usage: Option<Usage>,
365}
366
367impl From<ApiErrorResponse> for CompletionError {
368    fn from(err: ApiErrorResponse) -> Self {
369        CompletionError::ProviderError(err.message)
370    }
371}
372
373impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
374    type Error = CompletionError;
375
376    fn try_from(value: CompletionResponse) -> std::prelude::v1::Result<Self, Self::Error> {
377        match value.choices.as_slice() {
378            [Choice {
379                message:
380                    Message {
381                        tool_calls: Some(calls),
382                        ..
383                    },
384                ..
385            }, ..] => {
386                let call = calls.first().ok_or(CompletionError::ResponseError(
387                    "Tool selection is empty".into(),
388                ))?;
389
390                Ok(completion::CompletionResponse {
391                    choice: completion::ModelChoice::ToolCall(
392                        call.function.name.clone(),
393                        serde_json::from_str(&call.function.arguments)?,
394                    ),
395                    raw_response: value,
396                })
397            }
398            [Choice {
399                message:
400                    Message {
401                        content: Some(content),
402                        ..
403                    },
404                ..
405            }, ..] => Ok(completion::CompletionResponse {
406                choice: completion::ModelChoice::Message(content.to_string()),
407                raw_response: value,
408            }),
409            _ => Err(CompletionError::ResponseError(
410                "Response did not contain a message or tool call".into(),
411            )),
412        }
413    }
414}
415
416#[derive(Debug, Deserialize)]
417pub struct Choice {
418    pub index: usize,
419    pub message: Message,
420    pub logprobs: Option<serde_json::Value>,
421    pub finish_reason: String,
422}
423
424#[derive(Debug, Deserialize)]
425pub struct Message {
426    pub role: String,
427    pub content: Option<String>,
428    pub tool_calls: Option<Vec<ToolCall>>,
429}
430
431#[derive(Debug, Deserialize)]
432pub struct ToolCall {
433    pub id: String,
434    pub r#type: String,
435    pub function: Function,
436}
437
438#[derive(Clone, Debug, Deserialize, Serialize)]
439pub struct ToolDefinition {
440    pub r#type: String,
441    pub function: completion::ToolDefinition,
442}
443
444impl From<completion::ToolDefinition> for ToolDefinition {
445    fn from(tool: completion::ToolDefinition) -> Self {
446        Self {
447            r#type: "function".into(),
448            function: tool,
449        }
450    }
451}
452
453#[derive(Debug, Deserialize)]
454pub struct Function {
455    pub name: String,
456    pub arguments: String,
457}
458
459#[derive(Clone)]
460pub struct CompletionModel {
461    client: Client,
462    /// Name of the model (e.g.: gpt-3.5-turbo-1106)
463    pub model: String,
464}
465
466impl CompletionModel {
467    pub fn new(client: Client, model: &str) -> Self {
468        Self {
469            client,
470            model: model.to_string(),
471        }
472    }
473}
474
475impl completion::CompletionModel for CompletionModel {
476    type Response = CompletionResponse;
477
478    async fn completion(
479        &self,
480        mut completion_request: CompletionRequest,
481    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
482        // Add preamble to chat history (if available)
483        let mut full_history = if let Some(preamble) = &completion_request.preamble {
484            vec![completion::Message {
485                role: "system".into(),
486                content: preamble.clone(),
487            }]
488        } else {
489            vec![]
490        };
491
492        // Extend existing chat history
493        full_history.append(&mut completion_request.chat_history);
494
495        // Add context documents to chat history
496        let prompt_with_context = completion_request.prompt_with_context();
497
498        // Add context documents to chat history
499        full_history.push(completion::Message {
500            role: "user".into(),
501            content: prompt_with_context,
502        });
503
504        let request = if completion_request.tools.is_empty() {
505            json!({
506                "model": self.model,
507                "messages": full_history,
508                "temperature": completion_request.temperature,
509            })
510        } else {
511            json!({
512                "model": self.model,
513                "messages": full_history,
514                "temperature": completion_request.temperature,
515                "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
516                "tool_choice": "auto",
517            })
518        };
519
520        let response = self
521            .client
522            .post("/chat/completions")
523            .json(
524                &if let Some(params) = completion_request.additional_params {
525                    json_utils::merge(request, params)
526                } else {
527                    request
528                },
529            )
530            .send()
531            .await?;
532
533        if response.status().is_success() {
534            match response.json::<ApiResponse<CompletionResponse>>().await? {
535                ApiResponse::Ok(response) => {
536                    tracing::info!(target: "rig",
537                        "OpenAI completion token usage: {:?}",
538                        response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string())
539                    );
540                    response.try_into()
541                }
542                ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
543            }
544        } else {
545            Err(CompletionError::ProviderError(response.text().await?))
546        }
547    }
548}