use std::{convert::Infallible, str::FromStr};
use crate::OneOrMany;
use serde::{Deserialize, Serialize};
use thiserror::Error;
use super::CompletionError;
pub trait ConvertMessage: Sized + Send + Sync {
type Error: std::error::Error + Send;
fn convert_from_message(message: Message) -> Result<Vec<Self>, Self::Error>;
}
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
#[serde(tag = "role", rename_all = "lowercase")]
pub enum Message {
User { content: OneOrMany<UserContent> },
Assistant {
id: Option<String>,
content: OneOrMany<AssistantContent>,
},
}
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
#[serde(tag = "type", rename_all = "lowercase")]
pub enum UserContent {
Text(Text),
ToolResult(ToolResult),
Image(Image),
Audio(Audio),
Video(Video),
Document(Document),
}
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
#[serde(untagged)]
pub enum AssistantContent {
Text(Text),
ToolCall(ToolCall),
Reasoning(Reasoning),
Image(Image),
}
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
#[non_exhaustive]
pub struct Reasoning {
pub id: Option<String>,
pub reasoning: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub signature: Option<String>,
}
impl Reasoning {
pub fn new(input: &str) -> Self {
Self {
id: None,
reasoning: vec![input.to_string()],
signature: None,
}
}
pub fn optional_id(mut self, id: Option<String>) -> Self {
self.id = id;
self
}
pub fn with_id(mut self, id: String) -> Self {
self.id = Some(id);
self
}
pub fn with_signature(mut self, signature: Option<String>) -> Self {
self.signature = signature;
self
}
pub fn multi(input: Vec<String>) -> Self {
Self {
id: None,
reasoning: input,
signature: None,
}
}
}
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
pub struct ToolResult {
pub id: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub call_id: Option<String>,
pub content: OneOrMany<ToolResultContent>,
}
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
#[serde(tag = "type", rename_all = "lowercase")]
pub enum ToolResultContent {
Text(Text),
Image(Image),
}
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
pub struct ToolCall {
pub id: String,
pub call_id: Option<String>,
pub function: ToolFunction,
pub signature: Option<String>,
pub additional_params: Option<serde_json::Value>,
}
impl ToolCall {
pub fn new(id: String, function: ToolFunction) -> Self {
Self {
id,
call_id: None,
function,
signature: None,
additional_params: None,
}
}
pub fn with_call_id(mut self, call_id: String) -> Self {
self.call_id = Some(call_id);
self
}
pub fn with_signature(mut self, signature: Option<String>) -> Self {
self.signature = signature;
self
}
pub fn with_additional_params(mut self, additional_params: Option<serde_json::Value>) -> Self {
self.additional_params = additional_params;
self
}
}
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
pub struct ToolFunction {
pub name: String,
pub arguments: serde_json::Value,
}
impl ToolFunction {
pub fn new(name: String, arguments: serde_json::Value) -> Self {
Self { name, arguments }
}
}
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
pub struct Text {
pub text: String,
}
impl Text {
pub fn text(&self) -> &str {
&self.text
}
}
impl std::fmt::Display for Text {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let Self { text } = self;
write!(f, "{text}")
}
}
#[derive(Default, Clone, Debug, Deserialize, Serialize, PartialEq)]
pub struct Image {
pub data: DocumentSourceKind,
#[serde(skip_serializing_if = "Option::is_none")]
pub media_type: Option<ImageMediaType>,
#[serde(skip_serializing_if = "Option::is_none")]
pub detail: Option<ImageDetail>,
#[serde(flatten, skip_serializing_if = "Option::is_none")]
pub additional_params: Option<serde_json::Value>,
}
impl Image {
pub fn try_into_url(self) -> Result<String, MessageError> {
match self.data {
DocumentSourceKind::Url(url) => Ok(url),
DocumentSourceKind::Base64(data) => {
let Some(media_type) = self.media_type else {
return Err(MessageError::ConversionError(
"A media type is required to create a valid base64-encoded image URL"
.to_string(),
));
};
Ok(format!(
"data:image/{ty};base64,{data}",
ty = media_type.to_mime_type()
))
}
unknown => Err(MessageError::ConversionError(format!(
"Tried to convert unknown type to a URL: {unknown:?}"
))),
}
}
}
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Default)]
#[serde(tag = "type", content = "value", rename_all = "camelCase")]
#[non_exhaustive]
pub enum DocumentSourceKind {
Url(String),
Base64(String),
Raw(Vec<u8>),
String(String),
#[default]
Unknown,
}
impl DocumentSourceKind {
pub fn url(url: &str) -> Self {
Self::Url(url.to_string())
}
pub fn base64(base64_string: &str) -> Self {
Self::Base64(base64_string.to_string())
}
pub fn raw(bytes: impl Into<Vec<u8>>) -> Self {
Self::Raw(bytes.into())
}
pub fn string(input: &str) -> Self {
Self::String(input.into())
}
pub fn unknown() -> Self {
Self::Unknown
}
pub fn try_into_inner(self) -> Option<String> {
match self {
Self::Url(s) | Self::Base64(s) => Some(s),
_ => None,
}
}
}
impl std::fmt::Display for DocumentSourceKind {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Url(string) => write!(f, "{string}"),
Self::Base64(string) => write!(f, "{string}"),
Self::String(string) => write!(f, "{string}"),
Self::Raw(_) => write!(f, "<binary data>"),
Self::Unknown => write!(f, "<unknown>"),
}
}
}
#[derive(Default, Clone, Debug, Deserialize, Serialize, PartialEq)]
pub struct Audio {
pub data: DocumentSourceKind,
#[serde(skip_serializing_if = "Option::is_none")]
pub media_type: Option<AudioMediaType>,
#[serde(flatten, skip_serializing_if = "Option::is_none")]
pub additional_params: Option<serde_json::Value>,
}
#[derive(Default, Clone, Debug, Deserialize, Serialize, PartialEq)]
pub struct Video {
pub data: DocumentSourceKind,
#[serde(skip_serializing_if = "Option::is_none")]
pub media_type: Option<VideoMediaType>,
#[serde(flatten, skip_serializing_if = "Option::is_none")]
pub additional_params: Option<serde_json::Value>,
}
#[derive(Default, Clone, Debug, Deserialize, Serialize, PartialEq)]
pub struct Document {
pub data: DocumentSourceKind,
#[serde(skip_serializing_if = "Option::is_none")]
pub media_type: Option<DocumentMediaType>,
#[serde(flatten, skip_serializing_if = "Option::is_none")]
pub additional_params: Option<serde_json::Value>,
}
#[derive(Default, Clone, Debug, Deserialize, Serialize, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum ContentFormat {
#[default]
Base64,
String,
Url,
}
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
pub enum MediaType {
Image(ImageMediaType),
Audio(AudioMediaType),
Document(DocumentMediaType),
Video(VideoMediaType),
}
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum ImageMediaType {
JPEG,
PNG,
GIF,
WEBP,
HEIC,
HEIF,
SVG,
}
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum DocumentMediaType {
PDF,
TXT,
RTF,
HTML,
CSS,
MARKDOWN,
CSV,
XML,
Javascript,
Python,
}
impl DocumentMediaType {
pub fn is_code(&self) -> bool {
matches!(self, Self::Javascript | Self::Python)
}
}
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum AudioMediaType {
WAV,
MP3,
AIFF,
AAC,
OGG,
FLAC,
}
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum VideoMediaType {
AVI,
MP4,
MPEG,
}
#[derive(Default, Clone, Debug, Deserialize, Serialize, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum ImageDetail {
Low,
High,
#[default]
Auto,
}
impl Message {
pub(crate) fn rag_text(&self) -> Option<String> {
match self {
Message::User { content } => {
for item in content.iter() {
if let UserContent::Text(Text { text }) = item {
return Some(text.clone());
}
}
None
}
_ => None,
}
}
pub fn user(text: impl Into<String>) -> Self {
Message::User {
content: OneOrMany::one(UserContent::text(text)),
}
}
pub fn assistant(text: impl Into<String>) -> Self {
Message::Assistant {
id: None,
content: OneOrMany::one(AssistantContent::text(text)),
}
}
pub fn assistant_with_id(id: String, text: impl Into<String>) -> Self {
Message::Assistant {
id: Some(id),
content: OneOrMany::one(AssistantContent::text(text)),
}
}
pub fn tool_result(id: impl Into<String>, content: impl Into<String>) -> Self {
Message::User {
content: OneOrMany::one(UserContent::ToolResult(ToolResult {
id: id.into(),
call_id: None,
content: OneOrMany::one(ToolResultContent::text(content)),
})),
}
}
pub fn tool_result_with_call_id(
id: impl Into<String>,
call_id: Option<String>,
content: impl Into<String>,
) -> Self {
Message::User {
content: OneOrMany::one(UserContent::ToolResult(ToolResult {
id: id.into(),
call_id,
content: OneOrMany::one(ToolResultContent::text(content)),
})),
}
}
}
impl UserContent {
pub fn text(text: impl Into<String>) -> Self {
UserContent::Text(text.into().into())
}
pub fn image_base64(
data: impl Into<String>,
media_type: Option<ImageMediaType>,
detail: Option<ImageDetail>,
) -> Self {
UserContent::Image(Image {
data: DocumentSourceKind::Base64(data.into()),
media_type,
detail,
additional_params: None,
})
}
pub fn image_raw(
data: impl Into<Vec<u8>>,
media_type: Option<ImageMediaType>,
detail: Option<ImageDetail>,
) -> Self {
UserContent::Image(Image {
data: DocumentSourceKind::Raw(data.into()),
media_type,
detail,
..Default::default()
})
}
pub fn image_url(
url: impl Into<String>,
media_type: Option<ImageMediaType>,
detail: Option<ImageDetail>,
) -> Self {
UserContent::Image(Image {
data: DocumentSourceKind::Url(url.into()),
media_type,
detail,
additional_params: None,
})
}
pub fn audio(data: impl Into<String>, media_type: Option<AudioMediaType>) -> Self {
UserContent::Audio(Audio {
data: DocumentSourceKind::Base64(data.into()),
media_type,
additional_params: None,
})
}
pub fn audio_raw(data: impl Into<Vec<u8>>, media_type: Option<AudioMediaType>) -> Self {
UserContent::Audio(Audio {
data: DocumentSourceKind::Raw(data.into()),
media_type,
..Default::default()
})
}
pub fn audio_url(url: impl Into<String>, media_type: Option<AudioMediaType>) -> Self {
UserContent::Audio(Audio {
data: DocumentSourceKind::Url(url.into()),
media_type,
..Default::default()
})
}
pub fn document(data: impl Into<String>, media_type: Option<DocumentMediaType>) -> Self {
let data: String = data.into();
UserContent::Document(Document {
data: DocumentSourceKind::string(&data),
media_type,
additional_params: None,
})
}
pub fn document_raw(data: impl Into<Vec<u8>>, media_type: Option<DocumentMediaType>) -> Self {
UserContent::Document(Document {
data: DocumentSourceKind::Raw(data.into()),
media_type,
..Default::default()
})
}
pub fn document_url(url: impl Into<String>, media_type: Option<DocumentMediaType>) -> Self {
UserContent::Document(Document {
data: DocumentSourceKind::Url(url.into()),
media_type,
..Default::default()
})
}
pub fn tool_result(id: impl Into<String>, content: OneOrMany<ToolResultContent>) -> Self {
UserContent::ToolResult(ToolResult {
id: id.into(),
call_id: None,
content,
})
}
pub fn tool_result_with_call_id(
id: impl Into<String>,
call_id: String,
content: OneOrMany<ToolResultContent>,
) -> Self {
UserContent::ToolResult(ToolResult {
id: id.into(),
call_id: Some(call_id),
content,
})
}
}
impl AssistantContent {
pub fn text(text: impl Into<String>) -> Self {
AssistantContent::Text(text.into().into())
}
pub fn image_base64(
data: impl Into<String>,
media_type: Option<ImageMediaType>,
detail: Option<ImageDetail>,
) -> Self {
AssistantContent::Image(Image {
data: DocumentSourceKind::Base64(data.into()),
media_type,
detail,
additional_params: None,
})
}
pub fn tool_call(
id: impl Into<String>,
name: impl Into<String>,
arguments: serde_json::Value,
) -> Self {
AssistantContent::ToolCall(ToolCall::new(
id.into(),
ToolFunction {
name: name.into(),
arguments,
},
))
}
pub fn tool_call_with_call_id(
id: impl Into<String>,
call_id: String,
name: impl Into<String>,
arguments: serde_json::Value,
) -> Self {
AssistantContent::ToolCall(
ToolCall::new(
id.into(),
ToolFunction {
name: name.into(),
arguments,
},
)
.with_call_id(call_id),
)
}
pub fn reasoning(reasoning: impl AsRef<str>) -> Self {
AssistantContent::Reasoning(Reasoning::new(reasoning.as_ref()))
}
}
impl ToolResultContent {
pub fn text(text: impl Into<String>) -> Self {
ToolResultContent::Text(text.into().into())
}
pub fn image_base64(
data: impl Into<String>,
media_type: Option<ImageMediaType>,
detail: Option<ImageDetail>,
) -> Self {
ToolResultContent::Image(Image {
data: DocumentSourceKind::Base64(data.into()),
media_type,
detail,
additional_params: None,
})
}
pub fn image_raw(
data: impl Into<Vec<u8>>,
media_type: Option<ImageMediaType>,
detail: Option<ImageDetail>,
) -> Self {
ToolResultContent::Image(Image {
data: DocumentSourceKind::Raw(data.into()),
media_type,
detail,
..Default::default()
})
}
pub fn image_url(
url: impl Into<String>,
media_type: Option<ImageMediaType>,
detail: Option<ImageDetail>,
) -> Self {
ToolResultContent::Image(Image {
data: DocumentSourceKind::Url(url.into()),
media_type,
detail,
additional_params: None,
})
}
}
pub trait MimeType {
fn from_mime_type(mime_type: &str) -> Option<Self>
where
Self: Sized;
fn to_mime_type(&self) -> &'static str;
}
impl MimeType for MediaType {
fn from_mime_type(mime_type: &str) -> Option<Self> {
ImageMediaType::from_mime_type(mime_type)
.map(MediaType::Image)
.or_else(|| {
DocumentMediaType::from_mime_type(mime_type)
.map(MediaType::Document)
.or_else(|| {
AudioMediaType::from_mime_type(mime_type)
.map(MediaType::Audio)
.or_else(|| {
VideoMediaType::from_mime_type(mime_type).map(MediaType::Video)
})
})
})
}
fn to_mime_type(&self) -> &'static str {
match self {
MediaType::Image(media_type) => media_type.to_mime_type(),
MediaType::Audio(media_type) => media_type.to_mime_type(),
MediaType::Document(media_type) => media_type.to_mime_type(),
MediaType::Video(media_type) => media_type.to_mime_type(),
}
}
}
impl MimeType for ImageMediaType {
fn from_mime_type(mime_type: &str) -> Option<Self> {
match mime_type {
"image/jpeg" => Some(ImageMediaType::JPEG),
"image/png" => Some(ImageMediaType::PNG),
"image/gif" => Some(ImageMediaType::GIF),
"image/webp" => Some(ImageMediaType::WEBP),
"image/heic" => Some(ImageMediaType::HEIC),
"image/heif" => Some(ImageMediaType::HEIF),
"image/svg+xml" => Some(ImageMediaType::SVG),
_ => None,
}
}
fn to_mime_type(&self) -> &'static str {
match self {
ImageMediaType::JPEG => "image/jpeg",
ImageMediaType::PNG => "image/png",
ImageMediaType::GIF => "image/gif",
ImageMediaType::WEBP => "image/webp",
ImageMediaType::HEIC => "image/heic",
ImageMediaType::HEIF => "image/heif",
ImageMediaType::SVG => "image/svg+xml",
}
}
}
impl MimeType for DocumentMediaType {
fn from_mime_type(mime_type: &str) -> Option<Self> {
match mime_type {
"application/pdf" => Some(DocumentMediaType::PDF),
"text/plain" => Some(DocumentMediaType::TXT),
"text/rtf" => Some(DocumentMediaType::RTF),
"text/html" => Some(DocumentMediaType::HTML),
"text/css" => Some(DocumentMediaType::CSS),
"text/md" | "text/markdown" => Some(DocumentMediaType::MARKDOWN),
"text/csv" => Some(DocumentMediaType::CSV),
"text/xml" => Some(DocumentMediaType::XML),
"application/x-javascript" | "text/x-javascript" => Some(DocumentMediaType::Javascript),
"application/x-python" | "text/x-python" => Some(DocumentMediaType::Python),
_ => None,
}
}
fn to_mime_type(&self) -> &'static str {
match self {
DocumentMediaType::PDF => "application/pdf",
DocumentMediaType::TXT => "text/plain",
DocumentMediaType::RTF => "text/rtf",
DocumentMediaType::HTML => "text/html",
DocumentMediaType::CSS => "text/css",
DocumentMediaType::MARKDOWN => "text/markdown",
DocumentMediaType::CSV => "text/csv",
DocumentMediaType::XML => "text/xml",
DocumentMediaType::Javascript => "application/x-javascript",
DocumentMediaType::Python => "application/x-python",
}
}
}
impl MimeType for AudioMediaType {
fn from_mime_type(mime_type: &str) -> Option<Self> {
match mime_type {
"audio/wav" => Some(AudioMediaType::WAV),
"audio/mp3" => Some(AudioMediaType::MP3),
"audio/aiff" => Some(AudioMediaType::AIFF),
"audio/aac" => Some(AudioMediaType::AAC),
"audio/ogg" => Some(AudioMediaType::OGG),
"audio/flac" => Some(AudioMediaType::FLAC),
_ => None,
}
}
fn to_mime_type(&self) -> &'static str {
match self {
AudioMediaType::WAV => "audio/wav",
AudioMediaType::MP3 => "audio/mp3",
AudioMediaType::AIFF => "audio/aiff",
AudioMediaType::AAC => "audio/aac",
AudioMediaType::OGG => "audio/ogg",
AudioMediaType::FLAC => "audio/flac",
}
}
}
impl MimeType for VideoMediaType {
fn from_mime_type(mime_type: &str) -> Option<Self>
where
Self: Sized,
{
match mime_type {
"video/avi" => Some(VideoMediaType::AVI),
"video/mp4" => Some(VideoMediaType::MP4),
"video/mpeg" => Some(VideoMediaType::MPEG),
&_ => None,
}
}
fn to_mime_type(&self) -> &'static str {
match self {
VideoMediaType::AVI => "video/avi",
VideoMediaType::MP4 => "video/mp4",
VideoMediaType::MPEG => "video/mpeg",
}
}
}
impl std::str::FromStr for ImageDetail {
type Err = ();
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"low" => Ok(ImageDetail::Low),
"high" => Ok(ImageDetail::High),
"auto" => Ok(ImageDetail::Auto),
_ => Err(()),
}
}
}
impl From<String> for Text {
fn from(text: String) -> Self {
Text { text }
}
}
impl From<&String> for Text {
fn from(text: &String) -> Self {
text.to_owned().into()
}
}
impl From<&str> for Text {
fn from(text: &str) -> Self {
text.to_owned().into()
}
}
impl FromStr for Text {
type Err = Infallible;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(s.into())
}
}
impl From<String> for Message {
fn from(text: String) -> Self {
Message::User {
content: OneOrMany::one(UserContent::Text(text.into())),
}
}
}
impl From<&str> for Message {
fn from(text: &str) -> Self {
Message::User {
content: OneOrMany::one(UserContent::Text(text.into())),
}
}
}
impl From<&String> for Message {
fn from(text: &String) -> Self {
Message::User {
content: OneOrMany::one(UserContent::Text(text.into())),
}
}
}
impl From<Text> for Message {
fn from(text: Text) -> Self {
Message::User {
content: OneOrMany::one(UserContent::Text(text)),
}
}
}
impl From<Image> for Message {
fn from(image: Image) -> Self {
Message::User {
content: OneOrMany::one(UserContent::Image(image)),
}
}
}
impl From<Audio> for Message {
fn from(audio: Audio) -> Self {
Message::User {
content: OneOrMany::one(UserContent::Audio(audio)),
}
}
}
impl From<Document> for Message {
fn from(document: Document) -> Self {
Message::User {
content: OneOrMany::one(UserContent::Document(document)),
}
}
}
impl From<String> for ToolResultContent {
fn from(text: String) -> Self {
ToolResultContent::text(text)
}
}
impl From<String> for AssistantContent {
fn from(text: String) -> Self {
AssistantContent::text(text)
}
}
impl From<String> for UserContent {
fn from(text: String) -> Self {
UserContent::text(text)
}
}
impl From<AssistantContent> for Message {
fn from(content: AssistantContent) -> Self {
Message::Assistant {
id: None,
content: OneOrMany::one(content),
}
}
}
impl From<UserContent> for Message {
fn from(content: UserContent) -> Self {
Message::User {
content: OneOrMany::one(content),
}
}
}
impl From<OneOrMany<AssistantContent>> for Message {
fn from(content: OneOrMany<AssistantContent>) -> Self {
Message::Assistant { id: None, content }
}
}
impl From<OneOrMany<UserContent>> for Message {
fn from(content: OneOrMany<UserContent>) -> Self {
Message::User { content }
}
}
impl From<ToolCall> for Message {
fn from(tool_call: ToolCall) -> Self {
Message::Assistant {
id: None,
content: OneOrMany::one(AssistantContent::ToolCall(tool_call)),
}
}
}
impl From<ToolResult> for Message {
fn from(tool_result: ToolResult) -> Self {
Message::User {
content: OneOrMany::one(UserContent::ToolResult(tool_result)),
}
}
}
impl From<ToolResultContent> for Message {
fn from(tool_result_content: ToolResultContent) -> Self {
Message::User {
content: OneOrMany::one(UserContent::ToolResult(ToolResult {
id: String::new(),
call_id: None,
content: OneOrMany::one(tool_result_content),
})),
}
}
}
#[derive(Default, Clone, Debug, Deserialize, Serialize, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum ToolChoice {
#[default]
Auto,
None,
Required,
Specific {
function_names: Vec<String>,
},
}
#[derive(Debug, Error)]
pub enum MessageError {
#[error("Message conversion error: {0}")]
ConversionError(String),
}
impl From<MessageError> for CompletionError {
fn from(error: MessageError) -> Self {
CompletionError::RequestError(error.into())
}
}