use crate::algo::ProjectionBuilder;
use crate::algo::algorithms::Algorithm;
use crate::algo::procedures::{
AlgoContext, AlgoProcedure, AlgoResultRow, ProcedureSignature, ValueType,
};
use anyhow::{Result, anyhow};
use futures::stream::{self, BoxStream, StreamExt};
use serde_json::Value;
use std::marker::PhantomData;
use uni_common::core::id::Vid;
pub fn parse_vid_arg(value: &Value, name: &str) -> Result<Vid> {
let raw = match value {
Value::Number(n) => n
.as_u64()
.ok_or_else(|| anyhow!("`{name}` must be a non-negative integer"))?,
Value::String(s) => parse_vid_string(s)
.ok_or_else(|| anyhow!("`{name}` string must parse as a u64, got {s:?}"))?,
other => {
return Err(anyhow!(
"`{name}` must be an integer (or integer-string); got {other:?}"
));
}
};
Ok(Vid::from(raw))
}
fn parse_vid_string(s: &str) -> Option<u64> {
if let Ok(id) = s.parse::<u64>() {
return Some(id);
}
let (label, offset) = s.split_once(':')?;
let label = label.parse::<u16>().ok()?;
let offset = offset.parse::<u64>().ok()?;
Some((label as u64) << 48 | offset)
}
pub fn arg_f64(args: &[Value], i: usize, name: &str) -> Result<f64> {
args.get(i)
.and_then(Value::as_f64)
.ok_or_else(|| anyhow!("`{name}` must be a number"))
}
pub fn arg_u64(args: &[Value], i: usize, name: &str) -> Result<u64> {
args.get(i)
.and_then(Value::as_u64)
.ok_or_else(|| anyhow!("`{name}` must be a non-negative integer"))
}
pub fn arg_bool(args: &[Value], i: usize, name: &str) -> Result<bool> {
args.get(i)
.and_then(Value::as_bool)
.ok_or_else(|| anyhow!("`{name}` must be a boolean"))
}
pub fn arg_str<'a>(args: &'a [Value], i: usize, name: &str) -> Result<&'a str> {
args.get(i)
.and_then(Value::as_str)
.ok_or_else(|| anyhow!("`{name}` must be a string"))
}
pub fn arg_string_list(args: &[Value], i: usize, name: &str) -> Result<Vec<String>> {
args.get(i)
.and_then(Value::as_array)
.ok_or_else(|| anyhow!("`{name}` must be an array"))?
.iter()
.map(|v| {
v.as_str()
.map(str::to_owned)
.ok_or_else(|| anyhow!("`{name}` entries must be strings"))
})
.collect()
}
pub fn err_stream(e: anyhow::Error) -> BoxStream<'static, Result<AlgoResultRow>> {
stream::once(async move { Err(e) }).boxed()
}
pub trait GraphAlgoAdapter: Send + Sync + 'static {
const NAME: &'static str;
type Algo: Algorithm;
fn specific_args() -> Vec<(&'static str, ValueType, Option<Value>)>;
fn yields() -> Vec<(&'static str, ValueType)>;
fn to_config(args: Vec<Value>) -> Result<<Self::Algo as Algorithm>::Config>;
fn map_result(result: <Self::Algo as Algorithm>::Result) -> Result<Vec<AlgoResultRow>>;
fn include_reverse() -> bool {
true
}
fn weight_arg_index() -> Option<usize> {
None
}
fn customize_projection(builder: ProjectionBuilder, args: &[Value]) -> ProjectionBuilder {
let builder = match Self::weight_arg_index().and_then(|i| args.get(i)) {
Some(arg) => match arg.as_str() {
Some(prop) => builder.weight_property(prop),
None => builder,
},
None => builder,
};
builder.include_reverse(Self::include_reverse())
}
}
pub struct GenericAlgoProcedure<A: GraphAlgoAdapter> {
_marker: PhantomData<A>,
}
impl<A: GraphAlgoAdapter> GenericAlgoProcedure<A> {
pub fn new() -> Self {
Self {
_marker: PhantomData,
}
}
}
impl<A: GraphAlgoAdapter> Default for GenericAlgoProcedure<A> {
fn default() -> Self {
Self::new()
}
}
impl<A: GraphAlgoAdapter> AlgoProcedure for GenericAlgoProcedure<A>
where
<A::Algo as Algorithm>::Result: Send + 'static,
{
fn name(&self) -> &str {
A::NAME
}
fn signature(&self) -> ProcedureSignature {
let mut args = vec![
("nodeLabels", ValueType::List),
("relationshipTypes", ValueType::List),
];
let mut optional_args = Vec::new();
for (name, ty, default) in A::specific_args() {
if let Some(def) = default {
optional_args.push((name, ty, def));
} else {
args.push((name, ty));
}
}
ProcedureSignature {
args,
optional_args,
yields: A::yields(),
}
}
fn execute_with_projection(
&self,
_ctx: AlgoContext,
args: Vec<Value>,
projection: crate::algo::GraphProjection,
) -> BoxStream<'static, Result<AlgoResultRow>> {
let signature = self.signature();
let args = match signature.validate_args(args) {
Ok(a) => a,
Err(e) => return stream::once(async { Err(e) }).boxed(),
};
let specific_args = args[2..].to_vec();
let stream = async_stream::try_stream! {
let config = A::to_config(specific_args)?;
let result = tokio::task::spawn_blocking(move || {
A::Algo::run(&projection, config)
}).await?;
let rows = A::map_result(result)?;
for row in rows {
yield row;
}
};
Box::pin(stream)
}
fn customize_projection(
&self,
builder: ProjectionBuilder,
args: &[Value],
) -> ProjectionBuilder {
A::customize_projection(builder, args)
}
}
pub async fn build_projection_from_direct_args(
proc: &dyn AlgoProcedure,
ctx: &AlgoContext,
args: &[Value],
) -> Result<crate::algo::GraphProjection> {
let node_labels: Vec<String> = args
.first()
.and_then(Value::as_array)
.ok_or_else(|| anyhow!("args[0] must be an array of node-label names"))?
.iter()
.map(|v| {
v.as_str()
.ok_or_else(|| anyhow!("node-label must be a string"))
.map(str::to_owned)
})
.collect::<Result<Vec<_>>>()?;
let edge_types: Vec<String> = args
.get(1)
.and_then(Value::as_array)
.ok_or_else(|| anyhow!("args[1] must be an array of edge-type names"))?
.iter()
.map(|v| {
v.as_str()
.ok_or_else(|| anyhow!("edge-type must be a string"))
.map(str::to_owned)
})
.collect::<Result<Vec<_>>>()?;
let schema = ctx.storage.schema_manager().schema();
for label in &node_labels {
if !schema.labels.contains_key(label) {
return Err(anyhow!("Label '{label}' not found"));
}
}
for etype in &edge_types {
if !schema.edge_types.contains_key(etype) {
return Err(anyhow!("Edge type '{etype}' not found"));
}
}
let builder = ProjectionBuilder::new(ctx.storage.clone())
.l0_manager(ctx.l0_manager.clone())
.node_labels(&node_labels.iter().map(String::as_str).collect::<Vec<_>>())
.edge_types(&edge_types.iter().map(String::as_str).collect::<Vec<_>>());
let specific_args = &args[2..];
let builder = proc.customize_projection(builder, specific_args);
builder.build().await
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn parse_vid_arg_accepts_number() {
let vid = parse_vid_arg(&json!(42_u64), "node").expect("number must parse");
assert_eq!(vid.as_u64(), 42);
}
#[test]
fn parse_vid_arg_accepts_numeric_string() {
let vid = parse_vid_arg(&json!("17"), "node").expect("numeric string must parse");
assert_eq!(vid.as_u64(), 17);
}
#[test]
fn parse_vid_arg_rejects_non_numeric_string() {
let err = parse_vid_arg(&json!("abc"), "source").unwrap_err();
assert!(
err.to_string().contains("`source`"),
"error should name the arg: {err}"
);
}
#[test]
fn parse_vid_arg_rejects_negative_number() {
let err = parse_vid_arg(&json!(-1_i64), "source").unwrap_err();
assert!(err.to_string().contains("non-negative"), "error: {err}");
}
#[test]
fn parse_vid_arg_rejects_wrong_type() {
let err = parse_vid_arg(&json!(true), "source").unwrap_err();
assert!(err.to_string().contains("`source`"), "error: {err}");
}
}