use super::types::{FusionStrategy, Modality};
#[cfg(feature = "serialize")]
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serialize", derive(Serialize, Deserialize))]
pub struct MultimodalConfig {
pub required_modalities: Vec<Modality>,
pub optional_modalities: Vec<Modality>,
pub fusion_strategy: FusionStrategy,
pub max_text_length: Option<usize>,
pub target_image_size: Option<(u32, u32)>,
pub target_audio_sample_rate: Option<u32>,
pub pad_missing_modalities: bool,
pub strict_validation: bool,
pub cache_fused_results: bool,
pub max_cache_size: usize,
}
impl Default for MultimodalConfig {
fn default() -> Self {
Self {
required_modalities: vec![Modality::Text],
optional_modalities: vec![Modality::Image],
fusion_strategy: FusionStrategy::Concatenation,
max_text_length: Some(512),
target_image_size: Some((224, 224)),
target_audio_sample_rate: Some(16000),
pad_missing_modalities: true,
strict_validation: true,
cache_fused_results: false,
max_cache_size: 1000,
}
}
}
impl MultimodalConfig {
pub fn minimal() -> Self {
Self {
required_modalities: vec![],
optional_modalities: vec![Modality::Text, Modality::Image],
fusion_strategy: FusionStrategy::Concatenation,
max_text_length: None,
target_image_size: None,
target_audio_sample_rate: None,
pad_missing_modalities: false,
strict_validation: false,
cache_fused_results: false,
max_cache_size: 0,
}
}
pub fn text_only() -> Self {
Self {
required_modalities: vec![Modality::Text],
optional_modalities: vec![],
fusion_strategy: FusionStrategy::Concatenation,
max_text_length: Some(512),
target_image_size: None,
target_audio_sample_rate: None,
pad_missing_modalities: false,
strict_validation: true,
cache_fused_results: false,
max_cache_size: 0,
}
}
pub fn vision_language() -> Self {
Self {
required_modalities: vec![Modality::Text, Modality::Image],
optional_modalities: vec![],
fusion_strategy: FusionStrategy::Concatenation,
max_text_length: Some(256),
target_image_size: Some((224, 224)),
target_audio_sample_rate: None,
pad_missing_modalities: true,
strict_validation: true,
cache_fused_results: true,
max_cache_size: 5000,
}
}
pub fn validate_modalities(&self, available: &[Modality]) -> Result<(), String> {
for required in &self.required_modalities {
if !available.contains(required) {
return Err(format!(
"Required modality {:?} not found in sample",
required
));
}
}
Ok(())
}
pub fn all_modalities(&self) -> Vec<Modality> {
let mut all = self.required_modalities.clone();
all.extend(self.optional_modalities.clone());
all.sort_by_key(|m| format!("{:?}", m)); all.dedup();
all
}
pub fn is_expected_modality(&self, modality: &Modality) -> bool {
self.required_modalities.contains(modality) || self.optional_modalities.contains(modality)
}
}