Skip to main content

node2vec_rs/
lib.rs

1#![allow(clippy::needless_range_loop)]
2
3pub mod graph;
4pub mod reader;
5
6#[cfg(feature = "burn")]
7pub mod burn;
8
9#[cfg(feature = "cpu")]
10pub mod cpu;
11pub mod prelude;
12
13use clap::Parser;
14
15///////////////
16// Arguments //
17///////////////
18
19/// CLI arguments
20///
21/// ### Fields
22///
23/// * `input` - The input CSV. Needs to be provided.
24/// * `output` - Where to store the outputs. Defaults to `"/tmp/node2vec"`.
25/// * `directed` - Shall the graph be treated as a directed graph. Defaults
26///   to `false`.
27/// * `embedding_dim` - Size of the embedding to create. Defaults to `16`.
28/// * `split` - How much of the data should be in the trainings data vs.
29///   validation data. Defaults to `0.9`.
30/// * `walks_per_node` - Number of random walks to do per node. Defaults to
31///   `20`.
32/// * `walk_length` - Length of the random walks. Defaults to `20`.
33/// * `window_size` - Window size parameter for the skipgram model. Defaults to
34///   `2`.
35/// * `batch_size` - Batch size during training. Defaults to `256`.
36/// * `num_workers` - Number of workers to use during the generation of the
37///   batches. Defaults to `4`.
38/// * `num_epochs` - Number of epochs to train the model for. Defaults to `5`.
39/// * `num_negatives` - Number of negative examples to sample. Defaults to `5`.
40/// * `seed` - Seed for reproducibility. Defaults to `42`.
41/// * `learning_rate` - Learning rate for the Adam optimiser. Defaults to
42///   `1-e3`.
43/// * `p` - p parameter for the node2vec random walks and controls the
44///   probability to return to origin node. Defaults to `1.0`.
45/// * `q` - q parameter for node2vec random walks and controls the probability
46///   to venture on a different node from the origin node. Defaults to `1.0`.
47#[derive(Parser)]
48#[command(name = "node2vec")]
49#[command(about = "Node2Vec implementation using Burn", long_about = None)]
50pub struct Args {
51    #[arg(short, long)]
52    pub input: String,
53
54    #[arg(short, long, default_value = "/tmp/node2vec")]
55    pub output: String,
56
57    #[arg(long, default_value = "cpu")]
58    pub backend: String,
59
60    #[arg(short, long, default_value_t = false)]
61    pub directed: bool,
62
63    #[arg(short, long, default_value_t = 16)]
64    pub embedding_dim: usize,
65
66    #[arg(short, long, default_value_t = 0.9)]
67    pub split: f32,
68
69    #[arg(long, default_value_t = 20)]
70    pub walks_per_node: usize,
71
72    #[arg(long, default_value_t = 20)]
73    pub walk_length: usize,
74
75    #[arg(long, default_value_t = 2)]
76    pub window_size: usize,
77
78    #[arg(long, default_value_t = 256)]
79    pub batch_size: usize,
80
81    #[arg(long, default_value_t = 4)]
82    pub num_workers: usize,
83
84    #[arg(long, default_value_t = 5)]
85    pub num_epochs: usize,
86
87    #[arg(long, default_value_t = 5)]
88    pub num_negatives: usize,
89
90    #[arg(long, default_value_t = 42)]
91    pub seed: usize,
92
93    #[arg(long, default_value_t = 1.0e-3)]
94    pub learning_rate: f64,
95
96    #[arg(long, default_value_t = 1.0)]
97    pub p: f32,
98
99    #[arg(long, default_value_t = 1.0)]
100    pub q: f32,
101
102    #[arg(long, default_value_t = 1.0e-3)]
103    pub sample: f32,
104}