Skip to main content

node2vec_rs/cpu/
word2vec_model.rs

1use std::sync::Arc;
2
3use crate::cpu::matrix::Matrix;
4use crate::cpu::simd::saxpy_simd;
5use crate::cpu::*;
6
7////////////////
8// Constansts //
9////////////////
10
11const SIGMOID_TABLE_SIZE_F32: f32 = SIGMOID_TABLE_SIZE as f32;
12const LOG_TABLE_SIZE_F32: f32 = LOG_TABLE_SIZE as f32;
13
14/////////////
15// Helpers //
16/////////////
17
18/// Initialise the sigmoid table
19///
20/// ### Returns
21///
22/// A table of sigmoid values
23fn init_sigmoid_table() -> [f32; SIGMOID_TABLE_SIZE + 1] {
24    let mut sigmoid_table = [0f32; SIGMOID_TABLE_SIZE + 1];
25    for i in 0..SIGMOID_TABLE_SIZE + 1 {
26        let x = (i as f32 * 2. * MAX_SIGMOID) / SIGMOID_TABLE_SIZE_F32 - MAX_SIGMOID;
27        sigmoid_table[i] = 1.0 / (1.0 + (-x).exp());
28    }
29    sigmoid_table
30}
31
32/// Initialise the log table
33///
34/// ### Returns
35///
36/// A table of log values
37fn init_log_table() -> [f32; LOG_TABLE_SIZE + 1] {
38    let mut log_table = [0f32; LOG_TABLE_SIZE + 1];
39    for i in 0..LOG_TABLE_SIZE + 1 {
40        let x = (i as f32 + 1e-5) / LOG_TABLE_SIZE_F32;
41        log_table[i] = x.ln();
42    }
43    log_table
44}
45
46//////////
47// Main //
48//////////
49
50/// Initialise the word2vec model
51///
52/// ### Fields
53///
54/// * `input`: The input matrix
55/// * `output`: The output matrix
56/// * `dim`: The dimension of the model
57/// * `lr`: The learning rate
58/// * `neg`: The number of negative samples
59/// * `grad`: The gradient vector
60/// * `neg_pos`: The number of negative samples per positive sample
61/// * `sigmoid_table`: The sigmoid table
62/// * `log_table`: The log table
63/// * `negative_table`: The negative table
64/// * `loss`: The loss
65/// * `n_samples`: The number of samples
66pub struct Word2Vec<'a> {
67    pub input: &'a mut Matrix,
68    output: &'a mut Matrix,
69    dim: usize,
70    lr: f32,
71    neg: usize,
72    grad: Vec<f32>,
73    neg_pos: usize,
74    sigmoid_table: [f32; SIGMOID_TABLE_SIZE + 1],
75    log_table: [f32; LOG_TABLE_SIZE + 1],
76    negative_table: Arc<Vec<usize>>,
77    loss: f64,
78    n_samples: u64,
79}
80
81impl<'a> Word2Vec<'a> {
82    /// Generate a new word2vec model
83    ///
84    /// ### Params
85    ///
86    /// * `input` - The input matrix
87    /// * `output` - The output matrix
88    /// * `dim` - The dimension of the model
89    /// * `lr` - The learning rate
90    /// * `neg` - The number of negative samples
91    /// * `neg_table` - The negative table
92    ///
93    /// ### Returns
94    ///
95    /// A new initialised word2vec model
96    pub fn new(
97        input: &'a mut Matrix,
98        output: &'a mut Matrix,
99        dim: usize,
100        lr: f32,
101        neg: usize,
102        neg_table: Arc<Vec<usize>>,
103        neg_start: usize,
104    ) -> Word2Vec<'a> {
105        Self {
106            input,
107            output,
108            dim,
109            lr,
110            neg,
111            grad: vec![0f32; dim],
112            neg_pos: neg_start % neg_table.len(),
113            sigmoid_table: init_sigmoid_table(),
114            log_table: init_log_table(),
115            negative_table: neg_table,
116            loss: 0.,
117            n_samples: 0,
118        }
119    }
120
121    /// Return the loss value
122    ///
123    /// ### Returns
124    ///
125    /// The loss value
126    #[inline]
127    pub fn get_loss(&self) -> f64 {
128        self.loss / self.n_samples as f64
129    }
130
131    /// Set the learning rate
132    ///
133    /// ### Params
134    ///
135    /// * `lr` - The learning rate
136    #[inline(always)]
137    pub fn set_lr(&mut self, lr: f32) {
138        self.lr = lr;
139    }
140
141    /// Get the learning rate
142    ///
143    /// ### Returns
144    ///
145    /// The learning rate
146    #[inline(always)]
147    pub fn get_lr(&self) -> f32 {
148        self.lr
149    }
150
151    /// Return a negative
152    ///
153    /// ### Params
154    ///
155    /// * `target` - The target value
156    ///
157    /// ### Returns
158    ///
159    /// A negative example
160    fn get_negative(&mut self, target: usize) -> usize {
161        loop {
162            let negative = self.negative_table[self.neg_pos];
163            self.neg_pos = (self.neg_pos + 1) % self.negative_table.len();
164            if target != negative {
165                return negative;
166            }
167        }
168    }
169
170    /// Update the model
171    ///
172    /// ### Params
173    ///
174    /// * `input` - The input value
175    /// * `target` - The target value
176    #[inline(always)]
177    pub fn update(&mut self, input: usize, target: usize) {
178        self.loss += self.negative_sampling(input, target);
179        self.n_samples += 1;
180    }
181
182    /// Negative sampling
183    ///
184    /// ### Params
185    ///
186    /// * `input` - The input value
187    /// * `target` - The target value
188    fn negative_sampling(&mut self, input: usize, target: usize) -> f64 {
189        let input_emb = self.input.get_row(input);
190        let mut loss = 0f32;
191        self.grad_zero();
192        for i in 0..self.neg + 1 {
193            if i == 0 {
194                loss += self.binary_logistic(input_emb, target, 1);
195            } else {
196                let neg_sample = self.get_negative(target);
197                loss += self.binary_logistic(input_emb, neg_sample, 0);
198            }
199        }
200        unsafe { self.input.add_row(self.grad.as_mut_ptr(), input, 1.0) };
201        loss as f64
202    }
203
204    /// Return the log value
205    ///
206    /// ### Params
207    ///
208    /// * `x` - The value to log
209    ///
210    /// ### Returns
211    ///
212    /// The log value
213    #[inline]
214    fn log(&self, x: f32) -> f32 {
215        if x > 1.0 {
216            x
217        } else {
218            let i = (x * (LOG_TABLE_SIZE_F32)) as usize;
219            unsafe { *self.log_table.get_unchecked(i) }
220        }
221    }
222
223    /// Set the gradient to zero
224    #[inline]
225    fn grad_zero(&mut self) {
226        self.grad.fill(0.0);
227    }
228
229    /// Return the sigmoid value
230    ///
231    /// ### Params
232    ///
233    /// * `x` - The value to sigmoid
234    ///
235    /// ### Returns
236    ///
237    /// The sigmoid value
238    #[inline]
239    fn sigmoid(&self, x: f32) -> f32 {
240        if x < -MAX_SIGMOID {
241            0f32
242        } else if x > MAX_SIGMOID {
243            1f32
244        } else {
245            let i = (x + MAX_SIGMOID) * SIGMOID_TABLE_SIZE_F32 / MAX_SIGMOID / 2.;
246            unsafe { *self.sigmoid_table.get_unchecked(i as usize) }
247        }
248    }
249
250    /// Add a multiple of a row to the gradient
251    ///
252    /// ### Params
253    ///
254    /// * `other` - The row to add
255    /// * `a` - The scalar to multiply the row by
256    #[inline(always)]
257    fn add_mul_row(&mut self, other: *const f32, a: f32) {
258        unsafe {
259            let source_slice = std::slice::from_raw_parts(other, self.dim);
260            saxpy_simd(&mut self.grad, source_slice, a);
261        }
262    }
263
264    /// Calculate the binary logistic loss
265    ///
266    /// ### Params
267    ///
268    /// * `input_emb` - The input embedding
269    /// * `target` - The target word index
270    /// * `label` - The label: `1` positive; `-1` = negative
271    ///
272    /// ### Returns
273    ///
274    /// The binary logistic loss
275    #[inline]
276    fn binary_logistic(&mut self, input_emb: *mut f32, target: usize, label: i32) -> f32 {
277        let sum = unsafe { self.output.dot_row(input_emb, target) };
278        let score = self.sigmoid(sum);
279        let alpha = self.lr * (label as f32 - score);
280        let tar_emb = self.output.get_row(target);
281        self.add_mul_row(tar_emb, alpha);
282        unsafe { self.output.add_row(input_emb, target, alpha) };
283        if label == 1 {
284            -self.log(score)
285        } else {
286            -self.log(1.0 - score)
287        }
288    }
289}