use super::NodeBounds;
use crate::graph::Shape;
use crate::graph::attachment::{
Face as RoutingFace, classify_face_float as shared_classify_face_float,
};
use crate::graph::space::{FPoint, FRect};
const MIN_ATTACHMENT_GAP: usize = 2;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum NodeFace {
Top,
Bottom,
Left,
Right,
}
pub fn face_extent(bounds: &NodeBounds, face: &NodeFace) -> (usize, usize) {
match face {
NodeFace::Top | NodeFace::Bottom => {
let start = bounds.x + 1;
let end = (bounds.x + bounds.width).saturating_sub(2);
(start, end.max(start))
}
NodeFace::Left | NodeFace::Right => {
let start = bounds.y;
let end = bounds.y + bounds.height.saturating_sub(1);
(start, end.max(start))
}
}
}
pub fn face_fixed_coord(bounds: &NodeBounds, face: &NodeFace) -> usize {
match face {
NodeFace::Top => bounds.y,
NodeFace::Bottom => bounds.y + bounds.height.saturating_sub(1),
NodeFace::Left => bounds.x,
NodeFace::Right => bounds.x + bounds.width.saturating_sub(1),
}
}
pub fn classify_face(
bounds: &NodeBounds,
approach_point: (usize, usize),
_shape: Shape,
) -> NodeFace {
let center = FPoint::new(bounds.center_x() as f64, bounds.center_y() as f64);
let rect = FRect::new(
bounds.x as f64,
bounds.y as f64,
bounds.width as f64,
bounds.height as f64,
);
let approach = FPoint::new(approach_point.0 as f64, approach_point.1 as f64);
match shared_classify_face_float(center, rect, approach) {
RoutingFace::Top => NodeFace::Top,
RoutingFace::Bottom => NodeFace::Bottom,
RoutingFace::Left => NodeFace::Left,
RoutingFace::Right => NodeFace::Right,
}
}
pub fn spread_points_on_face(
face: NodeFace,
fixed_coord: usize,
extent: (usize, usize),
count: usize,
) -> Vec<(usize, usize)> {
if count == 0 {
return vec![];
}
let (start, end) = extent;
let range = end.saturating_sub(start);
let to_point = |pos: usize| match face {
NodeFace::Top | NodeFace::Bottom => (pos, fixed_coord),
NodeFace::Left | NodeFace::Right => (fixed_coord, pos),
};
if count == 1 {
return vec![to_point(start + range / 2)];
}
let mut positions: Vec<usize> = (0..count)
.map(|i| {
let pos = start + (i * range) / (count - 1);
pos.min(end)
})
.collect();
let needed_span = (count - 1) * MIN_ATTACHMENT_GAP;
if range >= needed_span {
for i in 1..positions.len() {
let min_pos = positions[i - 1] + MIN_ATTACHMENT_GAP;
if positions[i] < min_pos {
positions[i] = min_pos;
}
}
if let Some(&last) = positions.last()
&& last > end
{
let overshoot = last - end;
for pos in &mut positions {
*pos = pos.saturating_sub(overshoot);
}
}
} else if matches!(face, NodeFace::Top | NodeFace::Bottom) {
let center = start + range / 2;
let half_span = needed_span / 2;
let extended_start = center.saturating_sub(half_span);
positions = (0..count)
.map(|i| extended_start + (i * needed_span) / (count - 1))
.collect();
} else if count > range + 1 {
let capped_start = start.saturating_sub(1);
let capped_end = end + 1;
let capped_range = capped_end.saturating_sub(capped_start);
positions = (0..count)
.map(|i| {
let pos = capped_start + (i * capped_range) / (count - 1);
pos.min(capped_end)
})
.collect();
}
positions.into_iter().map(to_point).collect()
}
#[derive(Debug, Clone, Copy)]
pub struct FloatPoint {
pub x: f64,
pub y: f64,
}
impl FloatPoint {
pub fn new(x: f64, y: f64) -> Self {
Self { x, y }
}
pub fn to_usize(self) -> (usize, usize) {
(self.x.round() as usize, self.y.round() as usize)
}
}
impl From<(usize, usize)> for FloatPoint {
fn from((x, y): (usize, usize)) -> Self {
Self {
x: x as f64,
y: y as f64,
}
}
}
pub fn intersect_rect(bounds: &NodeBounds, point: FloatPoint) -> FloatPoint {
let x = bounds.center_x() as f64;
let y = bounds.center_y() as f64;
let dx = point.x - x;
let dy = point.y - y;
let w = bounds.width as f64 / 2.0;
let h = bounds.height as f64 / 2.0;
if dx.abs() < f64::EPSILON && dy.abs() < f64::EPSILON {
return FloatPoint::new(x, y + h);
}
let (sx, sy) = if dy.abs() * w > dx.abs() * h {
let h = if dy < 0.0 { -h } else { h };
(h * dx / dy, h)
} else {
let w = if dx < 0.0 { -w } else { w };
(w, w * dy / dx)
};
FloatPoint::new(x + sx, y + sy)
}
pub fn intersect_diamond(bounds: &NodeBounds, point: FloatPoint) -> FloatPoint {
let x = bounds.center_x() as f64;
let y = bounds.center_y() as f64;
let dx = point.x - x;
let dy = point.y - y;
let w = bounds.width as f64 / 2.0;
let h = bounds.height as f64 / 2.0;
if dx.abs() < f64::EPSILON && dy.abs() < f64::EPSILON {
return FloatPoint::new(x, y + h);
}
let t = 1.0 / (dx.abs() / w + dy.abs() / h);
FloatPoint::new(x + t * dx, y + t * dy)
}
pub fn intersect_node(bounds: &NodeBounds, point: (usize, usize), shape: Shape) -> (usize, usize) {
let float_point = FloatPoint::from(point);
let result = match shape {
Shape::Diamond | Shape::Hexagon => intersect_diamond(bounds, float_point),
_ => intersect_rect(bounds, float_point),
};
result.to_usize()
}
pub fn calculate_attachment_points(
source_bounds: &NodeBounds,
source_shape: Shape,
target_bounds: &NodeBounds,
target_shape: Shape,
waypoints: &[(usize, usize)],
) -> ((usize, usize), (usize, usize)) {
let source_center = (source_bounds.center_x(), source_bounds.center_y());
let target_center = (target_bounds.center_x(), target_bounds.center_y());
let source_attach = if let Some(&first_wp) = waypoints.first() {
intersect_node(source_bounds, first_wp, source_shape)
} else {
intersect_node(source_bounds, target_center, source_shape)
};
let target_attach = if let Some(&last_wp) = waypoints.last() {
intersect_node(target_bounds, last_wp, target_shape)
} else {
intersect_node(target_bounds, source_center, target_shape)
};
(source_attach, target_attach)
}
#[cfg(test)]
mod tests {
use super::*;
fn test_bounds() -> NodeBounds {
NodeBounds {
x: 10,
y: 5,
width: 10,
height: 5,
layout_center_x: None,
layout_center_y: None,
}
}
#[test]
fn test_intersect_rect_from_below() {
let bounds = test_bounds();
let point = FloatPoint::new(15.0, 20.0);
let result = intersect_rect(&bounds, point);
assert_eq!(result.x.round() as usize, 15);
assert_eq!(result.y.round() as usize, 10);
}
#[test]
fn test_intersect_rect_from_above() {
let bounds = test_bounds();
let point = FloatPoint::new(15.0, 0.0);
let result = intersect_rect(&bounds, point);
assert_eq!(result.x.round() as usize, 15);
assert!(result.y < bounds.center_y() as f64);
}
#[test]
fn test_intersect_rect_from_right() {
let bounds = test_bounds();
let point = FloatPoint::new(30.0, 7.5);
let result = intersect_rect(&bounds, point);
assert!(result.x > bounds.center_x() as f64);
assert_eq!(result.y.round() as usize, bounds.center_y());
}
#[test]
fn test_intersect_rect_from_left() {
let bounds = test_bounds();
let point = FloatPoint::new(0.0, 7.5);
let result = intersect_rect(&bounds, point);
assert!(result.x < bounds.center_x() as f64);
assert_eq!(result.y.round() as usize, bounds.center_y());
}
#[test]
fn test_intersect_rect_diagonal() {
let bounds = test_bounds();
let point = FloatPoint::new(25.0, 15.0);
let result = intersect_rect(&bounds, point);
let on_right = (result.x - (bounds.x + bounds.width) as f64).abs() < 1.0;
let on_bottom = (result.y - (bounds.y + bounds.height) as f64).abs() < 1.0;
assert!(on_right || on_bottom);
}
#[test]
fn test_intersect_diamond_from_below() {
let bounds = test_bounds();
let point = FloatPoint::new(15.0, 20.0);
let result = intersect_diamond(&bounds, point);
assert_eq!(result.x.round() as usize, bounds.center_x());
assert!(result.y > bounds.center_y() as f64);
}
#[test]
fn test_intersect_diamond_from_right() {
let bounds = test_bounds();
let point = FloatPoint::new(30.0, 7.5);
let result = intersect_diamond(&bounds, point);
assert!(result.x > bounds.center_x() as f64);
assert_eq!(result.y.round() as usize, bounds.center_y());
}
#[test]
fn test_intersect_node_rectangle() {
let bounds = test_bounds();
let point = (15, 20);
let result = intersect_node(&bounds, point, Shape::Rectangle);
assert!(result.1 >= bounds.y);
assert!(result.1 <= bounds.y + bounds.height);
}
#[test]
fn test_intersect_node_diamond() {
let bounds = test_bounds();
let point = (15, 20);
let result = intersect_node(&bounds, point, Shape::Diamond);
assert!(result.1 >= bounds.y);
assert!(result.1 <= bounds.y + bounds.height);
}
#[test]
fn test_calculate_attachment_points_direct() {
let source = NodeBounds {
x: 10,
y: 5,
width: 10,
height: 3,
layout_center_x: None,
layout_center_y: None,
};
let target = NodeBounds {
x: 10,
y: 15,
width: 10,
height: 3,
layout_center_x: None,
layout_center_y: None,
};
let (src_attach, tgt_attach) =
calculate_attachment_points(&source, Shape::Rectangle, &target, Shape::Rectangle, &[]);
assert!(src_attach.1 > source.y);
assert!(tgt_attach.1 < target.y + target.height);
}
#[test]
fn test_calculate_attachment_points_with_waypoints() {
let source = NodeBounds {
x: 10,
y: 5,
width: 10,
height: 3,
layout_center_x: None,
layout_center_y: None,
};
let target = NodeBounds {
x: 30,
y: 15,
width: 10,
height: 3,
layout_center_x: None,
layout_center_y: None,
};
let waypoints = [(20, 10), (25, 12)];
let (src_attach, tgt_attach) = calculate_attachment_points(
&source,
Shape::Rectangle,
&target,
Shape::Rectangle,
&waypoints,
);
assert!(src_attach.0 >= source.x && src_attach.0 <= source.x + source.width);
assert!(tgt_attach.0 >= target.x && tgt_attach.0 <= target.x + target.width);
}
#[test]
fn test_intersect_diamond_from_above() {
let bounds = test_bounds();
let point = FloatPoint::new(15.0, 0.0);
let result = intersect_diamond(&bounds, point);
assert_eq!(result.x.round() as usize, bounds.center_x());
assert!(result.y < bounds.center_y() as f64);
}
#[test]
fn test_intersect_diamond_from_left() {
let bounds = test_bounds();
let point = FloatPoint::new(0.0, 7.5);
let result = intersect_diamond(&bounds, point);
assert!(result.x < bounds.center_x() as f64);
assert_eq!(result.y.round() as usize, bounds.center_y());
}
#[test]
fn test_intersect_diamond_diagonal() {
let bounds = test_bounds();
let point = FloatPoint::new(25.0, 15.0);
let result = intersect_diamond(&bounds, point);
let center_x = bounds.center_x() as f64;
let center_y = bounds.center_y() as f64;
let dx = (result.x - center_x).abs();
let dy = (result.y - center_y).abs();
let w = bounds.width as f64 / 2.0;
let h = bounds.height as f64 / 2.0;
let boundary_check = dx / w + dy / h;
assert!(
(boundary_check - 1.0).abs() < 0.1,
"Point should be on diamond boundary, got {}",
boundary_check
);
}
#[test]
fn test_intersect_rect_point_at_center() {
let bounds = test_bounds();
let point = FloatPoint::new(bounds.center_x() as f64, bounds.center_y() as f64);
let result = intersect_rect(&bounds, point);
assert_eq!(result.x.round() as usize, bounds.center_x());
}
#[test]
fn test_intersect_diamond_point_at_center() {
let bounds = test_bounds();
let point = FloatPoint::new(bounds.center_x() as f64, bounds.center_y() as f64);
let result = intersect_diamond(&bounds, point);
assert_eq!(result.x.round() as usize, bounds.center_x());
}
#[test]
fn test_intersect_node_round_uses_rect() {
let bounds = test_bounds();
let point = (15, 20);
let rect_result = intersect_node(&bounds, point, Shape::Rectangle);
let round_result = intersect_node(&bounds, point, Shape::Round);
assert_eq!(rect_result, round_result);
}
#[test]
fn test_calculate_attachment_points_diamond_source() {
let source = NodeBounds {
x: 10,
y: 5,
width: 10,
height: 5,
layout_center_x: None,
layout_center_y: None,
};
let target = NodeBounds {
x: 10,
y: 20,
width: 10,
height: 3,
layout_center_x: None,
layout_center_y: None,
};
let (src_attach, tgt_attach) =
calculate_attachment_points(&source, Shape::Diamond, &target, Shape::Rectangle, &[]);
assert_eq!(src_attach.0, source.center_x());
assert!(tgt_attach.1 < target.y + target.height);
}
#[test]
fn test_float_point_to_usize() {
let p = FloatPoint::new(10.4, 20.6);
let (x, y) = p.to_usize();
assert_eq!(x, 10);
assert_eq!(y, 21);
}
#[test]
fn test_float_point_from_tuple() {
let p = FloatPoint::from((15_usize, 25_usize));
assert_eq!(p.x, 15.0);
assert_eq!(p.y, 25.0);
}
#[test]
fn test_classify_face_from_above() {
let bounds = test_bounds(); let result = classify_face(&bounds, (15, 0), Shape::Rectangle);
assert_eq!(result, NodeFace::Top);
}
#[test]
fn test_classify_face_from_below() {
let bounds = test_bounds();
let result = classify_face(&bounds, (15, 20), Shape::Rectangle);
assert_eq!(result, NodeFace::Bottom);
}
#[test]
fn test_classify_face_from_left() {
let bounds = test_bounds();
let result = classify_face(&bounds, (0, 7), Shape::Rectangle);
assert_eq!(result, NodeFace::Left);
}
#[test]
fn test_classify_face_from_right() {
let bounds = test_bounds();
let result = classify_face(&bounds, (30, 7), Shape::Rectangle);
assert_eq!(result, NodeFace::Right);
}
#[test]
fn test_classify_face_degenerate_center() {
let bounds = test_bounds();
let result = classify_face(&bounds, (15, 7), Shape::Rectangle);
assert_eq!(result, NodeFace::Bottom);
}
#[test]
fn test_classify_face_diamond_same_as_rect() {
let bounds = test_bounds();
assert_eq!(
classify_face(&bounds, (15, 0), Shape::Diamond),
NodeFace::Top
);
assert_eq!(
classify_face(&bounds, (15, 20), Shape::Diamond),
NodeFace::Bottom
);
}
#[test]
fn test_spread_points_count_zero() {
let result = spread_points_on_face(NodeFace::Top, 5, (2, 10), 0);
assert!(result.is_empty());
}
#[test]
fn test_spread_points_count_one() {
let result = spread_points_on_face(NodeFace::Bottom, 5, (2, 10), 1);
assert_eq!(result, vec![(6, 5)]);
}
#[test]
fn test_spread_points_endpoint_maximizing() {
let points = spread_points_on_face(NodeFace::Top, 5, (10, 13), 2);
assert_eq!(points, vec![(10, 5), (13, 5)]);
}
#[test]
fn test_spread_points_two_on_wide_range() {
let result = spread_points_on_face(NodeFace::Top, 0, (0, 8), 2);
assert_eq!(result, vec![(0, 0), (8, 0)]);
}
#[test]
fn test_spread_points_count_three() {
let result = spread_points_on_face(NodeFace::Bottom, 10, (0, 8), 3);
assert_eq!(result, vec![(0, 10), (4, 10), (8, 10)]);
}
#[test]
fn test_spread_points_left_right_face() {
let result = spread_points_on_face(NodeFace::Left, 5, (0, 8), 2);
assert_eq!(result[0], (5, 0));
assert_eq!(result[1], (5, 8));
}
#[test]
fn test_spread_points_narrow_range() {
let result = spread_points_on_face(NodeFace::Top, 0, (0, 2), 3);
assert_eq!(result, vec![(0, 0), (2, 0), (4, 0)]);
let xs: Vec<usize> = result.iter().map(|&(x, _)| x).collect();
for w in xs.windows(2) {
assert!(w[1] - w[0] >= 2);
}
}
#[test]
fn test_spread_points_min_gap_sufficient_range() {
let points = spread_points_on_face(NodeFace::Top, 0, (0, 7), 4);
let xs: Vec<usize> = points.iter().map(|&(x, _)| x).collect();
for w in xs.windows(2) {
assert!(
w[1] - w[0] >= 2,
"gap too small between {} and {}",
w[0],
w[1]
);
}
}
#[test]
fn test_spread_points_min_gap_insufficient_range() {
let points = spread_points_on_face(NodeFace::Top, 0, (0, 5), 4);
let xs: Vec<usize> = points.iter().map(|&(x, _)| x).collect();
assert_eq!(xs, vec![0, 2, 4, 6]);
for w in xs.windows(2) {
assert!(w[1] - w[0] >= 2, "gap too small: {} to {}", w[0], w[1]);
}
}
#[test]
fn test_spread_points_min_gap_barely_sufficient() {
let points = spread_points_on_face(NodeFace::Top, 0, (0, 8), 4);
let xs: Vec<usize> = points.iter().map(|&(x, _)| x).collect();
assert_eq!(xs, vec![0, 2, 5, 8]);
for w in xs.windows(2) {
assert!(w[1] - w[0] >= 2, "gap too small: {} to {}", w[0], w[1]);
}
}
#[test]
fn test_spread_points_min_gap_three_on_three() {
let points = spread_points_on_face(NodeFace::Top, 0, (0, 3), 3);
let xs: Vec<usize> = points.iter().map(|&(x, _)| x).collect();
assert_eq!(xs.len(), 3);
for w in xs.windows(2) {
assert!(w[1] - w[0] >= 2, "gap too small: {} to {}", w[0], w[1]);
}
}
#[test]
fn test_spread_points_min_gap_wide_face() {
let points = spread_points_on_face(NodeFace::Top, 0, (0, 20), 3);
assert_eq!(points, vec![(0, 0), (10, 0), (20, 0)]);
}
#[test]
fn test_spread_points_min_gap_exact_fit() {
let points = spread_points_on_face(NodeFace::Top, 0, (0, 4), 3);
assert_eq!(points, vec![(0, 0), (2, 0), (4, 0)]);
}
#[test]
fn test_spread_points_five_on_wide_face() {
let result = spread_points_on_face(NodeFace::Bottom, 0, (0, 12), 5);
assert_eq!(result, vec![(0, 0), (3, 0), (6, 0), (9, 0), (12, 0)]);
}
}