use crate::{
context::{Node, Router},
operations::util::{extract_all_params, normalize, split_path},
types::{MatchedRoute, MethodData, ParamEntry},
};
use std::collections::HashSet;
fn is_last_param_optional_for_find_all<T>(md: &MethodData<T>) -> bool {
md.params_map.as_ref().is_some_and(|pm| {
pm.last().is_some_and(|p_entry| match p_entry {
ParamEntry::Index(_, _, is_opt) => *is_opt,
ParamEntry::Wildcard(_, _, is_opt) => *is_opt,
})
})
}
pub fn find_all_routes<T: Clone + Eq + std::hash::Hash>(
router: &Router<T>,
method: &str,
path: &str,
capture_params: bool,
) -> Vec<MatchedRoute<T>> {
let normalized_path_string = normalize(path);
let segments: Vec<&str> = split_path(&normalized_path_string).collect();
let mut collected_method_data_refs: Vec<&MethodData<T>> = Vec::new();
let root_lock = router.root.read();
find_all_recursive_ordered(
&*root_lock,
method,
&segments,
0,
&mut collected_method_data_refs,
);
let mut results = Vec::new();
let mut seen_t_values = HashSet::new();
for md_ref in collected_method_data_refs {
if seen_t_values.insert(md_ref.data.clone()) {
let params = if capture_params {
extract_all_params(&segments, &md_ref.params_map)
} else {
None
};
results.push(MatchedRoute {
data: md_ref.data.clone(),
params,
});
}
}
results
}
fn find_all_recursive_ordered<'a, T: Clone + Eq + std::hash::Hash>(
node: &'a Node<T>,
method: &str,
segments: &[&str],
idx: usize,
matches: &mut Vec<&'a MethodData<T>>,
) {
if let Some(wildcard_child_node) = &node.wildcard_child {
if let Some(handlers) = wildcard_child_node
.methods
.get(method)
.or_else(|| wildcard_child_node.methods.get(""))
{
matches.extend(handlers.iter());
}
}
let current_segment_val = if idx < segments.len() {
Some(segments[idx])
} else {
None
};
if let Some(param_child_node) = &node.param_child {
if current_segment_val.is_some() {
find_all_recursive_ordered(param_child_node, method, segments, idx + 1, matches);
}
if idx == segments.len() {
if let Some(handlers) = param_child_node
.methods
.get(method)
.or_else(|| param_child_node.methods.get(""))
{
if handlers.iter().any(is_last_param_optional_for_find_all) {
matches.extend(handlers.iter());
}
}
}
}
if let Some(segment_val) = current_segment_val {
if let Some(static_child_node) = node.static_children.get(segment_val) {
find_all_recursive_ordered(static_child_node, method, segments, idx + 1, matches);
}
}
if idx == segments.len() {
if let Some(handlers) = node.methods.get(method).or_else(|| node.methods.get("")) {
matches.extend(handlers.iter());
}
}
}