use std::sync::{Arc, OnceLock};
use std::time::SystemTime;
use arrow_array::{ArrayRef, BooleanArray, Int64Array, RecordBatch, StringArray};
use arrow_schema::{DataType, Field, Schema, SchemaRef};
use datafusion::execution::SendableRecordBatchStream;
use datafusion::logical_expr::ColumnarValue;
use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
use datafusion::scalar::ScalarValue;
use uni_algo::{ProjectionInput, parse_graph_ref};
use uni_plugin::traits::procedure::{
NamedArgType, ProcedureContext, ProcedureMode, ProcedurePlugin, ProcedureSignature,
};
use uni_plugin::traits::scalar::ArgType;
use uni_plugin::{FnError, PluginError, PluginRegistrar, QName, SideEffects};
use crate::projection_store::{ProjectionEntry, ProjectionSourceKind, estimate_bytes, for_storage};
use crate::query::executor::procedure_host::QueryProcedureHost;
pub fn register_into(r: &mut PluginRegistrar<'_>) -> Result<(), PluginError> {
r.procedure(
QName::new("uni", "graph.project"),
ProjectProcedure::signature_static().clone(),
Arc::new(ProjectProcedure),
)?;
r.procedure(
QName::new("uni", "graph.drop"),
DropProcedure::signature_static().clone(),
Arc::new(DropProcedure),
)?;
r.procedure(
QName::new("uni", "graph.list"),
ListProcedure::signature_static().clone(),
Arc::new(ListProcedure),
)?;
r.procedure(
QName::new("uni", "graph.exists"),
ExistsProcedure::signature_static().clone(),
Arc::new(ExistsProcedure),
)?;
Ok(())
}
fn require_host<'a>(ctx: &ProcedureContext<'a>) -> Result<&'a QueryProcedureHost, FnError> {
ctx.host
.and_then(|h| h.as_any().downcast_ref::<QueryProcedureHost>())
.ok_or_else(|| FnError::new(0x701, "uni.graph.*: requires QueryProcedureHost"))
}
fn arg_to_json(cv: &ColumnarValue) -> serde_json::Value {
match cv {
ColumnarValue::Scalar(ScalarValue::LargeBinary(Some(b)))
| ColumnarValue::Scalar(ScalarValue::Binary(Some(b))) => {
serde_json::from_slice::<serde_json::Value>(b).unwrap_or(serde_json::Value::Null)
}
ColumnarValue::Scalar(ScalarValue::Utf8(Some(s)))
| ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(s))) => {
serde_json::Value::String(s.clone())
}
ColumnarValue::Scalar(ScalarValue::Boolean(Some(b))) => serde_json::Value::Bool(*b),
ColumnarValue::Scalar(ScalarValue::Int64(Some(i))) => {
serde_json::Value::Number((*i).into())
}
_ => serde_json::Value::Null,
}
}
fn arg_as_string(cv: &ColumnarValue) -> Option<String> {
match cv {
ColumnarValue::Scalar(ScalarValue::Utf8(Some(s)))
| ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(s))) => Some(s.clone()),
_ => None,
}
}
fn one_row_stream(
schema: SchemaRef,
cols: Vec<ArrayRef>,
) -> Result<SendableRecordBatchStream, FnError> {
let batch = RecordBatch::try_new(Arc::clone(&schema), cols)
.map_err(|e| FnError::new(0x830, format!("RecordBatch build: {e}")))?;
let stream =
futures::stream::once(async move { Ok::<_, datafusion::error::DataFusionError>(batch) });
Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream)))
}
#[derive(Debug)]
pub struct ProjectProcedure;
impl ProjectProcedure {
fn signature_static() -> &'static ProcedureSignature {
static SIG: OnceLock<ProcedureSignature> = OnceLock::new();
SIG.get_or_init(|| ProcedureSignature {
args: vec![
NamedArgType {
name: smol_str::SmolStr::new("name"),
ty: ArgType::Primitive(DataType::Utf8),
default: None,
doc: "Name to register the materialised projection under.".to_owned(),
},
NamedArgType {
name: smol_str::SmolStr::new("graphRef"),
ty: ArgType::Primitive(DataType::LargeBinary),
default: None,
doc: "Native or Cypher projection descriptor (Map).".to_owned(),
},
NamedArgType {
name: smol_str::SmolStr::new("config"),
ty: ArgType::Primitive(DataType::LargeBinary),
default: Some(ScalarValue::LargeBinary(Some(b"{}".to_vec()))),
doc: "Materialisation options (currently unused).".to_owned(),
},
],
yields: vec![
Field::new("name", DataType::Utf8, false),
Field::new("node_count", DataType::Int64, false),
Field::new("edge_count", DataType::Int64, false),
Field::new("bytes", DataType::Int64, false),
],
mode: ProcedureMode::Read, side_effects: SideEffects::ReadOnly,
retry_contract: None,
batch_input: None,
docs: "uni.graph.project(name, graphRef, config) — materialise \
a named graph projection from a Native or Cypher graphRef \
(no Named-of-Named). v1: in-memory, restart-clears."
.to_owned(),
})
}
}
impl ProcedurePlugin for ProjectProcedure {
fn signature(&self) -> &ProcedureSignature {
Self::signature_static()
}
fn invoke(
&self,
ctx: ProcedureContext<'_>,
args: &[ColumnarValue],
) -> Result<SendableRecordBatchStream, FnError> {
let host = require_host(&ctx)?.clone();
let name = args
.first()
.and_then(arg_as_string)
.ok_or_else(|| FnError::new(0x824, "uni.graph.project: name (String) required"))?;
let graph_ref = args
.get(1)
.map(arg_to_json)
.ok_or_else(|| FnError::new(0x824, "uni.graph.project: graphRef (Map) required"))?;
let projection_input = parse_graph_ref(&graph_ref)
.map_err(|e| FnError::new(0x820, format!("graphRef parse: {e}")))?;
let store = for_storage(host.storage());
if store.contains(&name) {
return Err(FnError::new(
0x824,
format!("uni.graph.project: projection `{name}` already exists; drop first"),
));
}
if let ProjectionInput::Named { .. } = &projection_input {
return Err(FnError::new(
0x824,
"uni.graph.project: graphRef cannot itself be Named \
(no projection-of-a-projection in v1)",
));
}
let schema: SchemaRef = Arc::new(Schema::new(Self::signature_static().yields.clone()));
let name_for_async = name.clone();
let store_for_async = Arc::clone(&store);
let schema_in_fut = Arc::clone(&schema);
let fut = async move {
let (projection, source_kind) = match projection_input {
ProjectionInput::Native {
node_labels,
edge_types,
weight_property,
include_reverse,
} => {
let storage = Arc::clone(host.storage());
let l0 = build_l0_manager(&host);
let mut builder = uni_algo::ProjectionBuilder::new(storage)
.l0_manager(l0)
.node_labels(&node_labels.iter().map(String::as_str).collect::<Vec<_>>())
.edge_types(&edge_types.iter().map(String::as_str).collect::<Vec<_>>())
.include_reverse(include_reverse);
if let Some(wp) = weight_property {
builder = builder.weight_property(&wp);
}
let projection = builder.build().await.map_err(|e| {
datafusion::error::DataFusionError::Execution(format!(
"uni.graph.project (Native): {e}"
))
})?;
(projection, ProjectionSourceKind::Native)
}
ProjectionInput::Cypher {
node_query,
edge_query,
weight_column,
include_reverse,
} => {
let inner_params = std::collections::HashMap::new();
let node_rows = host
.execute_inner_query(
&node_query,
&inner_params,
uni_plugin::traits::procedure::ProcedureMode::Read,
)
.await
.map_err(|e| {
datafusion::error::DataFusionError::Execution(format!(
"uni.graph.project node query: {e}"
))
})?;
let edge_rows = host
.execute_inner_query(
&edge_query,
&inner_params,
uni_plugin::traits::procedure::ProcedureMode::Read,
)
.await
.map_err(|e| {
datafusion::error::DataFusionError::Execution(format!(
"uni.graph.project edge query: {e}"
))
})?;
let projection = uni_algo::algo::projection::GraphProjection::from_rows(
&node_rows,
&edge_rows,
weight_column.as_deref(),
include_reverse,
)
.map_err(|e| {
datafusion::error::DataFusionError::Execution(format!(
"uni.graph.project (Cypher): {e}"
))
})?;
(projection, ProjectionSourceKind::Cypher)
}
ProjectionInput::Named { .. } => unreachable!("filtered above"),
};
let node_count = projection.vertex_count();
let edge_count = projection.edge_count();
let bytes = estimate_bytes(&projection);
let entry = ProjectionEntry {
projection: Arc::new(projection),
node_count,
edge_count,
bytes,
created_at: SystemTime::now(),
source_kind,
};
store_for_async
.insert(name_for_async.clone(), entry)
.map_err(|n| {
datafusion::error::DataFusionError::Execution(format!(
"uni.graph.project: projection `{n}` already exists"
))
})?;
let cols: Vec<ArrayRef> = vec![
Arc::new(StringArray::from(vec![name_for_async])),
Arc::new(Int64Array::from(vec![node_count as i64])),
Arc::new(Int64Array::from(vec![edge_count as i64])),
Arc::new(Int64Array::from(vec![bytes as i64])),
];
RecordBatch::try_new(schema_in_fut, cols).map_err(|e| {
datafusion::error::DataFusionError::Execution(format!("RecordBatch: {e}"))
})
};
let stream = futures::stream::once(fut);
Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream)))
}
}
#[derive(Debug)]
pub struct DropProcedure;
impl DropProcedure {
fn signature_static() -> &'static ProcedureSignature {
static SIG: OnceLock<ProcedureSignature> = OnceLock::new();
SIG.get_or_init(|| ProcedureSignature {
args: vec![NamedArgType {
name: smol_str::SmolStr::new("name"),
ty: ArgType::Primitive(DataType::Utf8),
default: None,
doc: "Projection name to evict from the store.".to_owned(),
}],
yields: vec![Field::new("dropped", DataType::Boolean, false)],
mode: ProcedureMode::Read,
side_effects: SideEffects::ReadOnly,
retry_contract: None,
batch_input: None,
docs: "uni.graph.drop(name) — remove a named projection. Returns \
`false` if no projection by that name existed."
.to_owned(),
})
}
}
impl ProcedurePlugin for DropProcedure {
fn signature(&self) -> &ProcedureSignature {
Self::signature_static()
}
fn invoke(
&self,
ctx: ProcedureContext<'_>,
args: &[ColumnarValue],
) -> Result<SendableRecordBatchStream, FnError> {
let host = require_host(&ctx)?;
let name = args
.first()
.and_then(arg_as_string)
.ok_or_else(|| FnError::new(0x824, "uni.graph.drop: name (String) required"))?;
let dropped = for_storage(host.storage()).drop_by_name(&name);
let schema: SchemaRef = Arc::new(Schema::new(Self::signature_static().yields.clone()));
let cols: Vec<ArrayRef> = vec![Arc::new(BooleanArray::from(vec![dropped]))];
one_row_stream(schema, cols)
}
}
#[derive(Debug)]
pub struct ListProcedure;
impl ListProcedure {
fn signature_static() -> &'static ProcedureSignature {
static SIG: OnceLock<ProcedureSignature> = OnceLock::new();
SIG.get_or_init(|| ProcedureSignature {
args: vec![],
yields: vec![
Field::new("name", DataType::Utf8, false),
Field::new("node_count", DataType::Int64, false),
Field::new("edge_count", DataType::Int64, false),
Field::new("bytes", DataType::Int64, false),
Field::new("created_at_ms", DataType::Int64, false),
Field::new("source_kind", DataType::Utf8, false),
],
mode: ProcedureMode::Read,
side_effects: SideEffects::ReadOnly,
retry_contract: None,
batch_input: None,
docs: "uni.graph.list — one row per stored projection. \
`source_kind` is `Native` or `Cypher`."
.to_owned(),
})
}
}
impl ProcedurePlugin for ListProcedure {
fn signature(&self) -> &ProcedureSignature {
Self::signature_static()
}
fn invoke(
&self,
ctx: ProcedureContext<'_>,
_args: &[ColumnarValue],
) -> Result<SendableRecordBatchStream, FnError> {
let host = require_host(&ctx)?;
let entries = for_storage(host.storage()).list();
let mut names = Vec::with_capacity(entries.len());
let mut nodes = Vec::with_capacity(entries.len());
let mut edges = Vec::with_capacity(entries.len());
let mut bytes = Vec::with_capacity(entries.len());
let mut created = Vec::with_capacity(entries.len());
let mut kinds = Vec::with_capacity(entries.len());
for (name, e) in entries {
names.push(name);
nodes.push(e.node_count as i64);
edges.push(e.edge_count as i64);
bytes.push(e.bytes as i64);
let ms = e
.created_at
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_millis() as i64)
.unwrap_or(0);
created.push(ms);
kinds.push(e.source_kind.as_str().to_owned());
}
let schema: SchemaRef = Arc::new(Schema::new(Self::signature_static().yields.clone()));
let cols: Vec<ArrayRef> = vec![
Arc::new(StringArray::from(names)),
Arc::new(Int64Array::from(nodes)),
Arc::new(Int64Array::from(edges)),
Arc::new(Int64Array::from(bytes)),
Arc::new(Int64Array::from(created)),
Arc::new(StringArray::from(kinds)),
];
let batch = RecordBatch::try_new(Arc::clone(&schema), cols)
.map_err(|e| FnError::new(0x830, format!("RecordBatch build: {e}")))?;
let stream =
futures::stream::once(
async move { Ok::<_, datafusion::error::DataFusionError>(batch) },
);
Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream)))
}
}
#[derive(Debug)]
pub struct ExistsProcedure;
impl ExistsProcedure {
fn signature_static() -> &'static ProcedureSignature {
static SIG: OnceLock<ProcedureSignature> = OnceLock::new();
SIG.get_or_init(|| ProcedureSignature {
args: vec![NamedArgType {
name: smol_str::SmolStr::new("name"),
ty: ArgType::Primitive(DataType::Utf8),
default: None,
doc: "Projection name to probe.".to_owned(),
}],
yields: vec![Field::new("exists", DataType::Boolean, false)],
mode: ProcedureMode::Read,
side_effects: SideEffects::ReadOnly,
retry_contract: None,
batch_input: None,
docs: "uni.graph.exists(name) — `true` iff a projection by that \
name is currently in the store."
.to_owned(),
})
}
}
impl ProcedurePlugin for ExistsProcedure {
fn signature(&self) -> &ProcedureSignature {
Self::signature_static()
}
fn invoke(
&self,
ctx: ProcedureContext<'_>,
args: &[ColumnarValue],
) -> Result<SendableRecordBatchStream, FnError> {
let host = require_host(&ctx)?;
let name = args
.first()
.and_then(arg_as_string)
.ok_or_else(|| FnError::new(0x824, "uni.graph.exists: name (String) required"))?;
let exists = for_storage(host.storage()).contains(&name);
let schema: SchemaRef = Arc::new(Schema::new(Self::signature_static().yields.clone()));
let cols: Vec<ArrayRef> = vec![Arc::new(BooleanArray::from(vec![exists]))];
one_row_stream(schema, cols)
}
}
fn build_l0_manager(host: &QueryProcedureHost) -> Option<Arc<uni_store::runtime::L0Manager>> {
use uni_store::runtime::L0Manager;
let l0_ctx = host.l0_context();
l0_ctx.current_l0.as_ref().map(|current| {
let mut pending = l0_ctx.pending_flush_l0s.clone();
if let Some(tx_l0) = &l0_ctx.transaction_l0 {
pending.push(tx_l0.clone());
}
Arc::new(L0Manager::from_snapshot(current.clone(), pending))
})
}