1use crate::models::{AccountType, AccountingNetwork};
7use nalgebra::Vector2;
8use std::collections::HashMap;
9
10#[derive(Debug, Clone)]
12pub struct LayoutNode {
13 pub index: u16,
15 pub account_type: AccountType,
17 pub position: Vector2<f32>,
19 pub velocity: Vector2<f32>,
21 pub pinned: bool,
23 pub mass: f32,
25}
26
27#[derive(Debug, Clone)]
29pub struct LayoutConfig {
30 pub repulsion: f32,
32 pub attraction: f32,
34 pub damping: f32,
36 pub min_velocity: f32,
38 pub max_iterations: usize,
40 pub ideal_length: f32,
42 pub width: f32,
44 pub height: f32,
46 pub gravity: f32,
48 pub group_by_type: bool,
50 pub min_node_distance: f32,
52}
53
54impl Default for LayoutConfig {
55 fn default() -> Self {
56 Self {
57 repulsion: 50000.0, attraction: 0.0008, damping: 0.85, min_velocity: 0.3, max_iterations: 50,
62 ideal_length: 300.0, width: 800.0,
64 height: 600.0,
65 gravity: 0.008, group_by_type: true,
67 min_node_distance: 300.0, }
69 }
70}
71
72pub struct ForceDirectedLayout {
74 pub config: LayoutConfig,
76 nodes: HashMap<u16, LayoutNode>,
78 edges: Vec<(u16, u16, f32)>,
80 pub converged: bool,
82 temperature: f32,
84}
85
86impl ForceDirectedLayout {
87 pub fn new(config: LayoutConfig) -> Self {
89 Self {
90 temperature: config.width / 10.0,
91 config,
92 nodes: HashMap::new(),
93 edges: Vec::new(),
94 converged: false,
95 }
96 }
97
98 pub fn initialize(&mut self, network: &AccountingNetwork) {
100 self.nodes.clear();
101 self.edges.clear();
102 self.converged = false;
103 self.temperature = self.config.width / 10.0;
104
105 let center = Vector2::new(self.config.width / 2.0, self.config.height / 2.0);
107
108 for (i, account) in network.accounts.iter().enumerate() {
109 let angle = self.type_angle(account.account_type) + (i as f32 * 0.1);
111 let radius = self.config.width.min(self.config.height) * 0.3;
112
113 let position = if self.config.group_by_type {
114 center + Vector2::new(angle.cos() * radius, angle.sin() * radius)
115 } else {
116 center
118 + Vector2::new(
119 (rand::random::<f32>() - 0.5) * self.config.width * 0.8,
120 (rand::random::<f32>() - 0.5) * self.config.height * 0.8,
121 )
122 };
123
124 self.nodes.insert(
125 account.index,
126 LayoutNode {
127 index: account.index,
128 account_type: account.account_type,
129 position,
130 velocity: Vector2::zeros(),
131 pinned: false,
132 mass: 1.0 + (account.risk_score * 2.0),
133 },
134 );
135 }
136
137 for flow in &network.flows {
139 let weight = flow.amount.to_f64().abs() as f32;
140 self.edges
141 .push((flow.source_account_index, flow.target_account_index, weight));
142 }
143 }
144
145 fn type_angle(&self, account_type: AccountType) -> f32 {
147 use std::f32::consts::PI;
148 match account_type {
149 AccountType::Asset => 0.0,
150 AccountType::Liability => PI * 2.0 / 5.0,
151 AccountType::Equity => PI * 4.0 / 5.0,
152 AccountType::Revenue => PI * 6.0 / 5.0,
153 AccountType::Expense => PI * 8.0 / 5.0,
154 AccountType::Contra => PI,
155 }
156 }
157
158 pub fn step(&mut self) -> bool {
160 if self.converged || self.nodes.is_empty() {
161 return false;
162 }
163
164 let mut forces: HashMap<u16, Vector2<f32>> = HashMap::new();
165 for &idx in self.nodes.keys() {
166 forces.insert(idx, Vector2::zeros());
167 }
168
169 let node_indices: Vec<u16> = self.nodes.keys().copied().collect();
171 let min_dist = self.config.min_node_distance;
172
173 for i in 0..node_indices.len() {
174 for j in (i + 1)..node_indices.len() {
175 let idx_i = node_indices[i];
176 let idx_j = node_indices[j];
177
178 let node_i = &self.nodes[&idx_i];
179 let node_j = &self.nodes[&idx_j];
180
181 let delta = node_j.position - node_i.position;
182 let distance = delta.magnitude().max(1.0);
183
184 let force_magnitude = if distance < min_dist {
186 let overlap = min_dist / distance;
188 (self.config.repulsion * overlap * overlap / (distance * distance)).min(200.0)
189 } else {
190 (self.config.repulsion / (distance * distance)).min(100.0)
192 };
193
194 let force = if distance > 0.01 {
196 delta / distance * force_magnitude
197 } else {
198 Vector2::new(
200 (rand::random::<f32>() - 0.5) * force_magnitude,
201 (rand::random::<f32>() - 0.5) * force_magnitude,
202 )
203 };
204
205 if let Some(f) = forces.get_mut(&idx_i) {
206 *f -= force;
207 }
208 if let Some(f) = forces.get_mut(&idx_j) {
209 *f += force;
210 }
211 }
212 }
213
214 for (source, target, weight) in &self.edges {
216 if let (Some(node_s), Some(node_t)) = (self.nodes.get(source), self.nodes.get(target)) {
217 let delta = node_t.position - node_s.position;
218 let distance = delta.magnitude().max(1.0);
219
220 let weight_factor = 1.0 + weight.abs().max(1.0).ln();
222 let force_magnitude = (self.config.attraction * distance * weight_factor).min(50.0);
223
224 let force = if distance > 0.01 {
226 delta / distance * force_magnitude
227 } else {
228 Vector2::zeros()
229 };
230
231 if let Some(f) = forces.get_mut(source) {
232 *f += force;
233 }
234 if let Some(f) = forces.get_mut(target) {
235 *f -= force;
236 }
237 }
238 }
239
240 let center = Vector2::new(self.config.width / 2.0, self.config.height / 2.0);
242 for (idx, node) in &self.nodes {
243 let delta = center - node.position;
244 let gravity_force = delta * self.config.gravity * node.mass;
245 if let Some(f) = forces.get_mut(idx) {
246 *f += gravity_force;
247 }
248 }
249
250 let mut max_displacement = 0.0f32;
252
253 for (idx, force) in &forces {
254 if let Some(node) = self.nodes.get_mut(idx) {
255 if node.pinned {
256 continue;
257 }
258
259 if !force.x.is_finite() || !force.y.is_finite() {
261 continue;
262 }
263
264 let new_velocity =
266 (node.velocity + *force / node.mass.max(0.1)) * self.config.damping;
267
268 if !new_velocity.x.is_finite() || !new_velocity.y.is_finite() {
270 node.velocity = Vector2::zeros();
271 continue;
272 }
273
274 node.velocity = new_velocity;
275
276 let speed = node.velocity.magnitude();
278 if speed > self.temperature {
279 node.velocity = node.velocity / speed * self.temperature;
280 }
281
282 let max_speed = 50.0;
284 if speed > max_speed {
285 node.velocity = node.velocity / speed * max_speed;
286 }
287
288 let displacement = node.velocity;
290 node.position += displacement;
291
292 let padding = 50.0;
294 node.position.x = node.position.x.clamp(padding, self.config.width - padding);
295 node.position.y = node.position.y.clamp(padding, self.config.height - padding);
296
297 max_displacement = max_displacement.max(displacement.magnitude());
298 }
299 }
300
301 self.temperature *= 0.995;
303
304 if max_displacement < self.config.min_velocity && self.temperature < 1.0 {
306 self.converged = true;
307 }
308
309 true
310 }
311
312 pub fn iterate(&mut self, iterations: usize) {
314 for _ in 0..iterations {
315 if !self.step() {
316 break;
317 }
318 }
319 }
320
321 pub fn get_position(&self, index: u16) -> Option<Vector2<f32>> {
323 self.nodes.get(&index).map(|n| n.position)
324 }
325
326 pub fn nodes(&self) -> impl Iterator<Item = &LayoutNode> {
328 self.nodes.values()
329 }
330
331 pub fn edges(&self) -> &[(u16, u16, f32)] {
333 &self.edges
334 }
335
336 pub fn pin_node(&mut self, index: u16) {
338 if let Some(node) = self.nodes.get_mut(&index) {
339 node.pinned = true;
340 }
341 }
342
343 pub fn unpin_node(&mut self, index: u16) {
345 if let Some(node) = self.nodes.get_mut(&index) {
346 node.pinned = false;
347 }
348 }
349
350 pub fn set_position(&mut self, index: u16, position: Vector2<f32>) {
352 if let Some(node) = self.nodes.get_mut(&index) {
353 node.position = position;
354 node.velocity = Vector2::zeros();
355 }
356 }
357
358 pub fn resize(&mut self, width: f32, height: f32) {
360 let scale_x = width / self.config.width;
361 let scale_y = height / self.config.height;
362
363 for node in self.nodes.values_mut() {
364 node.position.x *= scale_x;
365 node.position.y *= scale_y;
366 }
367
368 self.config.width = width;
369 self.config.height = height;
370 }
371
372 pub fn reset(&mut self, network: &AccountingNetwork) {
374 self.initialize(network);
375 }
376
377 pub fn update_edges(&mut self, network: &AccountingNetwork) {
380 use std::collections::HashMap as StdHashMap;
382 let mut edge_weights: StdHashMap<(u16, u16), f32> = StdHashMap::new();
383
384 for flow in &network.flows {
385 let key = (flow.source_account_index, flow.target_account_index);
386 *edge_weights.entry(key).or_insert(0.0) += flow.amount.to_f64().abs() as f32;
387 }
388
389 self.edges = edge_weights
391 .into_iter()
392 .map(|((s, t), w)| (s, t, w))
393 .collect();
394
395 for account in &network.accounts {
397 if !self.nodes.contains_key(&account.index) {
398 let center = Vector2::new(self.config.width / 2.0, self.config.height / 2.0);
399 let angle = self.type_angle(account.account_type);
400 let radius = self.config.width.min(self.config.height) * 0.3;
401 let position = center + Vector2::new(angle.cos() * radius, angle.sin() * radius);
402
403 self.nodes.insert(
404 account.index,
405 LayoutNode {
406 index: account.index,
407 account_type: account.account_type,
408 position,
409 velocity: Vector2::zeros(),
410 pinned: false,
411 mass: 1.0 + (account.risk_score * 2.0),
412 },
413 );
414 }
415 }
416 }
417
418 pub fn warm_up(&mut self) {
421 self.converged = false;
422 self.temperature = (self.config.width / 20.0).max(20.0); }
424
425 pub fn has_edges(&self) -> bool {
427 !self.edges.is_empty()
428 }
429
430 pub fn edge_count(&self) -> usize {
432 self.edges.len()
433 }
434}
435
436impl Default for ForceDirectedLayout {
437 fn default() -> Self {
438 Self::new(LayoutConfig::default())
439 }
440}
441
442#[cfg(test)]
443mod tests {
444 use super::*;
445
446 #[test]
447 fn test_layout_creation() {
448 let layout = ForceDirectedLayout::default();
449 assert!(!layout.converged);
450 assert_eq!(layout.nodes.len(), 0);
451 }
452
453 #[test]
454 fn test_layout_config() {
455 let config = LayoutConfig::default();
456 assert!(config.repulsion > 0.0);
457 assert!(config.attraction > 0.0);
458 }
459}