leptos_helios/
ml_intelligence.rs

1//! ML Intelligence Module for Helios
2//!
3//! This module provides machine learning capabilities for intelligent chart recommendations,
4//! data forecasting, and automatic visualization optimization.
5
6use serde::{Deserialize, Serialize};
7use std::time::{Duration, Instant};
8
9/// Time series data point for ML operations
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct TimeSeriesPoint {
12    pub timestamp: f64,
13    pub value: f64,
14}
15
16/// ML forecasting result with confidence metrics
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct ForecastResult {
19    pub predictions: Vec<f64>,
20    pub confidence: f64,
21    pub inference_time: Duration,
22    pub model_type: String,
23}
24
25/// Chart recommendation based on data analysis
26#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct ChartRecommendation {
28    pub chart_type: String,
29    pub confidence: f64,
30    pub reasoning: String,
31    pub optimization_suggestions: Vec<String>,
32}
33
34/// ML forecaster for time series prediction
35pub struct MLForecaster {
36    model_loaded: bool,
37    inference_count: u32,
38}
39
40impl MLForecaster {
41    /// Create a new ML forecaster
42    pub fn new() -> Self {
43        Self {
44            model_loaded: false,
45            inference_count: 0,
46        }
47    }
48
49    /// Load ML model for forecasting
50    pub fn load_model(&mut self) -> Result<(), String> {
51        // Simulate model loading
52        self.model_loaded = true;
53        Ok(())
54    }
55
56    /// Generate forecast for time series data
57    pub fn forecast(
58        &mut self,
59        series: &[TimeSeriesPoint],
60        periods: u32,
61    ) -> Result<ForecastResult, String> {
62        if !self.model_loaded {
63            return Err("Model not loaded".to_string());
64        }
65
66        let start = Instant::now();
67
68        // Simulate ML inference
69        let predictions = self.generate_predictions(series, periods);
70        let confidence = self.calculate_confidence(series);
71        let inference_time = start.elapsed();
72
73        self.inference_count += 1;
74
75        Ok(ForecastResult {
76            predictions,
77            confidence,
78            inference_time,
79            model_type: "LSTM".to_string(),
80        })
81    }
82
83    /// Generate predictions using ML model
84    fn generate_predictions(&self, series: &[TimeSeriesPoint], periods: u32) -> Vec<f64> {
85        if series.is_empty() {
86            return vec![0.0; periods as usize];
87        }
88
89        // Check for extreme values in the series
90        let has_extreme_values = series
91            .iter()
92            .any(|point| !point.value.is_finite() || point.value.abs() > 1e10);
93
94        if has_extreme_values {
95            // For extreme values, return simple constant predictions
96            return vec![0.0; periods as usize];
97        }
98
99        // Simple trend-based prediction for normal values
100        let last_value = series.last().unwrap().value;
101        let trend = self.calculate_trend(series);
102
103        (0..periods)
104            .map(|i| {
105                let prediction = last_value + trend * (i as f64 + 1.0);
106                // Ensure predictions are finite
107                if prediction.is_finite() {
108                    prediction.clamp(-1e10, 1e10)
109                } else {
110                    0.0
111                }
112            })
113            .collect()
114    }
115
116    /// Calculate trend from time series
117    fn calculate_trend(&self, series: &[TimeSeriesPoint]) -> f64 {
118        if series.len() < 2 {
119            return 0.0;
120        }
121
122        let first = series[0].value;
123        let last = series.last().unwrap().value;
124        let time_span = series.last().unwrap().timestamp - series[0].timestamp;
125
126        // Handle extreme values that could cause overflow
127        if !first.is_finite() || !last.is_finite() || !time_span.is_finite() {
128            return 0.0;
129        }
130
131        // Check for extreme differences that could cause overflow
132        let diff = last - first;
133        if diff.abs() > 1e10 {
134            return 0.0;
135        }
136
137        if time_span > 0.0 {
138            let trend = diff / time_span;
139            // Clamp extreme values to prevent overflow
140            if trend.is_finite() {
141                trend.clamp(-1e6, 1e6)
142            } else {
143                0.0
144            }
145        } else {
146            0.0
147        }
148    }
149
150    /// Calculate confidence score for predictions
151    fn calculate_confidence(&self, series: &[TimeSeriesPoint]) -> f64 {
152        if series.len() < 3 {
153            return 0.5;
154        }
155
156        // Calculate variance to determine confidence
157        let mean = series.iter().map(|p| p.value).sum::<f64>() / series.len() as f64;
158        let variance =
159            series.iter().map(|p| (p.value - mean).powi(2)).sum::<f64>() / series.len() as f64;
160
161        // Higher variance = lower confidence
162        (1.0 / (1.0 + variance)).min(1.0).max(0.0)
163    }
164
165    /// Get inference statistics
166    pub fn get_inference_count(&self) -> u32 {
167        self.inference_count
168    }
169
170    /// Check if model is loaded
171    pub fn is_model_loaded(&self) -> bool {
172        self.model_loaded
173    }
174}
175
176/// Intelligent chart recommendation engine
177pub struct ChartRecommendationEngine {
178    data_analyzer: DataAnalyzer,
179    recommendation_cache: std::collections::HashMap<String, ChartRecommendation>,
180}
181
182impl ChartRecommendationEngine {
183    /// Create a new recommendation engine
184    pub fn new() -> Self {
185        Self {
186            data_analyzer: DataAnalyzer::new(),
187            recommendation_cache: std::collections::HashMap::new(),
188        }
189    }
190
191    /// Analyze data and recommend optimal chart type
192    pub fn recommend_chart(&mut self, data: &[f64], metadata: &str) -> ChartRecommendation {
193        let cache_key = format!("{}_{}", data.len(), metadata);
194
195        if let Some(cached) = self.recommendation_cache.get(&cache_key) {
196            return cached.clone();
197        }
198
199        let analysis = self.data_analyzer.analyze(data);
200        let recommendation = self.generate_recommendation(&analysis, metadata);
201
202        self.recommendation_cache
203            .insert(cache_key, recommendation.clone());
204        recommendation
205    }
206
207    /// Generate recommendation based on data analysis
208    fn generate_recommendation(
209        &self,
210        analysis: &DataAnalysis,
211        _metadata: &str,
212    ) -> ChartRecommendation {
213        let (chart_type, confidence, reasoning) = match analysis.data_type {
214            DataType::TimeSeries => ("line", 0.9, "Time series data detected".to_string()),
215            DataType::Categorical => ("bar", 0.85, "Categorical data detected".to_string()),
216            DataType::Continuous => ("scatter", 0.8, "Continuous data detected".to_string()),
217            DataType::Mixed => ("combo", 0.7, "Mixed data types detected".to_string()),
218        };
219
220        let optimization_suggestions = self.generate_optimization_suggestions(analysis);
221
222        ChartRecommendation {
223            chart_type: chart_type.to_string(),
224            confidence,
225            reasoning,
226            optimization_suggestions,
227        }
228    }
229
230    /// Generate optimization suggestions
231    fn generate_optimization_suggestions(&self, analysis: &DataAnalysis) -> Vec<String> {
232        let mut suggestions = Vec::new();
233
234        if analysis.outlier_count > analysis.data_points / 10 {
235            suggestions.push("Consider outlier filtering for better visualization".to_string());
236        }
237
238        if analysis.variance > 1.0 {
239            suggestions.push("High variance detected - consider log scale".to_string());
240        }
241
242        if analysis.data_points >= 10000 {
243            suggestions.push("Large dataset - consider data sampling or aggregation".to_string());
244        }
245
246        suggestions
247    }
248
249    /// Get cache statistics
250    pub fn get_cache_size(&self) -> usize {
251        self.recommendation_cache.len()
252    }
253
254    /// Clear recommendation cache
255    pub fn clear_cache(&mut self) {
256        self.recommendation_cache.clear();
257    }
258}
259
260/// Data analysis for intelligent recommendations
261#[derive(Debug, Clone)]
262pub struct DataAnalysis {
263    pub data_type: DataType,
264    pub data_points: usize,
265    pub mean: f64,
266    pub variance: f64,
267    pub outlier_count: usize,
268    pub trend: TrendType,
269}
270
271#[derive(Debug, Clone)]
272pub enum DataType {
273    TimeSeries,
274    Categorical,
275    Continuous,
276    Mixed,
277}
278
279#[derive(Debug, Clone)]
280pub enum TrendType {
281    Increasing,
282    Decreasing,
283    Stable,
284    Volatile,
285}
286
287/// Data analyzer for ML operations
288pub struct DataAnalyzer {
289    analysis_count: u32,
290}
291
292impl DataAnalyzer {
293    /// Create a new data analyzer
294    pub fn new() -> Self {
295        Self { analysis_count: 0 }
296    }
297
298    /// Analyze data characteristics
299    pub fn analyze(&mut self, data: &[f64]) -> DataAnalysis {
300        self.analysis_count += 1;
301
302        let data_points = data.len();
303        let mean = if data_points > 0 {
304            data.iter().sum::<f64>() / data_points as f64
305        } else {
306            0.0
307        };
308
309        let variance = if data_points > 1 {
310            data.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / (data_points - 1) as f64
311        } else {
312            0.0
313        };
314
315        let outlier_count = self.count_outliers(data, mean, variance.sqrt());
316        let data_type = self.classify_data_type(data);
317        let trend = self.analyze_trend(data);
318
319        DataAnalysis {
320            data_type,
321            data_points,
322            mean,
323            variance,
324            outlier_count,
325            trend,
326        }
327    }
328
329    /// Count outliers using IQR method
330    fn count_outliers(&self, data: &[f64], mean: f64, std_dev: f64) -> usize {
331        if std_dev == 0.0 {
332            return 0;
333        }
334
335        data.iter()
336            .filter(|&&x| (x - mean).abs() > 2.0 * std_dev)
337            .count()
338    }
339
340    /// Classify data type based on characteristics
341    fn classify_data_type(&self, data: &[f64]) -> DataType {
342        if data.is_empty() {
343            return DataType::Continuous;
344        }
345
346        // Simple heuristic: check for discrete values
347        // Use a different approach since f64 doesn't implement Hash
348        let mut unique_count = 0;
349        let total_values = data.len();
350
351        // Count unique values by rounding to avoid floating point precision issues
352        let mut seen = std::collections::HashSet::new();
353        for &value in data {
354            let rounded = (value * 1000.0).round() as i64;
355            if seen.insert(rounded) {
356                unique_count += 1;
357            }
358        }
359
360        if (unique_count as f64 / total_values as f64) < 0.1 {
361            DataType::Categorical
362        } else {
363            DataType::Continuous
364        }
365    }
366
367    /// Analyze trend in data
368    fn analyze_trend(&self, data: &[f64]) -> TrendType {
369        if data.len() < 2 {
370            return TrendType::Stable;
371        }
372
373        let first_half = &data[..data.len() / 2];
374        let second_half = &data[data.len() / 2..];
375
376        let first_mean = first_half.iter().sum::<f64>() / first_half.len() as f64;
377        let second_mean = second_half.iter().sum::<f64>() / second_half.len() as f64;
378
379        let change = (second_mean - first_mean) / first_mean.abs().max(1e-10);
380
381        if change > 0.1 {
382            TrendType::Increasing
383        } else if change < -0.1 {
384            TrendType::Decreasing
385        } else {
386            TrendType::Stable
387        }
388    }
389
390    /// Get analysis statistics
391    pub fn get_analysis_count(&self) -> u32 {
392        self.analysis_count
393    }
394}
395
396/// ML performance monitor
397pub struct MLPerformanceMonitor {
398    total_inferences: u32,
399    total_inference_time: Duration,
400    max_inference_time: Duration,
401    min_inference_time: Duration,
402}
403
404impl MLPerformanceMonitor {
405    /// Create a new performance monitor
406    pub fn new() -> Self {
407        Self {
408            total_inferences: 0,
409            total_inference_time: Duration::ZERO,
410            max_inference_time: Duration::ZERO,
411            min_inference_time: Duration::from_secs(3600), // Start with high value
412        }
413    }
414
415    /// Record inference performance
416    pub fn record_inference(&mut self, duration: Duration) {
417        self.total_inferences += 1;
418        self.total_inference_time += duration;
419        self.max_inference_time = self.max_inference_time.max(duration);
420        self.min_inference_time = self.min_inference_time.min(duration);
421    }
422
423    /// Get average inference time
424    pub fn get_average_inference_time(&self) -> Duration {
425        if self.total_inferences > 0 {
426            Duration::from_nanos(
427                self.total_inference_time.as_nanos() as u64 / self.total_inferences as u64,
428            )
429        } else {
430            Duration::ZERO
431        }
432    }
433
434    /// Get performance statistics
435    pub fn get_stats(&self) -> MLPerformanceStats {
436        MLPerformanceStats {
437            total_inferences: self.total_inferences,
438            average_inference_time: self.get_average_inference_time(),
439            max_inference_time: self.max_inference_time,
440            min_inference_time: self.min_inference_time,
441        }
442    }
443}
444
445/// ML performance statistics
446#[derive(Debug, Clone)]
447pub struct MLPerformanceStats {
448    pub total_inferences: u32,
449    pub average_inference_time: Duration,
450    pub max_inference_time: Duration,
451    pub min_inference_time: Duration,
452}
453
454impl Default for MLForecaster {
455    fn default() -> Self {
456        Self::new()
457    }
458}
459
460impl Default for ChartRecommendationEngine {
461    fn default() -> Self {
462        Self::new()
463    }
464}
465
466impl Default for DataAnalyzer {
467    fn default() -> Self {
468        Self::new()
469    }
470}
471
472impl Default for MLPerformanceMonitor {
473    fn default() -> Self {
474        Self::new()
475    }
476}