use std::{cell::RefCell, collections::HashMap};
use anyhow::Result;
use arrow::record_batch::RecordBatch;
use crate::{
dsl::{StateRow, Value, arrow_value::ColumnReader},
graph::{EdgeId, Graph, GraphId, GraphRepo, NodeId},
};
pub(crate) struct PayloadCache {
node: RefCell<HashMap<String, ColumnReader>>,
edge: RefCell<HashMap<String, ColumnReader>>,
}
impl PayloadCache {
pub(crate) fn new() -> Self {
Self {
node: RefCell::new(HashMap::new()),
edge: RefCell::new(HashMap::new()),
}
}
fn node_reader(&self, batch: &RecordBatch, col: &str) -> Result<ColumnReader> {
Self::reader(&self.node, batch, col)
}
fn edge_reader(&self, batch: &RecordBatch, col: &str) -> Result<ColumnReader> {
Self::reader(&self.edge, batch, col)
}
fn reader(
map: &RefCell<HashMap<String, ColumnReader>>,
batch: &RecordBatch,
col: &str,
) -> Result<ColumnReader> {
if let Some(reader) = map.borrow().get(col) {
return Ok(reader.clone());
}
let reader = ColumnReader::bind(batch, col)?;
map.borrow_mut().insert(col.to_string(), reader.clone());
Ok(reader)
}
}
pub trait Kernel {
type State: Clone;
fn initial_state(&self, graph: &Graph, start: NodeId) -> Self::State;
fn visit(&self, cx: &EdgeCtx<'_, Self::State>) -> Result<bool>;
fn next_state(&self, cx: &EdgeCtx<'_, Self::State>) -> Result<Self::State>;
fn stop(&self, cx: &EdgeCtx<'_, Self::State>) -> Result<bool>;
fn state_row(&self, state: &Self::State) -> StateRow;
}
pub struct EdgeCtx<'a, S> {
graph: &'a Graph,
src: NodeId,
dest: NodeId,
edge: EdgeId,
state: &'a S,
cache: &'a PayloadCache,
}
impl<'a, S> EdgeCtx<'a, S> {
pub(crate) fn new(
graph: &'a Graph,
src: NodeId,
dest: NodeId,
edge: EdgeId,
state: &'a S,
cache: &'a PayloadCache,
) -> Self {
Self {
graph,
src,
dest,
edge,
state,
cache,
}
}
pub fn with_state<'b>(&'b self, state: &'b S) -> EdgeCtx<'b, S> {
EdgeCtx {
graph: self.graph,
src: self.src,
dest: self.dest,
edge: self.edge,
state,
cache: self.cache,
}
}
pub fn graph(&self) -> &'a Graph {
self.graph
}
pub fn src(&self) -> NodeId {
self.src
}
pub fn dest(&self) -> NodeId {
self.dest
}
pub fn edge(&self) -> EdgeId {
self.edge
}
pub fn state(&self) -> &S {
self.state
}
pub fn src_id(&self) -> Option<GraphId<'a>> {
self.graph.repo.external_node(self.src)
}
pub fn dest_id(&self) -> Option<GraphId<'a>> {
self.graph.repo.external_node(self.dest)
}
pub fn edge_id(&self) -> Option<GraphId<'a>> {
self.graph.repo.external_edge(self.edge)
}
pub fn src_value(&self, col: &str) -> Result<Value> {
self.cache
.node_reader(self.graph.repo.node_batch(), col)?
.value(self.src as usize)
}
pub fn dest_value(&self, col: &str) -> Result<Value> {
self.cache
.node_reader(self.graph.repo.node_batch(), col)?
.value(self.dest as usize)
}
pub fn edge_value(&self, col: &str) -> Result<Value> {
self.cache
.edge_reader(self.graph.repo.edge_batch(), col)?
.value(self.edge as usize)
}
pub fn src_u64(&self, col: &str) -> Result<Option<u64>> {
as_u64(self.src_value(col)?)
}
pub fn dest_u64(&self, col: &str) -> Result<Option<u64>> {
as_u64(self.dest_value(col)?)
}
pub fn edge_u64(&self, col: &str) -> Result<Option<u64>> {
as_u64(self.edge_value(col)?)
}
pub fn src_i64(&self, col: &str) -> Result<Option<i64>> {
as_i64(self.src_value(col)?)
}
pub fn dest_i64(&self, col: &str) -> Result<Option<i64>> {
as_i64(self.dest_value(col)?)
}
pub fn edge_i64(&self, col: &str) -> Result<Option<i64>> {
as_i64(self.edge_value(col)?)
}
pub fn src_f64(&self, col: &str) -> Result<Option<f64>> {
as_f64(self.src_value(col)?)
}
pub fn dest_f64(&self, col: &str) -> Result<Option<f64>> {
as_f64(self.dest_value(col)?)
}
pub fn edge_f64(&self, col: &str) -> Result<Option<f64>> {
as_f64(self.edge_value(col)?)
}
pub fn src_bool(&self, col: &str) -> Result<Option<bool>> {
as_bool(self.src_value(col)?)
}
pub fn dest_bool(&self, col: &str) -> Result<Option<bool>> {
as_bool(self.dest_value(col)?)
}
pub fn edge_bool(&self, col: &str) -> Result<Option<bool>> {
as_bool(self.edge_value(col)?)
}
pub fn src_str(&self, col: &str) -> Result<Option<String>> {
as_str(self.src_value(col)?)
}
pub fn dest_str(&self, col: &str) -> Result<Option<String>> {
as_str(self.dest_value(col)?)
}
pub fn edge_str(&self, col: &str) -> Result<Option<String>> {
as_str(self.edge_value(col)?)
}
}
#[allow(clippy::float_cmp)]
fn as_u64(value: Value) -> Result<Option<u64>> {
match value {
Value::Null => Ok(None),
Value::U64(v) => Ok(Some(v)),
Value::I64(v) => {
if v >= 0 {
Ok(Some(v as u64))
} else {
anyhow::bail!("cannot read negative value {v} as u64")
}
}
Value::F64(v) => {
if v.is_finite() && v.fract() == 0.0 && v >= 0.0 && v <= u64::MAX as f64 {
Ok(Some(v as u64))
} else {
anyhow::bail!("cannot read f64 {v} as u64 without loss")
}
}
other => anyhow::bail!("expected an integer value, got {other:?}"),
}
}
#[allow(clippy::float_cmp)]
fn as_i64(value: Value) -> Result<Option<i64>> {
match value {
Value::Null => Ok(None),
Value::I64(v) => Ok(Some(v)),
Value::U64(v) => {
if v <= i64::MAX as u64 {
Ok(Some(v as i64))
} else {
anyhow::bail!("cannot read u64 {v} as i64 (out of range)")
}
}
Value::F64(v) => {
if v.is_finite() && v.fract() == 0.0 && v >= i64::MIN as f64 && v <= i64::MAX as f64 {
Ok(Some(v as i64))
} else {
anyhow::bail!("cannot read f64 {v} as i64 without loss")
}
}
other => anyhow::bail!("expected an integer value, got {other:?}"),
}
}
fn as_f64(value: Value) -> Result<Option<f64>> {
match value {
Value::Null => Ok(None),
Value::F64(v) => Ok(Some(v)),
Value::I64(v) => Ok(Some(v as f64)),
Value::U64(v) => Ok(Some(v as f64)),
other => anyhow::bail!("expected a numeric value, got {other:?}"),
}
}
fn as_bool(value: Value) -> Result<Option<bool>> {
match value {
Value::Null => Ok(None),
Value::Bool(v) => Ok(Some(v)),
other => anyhow::bail!("expected bool value, got {other:?}"),
}
}
fn as_str(value: Value) -> Result<Option<String>> {
match value {
Value::Null => Ok(None),
Value::Str(v) => Ok(Some(v.to_string())),
other => anyhow::bail!("expected string value, got {other:?}"),
}
}
#[cfg(test)]
mod tests {
use super::{as_f64, as_i64, as_u64};
use crate::dsl::Value;
#[test]
fn as_u64_coercion() {
assert_eq!(as_u64(Value::U64(5)).unwrap(), Some(5));
assert_eq!(as_u64(Value::I64(5)).unwrap(), Some(5));
assert!(as_u64(Value::I64(-1)).is_err());
assert_eq!(as_u64(Value::F64(3.0)).unwrap(), Some(3));
assert!(as_u64(Value::F64(3.5)).is_err());
assert!(as_u64(Value::F64(-1.0)).is_err());
assert_eq!(as_u64(Value::Null).unwrap(), None);
assert!(as_u64(Value::Bool(true)).is_err());
}
#[test]
fn as_i64_coercion() {
assert_eq!(as_i64(Value::I64(-3)).unwrap(), Some(-3));
assert_eq!(as_i64(Value::U64(7)).unwrap(), Some(7));
assert!(as_i64(Value::U64(u64::MAX)).is_err());
assert_eq!(as_i64(Value::F64(4.0)).unwrap(), Some(4));
assert!(as_i64(Value::F64(4.5)).is_err());
assert_eq!(as_i64(Value::Null).unwrap(), None);
}
#[test]
fn as_f64_coercion() {
assert_eq!(as_f64(Value::F64(1.5)).unwrap(), Some(1.5));
assert_eq!(as_f64(Value::I64(-2)).unwrap(), Some(-2.0));
assert_eq!(as_f64(Value::U64(9)).unwrap(), Some(9.0));
assert_eq!(as_f64(Value::Null).unwrap(), None);
assert!(as_f64(Value::Bool(true)).is_err());
}
}