1pub mod chart_processor;
162pub mod document_parser;
163pub mod embedding_fusion;
164pub mod image_processor;
165pub mod layout_analysis;
166pub mod ocr;
167pub mod retrieval;
168pub mod table_processor;
169
170use crate::RragResult;
171use serde::{Deserialize, Serialize};
172use std::path::Path;
173
174pub struct MultiModalService {
176 config: MultiModalConfig,
178
179 image_processor: Box<dyn ImageProcessor>,
181
182 table_processor: Box<dyn TableProcessor>,
184
185 chart_processor: Box<dyn ChartProcessor>,
187
188 ocr_engine: Box<dyn OCREngine>,
190
191 layout_analyzer: Box<dyn LayoutAnalyzer>,
193
194 fusion_strategy: Box<dyn EmbeddingFusionStrategy>,
196}
197
198#[derive(Debug, Clone)]
200pub struct MultiModalConfig {
201 pub process_images: bool,
203
204 pub process_tables: bool,
206
207 pub process_charts: bool,
209
210 pub image_config: ImageProcessingConfig,
212
213 pub table_config: TableExtractionConfig,
215
216 pub chart_config: ChartAnalysisConfig,
218
219 pub ocr_config: OCRConfig,
221
222 pub layout_config: LayoutAnalysisConfig,
224
225 pub fusion_strategy: FusionStrategy,
227}
228
229#[derive(Debug, Clone)]
231pub struct ImageProcessingConfig {
232 pub max_width: u32,
234 pub max_height: u32,
235
236 pub supported_formats: Vec<ImageFormat>,
238
239 pub use_clip: bool,
241
242 pub generate_captions: bool,
244
245 pub extract_features: bool,
247
248 pub compression_quality: u8,
250}
251
252#[derive(Debug, Clone)]
254pub struct TableExtractionConfig {
255 pub min_rows: usize,
257
258 pub min_cols: usize,
260
261 pub extract_headers: bool,
263
264 pub infer_types: bool,
266
267 pub generate_summaries: bool,
269
270 pub output_format: TableOutputFormat,
272}
273
274#[derive(Debug, Clone)]
276pub struct ChartAnalysisConfig {
277 pub chart_types: Vec<ChartType>,
279
280 pub extract_data: bool,
282
283 pub generate_descriptions: bool,
285
286 pub analyze_trends: bool,
288}
289
290#[derive(Debug, Clone)]
292pub struct OCRConfig {
293 pub engine: OCREngineType,
295
296 pub languages: Vec<String>,
298
299 pub confidence_threshold: f32,
301
302 pub spell_correction: bool,
304
305 pub preserve_formatting: bool,
307}
308
309#[derive(Debug, Clone)]
311pub struct LayoutAnalysisConfig {
312 pub detect_structure: bool,
314
315 pub identify_sections: bool,
317
318 pub extract_reading_order: bool,
320
321 pub detect_columns: bool,
323}
324
325#[derive(Debug, Clone, Serialize, Deserialize)]
327pub struct MultiModalDocument {
328 pub id: String,
330
331 pub text_content: String,
333
334 pub images: Vec<ProcessedImage>,
336
337 pub tables: Vec<ExtractedTable>,
339
340 pub charts: Vec<AnalyzedChart>,
342
343 pub layout: DocumentLayout,
345
346 pub embeddings: MultiModalEmbeddings,
348
349 pub metadata: DocumentMetadata,
351}
352
353#[derive(Debug, Clone, Serialize, Deserialize)]
355pub struct ProcessedImage {
356 pub id: String,
358
359 pub source: String,
361
362 pub caption: Option<String>,
364
365 pub ocr_text: Option<String>,
367
368 pub features: Option<VisualFeatures>,
370
371 pub clip_embedding: Option<Vec<f32>>,
373
374 pub metadata: ImageMetadata,
376}
377
378#[derive(Debug, Clone, Serialize, Deserialize)]
380pub struct ExtractedTable {
381 pub id: String,
383
384 pub headers: Vec<String>,
386
387 pub rows: Vec<Vec<TableCell>>,
389
390 pub summary: Option<String>,
392
393 pub column_types: Vec<DataType>,
395
396 pub embedding: Option<Vec<f32>>,
398
399 pub statistics: Option<TableStatistics>,
401}
402
403#[derive(Debug, Clone, Serialize, Deserialize)]
405pub struct AnalyzedChart {
406 pub id: String,
408
409 pub chart_type: ChartType,
411
412 pub title: Option<String>,
414
415 pub axes: ChartAxes,
417
418 pub data_points: Vec<DataPoint>,
420
421 pub trends: Option<TrendAnalysis>,
423
424 pub description: Option<String>,
426
427 pub embedding: Option<Vec<f32>>,
429}
430
431#[derive(Debug, Clone, Serialize, Deserialize)]
433pub struct DocumentLayout {
434 pub pages: usize,
436
437 pub sections: Vec<DocumentSection>,
439
440 pub reading_order: Vec<String>,
442
443 pub columns: Option<ColumnLayout>,
445
446 pub document_type: DocumentType,
448}
449
450#[derive(Debug, Clone, Serialize, Deserialize)]
452pub struct MultiModalEmbeddings {
453 pub text_embeddings: Vec<f32>,
455
456 pub visual_embeddings: Option<Vec<f32>>,
458
459 pub table_embeddings: Option<Vec<f32>>,
461
462 pub fused_embedding: Vec<f32>,
464
465 pub weights: EmbeddingWeights,
467}
468
469pub trait ImageProcessor: Send + Sync {
471 fn process_image(&self, image_path: &Path) -> RragResult<ProcessedImage>;
473
474 fn extract_features(&self, image_path: &Path) -> RragResult<VisualFeatures>;
476
477 fn generate_caption(&self, image_path: &Path) -> RragResult<String>;
479
480 fn generate_clip_embedding(&self, image_path: &Path) -> RragResult<Vec<f32>>;
482}
483
484pub trait TableProcessor: Send + Sync {
486 fn extract_table(&self, content: &str) -> RragResult<Vec<ExtractedTable>>;
488
489 fn parse_structure(&self, table_html: &str) -> RragResult<ExtractedTable>;
491
492 fn generate_summary(&self, table: &ExtractedTable) -> RragResult<String>;
494
495 fn calculate_statistics(&self, table: &ExtractedTable) -> RragResult<TableStatistics>;
497}
498
499pub trait ChartProcessor: Send + Sync {
501 fn analyze_chart(&self, image_path: &Path) -> RragResult<AnalyzedChart>;
503
504 fn extract_data_points(&self, chart_image: &Path) -> RragResult<Vec<DataPoint>>;
506
507 fn identify_type(&self, chart_image: &Path) -> RragResult<ChartType>;
509
510 fn analyze_trends(&self, data_points: &[DataPoint]) -> RragResult<TrendAnalysis>;
512}
513
514pub trait OCREngine: Send + Sync {
516 fn ocr(&self, image_path: &Path) -> RragResult<OCRResult>;
518
519 fn get_text_with_confidence(&self, image_path: &Path) -> RragResult<Vec<(String, f32)>>;
521
522 fn get_layout(&self, image_path: &Path) -> RragResult<TextLayout>;
524}
525
526pub trait LayoutAnalyzer: Send + Sync {
528 fn analyze_layout(&self, document_path: &Path) -> RragResult<DocumentLayout>;
530
531 fn detect_sections(&self, content: &str) -> RragResult<Vec<DocumentSection>>;
533
534 fn extract_reading_order(&self, layout: &DocumentLayout) -> RragResult<Vec<String>>;
536}
537
538pub trait EmbeddingFusionStrategy: Send + Sync {
540 fn fuse_embeddings(&self, embeddings: &MultiModalEmbeddings) -> RragResult<Vec<f32>>;
542
543 fn calculate_weights(&self, document: &MultiModalDocument) -> RragResult<EmbeddingWeights>;
545}
546
547#[derive(Debug, Clone, Copy, PartialEq, Eq)]
551pub enum ImageFormat {
552 JPEG,
553 PNG,
554 GIF,
555 BMP,
556 WEBP,
557 SVG,
558 TIFF,
559}
560
561#[derive(Debug, Clone, Copy)]
563pub enum TableOutputFormat {
564 CSV,
565 JSON,
566 Markdown,
567 HTML,
568}
569
570#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
572pub enum ChartType {
573 Line,
574 Bar,
575 Pie,
576 Scatter,
577 Area,
578 Histogram,
579 HeatMap,
580 Box,
581 Unknown,
582}
583
584#[derive(Debug, Clone, Copy, PartialEq)]
586pub enum OCREngineType {
587 Tesseract,
588 EasyOCR,
589 PaddleOCR,
590 CloudVision,
591}
592
593#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
595pub enum DocumentType {
596 PDF,
597 Word,
598 PowerPoint,
599 HTML,
600 Markdown,
601 PlainText,
602 Mixed,
603}
604
605#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)]
607pub enum DataType {
608 String,
609 Number,
610 Date,
611 Boolean,
612 Mixed,
613}
614
615#[derive(Debug, Clone, Copy)]
617pub enum FusionStrategy {
618 Average,
620
621 Weighted,
623
624 Concatenate,
626
627 Attention,
629
630 Learned,
632}
633
634#[derive(Debug, Clone, Serialize, Deserialize)]
636pub struct VisualFeatures {
637 pub colors: Vec<Color>,
639
640 pub objects: Vec<DetectedObject>,
642
643 pub scene: Option<String>,
645
646 pub quality: ImageQuality,
648
649 pub layout: SpatialLayout,
651}
652
653#[derive(Debug, Clone, Serialize, Deserialize)]
655pub struct TableCell {
656 pub value: String,
658
659 pub data_type: DataType,
661
662 pub formatting: Option<CellFormatting>,
664}
665
666#[derive(Debug, Clone, Serialize, Deserialize)]
668pub struct TableStatistics {
669 pub row_count: usize,
671
672 pub column_count: usize,
674
675 pub null_percentages: Vec<f32>,
677
678 pub column_stats: Vec<ColumnStatistics>,
680}
681
682#[derive(Debug, Clone, Serialize, Deserialize)]
684pub struct ColumnStatistics {
685 pub name: String,
687
688 pub numeric_stats: Option<NumericStatistics>,
690
691 pub text_stats: Option<TextStatistics>,
693
694 pub unique_count: usize,
696}
697
698#[derive(Debug, Clone, Serialize, Deserialize)]
700pub struct NumericStatistics {
701 pub min: f64,
702 pub max: f64,
703 pub mean: f64,
704 pub median: f64,
705 pub std_dev: f64,
706}
707
708#[derive(Debug, Clone, Serialize, Deserialize)]
710pub struct TextStatistics {
711 pub min_length: usize,
712 pub max_length: usize,
713 pub avg_length: f32,
714 pub most_common: Vec<(String, usize)>,
715}
716
717#[derive(Debug, Clone, Serialize, Deserialize)]
719pub struct ChartAxes {
720 pub x_label: Option<String>,
721 pub y_label: Option<String>,
722 pub x_range: Option<(f64, f64)>,
723 pub y_range: Option<(f64, f64)>,
724}
725
726#[derive(Debug, Clone, Serialize, Deserialize)]
728pub struct DataPoint {
729 pub x: f64,
730 pub y: f64,
731 pub label: Option<String>,
732 pub series: Option<String>,
733}
734
735#[derive(Debug, Clone, Serialize, Deserialize)]
737pub struct TrendAnalysis {
738 pub direction: TrendDirection,
740
741 pub strength: f32,
743
744 pub seasonality: Option<Seasonality>,
746
747 pub outliers: Vec<DataPoint>,
749
750 pub forecast: Option<Vec<DataPoint>>,
752}
753
754#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
756pub enum TrendDirection {
757 Increasing,
758 Decreasing,
759 Stable,
760 Volatile,
761}
762
763#[derive(Debug, Clone, Serialize, Deserialize)]
765pub struct Seasonality {
766 pub period: f64,
767 pub amplitude: f64,
768 pub phase: f64,
769}
770
771#[derive(Debug, Clone)]
773pub struct OCRResult {
774 pub text: String,
776
777 pub confidence: f32,
779
780 pub words: Vec<OCRWord>,
782
783 pub languages: Vec<String>,
785}
786
787#[derive(Debug, Clone)]
789pub struct OCRWord {
790 pub text: String,
791 pub confidence: f32,
792 pub bounding_box: BoundingBox,
793}
794
795#[derive(Debug, Clone)]
797pub struct BoundingBox {
798 pub x: u32,
799 pub y: u32,
800 pub width: u32,
801 pub height: u32,
802}
803
804#[derive(Debug, Clone)]
806pub struct TextLayout {
807 pub blocks: Vec<TextBlock>,
809
810 pub reading_order: Vec<usize>,
812
813 pub columns: Option<Vec<Column>>,
815}
816
817#[derive(Debug, Clone)]
819pub struct TextBlock {
820 pub id: usize,
821 pub text: String,
822 pub bounding_box: BoundingBox,
823 pub block_type: BlockType,
824}
825
826#[derive(Debug, Clone, Copy)]
828pub enum BlockType {
829 Title,
830 Heading,
831 Paragraph,
832 Caption,
833 Footer,
834 Header,
835}
836
837#[derive(Debug, Clone)]
839pub struct Column {
840 pub index: usize,
841 pub blocks: Vec<usize>,
842 pub width: u32,
843}
844
845#[derive(Debug, Clone, Serialize, Deserialize)]
847pub struct DocumentSection {
848 pub id: String,
849 pub title: Option<String>,
850 pub content: String,
851 pub section_type: SectionType,
852 pub level: usize,
853 pub page_range: (usize, usize),
854}
855
856#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)]
858pub enum SectionType {
859 Title,
860 Abstract,
861 Introduction,
862 Body,
863 Conclusion,
864 References,
865 Appendix,
866}
867
868#[derive(Debug, Clone, Serialize, Deserialize)]
870pub struct ColumnLayout {
871 pub column_count: usize,
872 pub column_widths: Vec<f32>,
873 pub gutter_width: f32,
874}
875
876#[derive(Debug, Clone, Serialize, Deserialize)]
878pub struct DocumentMetadata {
879 pub title: Option<String>,
880 pub author: Option<String>,
881 pub creation_date: Option<String>,
882 pub modification_date: Option<String>,
883 pub page_count: usize,
884 pub word_count: usize,
885 pub language: String,
886 pub format: DocumentType,
887}
888
889#[derive(Debug, Clone, Serialize, Deserialize)]
891pub struct ImageMetadata {
892 pub width: u32,
893 pub height: u32,
894 pub format: String,
895 pub size_bytes: usize,
896 pub dpi: Option<u32>,
897 pub color_space: Option<String>,
898}
899
900#[derive(Debug, Clone, Serialize, Deserialize)]
902pub struct Color {
903 pub rgb: (u8, u8, u8),
904 pub percentage: f32,
905 pub name: Option<String>,
906}
907
908#[derive(Debug, Clone, Serialize, Deserialize)]
910pub struct DetectedObject {
911 pub class: String,
912 pub confidence: f32,
913 pub bounding_box: (f32, f32, f32, f32),
914}
915
916#[derive(Debug, Clone, Serialize, Deserialize)]
918pub struct ImageQuality {
919 pub sharpness: f32,
920 pub contrast: f32,
921 pub brightness: f32,
922 pub noise_level: f32,
923}
924
925#[derive(Debug, Clone, Serialize, Deserialize)]
927pub struct SpatialLayout {
928 pub composition_type: CompositionType,
929 pub focal_points: Vec<(f32, f32)>,
930 pub balance: f32,
931}
932
933#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
935pub enum CompositionType {
936 RuleOfThirds,
937 Centered,
938 Diagonal,
939 Symmetrical,
940 Asymmetrical,
941}
942
943#[derive(Debug, Clone, Serialize, Deserialize)]
945pub struct CellFormatting {
946 pub bold: bool,
947 pub italic: bool,
948 pub color: Option<String>,
949 pub background: Option<String>,
950}
951
952#[derive(Debug, Clone, Serialize, Deserialize)]
954pub struct EmbeddingWeights {
955 pub text_weight: f32,
956 pub visual_weight: f32,
957 pub table_weight: f32,
958 pub chart_weight: f32,
959}
960
961impl MultiModalService {
962 pub fn new(config: MultiModalConfig) -> RragResult<Self> {
964 Ok(Self {
965 config: config.clone(),
966 image_processor: Box::new(image_processor::DefaultImageProcessor::new(
967 config.image_config,
968 )?),
969 table_processor: Box::new(table_processor::DefaultTableProcessor::new(
970 config.table_config,
971 )?),
972 chart_processor: Box::new(chart_processor::DefaultChartProcessor::new(
973 config.chart_config,
974 )?),
975 ocr_engine: Box::new(ocr::DefaultOCREngine::new(config.ocr_config)?),
976 layout_analyzer: Box::new(layout_analysis::DefaultLayoutAnalyzer::new(
977 config.layout_config,
978 )?),
979 fusion_strategy: Box::new(embedding_fusion::DefaultFusionStrategy::new(
980 config.fusion_strategy,
981 )?),
982 })
983 }
984
985 pub async fn process_document(&self, _document_path: &Path) -> RragResult<MultiModalDocument> {
987 todo!("Implement multi-modal document processing")
989 }
990
991 pub async fn extract_modalities(&self, _content: &[u8]) -> RragResult<MultiModalDocument> {
993 todo!("Implement modality extraction")
995 }
996}
997
998impl Default for MultiModalConfig {
999 fn default() -> Self {
1000 Self {
1001 process_images: true,
1002 process_tables: true,
1003 process_charts: true,
1004 image_config: ImageProcessingConfig::default(),
1005 table_config: TableExtractionConfig::default(),
1006 chart_config: ChartAnalysisConfig::default(),
1007 ocr_config: OCRConfig::default(),
1008 layout_config: LayoutAnalysisConfig::default(),
1009 fusion_strategy: FusionStrategy::Weighted,
1010 }
1011 }
1012}
1013
1014impl Default for ImageProcessingConfig {
1015 fn default() -> Self {
1016 Self {
1017 max_width: 1920,
1018 max_height: 1080,
1019 supported_formats: vec![ImageFormat::JPEG, ImageFormat::PNG, ImageFormat::WEBP],
1020 use_clip: true,
1021 generate_captions: true,
1022 extract_features: true,
1023 compression_quality: 85,
1024 }
1025 }
1026}
1027
1028impl Default for TableExtractionConfig {
1029 fn default() -> Self {
1030 Self {
1031 min_rows: 2,
1032 min_cols: 2,
1033 extract_headers: true,
1034 infer_types: true,
1035 generate_summaries: true,
1036 output_format: TableOutputFormat::JSON,
1037 }
1038 }
1039}
1040
1041impl Default for ChartAnalysisConfig {
1042 fn default() -> Self {
1043 Self {
1044 chart_types: vec![
1045 ChartType::Line,
1046 ChartType::Bar,
1047 ChartType::Pie,
1048 ChartType::Scatter,
1049 ],
1050 extract_data: true,
1051 generate_descriptions: true,
1052 analyze_trends: true,
1053 }
1054 }
1055}
1056
1057impl Default for OCRConfig {
1058 fn default() -> Self {
1059 Self {
1060 engine: OCREngineType::Tesseract,
1061 languages: vec!["eng".to_string()],
1062 confidence_threshold: 0.7,
1063 spell_correction: true,
1064 preserve_formatting: true,
1065 }
1066 }
1067}
1068
1069impl Default for LayoutAnalysisConfig {
1070 fn default() -> Self {
1071 Self {
1072 detect_structure: true,
1073 identify_sections: true,
1074 extract_reading_order: true,
1075 detect_columns: true,
1076 }
1077 }
1078}
1079
1080#[cfg(test)]
1081mod tests {
1082 use super::*;
1083
1084 #[test]
1085 fn test_multimodal_config() {
1086 let config = MultiModalConfig::default();
1087 assert!(config.process_images);
1088 assert!(config.process_tables);
1089 assert!(config.process_charts);
1090 }
1091
1092 #[test]
1093 fn test_image_config() {
1094 let config = ImageProcessingConfig::default();
1095 assert_eq!(config.max_width, 1920);
1096 assert_eq!(config.max_height, 1080);
1097 assert!(config.use_clip);
1098 }
1099}