use super::{
AnalyzedChart, ChartAnalysisConfig, ChartAxes, ChartProcessor, ChartType, DataPoint,
Seasonality, TrendAnalysis, TrendDirection,
};
use crate::{RragError, RragResult};
use std::path::Path;
pub struct DefaultChartProcessor {
config: ChartAnalysisConfig,
type_classifier: ChartTypeClassifier,
data_extractor: ChartDataExtractor,
trend_analyzer: TrendAnalyzer,
description_generator: ChartDescriptionGenerator,
}
pub struct ChartTypeClassifier {
models: Vec<ClassificationModel>,
}
pub struct ChartDataExtractor {
ocr_enabled: bool,
color_analysis: bool,
shape_detection: bool,
}
pub struct TrendAnalyzer {
min_points: usize,
smoothing_window: usize,
seasonality_detection: bool,
}
pub struct ChartDescriptionGenerator {
templates: std::collections::HashMap<ChartType, String>,
nlg_enabled: bool,
}
#[derive(Debug, Clone)]
pub struct ClassificationModel {
model_type: ModelType,
confidence_threshold: f32,
features: Vec<FeatureType>,
}
#[derive(Debug, Clone, Copy)]
pub enum ModelType {
CNN,
SVM,
RandomForest,
Ensemble,
}
#[derive(Debug, Clone, Copy)]
pub enum FeatureType {
ColorHistogram,
EdgeDetection,
ShapeFeatures,
TextFeatures,
LayoutFeatures,
}
#[derive(Debug, Clone)]
pub struct ChartAnalysisResult {
pub chart_type: ChartType,
pub confidence: f32,
pub data_points: Vec<DataPoint>,
pub elements: ChartElements,
pub visual_properties: VisualProperties,
}
#[derive(Debug, Clone)]
pub struct ChartElements {
pub title: Option<String>,
pub axes: ChartAxes,
pub legend: Vec<LegendEntry>,
pub series: Vec<DataSeries>,
pub annotations: Vec<ChartAnnotation>,
}
#[derive(Debug, Clone)]
pub struct LegendEntry {
pub text: String,
pub color: Option<(u8, u8, u8)>,
pub symbol: Option<MarkerType>,
}
#[derive(Debug, Clone)]
pub struct DataSeries {
pub name: String,
pub points: Vec<DataPoint>,
pub color: Option<(u8, u8, u8)>,
pub line_style: Option<LineStyle>,
}
#[derive(Debug, Clone)]
pub struct ChartAnnotation {
pub text: String,
pub position: (f64, f64),
pub annotation_type: AnnotationType,
}
#[derive(Debug, Clone)]
pub struct VisualProperties {
pub chart_area: ChartArea,
pub color_scheme: ColorScheme,
pub typography: Typography,
pub grid: Option<GridProperties>,
}
#[derive(Debug, Clone)]
pub struct ChartArea {
pub bounds: (f64, f64, f64, f64),
pub plot_area: (f64, f64, f64, f64),
pub margins: (f64, f64, f64, f64), }
#[derive(Debug, Clone)]
pub struct ColorScheme {
pub primary_colors: Vec<(u8, u8, u8)>,
pub palette_type: PaletteType,
pub accessibility_score: f32,
}
#[derive(Debug, Clone)]
pub struct Typography {
pub title_font: Option<FontInfo>,
pub axis_font: Option<FontInfo>,
pub legend_font: Option<FontInfo>,
pub readability_score: f32,
}
#[derive(Debug, Clone)]
pub struct FontInfo {
pub family: String,
pub size: f32,
pub weight: FontWeight,
pub color: (u8, u8, u8),
}
#[derive(Debug, Clone)]
pub struct GridProperties {
pub grid_type: GridType,
pub color: (u8, u8, u8),
pub opacity: f32,
pub line_count: (usize, usize), }
#[derive(Debug, Clone, Copy)]
pub enum MarkerType {
Circle,
Square,
Triangle,
Diamond,
Plus,
Cross,
Star,
}
#[derive(Debug, Clone, Copy)]
pub enum LineStyle {
Solid,
Dashed,
Dotted,
DashDot,
}
#[derive(Debug, Clone, Copy)]
pub enum AnnotationType {
Label,
Arrow,
Callout,
Highlight,
}
#[derive(Debug, Clone, Copy)]
pub enum PaletteType {
Sequential,
Diverging,
Categorical,
Monochromatic,
}
#[derive(Debug, Clone, Copy)]
pub enum FontWeight {
Thin,
Light,
Regular,
Medium,
Bold,
ExtraBold,
}
#[derive(Debug, Clone, Copy)]
pub enum GridType {
Major,
Minor,
Both,
None,
}
impl DefaultChartProcessor {
pub fn new(config: ChartAnalysisConfig) -> RragResult<Self> {
let type_classifier = ChartTypeClassifier::new()?;
let data_extractor = ChartDataExtractor::new(true, true, true);
let trend_analyzer = TrendAnalyzer::new(5, 3, true);
let description_generator = ChartDescriptionGenerator::new();
Ok(Self {
config,
type_classifier,
data_extractor,
trend_analyzer,
description_generator,
})
}
pub fn analyze_comprehensive(&self, image_path: &Path) -> RragResult<ChartAnalysisResult> {
let (chart_type, confidence) = self.type_classifier.classify(image_path)?;
let data_points = self.data_extractor.extract(image_path, chart_type)?;
let elements = self.analyze_elements(image_path, chart_type)?;
let visual_properties = self.analyze_visual_properties(image_path)?;
Ok(ChartAnalysisResult {
chart_type,
confidence,
data_points,
elements,
visual_properties,
})
}
fn analyze_elements(
&self,
image_path: &Path,
chart_type: ChartType,
) -> RragResult<ChartElements> {
let title = self.extract_title(image_path)?;
let axes = self.extract_axes(image_path)?;
let legend = self.extract_legend(image_path)?;
let series = self.extract_series(image_path, chart_type)?;
let annotations = self.extract_annotations(image_path)?;
Ok(ChartElements {
title,
axes,
legend,
series,
annotations,
})
}
fn extract_title(&self, _image_path: &Path) -> RragResult<Option<String>> {
Ok(Some("Sample Chart Title".to_string()))
}
fn extract_axes(&self, _image_path: &Path) -> RragResult<ChartAxes> {
Ok(ChartAxes {
x_label: Some("Time".to_string()),
y_label: Some("Value".to_string()),
x_range: Some((0.0, 100.0)),
y_range: Some((0.0, 50.0)),
})
}
fn extract_legend(&self, _image_path: &Path) -> RragResult<Vec<LegendEntry>> {
Ok(vec![
LegendEntry {
text: "Series 1".to_string(),
color: Some((255, 0, 0)),
symbol: Some(MarkerType::Circle),
},
LegendEntry {
text: "Series 2".to_string(),
color: Some((0, 255, 0)),
symbol: Some(MarkerType::Square),
},
])
}
fn extract_series(
&self,
_image_path: &Path,
chart_type: ChartType,
) -> RragResult<Vec<DataSeries>> {
match chart_type {
ChartType::Line => self.extract_line_series(),
ChartType::Bar => self.extract_bar_series(),
ChartType::Pie => self.extract_pie_series(),
ChartType::Scatter => self.extract_scatter_series(),
_ => Ok(vec![]),
}
}
fn extract_line_series(&self) -> RragResult<Vec<DataSeries>> {
Ok(vec![DataSeries {
name: "Series 1".to_string(),
points: vec![
DataPoint {
x: 0.0,
y: 10.0,
label: None,
series: Some("Series 1".to_string()),
},
DataPoint {
x: 1.0,
y: 15.0,
label: None,
series: Some("Series 1".to_string()),
},
DataPoint {
x: 2.0,
y: 12.0,
label: None,
series: Some("Series 1".to_string()),
},
],
color: Some((255, 0, 0)),
line_style: Some(LineStyle::Solid),
}])
}
fn extract_bar_series(&self) -> RragResult<Vec<DataSeries>> {
Ok(vec![DataSeries {
name: "Categories".to_string(),
points: vec![
DataPoint {
x: 0.0,
y: 20.0,
label: Some("Category A".to_string()),
series: None,
},
DataPoint {
x: 1.0,
y: 35.0,
label: Some("Category B".to_string()),
series: None,
},
DataPoint {
x: 2.0,
y: 25.0,
label: Some("Category C".to_string()),
series: None,
},
],
color: Some((0, 100, 200)),
line_style: None,
}])
}
fn extract_pie_series(&self) -> RragResult<Vec<DataSeries>> {
Ok(vec![DataSeries {
name: "Pie Slices".to_string(),
points: vec![
DataPoint {
x: 0.0,
y: 40.0,
label: Some("Slice A".to_string()),
series: None,
},
DataPoint {
x: 1.0,
y: 30.0,
label: Some("Slice B".to_string()),
series: None,
},
DataPoint {
x: 2.0,
y: 30.0,
label: Some("Slice C".to_string()),
series: None,
},
],
color: None,
line_style: None,
}])
}
fn extract_scatter_series(&self) -> RragResult<Vec<DataSeries>> {
Ok(vec![DataSeries {
name: "Scatter Points".to_string(),
points: vec![
DataPoint {
x: 5.0,
y: 10.0,
label: None,
series: None,
},
DataPoint {
x: 15.0,
y: 25.0,
label: None,
series: None,
},
DataPoint {
x: 25.0,
y: 20.0,
label: None,
series: None,
},
],
color: Some((100, 100, 100)),
line_style: None,
}])
}
fn extract_annotations(&self, _image_path: &Path) -> RragResult<Vec<ChartAnnotation>> {
Ok(vec![])
}
fn analyze_visual_properties(&self, _image_path: &Path) -> RragResult<VisualProperties> {
Ok(VisualProperties {
chart_area: ChartArea {
bounds: (0.0, 0.0, 800.0, 600.0),
plot_area: (100.0, 100.0, 600.0, 400.0),
margins: (50.0, 50.0, 50.0, 100.0),
},
color_scheme: ColorScheme {
primary_colors: vec![(255, 0, 0), (0, 255, 0), (0, 0, 255)],
palette_type: PaletteType::Categorical,
accessibility_score: 0.8,
},
typography: Typography {
title_font: Some(FontInfo {
family: "Arial".to_string(),
size: 16.0,
weight: FontWeight::Bold,
color: (0, 0, 0),
}),
axis_font: Some(FontInfo {
family: "Arial".to_string(),
size: 12.0,
weight: FontWeight::Regular,
color: (100, 100, 100),
}),
legend_font: Some(FontInfo {
family: "Arial".to_string(),
size: 10.0,
weight: FontWeight::Regular,
color: (0, 0, 0),
}),
readability_score: 0.9,
},
grid: Some(GridProperties {
grid_type: GridType::Major,
color: (200, 200, 200),
opacity: 0.3,
line_count: (5, 10),
}),
})
}
}
impl ChartProcessor for DefaultChartProcessor {
fn analyze_chart(&self, image_path: &Path) -> RragResult<AnalyzedChart> {
let analysis = self.analyze_comprehensive(image_path)?;
let description = if self.config.generate_descriptions {
Some(self.description_generator.generate(&analysis)?)
} else {
None
};
let trends = if self.config.analyze_trends && !analysis.data_points.is_empty() {
Some(self.trend_analyzer.analyze(&analysis.data_points)?)
} else {
None
};
Ok(AnalyzedChart {
id: format!(
"chart_{}",
uuid::Uuid::new_v4().to_string().split('-').next().unwrap()
),
chart_type: analysis.chart_type,
title: analysis.elements.title,
axes: analysis.elements.axes,
data_points: analysis.data_points,
trends,
description,
embedding: None, })
}
fn extract_data_points(&self, chart_image: &Path) -> RragResult<Vec<DataPoint>> {
let analysis = self.analyze_comprehensive(chart_image)?;
Ok(analysis.data_points)
}
fn identify_type(&self, chart_image: &Path) -> RragResult<ChartType> {
let (chart_type, _confidence) = self.type_classifier.classify(chart_image)?;
Ok(chart_type)
}
fn analyze_trends(&self, data_points: &[DataPoint]) -> RragResult<TrendAnalysis> {
self.trend_analyzer.analyze(data_points)
}
}
impl ChartTypeClassifier {
pub fn new() -> RragResult<Self> {
let models = vec![
ClassificationModel {
model_type: ModelType::CNN,
confidence_threshold: 0.8,
features: vec![
FeatureType::ColorHistogram,
FeatureType::EdgeDetection,
FeatureType::ShapeFeatures,
],
},
ClassificationModel {
model_type: ModelType::SVM,
confidence_threshold: 0.7,
features: vec![FeatureType::LayoutFeatures, FeatureType::TextFeatures],
},
];
Ok(Self { models })
}
pub fn classify(&self, image_path: &Path) -> RragResult<(ChartType, f32)> {
let filename = image_path
.file_name()
.and_then(|name| name.to_str())
.unwrap_or("");
let (chart_type, confidence) = if filename.contains("line") {
(ChartType::Line, 0.95)
} else if filename.contains("bar") {
(ChartType::Bar, 0.90)
} else if filename.contains("pie") {
(ChartType::Pie, 0.85)
} else if filename.contains("scatter") {
(ChartType::Scatter, 0.80)
} else {
(ChartType::Unknown, 0.5)
};
Ok((chart_type, confidence))
}
pub fn extract_features(&self, _image_path: &Path) -> RragResult<Vec<f32>> {
let mut features = Vec::new();
features.extend(vec![0.1, 0.2, 0.3, 0.4]);
features.extend(vec![0.5, 0.6]);
features.extend(vec![0.7, 0.8, 0.9]);
features.extend(vec![0.2, 0.4]);
features.push(0.3);
Ok(features)
}
}
impl ChartDataExtractor {
pub fn new(ocr_enabled: bool, color_analysis: bool, shape_detection: bool) -> Self {
Self {
ocr_enabled,
color_analysis,
shape_detection,
}
}
pub fn extract(&self, image_path: &Path, chart_type: ChartType) -> RragResult<Vec<DataPoint>> {
match chart_type {
ChartType::Line => self.extract_line_data(image_path),
ChartType::Bar => self.extract_bar_data(image_path),
ChartType::Pie => self.extract_pie_data(image_path),
ChartType::Scatter => self.extract_scatter_data(image_path),
ChartType::Area => self.extract_area_data(image_path),
ChartType::Histogram => self.extract_histogram_data(image_path),
_ => Ok(vec![]),
}
}
fn extract_line_data(&self, _image_path: &Path) -> RragResult<Vec<DataPoint>> {
Ok(vec![
DataPoint {
x: 0.0,
y: 10.0,
label: None,
series: Some("Line 1".to_string()),
},
DataPoint {
x: 1.0,
y: 15.0,
label: None,
series: Some("Line 1".to_string()),
},
DataPoint {
x: 2.0,
y: 12.0,
label: None,
series: Some("Line 1".to_string()),
},
DataPoint {
x: 3.0,
y: 18.0,
label: None,
series: Some("Line 1".to_string()),
},
])
}
fn extract_bar_data(&self, _image_path: &Path) -> RragResult<Vec<DataPoint>> {
Ok(vec![
DataPoint {
x: 0.0,
y: 25.0,
label: Some("Q1".to_string()),
series: None,
},
DataPoint {
x: 1.0,
y: 30.0,
label: Some("Q2".to_string()),
series: None,
},
DataPoint {
x: 2.0,
y: 35.0,
label: Some("Q3".to_string()),
series: None,
},
DataPoint {
x: 3.0,
y: 40.0,
label: Some("Q4".to_string()),
series: None,
},
])
}
fn extract_pie_data(&self, _image_path: &Path) -> RragResult<Vec<DataPoint>> {
Ok(vec![
DataPoint {
x: 0.0,
y: 40.0,
label: Some("Category A".to_string()),
series: None,
},
DataPoint {
x: 1.0,
y: 30.0,
label: Some("Category B".to_string()),
series: None,
},
DataPoint {
x: 2.0,
y: 20.0,
label: Some("Category C".to_string()),
series: None,
},
DataPoint {
x: 3.0,
y: 10.0,
label: Some("Category D".to_string()),
series: None,
},
])
}
fn extract_scatter_data(&self, _image_path: &Path) -> RragResult<Vec<DataPoint>> {
Ok(vec![
DataPoint {
x: 5.0,
y: 10.0,
label: None,
series: None,
},
DataPoint {
x: 15.0,
y: 25.0,
label: None,
series: None,
},
DataPoint {
x: 25.0,
y: 20.0,
label: None,
series: None,
},
DataPoint {
x: 35.0,
y: 40.0,
label: None,
series: None,
},
])
}
fn extract_area_data(&self, image_path: &Path) -> RragResult<Vec<DataPoint>> {
self.extract_line_data(image_path)
}
fn extract_histogram_data(&self, _image_path: &Path) -> RragResult<Vec<DataPoint>> {
Ok(vec![
DataPoint {
x: 0.0,
y: 5.0,
label: Some("0-10".to_string()),
series: None,
},
DataPoint {
x: 1.0,
y: 15.0,
label: Some("10-20".to_string()),
series: None,
},
DataPoint {
x: 2.0,
y: 25.0,
label: Some("20-30".to_string()),
series: None,
},
DataPoint {
x: 3.0,
y: 10.0,
label: Some("30-40".to_string()),
series: None,
},
])
}
}
impl TrendAnalyzer {
pub fn new(min_points: usize, smoothing_window: usize, seasonality_detection: bool) -> Self {
Self {
min_points,
smoothing_window,
seasonality_detection,
}
}
pub fn analyze(&self, data_points: &[DataPoint]) -> RragResult<TrendAnalysis> {
if data_points.len() < self.min_points {
return Err(RragError::validation(
"data_points",
format!("minimum {} points", self.min_points),
format!("{} points", data_points.len()),
));
}
let direction = self.calculate_trend_direction(data_points);
let strength = self.calculate_trend_strength(data_points);
let seasonality = if self.seasonality_detection {
self.detect_seasonality(data_points)
} else {
None
};
let outliers = self.detect_outliers(data_points);
let forecast = if data_points.len() >= 10 {
Some(self.generate_forecast(data_points, 5)?)
} else {
None
};
Ok(TrendAnalysis {
direction,
strength,
seasonality,
outliers,
forecast,
})
}
fn calculate_trend_direction(&self, data_points: &[DataPoint]) -> TrendDirection {
if data_points.len() < 2 {
return TrendDirection::Stable;
}
let first_y = data_points[0].y;
let last_y = data_points[data_points.len() - 1].y;
let change = last_y - first_y;
let volatility = self.calculate_volatility(data_points);
if change.abs() < volatility * 0.5 {
TrendDirection::Stable
} else if volatility > change.abs() * 2.0 {
TrendDirection::Volatile
} else if change > 0.0 {
TrendDirection::Increasing
} else {
TrendDirection::Decreasing
}
}
fn calculate_trend_strength(&self, data_points: &[DataPoint]) -> f32 {
if data_points.len() < 2 {
return 0.0;
}
let n = data_points.len() as f64;
let sum_x: f64 = data_points.iter().map(|p| p.x).sum();
let sum_y: f64 = data_points.iter().map(|p| p.y as f64).sum();
let sum_xy: f64 = data_points.iter().map(|p| p.x * p.y as f64).sum();
let sum_x2: f64 = data_points.iter().map(|p| p.x * p.x).sum();
let sum_y2: f64 = data_points
.iter()
.map(|p| (p.y as f64) * (p.y as f64))
.sum();
let numerator = n * sum_xy - sum_x * sum_y;
let denominator = ((n * sum_x2 - sum_x * sum_x) * (n * sum_y2 - sum_y * sum_y)).sqrt();
if denominator == 0.0 {
return 0.0;
}
let r = numerator / denominator;
(r * r) as f32 }
fn calculate_volatility(&self, data_points: &[DataPoint]) -> f64 {
if data_points.len() < 2 {
return 0.0;
}
let values: Vec<f64> = data_points.iter().map(|p| p.y as f64).collect();
let mean = values.iter().sum::<f64>() / values.len() as f64;
let variance = values.iter().map(|v| (v - mean).powi(2)).sum::<f64>() / values.len() as f64;
variance.sqrt()
}
fn detect_seasonality(&self, data_points: &[DataPoint]) -> Option<Seasonality> {
if data_points.len() < 12 {
return None; }
Some(Seasonality {
period: 12.0, amplitude: 5.0,
phase: 0.0,
})
}
fn detect_outliers(&self, data_points: &[DataPoint]) -> Vec<DataPoint> {
if data_points.len() < 4 {
return vec![];
}
let mut y_values: Vec<f32> = data_points.iter().map(|p| p.y as f32).collect();
y_values.sort_by(|a, b| a.partial_cmp(b).unwrap());
let q1_idx = y_values.len() / 4;
let q3_idx = 3 * y_values.len() / 4;
let q1 = y_values[q1_idx];
let q3 = y_values[q3_idx];
let iqr = q3 - q1;
let lower_bound = q1 - 1.5 * iqr;
let upper_bound = q3 + 1.5 * iqr;
data_points
.iter()
.filter(|p| (p.y as f32) < lower_bound || (p.y as f32) > upper_bound)
.cloned()
.collect()
}
fn generate_forecast(
&self,
data_points: &[DataPoint],
num_points: usize,
) -> RragResult<Vec<DataPoint>> {
if data_points.len() < 2 {
return Ok(vec![]);
}
let n = data_points.len() as f64;
let sum_x: f64 = data_points.iter().map(|p| p.x).sum();
let sum_y: f64 = data_points.iter().map(|p| p.y as f64).sum();
let sum_xy: f64 = data_points.iter().map(|p| p.x * p.y as f64).sum();
let sum_x2: f64 = data_points.iter().map(|p| p.x * p.x).sum();
let slope = (n * sum_xy - sum_x * sum_y) / (n * sum_x2 - sum_x * sum_x);
let intercept = (sum_y - slope * sum_x) / n;
let last_x = data_points.last().unwrap().x;
let mut forecast = Vec::new();
for i in 1..=num_points {
let x = last_x + i as f64;
let y = (slope * x + intercept) as f32;
forecast.push(DataPoint {
x,
y: y as f64,
label: Some(format!("Forecast {}", i)),
series: Some("Forecast".to_string()),
});
}
Ok(forecast)
}
}
impl ChartDescriptionGenerator {
pub fn new() -> Self {
let mut templates = std::collections::HashMap::new();
templates.insert(
ChartType::Line,
"This line chart shows {data_description}. The trend is {trend_direction} with a strength of {trend_strength:.2}.".to_string()
);
templates.insert(
ChartType::Bar,
"This bar chart displays {data_description}. The highest value is {max_value} and the lowest is {min_value}.".to_string()
);
templates.insert(
ChartType::Pie,
"This pie chart represents {data_description}. The largest segment is {largest_segment} at {largest_percentage:.1}%.".to_string()
);
Self {
templates,
nlg_enabled: false,
}
}
pub fn generate(&self, analysis: &ChartAnalysisResult) -> RragResult<String> {
if let Some(template) = self.templates.get(&analysis.chart_type) {
let description = self.fill_template(template, analysis)?;
Ok(description)
} else {
Ok(format!(
"Chart of type {:?} with {} data points",
analysis.chart_type,
analysis.data_points.len()
))
}
}
fn fill_template(&self, template: &str, analysis: &ChartAnalysisResult) -> RragResult<String> {
let mut description = template.to_string();
description = description.replace(
"{data_description}",
&format!("{} data points", analysis.data_points.len()),
);
if !analysis.data_points.is_empty() {
let max_y = analysis
.data_points
.iter()
.map(|p| p.y as f32)
.fold(f32::NEG_INFINITY, |a, b| a.max(b));
let min_y = analysis
.data_points
.iter()
.map(|p| p.y as f32)
.fold(f32::INFINITY, |a, b| a.min(b));
description = description.replace("{max_value}", &max_y.to_string());
description = description.replace("{min_value}", &min_y.to_string());
}
Ok(description)
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::NamedTempFile;
#[test]
fn test_chart_processor_creation() {
let config = ChartAnalysisConfig::default();
let processor = DefaultChartProcessor::new(config).unwrap();
assert!(processor.config.extract_data);
assert!(processor.config.generate_descriptions);
assert!(processor.config.analyze_trends);
}
#[test]
fn test_chart_type_classification() {
let classifier = ChartTypeClassifier::new().unwrap();
let temp_file = NamedTempFile::new().unwrap();
let path = temp_file.path().with_file_name("line_chart.png");
let (chart_type, confidence) = classifier.classify(&path).unwrap();
assert_eq!(chart_type, ChartType::Line);
assert!(confidence > 0.9);
}
#[test]
fn test_trend_analysis() {
let analyzer = TrendAnalyzer::new(3, 2, false);
let data_points = vec![
DataPoint {
x: 0.0,
y: 10.0,
label: None,
series: None,
},
DataPoint {
x: 1.0,
y: 15.0,
label: None,
series: None,
},
DataPoint {
x: 2.0,
y: 20.0,
label: None,
series: None,
},
DataPoint {
x: 3.0,
y: 25.0,
label: None,
series: None,
},
];
let trend = analyzer.analyze(&data_points).unwrap();
assert_eq!(trend.direction, TrendDirection::Increasing);
assert!(trend.strength > 0.8);
}
#[test]
fn test_outlier_detection() {
let analyzer = TrendAnalyzer::new(3, 2, false);
let data_points = vec![
DataPoint {
x: 0.0,
y: 10.0,
label: None,
series: None,
},
DataPoint {
x: 1.0,
y: 12.0,
label: None,
series: None,
},
DataPoint {
x: 2.0,
y: 100.0,
label: None,
series: None,
}, DataPoint {
x: 3.0,
y: 11.0,
label: None,
series: None,
},
];
let outliers = analyzer.detect_outliers(&data_points);
assert_eq!(outliers.len(), 1);
assert_eq!(outliers[0].y, 100.0);
}
#[test]
fn test_data_extraction() {
let extractor = ChartDataExtractor::new(true, true, true);
let temp_file = NamedTempFile::new().unwrap();
let data_points = extractor
.extract(temp_file.path(), ChartType::Line)
.unwrap();
assert!(!data_points.is_empty());
}
#[test]
fn test_description_generation() {
let generator = ChartDescriptionGenerator::new();
let analysis = ChartAnalysisResult {
chart_type: ChartType::Line,
confidence: 0.9,
data_points: vec![
DataPoint {
x: 0.0,
y: 10.0,
label: None,
series: None,
},
DataPoint {
x: 1.0,
y: 15.0,
label: None,
series: None,
},
],
elements: ChartElements {
title: None,
axes: ChartAxes {
x_label: None,
y_label: None,
x_range: None,
y_range: None,
},
legend: vec![],
series: vec![],
annotations: vec![],
},
visual_properties: VisualProperties {
chart_area: ChartArea {
bounds: (0.0, 0.0, 100.0, 100.0),
plot_area: (0.0, 0.0, 100.0, 100.0),
margins: (0.0, 0.0, 0.0, 0.0),
},
color_scheme: ColorScheme {
primary_colors: vec![],
palette_type: PaletteType::Categorical,
accessibility_score: 1.0,
},
typography: Typography {
title_font: None,
axis_font: None,
legend_font: None,
readability_score: 1.0,
},
grid: None,
},
};
let description = generator.generate(&analysis).unwrap();
assert!(description.contains("line chart"));
assert!(description.contains("2 data points"));
}
}