use crate::core::traits::{Model, Tokenizer};
use crate::error::{Result, TrustformersError};
use crate::pipeline::{BasePipeline, Pipeline};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use trustformers_core::Tensor;
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct VisualQuestionAnsweringConfig {
pub max_question_length: usize,
pub max_answer_length: usize,
pub image_config: ImageConfig,
pub fusion_strategy: FusionStrategy,
pub answer_generation: AnswerGenerationStrategy,
pub confidence_threshold: f32,
pub top_k_answers: usize,
pub enable_attention_viz: bool,
pub enable_reasoning: bool,
}
impl Default for VisualQuestionAnsweringConfig {
fn default() -> Self {
Self {
max_question_length: 512,
max_answer_length: 256,
image_config: ImageConfig::default(),
fusion_strategy: FusionStrategy::CrossAttention,
answer_generation: AnswerGenerationStrategy::Generative,
confidence_threshold: 0.1,
top_k_answers: 5,
enable_attention_viz: false,
enable_reasoning: false,
}
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ImageConfig {
pub image_size: (u32, u32),
pub normalize_mean: [f32; 3],
pub normalize_std: [f32; 3],
pub enable_augmentation: bool,
pub patch_size: Option<u32>,
pub num_patches: Option<usize>,
}
impl Default for ImageConfig {
fn default() -> Self {
Self {
image_size: (224, 224),
normalize_mean: [0.485, 0.456, 0.406],
normalize_std: [0.229, 0.224, 0.225],
enable_augmentation: false,
patch_size: Some(16),
num_patches: Some(196), }
}
}
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
pub enum FusionStrategy {
#[default]
CrossAttention,
Concatenation,
Addition,
BilinearPooling,
TransformerFusion,
GraphFusion,
}
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
pub enum AnswerGenerationStrategy {
#[default]
Generative,
Extractive,
Classification,
Hybrid,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct VisualQuestionAnsweringInput {
pub image: ImageInput,
pub question: String,
pub context: Option<String>,
pub answer_candidates: Option<Vec<String>>,
pub metadata: Option<HashMap<String, String>>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum ImageInput {
Bytes(Vec<u8>),
Tensor(Vec<f32>),
Path(String),
Url(String),
Base64(String),
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct VisualQuestionAnsweringOutput {
pub answer: String,
pub confidence: f32,
pub alternative_answers: Vec<AnswerCandidate>,
pub attention_visualization: Option<AttentionVisualization>,
pub reasoning_chain: Option<Vec<ReasoningStep>>,
pub image_features: Option<ImageFeatures>,
pub metadata: ProcessingMetadata,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct AnswerCandidate {
pub answer: String,
pub confidence: f32,
pub evidence: Option<String>,
pub bbox: Option<BoundingBox>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct BoundingBox {
pub x: f32,
pub y: f32,
pub width: f32,
pub height: f32,
pub confidence: f32,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct AttentionVisualization {
pub cross_attention_weights: Vec<Vec<f32>>,
pub question_self_attention: Vec<Vec<f32>>,
pub visual_attention_heatmap: Vec<f32>,
pub attention_heads: Vec<AttentionHead>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct AttentionHead {
pub head_id: usize,
pub layer_id: usize,
pub pattern_type: String,
pub avg_attention: f32,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ReasoningStep {
pub description: String,
pub step_type: ReasoningStepType,
pub confidence: f32,
pub evidence: Option<String>,
pub grounding: Option<BoundingBox>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum ReasoningStepType {
ObjectDetection,
SpatialReasoning,
Counting,
AttributeRecognition,
RelationshipReasoning,
TemporalReasoning,
CausalReasoning,
LogicalInference,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ImageFeatures {
pub global_features: Vec<f32>,
pub patch_features: Vec<Vec<f32>>,
pub detected_objects: Vec<DetectedObject>,
pub scene_description: Option<String>,
pub image_classification: Option<Vec<ClassificationResult>>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct DetectedObject {
pub class: String,
pub confidence: f32,
pub bbox: BoundingBox,
pub attributes: Option<HashMap<String, String>>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ClassificationResult {
pub label: String,
pub confidence: f32,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ProcessingMetadata {
pub processing_time_ms: u64,
pub model_name: String,
pub config: String,
pub tokens_processed: usize,
pub memory_usage_mb: Option<f32>,
}
pub struct VisualQuestionAnsweringPipeline<M, T>
where
M: Model + Clone + Send + Sync + 'static,
T: Tokenizer + Clone + Send + Sync + 'static,
{
base: BasePipeline<M, T>,
config: VisualQuestionAnsweringConfig,
image_processor: ImageProcessor,
fusion_module: FusionModule,
answer_generator: AnswerGenerator,
attention_visualizer: Option<AttentionVisualizer>,
reasoning_engine: Option<ReasoningEngine>,
}
impl<M, T> VisualQuestionAnsweringPipeline<M, T>
where
M: Model<Input = Tensor, Output = Tensor> + Clone + Send + Sync + 'static,
T: Tokenizer + Clone + Send + Sync + 'static,
{
pub fn new(model: M, tokenizer: T) -> Result<Self> {
let base = BasePipeline::new(model, tokenizer);
let config = VisualQuestionAnsweringConfig::default();
let image_processor = ImageProcessor::new(config.image_config.clone())?;
let fusion_module = FusionModule::new(config.fusion_strategy.clone())?;
let answer_generator = AnswerGenerator::new(config.answer_generation.clone())?;
Ok(Self {
base,
config,
image_processor,
fusion_module,
answer_generator,
attention_visualizer: None,
reasoning_engine: None,
})
}
pub fn with_config(mut self, config: VisualQuestionAnsweringConfig) -> Self {
self.config = config;
self
}
pub fn with_fusion_strategy(mut self, strategy: FusionStrategy) -> Result<Self> {
self.config.fusion_strategy = strategy.clone();
self.fusion_module = FusionModule::new(strategy)?;
Ok(self)
}
pub fn with_answer_generation(mut self, strategy: AnswerGenerationStrategy) -> Result<Self> {
self.config.answer_generation = strategy.clone();
self.answer_generator = AnswerGenerator::new(strategy)?;
Ok(self)
}
pub fn with_attention_visualization(mut self, enable: bool) -> Self {
self.config.enable_attention_viz = enable;
if enable && self.attention_visualizer.is_none() {
self.attention_visualizer = Some(AttentionVisualizer::new());
}
self
}
pub fn with_reasoning(mut self, enable: bool) -> Self {
self.config.enable_reasoning = enable;
if enable && self.reasoning_engine.is_none() {
self.reasoning_engine = Some(ReasoningEngine::new());
}
self
}
pub fn with_confidence_threshold(mut self, threshold: f32) -> Self {
self.config.confidence_threshold = threshold.clamp(0.0, 1.0);
self
}
pub fn with_top_k_answers(mut self, k: usize) -> Self {
self.config.top_k_answers = k;
self
}
pub fn answer_question(
&self,
input: VisualQuestionAnsweringInput,
) -> Result<VisualQuestionAnsweringOutput> {
let start_time = std::time::Instant::now();
if input.question.trim().is_empty() {
return Err(TrustformersError::invalid_input_simple(
"Question cannot be empty".to_string(),
));
}
let image_tensor = self.image_processor.process_image(&input.image)?;
let image_features = self.extract_image_features(&image_tensor)?;
let question_tokens = self.base.tokenizer.encode(&input.question)?;
let question_ids_f32: Vec<f32> =
question_tokens.input_ids.iter().map(|&x| x as f32).collect();
let question_tensor =
Tensor::from_vec(question_ids_f32, &[1, question_tokens.input_ids.len()])?;
let fused_features = self.fusion_module.fuse(&image_tensor, &question_tensor)?;
let answer_output = self.answer_generator.generate_answer(
&fused_features,
&input.question,
&input.answer_candidates,
&self.config,
)?;
let attention_visualization = if self.config.enable_attention_viz {
self.attention_visualizer
.as_ref()
.map(|viz| {
viz.visualize_attention(&fused_features, &image_tensor, &question_tensor)
})
.transpose()?
} else {
None
};
let reasoning_chain = if self.config.enable_reasoning {
self.reasoning_engine
.as_ref()
.map(|engine| {
engine.generate_reasoning_chain(
&input.question,
&answer_output.answer,
&image_features,
)
})
.transpose()?
} else {
None
};
let processing_time = start_time.elapsed().as_millis() as u64;
Ok(VisualQuestionAnsweringOutput {
answer: answer_output.answer,
confidence: answer_output.confidence,
alternative_answers: answer_output.alternatives,
attention_visualization,
reasoning_chain,
image_features: Some(image_features),
metadata: ProcessingMetadata {
processing_time_ms: processing_time,
model_name: "vqa-model".to_string(),
config: serde_json::to_string(&self.config).unwrap_or_default(),
tokens_processed: question_tokens.input_ids.len(),
memory_usage_mb: None,
},
})
}
fn extract_image_features(&self, image_tensor: &Tensor) -> Result<ImageFeatures> {
let image_output = self.base.model.forward(image_tensor.clone())?;
let image_data = image_output.data()?;
let global_features = self.extract_global_features(&image_data);
let patch_features = self.extract_patch_features(&image_data);
let detected_objects = self.simulate_object_detection();
let scene_description = Some("A scene containing various objects".to_string());
let image_classification = Some(vec![
ClassificationResult {
label: "indoor".to_string(),
confidence: 0.8,
},
ClassificationResult {
label: "outdoor".to_string(),
confidence: 0.2,
},
]);
Ok(ImageFeatures {
global_features,
patch_features,
detected_objects,
scene_description,
image_classification,
})
}
fn extract_global_features(&self, image_data: &[f32]) -> Vec<f32> {
let chunk_size = 64; let mut global_features = vec![0.0; chunk_size];
for (i, &value) in image_data.iter().enumerate() {
global_features[i % chunk_size] += value;
}
let count = image_data.len() as f32 / chunk_size as f32;
for feature in &mut global_features {
*feature /= count;
}
global_features
}
fn extract_patch_features(&self, image_data: &[f32]) -> Vec<Vec<f32>> {
let patch_size = 64; let num_patches = self.config.image_config.num_patches.unwrap_or(196);
let mut patch_features = Vec::new();
for i in 0..num_patches {
let start_idx = (i * patch_size) % image_data.len();
let end_idx = ((i + 1) * patch_size).min(image_data.len());
let patch = if end_idx > start_idx {
image_data[start_idx..end_idx].to_vec()
} else {
vec![0.0; patch_size]
};
patch_features.push(patch);
}
patch_features
}
fn simulate_object_detection(&self) -> Vec<DetectedObject> {
vec![
DetectedObject {
class: "person".to_string(),
confidence: 0.9,
bbox: BoundingBox {
x: 0.2,
y: 0.3,
width: 0.3,
height: 0.6,
confidence: 0.9,
},
attributes: Some(
[
("age".to_string(), "adult".to_string()),
("gender".to_string(), "unknown".to_string()),
]
.iter()
.cloned()
.collect(),
),
},
DetectedObject {
class: "car".to_string(),
confidence: 0.8,
bbox: BoundingBox {
x: 0.6,
y: 0.4,
width: 0.3,
height: 0.3,
confidence: 0.8,
},
attributes: Some(
[
("color".to_string(), "red".to_string()),
("type".to_string(), "sedan".to_string()),
]
.iter()
.cloned()
.collect(),
),
},
]
}
}
impl<M, T> Pipeline for VisualQuestionAnsweringPipeline<M, T>
where
M: Model<Input = Tensor, Output = Tensor> + Clone + Send + Sync + 'static,
T: Tokenizer + Clone + Send + Sync + 'static,
{
type Input = VisualQuestionAnsweringInput;
type Output = VisualQuestionAnsweringOutput;
fn __call__(&self, input: Self::Input) -> Result<Self::Output> {
self.answer_question(input)
}
}
pub struct ImageProcessor {
config: ImageConfig,
}
impl ImageProcessor {
pub fn new(config: ImageConfig) -> Result<Self> {
Ok(Self { config })
}
pub fn process_image(&self, image: &ImageInput) -> Result<Tensor> {
match image {
ImageInput::Bytes(bytes) => self.process_image_bytes(bytes),
ImageInput::Tensor(tensor_data) => self.process_tensor_data(tensor_data),
ImageInput::Path(path) => self.process_image_path(path),
ImageInput::Url(url) => self.process_image_url(url),
ImageInput::Base64(base64) => self.process_base64_image(base64),
}
}
fn process_image_bytes(&self, _bytes: &[u8]) -> Result<Tensor> {
let (width, height) = self.config.image_size;
let channels = 3;
let size = (width * height * channels) as usize;
let data: Vec<f32> = (0..size)
.map(|i| {
(i as f32 / size as f32 - self.config.normalize_mean[i % 3])
/ self.config.normalize_std[i % 3]
})
.collect();
Tensor::from_vec(
data,
&[1, channels as usize, height as usize, width as usize],
)
.map_err(Into::into)
}
fn process_tensor_data(&self, tensor_data: &[f32]) -> Result<Tensor> {
let (width, height) = self.config.image_size;
let channels = 3;
let normalized_data: Vec<f32> = tensor_data
.iter()
.enumerate()
.map(|(i, &val)| {
(val - self.config.normalize_mean[i % 3]) / self.config.normalize_std[i % 3]
})
.collect();
Tensor::from_vec(
normalized_data,
&[1, channels, height as usize, width as usize],
)
.map_err(Into::into)
}
fn process_image_path(&self, _path: &str) -> Result<Tensor> {
self.process_image_bytes(&[])
}
fn process_image_url(&self, _url: &str) -> Result<Tensor> {
self.process_image_bytes(&[])
}
fn process_base64_image(&self, _base64: &str) -> Result<Tensor> {
self.process_image_bytes(&[])
}
}
pub struct FusionModule {
strategy: FusionStrategy,
}
impl FusionModule {
pub fn new(strategy: FusionStrategy) -> Result<Self> {
Ok(Self { strategy })
}
pub fn fuse(&self, image_tensor: &Tensor, question_tensor: &Tensor) -> Result<Tensor> {
match self.strategy {
FusionStrategy::CrossAttention => {
self.cross_attention_fusion(image_tensor, question_tensor)
},
FusionStrategy::Concatenation => {
self.concatenation_fusion(image_tensor, question_tensor)
},
FusionStrategy::Addition => self.addition_fusion(image_tensor, question_tensor),
FusionStrategy::BilinearPooling => {
self.bilinear_pooling_fusion(image_tensor, question_tensor)
},
FusionStrategy::TransformerFusion => {
self.transformer_fusion(image_tensor, question_tensor)
},
FusionStrategy::GraphFusion => self.graph_fusion(image_tensor, question_tensor),
}
}
fn cross_attention_fusion(
&self,
image_tensor: &Tensor,
question_tensor: &Tensor,
) -> Result<Tensor> {
let image_data = image_tensor.data()?;
let question_data = question_tensor.data()?;
let attention_dim = image_data.len().min(question_data.len());
let mut fused_data = Vec::with_capacity(attention_dim);
for i in 0..attention_dim {
let img_val = image_data[i % image_data.len()];
let q_val = question_data[i % question_data.len()];
fused_data.push(img_val * q_val);
}
Tensor::from_vec(fused_data, &[1, attention_dim]).map_err(Into::into)
}
fn concatenation_fusion(
&self,
image_tensor: &Tensor,
question_tensor: &Tensor,
) -> Result<Tensor> {
let mut fused_data = Vec::new();
fused_data.extend(image_tensor.data()?);
fused_data.extend(question_tensor.data()?);
let fused_len = fused_data.len();
Tensor::from_vec(fused_data, &[1, fused_len]).map_err(Into::into)
}
fn addition_fusion(&self, image_tensor: &Tensor, question_tensor: &Tensor) -> Result<Tensor> {
let image_data = image_tensor.data()?;
let question_data = question_tensor.data()?;
let min_len = image_data.len().min(question_data.len());
let fused_data: Vec<f32> = (0..min_len).map(|i| image_data[i] + question_data[i]).collect();
Tensor::from_vec(fused_data, &[1, min_len]).map_err(Into::into)
}
fn bilinear_pooling_fusion(
&self,
image_tensor: &Tensor,
question_tensor: &Tensor,
) -> Result<Tensor> {
let image_data = image_tensor.data()?;
let question_data = question_tensor.data()?;
let output_dim = 256; let mut fused_data = vec![0.0; output_dim];
for i in 0..output_dim {
let img_idx = i % image_data.len();
let q_idx = i % question_data.len();
fused_data[i] = image_data[img_idx] * question_data[q_idx];
}
Tensor::from_vec(fused_data, &[1, output_dim]).map_err(Into::into)
}
fn transformer_fusion(
&self,
image_tensor: &Tensor,
question_tensor: &Tensor,
) -> Result<Tensor> {
self.cross_attention_fusion(image_tensor, question_tensor)
}
fn graph_fusion(&self, image_tensor: &Tensor, question_tensor: &Tensor) -> Result<Tensor> {
self.concatenation_fusion(image_tensor, question_tensor)
}
}
pub struct AnswerGenerator {
strategy: AnswerGenerationStrategy,
}
#[derive(Debug, Clone)]
pub struct AnswerOutput {
pub answer: String,
pub confidence: f32,
pub alternatives: Vec<AnswerCandidate>,
}
impl AnswerGenerator {
pub fn new(strategy: AnswerGenerationStrategy) -> Result<Self> {
Ok(Self { strategy })
}
pub fn generate_answer(
&self,
features: &Tensor,
question: &str,
candidates: &Option<Vec<String>>,
config: &VisualQuestionAnsweringConfig,
) -> Result<AnswerOutput> {
match self.strategy {
AnswerGenerationStrategy::Generative => {
self.generative_answer(features, question, config)
},
AnswerGenerationStrategy::Extractive => {
self.extractive_answer(features, question, candidates, config)
},
AnswerGenerationStrategy::Classification => {
self.classification_answer(features, question, config)
},
AnswerGenerationStrategy::Hybrid => {
self.hybrid_answer(features, question, candidates, config)
},
}
}
fn generative_answer(
&self,
features: &Tensor,
question: &str,
config: &VisualQuestionAnsweringConfig,
) -> Result<AnswerOutput> {
let answer = if question.to_lowercase().contains("what") {
"An object or scene element"
} else if question.to_lowercase().contains("how many") {
"2"
} else if question.to_lowercase().contains("where") {
"In the center of the image"
} else if question.to_lowercase().contains("who") {
"A person"
} else if question.to_lowercase().contains("when") {
"During the day"
} else if question.to_lowercase().contains("why") {
"Due to the context of the scene"
} else if question.to_lowercase().contains("is") || question.to_lowercase().contains("are")
{
"Yes"
} else {
"I cannot determine the answer from the image"
};
let features_data = features.data()?;
let confidence =
0.7 + (features_data.iter().sum::<f32>() / features_data.len() as f32).abs() * 0.3;
let confidence = confidence.clamp(0.0, 1.0);
let alternatives = vec![
AnswerCandidate {
answer: "Alternative answer 1".to_string(),
confidence: confidence * 0.8,
evidence: Some("Based on visual features".to_string()),
bbox: None,
},
AnswerCandidate {
answer: "Alternative answer 2".to_string(),
confidence: confidence * 0.6,
evidence: Some("Based on question context".to_string()),
bbox: None,
},
];
Ok(AnswerOutput {
answer: answer.to_string(),
confidence,
alternatives,
})
}
fn extractive_answer(
&self,
_features: &Tensor,
_question: &str,
candidates: &Option<Vec<String>>,
_config: &VisualQuestionAnsweringConfig,
) -> Result<AnswerOutput> {
let default_candidates = vec![
"yes".to_string(),
"no".to_string(),
"person".to_string(),
"car".to_string(),
"building".to_string(),
];
let candidates = candidates.as_ref().unwrap_or(&default_candidates);
let answer = candidates.first().unwrap_or(&"unknown".to_string()).clone();
let confidence = 0.8;
let alternatives = candidates
.iter()
.enumerate()
.map(|(i, candidate)| AnswerCandidate {
answer: candidate.clone(),
confidence: 0.9 - (i as f32 * 0.1),
evidence: Some("Extracted from candidates".to_string()),
bbox: None,
})
.collect();
Ok(AnswerOutput {
answer,
confidence,
alternatives,
})
}
fn classification_answer(
&self,
features: &Tensor,
question: &str,
_config: &VisualQuestionAnsweringConfig,
) -> Result<AnswerOutput> {
let classes = if question.to_lowercase().contains("color") {
vec!["red", "blue", "green", "yellow", "black", "white"]
} else if question.to_lowercase().contains("animal") {
vec!["cat", "dog", "bird", "horse", "cow", "sheep"]
} else {
vec!["yes", "no", "maybe"]
};
let feature_sum = features.data()?.iter().sum::<f32>();
let class_idx = (feature_sum.abs() as usize) % classes.len();
let answer = classes[class_idx].to_string();
let confidence = 0.8;
let alternatives = classes
.iter()
.enumerate()
.map(|(i, &class)| AnswerCandidate {
answer: class.to_string(),
confidence: if i == class_idx { confidence } else { confidence * 0.5 },
evidence: Some("Classification result".to_string()),
bbox: None,
})
.collect();
Ok(AnswerOutput {
answer,
confidence,
alternatives,
})
}
fn hybrid_answer(
&self,
features: &Tensor,
question: &str,
candidates: &Option<Vec<String>>,
config: &VisualQuestionAnsweringConfig,
) -> Result<AnswerOutput> {
let generative_result = self.generative_answer(features, question, config)?;
let classification_result = self.classification_answer(features, question, config)?;
let answer = if generative_result.confidence > classification_result.confidence {
generative_result.answer
} else {
classification_result.answer
};
let confidence = (generative_result.confidence + classification_result.confidence) / 2.0;
let mut alternatives = generative_result.alternatives;
alternatives.extend(classification_result.alternatives);
alternatives.sort_by(|a, b| {
b.confidence.partial_cmp(&a.confidence).unwrap_or(std::cmp::Ordering::Equal)
});
alternatives.truncate(config.top_k_answers);
Ok(AnswerOutput {
answer,
confidence,
alternatives,
})
}
}
pub struct AttentionVisualizer;
impl Default for AttentionVisualizer {
fn default() -> Self {
Self::new()
}
}
impl AttentionVisualizer {
pub fn new() -> Self {
Self
}
pub fn visualize_attention(
&self,
_features: &Tensor,
_image_tensor: &Tensor,
_question_tensor: &Tensor,
) -> Result<AttentionVisualization> {
let cross_attention_weights = vec![
vec![0.1, 0.2, 0.3, 0.4],
vec![0.2, 0.3, 0.4, 0.1],
vec![0.3, 0.4, 0.1, 0.2],
];
let question_self_attention = vec![
vec![0.8, 0.1, 0.1],
vec![0.1, 0.8, 0.1],
vec![0.1, 0.1, 0.8],
];
let visual_attention_heatmap = (0..196).map(|i| (i as f32 / 196.0) * 0.5 + 0.5).collect();
let attention_heads = vec![
AttentionHead {
head_id: 0,
layer_id: 0,
pattern_type: "object-focused".to_string(),
avg_attention: 0.7,
},
AttentionHead {
head_id: 1,
layer_id: 0,
pattern_type: "spatial-reasoning".to_string(),
avg_attention: 0.6,
},
];
Ok(AttentionVisualization {
cross_attention_weights,
question_self_attention,
visual_attention_heatmap,
attention_heads,
})
}
}
pub struct ReasoningEngine;
impl Default for ReasoningEngine {
fn default() -> Self {
Self::new()
}
}
impl ReasoningEngine {
pub fn new() -> Self {
Self
}
pub fn generate_reasoning_chain(
&self,
question: &str,
answer: &str,
_image_features: &ImageFeatures,
) -> Result<Vec<ReasoningStep>> {
let mut reasoning_steps = Vec::new();
if question.to_lowercase().contains("how many") {
reasoning_steps.push(ReasoningStep {
description: "Detecting objects in the image".to_string(),
step_type: ReasoningStepType::ObjectDetection,
confidence: 0.9,
evidence: Some("Multiple objects detected".to_string()),
grounding: Some(BoundingBox {
x: 0.1,
y: 0.1,
width: 0.8,
height: 0.8,
confidence: 0.8,
}),
});
reasoning_steps.push(ReasoningStep {
description: "Counting detected objects".to_string(),
step_type: ReasoningStepType::Counting,
confidence: 0.8,
evidence: Some(format!("Counted objects to determine answer: {}", answer)),
grounding: None,
});
} else if question.to_lowercase().contains("where") {
reasoning_steps.push(ReasoningStep {
description: "Analyzing spatial relationships".to_string(),
step_type: ReasoningStepType::SpatialReasoning,
confidence: 0.8,
evidence: Some("Located object position in image".to_string()),
grounding: Some(BoundingBox {
x: 0.3,
y: 0.3,
width: 0.4,
height: 0.4,
confidence: 0.7,
}),
});
} else if question.to_lowercase().contains("what") {
reasoning_steps.push(ReasoningStep {
description: "Identifying objects and attributes".to_string(),
step_type: ReasoningStepType::AttributeRecognition,
confidence: 0.9,
evidence: Some("Recognized object attributes".to_string()),
grounding: None,
});
}
reasoning_steps.push(ReasoningStep {
description: format!("Concluding that the answer is: {}", answer),
step_type: ReasoningStepType::LogicalInference,
confidence: 0.7,
evidence: Some("Based on visual analysis and reasoning".to_string()),
grounding: None,
});
Ok(reasoning_steps)
}
}
#[derive(Clone, Debug)]
pub struct VqaImageInput {
pub pixels: Vec<u8>,
pub width: usize,
pub height: usize,
}
#[derive(Clone, Debug)]
pub struct VqaInput {
pub image: VqaImageInput,
pub question: String,
}
#[derive(Clone, Debug)]
pub struct VqaResult {
pub answer: String,
pub score: f32,
pub answer_id: usize,
}
#[derive(Clone, Debug)]
pub struct VqaConfig {
pub model_id: String,
pub max_answer_length: usize,
pub top_k: usize,
pub image_size: usize,
}
impl Default for VqaConfig {
fn default() -> Self {
Self {
model_id: "dandelin/vilt-b32-finetuned-vqa".to_string(),
max_answer_length: 30,
top_k: 5,
image_size: 384,
}
}
}
#[derive(Debug, thiserror::Error)]
pub enum PipelineError {
#[error("Empty question")]
EmptyQuestion,
#[error("Empty image")]
EmptyImage,
#[error("Empty answer vocabulary")]
EmptyVocabulary,
}
pub struct VqaProcessor {
vocab: std::collections::HashMap<String, u32>,
}
impl VqaProcessor {
pub fn new() -> Self {
Self {
vocab: std::collections::HashMap::new(),
}
}
pub fn encode_question(&self, question: &str) -> Vec<u32> {
question
.split_whitespace()
.map(|word| {
let lower = word.trim_matches(|c: char| !c.is_alphanumeric()).to_lowercase();
if let Some(&id) = self.vocab.get(&lower) {
id
} else {
let mut h: u64 = 5381;
for b in lower.bytes() {
h = h.wrapping_mul(33).wrapping_add(b as u64);
}
(h % 30_000) as u32 + 1
}
})
.collect()
}
pub fn encode_image_features(image: &VqaImageInput) -> Vec<f32> {
if image.pixels.is_empty() {
return Vec::new();
}
let as_f32: Vec<f32> = image.pixels.iter().map(|&b| b as f32 / 255.0).collect();
let means = [0.485_f32, 0.456, 0.406];
let stds = [0.229_f32, 0.224, 0.225];
let num_channels = 3;
as_f32
.iter()
.enumerate()
.map(|(i, &v)| {
let ch = i % num_channels;
(v - means[ch]) / stds[ch]
})
.collect()
}
pub fn combine_modalities(text_features: &[f32], image_features: &[f32]) -> Vec<f32> {
let mut combined: Vec<f32> =
text_features.iter().chain(image_features.iter()).copied().collect();
let norm: f32 = combined.iter().map(|v| v * v).sum::<f32>().sqrt();
if norm > 1e-8 {
for v in &mut combined {
*v /= norm;
}
}
combined
}
}
impl Default for VqaProcessor {
fn default() -> Self {
Self::new()
}
}
pub struct VisualQaPipeline {
pub config: VqaConfig,
processor: VqaProcessor,
}
impl VisualQaPipeline {
pub fn new(config: VqaConfig) -> std::result::Result<Self, PipelineError> {
Ok(Self {
config,
processor: VqaProcessor::new(),
})
}
pub fn score_answers(logits: &[f32], answer_vocab: &[String]) -> Vec<VqaResult> {
let max_l = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let exps: Vec<f32> = logits.iter().map(|&l| (l - max_l).exp()).collect();
let sum_exp: f32 = exps.iter().sum();
let probs: Vec<f32> = if sum_exp > 1e-8 {
exps.iter().map(|e| e / sum_exp).collect()
} else {
vec![1.0 / logits.len() as f32; logits.len()]
};
let mut results: Vec<VqaResult> = probs
.iter()
.enumerate()
.zip(answer_vocab.iter())
.map(|((id, &score), answer)| VqaResult {
answer: answer.clone(),
score,
answer_id: id,
})
.collect();
results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
results
}
pub fn answer(
&self,
input: VqaInput,
answer_vocab: &[String],
) -> std::result::Result<Vec<VqaResult>, PipelineError> {
if input.question.trim().is_empty() {
return Err(PipelineError::EmptyQuestion);
}
if input.image.pixels.is_empty() {
return Err(PipelineError::EmptyImage);
}
if answer_vocab.is_empty() {
return Err(PipelineError::EmptyVocabulary);
}
let q_tokens = self.processor.encode_question(&input.question);
let img_feats = VqaProcessor::encode_image_features(&input.image);
let q_text_feats: Vec<f32> = q_tokens.iter().map(|&t| t as f32 / 30_000.0).collect();
let combined = VqaProcessor::combine_modalities(&q_text_feats, &img_feats);
let logits: Vec<f32> = answer_vocab
.iter()
.enumerate()
.map(|(i, _)| {
let seed = combined.get(i % combined.len().max(1)).copied().unwrap_or(0.0);
seed + (i as f32 * 0.01)
})
.collect();
let mut results = Self::score_answers(&logits, answer_vocab);
results.truncate(self.config.top_k);
Ok(results)
}
pub fn answer_batch(
&self,
inputs: Vec<VqaInput>,
answer_vocab: &[String],
) -> std::result::Result<Vec<Vec<VqaResult>>, PipelineError> {
inputs.into_iter().map(|inp| self.answer(inp, answer_vocab)).collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::traits::{Model, TokenizedInput, Tokenizer};
use crate::AutoConfig;
use trustformers_core::Tensor;
#[derive(Clone)]
struct MockModel {
config: AutoConfig,
}
impl MockModel {
fn new() -> Self {
MockModel {
config: {
#[cfg(feature = "bert")]
{
AutoConfig::Bert(Default::default())
}
#[cfg(all(not(feature = "bert"), feature = "roberta"))]
{
AutoConfig::Roberta(Default::default())
}
#[cfg(all(not(feature = "bert"), not(feature = "roberta"), feature = "gpt2"))]
{
AutoConfig::Gpt2(Default::default())
}
#[cfg(all(
not(feature = "bert"),
not(feature = "roberta"),
not(feature = "gpt2"),
feature = "gpt_neo"
))]
{
AutoConfig::GptNeo(Default::default())
}
#[cfg(all(
not(feature = "bert"),
not(feature = "roberta"),
not(feature = "gpt2"),
not(feature = "gpt_neo"),
feature = "gpt_j"
))]
{
AutoConfig::GptJ(Default::default())
}
#[cfg(all(
not(feature = "bert"),
not(feature = "roberta"),
not(feature = "gpt2"),
not(feature = "gpt_neo"),
not(feature = "gpt_j"),
feature = "t5"
))]
{
AutoConfig::T5(Default::default())
}
#[cfg(all(
not(feature = "bert"),
not(feature = "roberta"),
not(feature = "gpt2"),
not(feature = "gpt_neo"),
not(feature = "gpt_j"),
not(feature = "t5"),
feature = "albert"
))]
{
AutoConfig::Albert(Default::default())
}
#[cfg(not(any(
feature = "bert",
feature = "roberta",
feature = "gpt2",
feature = "gpt_neo",
feature = "gpt_j",
feature = "t5",
feature = "albert"
)))]
{
compile_error!("At least one model feature must be enabled for tests (bert, roberta, gpt2, gpt_neo, gpt_j, t5, or albert)")
}
},
}
}
}
impl Model for MockModel {
type Input = Tensor;
type Output = Tensor;
type Config = AutoConfig;
fn forward(&self, _input: Self::Input) -> trustformers_core::errors::Result<Self::Output> {
Tensor::zeros(&[1, 10])
}
fn num_parameters(&self) -> usize {
1000 }
fn load_pretrained(
&mut self,
_reader: &mut dyn std::io::Read,
) -> trustformers_core::errors::Result<()> {
Ok(()) }
fn get_config(&self) -> &Self::Config {
&self.config
}
}
#[derive(Clone)]
struct MockTokenizer;
impl MockTokenizer {
fn new() -> Self {
MockTokenizer
}
}
impl Tokenizer for MockTokenizer {
fn encode(&self, _text: &str) -> trustformers_core::errors::Result<TokenizedInput> {
Ok(TokenizedInput {
input_ids: vec![1, 2, 3], attention_mask: vec![1, 1, 1],
token_type_ids: Some(vec![0, 0, 0]),
offset_mapping: None,
special_tokens_mask: None,
overflowing_tokens: None,
})
}
fn encode_pair(
&self,
_text_a: &str,
_text_b: &str,
) -> trustformers_core::errors::Result<TokenizedInput> {
Ok(TokenizedInput {
input_ids: vec![1, 2, 3, 4, 5], attention_mask: vec![1, 1, 1, 1, 1],
token_type_ids: Some(vec![0, 0, 0, 1, 1]),
offset_mapping: None,
special_tokens_mask: None,
overflowing_tokens: None,
})
}
fn decode(&self, _token_ids: &[u32]) -> trustformers_core::errors::Result<String> {
Ok("mock decoded text".to_string())
}
fn vocab_size(&self) -> usize {
1000
}
fn get_vocab(&self) -> std::collections::HashMap<String, u32> {
let mut vocab = std::collections::HashMap::new();
vocab.insert("test".to_string(), 1);
vocab.insert("mock".to_string(), 2);
vocab.insert("token".to_string(), 3);
vocab
}
fn token_to_id(&self, token: &str) -> Option<u32> {
match token {
"test" => Some(1),
"mock" => Some(2),
"token" => Some(3),
_ => None,
}
}
fn id_to_token(&self, id: u32) -> Option<String> {
match id {
1 => Some("test".to_string()),
2 => Some("mock".to_string()),
3 => Some("token".to_string()),
_ => None,
}
}
}
#[test]
fn test_vqa_pipeline_creation() {
let model = MockModel::new();
let tokenizer = MockTokenizer::new();
let pipeline = VisualQuestionAnsweringPipeline::new(model, tokenizer);
assert!(pipeline.is_ok());
}
#[test]
fn test_vqa_config() {
let config = VisualQuestionAnsweringConfig::default();
assert_eq!(config.max_question_length, 512);
assert_eq!(config.max_answer_length, 256);
assert_eq!(config.top_k_answers, 5);
}
#[test]
fn test_image_processor() {
let config = ImageConfig::default();
let processor = ImageProcessor::new(config).expect("operation failed in test");
let image = ImageInput::Tensor(vec![0.5; 224 * 224 * 3]);
let result = processor.process_image(&image);
assert!(result.is_ok());
}
#[test]
fn test_fusion_strategies() {
let fusion =
FusionModule::new(FusionStrategy::Concatenation).expect("operation failed in test");
let img_tensor =
Tensor::from_vec(vec![1.0, 2.0, 3.0], &[1, 3]).expect("tensor operation failed");
let q_tensor = Tensor::from_vec(vec![4.0, 5.0], &[1, 2]).expect("tensor operation failed");
let result = fusion.fuse(&img_tensor, &q_tensor);
assert!(result.is_ok());
}
#[test]
fn test_answer_generator() {
let generator = AnswerGenerator::new(AnswerGenerationStrategy::Generative)
.expect("operation failed in test");
let features =
Tensor::from_vec(vec![0.1, 0.2, 0.3], &[1, 3]).expect("tensor operation failed");
let config = VisualQuestionAnsweringConfig::default();
let result = generator.generate_answer(&features, "What is in the image?", &None, &config);
assert!(result.is_ok());
}
#[test]
fn test_reasoning_engine() {
let engine = ReasoningEngine::new();
let image_features = ImageFeatures {
global_features: vec![0.1, 0.2, 0.3],
patch_features: vec![],
detected_objects: vec![],
scene_description: None,
image_classification: None,
};
let result =
engine.generate_reasoning_chain("How many people are there?", "2", &image_features);
assert!(result.is_ok());
assert!(!result.expect("operation failed in test").is_empty());
}
#[test]
fn test_attention_visualizer() {
let visualizer = AttentionVisualizer::new();
let features =
Tensor::from_vec(vec![0.1, 0.2, 0.3], &[1, 3]).expect("tensor operation failed");
let image_tensor =
Tensor::from_vec(vec![0.5; 100], &[1, 100]).expect("tensor operation failed");
let question_tensor =
Tensor::from_vec(vec![0.3; 50], &[1, 50]).expect("tensor operation failed");
let result = visualizer.visualize_attention(&features, &image_tensor, &question_tensor);
assert!(result.is_ok());
}
#[test]
fn test_pipeline_configuration() {
let model = MockModel::new();
let tokenizer = MockTokenizer::new();
let pipeline = VisualQuestionAnsweringPipeline::new(model, tokenizer)
.expect("operation failed in test")
.with_fusion_strategy(FusionStrategy::CrossAttention)
.expect("operation failed in test")
.with_answer_generation(AnswerGenerationStrategy::Classification)
.expect("operation failed in test")
.with_confidence_threshold(0.5)
.with_top_k_answers(3);
assert!(matches!(
pipeline.config.fusion_strategy,
FusionStrategy::CrossAttention
));
assert!(matches!(
pipeline.config.answer_generation,
AnswerGenerationStrategy::Classification
));
assert_eq!(pipeline.config.confidence_threshold, 0.5);
assert_eq!(pipeline.config.top_k_answers, 3);
}
fn dummy_image(width: usize, height: usize) -> VqaImageInput {
VqaImageInput {
pixels: (0..(width * height * 3)).map(|i| (i % 256) as u8).collect(),
width,
height,
}
}
fn default_vocab() -> Vec<String> {
vec![
"yes".to_string(),
"no".to_string(),
"dog".to_string(),
"cat".to_string(),
"2".to_string(),
"3".to_string(),
"red".to_string(),
"blue".to_string(),
]
}
fn default_vqa_pipeline() -> VisualQaPipeline {
VisualQaPipeline::new(VqaConfig::default()).expect("pipeline creation ok")
}
#[test]
fn test_vqa_input_construction() {
let img = dummy_image(4, 4);
let input = VqaInput {
image: img.clone(),
question: "What color is the object?".to_string(),
};
assert_eq!(input.question, "What color is the object?");
assert_eq!(input.image.width, 4);
assert_eq!(input.image.height, 4);
assert_eq!(input.image.pixels.len(), 4 * 4 * 3);
}
#[test]
fn test_vqa_config_defaults() {
let cfg = VqaConfig::default();
assert_eq!(cfg.top_k, 5);
assert_eq!(cfg.image_size, 384);
assert!(cfg.max_answer_length > 0);
assert!(!cfg.model_id.is_empty());
}
#[test]
fn test_encode_question_nonempty() {
let proc = VqaProcessor::new();
let tokens = proc.encode_question("What is in the image?");
assert!(!tokens.is_empty());
}
#[test]
fn test_encode_question_token_count() {
let proc = VqaProcessor::new();
let question = "how many cats are there";
let tokens = proc.encode_question(question);
assert_eq!(tokens.len(), question.split_whitespace().count());
}
#[test]
fn test_encode_question_empty() {
let proc = VqaProcessor::new();
let tokens = proc.encode_question("");
assert!(tokens.is_empty());
}
#[test]
fn test_encode_image_features_length() {
let img = dummy_image(8, 8);
let feats = VqaProcessor::encode_image_features(&img);
assert_eq!(feats.len(), 8 * 8 * 3);
}
#[test]
fn test_encode_image_features_empty() {
let empty = VqaImageInput {
pixels: vec![],
width: 0,
height: 0,
};
let feats = VqaProcessor::encode_image_features(&empty);
assert!(feats.is_empty());
}
#[test]
fn test_encode_image_features_normalised() {
let img = VqaImageInput {
pixels: vec![128u8; 6 * 3],
width: 6,
height: 1,
};
let feats = VqaProcessor::encode_image_features(&img);
for &v in &feats {
assert!(v.is_finite(), "feature value {v} is not finite");
}
}
#[test]
fn test_combine_modalities_length() {
let text = vec![0.1_f32, 0.2, 0.3];
let image = vec![0.4_f32, 0.5, 0.6, 0.7];
let combined = VqaProcessor::combine_modalities(&text, &image);
assert_eq!(combined.len(), text.len() + image.len());
}
#[test]
fn test_combine_modalities_normalised() {
let text = vec![1.0_f32, 2.0, 3.0];
let image = vec![4.0_f32, 5.0];
let combined = VqaProcessor::combine_modalities(&text, &image);
let norm: f32 = combined.iter().map(|v| v * v).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 1e-5, "norm was {norm}, expected ~1.0");
}
#[test]
fn test_combine_modalities_zeros() {
let combined = VqaProcessor::combine_modalities(&[0.0; 4], &[0.0; 3]);
assert_eq!(combined.len(), 7);
for &v in &combined {
assert_eq!(v, 0.0);
}
}
#[test]
fn test_score_answers_probability_sum() {
let logits = vec![1.0_f32, 2.0, 0.5, 3.0];
let vocab: Vec<String> = ["a", "b", "c", "d"].iter().map(|s| s.to_string()).collect();
let results = VisualQaPipeline::score_answers(&logits, &vocab);
let total: f32 = results.iter().map(|r| r.score).sum();
assert!(
(total - 1.0).abs() < 1e-5,
"scores sum to {total}, expected 1.0"
);
}
#[test]
fn test_score_answers_sorted() {
let logits = vec![0.1_f32, 5.0, 2.0, 0.5];
let vocab: Vec<String> = ["a", "b", "c", "d"].iter().map(|s| s.to_string()).collect();
let results = VisualQaPipeline::score_answers(&logits, &vocab);
for i in 1..results.len() {
assert!(
results[i - 1].score >= results[i].score,
"scores not sorted at index {i}: {} < {}",
results[i - 1].score,
results[i].score
);
}
}
#[test]
fn test_score_answers_answer_id() {
let logits = vec![1.0_f32, 2.0, 3.0];
let vocab: Vec<String> = ["cat", "dog", "bird"].iter().map(|s| s.to_string()).collect();
let results = VisualQaPipeline::score_answers(&logits, &vocab);
assert_eq!(results[0].answer, "bird");
assert_eq!(results[0].answer_id, 2);
}
#[test]
fn test_vqa_pipeline_answer_ok() {
let pipeline = default_vqa_pipeline();
let input = VqaInput {
image: dummy_image(16, 16),
question: "What is this?".to_string(),
};
let vocab = default_vocab();
let results = pipeline.answer(input, &vocab).expect("answer failed");
assert!(!results.is_empty());
assert!(results.len() <= pipeline.config.top_k);
}
#[test]
fn test_vqa_pipeline_empty_question_error() {
let pipeline = default_vqa_pipeline();
let input = VqaInput {
image: dummy_image(8, 8),
question: " ".to_string(),
};
let err = pipeline.answer(input, &default_vocab()).unwrap_err();
assert!(matches!(err, PipelineError::EmptyQuestion));
}
#[test]
fn test_vqa_pipeline_empty_image_error() {
let pipeline = default_vqa_pipeline();
let input = VqaInput {
image: VqaImageInput {
pixels: vec![],
width: 0,
height: 0,
},
question: "Is there anything?".to_string(),
};
let err = pipeline.answer(input, &default_vocab()).unwrap_err();
assert!(matches!(err, PipelineError::EmptyImage));
}
#[test]
fn test_vqa_pipeline_empty_vocab_error() {
let pipeline = default_vqa_pipeline();
let input = VqaInput {
image: dummy_image(4, 4),
question: "What is this?".to_string(),
};
let err = pipeline.answer(input, &[]).unwrap_err();
assert!(matches!(err, PipelineError::EmptyVocabulary));
}
#[test]
fn test_vqa_pipeline_answer_batch_count() {
let pipeline = default_vqa_pipeline();
let vocab = default_vocab();
let inputs: Vec<VqaInput> = (0..3)
.map(|i| VqaInput {
image: dummy_image(4 + i, 4 + i),
question: format!("Question {i}?"),
})
.collect();
let batch_results = pipeline.answer_batch(inputs, &vocab).expect("batch failed");
assert_eq!(batch_results.len(), 3);
for results in &batch_results {
assert!(!results.is_empty());
}
}
#[test]
fn test_score_answers_single_vocab() {
let logits = vec![0.5_f32];
let vocab = vec!["yes".to_string()];
let results = VisualQaPipeline::score_answers(&logits, &vocab);
assert_eq!(results.len(), 1);
assert!((results[0].score - 1.0).abs() < 1e-5);
}
#[test]
fn test_vqa_pipeline_top_k_respected() {
let pipeline = VisualQaPipeline::new(VqaConfig {
top_k: 2,
..Default::default()
})
.expect("ok");
let vocab = default_vocab(); let input = VqaInput {
image: dummy_image(4, 4),
question: "What animal?".to_string(),
};
let results = pipeline.answer(input, &vocab).expect("answer ok");
assert!(results.len() <= 2);
}
#[test]
fn test_encode_image_features_scale_with_size() {
let small = dummy_image(4, 4);
let large = dummy_image(8, 8);
let f_small = VqaProcessor::encode_image_features(&small);
let f_large = VqaProcessor::encode_image_features(&large);
assert_eq!(f_large.len(), 4 * f_small.len());
}
}