node2vec_rs/lib.rs
1#![allow(clippy::needless_range_loop)]
2
3pub mod graph;
4pub mod reader;
5
6#[cfg(feature = "burn")]
7pub mod burn;
8
9#[cfg(feature = "cpu")]
10pub mod cpu;
11pub mod prelude;
12
13use clap::Parser;
14
15///////////////
16// Arguments //
17///////////////
18
19/// CLI arguments
20///
21/// ### Fields
22///
23/// * `input` - The input CSV. Needs to be provided.
24/// * `output` - Where to store the outputs. Defaults to `"/tmp/node2vec"`.
25/// * `directed` - Shall the graph be treated as a directed graph. Defaults
26/// to `false`.
27/// * `embedding_dim` - Size of the embedding to create. Defaults to `16`.
28/// * `split` - How much of the data should be in the trainings data vs.
29/// validation data. Defaults to `0.9`.
30/// * `walks_per_node` - Number of random walks to do per node. Defaults to
31/// `20`.
32/// * `walk_length` - Length of the random walks. Defaults to `20`.
33/// * `window_size` - Window size parameter for the skipgram model. Defaults to
34/// `2`.
35/// * `batch_size` - Batch size during training. Defaults to `256`.
36/// * `num_workers` - Number of workers to use during the generation of the
37/// batches. Defaults to `4`.
38/// * `num_epochs` - Number of epochs to train the model for. Defaults to `5`.
39/// * `num_negatives` - Number of negative examples to sample. Defaults to `5`.
40/// * `seed` - Seed for reproducibility. Defaults to `42`.
41/// * `learning_rate` - Learning rate for the Adam optimiser. Defaults to
42/// `1-e3`.
43/// * `p` - p parameter for the node2vec random walks and controls the
44/// probability to return to origin node. Defaults to `1.0`.
45/// * `q` - q parameter for node2vec random walks and controls the probability
46/// to venture on a different node from the origin node. Defaults to `1.0`.
47#[derive(Parser)]
48#[command(name = "node2vec")]
49#[command(about = "Node2Vec implementation using Burn", long_about = None)]
50pub struct Args {
51 #[arg(short, long)]
52 pub input: String,
53
54 #[arg(short, long, default_value = "/tmp/node2vec")]
55 pub output: String,
56
57 #[arg(long, default_value = "cpu")]
58 pub backend: String,
59
60 #[arg(short, long, default_value_t = false)]
61 pub directed: bool,
62
63 #[arg(short, long, default_value_t = 16)]
64 pub embedding_dim: usize,
65
66 #[arg(short, long, default_value_t = 0.9)]
67 pub split: f32,
68
69 #[arg(long, default_value_t = 20)]
70 pub walks_per_node: usize,
71
72 #[arg(long, default_value_t = 20)]
73 pub walk_length: usize,
74
75 #[arg(long, default_value_t = 2)]
76 pub window_size: usize,
77
78 #[arg(long, default_value_t = 256)]
79 pub batch_size: usize,
80
81 #[arg(long, default_value_t = 4)]
82 pub num_workers: usize,
83
84 #[arg(long, default_value_t = 5)]
85 pub num_epochs: usize,
86
87 #[arg(long, default_value_t = 5)]
88 pub num_negatives: usize,
89
90 #[arg(long, default_value_t = 42)]
91 pub seed: usize,
92
93 #[arg(long, default_value_t = 1.0e-3)]
94 pub learning_rate: f64,
95
96 #[arg(long, default_value_t = 1.0)]
97 pub p: f32,
98
99 #[arg(long, default_value_t = 1.0)]
100 pub q: f32,
101
102 #[arg(long, default_value_t = 1.0e-3)]
103 pub sample: f32,
104}