use crate::VectorType;
use crate::node::NodeId;
use crate::storage::memtable::MemTable;
use std::collections::{HashMap, HashSet, VecDeque};
pub type Path = Vec<NodeId>;
pub fn shortest_path<T: VectorType>(
mt: &MemTable<T>,
src: NodeId,
dst: NodeId,
max_depth: usize,
label_filter: Option<&str>,
) -> Option<Path> {
if src == dst {
return Some(vec![src]);
}
if max_depth == 0 {
return None;
}
let mut visited: HashSet<NodeId> = HashSet::new();
visited.insert(src);
let mut queue: VecDeque<(NodeId, Vec<NodeId>)> = VecDeque::new();
queue.push_back((src, vec![src]));
while let Some((current, path)) = queue.pop_front() {
if path.len() > max_depth {
break;
}
if let Some(edges) = mt.get_edges(current) {
for edge in edges {
if let Some(lf) = label_filter
&& edge.label != lf
{
continue;
}
let next = edge.target_id;
if next == dst {
let mut result = path.clone();
result.push(dst);
return Some(result);
}
if !visited.contains(&next) && path.len() < max_depth {
visited.insert(next);
let mut new_path = path.clone();
new_path.push(next);
queue.push_back((next, new_path));
}
}
}
}
None
}
pub fn variable_length_paths<T: VectorType>(
mt: &MemTable<T>,
src: NodeId,
min_depth: usize,
max_depth: usize,
label_filter: Option<&str>,
limit: usize,
) -> Vec<(NodeId, Path)> {
let mut results = Vec::new();
let mut visited = HashSet::new();
visited.insert(src);
dfs_variable_length(
mt,
src,
&vec![src],
min_depth,
max_depth,
label_filter,
&mut visited,
&mut results,
limit,
);
results
}
fn dfs_variable_length<T: VectorType>(
mt: &MemTable<T>,
current: NodeId,
path: &Vec<NodeId>,
min_depth: usize,
max_depth: usize,
label_filter: Option<&str>,
visited: &mut HashSet<NodeId>,
results: &mut Vec<(NodeId, Path)>,
limit: usize,
) {
let depth = path.len() - 1;
if depth >= min_depth {
results.push((current, path.clone()));
if results.len() >= limit {
return;
}
}
if depth >= max_depth {
return;
}
if let Some(edges) = mt.get_edges(current) {
for edge in edges {
if let Some(lf) = label_filter
&& edge.label != lf
{
continue;
}
let next = edge.target_id;
if visited.contains(&next) {
continue; }
visited.insert(next);
let mut new_path = path.clone();
new_path.push(next);
dfs_variable_length(
mt,
next,
&new_path,
min_depth,
max_depth,
label_filter,
visited,
results,
limit,
);
if results.len() >= limit {
visited.remove(&next);
return;
}
visited.remove(&next); }
}
}
pub fn all_paths<T: VectorType>(
mt: &MemTable<T>,
src: NodeId,
dst: NodeId,
max_depth: usize,
label_filter: Option<&str>,
limit: usize,
) -> Vec<Path> {
let mut results = Vec::new();
let mut visited = HashSet::new();
visited.insert(src);
dfs_all_paths(
mt,
src,
dst,
&vec![src],
max_depth,
label_filter,
&mut visited,
&mut results,
limit,
);
results
}
fn dfs_all_paths<T: VectorType>(
mt: &MemTable<T>,
current: NodeId,
dst: NodeId,
path: &Vec<NodeId>,
max_depth: usize,
label_filter: Option<&str>,
visited: &mut HashSet<NodeId>,
results: &mut Vec<Path>,
limit: usize,
) {
if current == dst && path.len() > 1 {
results.push(path.clone());
return;
}
let depth = path.len() - 1;
if depth >= max_depth || results.len() >= limit {
return;
}
if let Some(edges) = mt.get_edges(current) {
for edge in edges {
if let Some(lf) = label_filter
&& edge.label != lf
{
continue;
}
let next = edge.target_id;
if visited.contains(&next) && next != dst {
continue;
}
if next == dst {
let mut result_path = path.clone();
result_path.push(dst);
results.push(result_path);
if results.len() >= limit {
return;
}
continue;
}
visited.insert(next);
let mut new_path = path.clone();
new_path.push(next);
dfs_all_paths(
mt,
next,
dst,
&new_path,
max_depth,
label_filter,
visited,
results,
limit,
);
if results.len() >= limit {
visited.remove(&next);
return;
}
visited.remove(&next);
}
}
}
pub fn k_hop_neighbors<T: VectorType>(
mt: &MemTable<T>,
src: NodeId,
k: usize,
label_filter: Option<&str>,
) -> HashMap<NodeId, usize> {
let mut distances: HashMap<NodeId, usize> = HashMap::new();
distances.insert(src, 0);
let mut queue: VecDeque<(NodeId, usize)> = VecDeque::new();
queue.push_back((src, 0));
while let Some((current, depth)) = queue.pop_front() {
if depth >= k {
continue;
}
if let Some(edges) = mt.get_edges(current) {
for edge in edges {
if let Some(lf) = label_filter
&& edge.label != lf
{
continue;
}
let next = edge.target_id;
if let std::collections::hash_map::Entry::Vacant(e) = distances.entry(next) {
e.insert(depth + 1);
queue.push_back((next, depth + 1));
}
}
}
}
distances
}