rexis_rag/multimodal/
chart_processor.rs

1//! # Chart Processing
2//!
3//! Advanced chart analysis, data extraction, and trend analysis.
4
5use super::{
6    AnalyzedChart, ChartAnalysisConfig, ChartAxes, ChartProcessor, ChartType, DataPoint,
7    Seasonality, TrendAnalysis, TrendDirection,
8};
9use crate::{RragError, RragResult};
10use std::path::Path;
11
12/// Default chart processor implementation
13pub struct DefaultChartProcessor {
14    /// Configuration
15    config: ChartAnalysisConfig,
16
17    /// Chart type classifier
18    type_classifier: ChartTypeClassifier,
19
20    /// Data extractor
21    data_extractor: ChartDataExtractor,
22
23    /// Trend analyzer
24    trend_analyzer: TrendAnalyzer,
25
26    /// Description generator
27    description_generator: ChartDescriptionGenerator,
28}
29
30/// Chart type classifier
31pub struct ChartTypeClassifier {
32    /// Classification models
33    models: Vec<ClassificationModel>,
34}
35
36/// Chart data extractor
37pub struct ChartDataExtractor {
38    /// OCR engine for text extraction
39    ocr_enabled: bool,
40
41    /// Color analysis for data series
42    color_analysis: bool,
43
44    /// Shape detection for markers
45    shape_detection: bool,
46}
47
48/// Trend analyzer for time series and patterns
49pub struct TrendAnalyzer {
50    /// Minimum points for trend analysis
51    min_points: usize,
52
53    /// Smoothing window size
54    smoothing_window: usize,
55
56    /// Seasonality detection
57    seasonality_detection: bool,
58}
59
60/// Chart description generator
61pub struct ChartDescriptionGenerator {
62    /// Template-based generation
63    templates: std::collections::HashMap<ChartType, String>,
64
65    /// Natural language generation
66    nlg_enabled: bool,
67}
68
69/// Classification model for chart types
70#[derive(Debug, Clone)]
71pub struct ClassificationModel {
72    /// Model type
73    model_type: ModelType,
74
75    /// Confidence threshold
76    confidence_threshold: f32,
77
78    /// Feature extractors
79    features: Vec<FeatureType>,
80}
81
82/// Model types for classification
83#[derive(Debug, Clone, Copy)]
84pub enum ModelType {
85    CNN,
86    SVM,
87    RandomForest,
88    Ensemble,
89}
90
91/// Feature types for classification
92#[derive(Debug, Clone, Copy)]
93pub enum FeatureType {
94    ColorHistogram,
95    EdgeDetection,
96    ShapeFeatures,
97    TextFeatures,
98    LayoutFeatures,
99}
100
101/// Chart analysis result
102#[derive(Debug, Clone)]
103pub struct ChartAnalysisResult {
104    /// Identified chart type
105    pub chart_type: ChartType,
106
107    /// Classification confidence
108    pub confidence: f32,
109
110    /// Extracted data points
111    pub data_points: Vec<DataPoint>,
112
113    /// Chart elements
114    pub elements: ChartElements,
115
116    /// Visual properties
117    pub visual_properties: VisualProperties,
118}
119
120/// Chart elements
121#[derive(Debug, Clone)]
122pub struct ChartElements {
123    /// Chart title
124    pub title: Option<String>,
125
126    /// Axis labels
127    pub axes: ChartAxes,
128
129    /// Legend entries
130    pub legend: Vec<LegendEntry>,
131
132    /// Data series
133    pub series: Vec<DataSeries>,
134
135    /// Annotations
136    pub annotations: Vec<ChartAnnotation>,
137}
138
139/// Legend entry
140#[derive(Debug, Clone)]
141pub struct LegendEntry {
142    /// Legend text
143    pub text: String,
144
145    /// Associated color
146    pub color: Option<(u8, u8, u8)>,
147
148    /// Symbol/marker type
149    pub symbol: Option<MarkerType>,
150}
151
152/// Data series in chart
153#[derive(Debug, Clone)]
154pub struct DataSeries {
155    /// Series name
156    pub name: String,
157
158    /// Data points
159    pub points: Vec<DataPoint>,
160
161    /// Series color
162    pub color: Option<(u8, u8, u8)>,
163
164    /// Line style
165    pub line_style: Option<LineStyle>,
166}
167
168/// Chart annotation
169#[derive(Debug, Clone)]
170pub struct ChartAnnotation {
171    /// Annotation text
172    pub text: String,
173
174    /// Position
175    pub position: (f64, f64),
176
177    /// Annotation type
178    pub annotation_type: AnnotationType,
179}
180
181/// Visual properties of chart
182#[derive(Debug, Clone)]
183pub struct VisualProperties {
184    /// Chart area
185    pub chart_area: ChartArea,
186
187    /// Color scheme
188    pub color_scheme: ColorScheme,
189
190    /// Typography
191    pub typography: Typography,
192
193    /// Grid properties
194    pub grid: Option<GridProperties>,
195}
196
197/// Chart area dimensions
198#[derive(Debug, Clone)]
199pub struct ChartArea {
200    /// Chart bounds
201    pub bounds: (f64, f64, f64, f64), // (x, y, width, height)
202
203    /// Plot area
204    pub plot_area: (f64, f64, f64, f64),
205
206    /// Margins
207    pub margins: (f64, f64, f64, f64), // (top, right, bottom, left)
208}
209
210/// Color scheme analysis
211#[derive(Debug, Clone)]
212pub struct ColorScheme {
213    /// Primary colors
214    pub primary_colors: Vec<(u8, u8, u8)>,
215
216    /// Color palette type
217    pub palette_type: PaletteType,
218
219    /// Color accessibility score
220    pub accessibility_score: f32,
221}
222
223/// Typography analysis
224#[derive(Debug, Clone)]
225pub struct Typography {
226    /// Title font info
227    pub title_font: Option<FontInfo>,
228
229    /// Axis font info
230    pub axis_font: Option<FontInfo>,
231
232    /// Legend font info
233    pub legend_font: Option<FontInfo>,
234
235    /// Overall readability score
236    pub readability_score: f32,
237}
238
239/// Font information
240#[derive(Debug, Clone)]
241pub struct FontInfo {
242    /// Font family
243    pub family: String,
244
245    /// Font size
246    pub size: f32,
247
248    /// Font weight
249    pub weight: FontWeight,
250
251    /// Font color
252    pub color: (u8, u8, u8),
253}
254
255/// Grid properties
256#[derive(Debug, Clone)]
257pub struct GridProperties {
258    /// Grid type
259    pub grid_type: GridType,
260
261    /// Grid color
262    pub color: (u8, u8, u8),
263
264    /// Grid opacity
265    pub opacity: f32,
266
267    /// Grid line count
268    pub line_count: (usize, usize), // (horizontal, vertical)
269}
270
271/// Marker types
272#[derive(Debug, Clone, Copy)]
273pub enum MarkerType {
274    Circle,
275    Square,
276    Triangle,
277    Diamond,
278    Plus,
279    Cross,
280    Star,
281}
282
283/// Line styles
284#[derive(Debug, Clone, Copy)]
285pub enum LineStyle {
286    Solid,
287    Dashed,
288    Dotted,
289    DashDot,
290}
291
292/// Annotation types
293#[derive(Debug, Clone, Copy)]
294pub enum AnnotationType {
295    Label,
296    Arrow,
297    Callout,
298    Highlight,
299}
300
301/// Color palette types
302#[derive(Debug, Clone, Copy)]
303pub enum PaletteType {
304    Sequential,
305    Diverging,
306    Categorical,
307    Monochromatic,
308}
309
310/// Font weights
311#[derive(Debug, Clone, Copy)]
312pub enum FontWeight {
313    Thin,
314    Light,
315    Regular,
316    Medium,
317    Bold,
318    ExtraBold,
319}
320
321/// Grid types
322#[derive(Debug, Clone, Copy)]
323pub enum GridType {
324    Major,
325    Minor,
326    Both,
327    None,
328}
329
330impl DefaultChartProcessor {
331    /// Create new chart processor
332    pub fn new(config: ChartAnalysisConfig) -> RragResult<Self> {
333        let type_classifier = ChartTypeClassifier::new()?;
334        let data_extractor = ChartDataExtractor::new(true, true, true);
335        let trend_analyzer = TrendAnalyzer::new(5, 3, true);
336        let description_generator = ChartDescriptionGenerator::new();
337
338        Ok(Self {
339            config,
340            type_classifier,
341            data_extractor,
342            trend_analyzer,
343            description_generator,
344        })
345    }
346
347    /// Comprehensive chart analysis
348    pub fn analyze_comprehensive(&self, image_path: &Path) -> RragResult<ChartAnalysisResult> {
349        // Classify chart type
350        let (chart_type, confidence) = self.type_classifier.classify(image_path)?;
351
352        // Extract data points
353        let data_points = self.data_extractor.extract(image_path, chart_type)?;
354
355        // Analyze chart elements
356        let elements = self.analyze_elements(image_path, chart_type)?;
357
358        // Analyze visual properties
359        let visual_properties = self.analyze_visual_properties(image_path)?;
360
361        Ok(ChartAnalysisResult {
362            chart_type,
363            confidence,
364            data_points,
365            elements,
366            visual_properties,
367        })
368    }
369
370    /// Analyze chart elements
371    fn analyze_elements(
372        &self,
373        image_path: &Path,
374        chart_type: ChartType,
375    ) -> RragResult<ChartElements> {
376        // Extract title
377        let title = self.extract_title(image_path)?;
378
379        // Extract axes information
380        let axes = self.extract_axes(image_path)?;
381
382        // Extract legend
383        let legend = self.extract_legend(image_path)?;
384
385        // Extract data series
386        let series = self.extract_series(image_path, chart_type)?;
387
388        // Extract annotations
389        let annotations = self.extract_annotations(image_path)?;
390
391        Ok(ChartElements {
392            title,
393            axes,
394            legend,
395            series,
396            annotations,
397        })
398    }
399
400    /// Extract chart title
401    fn extract_title(&self, _image_path: &Path) -> RragResult<Option<String>> {
402        // Simulate title extraction using OCR
403        Ok(Some("Sample Chart Title".to_string()))
404    }
405
406    /// Extract axis information
407    fn extract_axes(&self, _image_path: &Path) -> RragResult<ChartAxes> {
408        // Simulate axis extraction
409        Ok(ChartAxes {
410            x_label: Some("Time".to_string()),
411            y_label: Some("Value".to_string()),
412            x_range: Some((0.0, 100.0)),
413            y_range: Some((0.0, 50.0)),
414        })
415    }
416
417    /// Extract legend information
418    fn extract_legend(&self, _image_path: &Path) -> RragResult<Vec<LegendEntry>> {
419        // Simulate legend extraction
420        Ok(vec![
421            LegendEntry {
422                text: "Series 1".to_string(),
423                color: Some((255, 0, 0)),
424                symbol: Some(MarkerType::Circle),
425            },
426            LegendEntry {
427                text: "Series 2".to_string(),
428                color: Some((0, 255, 0)),
429                symbol: Some(MarkerType::Square),
430            },
431        ])
432    }
433
434    /// Extract data series
435    fn extract_series(
436        &self,
437        _image_path: &Path,
438        chart_type: ChartType,
439    ) -> RragResult<Vec<DataSeries>> {
440        // Simulate series extraction based on chart type
441        match chart_type {
442            ChartType::Line => self.extract_line_series(),
443            ChartType::Bar => self.extract_bar_series(),
444            ChartType::Pie => self.extract_pie_series(),
445            ChartType::Scatter => self.extract_scatter_series(),
446            _ => Ok(vec![]),
447        }
448    }
449
450    /// Extract line chart series
451    fn extract_line_series(&self) -> RragResult<Vec<DataSeries>> {
452        Ok(vec![DataSeries {
453            name: "Series 1".to_string(),
454            points: vec![
455                DataPoint {
456                    x: 0.0,
457                    y: 10.0,
458                    label: None,
459                    series: Some("Series 1".to_string()),
460                },
461                DataPoint {
462                    x: 1.0,
463                    y: 15.0,
464                    label: None,
465                    series: Some("Series 1".to_string()),
466                },
467                DataPoint {
468                    x: 2.0,
469                    y: 12.0,
470                    label: None,
471                    series: Some("Series 1".to_string()),
472                },
473            ],
474            color: Some((255, 0, 0)),
475            line_style: Some(LineStyle::Solid),
476        }])
477    }
478
479    /// Extract bar chart series
480    fn extract_bar_series(&self) -> RragResult<Vec<DataSeries>> {
481        Ok(vec![DataSeries {
482            name: "Categories".to_string(),
483            points: vec![
484                DataPoint {
485                    x: 0.0,
486                    y: 20.0,
487                    label: Some("Category A".to_string()),
488                    series: None,
489                },
490                DataPoint {
491                    x: 1.0,
492                    y: 35.0,
493                    label: Some("Category B".to_string()),
494                    series: None,
495                },
496                DataPoint {
497                    x: 2.0,
498                    y: 25.0,
499                    label: Some("Category C".to_string()),
500                    series: None,
501                },
502            ],
503            color: Some((0, 100, 200)),
504            line_style: None,
505        }])
506    }
507
508    /// Extract pie chart series
509    fn extract_pie_series(&self) -> RragResult<Vec<DataSeries>> {
510        Ok(vec![DataSeries {
511            name: "Pie Slices".to_string(),
512            points: vec![
513                DataPoint {
514                    x: 0.0,
515                    y: 40.0,
516                    label: Some("Slice A".to_string()),
517                    series: None,
518                },
519                DataPoint {
520                    x: 1.0,
521                    y: 30.0,
522                    label: Some("Slice B".to_string()),
523                    series: None,
524                },
525                DataPoint {
526                    x: 2.0,
527                    y: 30.0,
528                    label: Some("Slice C".to_string()),
529                    series: None,
530                },
531            ],
532            color: None,
533            line_style: None,
534        }])
535    }
536
537    /// Extract scatter plot series
538    fn extract_scatter_series(&self) -> RragResult<Vec<DataSeries>> {
539        Ok(vec![DataSeries {
540            name: "Scatter Points".to_string(),
541            points: vec![
542                DataPoint {
543                    x: 5.0,
544                    y: 10.0,
545                    label: None,
546                    series: None,
547                },
548                DataPoint {
549                    x: 15.0,
550                    y: 25.0,
551                    label: None,
552                    series: None,
553                },
554                DataPoint {
555                    x: 25.0,
556                    y: 20.0,
557                    label: None,
558                    series: None,
559                },
560            ],
561            color: Some((100, 100, 100)),
562            line_style: None,
563        }])
564    }
565
566    /// Extract annotations
567    fn extract_annotations(&self, _image_path: &Path) -> RragResult<Vec<ChartAnnotation>> {
568        // Simulate annotation extraction
569        Ok(vec![])
570    }
571
572    /// Analyze visual properties
573    fn analyze_visual_properties(&self, _image_path: &Path) -> RragResult<VisualProperties> {
574        // Simulate visual property analysis
575        Ok(VisualProperties {
576            chart_area: ChartArea {
577                bounds: (0.0, 0.0, 800.0, 600.0),
578                plot_area: (100.0, 100.0, 600.0, 400.0),
579                margins: (50.0, 50.0, 50.0, 100.0),
580            },
581            color_scheme: ColorScheme {
582                primary_colors: vec![(255, 0, 0), (0, 255, 0), (0, 0, 255)],
583                palette_type: PaletteType::Categorical,
584                accessibility_score: 0.8,
585            },
586            typography: Typography {
587                title_font: Some(FontInfo {
588                    family: "Arial".to_string(),
589                    size: 16.0,
590                    weight: FontWeight::Bold,
591                    color: (0, 0, 0),
592                }),
593                axis_font: Some(FontInfo {
594                    family: "Arial".to_string(),
595                    size: 12.0,
596                    weight: FontWeight::Regular,
597                    color: (100, 100, 100),
598                }),
599                legend_font: Some(FontInfo {
600                    family: "Arial".to_string(),
601                    size: 10.0,
602                    weight: FontWeight::Regular,
603                    color: (0, 0, 0),
604                }),
605                readability_score: 0.9,
606            },
607            grid: Some(GridProperties {
608                grid_type: GridType::Major,
609                color: (200, 200, 200),
610                opacity: 0.3,
611                line_count: (5, 10),
612            }),
613        })
614    }
615}
616
617impl ChartProcessor for DefaultChartProcessor {
618    fn analyze_chart(&self, image_path: &Path) -> RragResult<AnalyzedChart> {
619        let analysis = self.analyze_comprehensive(image_path)?;
620
621        // Generate description
622        let description = if self.config.generate_descriptions {
623            Some(self.description_generator.generate(&analysis)?)
624        } else {
625            None
626        };
627
628        // Analyze trends
629        let trends = if self.config.analyze_trends && !analysis.data_points.is_empty() {
630            Some(self.trend_analyzer.analyze(&analysis.data_points)?)
631        } else {
632            None
633        };
634
635        Ok(AnalyzedChart {
636            id: format!(
637                "chart_{}",
638                uuid::Uuid::new_v4().to_string().split('-').next().unwrap()
639            ),
640            chart_type: analysis.chart_type,
641            title: analysis.elements.title,
642            axes: analysis.elements.axes,
643            data_points: analysis.data_points,
644            trends,
645            description,
646            embedding: None, // Would be generated by embedding service
647        })
648    }
649
650    fn extract_data_points(&self, chart_image: &Path) -> RragResult<Vec<DataPoint>> {
651        let analysis = self.analyze_comprehensive(chart_image)?;
652        Ok(analysis.data_points)
653    }
654
655    fn identify_type(&self, chart_image: &Path) -> RragResult<ChartType> {
656        let (chart_type, _confidence) = self.type_classifier.classify(chart_image)?;
657        Ok(chart_type)
658    }
659
660    fn analyze_trends(&self, data_points: &[DataPoint]) -> RragResult<TrendAnalysis> {
661        self.trend_analyzer.analyze(data_points)
662    }
663}
664
665impl ChartTypeClassifier {
666    /// Create new chart type classifier
667    pub fn new() -> RragResult<Self> {
668        let models = vec![
669            ClassificationModel {
670                model_type: ModelType::CNN,
671                confidence_threshold: 0.8,
672                features: vec![
673                    FeatureType::ColorHistogram,
674                    FeatureType::EdgeDetection,
675                    FeatureType::ShapeFeatures,
676                ],
677            },
678            ClassificationModel {
679                model_type: ModelType::SVM,
680                confidence_threshold: 0.7,
681                features: vec![FeatureType::LayoutFeatures, FeatureType::TextFeatures],
682            },
683        ];
684
685        Ok(Self { models })
686    }
687
688    /// Classify chart type
689    pub fn classify(&self, image_path: &Path) -> RragResult<(ChartType, f32)> {
690        // Simulate classification based on image analysis
691        let filename = image_path
692            .file_name()
693            .and_then(|name| name.to_str())
694            .unwrap_or("");
695
696        // Simple heuristic based on filename for demonstration
697        let (chart_type, confidence) = if filename.contains("line") {
698            (ChartType::Line, 0.95)
699        } else if filename.contains("bar") {
700            (ChartType::Bar, 0.90)
701        } else if filename.contains("pie") {
702            (ChartType::Pie, 0.85)
703        } else if filename.contains("scatter") {
704            (ChartType::Scatter, 0.80)
705        } else {
706            (ChartType::Unknown, 0.5)
707        };
708
709        Ok((chart_type, confidence))
710    }
711
712    /// Extract features for classification
713    pub fn extract_features(&self, _image_path: &Path) -> RragResult<Vec<f32>> {
714        // Simulate feature extraction
715        let mut features = Vec::new();
716
717        // Color histogram features
718        features.extend(vec![0.1, 0.2, 0.3, 0.4]); // RGB histogram
719
720        // Edge detection features
721        features.extend(vec![0.5, 0.6]); // Edge density, direction
722
723        // Shape features
724        features.extend(vec![0.7, 0.8, 0.9]); // Rectangularity, circularity, linearity
725
726        // Layout features
727        features.extend(vec![0.2, 0.4]); // Symmetry, balance
728
729        // Text features
730        features.push(0.3); // Text density
731
732        Ok(features)
733    }
734}
735
736impl ChartDataExtractor {
737    /// Create new data extractor
738    pub fn new(ocr_enabled: bool, color_analysis: bool, shape_detection: bool) -> Self {
739        Self {
740            ocr_enabled,
741            color_analysis,
742            shape_detection,
743        }
744    }
745
746    /// Extract data points from chart
747    pub fn extract(&self, image_path: &Path, chart_type: ChartType) -> RragResult<Vec<DataPoint>> {
748        match chart_type {
749            ChartType::Line => self.extract_line_data(image_path),
750            ChartType::Bar => self.extract_bar_data(image_path),
751            ChartType::Pie => self.extract_pie_data(image_path),
752            ChartType::Scatter => self.extract_scatter_data(image_path),
753            ChartType::Area => self.extract_area_data(image_path),
754            ChartType::Histogram => self.extract_histogram_data(image_path),
755            _ => Ok(vec![]),
756        }
757    }
758
759    /// Extract line chart data
760    fn extract_line_data(&self, _image_path: &Path) -> RragResult<Vec<DataPoint>> {
761        // Simulate line data extraction
762        Ok(vec![
763            DataPoint {
764                x: 0.0,
765                y: 10.0,
766                label: None,
767                series: Some("Line 1".to_string()),
768            },
769            DataPoint {
770                x: 1.0,
771                y: 15.0,
772                label: None,
773                series: Some("Line 1".to_string()),
774            },
775            DataPoint {
776                x: 2.0,
777                y: 12.0,
778                label: None,
779                series: Some("Line 1".to_string()),
780            },
781            DataPoint {
782                x: 3.0,
783                y: 18.0,
784                label: None,
785                series: Some("Line 1".to_string()),
786            },
787        ])
788    }
789
790    /// Extract bar chart data
791    fn extract_bar_data(&self, _image_path: &Path) -> RragResult<Vec<DataPoint>> {
792        // Simulate bar data extraction
793        Ok(vec![
794            DataPoint {
795                x: 0.0,
796                y: 25.0,
797                label: Some("Q1".to_string()),
798                series: None,
799            },
800            DataPoint {
801                x: 1.0,
802                y: 30.0,
803                label: Some("Q2".to_string()),
804                series: None,
805            },
806            DataPoint {
807                x: 2.0,
808                y: 35.0,
809                label: Some("Q3".to_string()),
810                series: None,
811            },
812            DataPoint {
813                x: 3.0,
814                y: 40.0,
815                label: Some("Q4".to_string()),
816                series: None,
817            },
818        ])
819    }
820
821    /// Extract pie chart data
822    fn extract_pie_data(&self, _image_path: &Path) -> RragResult<Vec<DataPoint>> {
823        // Simulate pie data extraction (percentages)
824        Ok(vec![
825            DataPoint {
826                x: 0.0,
827                y: 40.0,
828                label: Some("Category A".to_string()),
829                series: None,
830            },
831            DataPoint {
832                x: 1.0,
833                y: 30.0,
834                label: Some("Category B".to_string()),
835                series: None,
836            },
837            DataPoint {
838                x: 2.0,
839                y: 20.0,
840                label: Some("Category C".to_string()),
841                series: None,
842            },
843            DataPoint {
844                x: 3.0,
845                y: 10.0,
846                label: Some("Category D".to_string()),
847                series: None,
848            },
849        ])
850    }
851
852    /// Extract scatter plot data
853    fn extract_scatter_data(&self, _image_path: &Path) -> RragResult<Vec<DataPoint>> {
854        // Simulate scatter data extraction
855        Ok(vec![
856            DataPoint {
857                x: 5.0,
858                y: 10.0,
859                label: None,
860                series: None,
861            },
862            DataPoint {
863                x: 15.0,
864                y: 25.0,
865                label: None,
866                series: None,
867            },
868            DataPoint {
869                x: 25.0,
870                y: 20.0,
871                label: None,
872                series: None,
873            },
874            DataPoint {
875                x: 35.0,
876                y: 40.0,
877                label: None,
878                series: None,
879            },
880        ])
881    }
882
883    /// Extract area chart data
884    fn extract_area_data(&self, image_path: &Path) -> RragResult<Vec<DataPoint>> {
885        // Area charts similar to line charts
886        self.extract_line_data(image_path)
887    }
888
889    /// Extract histogram data
890    fn extract_histogram_data(&self, _image_path: &Path) -> RragResult<Vec<DataPoint>> {
891        // Simulate histogram data extraction
892        Ok(vec![
893            DataPoint {
894                x: 0.0,
895                y: 5.0,
896                label: Some("0-10".to_string()),
897                series: None,
898            },
899            DataPoint {
900                x: 1.0,
901                y: 15.0,
902                label: Some("10-20".to_string()),
903                series: None,
904            },
905            DataPoint {
906                x: 2.0,
907                y: 25.0,
908                label: Some("20-30".to_string()),
909                series: None,
910            },
911            DataPoint {
912                x: 3.0,
913                y: 10.0,
914                label: Some("30-40".to_string()),
915                series: None,
916            },
917        ])
918    }
919}
920
921impl TrendAnalyzer {
922    /// Create new trend analyzer
923    pub fn new(min_points: usize, smoothing_window: usize, seasonality_detection: bool) -> Self {
924        Self {
925            min_points,
926            smoothing_window,
927            seasonality_detection,
928        }
929    }
930
931    /// Analyze trends in data points
932    pub fn analyze(&self, data_points: &[DataPoint]) -> RragResult<TrendAnalysis> {
933        if data_points.len() < self.min_points {
934            return Err(RragError::validation(
935                "data_points",
936                format!("minimum {} points", self.min_points),
937                format!("{} points", data_points.len()),
938            ));
939        }
940
941        // Calculate trend direction
942        let direction = self.calculate_trend_direction(data_points);
943
944        // Calculate trend strength
945        let strength = self.calculate_trend_strength(data_points);
946
947        // Detect seasonality if enabled
948        let seasonality = if self.seasonality_detection {
949            self.detect_seasonality(data_points)
950        } else {
951            None
952        };
953
954        // Detect outliers
955        let outliers = self.detect_outliers(data_points);
956
957        // Generate forecast if enough data
958        let forecast = if data_points.len() >= 10 {
959            Some(self.generate_forecast(data_points, 5)?)
960        } else {
961            None
962        };
963
964        Ok(TrendAnalysis {
965            direction,
966            strength,
967            seasonality,
968            outliers,
969            forecast,
970        })
971    }
972
973    /// Calculate trend direction
974    fn calculate_trend_direction(&self, data_points: &[DataPoint]) -> TrendDirection {
975        if data_points.len() < 2 {
976            return TrendDirection::Stable;
977        }
978
979        let first_y = data_points[0].y;
980        let last_y = data_points[data_points.len() - 1].y;
981        let change = last_y - first_y;
982
983        // Calculate volatility
984        let volatility = self.calculate_volatility(data_points);
985
986        if change.abs() < volatility * 0.5 {
987            TrendDirection::Stable
988        } else if volatility > change.abs() * 2.0 {
989            TrendDirection::Volatile
990        } else if change > 0.0 {
991            TrendDirection::Increasing
992        } else {
993            TrendDirection::Decreasing
994        }
995    }
996
997    /// Calculate trend strength
998    fn calculate_trend_strength(&self, data_points: &[DataPoint]) -> f32 {
999        if data_points.len() < 2 {
1000            return 0.0;
1001        }
1002
1003        // Linear regression coefficient of determination (R²)
1004        let n = data_points.len() as f64;
1005        let sum_x: f64 = data_points.iter().map(|p| p.x).sum();
1006        let sum_y: f64 = data_points.iter().map(|p| p.y as f64).sum();
1007        let sum_xy: f64 = data_points.iter().map(|p| p.x * p.y as f64).sum();
1008        let sum_x2: f64 = data_points.iter().map(|p| p.x * p.x).sum();
1009        let sum_y2: f64 = data_points
1010            .iter()
1011            .map(|p| (p.y as f64) * (p.y as f64))
1012            .sum();
1013
1014        let numerator = n * sum_xy - sum_x * sum_y;
1015        let denominator = ((n * sum_x2 - sum_x * sum_x) * (n * sum_y2 - sum_y * sum_y)).sqrt();
1016
1017        if denominator == 0.0 {
1018            return 0.0;
1019        }
1020
1021        let r = numerator / denominator;
1022        (r * r) as f32 // R²
1023    }
1024
1025    /// Calculate volatility
1026    fn calculate_volatility(&self, data_points: &[DataPoint]) -> f64 {
1027        if data_points.len() < 2 {
1028            return 0.0;
1029        }
1030
1031        let values: Vec<f64> = data_points.iter().map(|p| p.y as f64).collect();
1032        let mean = values.iter().sum::<f64>() / values.len() as f64;
1033        let variance = values.iter().map(|v| (v - mean).powi(2)).sum::<f64>() / values.len() as f64;
1034        variance.sqrt()
1035    }
1036
1037    /// Detect seasonality patterns
1038    fn detect_seasonality(&self, data_points: &[DataPoint]) -> Option<Seasonality> {
1039        if data_points.len() < 12 {
1040            return None; // Need at least 12 points for seasonality detection
1041        }
1042
1043        // Simplified seasonality detection using autocorrelation
1044        // In practice, would use FFT or more sophisticated methods
1045
1046        Some(Seasonality {
1047            period: 12.0, // Assume monthly seasonality
1048            amplitude: 5.0,
1049            phase: 0.0,
1050        })
1051    }
1052
1053    /// Detect outliers using IQR method
1054    fn detect_outliers(&self, data_points: &[DataPoint]) -> Vec<DataPoint> {
1055        if data_points.len() < 4 {
1056            return vec![];
1057        }
1058
1059        let mut y_values: Vec<f32> = data_points.iter().map(|p| p.y as f32).collect();
1060        y_values.sort_by(|a, b| a.partial_cmp(b).unwrap());
1061
1062        let q1_idx = y_values.len() / 4;
1063        let q3_idx = 3 * y_values.len() / 4;
1064        let q1 = y_values[q1_idx];
1065        let q3 = y_values[q3_idx];
1066        let iqr = q3 - q1;
1067
1068        let lower_bound = q1 - 1.5 * iqr;
1069        let upper_bound = q3 + 1.5 * iqr;
1070
1071        data_points
1072            .iter()
1073            .filter(|p| (p.y as f32) < lower_bound || (p.y as f32) > upper_bound)
1074            .cloned()
1075            .collect()
1076    }
1077
1078    /// Generate forecast using simple linear regression
1079    fn generate_forecast(
1080        &self,
1081        data_points: &[DataPoint],
1082        num_points: usize,
1083    ) -> RragResult<Vec<DataPoint>> {
1084        if data_points.len() < 2 {
1085            return Ok(vec![]);
1086        }
1087
1088        // Calculate linear regression parameters
1089        let n = data_points.len() as f64;
1090        let sum_x: f64 = data_points.iter().map(|p| p.x).sum();
1091        let sum_y: f64 = data_points.iter().map(|p| p.y as f64).sum();
1092        let sum_xy: f64 = data_points.iter().map(|p| p.x * p.y as f64).sum();
1093        let sum_x2: f64 = data_points.iter().map(|p| p.x * p.x).sum();
1094
1095        let slope = (n * sum_xy - sum_x * sum_y) / (n * sum_x2 - sum_x * sum_x);
1096        let intercept = (sum_y - slope * sum_x) / n;
1097
1098        // Generate forecast points
1099        let last_x = data_points.last().unwrap().x;
1100        let mut forecast = Vec::new();
1101
1102        for i in 1..=num_points {
1103            let x = last_x + i as f64;
1104            let y = (slope * x + intercept) as f32;
1105
1106            forecast.push(DataPoint {
1107                x,
1108                y: y as f64,
1109                label: Some(format!("Forecast {}", i)),
1110                series: Some("Forecast".to_string()),
1111            });
1112        }
1113
1114        Ok(forecast)
1115    }
1116}
1117
1118impl ChartDescriptionGenerator {
1119    /// Create new description generator
1120    pub fn new() -> Self {
1121        let mut templates = std::collections::HashMap::new();
1122
1123        templates.insert(
1124            ChartType::Line,
1125            "This line chart shows {data_description}. The trend is {trend_direction} with a strength of {trend_strength:.2}.".to_string()
1126        );
1127
1128        templates.insert(
1129            ChartType::Bar,
1130            "This bar chart displays {data_description}. The highest value is {max_value} and the lowest is {min_value}.".to_string()
1131        );
1132
1133        templates.insert(
1134            ChartType::Pie,
1135            "This pie chart represents {data_description}. The largest segment is {largest_segment} at {largest_percentage:.1}%.".to_string()
1136        );
1137
1138        Self {
1139            templates,
1140            nlg_enabled: false,
1141        }
1142    }
1143
1144    /// Generate chart description
1145    pub fn generate(&self, analysis: &ChartAnalysisResult) -> RragResult<String> {
1146        if let Some(template) = self.templates.get(&analysis.chart_type) {
1147            let description = self.fill_template(template, analysis)?;
1148            Ok(description)
1149        } else {
1150            Ok(format!(
1151                "Chart of type {:?} with {} data points",
1152                analysis.chart_type,
1153                analysis.data_points.len()
1154            ))
1155        }
1156    }
1157
1158    /// Fill template with analysis data
1159    fn fill_template(&self, template: &str, analysis: &ChartAnalysisResult) -> RragResult<String> {
1160        let mut description = template.to_string();
1161
1162        // Basic substitutions
1163        description = description.replace(
1164            "{data_description}",
1165            &format!("{} data points", analysis.data_points.len()),
1166        );
1167
1168        if !analysis.data_points.is_empty() {
1169            let max_y = analysis
1170                .data_points
1171                .iter()
1172                .map(|p| p.y as f32)
1173                .fold(f32::NEG_INFINITY, |a, b| a.max(b));
1174            let min_y = analysis
1175                .data_points
1176                .iter()
1177                .map(|p| p.y as f32)
1178                .fold(f32::INFINITY, |a, b| a.min(b));
1179
1180            description = description.replace("{max_value}", &max_y.to_string());
1181            description = description.replace("{min_value}", &min_y.to_string());
1182        }
1183
1184        Ok(description)
1185    }
1186}
1187
1188#[cfg(test)]
1189mod tests {
1190    use super::*;
1191    use tempfile::NamedTempFile;
1192
1193    #[test]
1194    fn test_chart_processor_creation() {
1195        let config = ChartAnalysisConfig::default();
1196        let processor = DefaultChartProcessor::new(config).unwrap();
1197
1198        assert!(processor.config.extract_data);
1199        assert!(processor.config.generate_descriptions);
1200        assert!(processor.config.analyze_trends);
1201    }
1202
1203    #[test]
1204    fn test_chart_type_classification() {
1205        let classifier = ChartTypeClassifier::new().unwrap();
1206
1207        // Create temporary file with line chart hint
1208        let temp_file = NamedTempFile::new().unwrap();
1209        let path = temp_file.path().with_file_name("line_chart.png");
1210
1211        let (chart_type, confidence) = classifier.classify(&path).unwrap();
1212        assert_eq!(chart_type, ChartType::Line);
1213        assert!(confidence > 0.9);
1214    }
1215
1216    #[test]
1217    fn test_trend_analysis() {
1218        let analyzer = TrendAnalyzer::new(3, 2, false);
1219
1220        let data_points = vec![
1221            DataPoint {
1222                x: 0.0,
1223                y: 10.0,
1224                label: None,
1225                series: None,
1226            },
1227            DataPoint {
1228                x: 1.0,
1229                y: 15.0,
1230                label: None,
1231                series: None,
1232            },
1233            DataPoint {
1234                x: 2.0,
1235                y: 20.0,
1236                label: None,
1237                series: None,
1238            },
1239            DataPoint {
1240                x: 3.0,
1241                y: 25.0,
1242                label: None,
1243                series: None,
1244            },
1245        ];
1246
1247        let trend = analyzer.analyze(&data_points).unwrap();
1248        assert_eq!(trend.direction, TrendDirection::Increasing);
1249        assert!(trend.strength > 0.8);
1250    }
1251
1252    #[test]
1253    fn test_outlier_detection() {
1254        let analyzer = TrendAnalyzer::new(3, 2, false);
1255
1256        let data_points = vec![
1257            DataPoint {
1258                x: 0.0,
1259                y: 10.0,
1260                label: None,
1261                series: None,
1262            },
1263            DataPoint {
1264                x: 1.0,
1265                y: 12.0,
1266                label: None,
1267                series: None,
1268            },
1269            DataPoint {
1270                x: 2.0,
1271                y: 100.0,
1272                label: None,
1273                series: None,
1274            }, // Outlier
1275            DataPoint {
1276                x: 3.0,
1277                y: 11.0,
1278                label: None,
1279                series: None,
1280            },
1281        ];
1282
1283        let outliers = analyzer.detect_outliers(&data_points);
1284        assert_eq!(outliers.len(), 1);
1285        assert_eq!(outliers[0].y, 100.0);
1286    }
1287
1288    #[test]
1289    fn test_data_extraction() {
1290        let extractor = ChartDataExtractor::new(true, true, true);
1291
1292        let temp_file = NamedTempFile::new().unwrap();
1293        let data_points = extractor
1294            .extract(temp_file.path(), ChartType::Line)
1295            .unwrap();
1296
1297        assert!(!data_points.is_empty());
1298    }
1299
1300    #[test]
1301    fn test_description_generation() {
1302        let generator = ChartDescriptionGenerator::new();
1303
1304        let analysis = ChartAnalysisResult {
1305            chart_type: ChartType::Line,
1306            confidence: 0.9,
1307            data_points: vec![
1308                DataPoint {
1309                    x: 0.0,
1310                    y: 10.0,
1311                    label: None,
1312                    series: None,
1313                },
1314                DataPoint {
1315                    x: 1.0,
1316                    y: 15.0,
1317                    label: None,
1318                    series: None,
1319                },
1320            ],
1321            elements: ChartElements {
1322                title: None,
1323                axes: ChartAxes {
1324                    x_label: None,
1325                    y_label: None,
1326                    x_range: None,
1327                    y_range: None,
1328                },
1329                legend: vec![],
1330                series: vec![],
1331                annotations: vec![],
1332            },
1333            visual_properties: VisualProperties {
1334                chart_area: ChartArea {
1335                    bounds: (0.0, 0.0, 100.0, 100.0),
1336                    plot_area: (0.0, 0.0, 100.0, 100.0),
1337                    margins: (0.0, 0.0, 0.0, 0.0),
1338                },
1339                color_scheme: ColorScheme {
1340                    primary_colors: vec![],
1341                    palette_type: PaletteType::Categorical,
1342                    accessibility_score: 1.0,
1343                },
1344                typography: Typography {
1345                    title_font: None,
1346                    axis_font: None,
1347                    legend_font: None,
1348                    readability_score: 1.0,
1349                },
1350                grid: None,
1351            },
1352        };
1353
1354        let description = generator.generate(&analysis).unwrap();
1355        assert!(description.contains("line chart"));
1356        assert!(description.contains("2 data points"));
1357    }
1358}