use crate::{
core::entities::{nodes::node_ref::AsNodeRef, VID},
db::{
api::{
state::{node_state_ops::NodeStateOps, ops::Const, Index},
view::{DynamicGraph, IntoDynBoxed, IntoDynamic},
},
graph::{node::NodeView, nodes::Nodes},
},
errors::GraphError,
prelude::{GraphViewOps, NodeViewOps},
};
use arrow::{
array::AsArray,
compute::{cast_with_options, interleave_record_batch, CastOptions},
datatypes::UInt64Type,
row::{RowConverter, SortField},
};
use arrow_array::{Array, ArrayRef, RecordBatch, StringArray, UInt32Array};
use arrow_schema::{ArrowError, DataType, Field, FieldRef, Schema, SchemaBuilder, SortOptions};
use dashmap::DashMap;
use indexmap::{IndexMap, IndexSet};
use parquet::{
arrow::{arrow_reader::ParquetRecordBatchReaderBuilder, ArrowWriter},
basic::Compression,
file::properties::WriterProperties,
};
use arrow_array::{builder::UInt64Builder, UInt64Array};
use arrow_select::{concat::concat, take::take};
use datafusion_expr_common::groups_accumulator::EmitTo;
use datafusion_physical_expr::expressions::col;
use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr;
use datafusion_physical_plan::aggregates::{group_values::new_group_values, order::GroupOrdering};
use raphtory_api::core::entities::properties::prop::{Prop, PropType, PropUntagged, PropUnwrap};
use rayon::{iter::Either, prelude::*};
use serde::{de::DeserializeOwned, Serialize};
use serde_arrow::{
from_record_batch,
schema::{SchemaLike, TracingOptions},
to_record_batch, Deserializer,
};
use crate::db::graph::views::filter::model::node_state_filter::NodeStateBoolColOp;
use std::{
cmp::{Ordering, PartialEq},
collections::{BinaryHeap, HashMap},
fmt::{Debug, Formatter},
fs::File,
hash::BuildHasher,
marker::PhantomData,
path::Path,
sync::Arc,
};
pub trait NodeStateValue:
Clone + PartialEq + Serialize + DeserializeOwned + Send + Sync + Debug
{
}
impl<T> NodeStateValue for T where
T: Clone + PartialEq + Serialize + DeserializeOwned + Send + Sync + Debug
{
}
pub trait InputNodeStateValue<V>: NodeStateValue + From<V> {}
impl<T, V> InputNodeStateValue<V> for T where T: NodeStateValue + From<V> {}
#[derive(Clone, PartialEq)]
pub enum MergePriority {
Left,
Right,
Exclude,
}
#[derive(Clone, PartialEq, Debug)]
pub enum NodeStateOutput<'graph, G: GraphViewOps<'graph>> {
Node(NodeView<'graph, G>),
Nodes(Nodes<'graph, G, G>),
Prop(Option<PropUntagged>),
}
#[derive(Clone, PartialEq, Debug)]
pub enum NodeStateOutputType {
Node,
Nodes,
Prop,
}
pub type PropMap = IndexMap<String, Option<PropUntagged>>;
pub fn convert_prop_map<A, B>(map: IndexMap<String, Option<A>>) -> IndexMap<String, Option<B>>
where
B: From<A>,
{
map.into_iter()
.map(|(k, v)| (k, v.map(Into::into)))
.collect()
}
pub type TransformedPropMap<'graph, G> = IndexMap<String, NodeStateOutput<'graph, G>>;
pub type OutputTypedNodeState<'graph, G> =
TypedNodeState<'graph, PropMap, G, TransformedPropMap<'graph, G>>;
pub trait NodeTransform {
type Input;
type Output;
fn transform<'graph, G>(
state: &GenericNodeState<'graph, G>,
value: Self::Input,
) -> Self::Output
where
G: GraphViewOps<'graph>;
}
impl<T> NodeTransform for T
where
T: NodeStateValue,
{
type Input = Self;
type Output = Self;
fn transform<'graph, G>(
_state: &GenericNodeState<'graph, G>,
value: Self::Input,
) -> Self::Output
where
G: GraphViewOps<'graph>,
{
value
}
}
struct HeapRow {
row: Vec<u8>,
index: usize,
}
impl PartialEq for HeapRow {
fn eq(&self, other: &Self) -> bool {
self.row == other.row
}
}
impl Eq for HeapRow {}
impl PartialOrd for HeapRow {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for HeapRow {
fn cmp(&self, other: &Self) -> Ordering {
self.row.cmp(&other.row)
}
}
#[derive(Clone, Debug)]
pub struct GenericNodeState<'graph, G> {
pub base_graph: G,
values: RecordBatch,
pub(crate) keys: Option<Index<VID>>,
pub node_cols: HashMap<String, (NodeStateOutputType, Option<G>)>,
_marker: PhantomData<&'graph ()>,
}
impl<'graph, G> GenericNodeState<'graph, G> {
#[inline]
pub fn values_ref(&self) -> &RecordBatch {
&self.values
}
#[inline]
pub fn keys_ref(&self) -> Option<&Index<VID>> {
self.keys.as_ref()
}
}
#[derive(Clone)]
pub struct TypedNodeState<'graph, V: NodeStateValue, G, T: Clone + Sync + Send = V> {
pub state: GenericNodeState<'graph, G>,
pub converter: fn(&GenericNodeState<'graph, G>, V) -> T,
_v_marker: PhantomData<V>,
_t_marker: PhantomData<T>,
}
pub struct RecordBatchIterator<'a, T> {
deserializer: Deserializer<'a>,
idx: usize,
_phantom: std::marker::PhantomData<T>,
}
impl<'a, T> RecordBatchIterator<'a, T>
where
T: NodeStateValue,
{
pub fn new(record_batch: &'a RecordBatch) -> Self {
let deserializer: Deserializer<'a> = Deserializer::from_record_batch(record_batch).unwrap();
let idx: usize = 0;
Self {
deserializer,
idx,
_phantom: std::marker::PhantomData,
}
}
pub fn get(&self, idx: usize) -> Option<T> {
Some(T::deserialize(self.deserializer.get(idx)?).unwrap())
}
}
impl<'a, T> Iterator for RecordBatchIterator<'a, T>
where
T: NodeStateValue,
{
type Item = T;
fn next(&mut self) -> Option<Self::Item> {
let res = self.deserializer.get(self.idx);
if res.is_none() {
return None;
}
let item = T::deserialize(res.unwrap()).unwrap();
self.idx += 1;
Some(item)
}
}
impl<'graph, G: IntoDynamic> GenericNodeState<'graph, G> {
pub fn into_dyn(self) -> GenericNodeState<'graph, DynamicGraph> {
let node_cols = Some(
self.node_cols
.into_iter()
.map(|(k, (output_type, bg))| (k, (output_type, bg.map(|bg| bg.into_dynamic()))))
.collect(),
);
GenericNodeState::new(
self.base_graph.into_dynamic(),
self.values,
self.keys,
node_cols,
)
}
}
impl<'graph, G: GraphViewOps<'graph>> GenericNodeState<'graph, G> {
pub fn new(
base_graph: G,
values: RecordBatch,
keys: Option<Index<VID>>,
node_cols: Option<HashMap<String, (NodeStateOutputType, Option<G>)>>,
) -> Self {
Self {
base_graph,
values,
keys,
node_cols: node_cols.unwrap_or(HashMap::new()),
_marker: PhantomData,
}
}
pub fn get_nodes(
state: &GenericNodeState<'graph, G>,
value_map: PropMap,
) -> TransformedPropMap<'graph, G> {
value_map
.into_iter()
.map(|(key, value)| {
let node_col_entry = state.node_cols.get(&key);
let mut result = None;
if value.is_none() || node_col_entry.is_none() {
return (key, NodeStateOutput::Prop(value));
} else {
let (node_state_type, base_graph) = node_col_entry.unwrap();
if node_state_type == &NodeStateOutputType::Node {
if let Some(PropUntagged(Prop::U64(vid))) = value {
result = Some((
key,
NodeStateOutput::Node(NodeView::new_internal(
base_graph.as_ref().unwrap_or(&state.base_graph).clone(),
VID(vid as usize),
)),
));
}
} else if node_state_type == &NodeStateOutputType::Nodes {
#[cfg(feature = "arrow")]
if let Some(PropUntagged(Prop::Array(vid_arr))) = value {
return (
key,
NodeStateOutput::Nodes(Nodes::new_filtered(
base_graph.as_ref().unwrap_or(&state.base_graph).clone(),
base_graph.as_ref().unwrap_or(&state.base_graph).clone(),
Const(true),
Some(Index::from_iter(
vid_arr
.iter_prop()
.map(|vid| VID(vid.into_u64().unwrap() as usize)),
)),
)),
);
}
if let Some(PropType::List(_)) = value.as_ref().map(|value| value.0.dtype())
{
if let Some(PropUntagged(Prop::List(vid_list))) = value {
if let Some(vid_list) = Arc::into_inner(vid_list) {
result = Some((
key,
NodeStateOutput::Nodes(Nodes::new_filtered(
base_graph
.as_ref()
.unwrap_or(&state.base_graph)
.clone(),
base_graph
.as_ref()
.unwrap_or(&state.base_graph)
.clone(),
Const(true),
Some(Index::from_iter(vid_list.into_iter().map(
|vid| {
VID(vid
.try_cast(PropType::U64)
.unwrap()
.into_u64()
.unwrap()
as usize)
},
))),
)),
));
}
}
}
}
}
result.unwrap()
})
.collect()
}
pub fn to_output_nodestate(self) -> OutputTypedNodeState<'graph, G> {
TypedNodeState::new_mapped(self, Self::get_nodes)
}
pub fn into_inner(self) -> (RecordBatch, Option<Index<VID>>) {
(self.values, self.keys)
}
pub fn values(&self) -> &RecordBatch {
&self.values
}
fn get_index_by_node<N: AsNodeRef>(&self, node: &N) -> Option<usize> {
let id = self.base_graph.internalise_node(node.as_node_ref())?;
match &self.keys {
Some(index) => index.index(&id),
None => Some(id.0),
}
}
fn nodes(&self) -> Nodes<'graph, G> {
Nodes::new_filtered(
self.base_graph.clone(),
self.base_graph.clone(),
Const(true),
self.keys.clone(),
)
}
fn len(&self) -> usize {
self.values.num_rows()
}
pub fn from_parquet<P: AsRef<Path>>(
&self,
file_path: P,
id_column: Option<String>,
) -> Result<GenericNodeState<'graph, G>, GraphError> {
let num_nodes = self.base_graph.unfiltered_num_nodes();
let file = File::open(file_path)?;
let builder = ParquetRecordBatchReaderBuilder::try_new(file)
.map_err(|e| GraphError::IOErrorMsg(e.to_string()))?;
let schema = builder.schema().clone();
if let Some(ref col_name) = id_column {
if schema.column_with_name(col_name).is_none() {
return Err(GraphError::IOErrorMsg(
format!("Column {} does not exist.", col_name.clone()).to_string(),
));
}
}
let reader = builder
.build()
.map_err(|e| GraphError::IOErrorMsg(e.to_string()))?;
let batches: Vec<RecordBatch> = reader
.collect::<Result<Vec<_>, _>>()
.map_err(|e| GraphError::IOErrorMsg(e.to_string()))?;
if batches.is_empty() {
return Err(GraphError::IOErrorMsg("Parquet file is empty.".to_string()));
}
let mut batch = arrow::compute::concat_batches(&schema, &batches)
.map_err(|e| GraphError::IOErrorMsg(e.to_string()))?;
if batch.num_rows() > num_nodes {
return Err(GraphError::IOErrorMsg(
format!(
"Number of rows ({}) exceeds order of graph ({}).",
batch.num_rows(),
num_nodes,
)
.to_string(),
));
}
let mut index = self.keys.clone();
if let Some(ref col_name) = id_column {
if let Some(arr) = batch
.column_by_name(col_name)
.unwrap()
.as_primitive_opt::<UInt64Type>()
{
let max_node_id = arr.iter().max().unwrap_or(Some(0)).unwrap() as usize;
if max_node_id >= num_nodes {
return Err(GraphError::IOErrorMsg(
format!(
"Max Node ID ({}) exceeds order of graph ({}).",
max_node_id, num_nodes,
)
.to_string(),
));
}
index = Some(Index::from_iter(
arr.iter().map(|v| VID(v.unwrap_or(0) as usize)),
));
} else {
return Err(GraphError::IOErrorMsg(
format!("Column {} is not unsigned integer type.", col_name).to_string(),
));
}
batch.remove_column(schema.column_with_name(col_name).unwrap().0);
} else if batch.num_rows() < num_nodes {
index = Some(Index::from_iter((0..batch.num_rows()).map(VID)));
}
Ok(GenericNodeState {
base_graph: self.base_graph.clone(),
values: batch,
keys: index,
node_cols: self.node_cols.clone(),
_marker: PhantomData,
})
}
pub fn to_parquet<P: AsRef<Path>>(&self, file_path: P, id_column: Option<String>) {
let mut batch: Option<RecordBatch> = None;
let mut schema = self.values.schema();
if id_column.is_some() {
let ids: Vec<String> = self
.nodes()
.id()
.iter()
.map(|(_, gid)| gid.to_string())
.collect();
let ids_array = Arc::new(StringArray::from(ids)) as ArrayRef;
let mut builder = SchemaBuilder::new();
for field in &self.values.schema().fields().clone() {
builder.push(field.clone())
}
builder.push(Arc::new(Field::new(
id_column.unwrap(),
DataType::Utf8,
false,
)));
schema = Arc::new(Schema::new(builder.finish().fields));
let mut columns = self.values.columns().to_vec();
columns.push(ids_array);
batch = Some(RecordBatch::try_new(schema.clone(), columns).unwrap());
}
let file = File::create(file_path).unwrap();
let props = WriterProperties::builder()
.set_compression(Compression::SNAPPY)
.build();
let mut writer = ArrowWriter::try_new(file, schema, Some(props)).unwrap();
writer
.write(batch.as_ref().unwrap_or(&self.values))
.expect("Writing batch");
writer.close().unwrap();
}
fn merge_columns(
col_name: &String,
tgt_idx_set: Option<&Index<VID>>, tgt_idx_set_len: usize,
lh: Option<&Index<VID>>,
rh: Option<&Index<VID>>,
lh_batch: &RecordBatch,
rh_batch: &RecordBatch,
) -> (Arc<dyn Array>, Field) {
let iter = if tgt_idx_set.is_some() {
Either::Left(tgt_idx_set.as_ref().unwrap().iter())
} else {
Either::Right((0..tgt_idx_set_len).map(VID))
};
let mut idx_builder = UInt64Builder::with_capacity(tgt_idx_set_len);
let left_col = lh_batch.column_by_name(col_name);
let left_len = if left_col.is_some() {
left_col.unwrap().len()
} else {
0
};
let right_col = rh_batch.column_by_name(col_name);
let cat = match (left_col, right_col) {
(Some(l), Some(r)) => &concat(&[l.as_ref(), r.as_ref()]).unwrap(),
(Some(l), None) => l,
(None, Some(r)) => r,
(None, None) => unreachable!("at least one is guaranteed to be Some"),
};
for i in iter {
if left_col.is_some() && (lh.is_none() || lh.unwrap().contains(&i)) {
if lh.is_none() {
idx_builder.append_value(i.0 as u64);
} else {
idx_builder.append_value(lh.unwrap().index(&i).unwrap() as u64);
}
} else if right_col.is_some() && (rh.is_none() || rh.unwrap().contains(&i)) {
if rh.is_none() {
idx_builder.append_value((left_len + i.0) as u64);
} else {
idx_builder.append_value((left_len + rh.unwrap().index(&i).unwrap()) as u64);
}
} else {
idx_builder.append_null();
}
}
let take_idx: UInt64Array = idx_builder.finish();
let mut field = lh_batch.schema_ref().field_with_name(col_name);
if field.is_err() {
field = rh_batch.schema_ref().field_with_name(col_name);
}
(
take(cat.as_ref(), &take_idx, None).unwrap(),
field.unwrap().clone(),
)
}
pub fn merge(
&self,
other: &GenericNodeState<'graph, G>,
index_merge_priority: MergePriority, default_column_merge_priority: MergePriority,
column_merge_priority_map: Option<HashMap<String, MergePriority>>,
) -> Self {
let mut merge_node_cols: HashMap<String, (NodeStateOutputType, Option<G>)> =
HashMap::default();
let default_node_cols: &HashMap<String, (NodeStateOutputType, Option<G>)> =
if default_column_merge_priority == MergePriority::Left {
&self.node_cols
} else {
&other.node_cols
};
let new_idx_set = if self.keys.is_none() || other.keys.is_none() {
None
} else {
Some(
IndexSet::union(
self.keys.as_ref().unwrap().index.as_ref(),
other.keys.as_ref().unwrap().index.as_ref(),
)
.clone()
.into_iter()
.map(|v| v.to_owned())
.collect(),
)
};
let tgt_idx_set = match index_merge_priority {
MergePriority::Left => self.keys.as_ref(),
MergePriority::Right => other.keys.as_ref(),
MergePriority::Exclude => new_idx_set.as_ref(),
};
let tgt_index_set_len = tgt_idx_set
.map_or(self.base_graph.unfiltered_num_nodes(), |idx_set| {
idx_set.len()
});
let mut cols: Vec<Arc<dyn Array>> = vec![];
let mut fields: Vec<Field> = vec![];
if column_merge_priority_map.as_ref().is_some() {
for (col_name, priority) in column_merge_priority_map.as_ref().unwrap() {
let (lh, rh, lh_batch, rh_batch) = match priority {
MergePriority::Left => {
if self.node_cols.contains_key(col_name) {
merge_node_cols.insert(
col_name.to_string(),
self.node_cols.get(col_name).unwrap().clone(),
);
}
(
self.keys.as_ref(),
other.keys.as_ref(),
&self.values,
&other.values,
)
}
MergePriority::Right => {
if other.node_cols.contains_key(col_name) {
merge_node_cols.insert(
col_name.to_string(),
other.node_cols.get(col_name).unwrap().clone(),
);
}
(
other.keys.as_ref(),
self.keys.as_ref(),
&other.values,
&self.values,
)
}
MergePriority::Exclude => continue,
};
let (col, field) = GenericNodeState::<'graph, G>::merge_columns(
col_name,
tgt_idx_set,
tgt_index_set_len,
lh,
rh,
lh_batch,
rh_batch,
);
cols.push(col);
fields.push(field);
}
}
let (lh, rh, lh_batch, rh_batch) = match default_column_merge_priority {
MergePriority::Left => (
self.keys.as_ref(),
other.keys.as_ref(),
&self.values,
&other.values,
),
MergePriority::Right => (
other.keys.as_ref(),
self.keys.as_ref(),
&other.values,
&self.values,
),
MergePriority::Exclude => {
return GenericNodeState::new(
self.base_graph.clone(),
RecordBatch::try_new(Schema::new(fields).into(), cols).unwrap(),
tgt_idx_set.cloned(),
Some(merge_node_cols),
);
}
};
for column in lh_batch.schema().fields().iter() {
let col_name = column.name();
if column_merge_priority_map
.as_ref()
.map_or(true, |map| map.contains_key(col_name) == false)
{
if default_node_cols.contains_key(col_name) {
merge_node_cols.insert(
col_name.to_string(),
default_node_cols.get(col_name).unwrap().clone(),
);
}
let (col, field) = GenericNodeState::<'graph, G>::merge_columns(
col_name,
tgt_idx_set,
tgt_index_set_len,
lh,
rh,
lh_batch,
rh_batch,
);
cols.push(col);
fields.push(field);
}
}
GenericNodeState::new(
self.base_graph.clone(),
RecordBatch::try_new(Schema::new(fields).into(), cols).unwrap(),
tgt_idx_set.cloned(),
Some(merge_node_cols),
)
}
fn convert_recordbatch(
recordbatch: RecordBatch,
) -> Result<RecordBatch, arrow_schema::ArrowError> {
let new_columns: Vec<Arc<dyn Array>> = recordbatch
.columns()
.iter()
.map(|col| cast_with_options(col, &col.data_type().clone(), &CastOptions::default()))
.collect::<arrow::error::Result<_>>()?;
let new_fields: Vec<Field> = recordbatch
.schema()
.fields()
.iter()
.map(|f| Field::new(f.name(), f.data_type().clone(), true))
.collect();
let new_schema = Arc::new(Schema::new(new_fields));
RecordBatch::try_new(new_schema.clone(), new_columns)
}
}
impl<'graph, V, G> TypedNodeState<'graph, V, G>
where
V: NodeStateValue,
G: GraphViewOps<'graph>,
{
pub fn new(state: GenericNodeState<'graph, G>) -> Self {
Self {
state,
converter: V::transform,
_v_marker: PhantomData,
_t_marker: PhantomData,
}
}
}
impl<
'graph,
V: NodeStateValue + 'graph,
T: Clone + Send + Sync + 'graph,
G: GraphViewOps<'graph>,
> TypedNodeState<'graph, V, G, T>
{
pub fn new_mapped(
state: GenericNodeState<'graph, G>,
converter: fn(&GenericNodeState<'graph, G>, V) -> T,
) -> Self {
TypedNodeState {
state,
converter,
_v_marker: PhantomData,
_t_marker: PhantomData,
}
}
pub fn to_hashmap<Func, TransformedType>(self, f: Func) -> HashMap<String, TransformedType>
where
Func: Fn(V) -> TransformedType,
{
self.into_iter()
.map(|(node, value)| (node.name(), f(value)))
.collect()
}
pub fn to_transformed_hashmap(&self) -> HashMap<String, T> {
self.iter()
.map(|(node, value)| (node.name(), self.convert(value.clone())))
.collect()
}
pub fn to_output_nodestate(self) -> OutputTypedNodeState<'graph, G> {
TypedNodeState::new_mapped(self.state, GenericNodeState::get_nodes)
}
pub fn values_to_rows(&self) -> Vec<V> {
let rows: Vec<V> = from_record_batch(&self.state.values).unwrap();
rows
}
pub fn values_from_rows(&mut self, rows: Vec<V>) {
let fields = Vec::<FieldRef>::from_type::<V>(TracingOptions::default()).unwrap();
self.state.values = to_record_batch(&fields, &rows).unwrap();
}
pub fn convert(&self, value: V) -> T {
(self.converter)(&self.state, value)
}
pub fn get_groups(&self, cols: Vec<String>) -> Result<Vec<(T, Nodes<'graph, G>)>, GraphError> {
let num_rows = self.state.values().num_rows();
if num_rows == 0 {
return Ok(vec![].into());
}
let mut group_values = new_group_values(self.state.values().schema(), &GroupOrdering::None)
.map_err(|e| ArrowError::ParseError(e.to_string()))
.map_err(|e| GraphError::IOErrorMsg(e.to_string()))?;
let group_arrays: Vec<ArrayRef> = cols
.iter()
.map(|name| {
let idx = self
.state
.values()
.schema()
.index_of(name)
.map_err(|e| GraphError::IOErrorMsg(e.to_string()))?;
Ok(self.state.values().column(idx).clone())
})
.collect::<Result<_, GraphError>>()?;
let mut group_indices = vec![0usize; num_rows];
group_values
.intern(&group_arrays, &mut group_indices)
.map_err(|e| ArrowError::ComputeError(e.to_string()))
.map_err(|e| GraphError::IOErrorMsg(e.to_string()))?;
let group_arrays = group_values
.emit(EmitTo::All)
.map_err(|e| ArrowError::ComputeError(e.to_string()))
.map_err(|e| GraphError::IOErrorMsg(e.to_string()))?;
let fields: Vec<Field> = cols
.iter()
.map(|name| {
let idx = self
.state
.values()
.schema()
.index_of(name)
.map_err(|e| GraphError::IOErrorMsg(e.to_string()))?;
Ok(self.state.values().schema().field(idx).clone())
})
.collect::<Result<_, GraphError>>()?;
let schema = Arc::new(Schema::new(fields));
let group_batch = RecordBatch::try_new(schema, group_arrays)
.map_err(|e| GraphError::IOErrorMsg(e.to_string()))?;
let group_batch_des: RecordBatchIterator<'_, V> = RecordBatchIterator::new(&group_batch);
let groups: DashMap<usize, IndexSet<VID, ahash::RandomState>, ahash::RandomState> =
DashMap::default();
group_indices
.into_par_iter()
.enumerate()
.for_each(|(node_idx, group_value_idx)| {
let vid = self.state.keys.as_ref().map_or(VID(node_idx), |idx| {
idx.index.get_index(node_idx).unwrap().clone()
});
groups.entry(group_value_idx).or_default().insert(vid);
});
let result = groups
.into_par_iter()
.map(|(group_value_idx, nodes)| {
(
self.convert(group_batch_des.get(group_value_idx).unwrap()),
Nodes::new_filtered(
self.graph().clone(),
self.graph().clone(),
Const(true),
Some(Index::new(nodes)),
),
)
})
.collect();
return Ok(result);
}
}
impl<'graph, V, G, T> TypedNodeState<'graph, V, G, T>
where
V: NodeStateValue + 'graph,
T: Clone + Send + Sync + 'graph,
{
pub fn bool_col_filter(&self, col: &str) -> Result<NodeStateBoolColOp, GraphError> {
NodeStateBoolColOp::new(self, col)
}
}
impl<
'a,
'graph,
V: NodeStateValue + 'graph,
T: Clone + Sync + Send + 'graph,
G: GraphViewOps<'graph>,
> PartialEq<TypedNodeState<'graph, V, G, T>> for TypedNodeState<'graph, V, G, T>
{
fn eq(&self, other: &TypedNodeState<'graph, V, G, T>) -> bool {
self.len() == other.len()
&& self
.par_iter()
.all(|(node, value)| other.get_by_node(node).map(|v| v == value).unwrap_or(false))
}
}
impl<
'graph,
RHS: NodeStateValue + Send + Sync + 'graph,
T: Clone + Sync + Send + 'graph,
G: GraphViewOps<'graph>,
> PartialEq<Vec<RHS>> for TypedNodeState<'graph, RHS, G, T>
{
fn eq(&self, other: &Vec<RHS>) -> bool {
self.values_to_rows().par_iter().eq(other)
}
}
impl<'graph, T: Clone + Sync + Send + 'graph, G: GraphViewOps<'graph>>
PartialEq<Vec<IndexMap<String, Option<Prop>>>> for TypedNodeState<'graph, PropMap, G, T>
{
fn eq(&self, other: &Vec<IndexMap<String, Option<Prop>>>) -> bool {
let rows: Vec<_> = self.values_to_rows();
rows.len() == other.len()
&& rows
.into_par_iter()
.zip(other.par_iter())
.all(|(a, b)| convert_prop_map::<PropUntagged, Prop>(a) == *b)
}
}
impl<
'graph,
K: AsNodeRef,
RHS: NodeStateValue + 'graph,
T: Clone + Send + Sync + 'graph,
G: GraphViewOps<'graph>,
S,
> PartialEq<HashMap<K, RHS, S>> for TypedNodeState<'graph, RHS, G, T>
{
fn eq(&self, other: &HashMap<K, RHS, S>) -> bool {
other.len() == self.len()
&& other
.iter()
.all(|(k, rhs)| self.get_by_node(k).filter(|lhs| lhs == rhs).is_none() == false)
}
}
impl<'graph, G: GraphViewOps<'graph>> OutputTypedNodeState<'graph, G> {
pub fn try_eq_hashmap<K: AsNodeRef>(
&self,
other: HashMap<K, HashMap<String, Option<Prop>>>,
) -> Result<bool, &'static str> {
if other.len() != self.len() {
return Ok(false);
}
for (k, mut rhs_map) in other {
let lhs_map = self.get_by_node(&k).ok_or("Key missing in lhs map")?;
if lhs_map.len() != rhs_map.len() {
return Ok(false);
}
for (key, lhs_val) in lhs_map {
let rhs_val = rhs_map.remove(&key).ok_or("Key missing in rhs map")?;
if lhs_val.is_none() {
if rhs_val.is_none() {
continue;
} else {
return Ok(false);
}
} else if rhs_val.is_none() {
return Ok(false);
}
let lhs_val = lhs_val.unwrap();
let rhs_val = rhs_val.unwrap();
let casted_rhs = rhs_val
.try_cast(lhs_val.0.dtype())
.map_err(|_| "Failed to cast rhs value")?;
if casted_rhs != lhs_val.0 {
return Ok(false);
}
}
}
Ok(true)
}
}
impl<
'graph,
V: NodeStateValue + 'graph,
T: Clone + Send + Sync + 'graph,
G: GraphViewOps<'graph>,
> Debug for TypedNodeState<'graph, V, G, T>
{
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_map().entries(self.iter()).finish()
}
}
impl<
'graph,
V: NodeStateValue + 'graph,
T: Clone + Sync + Send + 'graph,
G: GraphViewOps<'graph>,
> IntoIterator for TypedNodeState<'graph, V, G, T>
{
type Item = (NodeView<'graph, G>, V);
type IntoIter = Box<dyn Iterator<Item = Self::Item> + 'graph>;
fn into_iter(self) -> Self::IntoIter {
self.nodes()
.clone()
.into_iter()
.zip(self.into_iter_values())
.into_dyn_boxed()
}
}
impl<
'a,
'graph: 'a,
V: NodeStateValue + 'graph,
T: Clone + Sync + Send + 'graph,
G: GraphViewOps<'graph>,
> NodeStateOps<'a, 'graph> for TypedNodeState<'graph, V, G, T>
{
type Graph = G;
type BaseGraph = G;
type Select = Const<bool>;
type Value = V;
type OwnedValue = V;
type OutputType = Self;
fn graph(&self) -> &Self::Graph {
&self.state.base_graph
}
fn base_graph(&self) -> &Self::BaseGraph {
&self.state.base_graph
}
fn iter_values(&'a self) -> impl Iterator<Item = Self::Value> + 'a {
RecordBatchIterator::new(&self.state.values)
}
#[allow(refining_impl_trait)]
fn par_iter_values(&'a self) -> impl IndexedParallelIterator<Item = Self::Value> + 'a {
let iter = RecordBatchIterator::<'a, Self::Value>::new(&self.state.values);
(0..self.len())
.into_par_iter()
.map(move |i| RecordBatchIterator::get(&iter, i).unwrap())
}
fn into_iter_values(self) -> impl Iterator<Item = Self::OwnedValue> + Send + Sync {
(0..self.len()).map(move |i| self.get_by_index(i).unwrap().1)
}
#[allow(refining_impl_trait)]
fn into_par_iter_values(self) -> impl IndexedParallelIterator<Item = Self::OwnedValue> {
(0..self.len())
.into_par_iter()
.map(move |i| self.get_by_index(i).unwrap().1)
}
fn iter(&'a self) -> impl Iterator<Item = (NodeView<'a, &'a Self::Graph>, Self::Value)> + 'a {
match &self.state.keys {
Some(index) => index
.iter()
.zip(self.iter_values())
.map(|(n, v)| (NodeView::new_internal(&self.state.base_graph, n), v))
.into_dyn_boxed(),
None => self
.iter_values()
.enumerate()
.map(|(i, v)| (NodeView::new_internal(&self.state.base_graph, VID(i)), v))
.into_dyn_boxed(),
}
}
fn nodes(&self) -> Nodes<'graph, Self::Graph> {
self.state.nodes()
}
fn par_iter(
&'a self,
) -> impl ParallelIterator<Item = (NodeView<'a, &'a Self::Graph>, Self::Value)> {
match &self.state.keys {
Some(index) => Either::Left(
index
.par_iter()
.zip(self.par_iter_values())
.map(|(n, v)| (NodeView::new_internal(&self.state.base_graph, n), v)),
),
None => Either::Right(
self.par_iter_values()
.enumerate()
.map(|(i, v)| (NodeView::new_internal(&self.state.base_graph, VID(i)), v)),
),
}
}
fn get_by_index(
&'a self,
index: usize,
) -> Option<(NodeView<'a, &'a Self::Graph>, Self::Value)> {
let vid = match &self.state.keys {
Some(node_index) => node_index.key(index),
None => Some(VID(index)),
};
if let Some(vid) = vid {
Some((
NodeView::new_internal(&self.state.base_graph, vid),
self.get_by_node(vid).unwrap(), ))
} else {
return None;
}
}
fn get_by_node<N: AsNodeRef>(&'a self, node: N) -> Option<Self::Value> {
let index = self.state.get_index_by_node(&node)?;
let deserializer = Deserializer::from_record_batch(&self.state.values).unwrap();
let item = V::deserialize(
deserializer
.get(index)
.ok_or_else(|| tracing::error!("Could not get item"))
.unwrap(),
)
.unwrap();
Some(item)
}
fn len(&self) -> usize {
self.state.len()
}
fn construct(
&self,
base_graph: Self::BaseGraph,
_graph: Self::Graph,
keys: IndexSet<VID, ahash::RandomState>,
values: Vec<Self::OwnedValue>,
) -> Self::OutputType
where
Self::BaseGraph: 'graph,
Self::Graph: 'graph,
{
let state = GenericNodeState::new_from_eval_with_index(
base_graph,
values,
Some(Index::new(keys)),
Some(self.state.node_cols.clone()),
);
TypedNodeState::<'graph, V, Self::Graph, T> {
state: state,
converter: self.converter,
_v_marker: PhantomData,
_t_marker: PhantomData,
}
}
}
impl<'graph, G: GraphViewOps<'graph>> GenericNodeState<'graph, G> {
fn take_values<V: NodeStateValue>(index: &Option<Index<VID>>, values: Vec<V>) -> Vec<V> {
let Some(index) = index else {
return values;
};
let mut values: Vec<Option<V>> = values.into_iter().map(Some).collect();
index
.iter()
.map(|vid| {
values
.get_mut(vid.0)
.and_then(Option::take)
.expect("index out of bounds")
})
.collect()
}
pub fn new_from_eval_with_index<V: NodeStateValue>(
graph: G,
values: Vec<V>,
index: Option<Index<VID>>,
node_cols: Option<HashMap<String, (NodeStateOutputType, Option<G>)>>,
) -> Self {
Self::new_from_eval_with_index_mapped(graph, values, index, |v| v, node_cols)
}
pub fn new_from_eval_with_index_mapped<R: Clone, V: NodeStateValue>(
graph: G,
values: Vec<R>,
index: Option<Index<VID>>,
map: impl Fn(R) -> V,
node_cols: Option<HashMap<String, (NodeStateOutputType, Option<G>)>>,
) -> Self {
let values: Vec<V> = values.into_iter().map(map).collect();
let fields = Vec::<FieldRef>::from_type::<V>(TracingOptions::default()).unwrap();
let values = Self::take_values(&index, values);
let values = Self::convert_recordbatch(to_record_batch(&fields, &values).unwrap()).unwrap();
Self::new(graph, values, index, node_cols)
}
pub fn new_empty(graph: G) -> Self {
Self::new(
graph.clone(),
RecordBatch::new_empty(Schema::empty().into()),
Some(Index::default()),
None,
)
}
pub fn new_from_eval<V: NodeStateValue>(
graph: G,
values: Vec<V>,
node_cols: Option<HashMap<String, (NodeStateOutputType, Option<G>)>>,
) -> Self {
Self::new_from_eval_mapped(graph, values, |v| v, node_cols)
}
pub fn new_from_eval_mapped<R: Clone, V: NodeStateValue>(
graph: G,
values: Vec<R>,
map: impl Fn(R) -> V,
node_cols: Option<HashMap<String, (NodeStateOutputType, Option<G>)>>,
) -> Self {
let index = Index::for_graph(graph.clone());
Self::new_from_eval_with_index_mapped(graph.clone(), values, index, map, node_cols)
}
pub fn new_from_values(
graph: G,
values: impl Into<RecordBatch>,
node_cols: Option<HashMap<String, (NodeStateOutputType, Option<G>)>>,
) -> Self {
let index = Index::for_graph(&graph);
Self::new(
graph.clone(),
Self::convert_recordbatch(values.into()).unwrap(),
index,
node_cols,
)
}
pub fn new_from_map<R: NodeStateValue, S: BuildHasher, V: NodeStateValue>(
graph: G,
mut values: HashMap<VID, R, S>,
map: impl Fn(R) -> V,
node_cols: Option<HashMap<String, (NodeStateOutputType, Option<G>)>>,
) -> Self {
let fields = Vec::<FieldRef>::from_type::<V>(TracingOptions::default()).unwrap();
if values.len() == graph.count_nodes() {
let values: Vec<_> = graph
.nodes()
.iter()
.map(|node| map(values.remove(&node.node).unwrap()))
.collect();
let values = to_record_batch(&fields, &values).unwrap();
Self::new_from_values(graph, values, node_cols)
} else {
let (index, values): (IndexSet<VID, ahash::RandomState>, Vec<_>) = graph
.nodes()
.iter()
.flat_map(|node| Some((node.node, map(values.remove(&node.node)?))))
.unzip();
let values = to_record_batch(&fields, &values).unwrap();
Self::new(graph.clone(), values, Some(Index::new(index)), node_cols)
}
}
fn get_sort_exprs(
sort_params: IndexMap<String, Option<String>>,
schema: &Schema,
) -> Vec<PhysicalSortExpr> {
let sort_exprs: Result<Vec<PhysicalSortExpr>, ArrowError> = sort_params
.into_iter()
.map(|(name, sort_opt)| {
let options = match sort_opt.as_deref().map(|s| s.to_lowercase()).as_deref() {
Some("desc") => SortOptions {
descending: true,
nulls_first: false,
},
Some("asc") => SortOptions {
descending: false,
nulls_first: true,
},
_ => SortOptions::default(),
};
Ok(PhysicalSortExpr {
expr: col(&name, schema)?,
options,
})
})
.collect();
sort_exprs.unwrap()
}
pub fn sort_by(
&self,
sort_params: IndexMap<String, Option<String>>,
) -> Result<GenericNodeState<'graph, G>, GraphError> {
if self.values.num_rows() == 0 {
return Ok(GenericNodeState::new_empty(self.base_graph.clone())); }
let sort_exprs: Vec<PhysicalSortExpr> =
Self::get_sort_exprs(sort_params, &self.values().schema());
let sort_fields: Vec<SortField> = sort_exprs
.iter()
.map(|expr| {
let col = expr
.expr
.evaluate(self.values())
.map_err(|e| {
GraphError::IOErrorMsg(format!("Failed to evaluate sort expr: {}", e))
})?
.into_array(self.values.num_rows())
.map_err(|e| {
GraphError::IOErrorMsg(format!("Failed to convert to array: {}", e))
})?;
Ok(SortField::new_with_options(
col.data_type().clone(),
arrow::compute::SortOptions {
descending: expr.options.descending,
nulls_first: expr.options.nulls_first,
},
))
})
.collect::<Result<Vec<_>, GraphError>>()
.map_err(|e| GraphError::IOErrorMsg(e.to_string()))?;
let converter = RowConverter::new(sort_fields)
.map_err(|e| GraphError::IOErrorMsg(format!("Failed to create RowConverter: {}", e)))?;
let sort_columns: Vec<_> = sort_exprs
.iter()
.map(|expr| {
expr.expr
.evaluate(self.values())
.map_err(|e| GraphError::IOErrorMsg(format!("Failed to evaluate: {}", e)))?
.into_array(self.values.num_rows())
.map_err(|e| GraphError::IOErrorMsg(format!("Failed to convert: {}", e)))
})
.collect::<Result<Vec<_>, GraphError>>()
.map_err(|e| GraphError::IOErrorMsg(e.to_string()))?;
let rows = converter.convert_columns(&sort_columns).map_err(|e| {
GraphError::IOErrorMsg(format!("Failed to convert to rows: {}", e).into())
})?;
let mut indices: Vec<u32> = (0..self.values.num_rows() as u32).collect();
let new_keys: Option<Index<VID>> = match &self.keys {
Some(keys) => Some(
indices
.iter()
.filter_map(|&i| keys.index.get_index(i as usize).copied())
.collect(),
),
None => Some(indices.iter().map(|&i| VID(i as usize)).collect()),
};
indices.par_sort_by(|&a, &b| rows.row(a as usize).cmp(&rows.row(b as usize)));
let indices_array = UInt32Array::from(indices);
let sorted_columns: Vec<_> = self
.values
.columns()
.iter()
.map(|col| {
take(col, &indices_array, None).map_err(|e| GraphError::IOErrorMsg(e.to_string()))
})
.collect::<Result<Vec<_>, GraphError>>()
.map_err(|e| GraphError::IOErrorMsg(e.to_string()))?;
Ok(GenericNodeState::new(
self.base_graph.clone(),
RecordBatch::try_new(self.values.schema(), sorted_columns)
.map_err(|e| GraphError::IOErrorMsg(e.to_string()))?,
new_keys,
Some(self.node_cols.clone()),
))
}
pub fn top_k(
&self,
sort_params: IndexMap<String, Option<String>>,
k: usize,
) -> Result<GenericNodeState<'graph, G>, GraphError> {
if self.values().num_rows() == 0 || k == 0 {
return Ok(GenericNodeState::new_empty(self.base_graph.clone()));
}
let sort_exprs: Vec<PhysicalSortExpr> =
Self::get_sort_exprs(sort_params, &self.values().schema());
let sort_fields: Vec<SortField> = sort_exprs
.iter()
.map(|e| {
Ok(SortField::new_with_options(
e.expr
.data_type(self.values().schema().as_ref())
.map_err(|e| GraphError::IOErrorMsg(e.to_string()))?,
e.options,
))
})
.collect::<Result<_, GraphError>>()
.map_err(|e| GraphError::IOErrorMsg(e.to_string()))?;
let converter =
RowConverter::new(sort_fields).map_err(|e| GraphError::IOErrorMsg(e.to_string()))?;
let sort_columns: Vec<ArrayRef> = sort_exprs
.iter()
.map(|e| {
e.expr
.evaluate(self.values())
.map_err(|e| GraphError::IOErrorMsg(e.to_string()))?
.into_array(self.values().num_rows())
.map_err(|e| GraphError::IOErrorMsg(e.to_string()))
})
.collect::<Result<_, GraphError>>()
.map_err(|e| GraphError::IOErrorMsg(e.to_string()))?;
let rows = converter
.convert_columns(&sort_columns)
.map_err(|e| GraphError::IOErrorMsg(e.to_string()))?;
let mut heap = BinaryHeap::<HeapRow>::with_capacity(k + 1);
for i in 0..self.values().num_rows() {
let row = rows.row(i);
if heap.len() < k {
heap.push(HeapRow {
row: row.as_ref().to_vec(),
index: i,
});
} else if let Some(max) = heap.peek() {
if row.as_ref() < max.row.as_slice() {
heap.pop();
heap.push(HeapRow {
row: row.as_ref().to_vec(),
index: i,
});
}
}
}
let sorted: Vec<HeapRow> = heap.into_sorted_vec();
let batches = [self.values()];
let indices: Vec<(usize, usize)> = sorted.iter().map(|r| (0, r.index)).collect();
let new_keys: Option<Index<VID>> = match &self.keys {
Some(keys) => Some(
indices
.iter()
.filter_map(|(_, i)| keys.index.get_index(*i as usize).copied())
.collect(),
),
None => Some(indices.iter().map(|(_, i)| VID(*i as usize)).collect()),
};
Ok(GenericNodeState::new(
self.base_graph.clone(),
interleave_record_batch(&batches, &indices)
.map_err(|e| GraphError::IOErrorMsg(e.to_string()))?,
new_keys,
Some(self.node_cols.clone()),
))
}
}