use crate::auto::types::{FeatureInput, FeatureOutput};
use crate::error::Result;
use std::collections::HashMap;
mod audio;
mod document;
mod generic;
mod vision;
pub use audio::{AudioFeatureConfig, AudioFeatureExtractor};
pub use document::{DocumentFeatureConfig, DocumentFeatureExtractor};
pub use generic::{GenericFeatureConfig, GenericFeatureExtractor};
pub use vision::{VisionFeatureConfig, VisionFeatureExtractor};
#[derive(Debug, Clone)]
pub struct AutoFeatureExtractor;
impl AutoFeatureExtractor {
pub fn from_pretrained(model_name_or_path: &str) -> Result<Box<dyn FeatureExtractor>> {
let config = crate::hub::load_config_from_hub(model_name_or_path, None)?;
let model_type = config.get("model_type").and_then(|v| v.as_str()).unwrap_or("unknown");
match model_type {
"clip" | "blip" | "vit" => Ok(Box::new(VisionFeatureExtractor::new(
VisionFeatureConfig::from_config(&config)?,
))),
"wav2vec2" | "whisper" | "hubert" => Ok(Box::new(AudioFeatureExtractor::new(
AudioFeatureConfig::from_config(&config)?,
))),
"layoutlm" | "donut" => Ok(Box::new(DocumentFeatureExtractor::new(
DocumentFeatureConfig::from_config(&config)?,
))),
_ => Ok(Box::new(GenericFeatureExtractor::new(
GenericFeatureConfig::from_config(&config)?,
))),
}
}
pub fn for_task(
task: &str,
model_config: &serde_json::Value,
) -> Result<Box<dyn FeatureExtractor>> {
match task {
"image-classification" | "object-detection" | "image-to-text" => Ok(Box::new(
VisionFeatureExtractor::new(VisionFeatureConfig::from_config(model_config)?),
)),
"automatic-speech-recognition" | "audio-classification" => Ok(Box::new(
AudioFeatureExtractor::new(AudioFeatureConfig::from_config(model_config)?),
)),
"document-understanding" | "document-question-answering" => Ok(Box::new(
DocumentFeatureExtractor::new(DocumentFeatureConfig::from_config(model_config)?),
)),
_ => Ok(Box::new(GenericFeatureExtractor::new(
GenericFeatureConfig::from_config(model_config)?,
))),
}
}
pub fn supported_model_types() -> Vec<&'static str> {
vec![
"clip", "blip", "vit", "wav2vec2", "whisper", "hubert", "layoutlm", "donut", "generic", ]
}
pub fn supported_tasks() -> Vec<&'static str> {
vec![
"image-classification",
"object-detection",
"image-to-text",
"automatic-speech-recognition",
"audio-classification",
"document-understanding",
"document-question-answering",
"text-classification",
]
}
}
pub trait FeatureExtractor: Send + Sync {
fn extract_features(&self, input: &FeatureInput) -> Result<FeatureOutput>;
fn config(&self) -> &dyn FeatureExtractorConfig;
fn preprocess(&self, input: &FeatureInput) -> Result<FeatureInput> {
Ok(input.clone())
}
fn postprocess(&self, features: FeatureOutput) -> Result<FeatureOutput> {
Ok(features)
}
fn supports_input(&self, _input: &FeatureInput) -> bool {
true
}
fn capabilities(&self) -> HashMap<String, serde_json::Value> {
let mut caps = HashMap::new();
caps.insert(
"feature_size".to_string(),
serde_json::Value::Number(self.config().feature_size().into()),
);
caps.insert(
"supports_batching".to_string(),
serde_json::Value::Bool(self.config().supports_batching()),
);
if let Some(max_batch) = self.config().max_batch_size() {
caps.insert(
"max_batch_size".to_string(),
serde_json::Value::Number(max_batch.into()),
);
}
caps
}
}
pub trait FeatureExtractorConfig: Send + Sync {
fn feature_size(&self) -> usize;
fn supports_batching(&self) -> bool;
fn max_batch_size(&self) -> Option<usize>;
fn additional_params(&self) -> HashMap<String, serde_json::Value> {
HashMap::new()
}
fn validate(&self) -> Result<()> {
Ok(())
}
}
#[cfg(test_disabled)]
mod tests {
use super::*;
use crate::auto::types::{
AudioMetadata, DocumentFormat, DocumentMetadata, FeatureInput, ImageFormat, ImageMetadata,
};
use trustformers_core::errors::TrustformersError;
#[test]
fn test_auto_feature_extractor_supported_types() {
let types = AutoFeatureExtractor::supported_model_types();
assert!(types.contains(&"clip"));
assert!(types.contains(&"wav2vec2"));
assert!(types.contains(&"layoutlm"));
assert!(types.contains(&"generic"));
}
#[test]
fn test_auto_feature_extractor_supported_tasks() {
let tasks = AutoFeatureExtractor::supported_tasks();
assert!(tasks.contains(&"image-classification"));
assert!(tasks.contains(&"automatic-speech-recognition"));
assert!(tasks.contains(&"document-understanding"));
}
#[test]
fn test_vision_feature_extractor() {
let config = VisionFeatureConfig {
feature_size: 768,
image_size: 224,
normalize: true,
do_resize: true,
do_center_crop: true,
crop_size: None,
mean: vec![0.485, 0.456, 0.406],
std: vec![0.229, 0.224, 0.225],
max_batch_size: Some(32),
};
let extractor = VisionFeatureExtractor::new(config);
let input = FeatureInput::Image {
data: vec![0u8; 1024],
format: ImageFormat::Jpeg,
metadata: Some(ImageMetadata {
width: 640,
height: 480,
channels: 3,
dpi: Some(96),
}),
};
let result = extractor.extract_features(&input);
assert!(result.is_ok());
let output = result.expect("Feature extraction should succeed");
assert_eq!(output.features.len(), 768);
assert_eq!(output.shape, vec![768]);
}
#[test]
fn test_feature_extractor_config_validation() {
let config = VisionFeatureConfig {
feature_size: 768,
image_size: 224,
normalize: true,
do_resize: true,
do_center_crop: true,
crop_size: None,
mean: vec![0.485, 0.456, 0.406],
std: vec![0.229, 0.224, 0.225],
max_batch_size: Some(32),
};
assert!(config.validate().is_ok());
assert_eq!(config.feature_size(), 768);
assert!(config.supports_batching());
assert_eq!(config.max_batch_size(), Some(32));
}
#[test]
fn test_extractor_capabilities() {
let config = AudioFeatureConfig {
sampling_rate: 16000,
feature_size: 80,
n_fft: 512,
hop_length: 160,
normalize: true,
max_batch_size: Some(16),
};
let extractor = AudioFeatureExtractor::new(config);
let caps = extractor.capabilities();
assert_eq!(
caps.get("feature_size")
.expect("missing feature_size capability")
.as_u64()
.expect("expected u64 value"),
80
);
assert_eq!(
caps.get("supports_batching")
.expect("missing supports_batching capability")
.as_bool()
.expect("expected bool value"),
true
);
assert_eq!(
caps.get("max_batch_size")
.expect("missing max_batch_size capability")
.as_u64()
.expect("expected u64 value"),
16
);
}
#[test]
fn test_invalid_input_handling() {
let config = VisionFeatureConfig {
feature_size: 768,
image_size: 224,
normalize: true,
do_resize: true,
do_center_crop: true,
crop_size: None,
mean: vec![0.485, 0.456, 0.406],
std: vec![0.229, 0.224, 0.225],
max_batch_size: Some(32),
};
let extractor = VisionFeatureExtractor::new(config);
let input = FeatureInput::Audio {
samples: vec![0.0; 1000],
sample_rate: 16000,
metadata: Some(AudioMetadata {
duration: 1.0,
channels: 1,
bit_depth: Some(16),
}),
};
let result = extractor.extract_features(&input);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
TrustformersError::InvalidInput { .. }
));
}
}