use crate::auto::feature_extractors::{FeatureExtractor, FeatureExtractorConfig};
use crate::auto::types::{FeatureInput, FeatureOutput};
use crate::error::{Result, TrustformersError};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct GenericFeatureExtractor {
config: GenericFeatureConfig,
}
impl GenericFeatureExtractor {
pub fn new(config: GenericFeatureConfig) -> Self {
Self { config }
}
fn extract_text_features(&self, text: &str) -> Result<Vec<f32>> {
let text_lower = text.to_lowercase();
let words: Vec<&str> = text_lower.split_whitespace().collect();
let mut features = vec![0.0f32; self.config.feature_size];
if words.is_empty() {
return Ok(features);
}
for (i, word) in words.iter().enumerate() {
let hash = self.simple_hash(word) % self.config.feature_size;
let position_weight = 1.0 / (i + 1) as f32; features[hash] += position_weight;
}
let norm: f32 = features.iter().map(|&x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
for feat in &mut features {
*feat /= norm;
}
}
Ok(features)
}
fn simple_hash(&self, word: &str) -> usize {
let mut hash = 0usize;
for byte in word.bytes() {
hash = hash.wrapping_mul(31).wrapping_add(byte as usize);
}
hash
}
fn validate_input(&self, input: &FeatureInput) -> Result<()> {
match input {
FeatureInput::Text { .. } => Ok(()),
_ => Err(TrustformersError::invalid_input_simple(
"Generic feature extractor only supports text input".to_string(),
)),
}
}
}
impl FeatureExtractor for GenericFeatureExtractor {
fn extract_features(&self, input: &FeatureInput) -> Result<FeatureOutput> {
self.validate_input(input)?;
match input {
FeatureInput::Text { content, metadata } => {
let features = self.extract_text_features(content)?;
let mut output_metadata = HashMap::new();
output_metadata.insert(
"input_type".to_string(),
serde_json::Value::String("text".to_string()),
);
output_metadata.insert(
"word_count".to_string(),
serde_json::Value::Number(content.split_whitespace().count().into()),
);
output_metadata.insert(
"character_count".to_string(),
serde_json::Value::Number(content.len().into()),
);
if let Some(text_meta) = metadata {
if let Some(lang) = &text_meta.language {
output_metadata.insert(
"language".to_string(),
serde_json::Value::String(lang.clone()),
);
}
if let Some(encoding) = &text_meta.encoding {
output_metadata.insert(
"encoding".to_string(),
serde_json::Value::String(encoding.clone()),
);
}
}
Ok(FeatureOutput {
features,
shape: vec![self.config.feature_size],
metadata: output_metadata,
attention_mask: None,
special_tokens: vec![],
})
},
_ => Err(TrustformersError::invalid_input_simple(
"Generic feature extractor requires text input".to_string(),
)),
}
}
fn config(&self) -> &dyn FeatureExtractorConfig {
&self.config
}
fn supports_input(&self, input: &FeatureInput) -> bool {
matches!(input, FeatureInput::Text { .. })
}
fn preprocess(&self, input: &FeatureInput) -> Result<FeatureInput> {
match input {
FeatureInput::Text { content, metadata } => {
let cleaned_content = content
.chars()
.filter(|c| {
c.is_alphanumeric() || c.is_whitespace() || c.is_ascii_punctuation()
})
.collect::<String>()
.trim()
.to_string();
Ok(FeatureInput::Text {
content: cleaned_content,
metadata: metadata.clone(),
})
},
_ => Ok(input.clone()),
}
}
fn capabilities(&self) -> HashMap<String, serde_json::Value> {
let mut caps = HashMap::new();
caps.insert(
"supported_modalities".to_string(),
serde_json::Value::Array(vec![serde_json::Value::String("text".to_string())]),
);
caps.insert(
"feature_size".to_string(),
serde_json::Value::Number(self.config.feature_size.into()),
);
caps.insert(
"supports_batching".to_string(),
serde_json::Value::Bool(self.config.supports_batching()),
);
if let Some(max_batch) = self.config.max_batch_size() {
caps.insert(
"max_batch_size".to_string(),
serde_json::Value::Number(max_batch.into()),
);
}
caps.insert(
"extraction_method".to_string(),
serde_json::Value::String("bag_of_words_with_position_weighting".to_string()),
);
caps.insert(
"normalization".to_string(),
serde_json::Value::String("l2".to_string()),
);
caps
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GenericFeatureConfig {
pub feature_size: usize,
pub max_batch_size: Option<usize>,
}
impl GenericFeatureConfig {
pub fn from_config(config: &serde_json::Value) -> Result<Self> {
let feature_size = config
.get("hidden_size")
.or_else(|| config.get("feature_size"))
.and_then(|v| v.as_u64())
.unwrap_or(768) as usize;
let max_batch_size =
config.get("max_batch_size").and_then(|v| v.as_u64()).map(|v| v as usize);
Ok(Self {
feature_size,
max_batch_size,
})
}
pub fn default() -> Self {
Self {
feature_size: 768,
max_batch_size: Some(32),
}
}
pub fn validate(&self) -> Result<()> {
if self.feature_size == 0 {
return Err(TrustformersError::lconfig_error(
"Feature size must be greater than 0".to_string(),
));
}
if let Some(batch_size) = self.max_batch_size {
if batch_size == 0 {
return Err(TrustformersError::lconfig_error(
"Max batch size must be greater than 0".to_string(),
));
}
}
Ok(())
}
}
impl FeatureExtractorConfig for GenericFeatureConfig {
fn feature_size(&self) -> usize {
self.feature_size
}
fn supports_batching(&self) -> bool {
true
}
fn max_batch_size(&self) -> Option<usize> {
self.max_batch_size
}
fn validate(&self) -> Result<()> {
self.validate()
}
fn additional_params(&self) -> HashMap<String, serde_json::Value> {
let mut params = HashMap::new();
params.insert(
"extraction_method".to_string(),
serde_json::Value::String("bag_of_words".to_string()),
);
params.insert(
"normalization".to_string(),
serde_json::Value::String("l2".to_string()),
);
params.insert(
"position_weighting".to_string(),
serde_json::Value::Bool(true),
);
params
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::auto::types::TextMetadata;
#[test]
fn test_generic_feature_extractor_creation() {
let config = GenericFeatureConfig {
feature_size: 512,
max_batch_size: Some(16),
};
let extractor = GenericFeatureExtractor::new(config);
assert_eq!(extractor.config().feature_size(), 512);
assert!(extractor.config().supports_batching());
assert_eq!(extractor.config().max_batch_size(), Some(16));
}
#[test]
fn test_text_feature_extraction() {
let config = GenericFeatureConfig {
feature_size: 128,
max_batch_size: Some(8),
};
let extractor = GenericFeatureExtractor::new(config);
let input = FeatureInput::Text {
content: "This is a test sentence for feature extraction.".to_string(),
metadata: Some(TextMetadata::new().with_language("en")),
};
let result = extractor.extract_features(&input);
assert!(result.is_ok());
let output = result.expect("operation failed in test");
assert_eq!(output.features.len(), 128);
assert_eq!(output.shape, vec![128]);
assert!(output.metadata.contains_key("word_count"));
assert!(output.metadata.contains_key("language"));
}
#[test]
fn test_empty_text_handling() {
let config = GenericFeatureConfig {
feature_size: 64,
max_batch_size: None,
};
let extractor = GenericFeatureExtractor::new(config);
let input = FeatureInput::Text {
content: "".to_string(),
metadata: None,
};
let result = extractor.extract_features(&input);
assert!(result.is_ok());
let output = result.expect("operation failed in test");
assert_eq!(output.features.len(), 64);
assert!(output.features.iter().all(|&x| x == 0.0));
}
#[test]
fn test_invalid_input_type() {
let config = GenericFeatureConfig::default();
let extractor = GenericFeatureExtractor::new(config);
let input = FeatureInput::Audio {
samples: vec![0.0; 1000],
sample_rate: 16000,
metadata: None,
};
let result = extractor.extract_features(&input);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
TrustformersError::InvalidInput { .. }
));
}
#[test]
fn test_supports_input() {
let config = GenericFeatureConfig::default();
let extractor = GenericFeatureExtractor::new(config);
let text_input = FeatureInput::Text {
content: "test".to_string(),
metadata: None,
};
assert!(extractor.supports_input(&text_input));
let audio_input = FeatureInput::Audio {
samples: vec![0.0],
sample_rate: 16000,
metadata: None,
};
assert!(!extractor.supports_input(&audio_input));
}
#[test]
fn test_config_from_json() {
let config_json = serde_json::json!({
"hidden_size": 1024,
"max_batch_size": 64
});
let config = GenericFeatureConfig::from_config(&config_json);
assert!(config.is_ok());
let config = config.expect("operation failed in test");
assert_eq!(config.feature_size, 1024);
assert_eq!(config.max_batch_size, Some(64));
}
#[test]
fn test_config_validation() {
let valid_config = GenericFeatureConfig {
feature_size: 768,
max_batch_size: Some(32),
};
assert!(valid_config.validate().is_ok());
let invalid_config = GenericFeatureConfig {
feature_size: 0,
max_batch_size: Some(32),
};
assert!(invalid_config.validate().is_err());
}
#[test]
fn test_simple_hash_consistency() {
let config = GenericFeatureConfig::default();
let extractor = GenericFeatureExtractor::new(config);
let hash1 = extractor.simple_hash("test");
let hash2 = extractor.simple_hash("test");
assert_eq!(hash1, hash2);
let hash3 = extractor.simple_hash("different");
assert_ne!(hash1, hash3);
}
#[test]
fn test_feature_normalization() {
let config = GenericFeatureConfig {
feature_size: 100,
max_batch_size: None,
};
let extractor = GenericFeatureExtractor::new(config);
let input = FeatureInput::Text {
content: "word1 word2 word3".to_string(),
metadata: None,
};
let result = extractor.extract_features(&input);
assert!(result.is_ok());
let output = result.expect("operation failed in test");
let norm_squared: f32 = output.features.iter().map(|&x| x * x).sum();
assert!((norm_squared - 1.0).abs() < 1e-6 || norm_squared == 0.0);
}
#[test]
fn test_text_preprocessing() {
let config = GenericFeatureConfig::default();
let extractor = GenericFeatureExtractor::new(config);
let input = FeatureInput::Text {
content: " Hello, World! \n\t ".to_string(),
metadata: None,
};
let preprocessed = extractor.preprocess(&input);
assert!(preprocessed.is_ok());
if let FeatureInput::Text { content, .. } = preprocessed.expect("operation failed in test")
{
assert_eq!(content, "Hello, World!");
} else {
panic!("Expected text input after preprocessing");
}
}
#[test]
fn test_extractor_capabilities() {
let config = GenericFeatureConfig {
feature_size: 256,
max_batch_size: Some(16),
};
let extractor = GenericFeatureExtractor::new(config);
let caps = extractor.capabilities();
assert!(caps.contains_key("supported_modalities"));
assert!(caps.contains_key("extraction_method"));
assert!(caps.contains_key("normalization"));
assert_eq!(
caps.get("feature_size")
.expect("expected value not found")
.as_u64()
.expect("expected u64 value"),
256
);
}
}