1use crate::json_utils::merge;
12use crate::streaming::StreamingResult;
13use crate::{
14 agent::AgentBuilder,
15 completion::{self, CompletionError, CompletionRequest},
16 embeddings::{self, EmbeddingError, EmbeddingsBuilder},
17 extractor::ExtractorBuilder,
18 json_utils,
19 message::{self, AudioMediaType, ImageDetail},
20 one_or_many::string_or_one_or_many,
21 streaming,
22 streaming::StreamingCompletionModel,
23 transcription::{self, TranscriptionError},
24 Embed, OneOrMany,
25};
26use async_stream::stream;
27use futures::StreamExt;
28use reqwest::multipart::Part;
29use reqwest::RequestBuilder;
30use schemars::JsonSchema;
31use serde::{Deserialize, Serialize};
32use serde_json::{json, Value};
33use std::collections::HashMap;
34use std::{convert::Infallible, str::FromStr};
35
36const OPENAI_API_BASE_URL: &str = "https://api.openai.com/v1";
40
41#[derive(Clone)]
42pub struct Client {
43 base_url: String,
44 http_client: reqwest::Client,
45}
46
47impl Client {
48 pub fn new(api_key: &str) -> Self {
50 Self::from_url(api_key, OPENAI_API_BASE_URL)
51 }
52
53 pub fn from_url(api_key: &str, base_url: &str) -> Self {
55 Self {
56 base_url: base_url.to_string(),
57 http_client: reqwest::Client::builder()
58 .default_headers({
59 let mut headers = reqwest::header::HeaderMap::new();
60 headers.insert(
61 "Authorization",
62 format!("Bearer {}", api_key)
63 .parse()
64 .expect("Bearer token should parse"),
65 );
66 headers
67 })
68 .build()
69 .expect("OpenAI reqwest client should build"),
70 }
71 }
72
73 pub fn from_env() -> Self {
76 let api_key = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set");
77 Self::new(&api_key)
78 }
79
80 fn post(&self, path: &str) -> reqwest::RequestBuilder {
81 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
82 self.http_client.post(url)
83 }
84
85 pub fn embedding_model(&self, model: &str) -> EmbeddingModel {
99 let ndims = match model {
100 TEXT_EMBEDDING_3_LARGE => 3072,
101 TEXT_EMBEDDING_3_SMALL | TEXT_EMBEDDING_ADA_002 => 1536,
102 _ => 0,
103 };
104 EmbeddingModel::new(self.clone(), model, ndims)
105 }
106
107 pub fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> EmbeddingModel {
119 EmbeddingModel::new(self.clone(), model, ndims)
120 }
121
122 pub fn embeddings<D: Embed>(&self, model: &str) -> EmbeddingsBuilder<EmbeddingModel, D> {
139 EmbeddingsBuilder::new(self.embedding_model(model))
140 }
141
142 pub fn completion_model(&self, model: &str) -> CompletionModel {
154 CompletionModel::new(self.clone(), model)
155 }
156
157 pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
172 AgentBuilder::new(self.completion_model(model))
173 }
174
175 pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
177 &self,
178 model: &str,
179 ) -> ExtractorBuilder<T, CompletionModel> {
180 ExtractorBuilder::new(self.completion_model(model))
181 }
182
183 pub fn transcription_model(&self, model: &str) -> TranscriptionModel {
195 TranscriptionModel::new(self.clone(), model)
196 }
197}
198
199#[derive(Debug, Deserialize)]
200struct ApiErrorResponse {
201 message: String,
202}
203
204#[derive(Debug, Deserialize)]
205#[serde(untagged)]
206enum ApiResponse<T> {
207 Ok(T),
208 Err(ApiErrorResponse),
209}
210
211pub const TEXT_EMBEDDING_3_LARGE: &str = "text-embedding-3-large";
216pub const TEXT_EMBEDDING_3_SMALL: &str = "text-embedding-3-small";
218pub const TEXT_EMBEDDING_ADA_002: &str = "text-embedding-ada-002";
220
221#[derive(Debug, Deserialize)]
222pub struct EmbeddingResponse {
223 pub object: String,
224 pub data: Vec<EmbeddingData>,
225 pub model: String,
226 pub usage: Usage,
227}
228
229impl From<ApiErrorResponse> for EmbeddingError {
230 fn from(err: ApiErrorResponse) -> Self {
231 EmbeddingError::ProviderError(err.message)
232 }
233}
234
235impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
236 fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
237 match value {
238 ApiResponse::Ok(response) => Ok(response),
239 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
240 }
241 }
242}
243
244#[derive(Debug, Deserialize)]
245pub struct EmbeddingData {
246 pub object: String,
247 pub embedding: Vec<f64>,
248 pub index: usize,
249}
250
251#[derive(Clone, Debug, Deserialize)]
252pub struct Usage {
253 pub prompt_tokens: usize,
254 pub total_tokens: usize,
255}
256
257impl std::fmt::Display for Usage {
258 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
259 write!(
260 f,
261 "Prompt tokens: {} Total tokens: {}",
262 self.prompt_tokens, self.total_tokens
263 )
264 }
265}
266
267#[derive(Clone)]
268pub struct EmbeddingModel {
269 client: Client,
270 pub model: String,
271 ndims: usize,
272}
273
274impl embeddings::EmbeddingModel for EmbeddingModel {
275 const MAX_DOCUMENTS: usize = 1024;
276
277 fn ndims(&self) -> usize {
278 self.ndims
279 }
280
281 #[cfg_attr(feature = "worker", worker::send)]
282 async fn embed_texts(
283 &self,
284 documents: impl IntoIterator<Item = String>,
285 ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
286 let documents = documents.into_iter().collect::<Vec<_>>();
287
288 let response = self
289 .client
290 .post("/embeddings")
291 .json(&json!({
292 "model": self.model,
293 "input": documents,
294 }))
295 .send()
296 .await?;
297
298 if response.status().is_success() {
299 match response.json::<ApiResponse<EmbeddingResponse>>().await? {
300 ApiResponse::Ok(response) => {
301 tracing::info!(target: "rig",
302 "OpenAI embedding token usage: {}",
303 response.usage
304 );
305
306 if response.data.len() != documents.len() {
307 return Err(EmbeddingError::ResponseError(
308 "Response data length does not match input length".into(),
309 ));
310 }
311
312 Ok(response
313 .data
314 .into_iter()
315 .zip(documents.into_iter())
316 .map(|(embedding, document)| embeddings::Embedding {
317 document,
318 vec: embedding.embedding,
319 })
320 .collect())
321 }
322 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
323 }
324 } else {
325 Err(EmbeddingError::ProviderError(response.text().await?))
326 }
327 }
328}
329
330impl EmbeddingModel {
331 pub fn new(client: Client, model: &str, ndims: usize) -> Self {
332 Self {
333 client,
334 model: model.to_string(),
335 ndims,
336 }
337 }
338}
339
340pub const O3_MINI: &str = "o3-mini";
346pub const O3_MINI_2025_01_31: &str = "o3-mini-2025-01-31";
348pub const O1: &str = "o1";
350pub const O1_2024_12_17: &str = "o1-2024-12-17";
352pub const O1_PREVIEW: &str = "o1-preview";
354pub const O1_PREVIEW_2024_09_12: &str = "o1-preview-2024-09-12";
356pub const O1_MINI: &str = "o1-mini";
358pub const O1_MINI_2024_09_12: &str = "o1-mini-2024-09-12";
360pub const GPT_4_5_PREVIEW: &str = "gpt-4.5-preview";
362pub const GPT_4_5_PREVIEW_2025_02_27: &str = "gpt-4.5-preview-2025-02-27";
364pub const GPT_4O: &str = "gpt-4o";
366pub const GPT_4O_MINI: &str = "gpt-4o-mini";
368pub const GPT_4O_2024_05_13: &str = "gpt-4o-2024-05-13";
370pub const GPT_4_TURBO: &str = "gpt-4-turbo";
372pub const GPT_4_TURBO_2024_04_09: &str = "gpt-4-turbo-2024-04-09";
374pub const GPT_4_TURBO_PREVIEW: &str = "gpt-4-turbo-preview";
376pub const GPT_4_0125_PREVIEW: &str = "gpt-4-0125-preview";
378pub const GPT_4_1106_PREVIEW: &str = "gpt-4-1106-preview";
380pub const GPT_4_VISION_PREVIEW: &str = "gpt-4-vision-preview";
382pub const GPT_4_1106_VISION_PREVIEW: &str = "gpt-4-1106-vision-preview";
384pub const GPT_4: &str = "gpt-4";
386pub const GPT_4_0613: &str = "gpt-4-0613";
388pub const GPT_4_32K: &str = "gpt-4-32k";
390pub const GPT_4_32K_0613: &str = "gpt-4-32k-0613";
392pub const GPT_35_TURBO: &str = "gpt-3.5-turbo";
394pub const GPT_35_TURBO_0125: &str = "gpt-3.5-turbo-0125";
396pub const GPT_35_TURBO_1106: &str = "gpt-3.5-turbo-1106";
398pub const GPT_35_TURBO_INSTRUCT: &str = "gpt-3.5-turbo-instruct";
400
401#[derive(Debug, Deserialize)]
402pub struct CompletionResponse {
403 pub id: String,
404 pub object: String,
405 pub created: u64,
406 pub model: String,
407 pub system_fingerprint: Option<String>,
408 pub choices: Vec<Choice>,
409 pub usage: Option<Usage>,
410}
411
412impl From<ApiErrorResponse> for CompletionError {
413 fn from(err: ApiErrorResponse) -> Self {
414 CompletionError::ProviderError(err.message)
415 }
416}
417
418impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
419 type Error = CompletionError;
420
421 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
422 let choice = response.choices.first().ok_or_else(|| {
423 CompletionError::ResponseError("Response contained no choices".to_owned())
424 })?;
425
426 let content = match &choice.message {
427 Message::Assistant {
428 content,
429 tool_calls,
430 ..
431 } => {
432 let mut content = content
433 .iter()
434 .filter_map(|c| {
435 let s = match c {
436 AssistantContent::Text { text } => text,
437 AssistantContent::Refusal { refusal } => refusal,
438 };
439 if s.is_empty() {
440 None
441 } else {
442 Some(completion::AssistantContent::text(s))
443 }
444 })
445 .collect::<Vec<_>>();
446
447 content.extend(
448 tool_calls
449 .iter()
450 .map(|call| {
451 completion::AssistantContent::tool_call(
452 &call.id,
453 &call.function.name,
454 call.function.arguments.clone(),
455 )
456 })
457 .collect::<Vec<_>>(),
458 );
459 Ok(content)
460 }
461 _ => Err(CompletionError::ResponseError(
462 "Response did not contain a valid message or tool call".into(),
463 )),
464 }?;
465
466 let choice = OneOrMany::many(content).map_err(|_| {
467 CompletionError::ResponseError(
468 "Response contained no message or tool call (empty)".to_owned(),
469 )
470 })?;
471
472 Ok(completion::CompletionResponse {
473 choice,
474 raw_response: response,
475 })
476 }
477}
478
479#[derive(Debug, Serialize, Deserialize)]
480pub struct Choice {
481 pub index: usize,
482 pub message: Message,
483 pub logprobs: Option<serde_json::Value>,
484 pub finish_reason: String,
485}
486
487#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
488#[serde(tag = "role", rename_all = "lowercase")]
489pub enum Message {
490 System {
491 #[serde(deserialize_with = "string_or_one_or_many")]
492 content: OneOrMany<SystemContent>,
493 #[serde(skip_serializing_if = "Option::is_none")]
494 name: Option<String>,
495 },
496 User {
497 #[serde(deserialize_with = "string_or_one_or_many")]
498 content: OneOrMany<UserContent>,
499 #[serde(skip_serializing_if = "Option::is_none")]
500 name: Option<String>,
501 },
502 Assistant {
503 #[serde(default, deserialize_with = "json_utils::string_or_vec")]
504 content: Vec<AssistantContent>,
505 #[serde(skip_serializing_if = "Option::is_none")]
506 refusal: Option<String>,
507 #[serde(skip_serializing_if = "Option::is_none")]
508 audio: Option<AudioAssistant>,
509 #[serde(skip_serializing_if = "Option::is_none")]
510 name: Option<String>,
511 #[serde(
512 default,
513 deserialize_with = "json_utils::null_or_vec",
514 skip_serializing_if = "Vec::is_empty"
515 )]
516 tool_calls: Vec<ToolCall>,
517 },
518 #[serde(rename = "tool")]
519 ToolResult {
520 tool_call_id: String,
521 content: OneOrMany<ToolResultContent>,
522 },
523}
524
525impl Message {
526 pub fn system(content: &str) -> Self {
527 Message::System {
528 content: OneOrMany::one(content.to_owned().into()),
529 name: None,
530 }
531 }
532}
533
534#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
535pub struct AudioAssistant {
536 id: String,
537}
538
539#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
540pub struct SystemContent {
541 #[serde(default)]
542 r#type: SystemContentType,
543 text: String,
544}
545
546#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
547#[serde(rename_all = "lowercase")]
548pub enum SystemContentType {
549 #[default]
550 Text,
551}
552
553#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
554#[serde(tag = "type", rename_all = "lowercase")]
555pub enum AssistantContent {
556 Text { text: String },
557 Refusal { refusal: String },
558}
559
560#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
561#[serde(tag = "type", rename_all = "lowercase")]
562pub enum UserContent {
563 Text { text: String },
564 Image { image_url: ImageUrl },
565 Audio { input_audio: InputAudio },
566}
567
568#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
569pub struct ImageUrl {
570 pub url: String,
571 #[serde(default)]
572 pub detail: ImageDetail,
573}
574
575#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
576pub struct InputAudio {
577 pub data: String,
578 pub format: AudioMediaType,
579}
580
581#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
582pub struct ToolResultContent {
583 #[serde(default)]
584 r#type: ToolResultContentType,
585 text: String,
586}
587
588#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
589#[serde(rename_all = "lowercase")]
590pub enum ToolResultContentType {
591 #[default]
592 Text,
593}
594
595impl FromStr for ToolResultContent {
596 type Err = Infallible;
597
598 fn from_str(s: &str) -> Result<Self, Self::Err> {
599 Ok(s.to_owned().into())
600 }
601}
602
603impl From<String> for ToolResultContent {
604 fn from(s: String) -> Self {
605 ToolResultContent {
606 r#type: ToolResultContentType::default(),
607 text: s,
608 }
609 }
610}
611
612#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
613pub struct ToolCall {
614 pub id: String,
615 #[serde(default)]
616 pub r#type: ToolType,
617 pub function: Function,
618}
619
620#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
621#[serde(rename_all = "lowercase")]
622pub enum ToolType {
623 #[default]
624 Function,
625}
626
627#[derive(Debug, Deserialize, Serialize, Clone)]
628pub struct ToolDefinition {
629 pub r#type: String,
630 pub function: completion::ToolDefinition,
631}
632
633impl From<completion::ToolDefinition> for ToolDefinition {
634 fn from(tool: completion::ToolDefinition) -> Self {
635 Self {
636 r#type: "function".into(),
637 function: tool,
638 }
639 }
640}
641
642#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
643pub struct Function {
644 pub name: String,
645 #[serde(with = "json_utils::stringified_json")]
646 pub arguments: serde_json::Value,
647}
648
649impl TryFrom<message::Message> for Vec<Message> {
650 type Error = message::MessageError;
651
652 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
653 match message {
654 message::Message::User { content } => {
655 let (tool_results, other_content): (Vec<_>, Vec<_>) = content
656 .into_iter()
657 .partition(|content| matches!(content, message::UserContent::ToolResult(_)));
658
659 if !tool_results.is_empty() {
662 tool_results
663 .into_iter()
664 .map(|content| match content {
665 message::UserContent::ToolResult(message::ToolResult {
666 id,
667 content,
668 }) => Ok::<_, message::MessageError>(Message::ToolResult {
669 tool_call_id: id,
670 content: content.try_map(|content| match content {
671 message::ToolResultContent::Text(message::Text { text }) => {
672 Ok(text.into())
673 }
674 _ => Err(message::MessageError::ConversionError(
675 "Tool result content does not support non-text".into(),
676 )),
677 })?,
678 }),
679 _ => unreachable!(),
680 })
681 .collect::<Result<Vec<_>, _>>()
682 } else {
683 let other_content = OneOrMany::many(other_content).expect(
684 "There must be other content here if there were no tool result content",
685 );
686
687 Ok(vec![Message::User {
688 content: other_content.map(|content| match content {
689 message::UserContent::Text(message::Text { text }) => {
690 UserContent::Text { text }
691 }
692 message::UserContent::Image(message::Image {
693 data, detail, ..
694 }) => UserContent::Image {
695 image_url: ImageUrl {
696 url: data,
697 detail: detail.unwrap_or_default(),
698 },
699 },
700 message::UserContent::Document(message::Document { data, .. }) => {
701 UserContent::Text { text: data }
702 }
703 message::UserContent::Audio(message::Audio {
704 data,
705 media_type,
706 ..
707 }) => UserContent::Audio {
708 input_audio: InputAudio {
709 data,
710 format: match media_type {
711 Some(media_type) => media_type,
712 None => AudioMediaType::MP3,
713 },
714 },
715 },
716 _ => unreachable!(),
717 }),
718 name: None,
719 }])
720 }
721 }
722 message::Message::Assistant { content } => {
723 let (text_content, tool_calls) = content.into_iter().fold(
724 (Vec::new(), Vec::new()),
725 |(mut texts, mut tools), content| {
726 match content {
727 message::AssistantContent::Text(text) => texts.push(text),
728 message::AssistantContent::ToolCall(tool_call) => tools.push(tool_call),
729 }
730 (texts, tools)
731 },
732 );
733
734 Ok(vec![Message::Assistant {
737 content: text_content
738 .into_iter()
739 .map(|content| content.text.into())
740 .collect::<Vec<_>>(),
741 refusal: None,
742 audio: None,
743 name: None,
744 tool_calls: tool_calls
745 .into_iter()
746 .map(|tool_call| tool_call.into())
747 .collect::<Vec<_>>(),
748 }])
749 }
750 }
751 }
752}
753
754impl From<message::ToolCall> for ToolCall {
755 fn from(tool_call: message::ToolCall) -> Self {
756 Self {
757 id: tool_call.id,
758 r#type: ToolType::default(),
759 function: Function {
760 name: tool_call.function.name,
761 arguments: tool_call.function.arguments,
762 },
763 }
764 }
765}
766
767impl From<ToolCall> for message::ToolCall {
768 fn from(tool_call: ToolCall) -> Self {
769 Self {
770 id: tool_call.id,
771 function: message::ToolFunction {
772 name: tool_call.function.name,
773 arguments: tool_call.function.arguments,
774 },
775 }
776 }
777}
778
779impl TryFrom<Message> for message::Message {
780 type Error = message::MessageError;
781
782 fn try_from(message: Message) -> Result<Self, Self::Error> {
783 Ok(match message {
784 Message::User { content, .. } => message::Message::User {
785 content: content.map(|content| content.into()),
786 },
787 Message::Assistant {
788 content,
789 tool_calls,
790 ..
791 } => {
792 let mut content = content
793 .into_iter()
794 .map(|content| match content {
795 AssistantContent::Text { text } => message::AssistantContent::text(text),
796
797 AssistantContent::Refusal { refusal } => {
800 message::AssistantContent::text(refusal)
801 }
802 })
803 .collect::<Vec<_>>();
804
805 content.extend(
806 tool_calls
807 .into_iter()
808 .map(|tool_call| Ok(message::AssistantContent::ToolCall(tool_call.into())))
809 .collect::<Result<Vec<_>, _>>()?,
810 );
811
812 message::Message::Assistant {
813 content: OneOrMany::many(content).map_err(|_| {
814 message::MessageError::ConversionError(
815 "Neither `content` nor `tool_calls` was provided to the Message"
816 .to_owned(),
817 )
818 })?,
819 }
820 }
821
822 Message::ToolResult {
823 tool_call_id,
824 content,
825 } => message::Message::User {
826 content: OneOrMany::one(message::UserContent::tool_result(
827 tool_call_id,
828 content.map(|content| message::ToolResultContent::text(content.text)),
829 )),
830 },
831
832 Message::System { content, .. } => message::Message::User {
835 content: content.map(|content| message::UserContent::text(content.text)),
836 },
837 })
838 }
839}
840
841impl From<UserContent> for message::UserContent {
842 fn from(content: UserContent) -> Self {
843 match content {
844 UserContent::Text { text } => message::UserContent::text(text),
845 UserContent::Image { image_url } => message::UserContent::image(
846 image_url.url,
847 Some(message::ContentFormat::default()),
848 None,
849 Some(image_url.detail),
850 ),
851 UserContent::Audio { input_audio } => message::UserContent::audio(
852 input_audio.data,
853 Some(message::ContentFormat::default()),
854 Some(input_audio.format),
855 ),
856 }
857 }
858}
859
860impl From<String> for UserContent {
861 fn from(s: String) -> Self {
862 UserContent::Text { text: s }
863 }
864}
865
866impl FromStr for UserContent {
867 type Err = Infallible;
868
869 fn from_str(s: &str) -> Result<Self, Self::Err> {
870 Ok(UserContent::Text {
871 text: s.to_string(),
872 })
873 }
874}
875
876impl From<String> for AssistantContent {
877 fn from(s: String) -> Self {
878 AssistantContent::Text { text: s }
879 }
880}
881
882impl FromStr for AssistantContent {
883 type Err = Infallible;
884
885 fn from_str(s: &str) -> Result<Self, Self::Err> {
886 Ok(AssistantContent::Text {
887 text: s.to_string(),
888 })
889 }
890}
891impl From<String> for SystemContent {
892 fn from(s: String) -> Self {
893 SystemContent {
894 r#type: SystemContentType::default(),
895 text: s,
896 }
897 }
898}
899
900impl FromStr for SystemContent {
901 type Err = Infallible;
902
903 fn from_str(s: &str) -> Result<Self, Self::Err> {
904 Ok(SystemContent {
905 r#type: SystemContentType::default(),
906 text: s.to_string(),
907 })
908 }
909}
910
911#[derive(Clone)]
912pub struct CompletionModel {
913 client: Client,
914 pub model: String,
916}
917
918impl CompletionModel {
919 pub fn new(client: Client, model: &str) -> Self {
920 Self {
921 client,
922 model: model.to_string(),
923 }
924 }
925
926 fn create_completion_request(
927 &self,
928 completion_request: CompletionRequest,
929 ) -> Result<Value, CompletionError> {
930 let mut full_history: Vec<Message> = match &completion_request.preamble {
932 Some(preamble) => vec![Message::system(preamble)],
933 None => vec![],
934 };
935
936 let prompt: Vec<Message> = completion_request.prompt_with_context().try_into()?;
938
939 let chat_history: Vec<Message> = completion_request
941 .chat_history
942 .into_iter()
943 .map(|message| message.try_into())
944 .collect::<Result<Vec<Vec<Message>>, _>>()?
945 .into_iter()
946 .flatten()
947 .collect();
948
949 full_history.extend(chat_history);
951 full_history.extend(prompt);
952
953 let request = if completion_request.tools.is_empty() {
954 json!({
955 "model": self.model,
956 "messages": full_history,
957
958 })
959 } else {
960 json!({
961 "model": self.model,
962 "messages": full_history,
963 "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
964 "tool_choice": "auto",
965 })
966 };
967
968 let request = if let Some(temperature) = completion_request.temperature {
971 json_utils::merge(
972 request,
973 json!({
974 "temperature": temperature,
975 }),
976 )
977 } else {
978 request
979 };
980
981 let request = if let Some(params) = completion_request.additional_params {
982 json_utils::merge(request, params)
983 } else {
984 request
985 };
986
987 Ok(request)
988 }
989}
990
991impl completion::CompletionModel for CompletionModel {
992 type Response = CompletionResponse;
993
994 #[cfg_attr(feature = "worker", worker::send)]
995 async fn completion(
996 &self,
997 completion_request: CompletionRequest,
998 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
999 let request = self.create_completion_request(completion_request)?;
1000
1001 let response = self
1002 .client
1003 .post("/chat/completions")
1004 .json(&request)
1005 .send()
1006 .await?;
1007
1008 if response.status().is_success() {
1009 let t = response.text().await?;
1010 tracing::debug!(target: "rig", "OpenAI completion error: {}", t);
1011
1012 match serde_json::from_str::<ApiResponse<CompletionResponse>>(&t)? {
1013 ApiResponse::Ok(response) => {
1014 tracing::info!(target: "rig",
1015 "OpenAI completion token usage: {:?}",
1016 response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string())
1017 );
1018 response.try_into()
1019 }
1020 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
1021 }
1022 } else {
1023 Err(CompletionError::ProviderError(response.text().await?))
1024 }
1025 }
1026}
1027
1028pub const WHISPER_1: &str = "whisper-1";
1032
1033#[derive(Debug, Deserialize)]
1034pub struct TranscriptionResponse {
1035 pub text: String,
1036}
1037
1038impl TryFrom<TranscriptionResponse>
1039 for transcription::TranscriptionResponse<TranscriptionResponse>
1040{
1041 type Error = TranscriptionError;
1042
1043 fn try_from(value: TranscriptionResponse) -> Result<Self, Self::Error> {
1044 Ok(transcription::TranscriptionResponse {
1045 text: value.text.clone(),
1046 response: value,
1047 })
1048 }
1049}
1050
1051#[derive(Clone)]
1052pub struct TranscriptionModel {
1053 client: Client,
1054 pub model: String,
1056}
1057
1058impl TranscriptionModel {
1059 pub fn new(client: Client, model: &str) -> Self {
1060 Self {
1061 client,
1062 model: model.to_string(),
1063 }
1064 }
1065}
1066
1067impl transcription::TranscriptionModel for TranscriptionModel {
1068 type Response = TranscriptionResponse;
1069
1070 #[cfg_attr(feature = "worker", worker::send)]
1071 async fn transcription(
1072 &self,
1073 request: transcription::TranscriptionRequest,
1074 ) -> Result<
1075 transcription::TranscriptionResponse<Self::Response>,
1076 transcription::TranscriptionError,
1077 > {
1078 let data = request.data;
1079
1080 let mut body = reqwest::multipart::Form::new()
1081 .text("model", self.model.clone())
1082 .text("language", request.language)
1083 .part(
1084 "file",
1085 Part::bytes(data).file_name(request.filename.clone()),
1086 );
1087
1088 if let Some(prompt) = request.prompt {
1089 body = body.text("prompt", prompt.clone());
1090 }
1091
1092 if let Some(ref temperature) = request.temperature {
1093 body = body.text("temperature", temperature.to_string());
1094 }
1095
1096 if let Some(ref additional_params) = request.additional_params {
1097 for (key, value) in additional_params
1098 .as_object()
1099 .expect("Additional Parameters to OpenAI Transcription should be a map")
1100 {
1101 body = body.text(key.to_owned(), value.to_string());
1102 }
1103 }
1104
1105 let response = self
1106 .client
1107 .post("audio/transcriptions")
1108 .multipart(body)
1109 .send()
1110 .await?;
1111
1112 if response.status().is_success() {
1113 match response
1114 .json::<ApiResponse<TranscriptionResponse>>()
1115 .await?
1116 {
1117 ApiResponse::Ok(response) => response.try_into(),
1118 ApiResponse::Err(api_error_response) => Err(TranscriptionError::ProviderError(
1119 api_error_response.message,
1120 )),
1121 }
1122 } else {
1123 Err(TranscriptionError::ProviderError(response.text().await?))
1124 }
1125 }
1126}
1127
1128#[derive(Debug, Serialize, Deserialize, Clone)]
1132pub struct StreamingFunction {
1133 #[serde(default)]
1134 name: Option<String>,
1135 #[serde(default)]
1136 arguments: String,
1137}
1138
1139#[derive(Debug, Serialize, Deserialize, Clone)]
1140pub struct StreamingToolCall {
1141 pub index: usize,
1142 pub function: StreamingFunction,
1143}
1144
1145#[derive(Deserialize)]
1146struct StreamingDelta {
1147 #[serde(default)]
1148 content: Option<String>,
1149 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
1150 tool_calls: Vec<StreamingToolCall>,
1151}
1152
1153#[derive(Deserialize)]
1154struct StreamingChoice {
1155 delta: StreamingDelta,
1156}
1157
1158#[derive(Deserialize)]
1159struct StreamingCompletionResponse {
1160 choices: Vec<StreamingChoice>,
1161}
1162
1163impl StreamingCompletionModel for CompletionModel {
1164 async fn stream(
1165 &self,
1166 completion_request: CompletionRequest,
1167 ) -> Result<StreamingResult, CompletionError> {
1168 let mut request = self.create_completion_request(completion_request)?;
1169 request = merge(request, json!({"stream": true}));
1170
1171 let builder = self.client.post("/chat/completions").json(&request);
1172 send_compatible_streaming_request(builder).await
1173 }
1174}
1175
1176pub async fn send_compatible_streaming_request(
1177 request_builder: RequestBuilder,
1178) -> Result<StreamingResult, CompletionError> {
1179 let response = request_builder.send().await?;
1180
1181 if !response.status().is_success() {
1182 return Err(CompletionError::ProviderError(format!(
1183 "{}: {}",
1184 response.status(),
1185 response.text().await?
1186 )));
1187 }
1188
1189 Ok(Box::pin(stream! {
1191 let mut stream = response.bytes_stream();
1192
1193 let mut partial_data = None;
1194 let mut calls: HashMap<usize, (String, String)> = HashMap::new();
1195
1196 while let Some(chunk_result) = stream.next().await {
1197 let chunk = match chunk_result {
1198 Ok(c) => c,
1199 Err(e) => {
1200 yield Err(CompletionError::from(e));
1201 break;
1202 }
1203 };
1204
1205 let text = match String::from_utf8(chunk.to_vec()) {
1206 Ok(t) => t,
1207 Err(e) => {
1208 yield Err(CompletionError::ResponseError(e.to_string()));
1209 break;
1210 }
1211 };
1212
1213
1214 for line in text.lines() {
1215 let mut line = line.to_string();
1216
1217
1218
1219 if partial_data.is_some() {
1221 line = format!("{}{}", partial_data.unwrap(), line);
1222 partial_data = None;
1223 }
1224 else {
1226 let Some(data) = line.strip_prefix("data: ") else {
1227 continue;
1228 };
1229
1230 if !line.ends_with("}") {
1232 partial_data = Some(data.to_string());
1233 } else {
1234 line = data.to_string();
1235 }
1236 }
1237
1238 let data = serde_json::from_str::<StreamingCompletionResponse>(&line);
1239
1240 let Ok(data) = data else {
1241 continue;
1242 };
1243
1244 let choice = data.choices.first().expect("Should have at least one choice");
1245
1246 let delta = &choice.delta;
1247
1248 if !delta.tool_calls.is_empty() {
1249 for tool_call in &delta.tool_calls {
1250 let function = tool_call.function.clone();
1251
1252 if function.name.is_some() && function.arguments.is_empty() {
1256 calls.insert(tool_call.index, (function.name.clone().unwrap(), "".to_string()));
1257 }
1258 else if function.name.is_none() && !function.arguments.is_empty() {
1262 let Some((name, arguments)) = calls.get(&tool_call.index) else {
1263 continue;
1264 };
1265
1266 let new_arguments = &tool_call.function.arguments;
1267 let arguments = format!("{}{}", arguments, new_arguments);
1268
1269 calls.insert(tool_call.index, (name.clone(), arguments));
1270 }
1271 else {
1273 let name = function.name.unwrap();
1274 let arguments = function.arguments;
1275 let Ok(arguments) = serde_json::from_str(&arguments) else {
1276 continue;
1277 };
1278
1279 yield Ok(streaming::StreamingChoice::ToolCall(name, "".to_string(), arguments))
1280 }
1281 }
1282 }
1283
1284 if let Some(content) = &choice.delta.content {
1285 yield Ok(streaming::StreamingChoice::Message(content.clone()))
1286 }
1287 }
1288 }
1289
1290 for (_, (name, arguments)) in calls {
1291 let Ok(arguments) = serde_json::from_str(&arguments) else {
1292 continue;
1293 };
1294
1295 yield Ok(streaming::StreamingChoice::ToolCall(name, "".to_string(), arguments))
1296 }
1297 }))
1298}
1299
1300#[cfg(test)]
1301mod tests {
1302 use super::*;
1303 use serde_path_to_error::deserialize;
1304
1305 #[test]
1306 fn test_deserialize_message() {
1307 let assistant_message_json = r#"
1308 {
1309 "role": "assistant",
1310 "content": "\n\nHello there, how may I assist you today?"
1311 }
1312 "#;
1313
1314 let assistant_message_json2 = r#"
1315 {
1316 "role": "assistant",
1317 "content": [
1318 {
1319 "type": "text",
1320 "text": "\n\nHello there, how may I assist you today?"
1321 }
1322 ],
1323 "tool_calls": null
1324 }
1325 "#;
1326
1327 let assistant_message_json3 = r#"
1328 {
1329 "role": "assistant",
1330 "tool_calls": [
1331 {
1332 "id": "call_h89ipqYUjEpCPI6SxspMnoUU",
1333 "type": "function",
1334 "function": {
1335 "name": "subtract",
1336 "arguments": "{\"x\": 2, \"y\": 5}"
1337 }
1338 }
1339 ],
1340 "content": null,
1341 "refusal": null
1342 }
1343 "#;
1344
1345 let user_message_json = r#"
1346 {
1347 "role": "user",
1348 "content": [
1349 {
1350 "type": "text",
1351 "text": "What's in this image?"
1352 },
1353 {
1354 "type": "image",
1355 "image_url": {
1356 "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
1357 }
1358 },
1359 {
1360 "type": "audio",
1361 "input_audio": {
1362 "data": "...",
1363 "format": "mp3"
1364 }
1365 }
1366 ]
1367 }
1368 "#;
1369
1370 let assistant_message: Message = {
1371 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
1372 deserialize(jd).unwrap_or_else(|err| {
1373 panic!(
1374 "Deserialization error at {} ({}:{}): {}",
1375 err.path(),
1376 err.inner().line(),
1377 err.inner().column(),
1378 err
1379 );
1380 })
1381 };
1382
1383 let assistant_message2: Message = {
1384 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
1385 deserialize(jd).unwrap_or_else(|err| {
1386 panic!(
1387 "Deserialization error at {} ({}:{}): {}",
1388 err.path(),
1389 err.inner().line(),
1390 err.inner().column(),
1391 err
1392 );
1393 })
1394 };
1395
1396 let assistant_message3: Message = {
1397 let jd: &mut serde_json::Deserializer<serde_json::de::StrRead<'_>> =
1398 &mut serde_json::Deserializer::from_str(assistant_message_json3);
1399 deserialize(jd).unwrap_or_else(|err| {
1400 panic!(
1401 "Deserialization error at {} ({}:{}): {}",
1402 err.path(),
1403 err.inner().line(),
1404 err.inner().column(),
1405 err
1406 );
1407 })
1408 };
1409
1410 let user_message: Message = {
1411 let jd = &mut serde_json::Deserializer::from_str(user_message_json);
1412 deserialize(jd).unwrap_or_else(|err| {
1413 panic!(
1414 "Deserialization error at {} ({}:{}): {}",
1415 err.path(),
1416 err.inner().line(),
1417 err.inner().column(),
1418 err
1419 );
1420 })
1421 };
1422
1423 match assistant_message {
1424 Message::Assistant { content, .. } => {
1425 assert_eq!(
1426 content[0],
1427 AssistantContent::Text {
1428 text: "\n\nHello there, how may I assist you today?".to_string()
1429 }
1430 );
1431 }
1432 _ => panic!("Expected assistant message"),
1433 }
1434
1435 match assistant_message2 {
1436 Message::Assistant {
1437 content,
1438 tool_calls,
1439 ..
1440 } => {
1441 assert_eq!(
1442 content[0],
1443 AssistantContent::Text {
1444 text: "\n\nHello there, how may I assist you today?".to_string()
1445 }
1446 );
1447
1448 assert_eq!(tool_calls, vec![]);
1449 }
1450 _ => panic!("Expected assistant message"),
1451 }
1452
1453 match assistant_message3 {
1454 Message::Assistant {
1455 content,
1456 tool_calls,
1457 refusal,
1458 ..
1459 } => {
1460 assert!(content.is_empty());
1461 assert!(refusal.is_none());
1462 assert_eq!(
1463 tool_calls[0],
1464 ToolCall {
1465 id: "call_h89ipqYUjEpCPI6SxspMnoUU".to_string(),
1466 r#type: ToolType::Function,
1467 function: Function {
1468 name: "subtract".to_string(),
1469 arguments: serde_json::json!({"x": 2, "y": 5}),
1470 },
1471 }
1472 );
1473 }
1474 _ => panic!("Expected assistant message"),
1475 }
1476
1477 match user_message {
1478 Message::User { content, .. } => {
1479 let (first, second) = {
1480 let mut iter = content.into_iter();
1481 (iter.next().unwrap(), iter.next().unwrap())
1482 };
1483 assert_eq!(
1484 first,
1485 UserContent::Text {
1486 text: "What's in this image?".to_string()
1487 }
1488 );
1489 assert_eq!(second, UserContent::Image { image_url: ImageUrl { url: "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg".to_string(), detail: ImageDetail::default() } });
1490 }
1491 _ => panic!("Expected user message"),
1492 }
1493 }
1494
1495 #[test]
1496 fn test_message_to_message_conversion() {
1497 let user_message = message::Message::User {
1498 content: OneOrMany::one(message::UserContent::text("Hello")),
1499 };
1500
1501 let assistant_message = message::Message::Assistant {
1502 content: OneOrMany::one(message::AssistantContent::text("Hi there!")),
1503 };
1504
1505 let converted_user_message: Vec<Message> = user_message.clone().try_into().unwrap();
1506 let converted_assistant_message: Vec<Message> =
1507 assistant_message.clone().try_into().unwrap();
1508
1509 match converted_user_message[0].clone() {
1510 Message::User { content, .. } => {
1511 assert_eq!(
1512 content.first(),
1513 UserContent::Text {
1514 text: "Hello".to_string()
1515 }
1516 );
1517 }
1518 _ => panic!("Expected user message"),
1519 }
1520
1521 match converted_assistant_message[0].clone() {
1522 Message::Assistant { content, .. } => {
1523 assert_eq!(
1524 content[0].clone(),
1525 AssistantContent::Text {
1526 text: "Hi there!".to_string()
1527 }
1528 );
1529 }
1530 _ => panic!("Expected assistant message"),
1531 }
1532
1533 let original_user_message: message::Message =
1534 converted_user_message[0].clone().try_into().unwrap();
1535 let original_assistant_message: message::Message =
1536 converted_assistant_message[0].clone().try_into().unwrap();
1537
1538 assert_eq!(original_user_message, user_message);
1539 assert_eq!(original_assistant_message, assistant_message);
1540 }
1541
1542 #[test]
1543 fn test_message_from_message_conversion() {
1544 let user_message = Message::User {
1545 content: OneOrMany::one(UserContent::Text {
1546 text: "Hello".to_string(),
1547 }),
1548 name: None,
1549 };
1550
1551 let assistant_message = Message::Assistant {
1552 content: vec![AssistantContent::Text {
1553 text: "Hi there!".to_string(),
1554 }],
1555 refusal: None,
1556 audio: None,
1557 name: None,
1558 tool_calls: vec![],
1559 };
1560
1561 let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
1562 let converted_assistant_message: message::Message =
1563 assistant_message.clone().try_into().unwrap();
1564
1565 match converted_user_message.clone() {
1566 message::Message::User { content } => {
1567 assert_eq!(content.first(), message::UserContent::text("Hello"));
1568 }
1569 _ => panic!("Expected user message"),
1570 }
1571
1572 match converted_assistant_message.clone() {
1573 message::Message::Assistant { content } => {
1574 assert_eq!(
1575 content.first(),
1576 message::AssistantContent::text("Hi there!")
1577 );
1578 }
1579 _ => panic!("Expected assistant message"),
1580 }
1581
1582 let original_user_message: Vec<Message> = converted_user_message.try_into().unwrap();
1583 let original_assistant_message: Vec<Message> =
1584 converted_assistant_message.try_into().unwrap();
1585
1586 assert_eq!(original_user_message[0], user_message);
1587 assert_eq!(original_assistant_message[0], assistant_message);
1588 }
1589}