use super::types::Modality;
use std::collections::HashMap;
use tenflowers_core::Tensor;
#[derive(Debug, Clone)]
pub struct MultimodalSample<T> {
pub text: Option<Tensor<T>>,
pub image: Option<Tensor<T>>,
pub audio: Option<Tensor<T>>,
pub video: Option<Tensor<T>>,
pub embeddings: Option<Tensor<T>>,
pub custom: HashMap<String, Tensor<T>>,
pub metadata: HashMap<String, String>,
pub label: Tensor<T>,
}
impl<T> MultimodalSample<T>
where
T: Clone + Default,
{
pub fn new(label: Tensor<T>) -> Self {
Self {
text: None,
image: None,
audio: None,
video: None,
embeddings: None,
custom: HashMap::new(),
metadata: HashMap::new(),
label,
}
}
pub fn with_text(mut self, text: Tensor<T>) -> Self {
self.text = Some(text);
self
}
pub fn with_image(mut self, image: Tensor<T>) -> Self {
self.image = Some(image);
self
}
pub fn with_audio(mut self, audio: Tensor<T>) -> Self {
self.audio = Some(audio);
self
}
pub fn with_video(mut self, video: Tensor<T>) -> Self {
self.video = Some(video);
self
}
pub fn with_embeddings(mut self, embeddings: Tensor<T>) -> Self {
self.embeddings = Some(embeddings);
self
}
pub fn with_custom(mut self, key: String, data: Tensor<T>) -> Self {
self.custom.insert(key, data);
self
}
pub fn with_metadata(mut self, key: String, value: String) -> Self {
self.metadata.insert(key, value);
self
}
pub fn available_modalities(&self) -> Vec<Modality> {
let mut modalities = Vec::new();
if self.text.is_some() {
modalities.push(Modality::Text);
}
if self.image.is_some() {
modalities.push(Modality::Image);
}
if self.audio.is_some() {
modalities.push(Modality::Audio);
}
if self.video.is_some() {
modalities.push(Modality::Video);
}
if self.embeddings.is_some() {
modalities.push(Modality::Embeddings);
}
for key in self.custom.keys() {
modalities.push(Modality::Custom(key.clone()));
}
modalities
}
pub fn has_modality(&self, modality: &Modality) -> bool {
match modality {
Modality::Text => self.text.is_some(),
Modality::Image => self.image.is_some(),
Modality::Audio => self.audio.is_some(),
Modality::Video => self.video.is_some(),
Modality::Embeddings => self.embeddings.is_some(),
Modality::Custom(key) => self.custom.contains_key(key),
}
}
pub fn get_modality(&self, modality: &Modality) -> Option<&Tensor<T>> {
match modality {
Modality::Text => self.text.as_ref(),
Modality::Image => self.image.as_ref(),
Modality::Audio => self.audio.as_ref(),
Modality::Video => self.video.as_ref(),
Modality::Embeddings => self.embeddings.as_ref(),
Modality::Custom(key) => self.custom.get(key),
}
}
pub fn get_modality_mut(&mut self, modality: &Modality) -> Option<&mut Tensor<T>> {
match modality {
Modality::Text => self.text.as_mut(),
Modality::Image => self.image.as_mut(),
Modality::Audio => self.audio.as_mut(),
Modality::Video => self.video.as_mut(),
Modality::Embeddings => self.embeddings.as_mut(),
Modality::Custom(key) => self.custom.get_mut(key),
}
}
pub fn remove_modality(&mut self, modality: &Modality) -> Option<Tensor<T>> {
match modality {
Modality::Text => self.text.take(),
Modality::Image => self.image.take(),
Modality::Audio => self.audio.take(),
Modality::Video => self.video.take(),
Modality::Embeddings => self.embeddings.take(),
Modality::Custom(key) => self.custom.remove(key),
}
}
pub fn modality_count(&self) -> usize {
self.available_modalities().len()
}
pub fn is_empty(&self) -> bool {
self.modality_count() == 0
}
}