use crate::{FloatVectorFormat, GlobalArgs, GranularityArgs, NumThreadsArg, get_thread_pool};
use anyhow::{Result, bail, ensure};
use clap::{ArgGroup, Args, Parser};
use dsi_bitstream::prelude::*;
use dsi_progress_logger::{ProgressLog, concurrent_progress_logger};
use epserde::deser::{Deserialize, Flags};
use rand::SeedableRng;
use std::path::PathBuf;
use webgraph::{
graphs::bvgraph::get_endianness,
prelude::{BvGraph, DCF, DEG_CUMUL_EXTENSION},
};
use webgraph_algo::distances::hyperball::HyperBallBuilder;
#[derive(Args, Debug, Clone)]
#[clap(group = ArgGroup::new("centralities"))]
pub struct Centralities {
#[clap(long, value_enum, default_value_t = FloatVectorFormat::Ascii)]
pub fmt: FloatVectorFormat,
#[clap(long)]
pub precision: Option<usize>,
#[clap(long)]
pub sum_of_distances: Option<PathBuf>,
#[clap(long)]
pub reachable_nodes: Option<PathBuf>,
#[clap(long)]
pub harmonic: Option<PathBuf>,
#[clap(long)]
pub closeness: Option<PathBuf>,
#[clap(long)]
pub neighborhood_function: Option<PathBuf>,
}
impl Centralities {
pub fn should_compute_sum_of_distances(&self) -> bool {
self.sum_of_distances.is_some() || self.closeness.is_some()
}
pub fn should_compute_sum_of_inverse_distances(&self) -> bool {
self.harmonic.is_some()
}
}
#[derive(Parser, Debug)]
#[command(
name = "hyperball",
about = "Use hyperball to compute centralities.",
long_about = None
)]
pub struct CliArgs {
pub basename: PathBuf,
#[clap(long, default_value_t = false)]
pub symm: bool,
#[clap(short, long)]
pub transposed: Option<PathBuf>,
#[clap(flatten)]
pub centralities: Centralities,
#[clap(short = 'm', long, default_value_t = 14)]
pub log2m: usize,
#[clap(long, default_value_t = usize::MAX)]
pub upper_bound: usize,
#[clap(long)]
pub threshold: Option<f64>,
#[clap(flatten)]
pub num_threads: NumThreadsArg,
#[clap(flatten)]
pub granularity: GranularityArgs,
#[clap(long, default_value_t = 0)]
pub seed: u64,
}
pub fn main(global_args: GlobalArgs, args: CliArgs) -> Result<()> {
ensure!(
!args.symm || args.transposed.is_none(),
"If the graph is symmetric, you should not pass the transpose."
);
match get_endianness(&args.basename)?.as_str() {
#[cfg(feature = "be_bins")]
BE::NAME => hyperball::<BE>(global_args, args),
#[cfg(feature = "le_bins")]
LE::NAME => hyperball::<LE>(global_args, args),
e => panic!("Unknown endianness: {}", e),
}
}
pub fn hyperball<E: Endianness>(global_args: GlobalArgs, args: CliArgs) -> Result<()> {
let mut pl = concurrent_progress_logger![];
if let Some(log_interval) = global_args.log_interval {
pl.log_interval(log_interval);
}
let thread_pool = get_thread_pool(args.num_threads.num_threads);
let graph = BvGraph::with_basename(&args.basename).load()?;
log::info!("Loading DCF...");
if !args.basename.with_extension(DEG_CUMUL_EXTENSION).exists() {
bail!(
"Missing DCF file. Please run `webgraph build dcf {}`.",
args.basename.display()
);
}
let deg_cumul = unsafe {
DCF::mmap(
args.basename.with_extension(DEG_CUMUL_EXTENSION),
Flags::RANDOM_ACCESS,
)
}?;
log::info!("Loading Transposed graph...");
let mut transposed = None;
if let Some(transposed_path) = args.transposed.as_ref() {
transposed = Some(BvGraph::with_basename(transposed_path).load()?);
}
let mut transposed_ref = transposed.as_ref();
if args.symm {
transposed_ref = Some(&graph);
}
let mut hb = HyperBallBuilder::with_hyper_log_log(
&graph,
transposed_ref,
deg_cumul.uncase(),
args.log2m,
None,
)?
.granularity(args.granularity.into_granularity())
.sum_of_distances(args.centralities.should_compute_sum_of_distances())
.sum_of_inverse_distances(args.centralities.should_compute_sum_of_inverse_distances())
.build(&mut pl);
log::info!("Starting Hyperball...");
let rng = rand::rngs::SmallRng::seed_from_u64(args.seed);
thread_pool.install(|| hb.run(args.upper_bound, args.threshold, rng, &mut pl))?;
log::info!("Storing the results...");
macro_rules! store_centrality {
($flag:ident, $method:ident, $description:expr) => {{
if let Some(path) = args.centralities.$flag {
log::info!("Saving {} to {}", $description, path.display());
let value = hb.$method()?;
args.centralities
.fmt
.store(path, &value, args.centralities.precision)?;
}
}};
}
store_centrality!(sum_of_distances, sum_of_distances, "sum of distances");
store_centrality!(harmonic, harmonic_centralities, "harmonic centralities");
store_centrality!(closeness, closeness_centrality, "closeness centralities");
store_centrality!(reachable_nodes, reachable_nodes, "reachable nodes");
store_centrality!(
neighborhood_function,
neighborhood_function,
"neighborhood function"
);
Ok(())
}