use super::{AnalyzedChart, ChartType, ExtractedTable, MultiModalDocument, ProcessedImage};
use crate::{RragError, RragResult};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
pub struct MultiModalRetriever {
config: RetrievalConfig,
text_retriever: TextRetriever,
visual_retriever: VisualRetriever,
table_retriever: TableRetriever,
chart_retriever: ChartRetriever,
cross_modal_retriever: CrossModalRetriever,
result_fusion: ResultFusion,
}
#[derive(Debug, Clone)]
pub struct RetrievalConfig {
pub max_results_per_modality: usize,
pub max_total_results: usize,
pub similarity_thresholds: ModalitySimilarityThresholds,
pub enable_cross_modal: bool,
pub fusion_strategy: ResultFusionStrategy,
pub scoring_weights: ScoringWeights,
}
#[derive(Debug, Clone)]
pub struct ModalitySimilarityThresholds {
pub text_threshold: f32,
pub visual_threshold: f32,
pub table_threshold: f32,
pub chart_threshold: f32,
pub cross_modal_threshold: f32,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub enum ResultFusionStrategy {
WeightedCombination,
RankFusion,
ScoreNormalization,
ReciprocalRankFusion,
}
#[derive(Debug, Clone)]
pub struct ScoringWeights {
pub semantic_weight: f32,
pub visual_weight: f32,
pub structural_weight: f32,
pub temporal_weight: f32,
pub contextual_weight: f32,
}
#[derive(Debug, Clone)]
pub struct MultiModalQuery {
pub text_query: Option<String>,
pub visual_query: Option<VisualQuery>,
pub table_query: Option<TableQuery>,
pub chart_query: Option<ChartQuery>,
pub cross_modal_constraints: Vec<CrossModalConstraint>,
pub metadata: QueryMetadata,
}
#[derive(Debug, Clone)]
pub enum VisualQuery {
ImageExample(String),
FeatureQuery(VisualFeatureQuery),
DescriptionQuery(String),
}
#[derive(Debug, Clone)]
pub struct TableQuery {
pub schema: Option<TableSchema>,
pub content_filters: Vec<ContentFilter>,
pub statistical_constraints: Vec<StatisticalConstraint>,
pub size_constraints: Option<SizeConstraints>,
}
#[derive(Debug, Clone)]
pub struct ChartQuery {
pub chart_types: Vec<ChartType>,
pub data_constraints: Vec<DataConstraint>,
pub trend_requirements: Vec<TrendRequirement>,
pub value_ranges: Vec<ValueRange>,
}
#[derive(Debug, Clone)]
pub struct CrossModalConstraint {
pub source_modality: Modality,
pub target_modality: Modality,
pub constraint_type: ConstraintType,
pub parameters: HashMap<String, String>,
}
#[derive(Debug, Clone, Copy)]
pub enum Modality {
Text,
Visual,
Table,
Chart,
}
#[derive(Debug, Clone)]
pub enum ConstraintType {
ContentAlignment,
SemanticConsistency,
VisualCoherence,
TemporalAlignment,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MultiModalRetrievalResult {
pub documents: Vec<RankedDocument>,
pub processing_time_ms: u64,
pub metadata: ResultMetadata,
pub statistics: RetrievalStatistics,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RankedDocument {
pub document: MultiModalDocument,
pub relevance_score: f32,
pub modality_scores: ModalityScores,
pub rank: usize,
pub explanation: Option<RelevanceExplanation>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModalityScores {
pub text_score: Option<f32>,
pub visual_score: Option<f32>,
pub table_score: Option<f32>,
pub chart_score: Option<f32>,
pub cross_modal_score: Option<f32>,
}
pub struct TextRetriever {
semantic_searcher: SemanticSearcher,
keyword_searcher: KeywordSearcher,
hybrid_combiner: HybridCombiner,
}
pub struct VisualRetriever {
clip_retriever: CLIPRetriever,
feature_retriever: FeatureBasedRetriever,
similarity_calculator: VisualSimilarityCalculator,
}
pub struct TableRetriever {
schema_matcher: SchemaMatcher,
content_searcher: TableContentSearcher,
statistical_analyzer: TableStatisticalAnalyzer,
}
pub struct ChartRetriever {
type_classifier: ChartTypeClassifier,
pattern_matcher: DataPatternMatcher,
trend_analyzer: ChartTrendAnalyzer,
}
pub struct CrossModalRetriever {
image_text_aligner: ImageTextAligner,
table_text_checker: TableTextConsistencyChecker,
coherence_scorer: CoherenceScorer,
}
pub struct ResultFusion {
strategy: ResultFusionStrategy,
score_normalizers: HashMap<Modality, ScoreNormalizer>,
rank_aggregator: RankAggregator,
}
impl MultiModalRetriever {
pub fn new(config: RetrievalConfig) -> RragResult<Self> {
let text_retriever = TextRetriever::new()?;
let visual_retriever = VisualRetriever::new()?;
let table_retriever = TableRetriever::new()?;
let chart_retriever = ChartRetriever::new()?;
let cross_modal_retriever = CrossModalRetriever::new()?;
let result_fusion = ResultFusion::new(config.fusion_strategy)?;
Ok(Self {
config,
text_retriever,
visual_retriever,
table_retriever,
chart_retriever,
cross_modal_retriever,
result_fusion,
})
}
pub async fn retrieve(
&self,
query: &MultiModalQuery,
documents: &[MultiModalDocument],
) -> RragResult<MultiModalRetrievalResult> {
let start_time = std::time::Instant::now();
let text_results = if let Some(ref text_q) = query.text_query {
self.text_retriever.retrieve(text_q, documents).await?
} else {
vec![]
};
let visual_results = if let Some(ref visual_q) = query.visual_query {
self.visual_retriever.retrieve(visual_q, documents).await?
} else {
vec![]
};
let table_results = if let Some(ref table_q) = query.table_query {
self.table_retriever.retrieve(table_q, documents).await?
} else {
vec![]
};
let chart_results = if let Some(ref chart_q) = query.chart_query {
self.chart_retriever.retrieve(chart_q, documents).await?
} else {
vec![]
};
let cross_modal_results = if self.config.enable_cross_modal {
self.cross_modal_retriever
.retrieve(query, documents)
.await?
} else {
vec![]
};
let fused_results = self.result_fusion.fuse_results(
&text_results,
&visual_results,
&table_results,
&chart_results,
&cross_modal_results,
&self.config.scoring_weights,
)?;
let processing_time = start_time.elapsed().as_millis() as u64;
Ok(MultiModalRetrievalResult {
documents: fused_results,
processing_time_ms: processing_time,
metadata: ResultMetadata {
total_documents_searched: documents.len(),
modalities_used: self.count_modalities_used(query),
fusion_strategy_used: self.config.fusion_strategy,
},
statistics: RetrievalStatistics {
text_results_count: text_results.len(),
visual_results_count: visual_results.len(),
table_results_count: table_results.len(),
chart_results_count: chart_results.len(),
cross_modal_results_count: cross_modal_results.len(),
},
})
}
fn count_modalities_used(&self, query: &MultiModalQuery) -> usize {
let mut count = 0;
if query.text_query.is_some() {
count += 1;
}
if query.visual_query.is_some() {
count += 1;
}
if query.table_query.is_some() {
count += 1;
}
if query.chart_query.is_some() {
count += 1;
}
count
}
pub async fn retrieve_by_embedding(
&self,
embedding: &[f32],
documents: &[MultiModalDocument],
) -> RragResult<Vec<RankedDocument>> {
let mut scored_documents = Vec::new();
for (idx, document) in documents.iter().enumerate() {
let similarity = self
.calculate_embedding_similarity(embedding, &document.embeddings.fused_embedding)?;
if similarity >= self.config.similarity_thresholds.text_threshold {
scored_documents.push(RankedDocument {
document: document.clone(),
relevance_score: similarity,
modality_scores: ModalityScores {
text_score: Some(similarity),
visual_score: None,
table_score: None,
chart_score: None,
cross_modal_score: None,
},
rank: idx,
explanation: None,
});
}
}
scored_documents.sort_by(|a, b| b.relevance_score.partial_cmp(&a.relevance_score).unwrap());
for (idx, doc) in scored_documents.iter_mut().enumerate() {
doc.rank = idx;
}
scored_documents.truncate(self.config.max_total_results);
Ok(scored_documents)
}
fn calculate_embedding_similarity(&self, a: &[f32], b: &[f32]) -> RragResult<f32> {
if a.len() != b.len() {
return Err(RragError::validation(
"embedding_dimensions",
"matching dimensions",
"mismatched dimensions",
));
}
let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a == 0.0 || norm_b == 0.0 {
Ok(0.0)
} else {
Ok(dot_product / (norm_a * norm_b))
}
}
}
impl TextRetriever {
pub fn new() -> RragResult<Self> {
Ok(Self {
semantic_searcher: SemanticSearcher::new(),
keyword_searcher: KeywordSearcher::new(),
hybrid_combiner: HybridCombiner::new(),
})
}
pub async fn retrieve(
&self,
query: &str,
documents: &[MultiModalDocument],
) -> RragResult<Vec<(usize, f32)>> {
let semantic_results = self.semantic_searcher.search(query, documents)?;
let keyword_results = self.keyword_searcher.search(query, documents)?;
let combined_results = self
.hybrid_combiner
.combine(semantic_results, keyword_results)?;
Ok(combined_results)
}
}
impl VisualRetriever {
pub fn new() -> RragResult<Self> {
Ok(Self {
clip_retriever: CLIPRetriever::new(),
feature_retriever: FeatureBasedRetriever::new(),
similarity_calculator: VisualSimilarityCalculator::new(),
})
}
pub async fn retrieve(
&self,
query: &VisualQuery,
documents: &[MultiModalDocument],
) -> RragResult<Vec<(usize, f32)>> {
match query {
VisualQuery::ImageExample(path) => {
self.clip_retriever
.retrieve_by_example(path, documents)
.await
}
VisualQuery::FeatureQuery(features) => {
self.feature_retriever
.retrieve_by_features(features, documents)
.await
}
VisualQuery::DescriptionQuery(description) => {
self.clip_retriever
.retrieve_by_description(description, documents)
.await
}
}
}
}
impl TableRetriever {
pub fn new() -> RragResult<Self> {
Ok(Self {
schema_matcher: SchemaMatcher::new(),
content_searcher: TableContentSearcher::new(),
statistical_analyzer: TableStatisticalAnalyzer::new(),
})
}
pub async fn retrieve(
&self,
query: &TableQuery,
documents: &[MultiModalDocument],
) -> RragResult<Vec<(usize, f32)>> {
let mut results = Vec::new();
for (doc_idx, document) in documents.iter().enumerate() {
if !document.tables.is_empty() {
let mut table_score = 0.0;
let mut matching_tables = 0;
for table in &document.tables {
let mut score = 0.0;
if let Some(ref schema) = query.schema {
score += self.schema_matcher.match_schema(schema, table)? * 0.3;
}
for filter in &query.content_filters {
score += self.content_searcher.apply_filter(filter, table)? * 0.4;
}
for constraint in &query.statistical_constraints {
score += self
.statistical_analyzer
.check_constraint(constraint, table)?
* 0.3;
}
if score > 0.0 {
table_score += score;
matching_tables += 1;
}
}
if matching_tables > 0 {
let avg_score = table_score / matching_tables as f32;
results.push((doc_idx, avg_score));
}
}
}
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
Ok(results)
}
}
impl ChartRetriever {
pub fn new() -> RragResult<Self> {
Ok(Self {
type_classifier: ChartTypeClassifier::new(),
pattern_matcher: DataPatternMatcher::new(),
trend_analyzer: ChartTrendAnalyzer::new(),
})
}
pub async fn retrieve(
&self,
query: &ChartQuery,
documents: &[MultiModalDocument],
) -> RragResult<Vec<(usize, f32)>> {
let mut results = Vec::new();
for (doc_idx, document) in documents.iter().enumerate() {
if !document.charts.is_empty() {
let mut chart_score = 0.0;
let mut matching_charts = 0;
for chart in &document.charts {
let mut score = 0.0;
if query.chart_types.contains(&chart.chart_type) {
score += 0.3;
}
for constraint in &query.data_constraints {
score += self.pattern_matcher.check_constraint(constraint, chart)? * 0.4;
}
if let Some(ref trends) = chart.trends {
for requirement in &query.trend_requirements {
score +=
self.trend_analyzer.check_requirement(requirement, trends)? * 0.3;
}
}
if score > 0.0 {
chart_score += score;
matching_charts += 1;
}
}
if matching_charts > 0 {
let avg_score = chart_score / matching_charts as f32;
results.push((doc_idx, avg_score));
}
}
}
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
Ok(results)
}
}
impl CrossModalRetriever {
pub fn new() -> RragResult<Self> {
Ok(Self {
image_text_aligner: ImageTextAligner::new(),
table_text_checker: TableTextConsistencyChecker::new(),
coherence_scorer: CoherenceScorer::new(),
})
}
pub async fn retrieve(
&self,
query: &MultiModalQuery,
documents: &[MultiModalDocument],
) -> RragResult<Vec<(usize, f32)>> {
let mut results = Vec::new();
for (doc_idx, document) in documents.iter().enumerate() {
let mut cross_modal_score = 0.0;
let mut constraint_count = 0;
for constraint in &query.cross_modal_constraints {
let score = match constraint.constraint_type {
ConstraintType::ContentAlignment => self
.image_text_aligner
.calculate_alignment(&document.text_content, &document.images)?,
ConstraintType::SemanticConsistency => self
.table_text_checker
.check_consistency(&document.text_content, &document.tables)?,
ConstraintType::VisualCoherence => {
self.coherence_scorer.score_visual_coherence(document)?
}
ConstraintType::TemporalAlignment => {
0.7 }
};
cross_modal_score += score;
constraint_count += 1;
}
if constraint_count > 0 {
let avg_score = cross_modal_score / constraint_count as f32;
results.push((doc_idx, avg_score));
}
}
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
Ok(results)
}
}
impl ResultFusion {
pub fn new(strategy: ResultFusionStrategy) -> RragResult<Self> {
Ok(Self {
strategy,
score_normalizers: HashMap::new(),
rank_aggregator: RankAggregator::new(),
})
}
pub fn fuse_results(
&self,
text_results: &[(usize, f32)],
visual_results: &[(usize, f32)],
table_results: &[(usize, f32)],
chart_results: &[(usize, f32)],
cross_modal_results: &[(usize, f32)],
weights: &ScoringWeights,
) -> RragResult<Vec<RankedDocument>> {
match self.strategy {
ResultFusionStrategy::WeightedCombination => self.weighted_fusion(
text_results,
visual_results,
table_results,
chart_results,
cross_modal_results,
weights,
),
ResultFusionStrategy::RankFusion => self.rank_fusion(
text_results,
visual_results,
table_results,
chart_results,
cross_modal_results,
),
ResultFusionStrategy::ScoreNormalization => self.score_normalization_fusion(
text_results,
visual_results,
table_results,
chart_results,
cross_modal_results,
weights,
),
ResultFusionStrategy::ReciprocalRankFusion => self.reciprocal_rank_fusion(
text_results,
visual_results,
table_results,
chart_results,
cross_modal_results,
),
}
}
fn weighted_fusion(
&self,
text_results: &[(usize, f32)],
visual_results: &[(usize, f32)],
_table_results: &[(usize, f32)],
_chart_results: &[(usize, f32)],
_cross_modal_results: &[(usize, f32)],
weights: &ScoringWeights,
) -> RragResult<Vec<RankedDocument>> {
let mut document_scores: HashMap<usize, f32> = HashMap::new();
let mut modality_scores: HashMap<usize, ModalityScores> = HashMap::new();
for &(doc_idx, score) in text_results {
*document_scores.entry(doc_idx).or_insert(0.0) += score * weights.semantic_weight;
modality_scores
.entry(doc_idx)
.or_insert(ModalityScores {
text_score: None,
visual_score: None,
table_score: None,
chart_score: None,
cross_modal_score: None,
})
.text_score = Some(score);
}
for &(doc_idx, score) in visual_results {
*document_scores.entry(doc_idx).or_insert(0.0) += score * weights.visual_weight;
modality_scores
.entry(doc_idx)
.or_insert(ModalityScores {
text_score: None,
visual_score: None,
table_score: None,
chart_score: None,
cross_modal_score: None,
})
.visual_score = Some(score);
}
let mut ranked_docs: Vec<(usize, f32, ModalityScores)> = document_scores
.into_iter()
.map(|(doc_idx, score)| {
let scores = modality_scores.remove(&doc_idx).unwrap_or(ModalityScores {
text_score: None,
visual_score: None,
table_score: None,
chart_score: None,
cross_modal_score: None,
});
(doc_idx, score, scores)
})
.collect();
ranked_docs.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
Ok(vec![])
}
fn rank_fusion(
&self,
_text: &[(usize, f32)],
_visual: &[(usize, f32)],
_table: &[(usize, f32)],
_chart: &[(usize, f32)],
_cross: &[(usize, f32)],
) -> RragResult<Vec<RankedDocument>> {
Ok(vec![])
}
fn score_normalization_fusion(
&self,
_text: &[(usize, f32)],
_visual: &[(usize, f32)],
_table: &[(usize, f32)],
_chart: &[(usize, f32)],
_cross: &[(usize, f32)],
_weights: &ScoringWeights,
) -> RragResult<Vec<RankedDocument>> {
Ok(vec![])
}
fn reciprocal_rank_fusion(
&self,
_text: &[(usize, f32)],
_visual: &[(usize, f32)],
_table: &[(usize, f32)],
_chart: &[(usize, f32)],
_cross: &[(usize, f32)],
) -> RragResult<Vec<RankedDocument>> {
Ok(vec![])
}
}
impl SemanticSearcher {
pub fn new() -> Self {
Self
}
pub fn search(
&self,
_query: &str,
_documents: &[MultiModalDocument],
) -> RragResult<Vec<(usize, f32)>> {
Ok(vec![(0, 0.8), (1, 0.6), (2, 0.4)])
}
}
impl KeywordSearcher {
pub fn new() -> Self {
Self
}
pub fn search(
&self,
_query: &str,
_documents: &[MultiModalDocument],
) -> RragResult<Vec<(usize, f32)>> {
Ok(vec![(0, 0.7), (2, 0.5), (3, 0.3)])
}
}
impl HybridCombiner {
pub fn new() -> Self {
Self
}
pub fn combine(
&self,
semantic: Vec<(usize, f32)>,
keyword: Vec<(usize, f32)>,
) -> RragResult<Vec<(usize, f32)>> {
let mut combined = HashMap::new();
for (idx, score) in semantic {
combined.insert(idx, score * 0.7);
}
for (idx, score) in keyword {
*combined.entry(idx).or_insert(0.0) += score * 0.3;
}
let mut results: Vec<(usize, f32)> = combined.into_iter().collect();
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
Ok(results)
}
}
impl CLIPRetriever {
pub fn new() -> Self {
Self
}
pub async fn retrieve_by_example(
&self,
_path: &str,
_documents: &[MultiModalDocument],
) -> RragResult<Vec<(usize, f32)>> {
Ok(vec![(0, 0.9), (1, 0.7)])
}
pub async fn retrieve_by_description(
&self,
_description: &str,
_documents: &[MultiModalDocument],
) -> RragResult<Vec<(usize, f32)>> {
Ok(vec![(0, 0.8), (2, 0.6)])
}
}
impl FeatureBasedRetriever {
pub fn new() -> Self {
Self
}
pub async fn retrieve_by_features(
&self,
_features: &VisualFeatureQuery,
_documents: &[MultiModalDocument],
) -> RragResult<Vec<(usize, f32)>> {
Ok(vec![(1, 0.85), (3, 0.5)])
}
}
impl VisualSimilarityCalculator {
pub fn new() -> Self {
Self
}
}
impl SchemaMatcher {
pub fn new() -> Self {
Self
}
pub fn match_schema(&self, _schema: &TableSchema, _table: &ExtractedTable) -> RragResult<f32> {
Ok(0.8)
}
}
impl TableContentSearcher {
pub fn new() -> Self {
Self
}
pub fn apply_filter(
&self,
_filter: &ContentFilter,
_table: &ExtractedTable,
) -> RragResult<f32> {
Ok(0.7)
}
}
impl TableStatisticalAnalyzer {
pub fn new() -> Self {
Self
}
pub fn check_constraint(
&self,
_constraint: &StatisticalConstraint,
_table: &ExtractedTable,
) -> RragResult<f32> {
Ok(0.6)
}
}
impl ChartTypeClassifier {
pub fn new() -> Self {
Self
}
}
impl DataPatternMatcher {
pub fn new() -> Self {
Self
}
pub fn check_constraint(
&self,
_constraint: &DataConstraint,
_chart: &AnalyzedChart,
) -> RragResult<f32> {
Ok(0.7)
}
}
impl ChartTrendAnalyzer {
pub fn new() -> Self {
Self
}
pub fn check_requirement(
&self,
_requirement: &TrendRequirement,
_trends: &super::TrendAnalysis,
) -> RragResult<f32> {
Ok(0.8)
}
}
impl ImageTextAligner {
pub fn new() -> Self {
Self
}
pub fn calculate_alignment(&self, _text: &str, _images: &[ProcessedImage]) -> RragResult<f32> {
Ok(0.75)
}
}
impl TableTextConsistencyChecker {
pub fn new() -> Self {
Self
}
pub fn check_consistency(&self, _text: &str, _tables: &[ExtractedTable]) -> RragResult<f32> {
Ok(0.8)
}
}
impl CoherenceScorer {
pub fn new() -> Self {
Self
}
pub fn score_visual_coherence(&self, _document: &MultiModalDocument) -> RragResult<f32> {
Ok(0.7)
}
}
impl RankAggregator {
pub fn new() -> Self {
Self
}
}
impl ScoreNormalizer {
pub fn new() -> Self {
Self
}
}
#[derive(Debug, Clone)]
pub struct VisualFeatureQuery {
pub colors: Option<Vec<String>>,
pub objects: Option<Vec<String>>,
pub scene_type: Option<String>,
}
#[derive(Debug, Clone)]
pub struct TableSchema {
pub columns: Vec<ColumnSchema>,
pub constraints: Vec<SchemaConstraint>,
}
#[derive(Debug, Clone)]
pub struct ColumnSchema {
pub name: String,
pub data_type: super::DataType,
pub required: bool,
}
#[derive(Debug, Clone)]
pub struct SchemaConstraint {
pub constraint_type: String,
pub parameters: HashMap<String, String>,
}
#[derive(Debug, Clone)]
pub struct ContentFilter {
pub column: String,
pub operator: FilterOperator,
pub value: String,
}
#[derive(Debug, Clone)]
pub enum FilterOperator {
Equals,
Contains,
GreaterThan,
LessThan,
Between,
}
#[derive(Debug, Clone)]
pub struct StatisticalConstraint {
pub metric: StatisticalMetric,
pub operator: FilterOperator,
pub value: f64,
}
#[derive(Debug, Clone)]
pub enum StatisticalMetric {
Mean,
Median,
StandardDeviation,
Count,
}
#[derive(Debug, Clone)]
pub struct SizeConstraints {
pub min_rows: Option<usize>,
pub max_rows: Option<usize>,
pub min_cols: Option<usize>,
pub max_cols: Option<usize>,
}
#[derive(Debug, Clone)]
pub struct DataConstraint {
pub constraint_type: String,
pub parameters: HashMap<String, String>,
}
#[derive(Debug, Clone)]
pub struct TrendRequirement {
pub trend_type: String,
pub strength: Option<f32>,
}
#[derive(Debug, Clone)]
pub struct ValueRange {
pub min: f64,
pub max: f64,
}
#[derive(Debug, Clone)]
pub struct QueryMetadata {
pub query_id: String,
pub timestamp: String,
pub user_id: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ResultMetadata {
pub total_documents_searched: usize,
pub modalities_used: usize,
pub fusion_strategy_used: ResultFusionStrategy,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RetrievalStatistics {
pub text_results_count: usize,
pub visual_results_count: usize,
pub table_results_count: usize,
pub chart_results_count: usize,
pub cross_modal_results_count: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RelevanceExplanation {
pub primary_matches: Vec<String>,
pub cross_modal_connections: Vec<String>,
pub confidence_factors: HashMap<String, f32>,
}
pub struct SemanticSearcher;
pub struct KeywordSearcher;
pub struct HybridCombiner;
pub struct CLIPRetriever;
pub struct FeatureBasedRetriever;
pub struct VisualSimilarityCalculator;
pub struct SchemaMatcher;
pub struct TableContentSearcher;
pub struct TableStatisticalAnalyzer;
pub struct ChartTypeClassifier;
pub struct DataPatternMatcher;
pub struct ChartTrendAnalyzer;
pub struct ImageTextAligner;
pub struct TableTextConsistencyChecker;
pub struct CoherenceScorer;
pub struct RankAggregator;
pub struct ScoreNormalizer;
impl Default for RetrievalConfig {
fn default() -> Self {
Self {
max_results_per_modality: 50,
max_total_results: 100,
similarity_thresholds: ModalitySimilarityThresholds {
text_threshold: 0.5,
visual_threshold: 0.6,
table_threshold: 0.4,
chart_threshold: 0.5,
cross_modal_threshold: 0.7,
},
enable_cross_modal: true,
fusion_strategy: ResultFusionStrategy::WeightedCombination,
scoring_weights: ScoringWeights {
semantic_weight: 0.4,
visual_weight: 0.3,
structural_weight: 0.2,
temporal_weight: 0.05,
contextual_weight: 0.05,
},
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_retriever_creation() {
let config = RetrievalConfig::default();
let retriever = MultiModalRetriever::new(config).unwrap();
assert_eq!(retriever.config.max_total_results, 100);
assert!(retriever.config.enable_cross_modal);
}
#[test]
fn test_embedding_similarity() {
let config = RetrievalConfig::default();
let retriever = MultiModalRetriever::new(config).unwrap();
let emb1 = vec![1.0, 0.0, 0.0];
let emb2 = vec![1.0, 0.0, 0.0];
let emb3 = vec![0.0, 1.0, 0.0];
let sim1 = retriever
.calculate_embedding_similarity(&emb1, &emb2)
.unwrap();
let sim2 = retriever
.calculate_embedding_similarity(&emb1, &emb3)
.unwrap();
assert!((sim1 - 1.0).abs() < 1e-6);
assert!((sim2 - 0.0).abs() < 1e-6);
}
}