use super::{EdgeTraversal, SearchTreeNode};
use crate::algorithm::search::search_pruning;
use crate::model::label::LabelModel;
use crate::model::network::{EdgeId, EdgeListId, Graph, NetworkError, VertexId};
use crate::model::unit::AsF64;
use crate::{algorithm::search::Direction, model::label::Label};
use allocative::Allocative;
use ordered_float::OrderedFloat;
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
#[derive(Clone, Debug, Allocative)]
pub struct SearchTree {
nodes: HashMap<Label, SearchTreeNode>,
labels: HashMap<VertexId, HashSet<Label>>,
root: Option<Label>,
direction: Direction,
}
impl Default for SearchTree {
fn default() -> Self {
Self::new(Direction::Forward)
}
}
impl SearchTree {
pub fn new(direction: Direction) -> Self {
Self {
nodes: HashMap::new(),
labels: HashMap::new(),
root: None,
direction,
}
}
pub fn with_root(root_label: Label, orientation: Direction) -> Self {
let mut tree = Self::new(orientation);
tree.set_root(root_label);
tree
}
pub fn set_root(&mut self, root_label: Label) {
let root_node = SearchTreeNode::new_root(self.direction);
self.nodes.insert(root_label.clone(), root_node);
if root_label.needs_vertex_map_storage() {
self.labels
.entry(*root_label.vertex_id())
.and_modify(|l| {
let _ = l.insert(root_label.clone());
})
.or_insert(HashSet::from([root_label.clone()]));
}
self.root = Some(root_label);
}
pub fn insert(
&mut self,
parent_label: Label,
edge_traversal: EdgeTraversal,
child_label: Label,
label_model: Arc<dyn LabelModel>,
) -> Result<(), SearchTreeError> {
search_pruning::prune_tree(self, &child_label, &edge_traversal, label_model)?;
if !self.nodes.contains_key(&parent_label) {
if self.is_empty() {
self.set_root(parent_label.clone());
} else {
return Err(SearchTreeError::ParentNotFound(parent_label));
}
}
if let Some(parent_node) = self.nodes.get_mut(&parent_label) {
parent_node.increment_child_count();
}
let new_node =
SearchTreeNode::new_child(edge_traversal, parent_label.clone(), self.direction);
self.nodes.insert(child_label.clone(), new_node);
if child_label.needs_vertex_map_storage() {
self.labels
.entry(*child_label.vertex_id())
.and_modify(|l| {
let _ = l.insert(child_label.clone());
})
.or_insert(HashSet::from([child_label.clone()]));
}
Ok(())
}
pub fn remove(&mut self, label: &Label) -> Result<(), SearchTreeError> {
let node = self
.nodes
.remove(label)
.ok_or_else(|| SearchTreeError::LabelNotFound(label.clone()))?;
if let Some(parent_label) = node.parent_label() {
if let Some(parent_node) = self.nodes.get_mut(parent_label) {
parent_node.decrement_child_count();
}
}
if !matches!(label, Label::Vertex(_)) {
let vertex_id = label.vertex_id();
if let Some(label_set) = self.labels.get_mut(vertex_id) {
label_set.remove(label);
if label_set.is_empty() {
self.labels.remove(vertex_id);
}
}
}
Ok(())
}
pub fn iter<'a>(&'a self) -> Box<dyn Iterator<Item = (&'a Label, &'a SearchTreeNode)> + 'a> {
Box::new(self.nodes.iter())
}
pub fn keys<'a>(&'a self) -> Box<dyn Iterator<Item = &'a Label> + 'a> {
Box::new(self.nodes.keys())
}
pub fn values<'a>(&'a self) -> Box<dyn Iterator<Item = &'a SearchTreeNode> + 'a> {
Box::new(self.nodes.values())
}
pub fn get(&self, label: &Label) -> Option<&SearchTreeNode> {
self.nodes.get(label)
}
pub fn get_min_cost_label(&self, vertex: VertexId) -> Option<&Label> {
self.get_label_by(vertex, min_cost_ordering, true)
}
pub fn get_labels(&self, vertex: VertexId) -> Box<dyn Iterator<Item = Label> + '_> {
let vertex_label = Label::Vertex(vertex);
let vertex_iter = std::iter::once(vertex_label);
match self.labels.get(&vertex) {
Some(labels) => Box::new(vertex_iter.chain(labels.iter().cloned())),
None => Box::new(vertex_iter),
}
}
pub fn get_labels_iter(&self, vertex: VertexId) -> Box<dyn Iterator<Item = Label>> {
match self.labels.get(&vertex) {
Some(labels) => Box::new(labels.clone().into_iter()),
None => Box::new(std::iter::empty()),
}
}
pub fn get_labels_mut(&mut self, vertex: VertexId) -> Option<&mut HashSet<Label>> {
self.labels.get_mut(&vertex)
}
pub fn get_label_by<F>(&self, vertex: VertexId, mut compare: F, min: bool) -> Option<&Label>
where
F: FnMut(&(&Label, Option<&EdgeTraversal>)) -> OrderedFloat<f64>,
{
let label_edge_iter = self.get_labels(vertex).filter_map(|label| {
let (stored_label, node) = self.nodes.get_key_value(&label)?;
let edge_traversal = node.incoming_edge();
Some((stored_label, edge_traversal))
});
let found = if min {
label_edge_iter.min_by_key(|item| compare(item))
} else {
label_edge_iter.max_by_key(|item| compare(item))
};
found.map(|(label, _)| label)
}
pub fn get_mut(&mut self, label: &Label) -> Option<&mut SearchTreeNode> {
self.nodes.get_mut(label)
}
pub fn root(&self) -> Option<&Label> {
self.root.as_ref()
}
pub fn get_parent(&self, label: &Label) -> Option<&SearchTreeNode> {
let node = self.get(label)?;
let parent_label = node.parent_label()?;
self.get(parent_label)
}
pub fn contains(&self, label: &Label) -> bool {
self.nodes.contains_key(label)
}
pub fn len(&self) -> usize {
self.nodes.len()
}
pub fn is_empty(&self) -> bool {
self.nodes.is_empty()
}
pub fn direction(&self) -> Direction {
self.direction
}
pub fn backtrack_with_depth(
&self,
leaf_vertex: VertexId,
depth: u64,
) -> Result<Vec<EdgeTraversal>, SearchTreeError> {
let target_label = self
.get_label_by(leaf_vertex, min_cost_ordering, true)
.ok_or(SearchTreeError::VertexNotFound(leaf_vertex))?;
self.reconstruct_path(target_label, Some(depth))
}
pub fn backtrack(&self, leaf_vertex: VertexId) -> Result<Vec<EdgeTraversal>, SearchTreeError> {
let target_label = self
.get_label_by(leaf_vertex, min_cost_ordering, true)
.ok_or(SearchTreeError::VertexNotFound(leaf_vertex))?;
self.reconstruct_path(target_label, None)
}
pub fn backtrack_edge_oriented_route(
&self,
target: (EdgeListId, EdgeId),
graph: Arc<Graph>,
) -> Result<Vec<EdgeTraversal>, SearchTreeError> {
let (d_el, d_e) = target;
let d_v = graph.src_vertex_id(&d_el, &d_e)?;
self.backtrack(d_v)
}
pub fn reconstruct_path(
&self,
target_label: &Label,
depth: Option<u64>,
) -> Result<Vec<EdgeTraversal>, SearchTreeError> {
let mut path = Vec::new();
let mut current_label = target_label;
let mut steps: u64 = 0;
let mut visited = HashSet::new();
loop {
if !visited.insert(current_label.clone()) {
return Err(SearchTreeError::InvalidBranchStructure(format!(
"Cycle detected at label: {}",
current_label
)));
}
if steps > self.nodes.len() as u64 {
return Err(SearchTreeError::InvalidBranchStructure(format!(
"Exceeded tree size {} while backtracking from {}",
self.nodes.len(),
target_label
)));
}
let exceeds_depth = depth.map(|l| steps >= l).unwrap_or_default();
if exceeds_depth {
break;
}
let current_node = self
.get(current_label)
.ok_or_else(|| SearchTreeError::LabelNotFound(current_label.clone()))?;
match current_node {
SearchTreeNode::Root { .. } => break,
SearchTreeNode::Branch {
incoming_edge,
parent,
..
} => {
path.push(incoming_edge.clone());
current_label = parent;
}
}
steps += 1;
}
match self.direction {
Direction::Forward => {
path.reverse();
Ok(path)
}
Direction::Reverse => Ok(path),
}
}
pub fn labels(&self) -> impl Iterator<Item = &Label> {
self.nodes.keys()
}
pub fn nodes(&self) -> impl Iterator<Item = &SearchTreeNode> {
self.nodes.values()
}
pub fn get_incoming_edge(&self, vertex: VertexId) -> Option<&EdgeTraversal> {
let label = self.get_label_by(vertex, min_cost_ordering, true)?;
let node = self.get(label)?;
node.incoming_edge()
}
}
fn min_cost_ordering(pair: &(&Label, Option<&EdgeTraversal>)) -> OrderedFloat<f64> {
let (_, et) = pair;
match et {
None => OrderedFloat(f64::MAX),
Some(e) => OrderedFloat(e.cost.total_cost.as_f64()),
}
}
#[derive(Debug, thiserror::Error)]
pub enum SearchTreeError {
#[error("parent not found for label {0}")]
ParentNotFound(Label),
#[error("Label not found in tree: {0}")]
LabelNotFound(Label),
#[error("Label '{0}' exists in tree without matching SearchTreeNode")]
MissingNodeForLabel(Label),
#[error("Node is missing parent reference: {0}")]
MissingParent(Label),
#[error("Invalid branch structure: {0}")]
InvalidBranchStructure(String),
#[error("Vertex not found in tree: {0}")]
VertexNotFound(VertexId),
#[error("Cycle detected: {0}")]
CycleDetected(String),
#[error("Search tree error while interacting with Graph: {source}")]
NetworkError {
#[from]
source: NetworkError,
},
#[error("Failure while pruning tree: {0}")]
PruningError(String),
}
#[cfg(test)]
mod tests {
use super::*;
use crate::model::{
cost::TraversalCost,
label::default::vertex_label_model::VertexLabelModel,
network::{EdgeId, EdgeListId, VertexId},
unit::Cost,
};
#[test]
fn test_new_empty_tree() {
let tree = SearchTree::new(Direction::Forward);
assert!(tree.is_empty());
assert_eq!(tree.len(), 0);
assert_eq!(tree.direction(), Direction::Forward);
assert!(tree.root().is_none());
}
#[test]
fn test_tree_with_root() {
let root_label = create_test_label(0);
let tree = SearchTree::with_root(root_label.clone(), Direction::Forward);
assert!(!tree.is_empty());
assert_eq!(tree.len(), 1);
assert_eq!(tree.root(), Some(&root_label));
assert!(tree.contains(&root_label));
let root_node = tree.get(&root_label).unwrap();
assert!(root_node.is_root());
}
#[test]
fn test_insert_child_nodes() {
let root_label = create_test_label(0);
let mut tree = SearchTree::with_root(root_label.clone(), Direction::Forward);
let child1_label = create_test_label(1);
let child1_traversal = create_test_edge_traversal(1, 10.0);
tree.insert(
root_label.clone(),
child1_traversal.clone(),
child1_label.clone(),
mock_label_model(),
)
.unwrap();
let child2_label = create_test_label(2);
let child2_traversal = create_test_edge_traversal(2, 15.0);
tree.insert(
root_label.clone(),
child2_traversal.clone(),
child2_label.clone(),
mock_label_model(),
)
.unwrap();
assert_eq!(tree.len(), 3);
let child1_node = tree.get(&child1_label).unwrap();
assert!(!child1_node.is_root());
assert_eq!(child1_node.parent_label(), Some(&root_label));
assert_eq!(child1_node.incoming_edge().unwrap().edge_id, EdgeId(1));
let child2_node = tree.get(&child2_label).unwrap();
assert!(!child2_node.is_root());
assert_eq!(child2_node.parent_label(), Some(&root_label));
assert_eq!(child2_node.incoming_edge().unwrap().edge_id, EdgeId(2));
}
#[test]
fn test_insert_with_nonexistent_parent() {
let root_label = create_test_label(0);
let mut tree = SearchTree::with_root(root_label, Direction::Forward);
let child_label = create_test_label(1);
let child_traversal = create_test_edge_traversal(1, 10.0);
let nonexistent_parent = create_test_label(99);
let result = tree.insert(
nonexistent_parent.clone(),
child_traversal,
child_label,
mock_label_model(),
);
assert!(matches!(result, Err(SearchTreeError::ParentNotFound(_))));
}
#[test]
fn test_get_parent() {
let root_label = create_test_label(0);
let mut tree = SearchTree::with_root(root_label.clone(), Direction::Forward);
let child_label = create_test_label(1);
let child_traversal = create_test_edge_traversal(1, 10.0);
tree.insert(
root_label.clone(),
child_traversal,
child_label.clone(),
mock_label_model(),
)
.unwrap();
assert!(tree.get_parent(&root_label).is_none());
let parent = tree.get(&child_label).unwrap().parent_label().unwrap();
assert_eq!(parent, &root_label);
}
#[test]
fn test_reconstruct_path_forward_orientation() {
let root_label = create_test_label(0);
let mut tree = SearchTree::with_root(root_label.clone(), Direction::Forward);
let child1_label = create_test_label(1);
let child1_traversal = create_test_edge_traversal(1, 10.0);
tree.insert(
root_label.clone(),
child1_traversal.clone(),
child1_label.clone(),
mock_label_model(),
)
.unwrap();
let child2_label = create_test_label(2);
let child2_traversal = create_test_edge_traversal(2, 15.0);
tree.insert(
child1_label.clone(),
child2_traversal.clone(),
child2_label.clone(),
mock_label_model(),
)
.unwrap();
let child3_label = create_test_label(3);
let child3_traversal = create_test_edge_traversal(3, 20.0);
tree.insert(
child2_label.clone(),
child3_traversal.clone(),
child3_label.clone(),
mock_label_model(),
)
.unwrap();
let path = tree.reconstruct_path(&child3_label, None).unwrap();
assert_eq!(path.len(), 3);
assert_eq!(path[0].edge_id, EdgeId(1)); assert_eq!(path[1].edge_id, EdgeId(2)); assert_eq!(path[2].edge_id, EdgeId(3)); }
#[test]
fn test_reconstruct_path_reverse_orientation() {
let root_label = create_test_label(0);
let mut tree = SearchTree::with_root(root_label.clone(), Direction::Reverse);
let child1_label = create_test_label(1);
let child1_traversal = create_test_edge_traversal(1, 10.0);
tree.insert(
root_label.clone(),
child1_traversal.clone(),
child1_label.clone(),
mock_label_model(),
)
.unwrap();
let child2_label = create_test_label(2);
let child2_traversal = create_test_edge_traversal(2, 15.0);
tree.insert(
child1_label.clone(),
child2_traversal.clone(),
child2_label.clone(),
mock_label_model(),
)
.unwrap();
let child3_label = create_test_label(3);
let child3_traversal = create_test_edge_traversal(3, 20.0);
tree.insert(
child2_label.clone(),
child3_traversal.clone(),
child3_label.clone(),
mock_label_model(),
)
.unwrap();
let path = tree.reconstruct_path(&child3_label, None).unwrap();
assert_eq!(path.len(), 3);
assert_eq!(path[0].edge_id, EdgeId(3)); assert_eq!(path[1].edge_id, EdgeId(2)); assert_eq!(path[2].edge_id, EdgeId(1)); }
#[test]
fn test_reconstruct_path_nonexistent_label() {
let root_label = create_test_label(0);
let tree = SearchTree::with_root(root_label, Direction::Forward);
let nonexistent_label = create_test_label(99);
let result = tree.reconstruct_path(&nonexistent_label, None);
assert!(matches!(result, Err(SearchTreeError::LabelNotFound(_))));
}
#[test]
fn test_iterators() {
let root_label = create_test_label(0);
let mut tree = SearchTree::with_root(root_label.clone(), Direction::Forward);
let child1_label = create_test_label(1);
let child1_traversal = create_test_edge_traversal(1, 10.0);
tree.insert(
root_label.clone(),
child1_traversal,
child1_label.clone(),
mock_label_model(),
)
.unwrap();
let child2_label = create_test_label(2);
let child2_traversal = create_test_edge_traversal(2, 15.0);
tree.insert(
root_label.clone(),
child2_traversal,
child2_label.clone(),
mock_label_model(),
)
.unwrap();
let labels: HashSet<_> = tree.labels().cloned().collect();
assert_eq!(labels.len(), 3);
assert!(labels.contains(&root_label));
assert!(labels.contains(&child1_label));
assert!(labels.contains(&child2_label));
let node_count = tree.nodes().count();
assert_eq!(node_count, 3);
let vertex_ids: HashSet<_> = tree.labels().map(|l| l.vertex_id()).collect();
assert_eq!(vertex_ids.len(), 3);
assert!(vertex_ids.contains(&VertexId(0)));
assert!(vertex_ids.contains(&VertexId(1)));
assert!(vertex_ids.contains(&VertexId(2)));
}
#[test]
fn test_backtrack_forward_tree() {
let root_label = create_test_label(0);
let mut tree = SearchTree::with_root(root_label.clone(), Direction::Forward);
let child1_label = create_test_label(1);
let child1_traversal = create_test_edge_traversal(1, 10.0);
tree.insert(
root_label.clone(),
child1_traversal.clone(),
child1_label.clone(),
mock_label_model(),
)
.unwrap();
let child2_label = create_test_label(2);
let child2_traversal = create_test_edge_traversal(2, 15.0);
tree.insert(
child1_label.clone(),
child2_traversal.clone(),
child2_label.clone(),
mock_label_model(),
)
.unwrap();
let child3_label = create_test_label(3);
let child3_traversal = create_test_edge_traversal(3, 20.0);
tree.insert(
child2_label.clone(),
child3_traversal.clone(),
child3_label.clone(),
mock_label_model(),
)
.unwrap();
let path = tree.backtrack(VertexId(3)).unwrap();
assert_eq!(path.len(), 3);
assert_eq!(path[0].edge_id, EdgeId(1)); assert_eq!(path[1].edge_id, EdgeId(2)); assert_eq!(path[2].edge_id, EdgeId(3)); }
#[test]
fn test_backtrack_reverse_tree() {
let root_label = create_test_label(0);
let mut tree = SearchTree::with_root(root_label.clone(), Direction::Reverse);
let child1_label = create_test_label(1);
let child1_traversal = create_test_edge_traversal(1, 10.0);
tree.insert(
root_label.clone(),
child1_traversal.clone(),
child1_label.clone(),
mock_label_model(),
)
.unwrap();
let child2_label = create_test_label(2);
let child2_traversal = create_test_edge_traversal(2, 15.0);
tree.insert(
child1_label.clone(),
child2_traversal.clone(),
child2_label.clone(),
mock_label_model(),
)
.unwrap();
let child3_label = create_test_label(3);
let child3_traversal = create_test_edge_traversal(3, 20.0);
tree.insert(
child2_label.clone(),
child3_traversal.clone(),
child3_label.clone(),
mock_label_model(),
)
.unwrap();
let path = tree.backtrack(VertexId(3)).unwrap();
assert_eq!(path.len(), 3);
assert_eq!(path[0].edge_id, EdgeId(3)); assert_eq!(path[1].edge_id, EdgeId(2)); assert_eq!(path[2].edge_id, EdgeId(1)); }
#[test]
fn test_backtrack_nonexistent_vertex() {
let root_label = create_test_label(0);
let tree = SearchTree::with_root(root_label, Direction::Forward);
let result = tree.backtrack(VertexId(99));
assert!(matches!(
result,
Err(SearchTreeError::VertexNotFound(VertexId(99)))
));
}
#[test]
fn test_backtrack_root_vertex() {
let root_label = create_test_label(0);
let tree = SearchTree::with_root(root_label.clone(), Direction::Forward);
let path = tree.backtrack(VertexId(0)).unwrap();
assert_eq!(path.len(), 0);
}
#[test]
fn test_find_label_for_vertex() {
let root_label = create_test_label(0);
let mut tree = SearchTree::with_root(root_label.clone(), Direction::Forward);
let child1_label = create_test_label(1);
let child1_traversal = create_test_edge_traversal(1, 10.0);
tree.insert(
root_label.clone(),
child1_traversal,
child1_label.clone(),
mock_label_model(),
)
.unwrap();
let found_label = tree.get_min_cost_label(VertexId(1));
assert_eq!(found_label, Some(&child1_label));
let not_found = tree.get_min_cost_label(VertexId(99));
assert_eq!(not_found, None);
}
#[test]
fn test_auto_root_creation() {
let mut tree = SearchTree::new(Direction::Forward);
assert!(tree.is_empty());
assert!(tree.root().is_none());
let parent_label = create_test_label(0);
let child_label = create_test_label(1);
let edge_traversal = create_test_edge_traversal(1, 10.0);
tree.insert(
parent_label.clone(),
edge_traversal.clone(),
child_label.clone(),
mock_label_model(),
)
.unwrap();
assert!(!tree.is_empty());
assert_eq!(tree.len(), 2); assert_eq!(tree.root(), Some(&parent_label));
let root_node = tree.get(&parent_label).unwrap();
assert!(root_node.is_root());
assert!(tree.nodes.contains_key(&parent_label));
let child_node = tree.get(&child_label).unwrap();
assert!(!child_node.is_root());
assert_eq!(child_node.parent_label(), Some(&parent_label));
assert_eq!(child_node.incoming_edge().unwrap().edge_id, EdgeId(1));
}
#[test]
fn test_auto_root_creation_chain() {
let mut tree = SearchTree::new(Direction::Forward);
let label0 = create_test_label(0);
let label1 = create_test_label(1);
let label2 = create_test_label(2);
let label3 = create_test_label(3);
tree.insert(
label0.clone(),
create_test_edge_traversal(1, 10.0),
label1.clone(),
mock_label_model(),
)
.unwrap();
tree.insert(
label1.clone(),
create_test_edge_traversal(2, 15.0),
label2.clone(),
mock_label_model(),
)
.unwrap();
tree.insert(
label2.clone(),
create_test_edge_traversal(3, 20.0),
label3.clone(),
mock_label_model(),
)
.unwrap();
assert_eq!(tree.len(), 4);
assert_eq!(tree.root(), Some(&label0));
let path = tree.backtrack(VertexId(3)).unwrap();
assert_eq!(path.len(), 3);
assert_eq!(path[0].edge_id, EdgeId(1)); assert_eq!(path[1].edge_id, EdgeId(2)); assert_eq!(path[2].edge_id, EdgeId(3)); }
#[test]
fn test_insert_without_auto_root_when_parent_exists() {
let mut tree = SearchTree::new(Direction::Forward);
let root_label = create_test_label(0);
tree.set_root(root_label.clone());
let child_label = create_test_label(1);
let edge_traversal = create_test_edge_traversal(1, 10.0);
tree.insert(
root_label.clone(),
edge_traversal,
child_label.clone(),
mock_label_model(),
)
.unwrap();
assert_eq!(tree.len(), 2);
assert_eq!(tree.root(), Some(&root_label));
let orphan_label = create_test_label(99);
let nonexistent_parent = create_test_label(999);
let result = tree.insert(
orphan_label,
create_test_edge_traversal(99, 5.0),
nonexistent_parent.clone(),
mock_label_model(),
);
assert!(matches!(result, Err(SearchTreeError::ParentNotFound(_))));
}
#[test]
fn test_backtrack_with_depth_forward_tree_full_path() {
let root_label = create_test_label(0);
let mut tree = SearchTree::with_root(root_label.clone(), Direction::Forward);
let child1_label = create_test_label(1);
let child1_traversal = create_test_edge_traversal(1, 10.0);
tree.insert(
root_label.clone(),
child1_traversal,
child1_label.clone(),
mock_label_model(),
)
.unwrap();
let child2_label = create_test_label(2);
let child2_traversal = create_test_edge_traversal(2, 15.0);
tree.insert(
child1_label.clone(),
child2_traversal,
child2_label.clone(),
mock_label_model(),
)
.unwrap();
let child3_label = create_test_label(3);
let child3_traversal = create_test_edge_traversal(3, 20.0);
tree.insert(
child2_label.clone(),
child3_traversal,
child3_label.clone(),
mock_label_model(),
)
.unwrap();
let child4_label = create_test_label(4);
let child4_traversal = create_test_edge_traversal(4, 25.0);
tree.insert(
child3_label.clone(),
child4_traversal,
child4_label.clone(),
mock_label_model(),
)
.unwrap();
let path = tree.backtrack_with_depth(VertexId(4), 4).unwrap();
assert_eq!(path.len(), 4);
assert_eq!(path[0].edge_id, EdgeId(1)); assert_eq!(path[1].edge_id, EdgeId(2)); assert_eq!(path[2].edge_id, EdgeId(3)); assert_eq!(path[3].edge_id, EdgeId(4)); }
#[test]
fn test_backtrack_with_depth_forward_tree_limited_depth() {
let root_label = create_test_label(0);
let mut tree = SearchTree::with_root(root_label.clone(), Direction::Forward);
let child1_label = create_test_label(1);
let child1_traversal = create_test_edge_traversal(1, 10.0);
tree.insert(
root_label.clone(),
child1_traversal,
child1_label.clone(),
mock_label_model(),
)
.unwrap();
let child2_label = create_test_label(2);
let child2_traversal = create_test_edge_traversal(2, 15.0);
tree.insert(
child1_label.clone(),
child2_traversal,
child2_label.clone(),
mock_label_model(),
)
.unwrap();
let child3_label = create_test_label(3);
let child3_traversal = create_test_edge_traversal(3, 20.0);
tree.insert(
child2_label.clone(),
child3_traversal,
child3_label.clone(),
mock_label_model(),
)
.unwrap();
let child4_label = create_test_label(4);
let child4_traversal = create_test_edge_traversal(4, 25.0);
tree.insert(
child3_label.clone(),
child4_traversal,
child4_label.clone(),
mock_label_model(),
)
.unwrap();
let path = tree.backtrack_with_depth(VertexId(4), 2).unwrap();
assert_eq!(path.len(), 2);
assert_eq!(path[0].edge_id, EdgeId(3)); assert_eq!(path[1].edge_id, EdgeId(4)); }
#[test]
fn test_backtrack_with_depth_forward_tree_depth_one() {
let root_label = create_test_label(0);
let mut tree = SearchTree::with_root(root_label.clone(), Direction::Forward);
let child1_label = create_test_label(1);
let child1_traversal = create_test_edge_traversal(1, 10.0);
tree.insert(
root_label.clone(),
child1_traversal,
child1_label.clone(),
mock_label_model(),
)
.unwrap();
let child2_label = create_test_label(2);
let child2_traversal = create_test_edge_traversal(2, 15.0);
tree.insert(
child1_label.clone(),
child2_traversal,
child2_label.clone(),
mock_label_model(),
)
.unwrap();
let child3_label = create_test_label(3);
let child3_traversal = create_test_edge_traversal(3, 20.0);
tree.insert(
child2_label.clone(),
child3_traversal,
child3_label.clone(),
mock_label_model(),
)
.unwrap();
let path = tree.backtrack_with_depth(VertexId(3), 1).unwrap();
assert_eq!(path.len(), 1);
assert_eq!(path[0].edge_id, EdgeId(3)); }
#[test]
fn test_backtrack_with_depth_reverse_tree_full_path() {
let root_label = create_test_label(0);
let mut tree = SearchTree::with_root(root_label.clone(), Direction::Reverse);
let child1_label = create_test_label(1);
let child1_traversal = create_test_edge_traversal(1, 10.0);
tree.insert(
root_label.clone(),
child1_traversal,
child1_label.clone(),
mock_label_model(),
)
.unwrap();
let child2_label = create_test_label(2);
let child2_traversal = create_test_edge_traversal(2, 15.0);
tree.insert(
child1_label.clone(),
child2_traversal,
child2_label.clone(),
mock_label_model(),
)
.unwrap();
let child3_label = create_test_label(3);
let child3_traversal = create_test_edge_traversal(3, 20.0);
tree.insert(
child2_label.clone(),
child3_traversal,
child3_label.clone(),
mock_label_model(),
)
.unwrap();
let child4_label = create_test_label(4);
let child4_traversal = create_test_edge_traversal(4, 25.0);
tree.insert(
child3_label.clone(),
child4_traversal,
child4_label.clone(),
mock_label_model(),
)
.unwrap();
let path = tree.backtrack_with_depth(VertexId(4), 4).unwrap();
assert_eq!(path.len(), 4);
assert_eq!(path[0].edge_id, EdgeId(4)); assert_eq!(path[1].edge_id, EdgeId(3)); assert_eq!(path[2].edge_id, EdgeId(2)); assert_eq!(path[3].edge_id, EdgeId(1)); }
#[test]
fn test_backtrack_with_depth_reverse_tree_limited_depth() {
let root_label = create_test_label(0);
let mut tree = SearchTree::with_root(root_label.clone(), Direction::Reverse);
let child1_label = create_test_label(1);
let child1_traversal = create_test_edge_traversal(1, 10.0);
tree.insert(
root_label.clone(),
child1_traversal,
child1_label.clone(),
mock_label_model(),
)
.unwrap();
let child2_label = create_test_label(2);
let child2_traversal = create_test_edge_traversal(2, 15.0);
tree.insert(
child1_label.clone(),
child2_traversal,
child2_label.clone(),
mock_label_model(),
)
.unwrap();
let child3_label = create_test_label(3);
let child3_traversal = create_test_edge_traversal(3, 20.0);
tree.insert(
child2_label.clone(),
child3_traversal,
child3_label.clone(),
mock_label_model(),
)
.unwrap();
let child4_label = create_test_label(4);
let child4_traversal = create_test_edge_traversal(4, 25.0);
tree.insert(
child3_label.clone(),
child4_traversal,
child4_label.clone(),
mock_label_model(),
)
.unwrap();
let path = tree.backtrack_with_depth(VertexId(4), 2).unwrap();
assert_eq!(path.len(), 2);
assert_eq!(path[0].edge_id, EdgeId(4)); assert_eq!(path[1].edge_id, EdgeId(3)); }
#[test]
fn test_backtrack_with_depth_from_root() {
let root_label = create_test_label(0);
let tree = SearchTree::with_root(root_label.clone(), Direction::Forward);
let path = tree.backtrack_with_depth(VertexId(0), 5).unwrap();
assert_eq!(path.len(), 0);
}
#[test]
fn test_backtrack_with_depth_nonexistent_vertex() {
let root_label = create_test_label(0);
let tree = SearchTree::with_root(root_label, Direction::Forward);
let result = tree.backtrack_with_depth(VertexId(99), 1);
assert!(matches!(
result,
Err(SearchTreeError::VertexNotFound(VertexId(99)))
));
}
#[test]
fn test_backtrack_with_depth_exceeds_available_path() {
let root_label = create_test_label(0);
let mut tree = SearchTree::with_root(root_label.clone(), Direction::Forward);
let child1_label = create_test_label(1);
let child1_traversal = create_test_edge_traversal(1, 10.0);
tree.insert(
root_label.clone(),
child1_traversal,
child1_label.clone(),
mock_label_model(),
)
.unwrap();
let child2_label = create_test_label(2);
let child2_traversal = create_test_edge_traversal(2, 15.0);
tree.insert(
child1_label.clone(),
child2_traversal,
child2_label.clone(),
mock_label_model(),
)
.unwrap();
let path = tree.backtrack_with_depth(VertexId(2), 10).unwrap();
assert_eq!(path.len(), 2);
assert_eq!(path[0].edge_id, EdgeId(1)); assert_eq!(path[1].edge_id, EdgeId(2)); }
#[test]
fn test_backtrack_with_depth_branching_tree() {
let root_label = create_test_label(0);
let mut tree = SearchTree::with_root(root_label.clone(), Direction::Forward);
let child1_label = create_test_label(1);
let child1_traversal = create_test_edge_traversal(1, 10.0);
tree.insert(
root_label.clone(),
child1_traversal,
child1_label.clone(),
mock_label_model(),
)
.unwrap();
let child2_label = create_test_label(2);
let child2_traversal = create_test_edge_traversal(2, 15.0);
tree.insert(
root_label.clone(),
child2_traversal,
child2_label.clone(),
mock_label_model(),
)
.unwrap();
let child3_label = create_test_label(3);
let child3_traversal = create_test_edge_traversal(3, 20.0);
tree.insert(
child1_label.clone(),
child3_traversal,
child3_label.clone(),
mock_label_model(),
)
.unwrap();
let child4_label = create_test_label(4);
let child4_traversal = create_test_edge_traversal(4, 25.0);
tree.insert(
child2_label.clone(),
child4_traversal,
child4_label.clone(),
mock_label_model(),
)
.unwrap();
let child5_label = create_test_label(5);
let child5_traversal = create_test_edge_traversal(5, 30.0);
tree.insert(
child4_label.clone(),
child5_traversal,
child5_label.clone(),
mock_label_model(),
)
.unwrap();
let path = tree.backtrack_with_depth(VertexId(3), 1).unwrap();
assert_eq!(path.len(), 1);
assert_eq!(path[0].edge_id, EdgeId(3));
let path = tree.backtrack_with_depth(VertexId(5), 2).unwrap();
assert_eq!(path.len(), 2);
assert_eq!(path[0].edge_id, EdgeId(4)); assert_eq!(path[1].edge_id, EdgeId(5));
let path = tree.backtrack_with_depth(VertexId(5), 3).unwrap();
assert_eq!(path.len(), 3);
assert_eq!(path[0].edge_id, EdgeId(2)); assert_eq!(path[1].edge_id, EdgeId(4)); assert_eq!(path[2].edge_id, EdgeId(5)); }
fn create_test_edge_traversal(edge_id: usize, cost: f64) -> EdgeTraversal {
EdgeTraversal {
edge_id: EdgeId(edge_id),
edge_list_id: EdgeListId(0),
cost: TraversalCost {
total_cost: Cost::new(cost),
objective_cost: Cost::new(cost),
#[cfg(feature = "detailed_costs")]
cost_component: std::collections::HashMap::new(),
},
result_state: vec![],
}
}
fn create_test_label(vertex_id: usize) -> Label {
Label::Vertex(VertexId(vertex_id))
}
#[test]
fn test_backtrack_mixed_labels_bug() {
let mut tree = SearchTree::new(Direction::Forward);
let root_label = Label::Vertex(VertexId(0));
let child_label = Label::VertexWithIntState {
vertex_id: VertexId(1),
state: 1,
};
tree.insert(
root_label.clone(),
create_test_edge_traversal(1, 10.0),
child_label.clone(),
mock_label_model(),
)
.unwrap();
let result = tree.backtrack(VertexId(0));
assert!(
result.is_ok(),
"Backtracking from root Vertex label should succeed even if tree has mixed labels"
);
}
#[test]
fn test_vertex_label_model_optimization_correctness() {
let mut tree = SearchTree::new(Direction::Forward);
let root_id = VertexId(0);
let child_id = VertexId(1);
let root_label = Label::Vertex(root_id);
let child_label = Label::Vertex(child_id);
let et = create_test_edge_traversal(1, 10.0);
tree.insert(
root_label.clone(),
et,
child_label.clone(),
mock_label_model(),
)
.unwrap();
assert!(
tree.labels.is_empty(),
"Tree labels map should be empty for pure Vertex labels"
);
let result = tree.backtrack(child_id);
assert!(
result.is_ok(),
"Backtracking failed for Vertex label: {:?}",
result.err()
);
let path = result.unwrap();
assert_eq!(path.len(), 1);
assert_eq!(path[0].edge_id, EdgeId(1));
let root_result = tree.backtrack(root_id);
assert!(root_result.is_ok());
assert_eq!(root_result.unwrap().len(), 0);
}
#[test]
fn test_get_incoming_edge() {
let root_label = create_test_label(0);
let mut tree = SearchTree::with_root(root_label.clone(), Direction::Forward);
let child1_label = create_test_label(1);
let child1_traversal = create_test_edge_traversal(1, 10.0);
tree.insert(
root_label.clone(),
child1_traversal,
child1_label.clone(),
mock_label_model(),
)
.unwrap();
let child2_label = create_test_label(2);
let child2_traversal = create_test_edge_traversal(2, 15.0);
tree.insert(
child1_label.clone(),
child2_traversal,
child2_label.clone(),
mock_label_model(),
)
.unwrap();
let child3_label = create_test_label(3);
let child3_traversal = create_test_edge_traversal(3, 20.0);
tree.insert(
child2_label.clone(),
child3_traversal,
child3_label.clone(),
mock_label_model(),
)
.unwrap();
let edge1 = tree.get_incoming_edge(VertexId(1));
assert!(edge1.is_some());
assert_eq!(edge1.unwrap().edge_id, EdgeId(1));
let edge2 = tree.get_incoming_edge(VertexId(2));
assert!(edge2.is_some());
assert_eq!(edge2.unwrap().edge_id, EdgeId(2));
let edge3 = tree.get_incoming_edge(VertexId(3));
assert!(edge3.is_some());
assert_eq!(edge3.unwrap().edge_id, EdgeId(3));
let edge_root = tree.get_incoming_edge(VertexId(0));
assert!(edge_root.is_none());
let edge_none = tree.get_incoming_edge(VertexId(99));
assert!(edge_none.is_none());
}
fn mock_label_model() -> Arc<dyn LabelModel> {
Arc::new(VertexLabelModel)
}
}