use crate::{core::entities::nodes::node_ref::AsNodeRef, db::api::view::StaticGraphViewOps};
use crate::{
core::entities::nodes::node_ref::NodeRef,
db::{
api::state::{ops::Const, GenericNodeState, Index, NodeStateOutputType, TypedNodeState},
graph::nodes::Nodes,
},
errors::GraphError,
prelude::*,
};
use indexmap::IndexSet;
use raphtory_api::core::{
entities::{
properties::prop::{PropType, PropUnwrap},
VID,
},
Direction,
};
use serde::{Deserialize, Serialize};
use std::{
cmp::Ordering,
collections::{BinaryHeap, HashMap, HashSet},
};
#[derive(Clone, PartialEq, Serialize, Deserialize, Debug, Default)]
pub struct DistanceState {
pub distance: f64,
pub path: Vec<VID>,
}
#[derive(Clone, Debug)]
pub struct TransformedDistanceState<'graph, G>
where
G: GraphViewOps<'graph>,
{
pub distance: f64,
pub path: Nodes<'graph, G, G>,
}
impl DistanceState {
pub fn node_transform<'graph, G>(
state: &GenericNodeState<'graph, G>,
value: Self,
) -> TransformedDistanceState<'graph, G>
where
G: GraphViewOps<'graph>,
{
TransformedDistanceState {
distance: value.distance,
path: Nodes::new_filtered(
state.base_graph.clone(),
state.base_graph.clone(),
Const(true),
Some(Index::from_iter(value.path)),
),
}
}
}
#[derive(PartialEq)]
struct State {
cost: Prop,
node: VID,
}
impl Eq for State {}
impl Ord for State {
fn cmp(&self, other: &State) -> Ordering {
self.partial_cmp(other).unwrap_or(Ordering::Equal)
}
}
impl PartialOrd for State {
fn partial_cmp(&self, other: &State) -> Option<Ordering> {
other.cost.partial_cmp(&self.cost)
}
}
pub fn dijkstra_single_source_shortest_paths<G: StaticGraphViewOps, T: AsNodeRef>(
g: &G,
source: T,
targets: Vec<T>,
weight: Option<&str>,
direction: Direction,
) -> Result<
TypedNodeState<'static, DistanceState, G, TransformedDistanceState<'static, G>>,
GraphError,
> {
let source_ref = source.as_node_ref();
let source_node = match g.node(source_ref) {
Some(src) => src,
None => {
let gid = match source_ref {
NodeRef::Internal(vid) => g.node_id(vid),
NodeRef::External(gid) => gid.to_owned(),
};
return Err(GraphError::NodeMissingError(gid));
}
};
let mut weight_type = PropType::U8;
if let Some(weight) = weight {
if let Some((_, dtype)) = g.edge_meta().get_prop_id_and_type(weight, false) {
weight_type = dtype;
} else {
return Err(GraphError::PropertyMissingError(weight.to_string()));
}
}
let mut target_nodes = vec![false; g.unfiltered_num_nodes()];
for target in targets {
if let Some(target_node) = g.node(target) {
target_nodes[target_node.node.index()] = true;
}
}
let cost_val = match weight_type {
PropType::F32 => Prop::F32(0f32),
PropType::F64 => Prop::F64(0f64),
PropType::U8 => Prop::U8(0u8),
PropType::U16 => Prop::U16(0u16),
PropType::U32 => Prop::U32(0u32),
PropType::U64 => Prop::U64(0u64),
PropType::I32 => Prop::I32(0i32),
PropType::I64 => Prop::I64(0i64),
p_type => {
return Err(GraphError::InvalidProperty {
reason: format!("Weight type: {:?}, not supported", p_type),
})
}
};
let max_val = match weight_type {
PropType::F32 => Prop::F32(f32::MAX),
PropType::F64 => Prop::F64(f64::MAX),
PropType::U8 => Prop::U8(u8::MAX),
PropType::U16 => Prop::U16(u16::MAX),
PropType::U32 => Prop::U32(u32::MAX),
PropType::U64 => Prop::U64(u64::MAX),
PropType::I32 => Prop::I32(i32::MAX),
PropType::I64 => Prop::I64(i64::MAX),
p_type => {
return Err(GraphError::InvalidProperty {
reason: format!("Weight type: {:?}, not supported", p_type),
})
}
};
let mut heap = BinaryHeap::new();
heap.push(State {
cost: cost_val.clone(),
node: source_node.node,
});
let mut dist: HashMap<VID, Prop> = HashMap::new();
let mut predecessor: HashMap<VID, VID> = HashMap::new();
let mut visited: HashSet<VID> = HashSet::new();
let mut paths: HashMap<VID, (f64, IndexSet<VID, ahash::RandomState>)> = HashMap::new();
dist.insert(source_node.node, cost_val.clone());
while let Some(State {
cost,
node: node_vid,
}) = heap.pop()
{
if target_nodes[node_vid.index()] && !paths.contains_key(&node_vid) {
let mut path = IndexSet::default();
path.insert(node_vid);
let mut current_node_id = node_vid;
while let Some(prev_node) = predecessor.get(¤t_node_id) {
path.insert(*prev_node);
current_node_id = *prev_node;
}
path.reverse();
paths.insert(node_vid, (cost.as_f64().unwrap(), path));
}
if !visited.insert(node_vid) {
continue;
}
let edges = match direction {
Direction::OUT => g.node(node_vid).unwrap().out_edges(),
Direction::IN => g.node(node_vid).unwrap().in_edges(),
Direction::BOTH => g.node(node_vid).unwrap().edges(),
};
for edge in edges {
let next_node_vid = edge.nbr().node;
let edge_val = match weight {
None => Prop::U8(1),
Some(weight) => match edge.properties().get(weight) {
Some(prop) => prop,
_ => continue,
},
};
let next_cost = cost.clone().add(edge_val).unwrap();
if next_cost < *dist.entry(next_node_vid).or_insert(max_val.clone()) {
heap.push(State {
cost: next_cost.clone(),
node: next_node_vid,
});
dist.insert(next_node_vid, next_cost);
predecessor.insert(next_node_vid, node_vid);
}
}
}
Ok(TypedNodeState::new_mapped(
GenericNodeState::new_from_map(
g.clone(),
paths,
|(cost, path)| DistanceState {
distance: cost,
path: path.into_iter().collect(),
},
Some(HashMap::from([(
"path".to_string(),
(NodeStateOutputType::Nodes, None),
)])),
),
DistanceState::node_transform,
))
}