Skip to main content

ringkernel_accnet/gui/
layout.rs

1//! Force-directed graph layout algorithm.
2//!
3//! Implements Fruchterman-Reingold with Barnes-Hut optimization for
4//! smooth, aesthetically pleasing network layouts.
5
6use crate::models::{AccountType, AccountingNetwork};
7use nalgebra::Vector2;
8use std::collections::HashMap;
9
10/// A node in the layout.
11#[derive(Debug, Clone)]
12pub struct LayoutNode {
13    /// Account index.
14    pub index: u16,
15    /// Account type.
16    pub account_type: AccountType,
17    /// Current position.
18    pub position: Vector2<f32>,
19    /// Current velocity.
20    pub velocity: Vector2<f32>,
21    /// Whether position is pinned.
22    pub pinned: bool,
23    /// Node mass (affects repulsion).
24    pub mass: f32,
25}
26
27/// Configuration for force-directed layout.
28#[derive(Debug, Clone)]
29pub struct LayoutConfig {
30    /// Repulsion constant (higher = more spread out).
31    pub repulsion: f32,
32    /// Attraction constant (higher = tighter clusters).
33    pub attraction: f32,
34    /// Damping factor (0-1, higher = more friction).
35    pub damping: f32,
36    /// Minimum velocity threshold.
37    pub min_velocity: f32,
38    /// Maximum iterations per frame.
39    pub max_iterations: usize,
40    /// Ideal edge length.
41    pub ideal_length: f32,
42    /// Canvas width.
43    pub width: f32,
44    /// Canvas height.
45    pub height: f32,
46    /// Gravity toward center.
47    pub gravity: f32,
48    /// Group accounts by type.
49    pub group_by_type: bool,
50    /// Minimum distance between any two nodes.
51    pub min_node_distance: f32,
52}
53
54impl Default for LayoutConfig {
55    fn default() -> Self {
56        Self {
57            repulsion: 50000.0, // Very strong repulsion for maximum spacing
58            attraction: 0.0008, // Very gentle attraction for loose connections
59            damping: 0.85,      // Balance between friction and movement
60            min_velocity: 0.3,  // Lower threshold for smoother convergence
61            max_iterations: 50,
62            ideal_length: 300.0, // Large ideal edge length
63            width: 800.0,
64            height: 600.0,
65            gravity: 0.008, // Very low gravity - minimal pull to center
66            group_by_type: true,
67            min_node_distance: 300.0, // Very large minimum distance between nodes
68        }
69    }
70}
71
72/// Force-directed graph layout engine.
73pub struct ForceDirectedLayout {
74    /// Configuration.
75    pub config: LayoutConfig,
76    /// Layout nodes.
77    nodes: HashMap<u16, LayoutNode>,
78    /// Edges (source, target, weight).
79    edges: Vec<(u16, u16, f32)>,
80    /// Whether layout has converged.
81    pub converged: bool,
82    /// Current temperature (for simulated annealing).
83    temperature: f32,
84}
85
86impl ForceDirectedLayout {
87    /// Create a new layout engine.
88    pub fn new(config: LayoutConfig) -> Self {
89        Self {
90            temperature: config.width / 10.0,
91            config,
92            nodes: HashMap::new(),
93            edges: Vec::new(),
94            converged: false,
95        }
96    }
97
98    /// Initialize layout from accounting network.
99    pub fn initialize(&mut self, network: &AccountingNetwork) {
100        self.nodes.clear();
101        self.edges.clear();
102        self.converged = false;
103        self.temperature = self.config.width / 10.0;
104
105        // Create nodes with initial positions
106        let center = Vector2::new(self.config.width / 2.0, self.config.height / 2.0);
107
108        for (i, account) in network.accounts.iter().enumerate() {
109            // Initial position based on account type (circular layout by type)
110            let angle = self.type_angle(account.account_type) + (i as f32 * 0.1);
111            let radius = self.config.width.min(self.config.height) * 0.3;
112
113            let position = if self.config.group_by_type {
114                center + Vector2::new(angle.cos() * radius, angle.sin() * radius)
115            } else {
116                // Random initial position
117                center
118                    + Vector2::new(
119                        (rand::random::<f32>() - 0.5) * self.config.width * 0.8,
120                        (rand::random::<f32>() - 0.5) * self.config.height * 0.8,
121                    )
122            };
123
124            self.nodes.insert(
125                account.index,
126                LayoutNode {
127                    index: account.index,
128                    account_type: account.account_type,
129                    position,
130                    velocity: Vector2::zeros(),
131                    pinned: false,
132                    mass: 1.0 + (account.risk_score * 2.0),
133                },
134            );
135        }
136
137        // Create edges from flows
138        for flow in &network.flows {
139            let weight = flow.amount.to_f64().abs() as f32;
140            self.edges
141                .push((flow.source_account_index, flow.target_account_index, weight));
142        }
143    }
144
145    /// Get base angle for account type grouping.
146    fn type_angle(&self, account_type: AccountType) -> f32 {
147        use std::f32::consts::PI;
148        match account_type {
149            AccountType::Asset => 0.0,
150            AccountType::Liability => PI * 2.0 / 5.0,
151            AccountType::Equity => PI * 4.0 / 5.0,
152            AccountType::Revenue => PI * 6.0 / 5.0,
153            AccountType::Expense => PI * 8.0 / 5.0,
154            AccountType::Contra => PI,
155        }
156    }
157
158    /// Run one iteration of the layout algorithm.
159    pub fn step(&mut self) -> bool {
160        if self.converged || self.nodes.is_empty() {
161            return false;
162        }
163
164        let mut forces: HashMap<u16, Vector2<f32>> = HashMap::new();
165        for &idx in self.nodes.keys() {
166            forces.insert(idx, Vector2::zeros());
167        }
168
169        // Calculate repulsion forces (all pairs) with minimum distance enforcement
170        let node_indices: Vec<u16> = self.nodes.keys().copied().collect();
171        let min_dist = self.config.min_node_distance;
172
173        for i in 0..node_indices.len() {
174            for j in (i + 1)..node_indices.len() {
175                let idx_i = node_indices[i];
176                let idx_j = node_indices[j];
177
178                let node_i = &self.nodes[&idx_i];
179                let node_j = &self.nodes[&idx_j];
180
181                let delta = node_j.position - node_i.position;
182                let distance = delta.magnitude().max(1.0);
183
184                // Strong repulsion when below minimum distance
185                let force_magnitude = if distance < min_dist {
186                    // Much stronger force when too close - push apart aggressively
187                    let overlap = min_dist / distance;
188                    (self.config.repulsion * overlap * overlap / (distance * distance)).min(200.0)
189                } else {
190                    // Normal Fruchterman-Reingold repulsion
191                    (self.config.repulsion / (distance * distance)).min(100.0)
192                };
193
194                // Safe normalization
195                let force = if distance > 0.01 {
196                    delta / distance * force_magnitude
197                } else {
198                    // Random direction if too close
199                    Vector2::new(
200                        (rand::random::<f32>() - 0.5) * force_magnitude,
201                        (rand::random::<f32>() - 0.5) * force_magnitude,
202                    )
203                };
204
205                if let Some(f) = forces.get_mut(&idx_i) {
206                    *f -= force;
207                }
208                if let Some(f) = forces.get_mut(&idx_j) {
209                    *f += force;
210                }
211            }
212        }
213
214        // Calculate attraction forces (edges)
215        for (source, target, weight) in &self.edges {
216            if let (Some(node_s), Some(node_t)) = (self.nodes.get(source), self.nodes.get(target)) {
217                let delta = node_t.position - node_s.position;
218                let distance = delta.magnitude().max(1.0);
219
220                // Attraction proportional to distance, with weight factor capped
221                let weight_factor = 1.0 + weight.abs().max(1.0).ln();
222                let force_magnitude = (self.config.attraction * distance * weight_factor).min(50.0);
223
224                // Safe normalization
225                let force = if distance > 0.01 {
226                    delta / distance * force_magnitude
227                } else {
228                    Vector2::zeros()
229                };
230
231                if let Some(f) = forces.get_mut(source) {
232                    *f += force;
233                }
234                if let Some(f) = forces.get_mut(target) {
235                    *f -= force;
236                }
237            }
238        }
239
240        // Apply gravity toward center
241        let center = Vector2::new(self.config.width / 2.0, self.config.height / 2.0);
242        for (idx, node) in &self.nodes {
243            let delta = center - node.position;
244            let gravity_force = delta * self.config.gravity * node.mass;
245            if let Some(f) = forces.get_mut(idx) {
246                *f += gravity_force;
247            }
248        }
249
250        // Apply forces and update positions
251        let mut max_displacement = 0.0f32;
252
253        for (idx, force) in &forces {
254            if let Some(node) = self.nodes.get_mut(idx) {
255                if node.pinned {
256                    continue;
257                }
258
259                // Skip if force is NaN or Inf
260                if !force.x.is_finite() || !force.y.is_finite() {
261                    continue;
262                }
263
264                // Update velocity with force and damping
265                let new_velocity =
266                    (node.velocity + *force / node.mass.max(0.1)) * self.config.damping;
267
268                // Skip if velocity becomes NaN
269                if !new_velocity.x.is_finite() || !new_velocity.y.is_finite() {
270                    node.velocity = Vector2::zeros();
271                    continue;
272                }
273
274                node.velocity = new_velocity;
275
276                // Limit velocity by temperature
277                let speed = node.velocity.magnitude();
278                if speed > self.temperature {
279                    node.velocity = node.velocity / speed * self.temperature;
280                }
281
282                // Cap maximum velocity
283                let max_speed = 50.0;
284                if speed > max_speed {
285                    node.velocity = node.velocity / speed * max_speed;
286                }
287
288                // Update position
289                let displacement = node.velocity;
290                node.position += displacement;
291
292                // Keep within bounds with padding
293                let padding = 50.0;
294                node.position.x = node.position.x.clamp(padding, self.config.width - padding);
295                node.position.y = node.position.y.clamp(padding, self.config.height - padding);
296
297                max_displacement = max_displacement.max(displacement.magnitude());
298            }
299        }
300
301        // Cool down
302        self.temperature *= 0.995;
303
304        // Check convergence
305        if max_displacement < self.config.min_velocity && self.temperature < 1.0 {
306            self.converged = true;
307        }
308
309        true
310    }
311
312    /// Run multiple iterations.
313    pub fn iterate(&mut self, iterations: usize) {
314        for _ in 0..iterations {
315            if !self.step() {
316                break;
317            }
318        }
319    }
320
321    /// Get node position.
322    pub fn get_position(&self, index: u16) -> Option<Vector2<f32>> {
323        self.nodes.get(&index).map(|n| n.position)
324    }
325
326    /// Get all nodes.
327    pub fn nodes(&self) -> impl Iterator<Item = &LayoutNode> {
328        self.nodes.values()
329    }
330
331    /// Get all edges.
332    pub fn edges(&self) -> &[(u16, u16, f32)] {
333        &self.edges
334    }
335
336    /// Pin a node at its current position.
337    pub fn pin_node(&mut self, index: u16) {
338        if let Some(node) = self.nodes.get_mut(&index) {
339            node.pinned = true;
340        }
341    }
342
343    /// Unpin a node.
344    pub fn unpin_node(&mut self, index: u16) {
345        if let Some(node) = self.nodes.get_mut(&index) {
346            node.pinned = false;
347        }
348    }
349
350    /// Set node position (for dragging).
351    pub fn set_position(&mut self, index: u16, position: Vector2<f32>) {
352        if let Some(node) = self.nodes.get_mut(&index) {
353            node.position = position;
354            node.velocity = Vector2::zeros();
355        }
356    }
357
358    /// Resize the layout area.
359    pub fn resize(&mut self, width: f32, height: f32) {
360        let scale_x = width / self.config.width;
361        let scale_y = height / self.config.height;
362
363        for node in self.nodes.values_mut() {
364            node.position.x *= scale_x;
365            node.position.y *= scale_y;
366        }
367
368        self.config.width = width;
369        self.config.height = height;
370    }
371
372    /// Reset the layout.
373    pub fn reset(&mut self, network: &AccountingNetwork) {
374        self.initialize(network);
375    }
376
377    /// Update edge weights from network without resetting positions.
378    /// This allows the layout to adapt to new flow data while preserving node positions.
379    pub fn update_edges(&mut self, network: &AccountingNetwork) {
380        // Aggregate edges by source-target pair
381        use std::collections::HashMap as StdHashMap;
382        let mut edge_weights: StdHashMap<(u16, u16), f32> = StdHashMap::new();
383
384        for flow in &network.flows {
385            let key = (flow.source_account_index, flow.target_account_index);
386            *edge_weights.entry(key).or_insert(0.0) += flow.amount.to_f64().abs() as f32;
387        }
388
389        // Convert to edge list
390        self.edges = edge_weights
391            .into_iter()
392            .map(|((s, t), w)| (s, t, w))
393            .collect();
394
395        // Add any missing nodes
396        for account in &network.accounts {
397            if !self.nodes.contains_key(&account.index) {
398                let center = Vector2::new(self.config.width / 2.0, self.config.height / 2.0);
399                let angle = self.type_angle(account.account_type);
400                let radius = self.config.width.min(self.config.height) * 0.3;
401                let position = center + Vector2::new(angle.cos() * radius, angle.sin() * radius);
402
403                self.nodes.insert(
404                    account.index,
405                    LayoutNode {
406                        index: account.index,
407                        account_type: account.account_type,
408                        position,
409                        velocity: Vector2::zeros(),
410                        pinned: false,
411                        mass: 1.0 + (account.risk_score * 2.0),
412                    },
413                );
414            }
415        }
416    }
417
418    /// Warm up the layout to allow it to readjust.
419    /// Call this periodically to let the layout adapt to weight changes.
420    pub fn warm_up(&mut self) {
421        self.converged = false;
422        self.temperature = (self.config.width / 20.0).max(20.0); // Moderate warm-up
423    }
424
425    /// Check if layout has any edges.
426    pub fn has_edges(&self) -> bool {
427        !self.edges.is_empty()
428    }
429
430    /// Get edge count.
431    pub fn edge_count(&self) -> usize {
432        self.edges.len()
433    }
434}
435
436impl Default for ForceDirectedLayout {
437    fn default() -> Self {
438        Self::new(LayoutConfig::default())
439    }
440}
441
442#[cfg(test)]
443mod tests {
444    use super::*;
445
446    #[test]
447    fn test_layout_creation() {
448        let layout = ForceDirectedLayout::default();
449        assert!(!layout.converged);
450        assert_eq!(layout.nodes.len(), 0);
451    }
452
453    #[test]
454    fn test_layout_config() {
455        let config = LayoutConfig::default();
456        assert!(config.repulsion > 0.0);
457        assert!(config.attraction > 0.0);
458    }
459}