1use std::{convert::Infallible, str::FromStr};
12
13use crate::{
14 agent::AgentBuilder,
15 completion::{self, CompletionError, CompletionRequest},
16 embeddings::{self, EmbeddingError, EmbeddingsBuilder},
17 extractor::ExtractorBuilder,
18 json_utils,
19 message::{self, AudioMediaType, ImageDetail},
20 one_or_many::string_or_one_or_many,
21 transcription::{self, TranscriptionError},
22 Embed, OneOrMany,
23};
24use reqwest::multipart::Part;
25use schemars::JsonSchema;
26use serde::{Deserialize, Serialize};
27use serde_json::json;
28
29const OPENAI_API_BASE_URL: &str = "https://api.openai.com/v1";
33
34#[derive(Clone)]
35pub struct Client {
36 base_url: String,
37 http_client: reqwest::Client,
38}
39
40impl Client {
41 pub fn new(api_key: &str) -> Self {
43 Self::from_url(api_key, OPENAI_API_BASE_URL)
44 }
45
46 pub fn from_url(api_key: &str, base_url: &str) -> Self {
48 Self {
49 base_url: base_url.to_string(),
50 http_client: reqwest::Client::builder()
51 .default_headers({
52 let mut headers = reqwest::header::HeaderMap::new();
53 headers.insert(
54 "Authorization",
55 format!("Bearer {}", api_key)
56 .parse()
57 .expect("Bearer token should parse"),
58 );
59 headers
60 })
61 .build()
62 .expect("OpenAI reqwest client should build"),
63 }
64 }
65
66 pub fn from_env() -> Self {
69 let api_key = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set");
70 Self::new(&api_key)
71 }
72
73 fn post(&self, path: &str) -> reqwest::RequestBuilder {
74 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
75 self.http_client.post(url)
76 }
77
78 pub fn embedding_model(&self, model: &str) -> EmbeddingModel {
92 let ndims = match model {
93 TEXT_EMBEDDING_3_LARGE => 3072,
94 TEXT_EMBEDDING_3_SMALL | TEXT_EMBEDDING_ADA_002 => 1536,
95 _ => 0,
96 };
97 EmbeddingModel::new(self.clone(), model, ndims)
98 }
99
100 pub fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> EmbeddingModel {
112 EmbeddingModel::new(self.clone(), model, ndims)
113 }
114
115 pub fn embeddings<D: Embed>(&self, model: &str) -> EmbeddingsBuilder<EmbeddingModel, D> {
132 EmbeddingsBuilder::new(self.embedding_model(model))
133 }
134
135 pub fn completion_model(&self, model: &str) -> CompletionModel {
147 CompletionModel::new(self.clone(), model)
148 }
149
150 pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
165 AgentBuilder::new(self.completion_model(model))
166 }
167
168 pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
170 &self,
171 model: &str,
172 ) -> ExtractorBuilder<T, CompletionModel> {
173 ExtractorBuilder::new(self.completion_model(model))
174 }
175
176 pub fn transcription_model(&self, model: &str) -> TranscriptionModel {
188 TranscriptionModel::new(self.clone(), model)
189 }
190}
191
192#[derive(Debug, Deserialize)]
193struct ApiErrorResponse {
194 message: String,
195}
196
197#[derive(Debug, Deserialize)]
198#[serde(untagged)]
199enum ApiResponse<T> {
200 Ok(T),
201 Err(ApiErrorResponse),
202}
203
204pub const TEXT_EMBEDDING_3_LARGE: &str = "text-embedding-3-large";
209pub const TEXT_EMBEDDING_3_SMALL: &str = "text-embedding-3-small";
211pub const TEXT_EMBEDDING_ADA_002: &str = "text-embedding-ada-002";
213
214#[derive(Debug, Deserialize)]
215pub struct EmbeddingResponse {
216 pub object: String,
217 pub data: Vec<EmbeddingData>,
218 pub model: String,
219 pub usage: Usage,
220}
221
222impl From<ApiErrorResponse> for EmbeddingError {
223 fn from(err: ApiErrorResponse) -> Self {
224 EmbeddingError::ProviderError(err.message)
225 }
226}
227
228impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
229 fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
230 match value {
231 ApiResponse::Ok(response) => Ok(response),
232 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
233 }
234 }
235}
236
237#[derive(Debug, Deserialize)]
238pub struct EmbeddingData {
239 pub object: String,
240 pub embedding: Vec<f64>,
241 pub index: usize,
242}
243
244#[derive(Clone, Debug, Deserialize)]
245pub struct Usage {
246 pub prompt_tokens: usize,
247 pub total_tokens: usize,
248}
249
250impl std::fmt::Display for Usage {
251 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
252 write!(
253 f,
254 "Prompt tokens: {} Total tokens: {}",
255 self.prompt_tokens, self.total_tokens
256 )
257 }
258}
259
260#[derive(Clone)]
261pub struct EmbeddingModel {
262 client: Client,
263 pub model: String,
264 ndims: usize,
265}
266
267impl embeddings::EmbeddingModel for EmbeddingModel {
268 const MAX_DOCUMENTS: usize = 1024;
269
270 fn ndims(&self) -> usize {
271 self.ndims
272 }
273
274 #[cfg_attr(feature = "worker", worker::send)]
275 async fn embed_texts(
276 &self,
277 documents: impl IntoIterator<Item = String>,
278 ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
279 let documents = documents.into_iter().collect::<Vec<_>>();
280
281 let response = self
282 .client
283 .post("/embeddings")
284 .json(&json!({
285 "model": self.model,
286 "input": documents,
287 }))
288 .send()
289 .await?;
290
291 if response.status().is_success() {
292 match response.json::<ApiResponse<EmbeddingResponse>>().await? {
293 ApiResponse::Ok(response) => {
294 tracing::info!(target: "rig",
295 "OpenAI embedding token usage: {}",
296 response.usage
297 );
298
299 if response.data.len() != documents.len() {
300 return Err(EmbeddingError::ResponseError(
301 "Response data length does not match input length".into(),
302 ));
303 }
304
305 Ok(response
306 .data
307 .into_iter()
308 .zip(documents.into_iter())
309 .map(|(embedding, document)| embeddings::Embedding {
310 document,
311 vec: embedding.embedding,
312 })
313 .collect())
314 }
315 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
316 }
317 } else {
318 Err(EmbeddingError::ProviderError(response.text().await?))
319 }
320 }
321}
322
323impl EmbeddingModel {
324 pub fn new(client: Client, model: &str, ndims: usize) -> Self {
325 Self {
326 client,
327 model: model.to_string(),
328 ndims,
329 }
330 }
331}
332
333pub const O3_MINI: &str = "o3-mini";
338pub const O3_MINI_2025_01_31: &str = "o3-mini-2025-01-31";
340pub const O1: &str = "o1";
342pub const O1_2024_12_17: &str = "o1-2024-12-17";
344pub const O1_PREVIEW: &str = "o1-preview";
346pub const O1_PREVIEW_2024_09_12: &str = "o1-preview-2024-09-12";
348pub const O1_MINI: &str = "o1-mini";
350pub const O1_MINI_2024_09_12: &str = "o1-mini-2024-09-12";
352pub const GPT_4O: &str = "gpt-4o";
354pub const GPT_4O_MINI: &str = "gpt-4o-mini";
356pub const GPT_4O_2024_05_13: &str = "gpt-4o-2024-05-13";
358pub const GPT_4_TURBO: &str = "gpt-4-turbo";
360pub const GPT_4_TURBO_2024_04_09: &str = "gpt-4-turbo-2024-04-09";
362pub const GPT_4_TURBO_PREVIEW: &str = "gpt-4-turbo-preview";
364pub const GPT_4_0125_PREVIEW: &str = "gpt-4-0125-preview";
366pub const GPT_4_1106_PREVIEW: &str = "gpt-4-1106-preview";
368pub const GPT_4_VISION_PREVIEW: &str = "gpt-4-vision-preview";
370pub const GPT_4_1106_VISION_PREVIEW: &str = "gpt-4-1106-vision-preview";
372pub const GPT_4: &str = "gpt-4";
374pub const GPT_4_0613: &str = "gpt-4-0613";
376pub const GPT_4_32K: &str = "gpt-4-32k";
378pub const GPT_4_32K_0613: &str = "gpt-4-32k-0613";
380pub const GPT_35_TURBO: &str = "gpt-3.5-turbo";
382pub const GPT_35_TURBO_0125: &str = "gpt-3.5-turbo-0125";
384pub const GPT_35_TURBO_1106: &str = "gpt-3.5-turbo-1106";
386pub const GPT_35_TURBO_INSTRUCT: &str = "gpt-3.5-turbo-instruct";
388
389#[derive(Debug, Deserialize)]
390pub struct CompletionResponse {
391 pub id: String,
392 pub object: String,
393 pub created: u64,
394 pub model: String,
395 pub system_fingerprint: Option<String>,
396 pub choices: Vec<Choice>,
397 pub usage: Option<Usage>,
398}
399
400impl From<ApiErrorResponse> for CompletionError {
401 fn from(err: ApiErrorResponse) -> Self {
402 CompletionError::ProviderError(err.message)
403 }
404}
405
406impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
407 type Error = CompletionError;
408
409 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
410 let choice = response.choices.first().ok_or_else(|| {
411 CompletionError::ResponseError("Response contained no choices".to_owned())
412 })?;
413
414 let content = match &choice.message {
415 Message::Assistant {
416 content,
417 tool_calls,
418 ..
419 } => {
420 let mut content = content
421 .iter()
422 .map(|c| match c {
423 AssistantContent::Text { text } => completion::AssistantContent::text(text),
424 AssistantContent::Refusal { refusal } => {
425 completion::AssistantContent::text(refusal)
426 }
427 })
428 .collect::<Vec<_>>();
429
430 content.extend(
431 tool_calls
432 .iter()
433 .map(|call| {
434 completion::AssistantContent::tool_call(
435 &call.id,
436 &call.function.name,
437 call.function.arguments.clone(),
438 )
439 })
440 .collect::<Vec<_>>(),
441 );
442 Ok(content)
443 }
444 _ => Err(CompletionError::ResponseError(
445 "Response did not contain a valid message or tool call".into(),
446 )),
447 }?;
448
449 let choice = OneOrMany::many(content).map_err(|_| {
450 CompletionError::ResponseError(
451 "Response contained no message or tool call (empty)".to_owned(),
452 )
453 })?;
454
455 Ok(completion::CompletionResponse {
456 choice,
457 raw_response: response,
458 })
459 }
460}
461
462#[derive(Debug, Serialize, Deserialize)]
463pub struct Choice {
464 pub index: usize,
465 pub message: Message,
466 pub logprobs: Option<serde_json::Value>,
467 pub finish_reason: String,
468}
469
470#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
471#[serde(tag = "role", rename_all = "lowercase")]
472pub enum Message {
473 System {
474 #[serde(deserialize_with = "string_or_one_or_many")]
475 content: OneOrMany<SystemContent>,
476 #[serde(skip_serializing_if = "Option::is_none")]
477 name: Option<String>,
478 },
479 User {
480 #[serde(deserialize_with = "string_or_one_or_many")]
481 content: OneOrMany<UserContent>,
482 #[serde(skip_serializing_if = "Option::is_none")]
483 name: Option<String>,
484 },
485 Assistant {
486 #[serde(default, deserialize_with = "json_utils::string_or_vec")]
487 content: Vec<AssistantContent>,
488 #[serde(skip_serializing_if = "Option::is_none")]
489 refusal: Option<String>,
490 #[serde(skip_serializing_if = "Option::is_none")]
491 audio: Option<AudioAssistant>,
492 #[serde(skip_serializing_if = "Option::is_none")]
493 name: Option<String>,
494 #[serde(
495 default,
496 deserialize_with = "json_utils::null_or_vec",
497 skip_serializing_if = "Vec::is_empty"
498 )]
499 tool_calls: Vec<ToolCall>,
500 },
501 #[serde(rename = "tool")]
502 ToolResult {
503 tool_call_id: String,
504 content: OneOrMany<ToolResultContent>,
505 },
506}
507
508impl Message {
509 pub fn system(content: &str) -> Self {
510 Message::System {
511 content: OneOrMany::one(content.to_owned().into()),
512 name: None,
513 }
514 }
515}
516
517#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
518pub struct AudioAssistant {
519 id: String,
520}
521
522#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
523pub struct SystemContent {
524 #[serde(default)]
525 r#type: SystemContentType,
526 text: String,
527}
528
529#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
530#[serde(rename_all = "lowercase")]
531pub enum SystemContentType {
532 #[default]
533 Text,
534}
535
536#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
537#[serde(tag = "type", rename_all = "lowercase")]
538pub enum AssistantContent {
539 Text { text: String },
540 Refusal { refusal: String },
541}
542
543#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
544#[serde(tag = "type", rename_all = "lowercase")]
545pub enum UserContent {
546 Text { text: String },
547 Image { image_url: ImageUrl },
548 Audio { input_audio: InputAudio },
549}
550
551#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
552pub struct ImageUrl {
553 pub url: String,
554 #[serde(default)]
555 pub detail: ImageDetail,
556}
557
558#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
559pub struct InputAudio {
560 pub data: String,
561 pub format: AudioMediaType,
562}
563
564#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
565pub struct ToolResultContent {
566 #[serde(default)]
567 r#type: ToolResultContentType,
568 text: String,
569}
570
571#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
572#[serde(rename_all = "lowercase")]
573pub enum ToolResultContentType {
574 #[default]
575 Text,
576}
577
578impl FromStr for ToolResultContent {
579 type Err = Infallible;
580
581 fn from_str(s: &str) -> Result<Self, Self::Err> {
582 Ok(s.to_owned().into())
583 }
584}
585
586impl From<String> for ToolResultContent {
587 fn from(s: String) -> Self {
588 ToolResultContent {
589 r#type: ToolResultContentType::default(),
590 text: s,
591 }
592 }
593}
594
595#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
596pub struct ToolCall {
597 pub id: String,
598 #[serde(default)]
599 pub r#type: ToolType,
600 pub function: Function,
601}
602
603#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
604#[serde(rename_all = "lowercase")]
605pub enum ToolType {
606 #[default]
607 Function,
608}
609
610#[derive(Debug, Deserialize, Serialize, Clone)]
611pub struct ToolDefinition {
612 pub r#type: String,
613 pub function: completion::ToolDefinition,
614}
615
616impl From<completion::ToolDefinition> for ToolDefinition {
617 fn from(tool: completion::ToolDefinition) -> Self {
618 Self {
619 r#type: "function".into(),
620 function: tool,
621 }
622 }
623}
624
625#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
626pub struct Function {
627 pub name: String,
628 #[serde(with = "json_utils::stringified_json")]
629 pub arguments: serde_json::Value,
630}
631
632impl TryFrom<message::Message> for Vec<Message> {
633 type Error = message::MessageError;
634
635 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
636 match message {
637 message::Message::User { content } => {
638 let (tool_results, other_content): (Vec<_>, Vec<_>) = content
639 .into_iter()
640 .partition(|content| matches!(content, message::UserContent::ToolResult(_)));
641
642 if !tool_results.is_empty() {
645 tool_results
646 .into_iter()
647 .map(|content| match content {
648 message::UserContent::ToolResult(message::ToolResult {
649 id,
650 content,
651 }) => Ok::<_, message::MessageError>(Message::ToolResult {
652 tool_call_id: id,
653 content: content.try_map(|content| match content {
654 message::ToolResultContent::Text(message::Text { text }) => {
655 Ok(text.into())
656 }
657 _ => Err(message::MessageError::ConversionError(
658 "Tool result content does not support non-text".into(),
659 )),
660 })?,
661 }),
662 _ => unreachable!(),
663 })
664 .collect::<Result<Vec<_>, _>>()
665 } else {
666 let other_content = OneOrMany::many(other_content).expect(
667 "There must be other content here if there were no tool result content",
668 );
669
670 Ok(vec![Message::User {
671 content: other_content.map(|content| match content {
672 message::UserContent::Text(message::Text { text }) => {
673 UserContent::Text { text }
674 }
675 message::UserContent::Image(message::Image {
676 data, detail, ..
677 }) => UserContent::Image {
678 image_url: ImageUrl {
679 url: data,
680 detail: detail.unwrap_or_default(),
681 },
682 },
683 message::UserContent::Document(message::Document { data, .. }) => {
684 UserContent::Text { text: data }
685 }
686 message::UserContent::Audio(message::Audio {
687 data,
688 media_type,
689 ..
690 }) => UserContent::Audio {
691 input_audio: InputAudio {
692 data,
693 format: match media_type {
694 Some(media_type) => media_type,
695 None => AudioMediaType::MP3,
696 },
697 },
698 },
699 _ => unreachable!(),
700 }),
701 name: None,
702 }])
703 }
704 }
705 message::Message::Assistant { content } => {
706 let (text_content, tool_calls) = content.into_iter().fold(
707 (Vec::new(), Vec::new()),
708 |(mut texts, mut tools), content| {
709 match content {
710 message::AssistantContent::Text(text) => texts.push(text),
711 message::AssistantContent::ToolCall(tool_call) => tools.push(tool_call),
712 }
713 (texts, tools)
714 },
715 );
716
717 Ok(vec![Message::Assistant {
720 content: text_content
721 .into_iter()
722 .map(|content| content.text.into())
723 .collect::<Vec<_>>(),
724 refusal: None,
725 audio: None,
726 name: None,
727 tool_calls: tool_calls
728 .into_iter()
729 .map(|tool_call| tool_call.into())
730 .collect::<Vec<_>>(),
731 }])
732 }
733 }
734 }
735}
736
737impl From<message::ToolCall> for ToolCall {
738 fn from(tool_call: message::ToolCall) -> Self {
739 Self {
740 id: tool_call.id,
741 r#type: ToolType::default(),
742 function: Function {
743 name: tool_call.function.name,
744 arguments: tool_call.function.arguments,
745 },
746 }
747 }
748}
749
750impl From<ToolCall> for message::ToolCall {
751 fn from(tool_call: ToolCall) -> Self {
752 Self {
753 id: tool_call.id,
754 function: message::ToolFunction {
755 name: tool_call.function.name,
756 arguments: tool_call.function.arguments,
757 },
758 }
759 }
760}
761
762impl TryFrom<Message> for message::Message {
763 type Error = message::MessageError;
764
765 fn try_from(message: Message) -> Result<Self, Self::Error> {
766 Ok(match message {
767 Message::User { content, .. } => message::Message::User {
768 content: content.map(|content| content.into()),
769 },
770 Message::Assistant {
771 content,
772 tool_calls,
773 ..
774 } => {
775 let mut content = content
776 .into_iter()
777 .map(|content| match content {
778 AssistantContent::Text { text } => message::AssistantContent::text(text),
779
780 AssistantContent::Refusal { refusal } => {
783 message::AssistantContent::text(refusal)
784 }
785 })
786 .collect::<Vec<_>>();
787
788 content.extend(
789 tool_calls
790 .into_iter()
791 .map(|tool_call| Ok(message::AssistantContent::ToolCall(tool_call.into())))
792 .collect::<Result<Vec<_>, _>>()?,
793 );
794
795 message::Message::Assistant {
796 content: OneOrMany::many(content).map_err(|_| {
797 message::MessageError::ConversionError(
798 "Neither `content` nor `tool_calls` was provided to the Message"
799 .to_owned(),
800 )
801 })?,
802 }
803 }
804
805 Message::ToolResult {
806 tool_call_id,
807 content,
808 } => message::Message::User {
809 content: OneOrMany::one(message::UserContent::tool_result(
810 tool_call_id,
811 content.map(|content| message::ToolResultContent::text(content.text)),
812 )),
813 },
814
815 Message::System { content, .. } => message::Message::User {
818 content: content.map(|content| message::UserContent::text(content.text)),
819 },
820 })
821 }
822}
823
824impl From<UserContent> for message::UserContent {
825 fn from(content: UserContent) -> Self {
826 match content {
827 UserContent::Text { text } => message::UserContent::text(text),
828 UserContent::Image { image_url } => message::UserContent::image(
829 image_url.url,
830 Some(message::ContentFormat::default()),
831 None,
832 Some(image_url.detail),
833 ),
834 UserContent::Audio { input_audio } => message::UserContent::audio(
835 input_audio.data,
836 Some(message::ContentFormat::default()),
837 Some(input_audio.format),
838 ),
839 }
840 }
841}
842
843impl From<String> for UserContent {
844 fn from(s: String) -> Self {
845 UserContent::Text { text: s }
846 }
847}
848
849impl FromStr for UserContent {
850 type Err = Infallible;
851
852 fn from_str(s: &str) -> Result<Self, Self::Err> {
853 Ok(UserContent::Text {
854 text: s.to_string(),
855 })
856 }
857}
858
859impl From<String> for AssistantContent {
860 fn from(s: String) -> Self {
861 AssistantContent::Text { text: s }
862 }
863}
864
865impl FromStr for AssistantContent {
866 type Err = Infallible;
867
868 fn from_str(s: &str) -> Result<Self, Self::Err> {
869 Ok(AssistantContent::Text {
870 text: s.to_string(),
871 })
872 }
873}
874impl From<String> for SystemContent {
875 fn from(s: String) -> Self {
876 SystemContent {
877 r#type: SystemContentType::default(),
878 text: s,
879 }
880 }
881}
882
883impl FromStr for SystemContent {
884 type Err = Infallible;
885
886 fn from_str(s: &str) -> Result<Self, Self::Err> {
887 Ok(SystemContent {
888 r#type: SystemContentType::default(),
889 text: s.to_string(),
890 })
891 }
892}
893
894#[derive(Clone)]
895pub struct CompletionModel {
896 client: Client,
897 pub model: String,
899}
900
901impl CompletionModel {
902 pub fn new(client: Client, model: &str) -> Self {
903 Self {
904 client,
905 model: model.to_string(),
906 }
907 }
908}
909
910impl completion::CompletionModel for CompletionModel {
911 type Response = CompletionResponse;
912
913 #[cfg_attr(feature = "worker", worker::send)]
914 async fn completion(
915 &self,
916 completion_request: CompletionRequest,
917 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
918 let mut full_history: Vec<Message> = match &completion_request.preamble {
920 Some(preamble) => vec![Message::system(preamble)],
921 None => vec![],
922 };
923
924 let prompt: Vec<Message> = completion_request.prompt_with_context().try_into()?;
926
927 let chat_history: Vec<Message> = completion_request
929 .chat_history
930 .into_iter()
931 .map(|message| message.try_into())
932 .collect::<Result<Vec<Vec<Message>>, _>>()?
933 .into_iter()
934 .flatten()
935 .collect();
936
937 full_history.extend(chat_history);
939 full_history.extend(prompt);
940
941 let request = if completion_request.tools.is_empty() {
942 json!({
943 "model": self.model,
944 "messages": full_history,
945
946 })
947 } else {
948 json!({
949 "model": self.model,
950 "messages": full_history,
951 "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
952 "tool_choice": "auto",
953 })
954 };
955
956 let request = if let Some(temperature) = completion_request.temperature {
959 json_utils::merge(
960 request,
961 json!({
962 "temperature": temperature,
963 }),
964 )
965 } else {
966 request
967 };
968
969 let response = self
970 .client
971 .post("/chat/completions")
972 .json(
973 &if let Some(params) = completion_request.additional_params {
974 json_utils::merge(request, params)
975 } else {
976 request
977 },
978 )
979 .send()
980 .await?;
981
982 if response.status().is_success() {
983 let t = response.text().await?;
984 tracing::debug!(target: "rig", "OpenAI completion error: {}", t);
985
986 match serde_json::from_str::<ApiResponse<CompletionResponse>>(&t)? {
987 ApiResponse::Ok(response) => {
988 tracing::info!(target: "rig",
989 "OpenAI completion token usage: {:?}",
990 response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string())
991 );
992 response.try_into()
993 }
994 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
995 }
996 } else {
997 Err(CompletionError::ProviderError(response.text().await?))
998 }
999 }
1000}
1001
1002pub const WHISPER_1: &str = "whisper-1";
1006
1007#[derive(Debug, Deserialize)]
1008pub struct TranscriptionResponse {
1009 pub text: String,
1010}
1011
1012impl TryFrom<TranscriptionResponse>
1013 for transcription::TranscriptionResponse<TranscriptionResponse>
1014{
1015 type Error = TranscriptionError;
1016
1017 fn try_from(value: TranscriptionResponse) -> Result<Self, Self::Error> {
1018 Ok(transcription::TranscriptionResponse {
1019 text: value.text.clone(),
1020 response: value,
1021 })
1022 }
1023}
1024
1025#[derive(Clone)]
1026pub struct TranscriptionModel {
1027 client: Client,
1028 pub model: String,
1030}
1031
1032impl TranscriptionModel {
1033 pub fn new(client: Client, model: &str) -> Self {
1034 Self {
1035 client,
1036 model: model.to_string(),
1037 }
1038 }
1039}
1040impl transcription::TranscriptionModel for TranscriptionModel {
1041 type Response = TranscriptionResponse;
1042
1043 #[cfg_attr(feature = "worker", worker::send)]
1044 async fn transcription(
1045 &self,
1046 request: transcription::TranscriptionRequest,
1047 ) -> Result<
1048 transcription::TranscriptionResponse<Self::Response>,
1049 transcription::TranscriptionError,
1050 > {
1051 let data = request.data;
1052
1053 let mut body = reqwest::multipart::Form::new()
1054 .text("model", self.model.clone())
1055 .text("language", request.language)
1056 .part(
1057 "file",
1058 Part::bytes(data).file_name(request.filename.clone()),
1059 );
1060
1061 if let Some(prompt) = request.prompt {
1062 body = body.text("prompt", prompt.clone());
1063 }
1064
1065 if let Some(ref temperature) = request.temperature {
1066 body = body.text("temperature", temperature.to_string());
1067 }
1068
1069 if let Some(ref additional_params) = request.additional_params {
1070 for (key, value) in additional_params
1071 .as_object()
1072 .expect("Additional Parameters to OpenAI Transcription should be a map")
1073 {
1074 body = body.text(key.to_owned(), value.to_string());
1075 }
1076 }
1077
1078 let response = self
1079 .client
1080 .post("audio/transcriptions")
1081 .multipart(body)
1082 .send()
1083 .await?;
1084
1085 if response.status().is_success() {
1086 match response
1087 .json::<ApiResponse<TranscriptionResponse>>()
1088 .await?
1089 {
1090 ApiResponse::Ok(response) => response.try_into(),
1091 ApiResponse::Err(api_error_response) => Err(TranscriptionError::ProviderError(
1092 api_error_response.message,
1093 )),
1094 }
1095 } else {
1096 Err(TranscriptionError::ProviderError(response.text().await?))
1097 }
1098 }
1099}
1100
1101#[cfg(test)]
1102mod tests {
1103 use super::*;
1104 use serde_path_to_error::deserialize;
1105
1106 #[test]
1107 fn test_deserialize_message() {
1108 let assistant_message_json = r#"
1109 {
1110 "role": "assistant",
1111 "content": "\n\nHello there, how may I assist you today?"
1112 }
1113 "#;
1114
1115 let assistant_message_json2 = r#"
1116 {
1117 "role": "assistant",
1118 "content": [
1119 {
1120 "type": "text",
1121 "text": "\n\nHello there, how may I assist you today?"
1122 }
1123 ],
1124 "tool_calls": null
1125 }
1126 "#;
1127
1128 let assistant_message_json3 = r#"
1129 {
1130 "role": "assistant",
1131 "tool_calls": [
1132 {
1133 "id": "call_h89ipqYUjEpCPI6SxspMnoUU",
1134 "type": "function",
1135 "function": {
1136 "name": "subtract",
1137 "arguments": "{\"x\": 2, \"y\": 5}"
1138 }
1139 }
1140 ],
1141 "content": null,
1142 "refusal": null
1143 }
1144 "#;
1145
1146 let user_message_json = r#"
1147 {
1148 "role": "user",
1149 "content": [
1150 {
1151 "type": "text",
1152 "text": "What's in this image?"
1153 },
1154 {
1155 "type": "image",
1156 "image_url": {
1157 "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
1158 }
1159 },
1160 {
1161 "type": "audio",
1162 "input_audio": {
1163 "data": "...",
1164 "format": "mp3"
1165 }
1166 }
1167 ]
1168 }
1169 "#;
1170
1171 let assistant_message: Message = {
1172 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
1173 deserialize(jd).unwrap_or_else(|err| {
1174 panic!(
1175 "Deserialization error at {} ({}:{}): {}",
1176 err.path(),
1177 err.inner().line(),
1178 err.inner().column(),
1179 err
1180 );
1181 })
1182 };
1183
1184 let assistant_message2: Message = {
1185 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
1186 deserialize(jd).unwrap_or_else(|err| {
1187 panic!(
1188 "Deserialization error at {} ({}:{}): {}",
1189 err.path(),
1190 err.inner().line(),
1191 err.inner().column(),
1192 err
1193 );
1194 })
1195 };
1196
1197 let assistant_message3: Message = {
1198 let jd: &mut serde_json::Deserializer<serde_json::de::StrRead<'_>> =
1199 &mut serde_json::Deserializer::from_str(assistant_message_json3);
1200 deserialize(jd).unwrap_or_else(|err| {
1201 panic!(
1202 "Deserialization error at {} ({}:{}): {}",
1203 err.path(),
1204 err.inner().line(),
1205 err.inner().column(),
1206 err
1207 );
1208 })
1209 };
1210
1211 let user_message: Message = {
1212 let jd = &mut serde_json::Deserializer::from_str(user_message_json);
1213 deserialize(jd).unwrap_or_else(|err| {
1214 panic!(
1215 "Deserialization error at {} ({}:{}): {}",
1216 err.path(),
1217 err.inner().line(),
1218 err.inner().column(),
1219 err
1220 );
1221 })
1222 };
1223
1224 match assistant_message {
1225 Message::Assistant { content, .. } => {
1226 assert_eq!(
1227 content[0],
1228 AssistantContent::Text {
1229 text: "\n\nHello there, how may I assist you today?".to_string()
1230 }
1231 );
1232 }
1233 _ => panic!("Expected assistant message"),
1234 }
1235
1236 match assistant_message2 {
1237 Message::Assistant {
1238 content,
1239 tool_calls,
1240 ..
1241 } => {
1242 assert_eq!(
1243 content[0],
1244 AssistantContent::Text {
1245 text: "\n\nHello there, how may I assist you today?".to_string()
1246 }
1247 );
1248
1249 assert_eq!(tool_calls, vec![]);
1250 }
1251 _ => panic!("Expected assistant message"),
1252 }
1253
1254 match assistant_message3 {
1255 Message::Assistant {
1256 content,
1257 tool_calls,
1258 refusal,
1259 ..
1260 } => {
1261 assert!(content.is_empty());
1262 assert!(refusal.is_none());
1263 assert_eq!(
1264 tool_calls[0],
1265 ToolCall {
1266 id: "call_h89ipqYUjEpCPI6SxspMnoUU".to_string(),
1267 r#type: ToolType::Function,
1268 function: Function {
1269 name: "subtract".to_string(),
1270 arguments: serde_json::json!({"x": 2, "y": 5}),
1271 },
1272 }
1273 );
1274 }
1275 _ => panic!("Expected assistant message"),
1276 }
1277
1278 match user_message {
1279 Message::User { content, .. } => {
1280 let (first, second) = {
1281 let mut iter = content.into_iter();
1282 (iter.next().unwrap(), iter.next().unwrap())
1283 };
1284 assert_eq!(
1285 first,
1286 UserContent::Text {
1287 text: "What's in this image?".to_string()
1288 }
1289 );
1290 assert_eq!(second, UserContent::Image { image_url: ImageUrl { url: "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg".to_string(), detail: ImageDetail::default() } });
1291 }
1292 _ => panic!("Expected user message"),
1293 }
1294 }
1295
1296 #[test]
1297 fn test_message_to_message_conversion() {
1298 let user_message = message::Message::User {
1299 content: OneOrMany::one(message::UserContent::text("Hello")),
1300 };
1301
1302 let assistant_message = message::Message::Assistant {
1303 content: OneOrMany::one(message::AssistantContent::text("Hi there!")),
1304 };
1305
1306 let converted_user_message: Vec<Message> = user_message.clone().try_into().unwrap();
1307 let converted_assistant_message: Vec<Message> =
1308 assistant_message.clone().try_into().unwrap();
1309
1310 match converted_user_message[0].clone() {
1311 Message::User { content, .. } => {
1312 assert_eq!(
1313 content.first(),
1314 UserContent::Text {
1315 text: "Hello".to_string()
1316 }
1317 );
1318 }
1319 _ => panic!("Expected user message"),
1320 }
1321
1322 match converted_assistant_message[0].clone() {
1323 Message::Assistant { content, .. } => {
1324 assert_eq!(
1325 content[0].clone(),
1326 AssistantContent::Text {
1327 text: "Hi there!".to_string()
1328 }
1329 );
1330 }
1331 _ => panic!("Expected assistant message"),
1332 }
1333
1334 let original_user_message: message::Message =
1335 converted_user_message[0].clone().try_into().unwrap();
1336 let original_assistant_message: message::Message =
1337 converted_assistant_message[0].clone().try_into().unwrap();
1338
1339 assert_eq!(original_user_message, user_message);
1340 assert_eq!(original_assistant_message, assistant_message);
1341 }
1342
1343 #[test]
1344 fn test_message_from_message_conversion() {
1345 let user_message = Message::User {
1346 content: OneOrMany::one(UserContent::Text {
1347 text: "Hello".to_string(),
1348 }),
1349 name: None,
1350 };
1351
1352 let assistant_message = Message::Assistant {
1353 content: vec![AssistantContent::Text {
1354 text: "Hi there!".to_string(),
1355 }],
1356 refusal: None,
1357 audio: None,
1358 name: None,
1359 tool_calls: vec![],
1360 };
1361
1362 let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
1363 let converted_assistant_message: message::Message =
1364 assistant_message.clone().try_into().unwrap();
1365
1366 match converted_user_message.clone() {
1367 message::Message::User { content } => {
1368 assert_eq!(content.first(), message::UserContent::text("Hello"));
1369 }
1370 _ => panic!("Expected user message"),
1371 }
1372
1373 match converted_assistant_message.clone() {
1374 message::Message::Assistant { content } => {
1375 assert_eq!(
1376 content.first(),
1377 message::AssistantContent::text("Hi there!")
1378 );
1379 }
1380 _ => panic!("Expected assistant message"),
1381 }
1382
1383 let original_user_message: Vec<Message> = converted_user_message.try_into().unwrap();
1384 let original_assistant_message: Vec<Message> =
1385 converted_assistant_message.try_into().unwrap();
1386
1387 assert_eq!(original_user_message[0], user_message);
1388 assert_eq!(original_assistant_message[0], assistant_message);
1389 }
1390}