1use crate::{
12 agent::AgentBuilder,
13 completion::{self, CompletionError, CompletionRequest},
14 embeddings::{self, EmbeddingError, EmbeddingsBuilder},
15 extractor::ExtractorBuilder,
16 json_utils, Embed,
17};
18use schemars::JsonSchema;
19use serde::{Deserialize, Serialize};
20use serde_json::json;
21
22const OPENAI_API_BASE_URL: &str = "https://api.openai.com/v1";
26
27#[derive(Clone)]
28pub struct Client {
29 base_url: String,
30 http_client: reqwest::Client,
31}
32
33impl Client {
34 pub fn new(api_key: &str) -> Self {
36 Self::from_url(api_key, OPENAI_API_BASE_URL)
37 }
38
39 pub fn from_url(api_key: &str, base_url: &str) -> Self {
41 Self {
42 base_url: base_url.to_string(),
43 http_client: reqwest::Client::builder()
44 .default_headers({
45 let mut headers = reqwest::header::HeaderMap::new();
46 headers.insert(
47 "Authorization",
48 format!("Bearer {}", api_key)
49 .parse()
50 .expect("Bearer token should parse"),
51 );
52 headers
53 })
54 .build()
55 .expect("OpenAI reqwest client should build"),
56 }
57 }
58
59 pub fn from_env() -> Self {
62 let api_key = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set");
63 Self::new(&api_key)
64 }
65
66 fn post(&self, path: &str) -> reqwest::RequestBuilder {
67 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
68 self.http_client.post(url)
69 }
70
71 pub fn embedding_model(&self, model: &str) -> EmbeddingModel {
85 let ndims = match model {
86 TEXT_EMBEDDING_3_LARGE => 3072,
87 TEXT_EMBEDDING_3_SMALL | TEXT_EMBEDDING_ADA_002 => 1536,
88 _ => 0,
89 };
90 EmbeddingModel::new(self.clone(), model, ndims)
91 }
92
93 pub fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> EmbeddingModel {
105 EmbeddingModel::new(self.clone(), model, ndims)
106 }
107
108 pub fn embeddings<D: Embed>(&self, model: &str) -> EmbeddingsBuilder<EmbeddingModel, D> {
125 EmbeddingsBuilder::new(self.embedding_model(model))
126 }
127
128 pub fn completion_model(&self, model: &str) -> CompletionModel {
140 CompletionModel::new(self.clone(), model)
141 }
142
143 pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
158 AgentBuilder::new(self.completion_model(model))
159 }
160
161 pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
163 &self,
164 model: &str,
165 ) -> ExtractorBuilder<T, CompletionModel> {
166 ExtractorBuilder::new(self.completion_model(model))
167 }
168}
169
170#[derive(Debug, Deserialize)]
171struct ApiErrorResponse {
172 message: String,
173}
174
175#[derive(Debug, Deserialize)]
176#[serde(untagged)]
177enum ApiResponse<T> {
178 Ok(T),
179 Err(ApiErrorResponse),
180}
181
182pub const TEXT_EMBEDDING_3_LARGE: &str = "text-embedding-3-large";
187pub const TEXT_EMBEDDING_3_SMALL: &str = "text-embedding-3-small";
189pub const TEXT_EMBEDDING_ADA_002: &str = "text-embedding-ada-002";
191
192#[derive(Debug, Deserialize)]
193pub struct EmbeddingResponse {
194 pub object: String,
195 pub data: Vec<EmbeddingData>,
196 pub model: String,
197 pub usage: Usage,
198}
199
200impl From<ApiErrorResponse> for EmbeddingError {
201 fn from(err: ApiErrorResponse) -> Self {
202 EmbeddingError::ProviderError(err.message)
203 }
204}
205
206impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
207 fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
208 match value {
209 ApiResponse::Ok(response) => Ok(response),
210 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
211 }
212 }
213}
214
215#[derive(Debug, Deserialize)]
216pub struct EmbeddingData {
217 pub object: String,
218 pub embedding: Vec<f64>,
219 pub index: usize,
220}
221
222#[derive(Clone, Debug, Deserialize)]
223pub struct Usage {
224 pub prompt_tokens: usize,
225 pub total_tokens: usize,
226}
227
228impl std::fmt::Display for Usage {
229 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
230 write!(
231 f,
232 "Prompt tokens: {} Total tokens: {}",
233 self.prompt_tokens, self.total_tokens
234 )
235 }
236}
237
238#[derive(Clone)]
239pub struct EmbeddingModel {
240 client: Client,
241 pub model: String,
242 ndims: usize,
243}
244
245impl embeddings::EmbeddingModel for EmbeddingModel {
246 const MAX_DOCUMENTS: usize = 1024;
247
248 fn ndims(&self) -> usize {
249 self.ndims
250 }
251
252 async fn embed_texts(
253 &self,
254 documents: impl IntoIterator<Item = String>,
255 ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
256 let documents = documents.into_iter().collect::<Vec<_>>();
257
258 let response = self
259 .client
260 .post("/embeddings")
261 .json(&json!({
262 "model": self.model,
263 "input": documents,
264 }))
265 .send()
266 .await?;
267
268 if response.status().is_success() {
269 match response.json::<ApiResponse<EmbeddingResponse>>().await? {
270 ApiResponse::Ok(response) => {
271 tracing::info!(target: "rig",
272 "OpenAI embedding token usage: {}",
273 response.usage
274 );
275
276 if response.data.len() != documents.len() {
277 return Err(EmbeddingError::ResponseError(
278 "Response data length does not match input length".into(),
279 ));
280 }
281
282 Ok(response
283 .data
284 .into_iter()
285 .zip(documents.into_iter())
286 .map(|(embedding, document)| embeddings::Embedding {
287 document,
288 vec: embedding.embedding,
289 })
290 .collect())
291 }
292 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
293 }
294 } else {
295 Err(EmbeddingError::ProviderError(response.text().await?))
296 }
297 }
298}
299
300impl EmbeddingModel {
301 pub fn new(client: Client, model: &str, ndims: usize) -> Self {
302 Self {
303 client,
304 model: model.to_string(),
305 ndims,
306 }
307 }
308}
309
310pub const O1_PREVIEW: &str = "o1-preview";
315pub const O1_PREVIEW_2024_09_12: &str = "o1-preview-2024-09-12";
317pub const O1_MINI: &str = "o1-mini";
319pub const O1_MINI_2024_09_12: &str = "o1-mini-2024-09-12";
321pub const GPT_4O: &str = "gpt-4o";
323pub const GPT_4O_2024_05_13: &str = "gpt-4o-2024-05-13";
325pub const GPT_4_TURBO: &str = "gpt-4-turbo";
327pub const GPT_4_TURBO_2024_04_09: &str = "gpt-4-turbo-2024-04-09";
329pub const GPT_4_TURBO_PREVIEW: &str = "gpt-4-turbo-preview";
331pub const GPT_4_0125_PREVIEW: &str = "gpt-4-0125-preview";
333pub const GPT_4_1106_PREVIEW: &str = "gpt-4-1106-preview";
335pub const GPT_4_VISION_PREVIEW: &str = "gpt-4-vision-preview";
337pub const GPT_4_1106_VISION_PREVIEW: &str = "gpt-4-1106-vision-preview";
339pub const GPT_4: &str = "gpt-4";
341pub const GPT_4_0613: &str = "gpt-4-0613";
343pub const GPT_4_32K: &str = "gpt-4-32k";
345pub const GPT_4_32K_0613: &str = "gpt-4-32k-0613";
347pub const GPT_35_TURBO: &str = "gpt-3.5-turbo";
349pub const GPT_35_TURBO_0125: &str = "gpt-3.5-turbo-0125";
351pub const GPT_35_TURBO_1106: &str = "gpt-3.5-turbo-1106";
353pub const GPT_35_TURBO_INSTRUCT: &str = "gpt-3.5-turbo-instruct";
355
356#[derive(Debug, Deserialize)]
357pub struct CompletionResponse {
358 pub id: String,
359 pub object: String,
360 pub created: u64,
361 pub model: String,
362 pub system_fingerprint: Option<String>,
363 pub choices: Vec<Choice>,
364 pub usage: Option<Usage>,
365}
366
367impl From<ApiErrorResponse> for CompletionError {
368 fn from(err: ApiErrorResponse) -> Self {
369 CompletionError::ProviderError(err.message)
370 }
371}
372
373impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
374 type Error = CompletionError;
375
376 fn try_from(value: CompletionResponse) -> std::prelude::v1::Result<Self, Self::Error> {
377 match value.choices.as_slice() {
378 [Choice {
379 message:
380 Message {
381 tool_calls: Some(calls),
382 ..
383 },
384 ..
385 }, ..] => {
386 let call = calls.first().ok_or(CompletionError::ResponseError(
387 "Tool selection is empty".into(),
388 ))?;
389
390 Ok(completion::CompletionResponse {
391 choice: completion::ModelChoice::ToolCall(
392 call.function.name.clone(),
393 serde_json::from_str(&call.function.arguments)?,
394 ),
395 raw_response: value,
396 })
397 }
398 [Choice {
399 message:
400 Message {
401 content: Some(content),
402 ..
403 },
404 ..
405 }, ..] => Ok(completion::CompletionResponse {
406 choice: completion::ModelChoice::Message(content.to_string()),
407 raw_response: value,
408 }),
409 _ => Err(CompletionError::ResponseError(
410 "Response did not contain a message or tool call".into(),
411 )),
412 }
413 }
414}
415
416#[derive(Debug, Deserialize)]
417pub struct Choice {
418 pub index: usize,
419 pub message: Message,
420 pub logprobs: Option<serde_json::Value>,
421 pub finish_reason: String,
422}
423
424#[derive(Debug, Deserialize)]
425pub struct Message {
426 pub role: String,
427 pub content: Option<String>,
428 pub tool_calls: Option<Vec<ToolCall>>,
429}
430
431#[derive(Debug, Deserialize)]
432pub struct ToolCall {
433 pub id: String,
434 pub r#type: String,
435 pub function: Function,
436}
437
438#[derive(Clone, Debug, Deserialize, Serialize)]
439pub struct ToolDefinition {
440 pub r#type: String,
441 pub function: completion::ToolDefinition,
442}
443
444impl From<completion::ToolDefinition> for ToolDefinition {
445 fn from(tool: completion::ToolDefinition) -> Self {
446 Self {
447 r#type: "function".into(),
448 function: tool,
449 }
450 }
451}
452
453#[derive(Debug, Deserialize)]
454pub struct Function {
455 pub name: String,
456 pub arguments: String,
457}
458
459#[derive(Clone)]
460pub struct CompletionModel {
461 client: Client,
462 pub model: String,
464}
465
466impl CompletionModel {
467 pub fn new(client: Client, model: &str) -> Self {
468 Self {
469 client,
470 model: model.to_string(),
471 }
472 }
473}
474
475impl completion::CompletionModel for CompletionModel {
476 type Response = CompletionResponse;
477
478 async fn completion(
479 &self,
480 mut completion_request: CompletionRequest,
481 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
482 let mut full_history = if let Some(preamble) = &completion_request.preamble {
484 vec![completion::Message {
485 role: "system".into(),
486 content: preamble.clone(),
487 }]
488 } else {
489 vec![]
490 };
491
492 full_history.append(&mut completion_request.chat_history);
494
495 let prompt_with_context = completion_request.prompt_with_context();
497
498 full_history.push(completion::Message {
500 role: "user".into(),
501 content: prompt_with_context,
502 });
503
504 let request = if completion_request.tools.is_empty() {
505 json!({
506 "model": self.model,
507 "messages": full_history,
508 "temperature": completion_request.temperature,
509 })
510 } else {
511 json!({
512 "model": self.model,
513 "messages": full_history,
514 "temperature": completion_request.temperature,
515 "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
516 "tool_choice": "auto",
517 })
518 };
519
520 let response = self
521 .client
522 .post("/chat/completions")
523 .json(
524 &if let Some(params) = completion_request.additional_params {
525 json_utils::merge(request, params)
526 } else {
527 request
528 },
529 )
530 .send()
531 .await?;
532
533 if response.status().is_success() {
534 match response.json::<ApiResponse<CompletionResponse>>().await? {
535 ApiResponse::Ok(response) => {
536 tracing::info!(target: "rig",
537 "OpenAI completion token usage: {:?}",
538 response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string())
539 );
540 response.try_into()
541 }
542 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
543 }
544 } else {
545 Err(CompletionError::ProviderError(response.text().await?))
546 }
547 }
548}