use crate::json_utils::merge;
use crate::streaming::StreamingResult;
use crate::{
agent::AgentBuilder,
completion::{self, CompletionError, CompletionRequest},
embeddings::{self, EmbeddingError, EmbeddingsBuilder},
extractor::ExtractorBuilder,
json_utils,
message::{self, AudioMediaType, ImageDetail},
one_or_many::string_or_one_or_many,
streaming,
streaming::StreamingCompletionModel,
transcription::{self, TranscriptionError},
Embed, OneOrMany,
};
use async_stream::stream;
use futures::StreamExt;
use reqwest::multipart::Part;
use reqwest::RequestBuilder;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use std::collections::HashMap;
use std::{convert::Infallible, str::FromStr};
const OPENAI_API_BASE_URL: &str = "https://api.openai.com/v1";
#[derive(Clone)]
pub struct Client {
base_url: String,
http_client: reqwest::Client,
}
impl Client {
pub fn new(api_key: &str) -> Self {
Self::from_url(api_key, OPENAI_API_BASE_URL)
}
pub fn from_url(api_key: &str, base_url: &str) -> Self {
Self {
base_url: base_url.to_string(),
http_client: reqwest::Client::builder()
.default_headers({
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
"Authorization",
format!("Bearer {}", api_key)
.parse()
.expect("Bearer token should parse"),
);
headers
})
.build()
.expect("OpenAI reqwest client should build"),
}
}
pub fn from_env() -> Self {
let api_key = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set");
Self::new(&api_key)
}
fn post(&self, path: &str) -> reqwest::RequestBuilder {
let url = format!("{}/{}", self.base_url, path).replace("//", "/");
self.http_client.post(url)
}
pub fn embedding_model(&self, model: &str) -> EmbeddingModel {
let ndims = match model {
TEXT_EMBEDDING_3_LARGE => 3072,
TEXT_EMBEDDING_3_SMALL | TEXT_EMBEDDING_ADA_002 => 1536,
_ => 0,
};
EmbeddingModel::new(self.clone(), model, ndims)
}
pub fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> EmbeddingModel {
EmbeddingModel::new(self.clone(), model, ndims)
}
pub fn embeddings<D: Embed>(&self, model: &str) -> EmbeddingsBuilder<EmbeddingModel, D> {
EmbeddingsBuilder::new(self.embedding_model(model))
}
pub fn completion_model(&self, model: &str) -> CompletionModel {
CompletionModel::new(self.clone(), model)
}
pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
AgentBuilder::new(self.completion_model(model))
}
pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
&self,
model: &str,
) -> ExtractorBuilder<T, CompletionModel> {
ExtractorBuilder::new(self.completion_model(model))
}
pub fn transcription_model(&self, model: &str) -> TranscriptionModel {
TranscriptionModel::new(self.clone(), model)
}
}
#[derive(Debug, Deserialize)]
struct ApiErrorResponse {
message: String,
}
#[derive(Debug, Deserialize)]
#[serde(untagged)]
enum ApiResponse<T> {
Ok(T),
Err(ApiErrorResponse),
}
pub const TEXT_EMBEDDING_3_LARGE: &str = "text-embedding-3-large";
pub const TEXT_EMBEDDING_3_SMALL: &str = "text-embedding-3-small";
pub const TEXT_EMBEDDING_ADA_002: &str = "text-embedding-ada-002";
#[derive(Debug, Deserialize)]
pub struct EmbeddingResponse {
pub object: String,
pub data: Vec<EmbeddingData>,
pub model: String,
pub usage: Usage,
}
impl From<ApiErrorResponse> for EmbeddingError {
fn from(err: ApiErrorResponse) -> Self {
EmbeddingError::ProviderError(err.message)
}
}
impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
match value {
ApiResponse::Ok(response) => Ok(response),
ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
}
}
}
#[derive(Debug, Deserialize)]
pub struct EmbeddingData {
pub object: String,
pub embedding: Vec<f64>,
pub index: usize,
}
#[derive(Clone, Debug, Deserialize)]
pub struct Usage {
pub prompt_tokens: usize,
pub total_tokens: usize,
}
impl std::fmt::Display for Usage {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Prompt tokens: {} Total tokens: {}",
self.prompt_tokens, self.total_tokens
)
}
}
#[derive(Clone)]
pub struct EmbeddingModel {
client: Client,
pub model: String,
ndims: usize,
}
impl embeddings::EmbeddingModel for EmbeddingModel {
const MAX_DOCUMENTS: usize = 1024;
fn ndims(&self) -> usize {
self.ndims
}
#[cfg_attr(feature = "worker", worker::send)]
async fn embed_texts(
&self,
documents: impl IntoIterator<Item = String>,
) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
let documents = documents.into_iter().collect::<Vec<_>>();
let response = self
.client
.post("/embeddings")
.json(&json!({
"model": self.model,
"input": documents,
}))
.send()
.await?;
if response.status().is_success() {
match response.json::<ApiResponse<EmbeddingResponse>>().await? {
ApiResponse::Ok(response) => {
tracing::info!(target: "rig",
"OpenAI embedding token usage: {}",
response.usage
);
if response.data.len() != documents.len() {
return Err(EmbeddingError::ResponseError(
"Response data length does not match input length".into(),
));
}
Ok(response
.data
.into_iter()
.zip(documents.into_iter())
.map(|(embedding, document)| embeddings::Embedding {
document,
vec: embedding.embedding,
})
.collect())
}
ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
}
} else {
Err(EmbeddingError::ProviderError(response.text().await?))
}
}
}
impl EmbeddingModel {
pub fn new(client: Client, model: &str, ndims: usize) -> Self {
Self {
client,
model: model.to_string(),
ndims,
}
}
}
pub const O3_MINI: &str = "o3-mini";
pub const O3_MINI_2025_01_31: &str = "o3-mini-2025-01-31";
pub const O1: &str = "o1";
pub const O1_2024_12_17: &str = "o1-2024-12-17";
pub const O1_PREVIEW: &str = "o1-preview";
pub const O1_PREVIEW_2024_09_12: &str = "o1-preview-2024-09-12";
pub const O1_MINI: &str = "o1-mini";
pub const O1_MINI_2024_09_12: &str = "o1-mini-2024-09-12";
pub const GPT_4_5_PREVIEW: &str = "gpt-4.5-preview";
pub const GPT_4_5_PREVIEW_2025_02_27: &str = "gpt-4.5-preview-2025-02-27";
pub const GPT_4O: &str = "gpt-4o";
pub const GPT_4O_MINI: &str = "gpt-4o-mini";
pub const GPT_4O_2024_05_13: &str = "gpt-4o-2024-05-13";
pub const GPT_4_TURBO: &str = "gpt-4-turbo";
pub const GPT_4_TURBO_2024_04_09: &str = "gpt-4-turbo-2024-04-09";
pub const GPT_4_TURBO_PREVIEW: &str = "gpt-4-turbo-preview";
pub const GPT_4_0125_PREVIEW: &str = "gpt-4-0125-preview";
pub const GPT_4_1106_PREVIEW: &str = "gpt-4-1106-preview";
pub const GPT_4_VISION_PREVIEW: &str = "gpt-4-vision-preview";
pub const GPT_4_1106_VISION_PREVIEW: &str = "gpt-4-1106-vision-preview";
pub const GPT_4: &str = "gpt-4";
pub const GPT_4_0613: &str = "gpt-4-0613";
pub const GPT_4_32K: &str = "gpt-4-32k";
pub const GPT_4_32K_0613: &str = "gpt-4-32k-0613";
pub const GPT_35_TURBO: &str = "gpt-3.5-turbo";
pub const GPT_35_TURBO_0125: &str = "gpt-3.5-turbo-0125";
pub const GPT_35_TURBO_1106: &str = "gpt-3.5-turbo-1106";
pub const GPT_35_TURBO_INSTRUCT: &str = "gpt-3.5-turbo-instruct";
#[derive(Debug, Deserialize)]
pub struct CompletionResponse {
pub id: String,
pub object: String,
pub created: u64,
pub model: String,
pub system_fingerprint: Option<String>,
pub choices: Vec<Choice>,
pub usage: Option<Usage>,
}
impl From<ApiErrorResponse> for CompletionError {
fn from(err: ApiErrorResponse) -> Self {
CompletionError::ProviderError(err.message)
}
}
impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
type Error = CompletionError;
fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
let choice = response.choices.first().ok_or_else(|| {
CompletionError::ResponseError("Response contained no choices".to_owned())
})?;
let content = match &choice.message {
Message::Assistant {
content,
tool_calls,
..
} => {
let mut content = content
.iter()
.filter_map(|c| {
let s = match c {
AssistantContent::Text { text } => text,
AssistantContent::Refusal { refusal } => refusal,
};
if s.is_empty() {
None
} else {
Some(completion::AssistantContent::text(s))
}
})
.collect::<Vec<_>>();
content.extend(
tool_calls
.iter()
.map(|call| {
completion::AssistantContent::tool_call(
&call.id,
&call.function.name,
call.function.arguments.clone(),
)
})
.collect::<Vec<_>>(),
);
Ok(content)
}
_ => Err(CompletionError::ResponseError(
"Response did not contain a valid message or tool call".into(),
)),
}?;
let choice = OneOrMany::many(content).map_err(|_| {
CompletionError::ResponseError(
"Response contained no message or tool call (empty)".to_owned(),
)
})?;
Ok(completion::CompletionResponse {
choice,
raw_response: response,
})
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct Choice {
pub index: usize,
pub message: Message,
pub logprobs: Option<serde_json::Value>,
pub finish_reason: String,
}
#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
#[serde(tag = "role", rename_all = "lowercase")]
pub enum Message {
System {
#[serde(deserialize_with = "string_or_one_or_many")]
content: OneOrMany<SystemContent>,
#[serde(skip_serializing_if = "Option::is_none")]
name: Option<String>,
},
User {
#[serde(deserialize_with = "string_or_one_or_many")]
content: OneOrMany<UserContent>,
#[serde(skip_serializing_if = "Option::is_none")]
name: Option<String>,
},
Assistant {
#[serde(default, deserialize_with = "json_utils::string_or_vec")]
content: Vec<AssistantContent>,
#[serde(skip_serializing_if = "Option::is_none")]
refusal: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
audio: Option<AudioAssistant>,
#[serde(skip_serializing_if = "Option::is_none")]
name: Option<String>,
#[serde(
default,
deserialize_with = "json_utils::null_or_vec",
skip_serializing_if = "Vec::is_empty"
)]
tool_calls: Vec<ToolCall>,
},
#[serde(rename = "tool")]
ToolResult {
tool_call_id: String,
content: OneOrMany<ToolResultContent>,
},
}
impl Message {
pub fn system(content: &str) -> Self {
Message::System {
content: OneOrMany::one(content.to_owned().into()),
name: None,
}
}
}
#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
pub struct AudioAssistant {
id: String,
}
#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
pub struct SystemContent {
#[serde(default)]
r#type: SystemContentType,
text: String,
}
#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
#[serde(rename_all = "lowercase")]
pub enum SystemContentType {
#[default]
Text,
}
#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
#[serde(tag = "type", rename_all = "lowercase")]
pub enum AssistantContent {
Text { text: String },
Refusal { refusal: String },
}
#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
#[serde(tag = "type", rename_all = "lowercase")]
pub enum UserContent {
Text { text: String },
Image { image_url: ImageUrl },
Audio { input_audio: InputAudio },
}
#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
pub struct ImageUrl {
pub url: String,
#[serde(default)]
pub detail: ImageDetail,
}
#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
pub struct InputAudio {
pub data: String,
pub format: AudioMediaType,
}
#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
pub struct ToolResultContent {
#[serde(default)]
r#type: ToolResultContentType,
text: String,
}
#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
#[serde(rename_all = "lowercase")]
pub enum ToolResultContentType {
#[default]
Text,
}
impl FromStr for ToolResultContent {
type Err = Infallible;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(s.to_owned().into())
}
}
impl From<String> for ToolResultContent {
fn from(s: String) -> Self {
ToolResultContent {
r#type: ToolResultContentType::default(),
text: s,
}
}
}
#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
pub struct ToolCall {
pub id: String,
#[serde(default)]
pub r#type: ToolType,
pub function: Function,
}
#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
#[serde(rename_all = "lowercase")]
pub enum ToolType {
#[default]
Function,
}
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct ToolDefinition {
pub r#type: String,
pub function: completion::ToolDefinition,
}
impl From<completion::ToolDefinition> for ToolDefinition {
fn from(tool: completion::ToolDefinition) -> Self {
Self {
r#type: "function".into(),
function: tool,
}
}
}
#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
pub struct Function {
pub name: String,
#[serde(with = "json_utils::stringified_json")]
pub arguments: serde_json::Value,
}
impl TryFrom<message::Message> for Vec<Message> {
type Error = message::MessageError;
fn try_from(message: message::Message) -> Result<Self, Self::Error> {
match message {
message::Message::User { content } => {
let (tool_results, other_content): (Vec<_>, Vec<_>) = content
.into_iter()
.partition(|content| matches!(content, message::UserContent::ToolResult(_)));
if !tool_results.is_empty() {
tool_results
.into_iter()
.map(|content| match content {
message::UserContent::ToolResult(message::ToolResult {
id,
content,
}) => Ok::<_, message::MessageError>(Message::ToolResult {
tool_call_id: id,
content: content.try_map(|content| match content {
message::ToolResultContent::Text(message::Text { text }) => {
Ok(text.into())
}
_ => Err(message::MessageError::ConversionError(
"Tool result content does not support non-text".into(),
)),
})?,
}),
_ => unreachable!(),
})
.collect::<Result<Vec<_>, _>>()
} else {
let other_content = OneOrMany::many(other_content).expect(
"There must be other content here if there were no tool result content",
);
Ok(vec![Message::User {
content: other_content.map(|content| match content {
message::UserContent::Text(message::Text { text }) => {
UserContent::Text { text }
}
message::UserContent::Image(message::Image {
data, detail, ..
}) => UserContent::Image {
image_url: ImageUrl {
url: data,
detail: detail.unwrap_or_default(),
},
},
message::UserContent::Document(message::Document { data, .. }) => {
UserContent::Text { text: data }
}
message::UserContent::Audio(message::Audio {
data,
media_type,
..
}) => UserContent::Audio {
input_audio: InputAudio {
data,
format: match media_type {
Some(media_type) => media_type,
None => AudioMediaType::MP3,
},
},
},
_ => unreachable!(),
}),
name: None,
}])
}
}
message::Message::Assistant { content } => {
let (text_content, tool_calls) = content.into_iter().fold(
(Vec::new(), Vec::new()),
|(mut texts, mut tools), content| {
match content {
message::AssistantContent::Text(text) => texts.push(text),
message::AssistantContent::ToolCall(tool_call) => tools.push(tool_call),
}
(texts, tools)
},
);
Ok(vec![Message::Assistant {
content: text_content
.into_iter()
.map(|content| content.text.into())
.collect::<Vec<_>>(),
refusal: None,
audio: None,
name: None,
tool_calls: tool_calls
.into_iter()
.map(|tool_call| tool_call.into())
.collect::<Vec<_>>(),
}])
}
}
}
}
impl From<message::ToolCall> for ToolCall {
fn from(tool_call: message::ToolCall) -> Self {
Self {
id: tool_call.id,
r#type: ToolType::default(),
function: Function {
name: tool_call.function.name,
arguments: tool_call.function.arguments,
},
}
}
}
impl From<ToolCall> for message::ToolCall {
fn from(tool_call: ToolCall) -> Self {
Self {
id: tool_call.id,
function: message::ToolFunction {
name: tool_call.function.name,
arguments: tool_call.function.arguments,
},
}
}
}
impl TryFrom<Message> for message::Message {
type Error = message::MessageError;
fn try_from(message: Message) -> Result<Self, Self::Error> {
Ok(match message {
Message::User { content, .. } => message::Message::User {
content: content.map(|content| content.into()),
},
Message::Assistant {
content,
tool_calls,
..
} => {
let mut content = content
.into_iter()
.map(|content| match content {
AssistantContent::Text { text } => message::AssistantContent::text(text),
AssistantContent::Refusal { refusal } => {
message::AssistantContent::text(refusal)
}
})
.collect::<Vec<_>>();
content.extend(
tool_calls
.into_iter()
.map(|tool_call| Ok(message::AssistantContent::ToolCall(tool_call.into())))
.collect::<Result<Vec<_>, _>>()?,
);
message::Message::Assistant {
content: OneOrMany::many(content).map_err(|_| {
message::MessageError::ConversionError(
"Neither `content` nor `tool_calls` was provided to the Message"
.to_owned(),
)
})?,
}
}
Message::ToolResult {
tool_call_id,
content,
} => message::Message::User {
content: OneOrMany::one(message::UserContent::tool_result(
tool_call_id,
content.map(|content| message::ToolResultContent::text(content.text)),
)),
},
Message::System { content, .. } => message::Message::User {
content: content.map(|content| message::UserContent::text(content.text)),
},
})
}
}
impl From<UserContent> for message::UserContent {
fn from(content: UserContent) -> Self {
match content {
UserContent::Text { text } => message::UserContent::text(text),
UserContent::Image { image_url } => message::UserContent::image(
image_url.url,
Some(message::ContentFormat::default()),
None,
Some(image_url.detail),
),
UserContent::Audio { input_audio } => message::UserContent::audio(
input_audio.data,
Some(message::ContentFormat::default()),
Some(input_audio.format),
),
}
}
}
impl From<String> for UserContent {
fn from(s: String) -> Self {
UserContent::Text { text: s }
}
}
impl FromStr for UserContent {
type Err = Infallible;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(UserContent::Text {
text: s.to_string(),
})
}
}
impl From<String> for AssistantContent {
fn from(s: String) -> Self {
AssistantContent::Text { text: s }
}
}
impl FromStr for AssistantContent {
type Err = Infallible;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(AssistantContent::Text {
text: s.to_string(),
})
}
}
impl From<String> for SystemContent {
fn from(s: String) -> Self {
SystemContent {
r#type: SystemContentType::default(),
text: s,
}
}
}
impl FromStr for SystemContent {
type Err = Infallible;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(SystemContent {
r#type: SystemContentType::default(),
text: s.to_string(),
})
}
}
#[derive(Clone)]
pub struct CompletionModel {
client: Client,
pub model: String,
}
impl CompletionModel {
pub fn new(client: Client, model: &str) -> Self {
Self {
client,
model: model.to_string(),
}
}
fn create_completion_request(
&self,
completion_request: CompletionRequest,
) -> Result<Value, CompletionError> {
let mut full_history: Vec<Message> = match &completion_request.preamble {
Some(preamble) => vec![Message::system(preamble)],
None => vec![],
};
let prompt: Vec<Message> = completion_request.prompt_with_context().try_into()?;
let chat_history: Vec<Message> = completion_request
.chat_history
.into_iter()
.map(|message| message.try_into())
.collect::<Result<Vec<Vec<Message>>, _>>()?
.into_iter()
.flatten()
.collect();
full_history.extend(chat_history);
full_history.extend(prompt);
let request = if completion_request.tools.is_empty() {
json!({
"model": self.model,
"messages": full_history,
})
} else {
json!({
"model": self.model,
"messages": full_history,
"tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
"tool_choice": "auto",
})
};
let request = if let Some(temperature) = completion_request.temperature {
json_utils::merge(
request,
json!({
"temperature": temperature,
}),
)
} else {
request
};
let request = if let Some(params) = completion_request.additional_params {
json_utils::merge(request, params)
} else {
request
};
Ok(request)
}
}
impl completion::CompletionModel for CompletionModel {
type Response = CompletionResponse;
#[cfg_attr(feature = "worker", worker::send)]
async fn completion(
&self,
completion_request: CompletionRequest,
) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
let request = self.create_completion_request(completion_request)?;
let response = self
.client
.post("/chat/completions")
.json(&request)
.send()
.await?;
if response.status().is_success() {
let t = response.text().await?;
tracing::debug!(target: "rig", "OpenAI completion error: {}", t);
match serde_json::from_str::<ApiResponse<CompletionResponse>>(&t)? {
ApiResponse::Ok(response) => {
tracing::info!(target: "rig",
"OpenAI completion token usage: {:?}",
response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string())
);
response.try_into()
}
ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
}
} else {
Err(CompletionError::ProviderError(response.text().await?))
}
}
}
pub const WHISPER_1: &str = "whisper-1";
#[derive(Debug, Deserialize)]
pub struct TranscriptionResponse {
pub text: String,
}
impl TryFrom<TranscriptionResponse>
for transcription::TranscriptionResponse<TranscriptionResponse>
{
type Error = TranscriptionError;
fn try_from(value: TranscriptionResponse) -> Result<Self, Self::Error> {
Ok(transcription::TranscriptionResponse {
text: value.text.clone(),
response: value,
})
}
}
#[derive(Clone)]
pub struct TranscriptionModel {
client: Client,
pub model: String,
}
impl TranscriptionModel {
pub fn new(client: Client, model: &str) -> Self {
Self {
client,
model: model.to_string(),
}
}
}
impl transcription::TranscriptionModel for TranscriptionModel {
type Response = TranscriptionResponse;
#[cfg_attr(feature = "worker", worker::send)]
async fn transcription(
&self,
request: transcription::TranscriptionRequest,
) -> Result<
transcription::TranscriptionResponse<Self::Response>,
transcription::TranscriptionError,
> {
let data = request.data;
let mut body = reqwest::multipart::Form::new()
.text("model", self.model.clone())
.text("language", request.language)
.part(
"file",
Part::bytes(data).file_name(request.filename.clone()),
);
if let Some(prompt) = request.prompt {
body = body.text("prompt", prompt.clone());
}
if let Some(ref temperature) = request.temperature {
body = body.text("temperature", temperature.to_string());
}
if let Some(ref additional_params) = request.additional_params {
for (key, value) in additional_params
.as_object()
.expect("Additional Parameters to OpenAI Transcription should be a map")
{
body = body.text(key.to_owned(), value.to_string());
}
}
let response = self
.client
.post("audio/transcriptions")
.multipart(body)
.send()
.await?;
if response.status().is_success() {
match response
.json::<ApiResponse<TranscriptionResponse>>()
.await?
{
ApiResponse::Ok(response) => response.try_into(),
ApiResponse::Err(api_error_response) => Err(TranscriptionError::ProviderError(
api_error_response.message,
)),
}
} else {
Err(TranscriptionError::ProviderError(response.text().await?))
}
}
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct StreamingFunction {
#[serde(default)]
name: Option<String>,
#[serde(default)]
arguments: String,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct StreamingToolCall {
pub index: usize,
pub function: StreamingFunction,
}
#[derive(Deserialize)]
struct StreamingDelta {
#[serde(default)]
content: Option<String>,
#[serde(default, deserialize_with = "json_utils::null_or_vec")]
tool_calls: Vec<StreamingToolCall>,
}
#[derive(Deserialize)]
struct StreamingChoice {
delta: StreamingDelta,
}
#[derive(Deserialize)]
struct StreamingCompletionResponse {
choices: Vec<StreamingChoice>,
}
impl StreamingCompletionModel for CompletionModel {
async fn stream(
&self,
completion_request: CompletionRequest,
) -> Result<StreamingResult, CompletionError> {
let mut request = self.create_completion_request(completion_request)?;
request = merge(request, json!({"stream": true}));
let builder = self.client.post("/chat/completions").json(&request);
send_compatible_streaming_request(builder).await
}
}
pub async fn send_compatible_streaming_request(
request_builder: RequestBuilder,
) -> Result<StreamingResult, CompletionError> {
let response = request_builder.send().await?;
if !response.status().is_success() {
return Err(CompletionError::ProviderError(format!(
"{}: {}",
response.status(),
response.text().await?
)));
}
Ok(Box::pin(stream! {
let mut stream = response.bytes_stream();
let mut partial_data = None;
let mut calls: HashMap<usize, (String, String)> = HashMap::new();
while let Some(chunk_result) = stream.next().await {
let chunk = match chunk_result {
Ok(c) => c,
Err(e) => {
yield Err(CompletionError::from(e));
break;
}
};
let text = match String::from_utf8(chunk.to_vec()) {
Ok(t) => t,
Err(e) => {
yield Err(CompletionError::ResponseError(e.to_string()));
break;
}
};
for line in text.lines() {
let mut line = line.to_string();
if partial_data.is_some() {
line = format!("{}{}", partial_data.unwrap(), line);
partial_data = None;
}
else {
let Some(data) = line.strip_prefix("data: ") else {
continue;
};
if !line.ends_with("}") {
partial_data = Some(data.to_string());
} else {
line = data.to_string();
}
}
let data = serde_json::from_str::<StreamingCompletionResponse>(&line);
let Ok(data) = data else {
continue;
};
let choice = data.choices.first().expect("Should have at least one choice");
let delta = &choice.delta;
if !delta.tool_calls.is_empty() {
for tool_call in &delta.tool_calls {
let function = tool_call.function.clone();
if function.name.is_some() && function.arguments.is_empty() {
calls.insert(tool_call.index, (function.name.clone().unwrap(), "".to_string()));
}
else if function.name.is_none() && !function.arguments.is_empty() {
let Some((name, arguments)) = calls.get(&tool_call.index) else {
continue;
};
let new_arguments = &tool_call.function.arguments;
let arguments = format!("{}{}", arguments, new_arguments);
calls.insert(tool_call.index, (name.clone(), arguments));
}
else {
let name = function.name.unwrap();
let arguments = function.arguments;
let Ok(arguments) = serde_json::from_str(&arguments) else {
continue;
};
yield Ok(streaming::StreamingChoice::ToolCall(name, "".to_string(), arguments))
}
}
}
if let Some(content) = &choice.delta.content {
yield Ok(streaming::StreamingChoice::Message(content.clone()))
}
}
}
for (_, (name, arguments)) in calls {
let Ok(arguments) = serde_json::from_str(&arguments) else {
continue;
};
yield Ok(streaming::StreamingChoice::ToolCall(name, "".to_string(), arguments))
}
}))
}
#[cfg(test)]
mod tests {
use super::*;
use serde_path_to_error::deserialize;
#[test]
fn test_deserialize_message() {
let assistant_message_json = r#"
{
"role": "assistant",
"content": "\n\nHello there, how may I assist you today?"
}
"#;
let assistant_message_json2 = r#"
{
"role": "assistant",
"content": [
{
"type": "text",
"text": "\n\nHello there, how may I assist you today?"
}
],
"tool_calls": null
}
"#;
let assistant_message_json3 = r#"
{
"role": "assistant",
"tool_calls": [
{
"id": "call_h89ipqYUjEpCPI6SxspMnoUU",
"type": "function",
"function": {
"name": "subtract",
"arguments": "{\"x\": 2, \"y\": 5}"
}
}
],
"content": null,
"refusal": null
}
"#;
let user_message_json = r#"
{
"role": "user",
"content": [
{
"type": "text",
"text": "What's in this image?"
},
{
"type": "image",
"image_url": {
"url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
}
},
{
"type": "audio",
"input_audio": {
"data": "...",
"format": "mp3"
}
}
]
}
"#;
let assistant_message: Message = {
let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
deserialize(jd).unwrap_or_else(|err| {
panic!(
"Deserialization error at {} ({}:{}): {}",
err.path(),
err.inner().line(),
err.inner().column(),
err
);
})
};
let assistant_message2: Message = {
let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
deserialize(jd).unwrap_or_else(|err| {
panic!(
"Deserialization error at {} ({}:{}): {}",
err.path(),
err.inner().line(),
err.inner().column(),
err
);
})
};
let assistant_message3: Message = {
let jd: &mut serde_json::Deserializer<serde_json::de::StrRead<'_>> =
&mut serde_json::Deserializer::from_str(assistant_message_json3);
deserialize(jd).unwrap_or_else(|err| {
panic!(
"Deserialization error at {} ({}:{}): {}",
err.path(),
err.inner().line(),
err.inner().column(),
err
);
})
};
let user_message: Message = {
let jd = &mut serde_json::Deserializer::from_str(user_message_json);
deserialize(jd).unwrap_or_else(|err| {
panic!(
"Deserialization error at {} ({}:{}): {}",
err.path(),
err.inner().line(),
err.inner().column(),
err
);
})
};
match assistant_message {
Message::Assistant { content, .. } => {
assert_eq!(
content[0],
AssistantContent::Text {
text: "\n\nHello there, how may I assist you today?".to_string()
}
);
}
_ => panic!("Expected assistant message"),
}
match assistant_message2 {
Message::Assistant {
content,
tool_calls,
..
} => {
assert_eq!(
content[0],
AssistantContent::Text {
text: "\n\nHello there, how may I assist you today?".to_string()
}
);
assert_eq!(tool_calls, vec![]);
}
_ => panic!("Expected assistant message"),
}
match assistant_message3 {
Message::Assistant {
content,
tool_calls,
refusal,
..
} => {
assert!(content.is_empty());
assert!(refusal.is_none());
assert_eq!(
tool_calls[0],
ToolCall {
id: "call_h89ipqYUjEpCPI6SxspMnoUU".to_string(),
r#type: ToolType::Function,
function: Function {
name: "subtract".to_string(),
arguments: serde_json::json!({"x": 2, "y": 5}),
},
}
);
}
_ => panic!("Expected assistant message"),
}
match user_message {
Message::User { content, .. } => {
let (first, second) = {
let mut iter = content.into_iter();
(iter.next().unwrap(), iter.next().unwrap())
};
assert_eq!(
first,
UserContent::Text {
text: "What's in this image?".to_string()
}
);
assert_eq!(second, UserContent::Image { image_url: ImageUrl { url: "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg".to_string(), detail: ImageDetail::default() } });
}
_ => panic!("Expected user message"),
}
}
#[test]
fn test_message_to_message_conversion() {
let user_message = message::Message::User {
content: OneOrMany::one(message::UserContent::text("Hello")),
};
let assistant_message = message::Message::Assistant {
content: OneOrMany::one(message::AssistantContent::text("Hi there!")),
};
let converted_user_message: Vec<Message> = user_message.clone().try_into().unwrap();
let converted_assistant_message: Vec<Message> =
assistant_message.clone().try_into().unwrap();
match converted_user_message[0].clone() {
Message::User { content, .. } => {
assert_eq!(
content.first(),
UserContent::Text {
text: "Hello".to_string()
}
);
}
_ => panic!("Expected user message"),
}
match converted_assistant_message[0].clone() {
Message::Assistant { content, .. } => {
assert_eq!(
content[0].clone(),
AssistantContent::Text {
text: "Hi there!".to_string()
}
);
}
_ => panic!("Expected assistant message"),
}
let original_user_message: message::Message =
converted_user_message[0].clone().try_into().unwrap();
let original_assistant_message: message::Message =
converted_assistant_message[0].clone().try_into().unwrap();
assert_eq!(original_user_message, user_message);
assert_eq!(original_assistant_message, assistant_message);
}
#[test]
fn test_message_from_message_conversion() {
let user_message = Message::User {
content: OneOrMany::one(UserContent::Text {
text: "Hello".to_string(),
}),
name: None,
};
let assistant_message = Message::Assistant {
content: vec![AssistantContent::Text {
text: "Hi there!".to_string(),
}],
refusal: None,
audio: None,
name: None,
tool_calls: vec![],
};
let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
let converted_assistant_message: message::Message =
assistant_message.clone().try_into().unwrap();
match converted_user_message.clone() {
message::Message::User { content } => {
assert_eq!(content.first(), message::UserContent::text("Hello"));
}
_ => panic!("Expected user message"),
}
match converted_assistant_message.clone() {
message::Message::Assistant { content } => {
assert_eq!(
content.first(),
message::AssistantContent::text("Hi there!")
);
}
_ => panic!("Expected assistant message"),
}
let original_user_message: Vec<Message> = converted_user_message.try_into().unwrap();
let original_assistant_message: Vec<Message> =
converted_assistant_message.try_into().unwrap();
assert_eq!(original_user_message[0], user_message);
assert_eq!(original_assistant_message[0], assistant_message);
}
}