use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum GraphAlgorithm {
PageRank,
Wcc,
LabelPropagation,
Lcc,
Sssp,
Betweenness,
Closeness,
Harmonic,
Degree,
Louvain,
Triangles,
Diameter,
KCore,
}
impl GraphAlgorithm {
pub fn name(&self) -> &'static str {
match self {
Self::PageRank => "pagerank",
Self::Wcc => "wcc",
Self::LabelPropagation => "label_propagation",
Self::Lcc => "lcc",
Self::Sssp => "sssp",
Self::Betweenness => "betweenness",
Self::Closeness => "closeness",
Self::Harmonic => "harmonic",
Self::Degree => "degree",
Self::Louvain => "louvain",
Self::Triangles => "triangles",
Self::Diameter => "diameter",
Self::KCore => "kcore",
}
}
pub fn is_iterative(&self) -> bool {
matches!(
self,
Self::PageRank | Self::LabelPropagation | Self::Louvain
)
}
pub fn result_schema(&self) -> &'static [(&'static str, AlgoColumnType)] {
use AlgoColumnType::*;
match self {
Self::PageRank => &[("node_id", Text), ("rank", Float64)],
Self::Wcc => &[("node_id", Text), ("component_id", Int64)],
Self::LabelPropagation => &[("node_id", Text), ("community_id", Int64)],
Self::Lcc => &[("node_id", Text), ("coefficient", Float64)],
Self::Sssp => &[("node_id", Text), ("distance", Float64)],
Self::Betweenness => &[("node_id", Text), ("centrality", Float64)],
Self::Closeness => &[("node_id", Text), ("centrality", Float64)],
Self::Harmonic => &[("node_id", Text), ("centrality", Float64)],
Self::Degree => &[("node_id", Text), ("centrality", Float64)],
Self::Louvain => &[
("node_id", Text),
("community_id", Int64),
("modularity", Float64),
],
Self::Triangles => &[("node_id", Text), ("triangles", Int64)],
Self::Diameter => &[("diameter", Int64), ("radius", Int64)],
Self::KCore => &[("node_id", Text), ("coreness", Int64)],
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AlgoColumnType {
Text,
Float64,
Int64,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct AlgoParams {
pub collection: String,
pub damping: Option<f64>,
pub max_iterations: Option<usize>,
pub tolerance: Option<f64>,
pub source_node: Option<String>,
pub sample_size: Option<usize>,
pub direction: Option<String>,
pub resolution: Option<f64>,
pub mode: Option<String>,
}
impl AlgoParams {
pub fn damping_factor(&self) -> f64 {
self.damping.unwrap_or(0.85).clamp(0.01, 0.99)
}
pub fn iterations(&self, default: usize) -> usize {
self.max_iterations.unwrap_or(default).max(1)
}
pub fn convergence_tolerance(&self) -> f64 {
let t = self.tolerance.unwrap_or(1e-7);
if t > 0.0 { t } else { 1e-7 }
}
pub fn louvain_resolution(&self) -> f64 {
let r = self.resolution.unwrap_or(1.0);
if r > 0.0 { r } else { 1.0 }
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn algorithm_names() {
assert_eq!(GraphAlgorithm::PageRank.name(), "pagerank");
assert_eq!(GraphAlgorithm::Wcc.name(), "wcc");
assert_eq!(GraphAlgorithm::KCore.name(), "kcore");
}
#[test]
fn iterative_algorithms() {
assert!(GraphAlgorithm::PageRank.is_iterative());
assert!(GraphAlgorithm::LabelPropagation.is_iterative());
assert!(GraphAlgorithm::Louvain.is_iterative());
assert!(!GraphAlgorithm::Wcc.is_iterative());
assert!(!GraphAlgorithm::Sssp.is_iterative());
}
#[test]
fn result_schema_columns() {
let schema = GraphAlgorithm::PageRank.result_schema();
assert_eq!(schema.len(), 2);
assert_eq!(schema[0], ("node_id", AlgoColumnType::Text));
assert_eq!(schema[1], ("rank", AlgoColumnType::Float64));
}
#[test]
fn louvain_schema_has_three_columns() {
let schema = GraphAlgorithm::Louvain.result_schema();
assert_eq!(schema.len(), 3);
}
#[test]
fn params_defaults() {
let p = AlgoParams::default();
assert_eq!(p.damping_factor(), 0.85);
assert_eq!(p.iterations(20), 20);
assert_eq!(p.convergence_tolerance(), 1e-7);
assert_eq!(p.louvain_resolution(), 1.0);
}
#[test]
fn params_clamping() {
let p = AlgoParams {
damping: Some(2.0),
tolerance: Some(-1.0),
resolution: Some(0.0),
..Default::default()
};
assert_eq!(p.damping_factor(), 0.99);
assert_eq!(p.convergence_tolerance(), 1e-7);
assert_eq!(p.louvain_resolution(), 1.0);
}
#[test]
fn params_serde_roundtrip() {
let p = AlgoParams {
collection: "users".into(),
damping: Some(0.9),
max_iterations: Some(30),
source_node: Some("alice".into()),
..Default::default()
};
let json = serde_json::to_string(&p).unwrap();
let p2: AlgoParams = serde_json::from_str(&json).unwrap();
assert_eq!(p2.collection, "users");
assert_eq!(p2.damping, Some(0.9));
assert_eq!(p2.max_iterations, Some(30));
assert_eq!(p2.source_node, Some("alice".into()));
}
}