use super::{FeatureExtractor, FeatureExtractorConfig};
use crate::auto::types::{FeatureInput, FeatureOutput, ImageFormat};
use crate::error::{Result, TrustformersError};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct VisionFeatureExtractor {
config: VisionFeatureConfig,
}
impl VisionFeatureExtractor {
pub fn new(config: VisionFeatureConfig) -> Self {
Self { config }
}
fn preprocess_image(&self, data: &[u8], format: ImageFormat) -> Result<Vec<f32>> {
let processed_size = self.config.image_size * self.config.image_size * 3;
Ok(vec![0.0; processed_size])
}
fn extract_visual_features(&self, image: &[f32]) -> Result<Vec<f32>> {
Ok(vec![0.0; self.config.feature_size])
}
}
impl FeatureExtractor for VisionFeatureExtractor {
fn extract_features(&self, input: &FeatureInput) -> Result<FeatureOutput> {
match input {
FeatureInput::Image {
data,
format,
metadata,
} => {
let processed_image = self.preprocess_image(data, *format)?;
let features = self.extract_visual_features(&processed_image)?;
let mut output_metadata = HashMap::new();
if let Some(meta) = metadata {
output_metadata.insert(
"width".to_string(),
serde_json::Value::Number(meta.width.into()),
);
output_metadata.insert(
"height".to_string(),
serde_json::Value::Number(meta.height.into()),
);
output_metadata.insert(
"channels".to_string(),
serde_json::Value::Number(meta.channels.into()),
);
if let Some(dpi) = meta.dpi {
output_metadata
.insert("dpi".to_string(), serde_json::Value::Number(dpi.into()));
}
}
output_metadata.insert(
"processed_image_size".to_string(),
serde_json::Value::Number(self.config.image_size.into()),
);
output_metadata.insert(
"normalized".to_string(),
serde_json::Value::Bool(self.config.normalize),
);
Ok(FeatureOutput {
features,
shape: vec![self.config.feature_size],
metadata: output_metadata,
attention_mask: None,
special_tokens: vec![],
})
},
_ => Err(TrustformersError::invalid_input_simple(
"Vision feature extractor requires image input".to_string(),
)),
}
}
fn config(&self) -> &dyn FeatureExtractorConfig {
&self.config
}
fn supports_input(&self, input: &FeatureInput) -> bool {
matches!(input, FeatureInput::Image { .. })
}
fn capabilities(&self) -> HashMap<String, serde_json::Value> {
let mut caps = HashMap::new();
caps.insert(
"feature_size".to_string(),
serde_json::Value::Number(self.config.feature_size.into()),
);
caps.insert(
"supports_batching".to_string(),
serde_json::Value::Bool(self.config.supports_batching()),
);
if let Some(max_batch) = self.config.max_batch_size {
caps.insert(
"max_batch_size".to_string(),
serde_json::Value::Number(max_batch.into()),
);
}
caps.insert(
"modality".to_string(),
serde_json::Value::String("vision".to_string()),
);
caps.insert(
"image_size".to_string(),
serde_json::Value::Number(self.config.image_size.into()),
);
caps.insert(
"supports_resize".to_string(),
serde_json::Value::Bool(self.config.do_resize),
);
caps.insert(
"supports_center_crop".to_string(),
serde_json::Value::Bool(self.config.do_center_crop),
);
caps.insert(
"normalize".to_string(),
serde_json::Value::Bool(self.config.normalize),
);
let supported_formats = vec!["jpeg", "jpg", "png", "webp", "bmp", "tiff"];
caps.insert(
"supported_formats".to_string(),
serde_json::Value::Array(
supported_formats
.into_iter()
.map(|f| serde_json::Value::String(f.to_string()))
.collect(),
),
);
caps
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VisionFeatureConfig {
pub image_size: usize,
pub feature_size: usize,
pub normalize: bool,
pub do_resize: bool,
pub do_center_crop: bool,
pub crop_size: Option<usize>,
pub mean: Vec<f32>,
pub std: Vec<f32>,
pub max_batch_size: Option<usize>,
}
impl VisionFeatureConfig {
pub fn from_config(config: &serde_json::Value) -> Result<Self> {
Ok(Self {
image_size: config
.get("image_size")
.or_else(|| config.get("size"))
.and_then(|v| v.as_u64())
.unwrap_or(224) as usize,
feature_size: config.get("hidden_size").and_then(|v| v.as_u64()).unwrap_or(768)
as usize,
normalize: config.get("do_normalize").and_then(|v| v.as_bool()).unwrap_or(true),
do_resize: config.get("do_resize").and_then(|v| v.as_bool()).unwrap_or(true),
do_center_crop: config.get("do_center_crop").and_then(|v| v.as_bool()).unwrap_or(true),
crop_size: config.get("crop_size").and_then(|v| v.as_u64()).map(|v| v as usize),
mean: config
.get("image_mean")
.and_then(|v| v.as_array())
.map(|arr| arr.iter().filter_map(|v| v.as_f64()).map(|v| v as f32).collect())
.unwrap_or_else(|| vec![0.485, 0.456, 0.406]), std: config
.get("image_std")
.and_then(|v| v.as_array())
.map(|arr| arr.iter().filter_map(|v| v.as_f64()).map(|v| v as f32).collect())
.unwrap_or_else(|| vec![0.229, 0.224, 0.225]), max_batch_size: config
.get("max_batch_size")
.and_then(|v| v.as_u64())
.map(|v| v as usize),
})
}
pub fn default() -> Self {
Self {
image_size: 224,
feature_size: 768,
normalize: true,
do_resize: true,
do_center_crop: true,
crop_size: None,
mean: vec![0.485, 0.456, 0.406], std: vec![0.229, 0.224, 0.225], max_batch_size: Some(32),
}
}
}
impl FeatureExtractorConfig for VisionFeatureConfig {
fn feature_size(&self) -> usize {
self.feature_size
}
fn supports_batching(&self) -> bool {
true
}
fn max_batch_size(&self) -> Option<usize> {
self.max_batch_size
}
fn additional_params(&self) -> HashMap<String, serde_json::Value> {
let mut params = HashMap::new();
params.insert(
"image_size".to_string(),
serde_json::Value::Number(self.image_size.into()),
);
params.insert(
"normalize".to_string(),
serde_json::Value::Bool(self.normalize),
);
params.insert(
"do_resize".to_string(),
serde_json::Value::Bool(self.do_resize),
);
params.insert(
"do_center_crop".to_string(),
serde_json::Value::Bool(self.do_center_crop),
);
if let Some(crop_size) = self.crop_size {
params.insert(
"crop_size".to_string(),
serde_json::Value::Number(crop_size.into()),
);
}
params.insert(
"mean".to_string(),
serde_json::Value::Array(
self.mean
.iter()
.map(|&v| {
serde_json::Number::from_f64(v as f64)
.map(serde_json::Value::Number)
.unwrap_or_else(|| serde_json::Value::String(format!("{}", v)))
})
.collect(),
),
);
params.insert(
"std".to_string(),
serde_json::Value::Array(
self.std
.iter()
.map(|&v| {
serde_json::Number::from_f64(v as f64)
.map(serde_json::Value::Number)
.unwrap_or_else(|| serde_json::Value::String(format!("{}", v)))
})
.collect(),
),
);
params
}
fn validate(&self) -> Result<()> {
if self.image_size == 0 {
return Err(TrustformersError::lconfig_error(
"Image size must be greater than 0".to_string(),
));
}
if self.feature_size == 0 {
return Err(TrustformersError::lconfig_error(
"Feature size must be greater than 0".to_string(),
));
}
if self.normalize {
if self.mean.len() != 3 {
return Err(TrustformersError::lconfig_error(
"Mean values must have exactly 3 elements (RGB)".to_string(),
));
}
if self.std.len() != 3 {
return Err(TrustformersError::lconfig_error(
"Standard deviation values must have exactly 3 elements (RGB)".to_string(),
));
}
for &std_val in &self.std {
if std_val <= 0.0 {
return Err(TrustformersError::lconfig_error(
"Standard deviation values must be positive".to_string(),
));
}
}
}
if let Some(crop_size) = self.crop_size {
if crop_size == 0 {
return Err(TrustformersError::lconfig_error(
"Crop size must be greater than 0".to_string(),
));
}
if crop_size > self.image_size {
return Err(TrustformersError::lconfig_error(
"Crop size cannot be larger than image size".to_string(),
));
}
}
if let Some(batch_size) = self.max_batch_size {
if batch_size == 0 {
return Err(TrustformersError::lconfig_error(
"Maximum batch size must be greater than 0".to_string(),
));
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::auto::types::{ImageFormat, ImageMetadata};
#[test]
fn test_vision_feature_extractor_creation() {
let config = VisionFeatureConfig {
image_size: 224,
feature_size: 768,
normalize: true,
do_resize: true,
do_center_crop: true,
crop_size: Some(224),
mean: vec![0.485, 0.456, 0.406],
std: vec![0.229, 0.224, 0.225],
max_batch_size: Some(32),
};
let extractor = VisionFeatureExtractor::new(config);
assert_eq!(extractor.config().feature_size(), 768);
assert!(extractor.config().supports_batching());
assert_eq!(extractor.config().max_batch_size(), Some(32));
}
#[test]
fn test_vision_feature_extraction() {
let config = VisionFeatureConfig::default();
let extractor = VisionFeatureExtractor::new(config);
let input = FeatureInput::Image {
data: vec![0u8; 1024],
format: ImageFormat::Jpeg,
metadata: Some(ImageMetadata {
width: 640,
height: 480,
channels: 3,
dpi: Some(96),
}),
};
let result = extractor.extract_features(&input);
assert!(result.is_ok());
let output = result.expect("operation failed in test");
assert_eq!(output.features.len(), 768);
assert_eq!(output.shape, vec![768]);
assert_eq!(
output
.metadata
.get("width")
.expect("expected value not found")
.as_u64()
.expect("expected u64 value"),
640
);
assert_eq!(
output
.metadata
.get("height")
.expect("expected value not found")
.as_u64()
.expect("expected u64 value"),
480
);
assert_eq!(
output
.metadata
.get("channels")
.expect("expected value not found")
.as_u64()
.expect("expected u64 value"),
3
);
assert_eq!(
output
.metadata
.get("dpi")
.expect("expected value not found")
.as_u64()
.expect("expected u64 value"),
96
);
}
#[test]
fn test_vision_config_from_json() {
let config_json = serde_json::json!({
"image_size": 224,
"hidden_size": 768,
"do_normalize": true,
"do_resize": true,
"do_center_crop": true,
"crop_size": 224,
"image_mean": [0.485, 0.456, 0.406],
"image_std": [0.229, 0.224, 0.225],
"max_batch_size": 32
});
let config =
VisionFeatureConfig::from_config(&config_json).expect("operation failed in test");
assert_eq!(config.image_size, 224);
assert_eq!(config.feature_size, 768);
assert!(config.normalize);
assert!(config.do_resize);
assert!(config.do_center_crop);
assert_eq!(config.crop_size, Some(224));
assert_eq!(config.mean, vec![0.485, 0.456, 0.406]);
assert_eq!(config.std, vec![0.229, 0.224, 0.225]);
assert_eq!(config.max_batch_size, Some(32));
}
#[test]
fn test_vision_config_defaults() {
let minimal_config = serde_json::json!({});
let config =
VisionFeatureConfig::from_config(&minimal_config).expect("operation failed in test");
assert_eq!(config.image_size, 224);
assert_eq!(config.feature_size, 768);
assert!(config.normalize);
assert!(config.do_resize);
assert!(config.do_center_crop);
assert_eq!(config.mean, vec![0.485, 0.456, 0.406]);
assert_eq!(config.std, vec![0.229, 0.224, 0.225]);
}
#[test]
fn test_vision_config_validation() {
let mut config = VisionFeatureConfig::default();
assert!(config.validate().is_ok());
config.image_size = 0;
assert!(config.validate().is_err());
config.image_size = 224;
config.feature_size = 0;
assert!(config.validate().is_err());
config.feature_size = 768;
config.mean = vec![0.5, 0.5]; assert!(config.validate().is_err());
config.mean = vec![0.485, 0.456, 0.406];
config.std = vec![0.2]; assert!(config.validate().is_err());
config.std = vec![0.229, 0.224, 0.225];
config.std = vec![-0.1, 0.224, 0.225];
assert!(config.validate().is_err());
config.std = vec![0.229, 0.224, 0.225];
config.crop_size = Some(300);
assert!(config.validate().is_err());
config.crop_size = Some(224);
assert!(config.validate().is_ok());
}
#[test]
fn test_input_type_validation() {
let config = VisionFeatureConfig::default();
let extractor = VisionFeatureExtractor::new(config);
let image_input = FeatureInput::Image {
data: vec![0u8; 1024],
format: ImageFormat::Png,
metadata: None,
};
assert!(extractor.supports_input(&image_input));
let audio_input = FeatureInput::Audio {
samples: vec![0.0; 1000],
sample_rate: 16000,
metadata: None,
};
assert!(!extractor.supports_input(&audio_input));
}
#[test]
fn test_extractor_capabilities() {
let config = VisionFeatureConfig::default();
let extractor = VisionFeatureExtractor::new(config);
let caps = extractor.capabilities();
assert_eq!(
caps.get("modality")
.expect("expected value not found")
.as_str()
.expect("expected str value"),
"vision"
);
assert_eq!(
caps.get("feature_size")
.expect("expected value not found")
.as_u64()
.expect("expected u64 value"),
768
);
assert_eq!(
caps.get("image_size")
.expect("expected value not found")
.as_u64()
.expect("expected u64 value"),
224
);
assert!(caps
.get("supports_batching")
.expect("expected value not found")
.as_bool()
.expect("operation failed in test"));
assert!(caps.contains_key("supported_formats"));
}
#[test]
fn test_invalid_input_handling() {
let config = VisionFeatureConfig::default();
let extractor = VisionFeatureExtractor::new(config);
let input = FeatureInput::Text {
content: "This is text, not an image".to_string(),
metadata: None,
};
let result = extractor.extract_features(&input);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
TrustformersError::InvalidInput { .. }
));
}
}