use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum ImageFormat {
Jpeg,
Png,
Webp,
Bmp,
Tiff,
}
impl ImageFormat {
pub fn extension(&self) -> &'static str {
match self {
ImageFormat::Jpeg => "jpg",
ImageFormat::Png => "png",
ImageFormat::Webp => "webp",
ImageFormat::Bmp => "bmp",
ImageFormat::Tiff => "tiff",
}
}
pub fn mime_type(&self) -> &'static str {
match self {
ImageFormat::Jpeg => "image/jpeg",
ImageFormat::Png => "image/png",
ImageFormat::Webp => "image/webp",
ImageFormat::Bmp => "image/bmp",
ImageFormat::Tiff => "image/tiff",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum DocumentFormat {
Pdf,
Html,
Text,
Markdown,
Docx,
Image,
}
impl DocumentFormat {
pub fn extension(&self) -> &'static str {
match self {
DocumentFormat::Pdf => "pdf",
DocumentFormat::Html => "html",
DocumentFormat::Text => "txt",
DocumentFormat::Markdown => "md",
DocumentFormat::Docx => "docx",
DocumentFormat::Image => "png",
}
}
pub fn mime_type(&self) -> &'static str {
match self {
DocumentFormat::Pdf => "application/pdf",
DocumentFormat::Html => "text/html",
DocumentFormat::Text => "text/plain",
DocumentFormat::Markdown => "text/markdown",
DocumentFormat::Docx => {
"application/vnd.openxmlformats-officedocument.wordprocessingml.document"
},
DocumentFormat::Image => "image/png",
}
}
}
#[derive(Debug, Clone)]
pub enum FeatureInput {
Image {
data: Vec<u8>,
format: ImageFormat,
metadata: Option<ImageMetadata>,
},
Audio {
samples: Vec<f32>,
sample_rate: u32,
metadata: Option<AudioMetadata>,
},
Text {
content: String,
metadata: Option<TextMetadata>,
},
Document {
content: Vec<u8>,
format: DocumentFormat,
metadata: Option<DocumentMetadata>,
},
Multimodal {
inputs: Vec<FeatureInput>,
metadata: Option<MultimodalMetadata>,
},
}
impl FeatureInput {
pub fn modality(&self) -> &'static str {
match self {
FeatureInput::Image { .. } => "image",
FeatureInput::Audio { .. } => "audio",
FeatureInput::Text { .. } => "text",
FeatureInput::Document { .. } => "document",
FeatureInput::Multimodal { .. } => "multimodal",
}
}
pub fn has_metadata(&self) -> bool {
match self {
FeatureInput::Image { metadata, .. } => metadata.is_some(),
FeatureInput::Audio { metadata, .. } => metadata.is_some(),
FeatureInput::Text { metadata, .. } => metadata.is_some(),
FeatureInput::Document { metadata, .. } => metadata.is_some(),
FeatureInput::Multimodal { metadata, .. } => metadata.is_some(),
}
}
}
#[derive(Debug, Clone)]
pub struct FeatureOutput {
pub features: Vec<f32>,
pub shape: Vec<usize>,
pub metadata: HashMap<String, serde_json::Value>,
pub attention_mask: Option<Vec<u32>>,
pub special_tokens: Vec<SpecialToken>,
}
impl FeatureOutput {
pub fn new(features: Vec<f32>, shape: Vec<usize>) -> Self {
Self {
features,
shape,
metadata: HashMap::new(),
attention_mask: None,
special_tokens: Vec::new(),
}
}
pub fn feature_count(&self) -> usize {
self.features.len()
}
pub fn feature_dimension(&self) -> usize {
self.shape.iter().product()
}
pub fn with_metadata(mut self, key: String, value: serde_json::Value) -> Self {
self.metadata.insert(key, value);
self
}
pub fn with_attention_mask(mut self, mask: Vec<u32>) -> Self {
self.attention_mask = Some(mask);
self
}
pub fn with_special_tokens(mut self, tokens: Vec<SpecialToken>) -> Self {
self.special_tokens = tokens;
self
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct SpecialToken {
pub token_type: String,
pub position: usize,
pub value: String,
}
impl SpecialToken {
pub fn new(token_type: impl Into<String>, position: usize, value: impl Into<String>) -> Self {
Self {
token_type: token_type.into(),
position,
value: value.into(),
}
}
pub fn is_cls_token(&self) -> bool {
self.token_type.eq_ignore_ascii_case("cls")
}
pub fn is_sep_token(&self) -> bool {
self.token_type.eq_ignore_ascii_case("sep")
}
pub fn is_pad_token(&self) -> bool {
self.token_type.eq_ignore_ascii_case("pad")
}
pub fn is_mask_token(&self) -> bool {
self.token_type.eq_ignore_ascii_case("mask")
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct ImageMetadata {
pub width: u32,
pub height: u32,
pub channels: u32,
pub dpi: Option<u32>,
}
impl ImageMetadata {
pub fn new(width: u32, height: u32, channels: u32) -> Self {
Self {
width,
height,
channels,
dpi: None,
}
}
pub fn pixel_count(&self) -> u64 {
self.width as u64 * self.height as u64
}
pub fn aspect_ratio(&self) -> f64 {
self.width as f64 / self.height as f64
}
pub fn is_grayscale(&self) -> bool {
self.channels == 1
}
pub fn has_alpha(&self) -> bool {
self.channels == 2 || self.channels == 4
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct AudioMetadata {
pub duration: f64,
pub channels: u32,
pub bit_depth: Option<u32>,
}
impl AudioMetadata {
pub fn new(duration: f64, channels: u32) -> Self {
Self {
duration,
channels,
bit_depth: None,
}
}
pub fn is_mono(&self) -> bool {
self.channels == 1
}
pub fn is_stereo(&self) -> bool {
self.channels == 2
}
pub fn duration_ms(&self) -> f64 {
self.duration * 1000.0
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct TextMetadata {
pub language: Option<String>,
pub encoding: Option<String>,
pub word_count: Option<usize>,
}
impl TextMetadata {
pub fn new() -> Self {
Self {
language: None,
encoding: None,
word_count: None,
}
}
pub fn with_language(mut self, language: impl Into<String>) -> Self {
self.language = Some(language.into());
self
}
pub fn with_encoding(mut self, encoding: impl Into<String>) -> Self {
self.encoding = Some(encoding.into());
self
}
pub fn with_word_count(mut self, count: usize) -> Self {
self.word_count = Some(count);
self
}
}
impl Default for TextMetadata {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct DocumentMetadata {
pub page_count: Option<usize>,
pub author: Option<String>,
pub title: Option<String>,
pub creation_date: Option<String>,
}
impl DocumentMetadata {
pub fn new() -> Self {
Self {
page_count: None,
author: None,
title: None,
creation_date: None,
}
}
pub fn with_page_count(mut self, count: usize) -> Self {
self.page_count = Some(count);
self
}
pub fn with_author(mut self, author: impl Into<String>) -> Self {
self.author = Some(author.into());
self
}
pub fn with_title(mut self, title: impl Into<String>) -> Self {
self.title = Some(title.into());
self
}
pub fn with_creation_date(mut self, date: impl Into<String>) -> Self {
self.creation_date = Some(date.into());
self
}
}
impl Default for DocumentMetadata {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct MultimodalMetadata {
pub modalities: Vec<String>,
pub fusion_strategy: Option<String>,
}
impl MultimodalMetadata {
pub fn new(modalities: Vec<String>) -> Self {
Self {
modalities,
fusion_strategy: None,
}
}
pub fn with_fusion_strategy(mut self, strategy: impl Into<String>) -> Self {
self.fusion_strategy = Some(strategy.into());
self
}
pub fn has_modality(&self, modality: &str) -> bool {
self.modalities.contains(&modality.to_string())
}
pub fn modality_count(&self) -> usize {
self.modalities.len()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
pub enum PaddingStrategy {
None,
#[default]
Longest,
MaxLength,
DoNotPad,
}
impl PaddingStrategy {
pub fn should_pad(&self) -> bool {
matches!(self, PaddingStrategy::Longest | PaddingStrategy::MaxLength)
}
pub fn is_dynamic(&self) -> bool {
matches!(self, PaddingStrategy::Longest)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DataExample {
pub input_ids: Vec<u32>,
pub attention_mask: Option<Vec<u32>>,
pub token_type_ids: Option<Vec<u32>>,
pub labels: Option<Vec<i64>>,
pub metadata: HashMap<String, serde_json::Value>,
}
impl DataExample {
pub fn new(input_ids: Vec<u32>) -> Self {
Self {
input_ids,
attention_mask: None,
token_type_ids: None,
labels: None,
metadata: HashMap::new(),
}
}
pub fn sequence_length(&self) -> usize {
self.input_ids.len()
}
pub fn with_attention_mask(mut self, mask: Vec<u32>) -> Self {
self.attention_mask = Some(mask);
self
}
pub fn with_token_type_ids(mut self, type_ids: Vec<u32>) -> Self {
self.token_type_ids = Some(type_ids);
self
}
pub fn with_labels(mut self, labels: Vec<i64>) -> Self {
self.labels = Some(labels);
self
}
pub fn with_metadata(mut self, key: String, value: serde_json::Value) -> Self {
self.metadata.insert(key, value);
self
}
pub fn has_labels(&self) -> bool {
self.labels.is_some()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CollatedBatch {
pub input_ids: Vec<Vec<u32>>,
pub attention_mask: Vec<Vec<u32>>,
pub token_type_ids: Option<Vec<Vec<u32>>>,
pub labels: Option<Vec<Vec<i64>>>,
pub batch_size: usize,
pub sequence_length: usize,
pub metadata: HashMap<String, serde_json::Value>,
}
impl CollatedBatch {
pub fn new(
input_ids: Vec<Vec<u32>>,
attention_mask: Vec<Vec<u32>>,
batch_size: usize,
sequence_length: usize,
) -> Self {
Self {
input_ids,
attention_mask,
token_type_ids: None,
labels: None,
batch_size,
sequence_length,
metadata: HashMap::new(),
}
}
pub fn total_tokens(&self) -> usize {
self.batch_size * self.sequence_length
}
pub fn input_shape(&self) -> (usize, usize) {
(self.batch_size, self.sequence_length)
}
pub fn has_token_type_ids(&self) -> bool {
self.token_type_ids.is_some()
}
pub fn has_labels(&self) -> bool {
self.labels.is_some()
}
pub fn with_token_type_ids(mut self, token_type_ids: Vec<Vec<u32>>) -> Self {
self.token_type_ids = Some(token_type_ids);
self
}
pub fn with_labels(mut self, labels: Vec<Vec<i64>>) -> Self {
self.labels = Some(labels);
self
}
pub fn with_metadata(mut self, key: String, value: serde_json::Value) -> Self {
self.metadata.insert(key, value);
self
}
}
pub mod utils {
use super::*;
pub fn feature_output_memory_size(output: &FeatureOutput) -> usize {
let features_size = output.features.len() * std::mem::size_of::<f32>();
let shape_size = output.shape.len() * std::mem::size_of::<usize>();
let attention_mask_size = output
.attention_mask
.as_ref()
.map(|mask| mask.len() * std::mem::size_of::<u32>())
.unwrap_or(0);
let special_tokens_size = output.special_tokens.len() * std::mem::size_of::<SpecialToken>();
features_size + shape_size + attention_mask_size + special_tokens_size
}
pub fn collated_batch_memory_size(batch: &CollatedBatch) -> usize {
let input_ids_size = batch.batch_size * batch.sequence_length * std::mem::size_of::<u32>();
let attention_mask_size =
batch.batch_size * batch.sequence_length * std::mem::size_of::<u32>();
let token_type_ids_size = batch
.token_type_ids
.as_ref()
.map(|_| batch.batch_size * batch.sequence_length * std::mem::size_of::<u32>())
.unwrap_or(0);
let labels_size = batch
.labels
.as_ref()
.map(|labels| {
labels.iter().map(|l| l.len()).sum::<usize>() * std::mem::size_of::<i64>()
})
.unwrap_or(0);
input_ids_size + attention_mask_size + token_type_ids_size + labels_size
}
pub fn validate_feature_output(output: &FeatureOutput) -> Result<(), String> {
let expected_size: usize = output.shape.iter().product();
if output.features.len() != expected_size {
return Err(format!(
"Feature vector size {} does not match shape {:?} (expected {})",
output.features.len(),
output.shape,
expected_size
));
}
if let Some(mask) = &output.attention_mask {
if !output.shape.is_empty() && mask.len() != output.shape[0] {
return Err(format!(
"Attention mask length {} does not match first dimension of shape {:?}",
mask.len(),
output.shape
));
}
}
Ok(())
}
pub fn validate_collated_batch(batch: &CollatedBatch) -> Result<(), String> {
if batch.input_ids.len() != batch.batch_size {
return Err(format!(
"Input IDs batch size {} does not match expected batch size {}",
batch.input_ids.len(),
batch.batch_size
));
}
if batch.attention_mask.len() != batch.batch_size {
return Err(format!(
"Attention mask batch size {} does not match expected batch size {}",
batch.attention_mask.len(),
batch.batch_size
));
}
for (i, input_ids) in batch.input_ids.iter().enumerate() {
if input_ids.len() != batch.sequence_length {
return Err(format!(
"Input IDs sequence {} has length {} but expected {}",
i,
input_ids.len(),
batch.sequence_length
));
}
}
for (i, attention_mask) in batch.attention_mask.iter().enumerate() {
if attention_mask.len() != batch.sequence_length {
return Err(format!(
"Attention mask sequence {} has length {} but expected {}",
i,
attention_mask.len(),
batch.sequence_length
));
}
}
if let Some(token_type_ids) = &batch.token_type_ids {
if token_type_ids.len() != batch.batch_size {
return Err(format!(
"Token type IDs batch size {} does not match expected batch size {}",
token_type_ids.len(),
batch.batch_size
));
}
for (i, type_ids) in token_type_ids.iter().enumerate() {
if type_ids.len() != batch.sequence_length {
return Err(format!(
"Token type IDs sequence {} has length {} but expected {}",
i,
type_ids.len(),
batch.sequence_length
));
}
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_image_format_properties() {
assert_eq!(ImageFormat::Jpeg.extension(), "jpg");
assert_eq!(ImageFormat::Jpeg.mime_type(), "image/jpeg");
assert_eq!(ImageFormat::Png.extension(), "png");
assert_eq!(ImageFormat::Png.mime_type(), "image/png");
}
#[test]
fn test_document_format_properties() {
assert_eq!(DocumentFormat::Pdf.extension(), "pdf");
assert_eq!(DocumentFormat::Pdf.mime_type(), "application/pdf");
assert_eq!(DocumentFormat::Html.extension(), "html");
assert_eq!(DocumentFormat::Html.mime_type(), "text/html");
}
#[test]
fn test_feature_input_modality() {
let image_input = FeatureInput::Image {
data: vec![1, 2, 3],
format: ImageFormat::Jpeg,
metadata: None,
};
assert_eq!(image_input.modality(), "image");
let text_input = FeatureInput::Text {
content: "Hello world".to_string(),
metadata: None,
};
assert_eq!(text_input.modality(), "text");
}
#[test]
fn test_special_token_creation() {
let token = SpecialToken::new("CLS", 0, "[CLS]");
assert_eq!(token.token_type, "CLS");
assert_eq!(token.position, 0);
assert_eq!(token.value, "[CLS]");
assert!(token.is_cls_token());
assert!(!token.is_sep_token());
}
#[test]
fn test_image_metadata_properties() {
let metadata = ImageMetadata::new(640, 480, 3);
assert_eq!(metadata.pixel_count(), 307200);
assert!((metadata.aspect_ratio() - 1.333333).abs() < 0.000001);
assert!(!metadata.is_grayscale());
assert!(!metadata.has_alpha());
let grayscale = ImageMetadata::new(100, 100, 1);
assert!(grayscale.is_grayscale());
let rgba = ImageMetadata::new(100, 100, 4);
assert!(rgba.has_alpha());
}
#[test]
fn test_audio_metadata_properties() {
let metadata = AudioMetadata::new(5.5, 2);
assert!(metadata.is_stereo());
assert!(!metadata.is_mono());
assert_eq!(metadata.duration_ms(), 5500.0);
let mono = AudioMetadata::new(3.0, 1);
assert!(mono.is_mono());
assert!(!mono.is_stereo());
}
#[test]
fn test_padding_strategy_properties() {
assert!(PaddingStrategy::Longest.should_pad());
assert!(PaddingStrategy::MaxLength.should_pad());
assert!(!PaddingStrategy::None.should_pad());
assert!(!PaddingStrategy::DoNotPad.should_pad());
assert!(PaddingStrategy::Longest.is_dynamic());
assert!(!PaddingStrategy::MaxLength.is_dynamic());
}
#[test]
fn test_data_example_creation() {
let example = DataExample::new(vec![1, 2, 3, 4])
.with_attention_mask(vec![1, 1, 1, 1])
.with_labels(vec![0]);
assert_eq!(example.sequence_length(), 4);
assert!(example.has_labels());
assert_eq!(example.attention_mask, Some(vec![1, 1, 1, 1]));
}
#[test]
fn test_collated_batch_properties() {
let batch = CollatedBatch::new(
vec![vec![1, 2, 3], vec![4, 5, 6]],
vec![vec![1, 1, 1], vec![1, 1, 1]],
2,
3,
);
assert_eq!(batch.batch_size, 2);
assert_eq!(batch.sequence_length, 3);
assert_eq!(batch.total_tokens(), 6);
assert_eq!(batch.input_shape(), (2, 3));
assert!(!batch.has_labels());
assert!(!batch.has_token_type_ids());
}
#[test]
fn test_feature_output_validation() {
let output = FeatureOutput::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]);
assert!(utils::validate_feature_output(&output).is_ok());
let invalid_output = FeatureOutput::new(vec![1.0, 2.0, 3.0], vec![2, 2]);
assert!(utils::validate_feature_output(&invalid_output).is_err());
}
#[test]
fn test_collated_batch_validation() {
let valid_batch = CollatedBatch::new(
vec![vec![1, 2], vec![3, 4]],
vec![vec![1, 1], vec![1, 1]],
2,
2,
);
assert!(utils::validate_collated_batch(&valid_batch).is_ok());
let invalid_batch = CollatedBatch::new(
vec![vec![1, 2, 3], vec![4, 5]],
vec![vec![1, 1], vec![1, 1]],
2,
2,
);
assert!(utils::validate_collated_batch(&invalid_batch).is_err());
}
#[test]
fn test_multimodal_metadata() {
let metadata = MultimodalMetadata::new(vec!["text".to_string(), "image".to_string()])
.with_fusion_strategy("late_fusion");
assert_eq!(metadata.modality_count(), 2);
assert!(metadata.has_modality("text"));
assert!(metadata.has_modality("image"));
assert!(!metadata.has_modality("audio"));
assert_eq!(metadata.fusion_strategy, Some("late_fusion".to_string()));
}
}