use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use crate::error::LlmError;
use crate::types::{ContentPart, MessageContent};
#[allow(dead_code)]
pub struct MultimodalProcessor {
image_formats: Vec<ImageFormat>,
audio_formats: Vec<AudioFormat>,
document_formats: Vec<DocumentFormat>,
config: ProcessingConfig,
}
impl MultimodalProcessor {
pub fn new() -> Self {
Self {
image_formats: ImageFormat::all_supported(),
audio_formats: AudioFormat::all_supported(),
document_formats: DocumentFormat::all_supported(),
config: ProcessingConfig::default(),
}
}
pub fn process_content(&self, content: &MessageContent) -> Result<ProcessedContent, LlmError> {
match content {
MessageContent::Text(text) => Ok(ProcessedContent::Text(text.clone())),
MessageContent::MultiModal(parts) => {
let mut processed_parts = Vec::new();
for part in parts {
let processed_part = self.process_content_part(part)?;
processed_parts.push(processed_part);
}
Ok(ProcessedContent::MultiModal(processed_parts))
}
}
}
fn process_content_part(&self, part: &ContentPart) -> Result<ProcessedContentPart, LlmError> {
match part {
ContentPart::Text { text } => Ok(ProcessedContentPart::Text {
text: text.clone(),
metadata: ContentMetadata::default(),
}),
ContentPart::Image { image_url, detail } => {
let image_info = self.analyze_image(image_url)?;
let format = match image_info.format {
MediaFormat::Image(fmt) => fmt,
_ => ImageFormat::Jpeg, };
Ok(ProcessedContentPart::Image {
data: image_url.clone(),
format,
detail: detail.clone(),
metadata: image_info.metadata,
})
}
ContentPart::Audio { audio_url, format } => {
let audio_info = self.analyze_audio(audio_url, Some(format.as_str()))?;
let format = match audio_info.format {
MediaFormat::Audio(fmt) => fmt,
_ => AudioFormat::Wav, };
Ok(ProcessedContentPart::Audio {
data: audio_url.clone(),
format,
metadata: audio_info.metadata,
})
}
}
}
fn analyze_image(&self, image_data: &str) -> Result<MediaInfo, LlmError> {
let format = if image_data.starts_with("data:image/") {
let mime_type = image_data
.split(';')
.next()
.and_then(|s| s.strip_prefix("data:"))
.unwrap_or("image/jpeg");
ImageFormat::from_mime_type(mime_type)
} else {
ImageFormat::detect_from_base64(image_data)
};
let mut metadata = ContentMetadata::default();
metadata.insert(
"original_format".to_string(),
serde_json::Value::String(format.to_string()),
);
if let Ok(size) = self.estimate_data_size(image_data) {
metadata.insert(
"estimated_size_bytes".to_string(),
serde_json::Value::Number(size.into()),
);
}
Ok(MediaInfo {
format: MediaFormat::Image(format),
metadata,
})
}
fn analyze_audio(
&self,
audio_data: &str,
format_hint: Option<&str>,
) -> Result<MediaInfo, LlmError> {
let format = if let Some(hint) = format_hint {
AudioFormat::from_extension(hint)
} else if audio_data.starts_with("data:audio/") {
let mime_type = audio_data
.split(';')
.next()
.and_then(|s| s.strip_prefix("data:"))
.unwrap_or("audio/wav");
AudioFormat::from_mime_type(mime_type)
} else {
AudioFormat::Wav };
let mut metadata = ContentMetadata::default();
metadata.insert(
"original_format".to_string(),
serde_json::Value::String(format.to_string()),
);
if let Ok(size) = self.estimate_data_size(audio_data) {
metadata.insert(
"estimated_size_bytes".to_string(),
serde_json::Value::Number(size.into()),
);
}
Ok(MediaInfo {
format: MediaFormat::Audio(format),
metadata,
})
}
fn estimate_data_size(&self, data: &str) -> Result<u64, LlmError> {
let base64_data = if data.contains(',') {
data.split(',').nth(1).unwrap_or(data)
} else {
data
};
let base64_len = base64_data.len() as u64;
Ok((base64_len * 3) / 4)
}
}
impl Default for MultimodalProcessor {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProcessingConfig {
pub max_image_size: Option<u64>,
pub max_audio_duration: Option<u32>,
pub auto_compress: bool,
pub compression_quality: f32,
}
impl Default for ProcessingConfig {
fn default() -> Self {
Self {
max_image_size: Some(20 * 1024 * 1024), max_audio_duration: Some(300), auto_compress: true,
compression_quality: 0.8,
}
}
}
#[derive(Debug, Clone)]
pub enum ProcessedContent {
Text(String),
MultiModal(Vec<ProcessedContentPart>),
}
#[derive(Debug, Clone)]
pub enum ProcessedContentPart {
Text {
text: String,
metadata: ContentMetadata,
},
Image {
data: String,
format: ImageFormat,
detail: Option<String>,
metadata: ContentMetadata,
},
Audio {
data: String,
format: AudioFormat,
metadata: ContentMetadata,
},
Document {
data: String,
format: DocumentFormat,
metadata: ContentMetadata,
},
}
pub type ContentMetadata = HashMap<String, serde_json::Value>;
#[derive(Debug, Clone)]
pub struct MediaInfo {
pub format: MediaFormat,
pub metadata: ContentMetadata,
}
#[derive(Debug, Clone)]
pub enum MediaFormat {
Image(ImageFormat),
Audio(AudioFormat),
Document(DocumentFormat),
}
#[derive(Debug, Clone, PartialEq)]
pub enum ImageFormat {
Jpeg,
Png,
Gif,
WebP,
Bmp,
Tiff,
Svg,
}
impl ImageFormat {
pub fn all_supported() -> Vec<Self> {
vec![
Self::Jpeg,
Self::Png,
Self::Gif,
Self::WebP,
Self::Bmp,
Self::Tiff,
Self::Svg,
]
}
pub fn from_mime_type(mime_type: &str) -> Self {
match mime_type {
"image/jpeg" | "image/jpg" => Self::Jpeg,
"image/png" => Self::Png,
"image/gif" => Self::Gif,
"image/webp" => Self::WebP,
"image/bmp" => Self::Bmp,
"image/tiff" => Self::Tiff,
"image/svg+xml" => Self::Svg,
_ => Self::Jpeg, }
}
pub fn detect_from_base64(data: &str) -> Self {
if data.starts_with("/9j/") || data.starts_with("iVBOR") {
Self::Jpeg
} else if data.starts_with("iVBOR") {
Self::Png
} else if data.starts_with("R0lGOD") {
Self::Gif
} else if data.starts_with("UklGR") {
Self::WebP
} else {
Self::Jpeg }
}
pub const fn mime_type(&self) -> &'static str {
match self {
Self::Jpeg => "image/jpeg",
Self::Png => "image/png",
Self::Gif => "image/gif",
Self::WebP => "image/webp",
Self::Bmp => "image/bmp",
Self::Tiff => "image/tiff",
Self::Svg => "image/svg+xml",
}
}
}
impl std::fmt::Display for ImageFormat {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{self:?}")
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum AudioFormat {
Mp3,
Wav,
Flac,
Ogg,
M4a,
Webm,
}
impl AudioFormat {
pub fn all_supported() -> Vec<Self> {
vec![
Self::Mp3,
Self::Wav,
Self::Flac,
Self::Ogg,
Self::M4a,
Self::Webm,
]
}
pub fn from_mime_type(mime_type: &str) -> Self {
match mime_type {
"audio/mpeg" | "audio/mp3" => Self::Mp3,
"audio/wav" | "audio/wave" => Self::Wav,
"audio/flac" => Self::Flac,
"audio/ogg" => Self::Ogg,
"audio/m4a" => Self::M4a,
"audio/webm" => Self::Webm,
_ => Self::Wav, }
}
pub fn from_extension(ext: &str) -> Self {
match ext.to_lowercase().as_str() {
"mp3" => Self::Mp3,
"wav" => Self::Wav,
"flac" => Self::Flac,
"ogg" => Self::Ogg,
"m4a" => Self::M4a,
"webm" => Self::Webm,
_ => Self::Wav, }
}
pub const fn mime_type(&self) -> &'static str {
match self {
Self::Mp3 => "audio/mpeg",
Self::Wav => "audio/wav",
Self::Flac => "audio/flac",
Self::Ogg => "audio/ogg",
Self::M4a => "audio/m4a",
Self::Webm => "audio/webm",
}
}
}
impl std::fmt::Display for AudioFormat {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{self:?}")
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum DocumentFormat {
Pdf,
Docx,
Txt,
Md,
Html,
Csv,
Json,
Xml,
}
impl DocumentFormat {
pub fn all_supported() -> Vec<Self> {
vec![
Self::Pdf,
Self::Docx,
Self::Txt,
Self::Md,
Self::Html,
Self::Csv,
Self::Json,
Self::Xml,
]
}
pub fn from_mime_type(mime_type: &str) -> Self {
match mime_type {
"application/pdf" => Self::Pdf,
"application/vnd.openxmlformats-officedocument.wordprocessingml.document" => Self::Docx,
"text/plain" => Self::Txt,
"text/markdown" => Self::Md,
"text/html" => Self::Html,
"text/csv" => Self::Csv,
"application/json" => Self::Json,
"application/xml" | "text/xml" => Self::Xml,
_ => Self::Txt, }
}
pub const fn mime_type(&self) -> &'static str {
match self {
Self::Pdf => "application/pdf",
Self::Docx => "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
Self::Txt => "text/plain",
Self::Md => "text/markdown",
Self::Html => "text/html",
Self::Csv => "text/csv",
Self::Json => "application/json",
Self::Xml => "application/xml",
}
}
}
impl std::fmt::Display for DocumentFormat {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{self:?}")
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_image_format_detection() {
assert_eq!(ImageFormat::from_mime_type("image/jpeg"), ImageFormat::Jpeg);
assert_eq!(ImageFormat::from_mime_type("image/png"), ImageFormat::Png);
assert_eq!(ImageFormat::from_mime_type("image/gif"), ImageFormat::Gif);
}
#[test]
fn test_audio_format_detection() {
assert_eq!(AudioFormat::from_mime_type("audio/mpeg"), AudioFormat::Mp3);
assert_eq!(AudioFormat::from_extension("wav"), AudioFormat::Wav);
assert_eq!(AudioFormat::from_extension("flac"), AudioFormat::Flac);
}
#[test]
fn test_multimodal_processor() {
let processor = MultimodalProcessor::new();
let text_content = MessageContent::Text("Hello world".to_string());
let processed = processor.process_content(&text_content).unwrap();
match processed {
ProcessedContent::Text(text) => assert_eq!(text, "Hello world"),
_ => panic!("Expected text content"),
}
}
}