use crate::{
db::{
api::{
state::{GenericNodeState, TypedNodeState},
view::StaticGraphViewOps,
},
graph::node::NodeView,
},
errors::GraphError,
prelude::*,
};
use raphtory_api::core::Direction;
use rayon::prelude::*;
use serde::{Deserialize, Serialize};
#[derive(Clone, PartialEq, Serialize, Deserialize, Debug, Default)]
pub struct BalanceState {
pub balance: f64,
}
fn balance_per_node<'graph, G: GraphViewOps<'graph>>(
v: &NodeView<'graph, G>,
prop_id: usize,
direction: Direction,
) -> f64 {
match direction {
Direction::IN => v
.in_edges()
.properties()
.flat_map(|prop| {
prop.temporal().get_by_id(prop_id).map(|val| {
val.values()
.map(|valval| valval.as_f64().unwrap_or(1.0f64))
.sum::<f64>()
})
})
.sum::<f64>(),
Direction::OUT => -v
.out_edges()
.properties()
.flat_map(|prop| {
prop.temporal().get_by_id(prop_id).map(|val| {
val.values()
.map(|valval| valval.as_f64().unwrap_or(1.0f64))
.sum::<f64>()
})
})
.sum::<f64>(),
Direction::BOTH => {
let in_res = balance_per_node(v, prop_id, Direction::IN);
let out_res = balance_per_node(v, prop_id, Direction::OUT);
in_res + out_res
}
}
}
pub fn balance<G: StaticGraphViewOps>(
graph: &G,
name: String,
direction: Direction,
) -> Result<TypedNodeState<'static, BalanceState, G>, GraphError> {
if let Some((weight_id, weight_type)) = graph.edge_meta().get_prop_id_and_type(&name, false) {
if !weight_type.is_numeric() {
return Err(GraphError::InvalidProperty {
reason: "Edge property {name} is not numeric".to_string(),
});
}
let values: Vec<_> = graph
.nodes()
.par_iter()
.map(|n| BalanceState {
balance: balance_per_node(&n, weight_id, direction),
})
.collect();
Ok(TypedNodeState::new(GenericNodeState::new_from_eval(
graph.clone(),
values,
None,
)))
} else {
Err(GraphError::InvalidProperty {
reason: "Edge property {name} does not exist".to_string(),
})
}
}