node2vec_rs/cpu/
word2vec_model.rs1use std::sync::Arc;
2
3use crate::cpu::matrix::Matrix;
4use crate::cpu::simd::saxpy_simd;
5use crate::cpu::*;
6
7const SIGMOID_TABLE_SIZE_F32: f32 = SIGMOID_TABLE_SIZE as f32;
12const LOG_TABLE_SIZE_F32: f32 = LOG_TABLE_SIZE as f32;
13
14fn init_sigmoid_table() -> [f32; SIGMOID_TABLE_SIZE + 1] {
24 let mut sigmoid_table = [0f32; SIGMOID_TABLE_SIZE + 1];
25 for i in 0..SIGMOID_TABLE_SIZE + 1 {
26 let x = (i as f32 * 2. * MAX_SIGMOID) / SIGMOID_TABLE_SIZE_F32 - MAX_SIGMOID;
27 sigmoid_table[i] = 1.0 / (1.0 + (-x).exp());
28 }
29 sigmoid_table
30}
31
32fn init_log_table() -> [f32; LOG_TABLE_SIZE + 1] {
38 let mut log_table = [0f32; LOG_TABLE_SIZE + 1];
39 for i in 0..LOG_TABLE_SIZE + 1 {
40 let x = (i as f32 + 1e-5) / LOG_TABLE_SIZE_F32;
41 log_table[i] = x.ln();
42 }
43 log_table
44}
45
46pub struct Word2Vec<'a> {
67 pub input: &'a mut Matrix,
68 output: &'a mut Matrix,
69 dim: usize,
70 lr: f32,
71 neg: usize,
72 grad: Vec<f32>,
73 neg_pos: usize,
74 sigmoid_table: [f32; SIGMOID_TABLE_SIZE + 1],
75 log_table: [f32; LOG_TABLE_SIZE + 1],
76 negative_table: Arc<Vec<usize>>,
77 loss: f64,
78 n_samples: u64,
79}
80
81impl<'a> Word2Vec<'a> {
82 pub fn new(
97 input: &'a mut Matrix,
98 output: &'a mut Matrix,
99 dim: usize,
100 lr: f32,
101 neg: usize,
102 neg_table: Arc<Vec<usize>>,
103 neg_start: usize,
104 ) -> Word2Vec<'a> {
105 Self {
106 input,
107 output,
108 dim,
109 lr,
110 neg,
111 grad: vec![0f32; dim],
112 neg_pos: neg_start % neg_table.len(),
113 sigmoid_table: init_sigmoid_table(),
114 log_table: init_log_table(),
115 negative_table: neg_table,
116 loss: 0.,
117 n_samples: 0,
118 }
119 }
120
121 #[inline]
127 pub fn get_loss(&self) -> f64 {
128 self.loss / self.n_samples as f64
129 }
130
131 #[inline(always)]
137 pub fn set_lr(&mut self, lr: f32) {
138 self.lr = lr;
139 }
140
141 #[inline(always)]
147 pub fn get_lr(&self) -> f32 {
148 self.lr
149 }
150
151 fn get_negative(&mut self, target: usize) -> usize {
161 loop {
162 let negative = self.negative_table[self.neg_pos];
163 self.neg_pos = (self.neg_pos + 1) % self.negative_table.len();
164 if target != negative {
165 return negative;
166 }
167 }
168 }
169
170 #[inline(always)]
177 pub fn update(&mut self, input: usize, target: usize) {
178 self.loss += self.negative_sampling(input, target);
179 self.n_samples += 1;
180 }
181
182 fn negative_sampling(&mut self, input: usize, target: usize) -> f64 {
189 let input_emb = self.input.get_row(input);
190 let mut loss = 0f32;
191 self.grad_zero();
192 for i in 0..self.neg + 1 {
193 if i == 0 {
194 loss += self.binary_logistic(input_emb, target, 1);
195 } else {
196 let neg_sample = self.get_negative(target);
197 loss += self.binary_logistic(input_emb, neg_sample, 0);
198 }
199 }
200 unsafe { self.input.add_row(self.grad.as_mut_ptr(), input, 1.0) };
201 loss as f64
202 }
203
204 #[inline]
214 fn log(&self, x: f32) -> f32 {
215 if x > 1.0 {
216 x
217 } else {
218 let i = (x * (LOG_TABLE_SIZE_F32)) as usize;
219 unsafe { *self.log_table.get_unchecked(i) }
220 }
221 }
222
223 #[inline]
225 fn grad_zero(&mut self) {
226 self.grad.fill(0.0);
227 }
228
229 #[inline]
239 fn sigmoid(&self, x: f32) -> f32 {
240 if x < -MAX_SIGMOID {
241 0f32
242 } else if x > MAX_SIGMOID {
243 1f32
244 } else {
245 let i = (x + MAX_SIGMOID) * SIGMOID_TABLE_SIZE_F32 / MAX_SIGMOID / 2.;
246 unsafe { *self.sigmoid_table.get_unchecked(i as usize) }
247 }
248 }
249
250 #[inline(always)]
257 fn add_mul_row(&mut self, other: *const f32, a: f32) {
258 unsafe {
259 let source_slice = std::slice::from_raw_parts(other, self.dim);
260 saxpy_simd(&mut self.grad, source_slice, a);
261 }
262 }
263
264 #[inline]
276 fn binary_logistic(&mut self, input_emb: *mut f32, target: usize, label: i32) -> f32 {
277 let sum = unsafe { self.output.dot_row(input_emb, target) };
278 let score = self.sigmoid(sum);
279 let alpha = self.lr * (label as f32 - score);
280 let tar_emb = self.output.get_row(target);
281 self.add_mul_row(tar_emb, alpha);
282 unsafe { self.output.add_row(input_emb, target, alpha) };
283 if label == 1 {
284 -self.log(score)
285 } else {
286 -self.log(1.0 - score)
287 }
288 }
289}