use std::sync::Arc;
use manifoldb_core::{EntityId, Value};
use manifoldb_graph::analytics::{
BetweennessCentrality, BetweennessCentralityConfig, CommunityDetection,
CommunityDetectionConfig, PageRank, PageRankConfig,
};
use manifoldb_graph::traversal::Direction;
use manifoldb_storage::Transaction;
use crate::error::ParseError;
use crate::exec::context::ExecutionContext;
use crate::exec::operator::{BoxedOperator, Operator, OperatorBase, OperatorResult, OperatorState};
use crate::exec::row::{Row, Schema};
#[derive(Debug, Clone)]
pub struct PageRankOpConfig {
pub damping_factor: f64,
pub max_iterations: usize,
pub tolerance: f64,
pub normalize: bool,
}
impl Default for PageRankOpConfig {
fn default() -> Self {
let pr_config = PageRankConfig::default();
Self {
damping_factor: pr_config.damping_factor,
max_iterations: pr_config.max_iterations,
tolerance: pr_config.tolerance,
normalize: pr_config.normalize,
}
}
}
impl PageRankOpConfig {
pub fn new() -> Self {
Self::default()
}
pub const fn with_damping_factor(mut self, d: f64) -> Self {
self.damping_factor = d;
self
}
pub const fn with_max_iterations(mut self, n: usize) -> Self {
self.max_iterations = n;
self
}
pub const fn with_tolerance(mut self, t: f64) -> Self {
self.tolerance = t;
self
}
pub const fn with_normalize(mut self, n: bool) -> Self {
self.normalize = n;
self
}
fn to_pagerank_config(&self) -> PageRankConfig {
PageRankConfig::new()
.with_damping_factor(self.damping_factor)
.with_max_iterations(self.max_iterations)
.with_tolerance(self.tolerance)
.with_normalize(self.normalize)
}
}
pub struct PageRankOp<T> {
base: OperatorBase,
config: PageRankOpConfig,
tx: Option<T>,
results: Option<std::vec::IntoIter<(EntityId, f64)>>,
input: Option<BoxedOperator>,
input_node_column: Option<usize>,
}
impl<T> PageRankOp<T> {
pub fn new(config: PageRankOpConfig) -> Self {
let schema = Arc::new(Schema::new(vec!["node".to_string(), "score".to_string()]));
Self {
base: OperatorBase::new(schema),
config,
tx: None,
results: None,
input: None,
input_node_column: None,
}
}
pub fn with_input(config: PageRankOpConfig, input: BoxedOperator, node_column: usize) -> Self {
let schema = Arc::new(Schema::new(vec!["node".to_string(), "score".to_string()]));
Self {
base: OperatorBase::new(schema),
config,
tx: None,
results: None,
input: Some(input),
input_node_column: Some(node_column),
}
}
pub fn with_tx(mut self, tx: T) -> Self {
self.tx = Some(tx);
self
}
}
impl<T> Operator for PageRankOp<T>
where
T: Transaction + Send,
{
fn open(&mut self, ctx: &ExecutionContext) -> OperatorResult<()> {
if let Some(ref mut input) = self.input {
input.open(ctx)?;
}
self.base.set_open();
Ok(())
}
fn next(&mut self) -> OperatorResult<Option<Row>> {
if self.results.is_none() {
let tx = self.tx.as_ref().ok_or_else(|| {
ParseError::InvalidGraphOp("PageRank requires transaction access".to_string())
})?;
let pr_config = self.config.to_pagerank_config();
let result = if let Some(ref mut input) = self.input {
let column = self.input_node_column.unwrap_or(0);
let mut nodes = Vec::new();
while let Some(row) = input.next()? {
if let Some(Value::Int(id)) = row.get(column) {
nodes.push(EntityId::new(*id as u64));
}
}
PageRank::compute_for_nodes(tx, &nodes, &pr_config)
.map_err(|e| ParseError::InvalidGraphOp(format!("PageRank error: {e}")))?
} else {
PageRank::compute(tx, &pr_config)
.map_err(|e| ParseError::InvalidGraphOp(format!("PageRank error: {e}")))?
};
let sorted = result.sorted();
self.results = Some(sorted.into_iter());
}
if let Some(ref mut iter) = self.results {
if let Some((node, score)) = iter.next() {
let row = Row::new(
self.base.schema(),
vec![Value::Int(node.as_u64() as i64), Value::Float(score)],
);
self.base.inc_rows_produced();
return Ok(Some(row));
}
}
self.base.set_finished();
Ok(None)
}
fn close(&mut self) -> OperatorResult<()> {
if let Some(ref mut input) = self.input {
input.close()?;
}
self.results = None;
self.base.set_closed();
Ok(())
}
fn schema(&self) -> Arc<Schema> {
self.base.schema()
}
fn state(&self) -> OperatorState {
self.base.state()
}
fn name(&self) -> &'static str {
"PageRankOp"
}
}
#[derive(Debug, Clone)]
pub struct BetweennessCentralityOpConfig {
pub normalize: bool,
pub direction: Direction,
}
impl Default for BetweennessCentralityOpConfig {
fn default() -> Self {
let bc_config = BetweennessCentralityConfig::default();
Self { normalize: bc_config.normalize, direction: bc_config.direction }
}
}
impl BetweennessCentralityOpConfig {
pub fn new() -> Self {
Self::default()
}
pub const fn with_normalize(mut self, n: bool) -> Self {
self.normalize = n;
self
}
pub const fn with_direction(mut self, d: Direction) -> Self {
self.direction = d;
self
}
fn to_centrality_config(&self) -> BetweennessCentralityConfig {
BetweennessCentralityConfig::new()
.with_normalize(self.normalize)
.with_direction(self.direction)
}
}
pub struct BetweennessCentralityOp<T> {
base: OperatorBase,
config: BetweennessCentralityOpConfig,
tx: Option<T>,
results: Option<std::vec::IntoIter<(EntityId, f64)>>,
input: Option<BoxedOperator>,
input_node_column: Option<usize>,
}
impl<T> BetweennessCentralityOp<T> {
pub fn new(config: BetweennessCentralityOpConfig) -> Self {
let schema = Arc::new(Schema::new(vec!["node".to_string(), "score".to_string()]));
Self {
base: OperatorBase::new(schema),
config,
tx: None,
results: None,
input: None,
input_node_column: None,
}
}
pub fn with_input(
config: BetweennessCentralityOpConfig,
input: BoxedOperator,
node_column: usize,
) -> Self {
let schema = Arc::new(Schema::new(vec!["node".to_string(), "score".to_string()]));
Self {
base: OperatorBase::new(schema),
config,
tx: None,
results: None,
input: Some(input),
input_node_column: Some(node_column),
}
}
pub fn with_tx(mut self, tx: T) -> Self {
self.tx = Some(tx);
self
}
}
impl<T> Operator for BetweennessCentralityOp<T>
where
T: Transaction + Send,
{
fn open(&mut self, ctx: &ExecutionContext) -> OperatorResult<()> {
if let Some(ref mut input) = self.input {
input.open(ctx)?;
}
self.base.set_open();
Ok(())
}
fn next(&mut self) -> OperatorResult<Option<Row>> {
if self.results.is_none() {
let tx = self.tx.as_ref().ok_or_else(|| {
ParseError::InvalidGraphOp(
"BetweennessCentrality requires transaction access".to_string(),
)
})?;
let bc_config = self.config.to_centrality_config();
let result = if let Some(ref mut input) = self.input {
let column = self.input_node_column.unwrap_or(0);
let mut nodes = Vec::new();
while let Some(row) = input.next()? {
if let Some(Value::Int(id)) = row.get(column) {
nodes.push(EntityId::new(*id as u64));
}
}
BetweennessCentrality::compute_for_nodes(tx, &nodes, &bc_config).map_err(|e| {
ParseError::InvalidGraphOp(format!("BetweennessCentrality error: {e}"))
})?
} else {
BetweennessCentrality::compute(tx, &bc_config).map_err(|e| {
ParseError::InvalidGraphOp(format!("BetweennessCentrality error: {e}"))
})?
};
let sorted = result.sorted();
self.results = Some(sorted.into_iter());
}
if let Some(ref mut iter) = self.results {
if let Some((node, score)) = iter.next() {
let row = Row::new(
self.base.schema(),
vec![Value::Int(node.as_u64() as i64), Value::Float(score)],
);
self.base.inc_rows_produced();
return Ok(Some(row));
}
}
self.base.set_finished();
Ok(None)
}
fn close(&mut self) -> OperatorResult<()> {
if let Some(ref mut input) = self.input {
input.close()?;
}
self.results = None;
self.base.set_closed();
Ok(())
}
fn schema(&self) -> Arc<Schema> {
self.base.schema()
}
fn state(&self) -> OperatorState {
self.base.state()
}
fn name(&self) -> &'static str {
"BetweennessCentralityOp"
}
}
#[derive(Debug, Clone)]
pub struct CommunityDetectionOpConfig {
pub max_iterations: usize,
pub direction: Direction,
pub seed: Option<u64>,
}
impl Default for CommunityDetectionOpConfig {
fn default() -> Self {
let cd_config = CommunityDetectionConfig::default();
Self {
max_iterations: cd_config.max_iterations,
direction: cd_config.direction,
seed: cd_config.seed,
}
}
}
impl CommunityDetectionOpConfig {
pub fn new() -> Self {
Self::default()
}
pub const fn with_max_iterations(mut self, n: usize) -> Self {
self.max_iterations = n;
self
}
pub const fn with_direction(mut self, d: Direction) -> Self {
self.direction = d;
self
}
pub const fn with_seed(mut self, s: u64) -> Self {
self.seed = Some(s);
self
}
fn to_community_config(&self) -> CommunityDetectionConfig {
let mut config = CommunityDetectionConfig::new()
.with_max_iterations(self.max_iterations)
.with_direction(self.direction);
if let Some(seed) = self.seed {
config = config.with_seed(seed);
}
config
}
}
pub struct CommunityDetectionOp<T> {
base: OperatorBase,
config: CommunityDetectionOpConfig,
tx: Option<T>,
results: Option<std::vec::IntoIter<(EntityId, u64)>>,
input: Option<BoxedOperator>,
input_node_column: Option<usize>,
}
impl<T> CommunityDetectionOp<T> {
pub fn new(config: CommunityDetectionOpConfig) -> Self {
let schema = Arc::new(Schema::new(vec!["node".to_string(), "community".to_string()]));
Self {
base: OperatorBase::new(schema),
config,
tx: None,
results: None,
input: None,
input_node_column: None,
}
}
pub fn with_input(
config: CommunityDetectionOpConfig,
input: BoxedOperator,
node_column: usize,
) -> Self {
let schema = Arc::new(Schema::new(vec!["node".to_string(), "community".to_string()]));
Self {
base: OperatorBase::new(schema),
config,
tx: None,
results: None,
input: Some(input),
input_node_column: Some(node_column),
}
}
pub fn with_tx(mut self, tx: T) -> Self {
self.tx = Some(tx);
self
}
}
impl<T> Operator for CommunityDetectionOp<T>
where
T: Transaction + Send,
{
fn open(&mut self, ctx: &ExecutionContext) -> OperatorResult<()> {
if let Some(ref mut input) = self.input {
input.open(ctx)?;
}
self.base.set_open();
Ok(())
}
fn next(&mut self) -> OperatorResult<Option<Row>> {
if self.results.is_none() {
let tx = self.tx.as_ref().ok_or_else(|| {
ParseError::InvalidGraphOp(
"CommunityDetection requires transaction access".to_string(),
)
})?;
let cd_config = self.config.to_community_config();
let result = if let Some(ref mut input) = self.input {
let column = self.input_node_column.unwrap_or(0);
let mut nodes = Vec::new();
while let Some(row) = input.next()? {
if let Some(Value::Int(id)) = row.get(column) {
nodes.push(EntityId::new(*id as u64));
}
}
CommunityDetection::label_propagation_for_nodes(tx, &nodes, &cd_config).map_err(
|e| ParseError::InvalidGraphOp(format!("CommunityDetection error: {e}")),
)?
} else {
CommunityDetection::label_propagation(tx, &cd_config).map_err(|e| {
ParseError::InvalidGraphOp(format!("CommunityDetection error: {e}"))
})?
};
let mut pairs: Vec<_> = result.assignments.into_iter().collect();
pairs.sort_by_key(|(id, _)| id.as_u64());
self.results = Some(pairs.into_iter());
}
if let Some(ref mut iter) = self.results {
if let Some((node, community)) = iter.next() {
let row = Row::new(
self.base.schema(),
vec![Value::Int(node.as_u64() as i64), Value::Int(community as i64)],
);
self.base.inc_rows_produced();
return Ok(Some(row));
}
}
self.base.set_finished();
Ok(None)
}
fn close(&mut self) -> OperatorResult<()> {
if let Some(ref mut input) = self.input {
input.close()?;
}
self.results = None;
self.base.set_closed();
Ok(())
}
fn schema(&self) -> Arc<Schema> {
self.base.schema()
}
fn state(&self) -> OperatorState {
self.base.state()
}
fn name(&self) -> &'static str {
"CommunityDetectionOp"
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn pagerank_config_defaults() {
let config = PageRankOpConfig::default();
assert!((config.damping_factor - 0.85).abs() < f64::EPSILON);
assert_eq!(config.max_iterations, 100);
assert!(config.normalize);
}
#[test]
fn pagerank_config_builder() {
let config = PageRankOpConfig::new()
.with_damping_factor(0.9)
.with_max_iterations(50)
.with_tolerance(1e-8)
.with_normalize(false);
assert!((config.damping_factor - 0.9).abs() < f64::EPSILON);
assert_eq!(config.max_iterations, 50);
assert!((config.tolerance - 1e-8).abs() < f64::EPSILON);
assert!(!config.normalize);
}
#[test]
fn betweenness_centrality_config_defaults() {
let config = BetweennessCentralityOpConfig::default();
assert!(config.normalize);
assert_eq!(config.direction, Direction::Both);
}
#[test]
fn community_detection_config_defaults() {
let config = CommunityDetectionOpConfig::default();
assert_eq!(config.max_iterations, 100);
assert_eq!(config.direction, Direction::Both);
assert!(config.seed.is_none());
}
#[test]
fn community_detection_config_builder() {
let config = CommunityDetectionOpConfig::new()
.with_max_iterations(50)
.with_direction(Direction::Outgoing)
.with_seed(42);
assert_eq!(config.max_iterations, 50);
assert_eq!(config.direction, Direction::Outgoing);
assert_eq!(config.seed, Some(42));
}
}