use crate::backend::StorageBackend;
use crate::backend::table_names;
use crate::backend::types::{ScalarIndexType, ScanRequest, WriteMode};
use crate::storage::arrow_convert::build_timestamp_column_from_eid_map;
use anyhow::{Result, anyhow};
use arrow_array::builder::{LargeBinaryBuilder, StringBuilder};
use arrow_array::{Array, ArrayRef, BooleanArray, RecordBatch, UInt64Array};
use arrow_schema::{DataType, Field, Schema as ArrowSchema, TimeUnit};
use std::collections::HashMap;
use std::sync::Arc;
use uni_common::Properties;
use uni_common::core::id::{Eid, Vid};
#[derive(Debug)]
pub struct MainEdgeDataset {
_base_uri: String,
}
impl MainEdgeDataset {
pub fn new(base_uri: &str) -> Self {
Self {
_base_uri: base_uri.to_string(),
}
}
pub fn get_arrow_schema() -> Arc<ArrowSchema> {
Arc::new(ArrowSchema::new(vec![
Field::new("_eid", DataType::UInt64, false),
Field::new("src_vid", DataType::UInt64, false),
Field::new("dst_vid", DataType::UInt64, false),
Field::new("type", DataType::Utf8, false),
Field::new("props_json", DataType::LargeBinary, true),
Field::new("_deleted", DataType::Boolean, false),
Field::new("_version", DataType::UInt64, false),
Field::new(
"_created_at",
DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".into())),
true,
),
Field::new(
"_updated_at",
DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".into())),
true,
),
]))
}
pub fn table_name() -> &'static str {
"edges"
}
pub fn build_record_batch(
edges: &[(Eid, Vid, Vid, String, Properties, bool, u64)],
created_at: Option<&HashMap<Eid, i64>>,
updated_at: Option<&HashMap<Eid, i64>>,
) -> Result<RecordBatch> {
let arrow_schema = Self::get_arrow_schema();
let mut columns: Vec<ArrayRef> = Vec::with_capacity(arrow_schema.fields().len());
let eids: Vec<u64> = edges
.iter()
.map(|(e, _, _, _, _, _, _)| e.as_u64())
.collect();
columns.push(Arc::new(UInt64Array::from(eids)));
let src_vids: Vec<u64> = edges
.iter()
.map(|(_, s, _, _, _, _, _)| s.as_u64())
.collect();
columns.push(Arc::new(UInt64Array::from(src_vids)));
let dst_vids: Vec<u64> = edges
.iter()
.map(|(_, _, d, _, _, _, _)| d.as_u64())
.collect();
columns.push(Arc::new(UInt64Array::from(dst_vids)));
let mut type_builder = StringBuilder::new();
for (_, _, _, edge_type, _, _, _) in edges.iter() {
type_builder.append_value(edge_type);
}
columns.push(Arc::new(type_builder.finish()));
let mut props_json_builder = LargeBinaryBuilder::new();
for (_, _, _, _, props, _, _) in edges.iter() {
let jsonb_bytes = {
let json_val = serde_json::to_value(props).unwrap_or(serde_json::json!({}));
let uni_val: uni_common::Value = json_val.into();
uni_common::cypher_value_codec::encode(&uni_val)
};
props_json_builder.append_value(&jsonb_bytes);
}
columns.push(Arc::new(props_json_builder.finish()));
let deleted: Vec<bool> = edges.iter().map(|(_, _, _, _, _, d, _)| *d).collect();
columns.push(Arc::new(BooleanArray::from(deleted)));
let versions: Vec<u64> = edges.iter().map(|(_, _, _, _, _, _, v)| *v).collect();
columns.push(Arc::new(UInt64Array::from(versions)));
let eids = edges.iter().map(|(e, _, _, _, _, _, _)| *e);
columns.push(build_timestamp_column_from_eid_map(
eids.clone(),
created_at,
));
columns.push(build_timestamp_column_from_eid_map(eids, updated_at));
RecordBatch::try_new(arrow_schema, columns).map_err(|e| anyhow!(e))
}
pub async fn write_batch(backend: &dyn StorageBackend, batch: RecordBatch) -> Result<()> {
let table_name = table_names::main_edge_table_name();
if backend.table_exists(table_name).await? {
backend
.write(table_name, vec![batch], WriteMode::Append)
.await
} else {
backend.create_table(table_name, vec![batch]).await
}
}
pub async fn ensure_default_indexes(backend: &dyn StorageBackend) -> Result<()> {
let table_name = table_names::main_edge_table_name();
let _ = backend
.create_scalar_index(table_name, "_eid", ScalarIndexType::BTree)
.await;
let _ = backend
.create_scalar_index(table_name, "src_vid", ScalarIndexType::BTree)
.await;
let _ = backend
.create_scalar_index(table_name, "dst_vid", ScalarIndexType::BTree)
.await;
let _ = backend
.create_scalar_index(table_name, "type", ScalarIndexType::BTree)
.await;
Ok(())
}
pub async fn find_by_eid(
backend: &dyn StorageBackend,
eid: Eid,
) -> Result<Option<(Vid, Vid, String, Properties)>> {
let filter = format!("_eid = {}", eid.as_u64());
let results = Self::execute_query(backend, &filter, None).await?;
for batch in results {
if batch.num_rows() > 0 {
let src_vid_col = batch.column_by_name("src_vid");
let dst_vid_col = batch.column_by_name("dst_vid");
let type_col = batch.column_by_name("type");
let props_col = batch.column_by_name("props_json");
if let (Some(src), Some(dst), Some(typ), Some(props)) =
(src_vid_col, dst_vid_col, type_col, props_col)
&& let (Some(src_arr), Some(dst_arr), Some(type_arr), Some(props_arr)) = (
src.as_any().downcast_ref::<UInt64Array>(),
dst.as_any().downcast_ref::<UInt64Array>(),
typ.as_any().downcast_ref::<arrow_array::StringArray>(),
props
.as_any()
.downcast_ref::<arrow_array::LargeBinaryArray>(),
)
{
let src_vid = Vid::from(src_arr.value(0));
let dst_vid = Vid::from(dst_arr.value(0));
let edge_type = type_arr.value(0).to_string();
let properties: Properties = if props_arr.is_null(0)
|| props_arr.value(0).is_empty()
{
Properties::new()
} else {
let uni_val = uni_common::cypher_value_codec::decode(props_arr.value(0))
.unwrap_or(uni_common::Value::Null);
let json_val: serde_json::Value = uni_val.into();
serde_json::from_value(json_val).unwrap_or_default()
};
return Ok(Some((src_vid, dst_vid, edge_type, properties)));
}
}
}
Ok(None)
}
pub async fn exists_by_eid(backend: &dyn StorageBackend, eid: Eid) -> Result<bool> {
let filter = format!("_eid = {}", eid.as_u64());
let batches = Self::execute_query(backend, &filter, Some(vec!["_eid"])).await?;
Ok(!batches.is_empty() && batches.iter().any(|b| b.num_rows() > 0))
}
async fn execute_query(
backend: &dyn StorageBackend,
filter: &str,
columns: Option<Vec<&str>>,
) -> Result<Vec<RecordBatch>> {
let table_name = table_names::main_edge_table_name();
if !backend.table_exists(table_name).await? {
return Ok(Vec::new());
}
let mut request = ScanRequest::all(table_name).with_filter(filter);
if let Some(cols) = columns {
request = request.with_columns(cols.into_iter().map(String::from).collect());
}
backend.scan(request).await
}
fn extract_eids(batches: &[RecordBatch]) -> Vec<Eid> {
let mut eids = Vec::new();
for batch in batches {
if let Some(eid_col) = batch.column_by_name("_eid")
&& let Some(eid_arr) = eid_col.as_any().downcast_ref::<UInt64Array>()
{
for i in 0..eid_arr.len() {
if !eid_arr.is_null(i) {
eids.push(Eid::new(eid_arr.value(i)));
}
}
}
}
eids
}
pub async fn find_all_eids(backend: &dyn StorageBackend) -> Result<Vec<Eid>> {
let batches = Self::execute_query(backend, "_deleted = false", Some(vec!["_eid"])).await?;
Ok(Self::extract_eids(&batches))
}
pub async fn find_eids_by_type_name(
backend: &dyn StorageBackend,
type_name: &str,
) -> Result<Vec<Eid>> {
let filter = format!(
"_deleted = false AND type = '{}'",
type_name.replace('\'', "''")
);
let batches = Self::execute_query(backend, &filter, Some(vec!["_eid"])).await?;
Ok(Self::extract_eids(&batches))
}
pub async fn find_props_by_eid(
backend: &dyn StorageBackend,
eid: Eid,
) -> Result<Option<Properties>> {
let filter = format!("_eid = {} AND _deleted = false", eid.as_u64());
let batches =
Self::execute_query(backend, &filter, Some(vec!["props_json", "_version"])).await?;
if batches.is_empty() {
return Ok(None);
}
let mut best_props: Option<Properties> = None;
let mut best_version: u64 = 0;
for batch in &batches {
let props_col = batch.column_by_name("props_json");
let version_col = batch.column_by_name("_version");
if let (Some(props_arr), Some(ver_arr)) = (
props_col.and_then(|c| c.as_any().downcast_ref::<arrow_array::LargeBinaryArray>()),
version_col.and_then(|c| c.as_any().downcast_ref::<UInt64Array>()),
) {
for i in 0..batch.num_rows() {
let version = if ver_arr.is_null(i) {
0
} else {
ver_arr.value(i)
};
if version >= best_version {
best_version = version;
best_props = Some(Self::parse_props_json(props_arr, i)?);
}
}
}
}
Ok(best_props)
}
fn parse_props_json(arr: &arrow_array::LargeBinaryArray, idx: usize) -> Result<Properties> {
if arr.is_null(idx) || arr.value(idx).is_empty() {
return Ok(Properties::new());
}
let bytes = arr.value(idx);
let uni_val = uni_common::cypher_value_codec::decode(bytes)
.map_err(|e| anyhow!("Failed to decode CypherValue: {}", e))?;
let json_val: serde_json::Value = uni_val.into();
serde_json::from_value(json_val).map_err(|e| anyhow!("Failed to parse props_json: {}", e))
}
pub async fn find_type_by_eid(
backend: &dyn StorageBackend,
eid: Eid,
) -> Result<Option<String>> {
let filter = format!("_eid = {} AND _deleted = false", eid.as_u64());
let batches = Self::execute_query(backend, &filter, Some(vec!["type"])).await?;
for batch in batches {
if batch.num_rows() > 0
&& let Some(type_col) = batch.column_by_name("type")
&& let Some(type_arr) = type_col.as_any().downcast_ref::<arrow_array::StringArray>()
&& !type_arr.is_null(0)
{
return Ok(Some(type_arr.value(0).to_string()));
}
}
Ok(None)
}
pub async fn find_edges_by_type_name(
backend: &dyn StorageBackend,
type_name: &str,
) -> Result<Vec<(Eid, Vid, Vid, Properties)>> {
let filter = format!(
"_deleted = false AND type = '{}'",
type_name.replace('\'', "''")
);
let batches = Self::execute_query(backend, &filter, None).await?;
let mut edges = Vec::new();
for batch in &batches {
Self::extract_edges_from_batch(batch, &mut edges)?;
}
Ok(edges)
}
pub async fn find_edges_by_type_names(
backend: &dyn StorageBackend,
type_names: &[&str],
) -> Result<Vec<(Eid, Vid, Vid, String, Properties)>> {
if type_names.is_empty() {
return Ok(Vec::new());
}
let escaped_types: Vec<String> = type_names
.iter()
.map(|t| format!("'{}'", t.replace('\'', "''")))
.collect();
let filter = format!(
"_deleted = false AND type IN ({})",
escaped_types.join(", ")
);
let batches = Self::execute_query(backend, &filter, None).await?;
let mut edges = Vec::new();
for batch in &batches {
Self::extract_edges_with_type_from_batch(batch, &mut edges)?;
}
Ok(edges)
}
fn extract_edges_from_batch(
batch: &RecordBatch,
edges: &mut Vec<(Eid, Vid, Vid, Properties)>,
) -> Result<()> {
let mut edges_with_type = Vec::new();
Self::extract_edges_with_type_from_batch(batch, &mut edges_with_type)?;
edges.extend(
edges_with_type
.into_iter()
.map(|(eid, src, dst, _type, props)| (eid, src, dst, props)),
);
Ok(())
}
fn extract_edges_with_type_from_batch(
batch: &RecordBatch,
edges: &mut Vec<(Eid, Vid, Vid, String, Properties)>,
) -> Result<()> {
let Some(eid_arr) = batch
.column_by_name("_eid")
.and_then(|c| c.as_any().downcast_ref::<UInt64Array>())
else {
return Ok(());
};
let Some(src_arr) = batch
.column_by_name("src_vid")
.and_then(|c| c.as_any().downcast_ref::<UInt64Array>())
else {
return Ok(());
};
let Some(dst_arr) = batch
.column_by_name("dst_vid")
.and_then(|c| c.as_any().downcast_ref::<UInt64Array>())
else {
return Ok(());
};
let type_arr = batch
.column_by_name("type")
.and_then(|c| c.as_any().downcast_ref::<arrow_array::StringArray>());
let props_arr = batch
.column_by_name("props_json")
.and_then(|c| c.as_any().downcast_ref::<arrow_array::LargeBinaryArray>());
for i in 0..batch.num_rows() {
if eid_arr.is_null(i) || src_arr.is_null(i) || dst_arr.is_null(i) {
continue;
}
let eid = Eid::new(eid_arr.value(i));
let src_vid = Vid::new(src_arr.value(i));
let dst_vid = Vid::new(dst_arr.value(i));
let edge_type = type_arr
.filter(|arr| !arr.is_null(i))
.map(|arr| arr.value(i).to_string())
.unwrap_or_default();
let props = props_arr
.map(|arr| Self::parse_props_json(arr, i))
.transpose()?
.unwrap_or_default();
edges.push((eid, src_vid, dst_vid, edge_type, props));
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_main_edge_schema() {
let schema = MainEdgeDataset::get_arrow_schema();
assert_eq!(schema.fields().len(), 9);
assert!(schema.field_with_name("_eid").is_ok());
assert!(schema.field_with_name("src_vid").is_ok());
assert!(schema.field_with_name("dst_vid").is_ok());
assert!(schema.field_with_name("type").is_ok());
assert!(schema.field_with_name("props_json").is_ok());
assert!(schema.field_with_name("_deleted").is_ok());
assert!(schema.field_with_name("_version").is_ok());
assert!(schema.field_with_name("_created_at").is_ok());
assert!(schema.field_with_name("_updated_at").is_ok());
}
#[test]
fn test_build_record_batch() {
use uni_common::Value;
let mut props = HashMap::new();
props.insert("weight".to_string(), Value::Float(0.5));
let edges = vec![(
Eid::new(1),
Vid::new(1),
Vid::new(2),
"KNOWS".to_string(),
props,
false,
1u64,
)];
let batch = MainEdgeDataset::build_record_batch(&edges, None, None).unwrap();
assert_eq!(batch.num_rows(), 1);
assert_eq!(batch.num_columns(), 9);
}
#[test]
fn test_build_record_batch_multiple_edges() {
use uni_common::Value;
let edges = vec![
(
Eid::new(1),
Vid::new(1),
Vid::new(2),
"KNOWS".to_string(),
HashMap::from([("since".to_string(), Value::Int(2020))]),
false,
1u64,
),
(
Eid::new(2),
Vid::new(2),
Vid::new(3),
"WORKS_AT".to_string(),
HashMap::new(),
false,
2u64,
),
(
Eid::new(3),
Vid::new(1),
Vid::new(3),
"KNOWS".to_string(),
HashMap::new(),
true, 3u64,
),
];
let batch = MainEdgeDataset::build_record_batch(&edges, None, None).unwrap();
assert_eq!(batch.num_rows(), 3);
assert_eq!(batch.num_columns(), 9);
let type_col = batch
.column_by_name("type")
.unwrap()
.as_any()
.downcast_ref::<arrow_array::StringArray>()
.unwrap();
assert_eq!(type_col.value(0), "KNOWS");
assert_eq!(type_col.value(1), "WORKS_AT");
assert_eq!(type_col.value(2), "KNOWS");
}
#[test]
fn test_build_record_batch_with_timestamps() {
let edges = vec![(
Eid::new(1),
Vid::new(1),
Vid::new(2),
"KNOWS".to_string(),
HashMap::new(),
false,
1u64,
)];
let mut created_at: HashMap<Eid, i64> = HashMap::new();
created_at.insert(Eid::new(1), 1_000_000_000);
let mut updated_at: HashMap<Eid, i64> = HashMap::new();
updated_at.insert(Eid::new(1), 2_000_000_000);
let batch =
MainEdgeDataset::build_record_batch(&edges, Some(&created_at), Some(&updated_at))
.unwrap();
assert_eq!(batch.num_rows(), 1);
let created_col = batch.column_by_name("_created_at").unwrap();
assert!(!created_col.is_null(0), "created_at should be populated");
}
}