Skip to main content

ringkernel_accnet/gui/
heatmaps.rs

1//! Heatmap visualizations for account activity and correlations.
2//!
3//! Provides grid-based heatmaps that reveal patterns in:
4//! - Account activity over time
5//! - Account-to-account flow correlations
6//! - Temporal patterns (hour of day, day of week)
7
8use eframe::egui::{self, Color32, Pos2, Rect, Response, Sense, Stroke, Vec2};
9use std::collections::HashMap;
10
11use super::theme::AccNetTheme;
12use crate::models::{AccountType, AccountingNetwork};
13
14/// Color gradient for heatmaps (cold to hot).
15pub struct HeatmapGradient {
16    /// Gradient stops: (position 0-1, color).
17    pub stops: Vec<(f32, Color32)>,
18}
19
20impl HeatmapGradient {
21    /// Default cold-to-hot gradient (blue -> cyan -> green -> yellow -> red).
22    pub fn thermal() -> Self {
23        Self {
24            stops: vec![
25                (0.0, Color32::from_rgb(20, 30, 60)),    // Dark blue
26                (0.2, Color32::from_rgb(40, 80, 160)),   // Blue
27                (0.4, Color32::from_rgb(60, 180, 180)),  // Cyan
28                (0.5, Color32::from_rgb(80, 200, 100)),  // Green
29                (0.7, Color32::from_rgb(220, 200, 60)),  // Yellow
30                (0.85, Color32::from_rgb(240, 120, 40)), // Orange
31                (1.0, Color32::from_rgb(200, 40, 40)),   // Red
32            ],
33        }
34    }
35
36    /// Correlation gradient (negative blue, zero gray, positive red).
37    pub fn correlation() -> Self {
38        Self {
39            stops: vec![
40                (0.0, Color32::from_rgb(40, 80, 200)), // Strong negative (blue)
41                (0.35, Color32::from_rgb(80, 120, 180)), // Weak negative
42                (0.5, Color32::from_rgb(60, 60, 70)),  // Zero (gray)
43                (0.65, Color32::from_rgb(180, 100, 80)), // Weak positive
44                (1.0, Color32::from_rgb(200, 50, 50)), // Strong positive (red)
45            ],
46        }
47    }
48
49    /// Risk gradient (green safe -> yellow caution -> red danger).
50    pub fn risk() -> Self {
51        Self {
52            stops: vec![
53                (0.0, Color32::from_rgb(60, 160, 80)),  // Safe (green)
54                (0.3, Color32::from_rgb(120, 180, 80)), // Low risk
55                (0.5, Color32::from_rgb(200, 200, 60)), // Medium (yellow)
56                (0.7, Color32::from_rgb(220, 140, 50)), // Elevated
57                (0.85, Color32::from_rgb(200, 80, 50)), // High
58                (1.0, Color32::from_rgb(180, 40, 40)),  // Critical (red)
59            ],
60        }
61    }
62
63    /// Get color for a value in [0, 1].
64    pub fn sample(&self, t: f32) -> Color32 {
65        let t = t.clamp(0.0, 1.0);
66
67        // Find surrounding stops
68        let mut prev = (0.0_f32, self.stops[0].1);
69        for &(pos, color) in &self.stops {
70            if t <= pos {
71                // Interpolate between prev and current
72                let range = pos - prev.0;
73                if range < 0.001 {
74                    return color;
75                }
76                let local_t = (t - prev.0) / range;
77                return Self::lerp_color(prev.1, color, local_t);
78            }
79            prev = (pos, color);
80        }
81        self.stops.last().map(|s| s.1).unwrap_or(Color32::WHITE)
82    }
83
84    fn lerp_color(a: Color32, b: Color32, t: f32) -> Color32 {
85        Color32::from_rgb(
86            (a.r() as f32 + (b.r() as f32 - a.r() as f32) * t) as u8,
87            (a.g() as f32 + (b.g() as f32 - a.g() as f32) * t) as u8,
88            (a.b() as f32 + (b.b() as f32 - a.b() as f32) * t) as u8,
89        )
90    }
91}
92
93/// Account activity heatmap - shows activity by account type and time.
94pub struct ActivityHeatmap {
95    /// Activity data: [account_type][time_bucket] = activity_level (0-1).
96    pub data: Vec<Vec<f32>>,
97    /// Row labels (account types or account names).
98    pub row_labels: Vec<String>,
99    /// Column labels (time periods).
100    pub col_labels: Vec<String>,
101    /// Title.
102    pub title: String,
103    /// Gradient for coloring.
104    pub gradient: HeatmapGradient,
105    /// Cell size.
106    pub cell_size: f32,
107}
108
109impl ActivityHeatmap {
110    /// Create from network data, grouping by account type.
111    pub fn from_network_by_type(network: &AccountingNetwork) -> Self {
112        let type_names = ["Asset", "Liability", "Equity", "Revenue", "Expense"];
113        let mut data = vec![vec![0.0f32; 10]; 5]; // 5 types x 10 time buckets
114
115        // Calculate activity by account type
116        let mut max_activity = 1.0f32;
117        for account in &network.accounts {
118            let type_idx = match account.account_type {
119                AccountType::Asset | AccountType::Contra => 0,
120                AccountType::Liability => 1,
121                AccountType::Equity => 2,
122                AccountType::Revenue => 3,
123                AccountType::Expense => 4,
124            };
125
126            // Distribute activity across time buckets based on transaction_count
127            let activity = account.transaction_count as f32;
128            let bucket = (account.index as usize) % 10; // Distribute by index for demo
129            data[type_idx][bucket] += activity;
130            max_activity = max_activity.max(data[type_idx][bucket]);
131        }
132
133        // Normalize to 0-1
134        for row in &mut data {
135            for cell in row {
136                *cell /= max_activity.max(1.0);
137            }
138        }
139
140        Self {
141            data,
142            row_labels: type_names.iter().map(|s| s.to_string()).collect(),
143            col_labels: (1..=10).map(|i| format!("T{}", i)).collect(),
144            title: "Account Type Activity".to_string(),
145            gradient: HeatmapGradient::thermal(),
146            cell_size: 18.0,
147        }
148    }
149
150    /// Create from network showing individual account activity.
151    pub fn from_network_top_accounts(network: &AccountingNetwork, top_n: usize) -> Self {
152        // Get top N accounts by transaction count
153        let mut accounts: Vec<_> = network.accounts.iter().enumerate().collect();
154        accounts.sort_by_key(|a| std::cmp::Reverse(a.1.transaction_count));
155        accounts.truncate(top_n);
156
157        let mut data = vec![vec![0.0f32; 8]; accounts.len()];
158        let mut row_labels = Vec::new();
159
160        for (row_idx, (_, account)) in accounts.iter().enumerate() {
161            // Use account code as label, or index if no metadata
162            row_labels.push(format!("#{}", account.index));
163
164            // Simulate activity distribution (in real app, use actual temporal data)
165            let base_activity = account.transaction_count as f32;
166            for (col, cell) in data[row_idx].iter_mut().enumerate().take(8) {
167                // Create some variation based on account properties
168                let variation = ((account.index as f32 + col as f32) * 0.7).sin() * 0.3 + 0.7;
169                *cell = base_activity * variation;
170            }
171        }
172
173        // Normalize
174        let max_val = data
175            .iter()
176            .flat_map(|row| row.iter())
177            .copied()
178            .fold(0.0f32, f32::max)
179            .max(1.0);
180
181        for row in &mut data {
182            for cell in row {
183                *cell /= max_val;
184            }
185        }
186
187        Self {
188            data,
189            row_labels,
190            col_labels: ["Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun", "Avg"]
191                .iter()
192                .map(|s| s.to_string())
193                .collect(),
194            title: "Top Account Activity by Day".to_string(),
195            gradient: HeatmapGradient::thermal(),
196            cell_size: 16.0,
197        }
198    }
199
200    /// Render the heatmap.
201    pub fn show(&self, ui: &mut egui::Ui, theme: &AccNetTheme) -> Response {
202        let rows = self.data.len();
203        let cols = if rows > 0 { self.data[0].len() } else { 0 };
204
205        let label_width = 60.0;
206        let header_height = 20.0;
207        let width = label_width + cols as f32 * self.cell_size + 10.0;
208        let height = header_height + rows as f32 * self.cell_size + 25.0;
209
210        let (response, painter) = ui.allocate_painter(
211            Vec2::new(width.min(ui.available_width()), height),
212            Sense::hover(),
213        );
214        let rect = response.rect;
215
216        // Title
217        painter.text(
218            Pos2::new(rect.left() + 5.0, rect.top()),
219            egui::Align2::LEFT_TOP,
220            &self.title,
221            egui::FontId::proportional(11.0),
222            theme.text_secondary,
223        );
224
225        if rows == 0 || cols == 0 {
226            return response;
227        }
228
229        let grid_left = rect.left() + label_width;
230        let grid_top = rect.top() + header_height;
231
232        // Column headers
233        for (i, label) in self.col_labels.iter().enumerate() {
234            let x = grid_left + (i as f32 + 0.5) * self.cell_size;
235            painter.text(
236                Pos2::new(x, grid_top - 3.0),
237                egui::Align2::CENTER_BOTTOM,
238                label,
239                egui::FontId::proportional(7.0),
240                theme.text_secondary,
241            );
242        }
243
244        // Render grid
245        for (row_idx, row_data) in self.data.iter().enumerate() {
246            let y = grid_top + row_idx as f32 * self.cell_size;
247
248            // Row label
249            if row_idx < self.row_labels.len() {
250                painter.text(
251                    Pos2::new(grid_left - 3.0, y + self.cell_size / 2.0),
252                    egui::Align2::RIGHT_CENTER,
253                    &self.row_labels[row_idx],
254                    egui::FontId::proportional(8.0),
255                    theme.text_secondary,
256                );
257            }
258
259            // Cells
260            for (col_idx, &value) in row_data.iter().enumerate() {
261                let x = grid_left + col_idx as f32 * self.cell_size;
262                let cell_rect = Rect::from_min_size(
263                    Pos2::new(x + 1.0, y + 1.0),
264                    Vec2::new(self.cell_size - 2.0, self.cell_size - 2.0),
265                );
266
267                let color = self.gradient.sample(value);
268                painter.rect_filled(cell_rect, 2.0, color);
269            }
270        }
271
272        // Grid border
273        let grid_rect = Rect::from_min_size(
274            Pos2::new(grid_left, grid_top),
275            Vec2::new(cols as f32 * self.cell_size, rows as f32 * self.cell_size),
276        );
277        painter.rect_stroke(
278            grid_rect,
279            0.0,
280            Stroke::new(1.0, Color32::from_rgb(60, 60, 70)),
281        );
282
283        response
284    }
285}
286
287/// Account correlation heatmap - shows flow relationships between accounts.
288pub struct CorrelationHeatmap {
289    /// Correlation matrix: [from_account][to_account] = correlation (-1 to 1).
290    pub data: Vec<Vec<f32>>,
291    /// Account labels.
292    pub labels: Vec<String>,
293    /// Title.
294    pub title: String,
295    /// Gradient for coloring.
296    pub gradient: HeatmapGradient,
297    /// Cell size.
298    pub cell_size: f32,
299}
300
301impl CorrelationHeatmap {
302    /// Create from network flow data.
303    pub fn from_network(
304        network: &AccountingNetwork,
305        top_n: usize,
306        account_names: &HashMap<u16, String>,
307    ) -> Self {
308        // Get top N accounts by degree
309        let mut accounts: Vec<_> = network.accounts.iter().enumerate().collect();
310        accounts.sort_by(|a, b| {
311            let deg_a = a.1.in_degree + a.1.out_degree;
312            let deg_b = b.1.in_degree + b.1.out_degree;
313            deg_b.cmp(&deg_a)
314        });
315        accounts.truncate(top_n);
316
317        let n = accounts.len();
318        let mut data = vec![vec![0.0f32; n]; n];
319        let mut labels = Vec::new();
320
321        // Map original indices to matrix indices
322        let index_map: HashMap<u16, usize> = accounts
323            .iter()
324            .enumerate()
325            .map(|(i, (_, acc))| (acc.index, i))
326            .collect();
327
328        for (_, acc) in &accounts {
329            let name = account_names
330                .get(&acc.index)
331                .cloned()
332                .unwrap_or_else(|| format!("#{}", acc.index));
333            // Truncate long names for display
334            let short_name: String = name.chars().take(8).collect();
335            labels.push(short_name);
336        }
337
338        // Build correlation from flows
339        let mut flow_counts: HashMap<(u16, u16), usize> = HashMap::new();
340        let mut max_flow = 1usize;
341
342        for flow in &network.flows {
343            if index_map.contains_key(&flow.source_account_index)
344                && index_map.contains_key(&flow.target_account_index)
345            {
346                let key = (flow.source_account_index, flow.target_account_index);
347                let count = flow_counts.entry(key).or_insert(0);
348                *count += 1;
349                max_flow = max_flow.max(*count);
350            }
351        }
352
353        // Fill correlation matrix
354        for ((from, to), count) in flow_counts {
355            if let (Some(&i), Some(&j)) = (index_map.get(&from), index_map.get(&to)) {
356                // Normalize to 0-1, then shift to -0.5 to 0.5 for visualization
357                // (showing flow as positive correlation)
358                let normalized = count as f32 / max_flow as f32;
359                data[i][j] = normalized;
360                // Symmetric for visualization
361                data[j][i] = normalized * 0.8; // Slightly lower for reverse
362            }
363        }
364
365        // Diagonal = 1.0 (self-correlation)
366        for (i, row) in data.iter_mut().enumerate().take(n) {
367            row[i] = 1.0;
368        }
369
370        Self {
371            data,
372            labels,
373            title: "Account Flow Correlation".to_string(),
374            gradient: HeatmapGradient::correlation(),
375            cell_size: 14.0,
376        }
377    }
378
379    /// Render the correlation heatmap.
380    pub fn show(&self, ui: &mut egui::Ui, theme: &AccNetTheme) -> Response {
381        let n = self.data.len();
382        if n == 0 {
383            let (response, _) = ui.allocate_painter(Vec2::new(100.0, 40.0), Sense::hover());
384            return response;
385        }
386
387        let label_width = 35.0;
388        let header_height = 35.0;
389        let width = label_width + n as f32 * self.cell_size + 10.0;
390        let height = header_height + n as f32 * self.cell_size + 25.0;
391
392        let (response, painter) = ui.allocate_painter(
393            Vec2::new(width.min(ui.available_width()), height),
394            Sense::hover(),
395        );
396        let rect = response.rect;
397
398        // Title
399        painter.text(
400            Pos2::new(rect.left() + 5.0, rect.top()),
401            egui::Align2::LEFT_TOP,
402            &self.title,
403            egui::FontId::proportional(11.0),
404            theme.text_secondary,
405        );
406
407        let grid_left = rect.left() + label_width;
408        let grid_top = rect.top() + header_height;
409
410        // Column headers (rotated labels)
411        for (i, label) in self.labels.iter().enumerate() {
412            let x = grid_left + (i as f32 + 0.5) * self.cell_size;
413            painter.text(
414                Pos2::new(x, grid_top - 3.0),
415                egui::Align2::CENTER_BOTTOM,
416                label,
417                egui::FontId::proportional(7.0),
418                theme.text_secondary,
419            );
420        }
421
422        // Render grid
423        for (i, row) in self.data.iter().enumerate() {
424            let y = grid_top + i as f32 * self.cell_size;
425
426            // Row label
427            if i < self.labels.len() {
428                painter.text(
429                    Pos2::new(grid_left - 2.0, y + self.cell_size / 2.0),
430                    egui::Align2::RIGHT_CENTER,
431                    &self.labels[i],
432                    egui::FontId::proportional(7.0),
433                    theme.text_secondary,
434                );
435            }
436
437            // Cells
438            for (j, &value) in row.iter().enumerate() {
439                let x = grid_left + j as f32 * self.cell_size;
440                let cell_rect = Rect::from_min_size(
441                    Pos2::new(x + 0.5, y + 0.5),
442                    Vec2::new(self.cell_size - 1.0, self.cell_size - 1.0),
443                );
444
445                let color = self.gradient.sample(value);
446                painter.rect_filled(cell_rect, 1.0, color);
447            }
448        }
449
450        response
451    }
452}
453
454/// Risk heatmap showing account risk levels by various factors.
455pub struct RiskHeatmap {
456    /// Risk data: [account][risk_factor] = risk_level (0-1).
457    pub data: Vec<Vec<f32>>,
458    /// Account labels.
459    pub account_labels: Vec<String>,
460    /// Risk factor labels.
461    pub factor_labels: Vec<String>,
462    /// Title.
463    pub title: String,
464    /// Gradient.
465    pub gradient: HeatmapGradient,
466    /// Cell size.
467    pub cell_size: f32,
468}
469
470impl RiskHeatmap {
471    /// Create from network analyzing multiple risk factors.
472    pub fn from_network(
473        network: &AccountingNetwork,
474        top_n: usize,
475        account_names: &HashMap<u16, String>,
476    ) -> Self {
477        let factors = ["Suspense", "Centrality", "Volume", "Balance", "Anomaly"];
478
479        // Get accounts with highest total risk indicators
480        let mut account_risks: Vec<(usize, f32, &crate::models::AccountNode)> = network
481            .accounts
482            .iter()
483            .enumerate()
484            .map(|(i, acc)| {
485                let suspense_risk = if acc
486                    .flags
487                    .has(crate::models::AccountFlags::IS_SUSPENSE_ACCOUNT)
488                {
489                    0.8
490                } else {
491                    0.0
492                };
493                let degree_risk = (acc.in_degree + acc.out_degree) as f32 / 100.0;
494                let total = suspense_risk + degree_risk;
495                (i, total, acc)
496            })
497            .collect();
498
499        account_risks.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
500        account_risks.truncate(top_n);
501
502        let mut data = Vec::new();
503        let mut account_labels = Vec::new();
504
505        let max_degree = network
506            .accounts
507            .iter()
508            .map(|a| (a.in_degree + a.out_degree) as f32)
509            .fold(1.0f32, f32::max);
510
511        let max_volume = network
512            .accounts
513            .iter()
514            .map(|a| a.transaction_count as f32)
515            .fold(1.0f32, f32::max);
516
517        let max_balance = network
518            .accounts
519            .iter()
520            .map(|a| a.closing_balance.to_f64().abs() as f32)
521            .fold(1.0f32, f32::max);
522
523        for (_, _, acc) in account_risks {
524            let name = account_names
525                .get(&acc.index)
526                .cloned()
527                .unwrap_or_else(|| format!("#{}", acc.index));
528            // Truncate long names for display
529            let short_name: String = name.chars().take(10).collect();
530            account_labels.push(short_name);
531
532            let row = vec![
533                // Suspense risk
534                if acc
535                    .flags
536                    .has(crate::models::AccountFlags::IS_SUSPENSE_ACCOUNT)
537                {
538                    0.9
539                } else {
540                    0.1
541                },
542                // Centrality risk (high degree = higher risk)
543                ((acc.in_degree + acc.out_degree) as f32 / max_degree).min(1.0),
544                // Volume risk (high transaction count = watch closely)
545                (acc.transaction_count as f32 / max_volume).min(1.0),
546                // Balance concentration risk
547                (acc.closing_balance.to_f64().abs() as f32 / max_balance).min(1.0),
548                // Anomaly flag risk
549                if acc.flags.has(crate::models::AccountFlags::HAS_ANOMALY) {
550                    0.85
551                } else {
552                    0.15
553                },
554            ];
555            data.push(row);
556        }
557
558        Self {
559            data,
560            account_labels,
561            factor_labels: factors.iter().map(|s| s.to_string()).collect(),
562            title: "Account Risk Factors".to_string(),
563            gradient: HeatmapGradient::risk(),
564            cell_size: 20.0,
565        }
566    }
567
568    /// Render the risk heatmap.
569    pub fn show(&self, ui: &mut egui::Ui, theme: &AccNetTheme) -> Response {
570        let rows = self.data.len();
571        let cols = self.factor_labels.len();
572
573        if rows == 0 || cols == 0 {
574            let (response, _) = ui.allocate_painter(Vec2::new(100.0, 40.0), Sense::hover());
575            return response;
576        }
577
578        let label_width = 40.0;
579        let header_height = 22.0;
580        let width = label_width + cols as f32 * self.cell_size + 10.0;
581        let height = header_height + rows as f32 * self.cell_size + 25.0;
582
583        let (response, painter) = ui.allocate_painter(
584            Vec2::new(width.min(ui.available_width()), height),
585            Sense::hover(),
586        );
587        let rect = response.rect;
588
589        // Title
590        painter.text(
591            Pos2::new(rect.left() + 5.0, rect.top()),
592            egui::Align2::LEFT_TOP,
593            &self.title,
594            egui::FontId::proportional(11.0),
595            theme.text_secondary,
596        );
597
598        let grid_left = rect.left() + label_width;
599        let grid_top = rect.top() + header_height;
600
601        // Column headers
602        for (i, label) in self.factor_labels.iter().enumerate() {
603            let x = grid_left + (i as f32 + 0.5) * self.cell_size;
604            painter.text(
605                Pos2::new(x, grid_top - 2.0),
606                egui::Align2::CENTER_BOTTOM,
607                &label[..label.len().min(4)], // Truncate to 4 chars
608                egui::FontId::proportional(7.0),
609                theme.text_secondary,
610            );
611        }
612
613        // Render cells
614        for (i, row) in self.data.iter().enumerate() {
615            let y = grid_top + i as f32 * self.cell_size;
616
617            // Row label
618            if i < self.account_labels.len() {
619                painter.text(
620                    Pos2::new(grid_left - 2.0, y + self.cell_size / 2.0),
621                    egui::Align2::RIGHT_CENTER,
622                    &self.account_labels[i],
623                    egui::FontId::proportional(7.0),
624                    theme.text_secondary,
625                );
626            }
627
628            for (j, &value) in row.iter().enumerate() {
629                let x = grid_left + j as f32 * self.cell_size;
630                let cell_rect = Rect::from_min_size(
631                    Pos2::new(x + 1.0, y + 1.0),
632                    Vec2::new(self.cell_size - 2.0, self.cell_size - 2.0),
633                );
634
635                let color = self.gradient.sample(value);
636                painter.rect_filled(cell_rect, 2.0, color);
637            }
638        }
639
640        response
641    }
642}
643
644#[cfg(test)]
645mod tests {
646    use super::*;
647
648    #[test]
649    fn test_gradient_thermal() {
650        let g = HeatmapGradient::thermal();
651        let c0 = g.sample(0.0);
652        let c1 = g.sample(1.0);
653        assert_ne!(c0, c1);
654    }
655
656    #[test]
657    fn test_gradient_bounds() {
658        let g = HeatmapGradient::thermal();
659        // Should not panic on out-of-bounds
660        let _ = g.sample(-0.5);
661        let _ = g.sample(1.5);
662    }
663}