use serde::{Deserialize, Serialize};
use std::collections::HashSet;
use crate::protocol::v2::manifest::MultimodalConfig;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum Modality {
Text,
Image,
Audio,
Video,
}
impl Modality {
pub fn as_str(&self) -> &'static str {
match self {
Self::Text => "text",
Self::Image => "image",
Self::Audio => "audio",
Self::Video => "video",
}
}
}
impl std::str::FromStr for Modality {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"text" => Ok(Self::Text),
"image" => Ok(Self::Image),
"audio" => Ok(Self::Audio),
"video" => Ok(Self::Video),
_ => Err(format!("Unknown modality: {}", s)),
}
}
}
#[derive(Debug, Clone)]
pub struct MultimodalCapabilities {
pub input_modalities: HashSet<Modality>,
pub output_modalities: HashSet<Modality>,
pub image_formats: Vec<String>,
pub audio_formats: Vec<String>,
pub video_formats: Vec<String>,
pub max_image_size: Option<String>,
pub max_audio_duration: Option<String>,
pub supports_omni: bool,
pub supports_realtime_voice: bool,
}
impl MultimodalCapabilities {
pub fn from_config(config: &MultimodalConfig) -> Self {
let mut input_modalities = HashSet::new();
let mut output_modalities = HashSet::new();
let mut image_formats = Vec::new();
let mut audio_formats = Vec::new();
let mut video_formats = Vec::new();
let mut max_image_size = None;
let mut max_audio_duration = None;
input_modalities.insert(Modality::Text);
output_modalities.insert(Modality::Text);
if let Some(input) = &config.input {
if let Some(vision) = &input.vision {
if vision.supported {
input_modalities.insert(Modality::Image);
image_formats = vision.formats.clone();
max_image_size = vision.max_file_size.clone();
}
}
if let Some(audio) = &input.audio {
if audio.supported {
input_modalities.insert(Modality::Audio);
audio_formats = audio.formats.clone();
}
}
if let Some(video) = &input.video {
if video.supported {
input_modalities.insert(Modality::Video);
video_formats = video.formats.clone();
max_audio_duration.clone_from(&video.formats.first().map(|_| "".to_string()));
}
}
}
if let Some(output) = &config.output {
if let Some(audio_out) = &output.audio {
if audio_out.supported {
output_modalities.insert(Modality::Audio);
}
}
if let Some(image_out) = &output.image {
if image_out.supported {
output_modalities.insert(Modality::Image);
}
}
if let Some(video_out) = &output.video {
if video_out.supported {
output_modalities.insert(Modality::Video);
if video_formats.is_empty() {
video_formats = video_out.formats.clone();
}
}
}
}
let supports_omni = config
.omni_mode
.as_ref()
.map(|o| o.supported)
.unwrap_or(false);
let supports_realtime_voice = config
.omni_mode
.as_ref()
.map(|o| o.real_time_voice_chat)
.unwrap_or(false);
Self {
input_modalities,
output_modalities,
image_formats,
audio_formats,
video_formats,
max_image_size,
max_audio_duration,
supports_omni,
supports_realtime_voice,
}
}
pub fn supports_input(&self, modality: Modality) -> bool {
self.input_modalities.contains(&modality)
}
pub fn supports_output(&self, modality: Modality) -> bool {
self.output_modalities.contains(&modality)
}
pub fn validate_image_format(&self, format: &str) -> bool {
if self.image_formats.is_empty() {
return true; }
self.image_formats
.iter()
.any(|f| f.eq_ignore_ascii_case(format))
}
pub fn validate_audio_format(&self, format: &str) -> bool {
if self.audio_formats.is_empty() {
return true;
}
self.audio_formats
.iter()
.any(|f| f.eq_ignore_ascii_case(format))
}
pub fn validate_video_format(&self, format: &str) -> bool {
if self.video_formats.is_empty() {
return true;
}
self.video_formats
.iter()
.any(|f| f.eq_ignore_ascii_case(format))
}
}
pub fn detect_modalities(content_blocks: &[serde_json::Value]) -> HashSet<Modality> {
let mut modalities = HashSet::new();
for block in content_blocks {
if let Some(block_type) = block.get("type").and_then(|t| t.as_str()) {
match block_type {
"text" => {
modalities.insert(Modality::Text);
}
"image" | "image_url" => {
modalities.insert(Modality::Image);
}
"audio" | "input_audio" => {
modalities.insert(Modality::Audio);
}
"video" => {
modalities.insert(Modality::Video);
}
_ => {}
}
}
}
if modalities.is_empty() {
modalities.insert(Modality::Text);
}
modalities
}
pub fn validate_content_modalities(
blocks: &[serde_json::Value],
caps: &MultimodalCapabilities,
) -> Result<(), Vec<Modality>> {
let detected = detect_modalities(blocks);
let unsupported: Vec<Modality> = detected
.into_iter()
.filter(|m| !caps.supports_input(*m))
.collect();
if unsupported.is_empty() {
Ok(())
} else {
Err(unsupported)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::protocol::v2::manifest::*;
fn make_config() -> MultimodalConfig {
MultimodalConfig {
input: Some(MultimodalInput {
vision: Some(VisionConfig {
supported: true,
formats: vec!["jpeg".into(), "png".into(), "webp".into()],
encoding_methods: vec!["base64_inline".into(), "url".into()],
document_understanding: false,
max_file_size: Some("20MB".into()),
max_resolution: None,
}),
audio: Some(AudioInputConfig {
supported: true,
formats: vec!["mp3".into(), "wav".into()],
real_time_streaming: false,
speech_recognition: true,
}),
video: None,
}),
output: Some(MultimodalOutput {
text: true,
audio: Some(AudioOutputConfig {
supported: true,
real_time_tts: false,
natural_voice: true,
voice_selection: true,
}),
image: None,
video: Some(VideoOutputConfig {
supported: true,
formats: vec!["mp4".into()],
max_duration: None,
max_resolution: None,
}),
}),
omni_mode: None,
}
}
#[test]
fn test_from_config() {
let caps = MultimodalCapabilities::from_config(&make_config());
assert!(caps.supports_input(Modality::Text));
assert!(caps.supports_input(Modality::Image));
assert!(caps.supports_input(Modality::Audio));
assert!(!caps.supports_input(Modality::Video));
assert!(caps.supports_output(Modality::Audio));
assert!(!caps.supports_output(Modality::Image));
assert!(caps.supports_output(Modality::Video));
}
#[test]
fn test_validate_image_format() {
let caps = MultimodalCapabilities::from_config(&make_config());
assert!(caps.validate_image_format("jpeg"));
assert!(caps.validate_image_format("PNG")); assert!(!caps.validate_image_format("bmp"));
}
#[test]
fn test_validate_audio_format() {
let caps = MultimodalCapabilities::from_config(&make_config());
assert!(caps.validate_audio_format("mp3"));
assert!(!caps.validate_audio_format("flac"));
}
#[test]
fn test_detect_modalities() {
let blocks = vec![
serde_json::json!({"type": "text", "text": "Hello"}),
serde_json::json!({"type": "image", "source": {}}),
];
let mods = detect_modalities(&blocks);
assert!(mods.contains(&Modality::Text));
assert!(mods.contains(&Modality::Image));
assert!(!mods.contains(&Modality::Audio));
}
#[test]
fn test_validate_content_modalities_ok() {
let caps = MultimodalCapabilities::from_config(&make_config());
let blocks = vec![
serde_json::json!({"type": "text", "text": "Describe this image"}),
serde_json::json!({"type": "image", "source": {"type": "url", "data": "http://..."}}),
];
assert!(validate_content_modalities(&blocks, &caps).is_ok());
}
#[test]
fn test_validate_content_modalities_fail() {
let caps = MultimodalCapabilities::from_config(&make_config());
let blocks = vec![serde_json::json!({"type": "video", "source": {}})];
let err = validate_content_modalities(&blocks, &caps).unwrap_err();
assert!(err.contains(&Modality::Video));
}
}