use std::collections::HashMap;
use crate::graph::Direction;
use crate::graph::space::{FPoint, FRect};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum PortFace {
Top,
Bottom,
Left,
Right,
}
impl PortFace {
pub fn as_str(&self) -> &'static str {
match self {
Self::Top => "top",
Self::Bottom => "bottom",
Self::Left => "left",
Self::Right => "right",
}
}
}
impl std::str::FromStr for PortFace {
type Err = ();
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"top" => Ok(Self::Top),
"bottom" => Ok(Self::Bottom),
"left" => Ok(Self::Left),
"right" => Ok(Self::Right),
_ => Err(()),
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct EdgePort {
pub face: PortFace,
pub fraction: f64,
pub position: FPoint,
pub group_size: usize,
}
#[derive(Debug, Clone, Copy, Hash, Eq, PartialEq)]
pub enum Face {
Top,
Bottom,
Left,
Right,
}
impl Face {
pub fn to_port_face(self) -> PortFace {
match self {
Face::Top => PortFace::Top,
Face::Bottom => PortFace::Bottom,
Face::Left => PortFace::Left,
Face::Right => PortFace::Right,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum OverflowSide {
LeftOrTop,
RightOrBottom,
}
pub const FAN_IN_PRIMARY_FACE_CAPACITY_TD_BT: usize = 4;
pub const FAN_IN_PRIMARY_FACE_CAPACITY_LR_RL: usize = 2;
pub const BACKWARD_SIDE_CHANNEL_LONG_RANK_SPAN: usize = 6;
pub const LARGE_HORIZONTAL_OFFSET_THRESHOLD: usize = 30;
const TD_BT_PARITY_MIN_RECT_SPAN: f64 = 20.0;
pub fn fan_in_primary_face_capacity(direction: Direction) -> usize {
match direction {
Direction::TopDown | Direction::BottomTop => FAN_IN_PRIMARY_FACE_CAPACITY_TD_BT,
Direction::LeftRight | Direction::RightLeft => FAN_IN_PRIMARY_FACE_CAPACITY_LR_RL,
}
}
pub fn fan_in_overflow_face_for_slot(direction: Direction, slot: OverflowSide) -> Face {
match direction {
Direction::TopDown | Direction::BottomTop => match slot {
OverflowSide::LeftOrTop => Face::Left,
OverflowSide::RightOrBottom => Face::Right,
},
Direction::LeftRight | Direction::RightLeft => match slot {
OverflowSide::LeftOrTop => Face::Top,
OverflowSide::RightOrBottom => Face::Bottom,
},
}
}
pub fn canonical_backward_channel_face(direction: Direction) -> Face {
match direction {
Direction::TopDown | Direction::BottomTop => Face::Right,
Direction::LeftRight | Direction::RightLeft => Face::Bottom,
}
}
pub fn fan_in_primary_target_face(direction: Direction) -> Face {
match direction {
Direction::TopDown => Face::Top,
Direction::BottomTop => Face::Bottom,
Direction::LeftRight => Face::Left,
Direction::RightLeft => Face::Right,
}
}
fn fan_in_non_canonical_overflow_face(direction: Direction) -> Face {
match direction {
Direction::TopDown | Direction::BottomTop => Face::Left,
Direction::LeftRight | Direction::RightLeft => Face::Top,
}
}
pub fn resolve_overflow_backward_channel_conflict(
direction: Direction,
is_backward: bool,
target_has_backward_conflict: bool,
overflow_face: Option<Face>,
proposed_face: Face,
) -> Face {
if !is_backward || overflow_face.is_none() {
if target_has_backward_conflict
&& overflow_face.is_some()
&& proposed_face == canonical_backward_channel_face(direction)
{
return fan_in_non_canonical_overflow_face(direction);
}
return proposed_face;
}
canonical_backward_channel_face(direction)
}
pub fn prefer_backward_side_channel(
is_backward: bool,
has_layout_waypoints: bool,
rank_span: Option<usize>,
) -> bool {
if !is_backward {
return false;
}
if rank_span.is_some_and(|span| span >= BACKWARD_SIDE_CHANNEL_LONG_RANK_SPAN) {
return true;
}
!has_layout_waypoints
}
pub fn can_apply_td_bt_backward_hint_parity(
direction: Direction,
is_backward: bool,
has_subgraph_endpoint: bool,
rank_span: usize,
source_rect: FRect,
target_rect: FRect,
source_center_x: f64,
) -> bool {
if !matches!(direction, Direction::TopDown | Direction::BottomTop) {
return false;
}
if has_subgraph_endpoint {
return false;
}
if prefer_backward_side_channel(is_backward, true, Some(rank_span)) {
return false;
}
if source_rect.width < TD_BT_PARITY_MIN_RECT_SPAN
|| source_rect.height < TD_BT_PARITY_MIN_RECT_SPAN
|| target_rect.width < TD_BT_PARITY_MIN_RECT_SPAN
|| target_rect.height < TD_BT_PARITY_MIN_RECT_SPAN
{
return false;
}
let target_right = target_rect.x + target_rect.width;
source_center_x <= target_right
}
pub fn classify_face_float(center: FPoint, rect: FRect, approach: FPoint) -> Face {
let dx = approach.x - center.x;
let dy = approach.y - center.y;
if dx.abs() < 0.5 && dy.abs() < 0.5 {
return Face::Bottom;
}
let half_w = rect.width / 2.0;
let half_h = rect.height / 2.0;
if dy.abs() * half_w > dx.abs() * half_h {
if dy < 0.0 { Face::Top } else { Face::Bottom }
} else if dx < 0.0 {
Face::Left
} else {
Face::Right
}
}
pub fn point_on_face_float(rect: FRect, face: Face, fraction: f64) -> FPoint {
let fraction = fraction.clamp(0.0, 1.0);
match face {
Face::Top => FPoint::new(rect.x + rect.width * fraction, rect.y),
Face::Bottom => FPoint::new(rect.x + rect.width * fraction, rect.y + rect.height),
Face::Left => FPoint::new(rect.x, rect.y + rect.height * fraction),
Face::Right => FPoint::new(rect.x + rect.width, rect.y + rect.height * fraction),
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct EdgeAttachment {
pub face: Face,
pub fraction: f64,
}
#[derive(Debug, Clone, PartialEq)]
pub struct PlannedEdgeAttachments {
pub source: Option<EdgeAttachment>,
pub target: Option<EdgeAttachment>,
}
#[derive(Debug, Clone, Default, PartialEq)]
pub struct AttachmentPlan {
edge_attachments: HashMap<usize, PlannedEdgeAttachments>,
group_sizes: HashMap<(String, Face), usize>,
source_fractions: HashMap<(String, Face), Vec<f64>>,
target_fractions: HashMap<(String, Face), Vec<f64>>,
}
impl AttachmentPlan {
#[cfg(test)]
#[allow(dead_code)]
pub fn source_fractions_for(&self, node_id: &str, face: Face) -> Vec<f64> {
self.source_fractions
.get(&(node_id.to_string(), face))
.cloned()
.unwrap_or_default()
}
pub fn edge(&self, edge_index: usize) -> Option<&PlannedEdgeAttachments> {
self.edge_attachments.get(&edge_index)
}
pub fn group_size(&self, node_id: &str, face: Face) -> usize {
self.group_sizes
.get(&(node_id.to_string(), face))
.copied()
.unwrap_or(0)
}
pub fn attachments(&self) -> impl Iterator<Item = (&usize, &PlannedEdgeAttachments)> + '_ {
self.edge_attachments.iter()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub enum AttachmentSide {
Source,
Target,
}
#[derive(Debug, Clone, PartialEq)]
pub struct AttachmentCandidate {
pub edge_index: usize,
pub node_id: String,
pub side: AttachmentSide,
pub face: Face,
pub cross_axis: f64,
}
pub fn plan_attachment_candidates(candidates: Vec<AttachmentCandidate>) -> AttachmentPlan {
let mut groups: HashMap<(String, Face), Vec<AttachmentCandidate>> = HashMap::new();
for candidate in candidates {
groups
.entry((candidate.node_id.clone(), candidate.face))
.or_default()
.push(candidate);
}
let mut plan = AttachmentPlan::default();
for ((node_id, face), mut group) in groups {
group.sort_by(compare_attachment_candidates);
plan.group_sizes
.insert((node_id.clone(), face), group.len());
for (idx, candidate) in group.iter().enumerate() {
let fraction = if group.len() <= 1 {
0.5
} else {
idx as f64 / (group.len() - 1) as f64
};
let attachment = EdgeAttachment { face, fraction };
let edge_entry = plan.edge_attachments.entry(candidate.edge_index).or_insert(
PlannedEdgeAttachments {
source: None,
target: None,
},
);
match candidate.side {
AttachmentSide::Source => {
edge_entry.source = Some(attachment);
plan.source_fractions
.entry((candidate.node_id.clone(), candidate.face))
.or_default()
.push(fraction);
}
AttachmentSide::Target => {
edge_entry.target = Some(attachment);
plan.target_fractions
.entry((candidate.node_id.clone(), candidate.face))
.or_default()
.push(fraction);
}
}
}
}
plan
}
fn compare_attachment_candidates(
a: &AttachmentCandidate,
b: &AttachmentCandidate,
) -> std::cmp::Ordering {
a.cross_axis
.total_cmp(&b.cross_axis)
.then_with(|| a.edge_index.cmp(&b.edge_index))
.then_with(|| a.side.cmp(&b.side))
}
pub fn edge_faces(direction: Direction, is_backward: bool) -> (Face, Face) {
let (forward_src, forward_tgt) = match direction {
Direction::TopDown => (Face::Bottom, Face::Top),
Direction::BottomTop => (Face::Top, Face::Bottom),
Direction::LeftRight => (Face::Right, Face::Left),
Direction::RightLeft => (Face::Left, Face::Right),
};
if is_backward {
(forward_tgt, forward_src)
} else {
(forward_src, forward_tgt)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn port_face_round_trips() {
assert_eq!(PortFace::Top.as_str(), "top");
assert_eq!(PortFace::Bottom.as_str(), "bottom");
assert_eq!(PortFace::Left.as_str(), "left");
assert_eq!(PortFace::Right.as_str(), "right");
assert_eq!("top".parse::<PortFace>(), Ok(PortFace::Top));
assert_eq!("bottom".parse::<PortFace>(), Ok(PortFace::Bottom));
assert_eq!("left".parse::<PortFace>(), Ok(PortFace::Left));
assert_eq!("right".parse::<PortFace>(), Ok(PortFace::Right));
assert_eq!("invalid".parse::<PortFace>(), Err(()));
}
#[test]
fn edge_port_construction() {
let port = EdgePort {
face: PortFace::Top,
fraction: 0.5,
position: FPoint::new(50.0, 10.0),
group_size: 1,
};
assert_eq!(port.face, PortFace::Top);
assert!((port.fraction - 0.5).abs() < f64::EPSILON);
assert_eq!(port.position, FPoint::new(50.0, 10.0));
assert_eq!(port.group_size, 1);
}
#[test]
fn prefer_backward_side_channel_uses_no_waypoint_fallback() {
assert!(prefer_backward_side_channel(true, false, None));
assert!(!prefer_backward_side_channel(true, true, None));
}
#[test]
fn prefer_backward_side_channel_uses_long_span_override() {
assert!(prefer_backward_side_channel(
true,
true,
Some(BACKWARD_SIDE_CHANNEL_LONG_RANK_SPAN)
));
assert!(!prefer_backward_side_channel(
true,
true,
Some(BACKWARD_SIDE_CHANNEL_LONG_RANK_SPAN - 1)
));
}
#[test]
fn prefer_backward_side_channel_ignores_forward_edges() {
assert!(!prefer_backward_side_channel(false, false, Some(10)));
}
#[test]
fn td_bt_backward_hint_parity_requires_safe_geometry() {
let source_rect = FRect::new(10.0, 10.0, 40.0, 40.0);
let target_rect = FRect::new(20.0, 0.0, 40.0, 40.0);
let source_center_x = source_rect.x + source_rect.width / 2.0;
assert!(can_apply_td_bt_backward_hint_parity(
Direction::TopDown,
true,
false,
2,
source_rect,
target_rect,
source_center_x
));
}
#[test]
fn td_bt_backward_hint_parity_rejects_long_span_and_crossing_topology() {
let source_rect = FRect::new(80.0, 10.0, 40.0, 40.0);
let target_rect = FRect::new(10.0, 0.0, 40.0, 40.0);
let source_center_x = source_rect.x + source_rect.width / 2.0;
assert!(!can_apply_td_bt_backward_hint_parity(
Direction::TopDown,
true,
false,
BACKWARD_SIDE_CHANNEL_LONG_RANK_SPAN,
source_rect,
target_rect,
source_center_x
));
assert!(!can_apply_td_bt_backward_hint_parity(
Direction::TopDown,
true,
false,
2,
source_rect,
target_rect,
source_center_x
));
}
}