1use eframe::egui::{self, Color32, Pos2, Rect, Response, Sense, Stroke, Vec2};
9use std::collections::HashMap;
10
11use super::theme::AccNetTheme;
12use crate::models::{AccountType, AccountingNetwork};
13
14pub struct HeatmapGradient {
16 pub stops: Vec<(f32, Color32)>,
18}
19
20impl HeatmapGradient {
21 pub fn thermal() -> Self {
23 Self {
24 stops: vec![
25 (0.0, Color32::from_rgb(20, 30, 60)), (0.2, Color32::from_rgb(40, 80, 160)), (0.4, Color32::from_rgb(60, 180, 180)), (0.5, Color32::from_rgb(80, 200, 100)), (0.7, Color32::from_rgb(220, 200, 60)), (0.85, Color32::from_rgb(240, 120, 40)), (1.0, Color32::from_rgb(200, 40, 40)), ],
33 }
34 }
35
36 pub fn correlation() -> Self {
38 Self {
39 stops: vec![
40 (0.0, Color32::from_rgb(40, 80, 200)), (0.35, Color32::from_rgb(80, 120, 180)), (0.5, Color32::from_rgb(60, 60, 70)), (0.65, Color32::from_rgb(180, 100, 80)), (1.0, Color32::from_rgb(200, 50, 50)), ],
46 }
47 }
48
49 pub fn risk() -> Self {
51 Self {
52 stops: vec![
53 (0.0, Color32::from_rgb(60, 160, 80)), (0.3, Color32::from_rgb(120, 180, 80)), (0.5, Color32::from_rgb(200, 200, 60)), (0.7, Color32::from_rgb(220, 140, 50)), (0.85, Color32::from_rgb(200, 80, 50)), (1.0, Color32::from_rgb(180, 40, 40)), ],
60 }
61 }
62
63 pub fn sample(&self, t: f32) -> Color32 {
65 let t = t.clamp(0.0, 1.0);
66
67 let mut prev = (0.0_f32, self.stops[0].1);
69 for &(pos, color) in &self.stops {
70 if t <= pos {
71 let range = pos - prev.0;
73 if range < 0.001 {
74 return color;
75 }
76 let local_t = (t - prev.0) / range;
77 return Self::lerp_color(prev.1, color, local_t);
78 }
79 prev = (pos, color);
80 }
81 self.stops.last().map(|s| s.1).unwrap_or(Color32::WHITE)
82 }
83
84 fn lerp_color(a: Color32, b: Color32, t: f32) -> Color32 {
85 Color32::from_rgb(
86 (a.r() as f32 + (b.r() as f32 - a.r() as f32) * t) as u8,
87 (a.g() as f32 + (b.g() as f32 - a.g() as f32) * t) as u8,
88 (a.b() as f32 + (b.b() as f32 - a.b() as f32) * t) as u8,
89 )
90 }
91}
92
93pub struct ActivityHeatmap {
95 pub data: Vec<Vec<f32>>,
97 pub row_labels: Vec<String>,
99 pub col_labels: Vec<String>,
101 pub title: String,
103 pub gradient: HeatmapGradient,
105 pub cell_size: f32,
107}
108
109impl ActivityHeatmap {
110 pub fn from_network_by_type(network: &AccountingNetwork) -> Self {
112 let type_names = ["Asset", "Liability", "Equity", "Revenue", "Expense"];
113 let mut data = vec![vec![0.0f32; 10]; 5]; let mut max_activity = 1.0f32;
117 for account in &network.accounts {
118 let type_idx = match account.account_type {
119 AccountType::Asset | AccountType::Contra => 0,
120 AccountType::Liability => 1,
121 AccountType::Equity => 2,
122 AccountType::Revenue => 3,
123 AccountType::Expense => 4,
124 };
125
126 let activity = account.transaction_count as f32;
128 let bucket = (account.index as usize) % 10; data[type_idx][bucket] += activity;
130 max_activity = max_activity.max(data[type_idx][bucket]);
131 }
132
133 for row in &mut data {
135 for cell in row {
136 *cell /= max_activity.max(1.0);
137 }
138 }
139
140 Self {
141 data,
142 row_labels: type_names.iter().map(|s| s.to_string()).collect(),
143 col_labels: (1..=10).map(|i| format!("T{}", i)).collect(),
144 title: "Account Type Activity".to_string(),
145 gradient: HeatmapGradient::thermal(),
146 cell_size: 18.0,
147 }
148 }
149
150 pub fn from_network_top_accounts(network: &AccountingNetwork, top_n: usize) -> Self {
152 let mut accounts: Vec<_> = network.accounts.iter().enumerate().collect();
154 accounts.sort_by_key(|a| std::cmp::Reverse(a.1.transaction_count));
155 accounts.truncate(top_n);
156
157 let mut data = vec![vec![0.0f32; 8]; accounts.len()];
158 let mut row_labels = Vec::new();
159
160 for (row_idx, (_, account)) in accounts.iter().enumerate() {
161 row_labels.push(format!("#{}", account.index));
163
164 let base_activity = account.transaction_count as f32;
166 for (col, cell) in data[row_idx].iter_mut().enumerate().take(8) {
167 let variation = ((account.index as f32 + col as f32) * 0.7).sin() * 0.3 + 0.7;
169 *cell = base_activity * variation;
170 }
171 }
172
173 let max_val = data
175 .iter()
176 .flat_map(|row| row.iter())
177 .copied()
178 .fold(0.0f32, f32::max)
179 .max(1.0);
180
181 for row in &mut data {
182 for cell in row {
183 *cell /= max_val;
184 }
185 }
186
187 Self {
188 data,
189 row_labels,
190 col_labels: ["Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun", "Avg"]
191 .iter()
192 .map(|s| s.to_string())
193 .collect(),
194 title: "Top Account Activity by Day".to_string(),
195 gradient: HeatmapGradient::thermal(),
196 cell_size: 16.0,
197 }
198 }
199
200 pub fn show(&self, ui: &mut egui::Ui, theme: &AccNetTheme) -> Response {
202 let rows = self.data.len();
203 let cols = if rows > 0 { self.data[0].len() } else { 0 };
204
205 let label_width = 60.0;
206 let header_height = 20.0;
207 let width = label_width + cols as f32 * self.cell_size + 10.0;
208 let height = header_height + rows as f32 * self.cell_size + 25.0;
209
210 let (response, painter) = ui.allocate_painter(
211 Vec2::new(width.min(ui.available_width()), height),
212 Sense::hover(),
213 );
214 let rect = response.rect;
215
216 painter.text(
218 Pos2::new(rect.left() + 5.0, rect.top()),
219 egui::Align2::LEFT_TOP,
220 &self.title,
221 egui::FontId::proportional(11.0),
222 theme.text_secondary,
223 );
224
225 if rows == 0 || cols == 0 {
226 return response;
227 }
228
229 let grid_left = rect.left() + label_width;
230 let grid_top = rect.top() + header_height;
231
232 for (i, label) in self.col_labels.iter().enumerate() {
234 let x = grid_left + (i as f32 + 0.5) * self.cell_size;
235 painter.text(
236 Pos2::new(x, grid_top - 3.0),
237 egui::Align2::CENTER_BOTTOM,
238 label,
239 egui::FontId::proportional(7.0),
240 theme.text_secondary,
241 );
242 }
243
244 for (row_idx, row_data) in self.data.iter().enumerate() {
246 let y = grid_top + row_idx as f32 * self.cell_size;
247
248 if row_idx < self.row_labels.len() {
250 painter.text(
251 Pos2::new(grid_left - 3.0, y + self.cell_size / 2.0),
252 egui::Align2::RIGHT_CENTER,
253 &self.row_labels[row_idx],
254 egui::FontId::proportional(8.0),
255 theme.text_secondary,
256 );
257 }
258
259 for (col_idx, &value) in row_data.iter().enumerate() {
261 let x = grid_left + col_idx as f32 * self.cell_size;
262 let cell_rect = Rect::from_min_size(
263 Pos2::new(x + 1.0, y + 1.0),
264 Vec2::new(self.cell_size - 2.0, self.cell_size - 2.0),
265 );
266
267 let color = self.gradient.sample(value);
268 painter.rect_filled(cell_rect, 2.0, color);
269 }
270 }
271
272 let grid_rect = Rect::from_min_size(
274 Pos2::new(grid_left, grid_top),
275 Vec2::new(cols as f32 * self.cell_size, rows as f32 * self.cell_size),
276 );
277 painter.rect_stroke(
278 grid_rect,
279 0.0,
280 Stroke::new(1.0, Color32::from_rgb(60, 60, 70)),
281 );
282
283 response
284 }
285}
286
287pub struct CorrelationHeatmap {
289 pub data: Vec<Vec<f32>>,
291 pub labels: Vec<String>,
293 pub title: String,
295 pub gradient: HeatmapGradient,
297 pub cell_size: f32,
299}
300
301impl CorrelationHeatmap {
302 pub fn from_network(
304 network: &AccountingNetwork,
305 top_n: usize,
306 account_names: &HashMap<u16, String>,
307 ) -> Self {
308 let mut accounts: Vec<_> = network.accounts.iter().enumerate().collect();
310 accounts.sort_by(|a, b| {
311 let deg_a = a.1.in_degree + a.1.out_degree;
312 let deg_b = b.1.in_degree + b.1.out_degree;
313 deg_b.cmp(°_a)
314 });
315 accounts.truncate(top_n);
316
317 let n = accounts.len();
318 let mut data = vec![vec![0.0f32; n]; n];
319 let mut labels = Vec::new();
320
321 let index_map: HashMap<u16, usize> = accounts
323 .iter()
324 .enumerate()
325 .map(|(i, (_, acc))| (acc.index, i))
326 .collect();
327
328 for (_, acc) in &accounts {
329 let name = account_names
330 .get(&acc.index)
331 .cloned()
332 .unwrap_or_else(|| format!("#{}", acc.index));
333 let short_name: String = name.chars().take(8).collect();
335 labels.push(short_name);
336 }
337
338 let mut flow_counts: HashMap<(u16, u16), usize> = HashMap::new();
340 let mut max_flow = 1usize;
341
342 for flow in &network.flows {
343 if index_map.contains_key(&flow.source_account_index)
344 && index_map.contains_key(&flow.target_account_index)
345 {
346 let key = (flow.source_account_index, flow.target_account_index);
347 let count = flow_counts.entry(key).or_insert(0);
348 *count += 1;
349 max_flow = max_flow.max(*count);
350 }
351 }
352
353 for ((from, to), count) in flow_counts {
355 if let (Some(&i), Some(&j)) = (index_map.get(&from), index_map.get(&to)) {
356 let normalized = count as f32 / max_flow as f32;
359 data[i][j] = normalized;
360 data[j][i] = normalized * 0.8; }
363 }
364
365 for (i, row) in data.iter_mut().enumerate().take(n) {
367 row[i] = 1.0;
368 }
369
370 Self {
371 data,
372 labels,
373 title: "Account Flow Correlation".to_string(),
374 gradient: HeatmapGradient::correlation(),
375 cell_size: 14.0,
376 }
377 }
378
379 pub fn show(&self, ui: &mut egui::Ui, theme: &AccNetTheme) -> Response {
381 let n = self.data.len();
382 if n == 0 {
383 let (response, _) = ui.allocate_painter(Vec2::new(100.0, 40.0), Sense::hover());
384 return response;
385 }
386
387 let label_width = 35.0;
388 let header_height = 35.0;
389 let width = label_width + n as f32 * self.cell_size + 10.0;
390 let height = header_height + n as f32 * self.cell_size + 25.0;
391
392 let (response, painter) = ui.allocate_painter(
393 Vec2::new(width.min(ui.available_width()), height),
394 Sense::hover(),
395 );
396 let rect = response.rect;
397
398 painter.text(
400 Pos2::new(rect.left() + 5.0, rect.top()),
401 egui::Align2::LEFT_TOP,
402 &self.title,
403 egui::FontId::proportional(11.0),
404 theme.text_secondary,
405 );
406
407 let grid_left = rect.left() + label_width;
408 let grid_top = rect.top() + header_height;
409
410 for (i, label) in self.labels.iter().enumerate() {
412 let x = grid_left + (i as f32 + 0.5) * self.cell_size;
413 painter.text(
414 Pos2::new(x, grid_top - 3.0),
415 egui::Align2::CENTER_BOTTOM,
416 label,
417 egui::FontId::proportional(7.0),
418 theme.text_secondary,
419 );
420 }
421
422 for (i, row) in self.data.iter().enumerate() {
424 let y = grid_top + i as f32 * self.cell_size;
425
426 if i < self.labels.len() {
428 painter.text(
429 Pos2::new(grid_left - 2.0, y + self.cell_size / 2.0),
430 egui::Align2::RIGHT_CENTER,
431 &self.labels[i],
432 egui::FontId::proportional(7.0),
433 theme.text_secondary,
434 );
435 }
436
437 for (j, &value) in row.iter().enumerate() {
439 let x = grid_left + j as f32 * self.cell_size;
440 let cell_rect = Rect::from_min_size(
441 Pos2::new(x + 0.5, y + 0.5),
442 Vec2::new(self.cell_size - 1.0, self.cell_size - 1.0),
443 );
444
445 let color = self.gradient.sample(value);
446 painter.rect_filled(cell_rect, 1.0, color);
447 }
448 }
449
450 response
451 }
452}
453
454pub struct RiskHeatmap {
456 pub data: Vec<Vec<f32>>,
458 pub account_labels: Vec<String>,
460 pub factor_labels: Vec<String>,
462 pub title: String,
464 pub gradient: HeatmapGradient,
466 pub cell_size: f32,
468}
469
470impl RiskHeatmap {
471 pub fn from_network(
473 network: &AccountingNetwork,
474 top_n: usize,
475 account_names: &HashMap<u16, String>,
476 ) -> Self {
477 let factors = ["Suspense", "Centrality", "Volume", "Balance", "Anomaly"];
478
479 let mut account_risks: Vec<(usize, f32, &crate::models::AccountNode)> = network
481 .accounts
482 .iter()
483 .enumerate()
484 .map(|(i, acc)| {
485 let suspense_risk = if acc
486 .flags
487 .has(crate::models::AccountFlags::IS_SUSPENSE_ACCOUNT)
488 {
489 0.8
490 } else {
491 0.0
492 };
493 let degree_risk = (acc.in_degree + acc.out_degree) as f32 / 100.0;
494 let total = suspense_risk + degree_risk;
495 (i, total, acc)
496 })
497 .collect();
498
499 account_risks.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
500 account_risks.truncate(top_n);
501
502 let mut data = Vec::new();
503 let mut account_labels = Vec::new();
504
505 let max_degree = network
506 .accounts
507 .iter()
508 .map(|a| (a.in_degree + a.out_degree) as f32)
509 .fold(1.0f32, f32::max);
510
511 let max_volume = network
512 .accounts
513 .iter()
514 .map(|a| a.transaction_count as f32)
515 .fold(1.0f32, f32::max);
516
517 let max_balance = network
518 .accounts
519 .iter()
520 .map(|a| a.closing_balance.to_f64().abs() as f32)
521 .fold(1.0f32, f32::max);
522
523 for (_, _, acc) in account_risks {
524 let name = account_names
525 .get(&acc.index)
526 .cloned()
527 .unwrap_or_else(|| format!("#{}", acc.index));
528 let short_name: String = name.chars().take(10).collect();
530 account_labels.push(short_name);
531
532 let row = vec![
533 if acc
535 .flags
536 .has(crate::models::AccountFlags::IS_SUSPENSE_ACCOUNT)
537 {
538 0.9
539 } else {
540 0.1
541 },
542 ((acc.in_degree + acc.out_degree) as f32 / max_degree).min(1.0),
544 (acc.transaction_count as f32 / max_volume).min(1.0),
546 (acc.closing_balance.to_f64().abs() as f32 / max_balance).min(1.0),
548 if acc.flags.has(crate::models::AccountFlags::HAS_ANOMALY) {
550 0.85
551 } else {
552 0.15
553 },
554 ];
555 data.push(row);
556 }
557
558 Self {
559 data,
560 account_labels,
561 factor_labels: factors.iter().map(|s| s.to_string()).collect(),
562 title: "Account Risk Factors".to_string(),
563 gradient: HeatmapGradient::risk(),
564 cell_size: 20.0,
565 }
566 }
567
568 pub fn show(&self, ui: &mut egui::Ui, theme: &AccNetTheme) -> Response {
570 let rows = self.data.len();
571 let cols = self.factor_labels.len();
572
573 if rows == 0 || cols == 0 {
574 let (response, _) = ui.allocate_painter(Vec2::new(100.0, 40.0), Sense::hover());
575 return response;
576 }
577
578 let label_width = 40.0;
579 let header_height = 22.0;
580 let width = label_width + cols as f32 * self.cell_size + 10.0;
581 let height = header_height + rows as f32 * self.cell_size + 25.0;
582
583 let (response, painter) = ui.allocate_painter(
584 Vec2::new(width.min(ui.available_width()), height),
585 Sense::hover(),
586 );
587 let rect = response.rect;
588
589 painter.text(
591 Pos2::new(rect.left() + 5.0, rect.top()),
592 egui::Align2::LEFT_TOP,
593 &self.title,
594 egui::FontId::proportional(11.0),
595 theme.text_secondary,
596 );
597
598 let grid_left = rect.left() + label_width;
599 let grid_top = rect.top() + header_height;
600
601 for (i, label) in self.factor_labels.iter().enumerate() {
603 let x = grid_left + (i as f32 + 0.5) * self.cell_size;
604 painter.text(
605 Pos2::new(x, grid_top - 2.0),
606 egui::Align2::CENTER_BOTTOM,
607 &label[..label.len().min(4)], egui::FontId::proportional(7.0),
609 theme.text_secondary,
610 );
611 }
612
613 for (i, row) in self.data.iter().enumerate() {
615 let y = grid_top + i as f32 * self.cell_size;
616
617 if i < self.account_labels.len() {
619 painter.text(
620 Pos2::new(grid_left - 2.0, y + self.cell_size / 2.0),
621 egui::Align2::RIGHT_CENTER,
622 &self.account_labels[i],
623 egui::FontId::proportional(7.0),
624 theme.text_secondary,
625 );
626 }
627
628 for (j, &value) in row.iter().enumerate() {
629 let x = grid_left + j as f32 * self.cell_size;
630 let cell_rect = Rect::from_min_size(
631 Pos2::new(x + 1.0, y + 1.0),
632 Vec2::new(self.cell_size - 2.0, self.cell_size - 2.0),
633 );
634
635 let color = self.gradient.sample(value);
636 painter.rect_filled(cell_rect, 2.0, color);
637 }
638 }
639
640 response
641 }
642}
643
644#[cfg(test)]
645mod tests {
646 use super::*;
647
648 #[test]
649 fn test_gradient_thermal() {
650 let g = HeatmapGradient::thermal();
651 let c0 = g.sample(0.0);
652 let c1 = g.sample(1.0);
653 assert_ne!(c0, c1);
654 }
655
656 #[test]
657 fn test_gradient_bounds() {
658 let g = HeatmapGradient::thermal();
659 let _ = g.sample(-0.5);
661 let _ = g.sample(1.5);
662 }
663}