1use super::{
6 AnalyzedChart, ChartAnalysisConfig, ChartAxes, ChartProcessor, ChartType, DataPoint,
7 Seasonality, TrendAnalysis, TrendDirection,
8};
9use crate::{RragError, RragResult};
10use std::path::Path;
11
12pub struct DefaultChartProcessor {
14 config: ChartAnalysisConfig,
16
17 type_classifier: ChartTypeClassifier,
19
20 data_extractor: ChartDataExtractor,
22
23 trend_analyzer: TrendAnalyzer,
25
26 description_generator: ChartDescriptionGenerator,
28}
29
30pub struct ChartTypeClassifier {
32 models: Vec<ClassificationModel>,
34}
35
36pub struct ChartDataExtractor {
38 ocr_enabled: bool,
40
41 color_analysis: bool,
43
44 shape_detection: bool,
46}
47
48pub struct TrendAnalyzer {
50 min_points: usize,
52
53 smoothing_window: usize,
55
56 seasonality_detection: bool,
58}
59
60pub struct ChartDescriptionGenerator {
62 templates: std::collections::HashMap<ChartType, String>,
64
65 nlg_enabled: bool,
67}
68
69#[derive(Debug, Clone)]
71pub struct ClassificationModel {
72 model_type: ModelType,
74
75 confidence_threshold: f32,
77
78 features: Vec<FeatureType>,
80}
81
82#[derive(Debug, Clone, Copy)]
84pub enum ModelType {
85 CNN,
86 SVM,
87 RandomForest,
88 Ensemble,
89}
90
91#[derive(Debug, Clone, Copy)]
93pub enum FeatureType {
94 ColorHistogram,
95 EdgeDetection,
96 ShapeFeatures,
97 TextFeatures,
98 LayoutFeatures,
99}
100
101#[derive(Debug, Clone)]
103pub struct ChartAnalysisResult {
104 pub chart_type: ChartType,
106
107 pub confidence: f32,
109
110 pub data_points: Vec<DataPoint>,
112
113 pub elements: ChartElements,
115
116 pub visual_properties: VisualProperties,
118}
119
120#[derive(Debug, Clone)]
122pub struct ChartElements {
123 pub title: Option<String>,
125
126 pub axes: ChartAxes,
128
129 pub legend: Vec<LegendEntry>,
131
132 pub series: Vec<DataSeries>,
134
135 pub annotations: Vec<ChartAnnotation>,
137}
138
139#[derive(Debug, Clone)]
141pub struct LegendEntry {
142 pub text: String,
144
145 pub color: Option<(u8, u8, u8)>,
147
148 pub symbol: Option<MarkerType>,
150}
151
152#[derive(Debug, Clone)]
154pub struct DataSeries {
155 pub name: String,
157
158 pub points: Vec<DataPoint>,
160
161 pub color: Option<(u8, u8, u8)>,
163
164 pub line_style: Option<LineStyle>,
166}
167
168#[derive(Debug, Clone)]
170pub struct ChartAnnotation {
171 pub text: String,
173
174 pub position: (f64, f64),
176
177 pub annotation_type: AnnotationType,
179}
180
181#[derive(Debug, Clone)]
183pub struct VisualProperties {
184 pub chart_area: ChartArea,
186
187 pub color_scheme: ColorScheme,
189
190 pub typography: Typography,
192
193 pub grid: Option<GridProperties>,
195}
196
197#[derive(Debug, Clone)]
199pub struct ChartArea {
200 pub bounds: (f64, f64, f64, f64), pub plot_area: (f64, f64, f64, f64),
205
206 pub margins: (f64, f64, f64, f64), }
209
210#[derive(Debug, Clone)]
212pub struct ColorScheme {
213 pub primary_colors: Vec<(u8, u8, u8)>,
215
216 pub palette_type: PaletteType,
218
219 pub accessibility_score: f32,
221}
222
223#[derive(Debug, Clone)]
225pub struct Typography {
226 pub title_font: Option<FontInfo>,
228
229 pub axis_font: Option<FontInfo>,
231
232 pub legend_font: Option<FontInfo>,
234
235 pub readability_score: f32,
237}
238
239#[derive(Debug, Clone)]
241pub struct FontInfo {
242 pub family: String,
244
245 pub size: f32,
247
248 pub weight: FontWeight,
250
251 pub color: (u8, u8, u8),
253}
254
255#[derive(Debug, Clone)]
257pub struct GridProperties {
258 pub grid_type: GridType,
260
261 pub color: (u8, u8, u8),
263
264 pub opacity: f32,
266
267 pub line_count: (usize, usize), }
270
271#[derive(Debug, Clone, Copy)]
273pub enum MarkerType {
274 Circle,
275 Square,
276 Triangle,
277 Diamond,
278 Plus,
279 Cross,
280 Star,
281}
282
283#[derive(Debug, Clone, Copy)]
285pub enum LineStyle {
286 Solid,
287 Dashed,
288 Dotted,
289 DashDot,
290}
291
292#[derive(Debug, Clone, Copy)]
294pub enum AnnotationType {
295 Label,
296 Arrow,
297 Callout,
298 Highlight,
299}
300
301#[derive(Debug, Clone, Copy)]
303pub enum PaletteType {
304 Sequential,
305 Diverging,
306 Categorical,
307 Monochromatic,
308}
309
310#[derive(Debug, Clone, Copy)]
312pub enum FontWeight {
313 Thin,
314 Light,
315 Regular,
316 Medium,
317 Bold,
318 ExtraBold,
319}
320
321#[derive(Debug, Clone, Copy)]
323pub enum GridType {
324 Major,
325 Minor,
326 Both,
327 None,
328}
329
330impl DefaultChartProcessor {
331 pub fn new(config: ChartAnalysisConfig) -> RragResult<Self> {
333 let type_classifier = ChartTypeClassifier::new()?;
334 let data_extractor = ChartDataExtractor::new(true, true, true);
335 let trend_analyzer = TrendAnalyzer::new(5, 3, true);
336 let description_generator = ChartDescriptionGenerator::new();
337
338 Ok(Self {
339 config,
340 type_classifier,
341 data_extractor,
342 trend_analyzer,
343 description_generator,
344 })
345 }
346
347 pub fn analyze_comprehensive(&self, image_path: &Path) -> RragResult<ChartAnalysisResult> {
349 let (chart_type, confidence) = self.type_classifier.classify(image_path)?;
351
352 let data_points = self.data_extractor.extract(image_path, chart_type)?;
354
355 let elements = self.analyze_elements(image_path, chart_type)?;
357
358 let visual_properties = self.analyze_visual_properties(image_path)?;
360
361 Ok(ChartAnalysisResult {
362 chart_type,
363 confidence,
364 data_points,
365 elements,
366 visual_properties,
367 })
368 }
369
370 fn analyze_elements(
372 &self,
373 image_path: &Path,
374 chart_type: ChartType,
375 ) -> RragResult<ChartElements> {
376 let title = self.extract_title(image_path)?;
378
379 let axes = self.extract_axes(image_path)?;
381
382 let legend = self.extract_legend(image_path)?;
384
385 let series = self.extract_series(image_path, chart_type)?;
387
388 let annotations = self.extract_annotations(image_path)?;
390
391 Ok(ChartElements {
392 title,
393 axes,
394 legend,
395 series,
396 annotations,
397 })
398 }
399
400 fn extract_title(&self, _image_path: &Path) -> RragResult<Option<String>> {
402 Ok(Some("Sample Chart Title".to_string()))
404 }
405
406 fn extract_axes(&self, _image_path: &Path) -> RragResult<ChartAxes> {
408 Ok(ChartAxes {
410 x_label: Some("Time".to_string()),
411 y_label: Some("Value".to_string()),
412 x_range: Some((0.0, 100.0)),
413 y_range: Some((0.0, 50.0)),
414 })
415 }
416
417 fn extract_legend(&self, _image_path: &Path) -> RragResult<Vec<LegendEntry>> {
419 Ok(vec![
421 LegendEntry {
422 text: "Series 1".to_string(),
423 color: Some((255, 0, 0)),
424 symbol: Some(MarkerType::Circle),
425 },
426 LegendEntry {
427 text: "Series 2".to_string(),
428 color: Some((0, 255, 0)),
429 symbol: Some(MarkerType::Square),
430 },
431 ])
432 }
433
434 fn extract_series(
436 &self,
437 _image_path: &Path,
438 chart_type: ChartType,
439 ) -> RragResult<Vec<DataSeries>> {
440 match chart_type {
442 ChartType::Line => self.extract_line_series(),
443 ChartType::Bar => self.extract_bar_series(),
444 ChartType::Pie => self.extract_pie_series(),
445 ChartType::Scatter => self.extract_scatter_series(),
446 _ => Ok(vec![]),
447 }
448 }
449
450 fn extract_line_series(&self) -> RragResult<Vec<DataSeries>> {
452 Ok(vec![DataSeries {
453 name: "Series 1".to_string(),
454 points: vec![
455 DataPoint {
456 x: 0.0,
457 y: 10.0,
458 label: None,
459 series: Some("Series 1".to_string()),
460 },
461 DataPoint {
462 x: 1.0,
463 y: 15.0,
464 label: None,
465 series: Some("Series 1".to_string()),
466 },
467 DataPoint {
468 x: 2.0,
469 y: 12.0,
470 label: None,
471 series: Some("Series 1".to_string()),
472 },
473 ],
474 color: Some((255, 0, 0)),
475 line_style: Some(LineStyle::Solid),
476 }])
477 }
478
479 fn extract_bar_series(&self) -> RragResult<Vec<DataSeries>> {
481 Ok(vec![DataSeries {
482 name: "Categories".to_string(),
483 points: vec![
484 DataPoint {
485 x: 0.0,
486 y: 20.0,
487 label: Some("Category A".to_string()),
488 series: None,
489 },
490 DataPoint {
491 x: 1.0,
492 y: 35.0,
493 label: Some("Category B".to_string()),
494 series: None,
495 },
496 DataPoint {
497 x: 2.0,
498 y: 25.0,
499 label: Some("Category C".to_string()),
500 series: None,
501 },
502 ],
503 color: Some((0, 100, 200)),
504 line_style: None,
505 }])
506 }
507
508 fn extract_pie_series(&self) -> RragResult<Vec<DataSeries>> {
510 Ok(vec![DataSeries {
511 name: "Pie Slices".to_string(),
512 points: vec![
513 DataPoint {
514 x: 0.0,
515 y: 40.0,
516 label: Some("Slice A".to_string()),
517 series: None,
518 },
519 DataPoint {
520 x: 1.0,
521 y: 30.0,
522 label: Some("Slice B".to_string()),
523 series: None,
524 },
525 DataPoint {
526 x: 2.0,
527 y: 30.0,
528 label: Some("Slice C".to_string()),
529 series: None,
530 },
531 ],
532 color: None,
533 line_style: None,
534 }])
535 }
536
537 fn extract_scatter_series(&self) -> RragResult<Vec<DataSeries>> {
539 Ok(vec![DataSeries {
540 name: "Scatter Points".to_string(),
541 points: vec![
542 DataPoint {
543 x: 5.0,
544 y: 10.0,
545 label: None,
546 series: None,
547 },
548 DataPoint {
549 x: 15.0,
550 y: 25.0,
551 label: None,
552 series: None,
553 },
554 DataPoint {
555 x: 25.0,
556 y: 20.0,
557 label: None,
558 series: None,
559 },
560 ],
561 color: Some((100, 100, 100)),
562 line_style: None,
563 }])
564 }
565
566 fn extract_annotations(&self, _image_path: &Path) -> RragResult<Vec<ChartAnnotation>> {
568 Ok(vec![])
570 }
571
572 fn analyze_visual_properties(&self, _image_path: &Path) -> RragResult<VisualProperties> {
574 Ok(VisualProperties {
576 chart_area: ChartArea {
577 bounds: (0.0, 0.0, 800.0, 600.0),
578 plot_area: (100.0, 100.0, 600.0, 400.0),
579 margins: (50.0, 50.0, 50.0, 100.0),
580 },
581 color_scheme: ColorScheme {
582 primary_colors: vec![(255, 0, 0), (0, 255, 0), (0, 0, 255)],
583 palette_type: PaletteType::Categorical,
584 accessibility_score: 0.8,
585 },
586 typography: Typography {
587 title_font: Some(FontInfo {
588 family: "Arial".to_string(),
589 size: 16.0,
590 weight: FontWeight::Bold,
591 color: (0, 0, 0),
592 }),
593 axis_font: Some(FontInfo {
594 family: "Arial".to_string(),
595 size: 12.0,
596 weight: FontWeight::Regular,
597 color: (100, 100, 100),
598 }),
599 legend_font: Some(FontInfo {
600 family: "Arial".to_string(),
601 size: 10.0,
602 weight: FontWeight::Regular,
603 color: (0, 0, 0),
604 }),
605 readability_score: 0.9,
606 },
607 grid: Some(GridProperties {
608 grid_type: GridType::Major,
609 color: (200, 200, 200),
610 opacity: 0.3,
611 line_count: (5, 10),
612 }),
613 })
614 }
615}
616
617impl ChartProcessor for DefaultChartProcessor {
618 fn analyze_chart(&self, image_path: &Path) -> RragResult<AnalyzedChart> {
619 let analysis = self.analyze_comprehensive(image_path)?;
620
621 let description = if self.config.generate_descriptions {
623 Some(self.description_generator.generate(&analysis)?)
624 } else {
625 None
626 };
627
628 let trends = if self.config.analyze_trends && !analysis.data_points.is_empty() {
630 Some(self.trend_analyzer.analyze(&analysis.data_points)?)
631 } else {
632 None
633 };
634
635 Ok(AnalyzedChart {
636 id: format!(
637 "chart_{}",
638 uuid::Uuid::new_v4().to_string().split('-').next().unwrap()
639 ),
640 chart_type: analysis.chart_type,
641 title: analysis.elements.title,
642 axes: analysis.elements.axes,
643 data_points: analysis.data_points,
644 trends,
645 description,
646 embedding: None, })
648 }
649
650 fn extract_data_points(&self, chart_image: &Path) -> RragResult<Vec<DataPoint>> {
651 let analysis = self.analyze_comprehensive(chart_image)?;
652 Ok(analysis.data_points)
653 }
654
655 fn identify_type(&self, chart_image: &Path) -> RragResult<ChartType> {
656 let (chart_type, _confidence) = self.type_classifier.classify(chart_image)?;
657 Ok(chart_type)
658 }
659
660 fn analyze_trends(&self, data_points: &[DataPoint]) -> RragResult<TrendAnalysis> {
661 self.trend_analyzer.analyze(data_points)
662 }
663}
664
665impl ChartTypeClassifier {
666 pub fn new() -> RragResult<Self> {
668 let models = vec![
669 ClassificationModel {
670 model_type: ModelType::CNN,
671 confidence_threshold: 0.8,
672 features: vec![
673 FeatureType::ColorHistogram,
674 FeatureType::EdgeDetection,
675 FeatureType::ShapeFeatures,
676 ],
677 },
678 ClassificationModel {
679 model_type: ModelType::SVM,
680 confidence_threshold: 0.7,
681 features: vec![FeatureType::LayoutFeatures, FeatureType::TextFeatures],
682 },
683 ];
684
685 Ok(Self { models })
686 }
687
688 pub fn classify(&self, image_path: &Path) -> RragResult<(ChartType, f32)> {
690 let filename = image_path
692 .file_name()
693 .and_then(|name| name.to_str())
694 .unwrap_or("");
695
696 let (chart_type, confidence) = if filename.contains("line") {
698 (ChartType::Line, 0.95)
699 } else if filename.contains("bar") {
700 (ChartType::Bar, 0.90)
701 } else if filename.contains("pie") {
702 (ChartType::Pie, 0.85)
703 } else if filename.contains("scatter") {
704 (ChartType::Scatter, 0.80)
705 } else {
706 (ChartType::Unknown, 0.5)
707 };
708
709 Ok((chart_type, confidence))
710 }
711
712 pub fn extract_features(&self, _image_path: &Path) -> RragResult<Vec<f32>> {
714 let mut features = Vec::new();
716
717 features.extend(vec![0.1, 0.2, 0.3, 0.4]); features.extend(vec![0.5, 0.6]); features.extend(vec![0.7, 0.8, 0.9]); features.extend(vec![0.2, 0.4]); features.push(0.3); Ok(features)
733 }
734}
735
736impl ChartDataExtractor {
737 pub fn new(ocr_enabled: bool, color_analysis: bool, shape_detection: bool) -> Self {
739 Self {
740 ocr_enabled,
741 color_analysis,
742 shape_detection,
743 }
744 }
745
746 pub fn extract(&self, image_path: &Path, chart_type: ChartType) -> RragResult<Vec<DataPoint>> {
748 match chart_type {
749 ChartType::Line => self.extract_line_data(image_path),
750 ChartType::Bar => self.extract_bar_data(image_path),
751 ChartType::Pie => self.extract_pie_data(image_path),
752 ChartType::Scatter => self.extract_scatter_data(image_path),
753 ChartType::Area => self.extract_area_data(image_path),
754 ChartType::Histogram => self.extract_histogram_data(image_path),
755 _ => Ok(vec![]),
756 }
757 }
758
759 fn extract_line_data(&self, _image_path: &Path) -> RragResult<Vec<DataPoint>> {
761 Ok(vec![
763 DataPoint {
764 x: 0.0,
765 y: 10.0,
766 label: None,
767 series: Some("Line 1".to_string()),
768 },
769 DataPoint {
770 x: 1.0,
771 y: 15.0,
772 label: None,
773 series: Some("Line 1".to_string()),
774 },
775 DataPoint {
776 x: 2.0,
777 y: 12.0,
778 label: None,
779 series: Some("Line 1".to_string()),
780 },
781 DataPoint {
782 x: 3.0,
783 y: 18.0,
784 label: None,
785 series: Some("Line 1".to_string()),
786 },
787 ])
788 }
789
790 fn extract_bar_data(&self, _image_path: &Path) -> RragResult<Vec<DataPoint>> {
792 Ok(vec![
794 DataPoint {
795 x: 0.0,
796 y: 25.0,
797 label: Some("Q1".to_string()),
798 series: None,
799 },
800 DataPoint {
801 x: 1.0,
802 y: 30.0,
803 label: Some("Q2".to_string()),
804 series: None,
805 },
806 DataPoint {
807 x: 2.0,
808 y: 35.0,
809 label: Some("Q3".to_string()),
810 series: None,
811 },
812 DataPoint {
813 x: 3.0,
814 y: 40.0,
815 label: Some("Q4".to_string()),
816 series: None,
817 },
818 ])
819 }
820
821 fn extract_pie_data(&self, _image_path: &Path) -> RragResult<Vec<DataPoint>> {
823 Ok(vec![
825 DataPoint {
826 x: 0.0,
827 y: 40.0,
828 label: Some("Category A".to_string()),
829 series: None,
830 },
831 DataPoint {
832 x: 1.0,
833 y: 30.0,
834 label: Some("Category B".to_string()),
835 series: None,
836 },
837 DataPoint {
838 x: 2.0,
839 y: 20.0,
840 label: Some("Category C".to_string()),
841 series: None,
842 },
843 DataPoint {
844 x: 3.0,
845 y: 10.0,
846 label: Some("Category D".to_string()),
847 series: None,
848 },
849 ])
850 }
851
852 fn extract_scatter_data(&self, _image_path: &Path) -> RragResult<Vec<DataPoint>> {
854 Ok(vec![
856 DataPoint {
857 x: 5.0,
858 y: 10.0,
859 label: None,
860 series: None,
861 },
862 DataPoint {
863 x: 15.0,
864 y: 25.0,
865 label: None,
866 series: None,
867 },
868 DataPoint {
869 x: 25.0,
870 y: 20.0,
871 label: None,
872 series: None,
873 },
874 DataPoint {
875 x: 35.0,
876 y: 40.0,
877 label: None,
878 series: None,
879 },
880 ])
881 }
882
883 fn extract_area_data(&self, image_path: &Path) -> RragResult<Vec<DataPoint>> {
885 self.extract_line_data(image_path)
887 }
888
889 fn extract_histogram_data(&self, _image_path: &Path) -> RragResult<Vec<DataPoint>> {
891 Ok(vec![
893 DataPoint {
894 x: 0.0,
895 y: 5.0,
896 label: Some("0-10".to_string()),
897 series: None,
898 },
899 DataPoint {
900 x: 1.0,
901 y: 15.0,
902 label: Some("10-20".to_string()),
903 series: None,
904 },
905 DataPoint {
906 x: 2.0,
907 y: 25.0,
908 label: Some("20-30".to_string()),
909 series: None,
910 },
911 DataPoint {
912 x: 3.0,
913 y: 10.0,
914 label: Some("30-40".to_string()),
915 series: None,
916 },
917 ])
918 }
919}
920
921impl TrendAnalyzer {
922 pub fn new(min_points: usize, smoothing_window: usize, seasonality_detection: bool) -> Self {
924 Self {
925 min_points,
926 smoothing_window,
927 seasonality_detection,
928 }
929 }
930
931 pub fn analyze(&self, data_points: &[DataPoint]) -> RragResult<TrendAnalysis> {
933 if data_points.len() < self.min_points {
934 return Err(RragError::validation(
935 "data_points",
936 format!("minimum {} points", self.min_points),
937 format!("{} points", data_points.len()),
938 ));
939 }
940
941 let direction = self.calculate_trend_direction(data_points);
943
944 let strength = self.calculate_trend_strength(data_points);
946
947 let seasonality = if self.seasonality_detection {
949 self.detect_seasonality(data_points)
950 } else {
951 None
952 };
953
954 let outliers = self.detect_outliers(data_points);
956
957 let forecast = if data_points.len() >= 10 {
959 Some(self.generate_forecast(data_points, 5)?)
960 } else {
961 None
962 };
963
964 Ok(TrendAnalysis {
965 direction,
966 strength,
967 seasonality,
968 outliers,
969 forecast,
970 })
971 }
972
973 fn calculate_trend_direction(&self, data_points: &[DataPoint]) -> TrendDirection {
975 if data_points.len() < 2 {
976 return TrendDirection::Stable;
977 }
978
979 let first_y = data_points[0].y;
980 let last_y = data_points[data_points.len() - 1].y;
981 let change = last_y - first_y;
982
983 let volatility = self.calculate_volatility(data_points);
985
986 if change.abs() < volatility * 0.5 {
987 TrendDirection::Stable
988 } else if volatility > change.abs() * 2.0 {
989 TrendDirection::Volatile
990 } else if change > 0.0 {
991 TrendDirection::Increasing
992 } else {
993 TrendDirection::Decreasing
994 }
995 }
996
997 fn calculate_trend_strength(&self, data_points: &[DataPoint]) -> f32 {
999 if data_points.len() < 2 {
1000 return 0.0;
1001 }
1002
1003 let n = data_points.len() as f64;
1005 let sum_x: f64 = data_points.iter().map(|p| p.x).sum();
1006 let sum_y: f64 = data_points.iter().map(|p| p.y as f64).sum();
1007 let sum_xy: f64 = data_points.iter().map(|p| p.x * p.y as f64).sum();
1008 let sum_x2: f64 = data_points.iter().map(|p| p.x * p.x).sum();
1009 let sum_y2: f64 = data_points
1010 .iter()
1011 .map(|p| (p.y as f64) * (p.y as f64))
1012 .sum();
1013
1014 let numerator = n * sum_xy - sum_x * sum_y;
1015 let denominator = ((n * sum_x2 - sum_x * sum_x) * (n * sum_y2 - sum_y * sum_y)).sqrt();
1016
1017 if denominator == 0.0 {
1018 return 0.0;
1019 }
1020
1021 let r = numerator / denominator;
1022 (r * r) as f32 }
1024
1025 fn calculate_volatility(&self, data_points: &[DataPoint]) -> f64 {
1027 if data_points.len() < 2 {
1028 return 0.0;
1029 }
1030
1031 let values: Vec<f64> = data_points.iter().map(|p| p.y as f64).collect();
1032 let mean = values.iter().sum::<f64>() / values.len() as f64;
1033 let variance = values.iter().map(|v| (v - mean).powi(2)).sum::<f64>() / values.len() as f64;
1034 variance.sqrt()
1035 }
1036
1037 fn detect_seasonality(&self, data_points: &[DataPoint]) -> Option<Seasonality> {
1039 if data_points.len() < 12 {
1040 return None; }
1042
1043 Some(Seasonality {
1047 period: 12.0, amplitude: 5.0,
1049 phase: 0.0,
1050 })
1051 }
1052
1053 fn detect_outliers(&self, data_points: &[DataPoint]) -> Vec<DataPoint> {
1055 if data_points.len() < 4 {
1056 return vec![];
1057 }
1058
1059 let mut y_values: Vec<f32> = data_points.iter().map(|p| p.y as f32).collect();
1060 y_values.sort_by(|a, b| a.partial_cmp(b).unwrap());
1061
1062 let q1_idx = y_values.len() / 4;
1063 let q3_idx = 3 * y_values.len() / 4;
1064 let q1 = y_values[q1_idx];
1065 let q3 = y_values[q3_idx];
1066 let iqr = q3 - q1;
1067
1068 let lower_bound = q1 - 1.5 * iqr;
1069 let upper_bound = q3 + 1.5 * iqr;
1070
1071 data_points
1072 .iter()
1073 .filter(|p| (p.y as f32) < lower_bound || (p.y as f32) > upper_bound)
1074 .cloned()
1075 .collect()
1076 }
1077
1078 fn generate_forecast(
1080 &self,
1081 data_points: &[DataPoint],
1082 num_points: usize,
1083 ) -> RragResult<Vec<DataPoint>> {
1084 if data_points.len() < 2 {
1085 return Ok(vec![]);
1086 }
1087
1088 let n = data_points.len() as f64;
1090 let sum_x: f64 = data_points.iter().map(|p| p.x).sum();
1091 let sum_y: f64 = data_points.iter().map(|p| p.y as f64).sum();
1092 let sum_xy: f64 = data_points.iter().map(|p| p.x * p.y as f64).sum();
1093 let sum_x2: f64 = data_points.iter().map(|p| p.x * p.x).sum();
1094
1095 let slope = (n * sum_xy - sum_x * sum_y) / (n * sum_x2 - sum_x * sum_x);
1096 let intercept = (sum_y - slope * sum_x) / n;
1097
1098 let last_x = data_points.last().unwrap().x;
1100 let mut forecast = Vec::new();
1101
1102 for i in 1..=num_points {
1103 let x = last_x + i as f64;
1104 let y = (slope * x + intercept) as f32;
1105
1106 forecast.push(DataPoint {
1107 x,
1108 y: y as f64,
1109 label: Some(format!("Forecast {}", i)),
1110 series: Some("Forecast".to_string()),
1111 });
1112 }
1113
1114 Ok(forecast)
1115 }
1116}
1117
1118impl ChartDescriptionGenerator {
1119 pub fn new() -> Self {
1121 let mut templates = std::collections::HashMap::new();
1122
1123 templates.insert(
1124 ChartType::Line,
1125 "This line chart shows {data_description}. The trend is {trend_direction} with a strength of {trend_strength:.2}.".to_string()
1126 );
1127
1128 templates.insert(
1129 ChartType::Bar,
1130 "This bar chart displays {data_description}. The highest value is {max_value} and the lowest is {min_value}.".to_string()
1131 );
1132
1133 templates.insert(
1134 ChartType::Pie,
1135 "This pie chart represents {data_description}. The largest segment is {largest_segment} at {largest_percentage:.1}%.".to_string()
1136 );
1137
1138 Self {
1139 templates,
1140 nlg_enabled: false,
1141 }
1142 }
1143
1144 pub fn generate(&self, analysis: &ChartAnalysisResult) -> RragResult<String> {
1146 if let Some(template) = self.templates.get(&analysis.chart_type) {
1147 let description = self.fill_template(template, analysis)?;
1148 Ok(description)
1149 } else {
1150 Ok(format!(
1151 "Chart of type {:?} with {} data points",
1152 analysis.chart_type,
1153 analysis.data_points.len()
1154 ))
1155 }
1156 }
1157
1158 fn fill_template(&self, template: &str, analysis: &ChartAnalysisResult) -> RragResult<String> {
1160 let mut description = template.to_string();
1161
1162 description = description.replace(
1164 "{data_description}",
1165 &format!("{} data points", analysis.data_points.len()),
1166 );
1167
1168 if !analysis.data_points.is_empty() {
1169 let max_y = analysis
1170 .data_points
1171 .iter()
1172 .map(|p| p.y as f32)
1173 .fold(f32::NEG_INFINITY, |a, b| a.max(b));
1174 let min_y = analysis
1175 .data_points
1176 .iter()
1177 .map(|p| p.y as f32)
1178 .fold(f32::INFINITY, |a, b| a.min(b));
1179
1180 description = description.replace("{max_value}", &max_y.to_string());
1181 description = description.replace("{min_value}", &min_y.to_string());
1182 }
1183
1184 Ok(description)
1185 }
1186}
1187
1188#[cfg(test)]
1189mod tests {
1190 use super::*;
1191 use tempfile::NamedTempFile;
1192
1193 #[test]
1194 fn test_chart_processor_creation() {
1195 let config = ChartAnalysisConfig::default();
1196 let processor = DefaultChartProcessor::new(config).unwrap();
1197
1198 assert!(processor.config.extract_data);
1199 assert!(processor.config.generate_descriptions);
1200 assert!(processor.config.analyze_trends);
1201 }
1202
1203 #[test]
1204 fn test_chart_type_classification() {
1205 let classifier = ChartTypeClassifier::new().unwrap();
1206
1207 let temp_file = NamedTempFile::new().unwrap();
1209 let path = temp_file.path().with_file_name("line_chart.png");
1210
1211 let (chart_type, confidence) = classifier.classify(&path).unwrap();
1212 assert_eq!(chart_type, ChartType::Line);
1213 assert!(confidence > 0.9);
1214 }
1215
1216 #[test]
1217 fn test_trend_analysis() {
1218 let analyzer = TrendAnalyzer::new(3, 2, false);
1219
1220 let data_points = vec![
1221 DataPoint {
1222 x: 0.0,
1223 y: 10.0,
1224 label: None,
1225 series: None,
1226 },
1227 DataPoint {
1228 x: 1.0,
1229 y: 15.0,
1230 label: None,
1231 series: None,
1232 },
1233 DataPoint {
1234 x: 2.0,
1235 y: 20.0,
1236 label: None,
1237 series: None,
1238 },
1239 DataPoint {
1240 x: 3.0,
1241 y: 25.0,
1242 label: None,
1243 series: None,
1244 },
1245 ];
1246
1247 let trend = analyzer.analyze(&data_points).unwrap();
1248 assert_eq!(trend.direction, TrendDirection::Increasing);
1249 assert!(trend.strength > 0.8);
1250 }
1251
1252 #[test]
1253 fn test_outlier_detection() {
1254 let analyzer = TrendAnalyzer::new(3, 2, false);
1255
1256 let data_points = vec![
1257 DataPoint {
1258 x: 0.0,
1259 y: 10.0,
1260 label: None,
1261 series: None,
1262 },
1263 DataPoint {
1264 x: 1.0,
1265 y: 12.0,
1266 label: None,
1267 series: None,
1268 },
1269 DataPoint {
1270 x: 2.0,
1271 y: 100.0,
1272 label: None,
1273 series: None,
1274 }, DataPoint {
1276 x: 3.0,
1277 y: 11.0,
1278 label: None,
1279 series: None,
1280 },
1281 ];
1282
1283 let outliers = analyzer.detect_outliers(&data_points);
1284 assert_eq!(outliers.len(), 1);
1285 assert_eq!(outliers[0].y, 100.0);
1286 }
1287
1288 #[test]
1289 fn test_data_extraction() {
1290 let extractor = ChartDataExtractor::new(true, true, true);
1291
1292 let temp_file = NamedTempFile::new().unwrap();
1293 let data_points = extractor
1294 .extract(temp_file.path(), ChartType::Line)
1295 .unwrap();
1296
1297 assert!(!data_points.is_empty());
1298 }
1299
1300 #[test]
1301 fn test_description_generation() {
1302 let generator = ChartDescriptionGenerator::new();
1303
1304 let analysis = ChartAnalysisResult {
1305 chart_type: ChartType::Line,
1306 confidence: 0.9,
1307 data_points: vec![
1308 DataPoint {
1309 x: 0.0,
1310 y: 10.0,
1311 label: None,
1312 series: None,
1313 },
1314 DataPoint {
1315 x: 1.0,
1316 y: 15.0,
1317 label: None,
1318 series: None,
1319 },
1320 ],
1321 elements: ChartElements {
1322 title: None,
1323 axes: ChartAxes {
1324 x_label: None,
1325 y_label: None,
1326 x_range: None,
1327 y_range: None,
1328 },
1329 legend: vec![],
1330 series: vec![],
1331 annotations: vec![],
1332 },
1333 visual_properties: VisualProperties {
1334 chart_area: ChartArea {
1335 bounds: (0.0, 0.0, 100.0, 100.0),
1336 plot_area: (0.0, 0.0, 100.0, 100.0),
1337 margins: (0.0, 0.0, 0.0, 0.0),
1338 },
1339 color_scheme: ColorScheme {
1340 primary_colors: vec![],
1341 palette_type: PaletteType::Categorical,
1342 accessibility_score: 1.0,
1343 },
1344 typography: Typography {
1345 title_font: None,
1346 axis_font: None,
1347 legend_font: None,
1348 readability_score: 1.0,
1349 },
1350 grid: None,
1351 },
1352 };
1353
1354 let description = generator.generate(&analysis).unwrap();
1355 assert!(description.contains("line chart"));
1356 assert!(description.contains("2 data points"));
1357 }
1358}