pub mod chart_processor;
pub mod document_parser;
pub mod embedding_fusion;
pub mod image_processor;
pub mod layout_analysis;
pub mod ocr;
pub mod retrieval;
pub mod table_processor;
use crate::RragResult;
use serde::{Deserialize, Serialize};
use std::path::Path;
pub struct MultiModalService {
config: MultiModalConfig,
image_processor: Box<dyn ImageProcessor>,
table_processor: Box<dyn TableProcessor>,
chart_processor: Box<dyn ChartProcessor>,
ocr_engine: Box<dyn OCREngine>,
layout_analyzer: Box<dyn LayoutAnalyzer>,
fusion_strategy: Box<dyn EmbeddingFusionStrategy>,
}
#[derive(Debug, Clone)]
pub struct MultiModalConfig {
pub process_images: bool,
pub process_tables: bool,
pub process_charts: bool,
pub image_config: ImageProcessingConfig,
pub table_config: TableExtractionConfig,
pub chart_config: ChartAnalysisConfig,
pub ocr_config: OCRConfig,
pub layout_config: LayoutAnalysisConfig,
pub fusion_strategy: FusionStrategy,
}
#[derive(Debug, Clone)]
pub struct ImageProcessingConfig {
pub max_width: u32,
pub max_height: u32,
pub supported_formats: Vec<ImageFormat>,
pub use_clip: bool,
pub generate_captions: bool,
pub extract_features: bool,
pub compression_quality: u8,
}
#[derive(Debug, Clone)]
pub struct TableExtractionConfig {
pub min_rows: usize,
pub min_cols: usize,
pub extract_headers: bool,
pub infer_types: bool,
pub generate_summaries: bool,
pub output_format: TableOutputFormat,
}
#[derive(Debug, Clone)]
pub struct ChartAnalysisConfig {
pub chart_types: Vec<ChartType>,
pub extract_data: bool,
pub generate_descriptions: bool,
pub analyze_trends: bool,
}
#[derive(Debug, Clone)]
pub struct OCRConfig {
pub engine: OCREngineType,
pub languages: Vec<String>,
pub confidence_threshold: f32,
pub spell_correction: bool,
pub preserve_formatting: bool,
}
#[derive(Debug, Clone)]
pub struct LayoutAnalysisConfig {
pub detect_structure: bool,
pub identify_sections: bool,
pub extract_reading_order: bool,
pub detect_columns: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MultiModalDocument {
pub id: String,
pub text_content: String,
pub images: Vec<ProcessedImage>,
pub tables: Vec<ExtractedTable>,
pub charts: Vec<AnalyzedChart>,
pub layout: DocumentLayout,
pub embeddings: MultiModalEmbeddings,
pub metadata: DocumentMetadata,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProcessedImage {
pub id: String,
pub source: String,
pub caption: Option<String>,
pub ocr_text: Option<String>,
pub features: Option<VisualFeatures>,
pub clip_embedding: Option<Vec<f32>>,
pub metadata: ImageMetadata,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExtractedTable {
pub id: String,
pub headers: Vec<String>,
pub rows: Vec<Vec<TableCell>>,
pub summary: Option<String>,
pub column_types: Vec<DataType>,
pub embedding: Option<Vec<f32>>,
pub statistics: Option<TableStatistics>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AnalyzedChart {
pub id: String,
pub chart_type: ChartType,
pub title: Option<String>,
pub axes: ChartAxes,
pub data_points: Vec<DataPoint>,
pub trends: Option<TrendAnalysis>,
pub description: Option<String>,
pub embedding: Option<Vec<f32>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DocumentLayout {
pub pages: usize,
pub sections: Vec<DocumentSection>,
pub reading_order: Vec<String>,
pub columns: Option<ColumnLayout>,
pub document_type: DocumentType,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MultiModalEmbeddings {
pub text_embeddings: Vec<f32>,
pub visual_embeddings: Option<Vec<f32>>,
pub table_embeddings: Option<Vec<f32>>,
pub fused_embedding: Vec<f32>,
pub weights: EmbeddingWeights,
}
pub trait ImageProcessor: Send + Sync {
fn process_image(&self, image_path: &Path) -> RragResult<ProcessedImage>;
fn extract_features(&self, image_path: &Path) -> RragResult<VisualFeatures>;
fn generate_caption(&self, image_path: &Path) -> RragResult<String>;
fn generate_clip_embedding(&self, image_path: &Path) -> RragResult<Vec<f32>>;
}
pub trait TableProcessor: Send + Sync {
fn extract_table(&self, content: &str) -> RragResult<Vec<ExtractedTable>>;
fn parse_structure(&self, table_html: &str) -> RragResult<ExtractedTable>;
fn generate_summary(&self, table: &ExtractedTable) -> RragResult<String>;
fn calculate_statistics(&self, table: &ExtractedTable) -> RragResult<TableStatistics>;
}
pub trait ChartProcessor: Send + Sync {
fn analyze_chart(&self, image_path: &Path) -> RragResult<AnalyzedChart>;
fn extract_data_points(&self, chart_image: &Path) -> RragResult<Vec<DataPoint>>;
fn identify_type(&self, chart_image: &Path) -> RragResult<ChartType>;
fn analyze_trends(&self, data_points: &[DataPoint]) -> RragResult<TrendAnalysis>;
}
pub trait OCREngine: Send + Sync {
fn ocr(&self, image_path: &Path) -> RragResult<OCRResult>;
fn get_text_with_confidence(&self, image_path: &Path) -> RragResult<Vec<(String, f32)>>;
fn get_layout(&self, image_path: &Path) -> RragResult<TextLayout>;
}
pub trait LayoutAnalyzer: Send + Sync {
fn analyze_layout(&self, document_path: &Path) -> RragResult<DocumentLayout>;
fn detect_sections(&self, content: &str) -> RragResult<Vec<DocumentSection>>;
fn extract_reading_order(&self, layout: &DocumentLayout) -> RragResult<Vec<String>>;
}
pub trait EmbeddingFusionStrategy: Send + Sync {
fn fuse_embeddings(&self, embeddings: &MultiModalEmbeddings) -> RragResult<Vec<f32>>;
fn calculate_weights(&self, document: &MultiModalDocument) -> RragResult<EmbeddingWeights>;
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ImageFormat {
JPEG,
PNG,
GIF,
BMP,
WEBP,
SVG,
TIFF,
}
#[derive(Debug, Clone, Copy)]
pub enum TableOutputFormat {
CSV,
JSON,
Markdown,
HTML,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum ChartType {
Line,
Bar,
Pie,
Scatter,
Area,
Histogram,
HeatMap,
Box,
Unknown,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum OCREngineType {
Tesseract,
EasyOCR,
PaddleOCR,
CloudVision,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
pub enum DocumentType {
PDF,
Word,
PowerPoint,
HTML,
Markdown,
PlainText,
Mixed,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub enum DataType {
String,
Number,
Date,
Boolean,
Mixed,
}
#[derive(Debug, Clone, Copy)]
pub enum FusionStrategy {
Average,
Weighted,
Concatenate,
Attention,
Learned,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VisualFeatures {
pub colors: Vec<Color>,
pub objects: Vec<DetectedObject>,
pub scene: Option<String>,
pub quality: ImageQuality,
pub layout: SpatialLayout,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TableCell {
pub value: String,
pub data_type: DataType,
pub formatting: Option<CellFormatting>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TableStatistics {
pub row_count: usize,
pub column_count: usize,
pub null_percentages: Vec<f32>,
pub column_stats: Vec<ColumnStatistics>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ColumnStatistics {
pub name: String,
pub numeric_stats: Option<NumericStatistics>,
pub text_stats: Option<TextStatistics>,
pub unique_count: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NumericStatistics {
pub min: f64,
pub max: f64,
pub mean: f64,
pub median: f64,
pub std_dev: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TextStatistics {
pub min_length: usize,
pub max_length: usize,
pub avg_length: f32,
pub most_common: Vec<(String, usize)>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChartAxes {
pub x_label: Option<String>,
pub y_label: Option<String>,
pub x_range: Option<(f64, f64)>,
pub y_range: Option<(f64, f64)>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DataPoint {
pub x: f64,
pub y: f64,
pub label: Option<String>,
pub series: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TrendAnalysis {
pub direction: TrendDirection,
pub strength: f32,
pub seasonality: Option<Seasonality>,
pub outliers: Vec<DataPoint>,
pub forecast: Option<Vec<DataPoint>>,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
pub enum TrendDirection {
Increasing,
Decreasing,
Stable,
Volatile,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Seasonality {
pub period: f64,
pub amplitude: f64,
pub phase: f64,
}
#[derive(Debug, Clone)]
pub struct OCRResult {
pub text: String,
pub confidence: f32,
pub words: Vec<OCRWord>,
pub languages: Vec<String>,
}
#[derive(Debug, Clone)]
pub struct OCRWord {
pub text: String,
pub confidence: f32,
pub bounding_box: BoundingBox,
}
#[derive(Debug, Clone)]
pub struct BoundingBox {
pub x: u32,
pub y: u32,
pub width: u32,
pub height: u32,
}
#[derive(Debug, Clone)]
pub struct TextLayout {
pub blocks: Vec<TextBlock>,
pub reading_order: Vec<usize>,
pub columns: Option<Vec<Column>>,
}
#[derive(Debug, Clone)]
pub struct TextBlock {
pub id: usize,
pub text: String,
pub bounding_box: BoundingBox,
pub block_type: BlockType,
}
#[derive(Debug, Clone, Copy)]
pub enum BlockType {
Title,
Heading,
Paragraph,
Caption,
Footer,
Header,
}
#[derive(Debug, Clone)]
pub struct Column {
pub index: usize,
pub blocks: Vec<usize>,
pub width: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DocumentSection {
pub id: String,
pub title: Option<String>,
pub content: String,
pub section_type: SectionType,
pub level: usize,
pub page_range: (usize, usize),
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub enum SectionType {
Title,
Abstract,
Introduction,
Body,
Conclusion,
References,
Appendix,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ColumnLayout {
pub column_count: usize,
pub column_widths: Vec<f32>,
pub gutter_width: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DocumentMetadata {
pub title: Option<String>,
pub author: Option<String>,
pub creation_date: Option<String>,
pub modification_date: Option<String>,
pub page_count: usize,
pub word_count: usize,
pub language: String,
pub format: DocumentType,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ImageMetadata {
pub width: u32,
pub height: u32,
pub format: String,
pub size_bytes: usize,
pub dpi: Option<u32>,
pub color_space: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Color {
pub rgb: (u8, u8, u8),
pub percentage: f32,
pub name: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DetectedObject {
pub class: String,
pub confidence: f32,
pub bounding_box: (f32, f32, f32, f32),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ImageQuality {
pub sharpness: f32,
pub contrast: f32,
pub brightness: f32,
pub noise_level: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SpatialLayout {
pub composition_type: CompositionType,
pub focal_points: Vec<(f32, f32)>,
pub balance: f32,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub enum CompositionType {
RuleOfThirds,
Centered,
Diagonal,
Symmetrical,
Asymmetrical,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CellFormatting {
pub bold: bool,
pub italic: bool,
pub color: Option<String>,
pub background: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingWeights {
pub text_weight: f32,
pub visual_weight: f32,
pub table_weight: f32,
pub chart_weight: f32,
}
impl MultiModalService {
pub fn new(config: MultiModalConfig) -> RragResult<Self> {
Ok(Self {
config: config.clone(),
image_processor: Box::new(image_processor::DefaultImageProcessor::new(
config.image_config,
)?),
table_processor: Box::new(table_processor::DefaultTableProcessor::new(
config.table_config,
)?),
chart_processor: Box::new(chart_processor::DefaultChartProcessor::new(
config.chart_config,
)?),
ocr_engine: Box::new(ocr::DefaultOCREngine::new(config.ocr_config)?),
layout_analyzer: Box::new(layout_analysis::DefaultLayoutAnalyzer::new(
config.layout_config,
)?),
fusion_strategy: Box::new(embedding_fusion::DefaultFusionStrategy::new(
config.fusion_strategy,
)?),
})
}
pub async fn process_document(&self, _document_path: &Path) -> RragResult<MultiModalDocument> {
todo!("Implement multi-modal document processing")
}
pub async fn extract_modalities(&self, _content: &[u8]) -> RragResult<MultiModalDocument> {
todo!("Implement modality extraction")
}
}
impl Default for MultiModalConfig {
fn default() -> Self {
Self {
process_images: true,
process_tables: true,
process_charts: true,
image_config: ImageProcessingConfig::default(),
table_config: TableExtractionConfig::default(),
chart_config: ChartAnalysisConfig::default(),
ocr_config: OCRConfig::default(),
layout_config: LayoutAnalysisConfig::default(),
fusion_strategy: FusionStrategy::Weighted,
}
}
}
impl Default for ImageProcessingConfig {
fn default() -> Self {
Self {
max_width: 1920,
max_height: 1080,
supported_formats: vec![ImageFormat::JPEG, ImageFormat::PNG, ImageFormat::WEBP],
use_clip: true,
generate_captions: true,
extract_features: true,
compression_quality: 85,
}
}
}
impl Default for TableExtractionConfig {
fn default() -> Self {
Self {
min_rows: 2,
min_cols: 2,
extract_headers: true,
infer_types: true,
generate_summaries: true,
output_format: TableOutputFormat::JSON,
}
}
}
impl Default for ChartAnalysisConfig {
fn default() -> Self {
Self {
chart_types: vec![
ChartType::Line,
ChartType::Bar,
ChartType::Pie,
ChartType::Scatter,
],
extract_data: true,
generate_descriptions: true,
analyze_trends: true,
}
}
}
impl Default for OCRConfig {
fn default() -> Self {
Self {
engine: OCREngineType::Tesseract,
languages: vec!["eng".to_string()],
confidence_threshold: 0.7,
spell_correction: true,
preserve_formatting: true,
}
}
}
impl Default for LayoutAnalysisConfig {
fn default() -> Self {
Self {
detect_structure: true,
identify_sections: true,
extract_reading_order: true,
detect_columns: true,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_multimodal_config() {
let config = MultiModalConfig::default();
assert!(config.process_images);
assert!(config.process_tables);
assert!(config.process_charts);
}
#[test]
fn test_image_config() {
let config = ImageProcessingConfig::default();
assert_eq!(config.max_width, 1920);
assert_eq!(config.max_height, 1080);
assert!(config.use_clip);
}
}