use std::{
collections::{HashMap, HashSet},
hash::Hash,
};
#[derive(PartialEq, Eq)]
pub enum SortDirection {
Asc,
Desc,
}
#[derive(Debug)]
pub enum NodeSortingError<ID> {
CyclicParentRelation { cycle_node_ids: Vec<ID> },
}
pub fn sort_nodes_by_parent_depth<T, ID, IDExtractor, ParentExtractor>(
nodes: &mut Vec<T>,
direction: SortDirection,
extract_id: IDExtractor,
extract_parent: ParentExtractor,
) -> Result<(), NodeSortingError<ID>>
where
IDExtractor: Fn(&T) -> ID,
ParentExtractor: Fn(&T) -> Option<ID>,
ID: Eq + Hash + Clone,
{
let id_to_index: HashMap<ID, usize> = nodes.iter().enumerate().map(|(idx, node)| (extract_id(node), idx)).collect();
let mut depths = Vec::with_capacity(nodes.len());
for i in 0..nodes.len() {
let depth = calculate_depth::<T, ID, IDExtractor, ParentExtractor>(
i,
nodes,
&id_to_index,
&extract_id,
&extract_parent,
)?;
depths.push(depth);
}
let mut pairs: Vec<(i64, T)> = nodes.drain(..).enumerate().map(|(i, node)| (depths[i] as i64, node)).collect();
pairs.sort_by_key(|(depth, _)| match direction {
SortDirection::Asc => *depth,
SortDirection::Desc => -*depth,
});
nodes.extend(pairs.into_iter().map(|(_, node)| node));
Ok(())
}
fn calculate_depth<T, ID, IDExtractor, ParentExtractor>(
start_idx: usize,
nodes: &[T],
id_to_index: &HashMap<ID, usize>,
extract_id: &IDExtractor,
extract_parent: &ParentExtractor,
) -> Result<usize, NodeSortingError<ID>>
where
IDExtractor: Fn(&T) -> ID,
ParentExtractor: Fn(&T) -> Option<ID>,
ID: Eq + Hash + Clone,
{
let mut visited = Vec::new();
let mut visited_set = HashSet::new();
let mut current_idx = start_idx;
let mut depth = 0;
loop {
let current_id = extract_id(&nodes[current_idx]);
if visited_set.contains(¤t_idx) {
visited.push(current_id);
return Err(NodeSortingError::CyclicParentRelation { cycle_node_ids: visited });
}
visited.push(current_id.clone());
visited_set.insert(current_idx);
match extract_parent(&nodes[current_idx]) {
None => return Ok(depth),
Some(parent_id) => {
match id_to_index.get(&parent_id) {
None => return Ok(depth), Some(&parent_idx) => {
current_idx = parent_idx;
depth += 1;
}
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
struct DemoNode {
id: usize,
parent_node: Option<usize>,
}
impl DemoNode {
fn build_parent_relation(id: usize, parent: Option<usize>) -> Self {
Self { id, parent_node: parent }
}
}
#[test]
fn test_sort_no_parents() {
let mut nodes = vec![
DemoNode::build_parent_relation(3, None),
DemoNode::build_parent_relation(1, None),
DemoNode::build_parent_relation(2, None),
];
sort_nodes_by_parent_depth(&mut nodes, SortDirection::Asc, |n| n.id, |n| n.parent_node).unwrap();
assert_eq!(nodes.len(), 3);
for node in &nodes {
assert_eq!(node.parent_node, None);
}
}
#[test]
fn test_sort_simple_hierarchy() {
let mut nodes = vec![
DemoNode::build_parent_relation(3, Some(2)), DemoNode::build_parent_relation(1, None), DemoNode::build_parent_relation(2, Some(1)), ];
sort_nodes_by_parent_depth(&mut nodes, SortDirection::Asc, |n| n.id, |n| n.parent_node).unwrap();
assert_eq!(nodes[0].id, 1); assert_eq!(nodes[1].id, 2); assert_eq!(nodes[2].id, 3); }
#[test]
fn test_sort_simple_hierarchy_desc() {
let mut nodes = vec![
DemoNode::build_parent_relation(3, Some(2)), DemoNode::build_parent_relation(1, None), DemoNode::build_parent_relation(2, Some(1)), ];
sort_nodes_by_parent_depth(&mut nodes, SortDirection::Desc, |n| n.id, |n| n.parent_node).unwrap();
assert_eq!(nodes[0].id, 3);
assert_eq!(nodes[1].id, 2);
assert_eq!(nodes[2].id, 1);
}
#[test]
fn test_sort_multiple_roots() {
let mut nodes = vec![
DemoNode::build_parent_relation(4, Some(3)), DemoNode::build_parent_relation(2, Some(1)), DemoNode::build_parent_relation(1, None), DemoNode::build_parent_relation(3, None), ];
let result = sort_nodes_by_parent_depth(&mut nodes, SortDirection::Asc, |n| n.id, |n| n.parent_node);
assert!(result.is_ok());
assert_eq!(nodes[0].parent_node, None);
assert_eq!(nodes[1].parent_node, None);
assert!(nodes[2].parent_node.is_some());
assert!(nodes[3].parent_node.is_some());
}
#[test]
fn test_cyclic_relation_self_reference() {
let mut nodes = vec![
DemoNode::build_parent_relation(1, Some(1)), ];
let result = sort_nodes_by_parent_depth(&mut nodes, SortDirection::Asc, |n| n.id, |n| n.parent_node);
assert!(result.is_err());
match result {
Err(NodeSortingError::CyclicParentRelation { cycle_node_ids }) => {
assert!(cycle_node_ids.contains(&1));
}
_ => panic!("Expected CyclicParentRelation error"),
}
}
#[test]
fn test_cyclic_relation_two_nodes() {
let mut nodes = vec![DemoNode::build_parent_relation(1, Some(2)), DemoNode::build_parent_relation(2, Some(1))];
let result = sort_nodes_by_parent_depth(&mut nodes, SortDirection::Asc, |n| n.id, |n| n.parent_node);
assert!(result.is_err());
match result {
Err(NodeSortingError::CyclicParentRelation { cycle_node_ids }) => {
assert!(cycle_node_ids.contains(&1));
assert!(cycle_node_ids.contains(&2));
}
_ => panic!("Expected CyclicParentRelation error"),
}
}
#[test]
fn test_cyclic_relation_three_nodes() {
let mut nodes = vec![
DemoNode::build_parent_relation(1, Some(2)),
DemoNode::build_parent_relation(2, Some(3)),
DemoNode::build_parent_relation(3, Some(1)),
];
let result = sort_nodes_by_parent_depth(&mut nodes, SortDirection::Asc, |n| n.id, |n| n.parent_node);
assert!(result.is_err());
match result {
Err(NodeSortingError::CyclicParentRelation { cycle_node_ids }) => {
assert_eq!(cycle_node_ids.len(), 4); assert!(cycle_node_ids.contains(&1));
assert!(cycle_node_ids.contains(&2));
assert!(cycle_node_ids.contains(&3));
}
_ => panic!("Expected CyclicParentRelation error"),
}
}
#[test]
fn test_missing_parent() {
let mut nodes = vec![
DemoNode::build_parent_relation(1, Some(99)), ];
sort_nodes_by_parent_depth(&mut nodes, SortDirection::Asc, |n| n.id, |n| n.parent_node).unwrap();
}
#[test]
fn test_complex_hierarchy() {
let mut nodes = vec![
DemoNode::build_parent_relation(7, Some(5)), DemoNode::build_parent_relation(1, None), DemoNode::build_parent_relation(5, Some(3)), DemoNode::build_parent_relation(3, Some(1)), DemoNode::build_parent_relation(2, None), DemoNode::build_parent_relation(4, Some(2)), DemoNode::build_parent_relation(6, Some(4)), ];
sort_nodes_by_parent_depth(&mut nodes, SortDirection::Asc, |n| n.id, |n| n.parent_node).unwrap();
let depths: Vec<usize> = nodes
.iter()
.map(|node| {
let mut depth = 0;
let mut current_parent = node.parent_node;
while let Some(parent_id) = current_parent {
depth += 1;
current_parent = nodes.iter().find(|n| n.id == parent_id).and_then(|n| n.parent_node);
}
depth
})
.collect();
for i in 1..depths.len() {
assert!(depths[i] >= depths[i - 1]);
}
}
}