node2vec_rs/cpu/matrix.rs
1use faer::Mat;
2use rand::rngs::StdRng;
3use rand::SeedableRng;
4use rand_distr::{Distribution, Uniform};
5use std::cell::UnsafeCell;
6
7use crate::cpu::simd::*;
8
9/// Dense matrix stored in row-major format
10///
11/// The matrix data is stored as a flat vector where each row
12/// is laid out contiguously in memory for cache efficiency.
13/// This layout enables efficient SIMD operations on individual rows.
14///
15/// ### Fields
16///
17/// - `n_col` - Number of elements per row (dimension of embeddings)
18/// - `data` - Flat storage of matrix data (rows * ncol elements)
19#[derive(Debug)]
20pub struct Matrix {
21 n_col: usize,
22 n_row: usize,
23 data: Vec<f32>,
24}
25
26/// Thread-safe wrapper around Matrix for concurrent training
27///
28/// Uses `UnsafeCell` to allow interior mutability across threads.
29/// Safety is ensured by the training algorithm's access patterns where
30/// each thread writes to distinct rows.
31///
32/// ### Fields
33///
34/// * `inner` - The inner matrix data (unsafe cell)
35#[derive(Debug)]
36pub struct MatrixWrapper {
37 pub inner: UnsafeCell<Matrix>,
38}
39
40/// SAFETY: This is intentionally unsound. Multiple threads will concurrently
41/// read and write overlapping rows (e.g. a target in one thread may be a
42/// negative sample in another). This mirrors the deliberate data race in
43/// Mikolov's original word2vec C implementation and the word2vec-rs crate
44/// it was ported from. SGD tolerates stale/torn reads and the resulting
45/// embeddings converge in practice. Do not use MatrixWrapper as a general-
46/// purpose concurrent container.
47unsafe impl Sync for MatrixWrapper {}
48
49impl Matrix {
50 /// Creates a new matrix initialised with zeros
51 ///
52 /// ### Params
53 ///
54 /// * `rows` - Number of rows in the matrix (typically vocabulary size)
55 /// * `n_col` - Number of elements per row (embedding dimension)
56 ///
57 /// ### Returns
58 ///
59 /// A new Matrix with all elements set to 0.0
60 pub fn new(n_row: usize, n_col: usize) -> Matrix {
61 Matrix {
62 data: vec![0f32; n_col * n_row],
63 n_col,
64 n_row,
65 }
66 }
67
68 /// Normalises all rows in the matrix to unit length
69 ///
70 /// Each row vector is divided by its L2 norm, making it a unit vector.
71 pub fn norm_self(&mut self) {
72 let num_rows = self.n_row;
73 for i in 0..num_rows {
74 let n = self.norm(i);
75 if n > 0.0 {
76 let start = i * self.n_col;
77 let end = start + self.n_col;
78 for j in start..end {
79 self.data[j] /= n;
80 }
81 }
82 }
83 }
84
85 /// Wraps the matrix in a thread-safe wrapper
86 ///
87 /// ### Returns
88 ///
89 /// A `MatrixWrapper` that can be safely shared across threads using Arc
90 pub fn make_send(self) -> MatrixWrapper {
91 MatrixWrapper {
92 inner: UnsafeCell::new(self),
93 }
94 }
95
96 /// Initialises matrix with uniform random values
97 ///
98 /// ### Params
99 ///
100 /// * `bound` - Values will be sampled uniformly from [-bound, bound]
101 /// * `seed` - Seed for reproducibility.
102 pub fn uniform(&mut self, bound: f32, seed: usize) {
103 let between = Uniform::new(-bound, bound).unwrap();
104 let mut rng = StdRng::seed_from_u64(seed as u64);
105 for v in &mut self.data {
106 *v = between.sample(&mut rng);
107 }
108 }
109
110 /// Computes the L2 norm of a matrix row
111 ///
112 /// ### Params
113 ///
114 /// * `i` - Row index
115 ///
116 /// ### Returns
117 ///
118 /// The L2 norm (Euclidean length) of the row vector
119 pub fn norm(&self, i: usize) -> f32 {
120 let start = i * self.n_col;
121 let end = start + self.n_col;
122 norm_l2_simd(&self.data[start..end])
123 }
124
125 /// Sets all matrix elements to zero
126 #[inline(always)]
127 pub fn zero(&mut self) {
128 for v in self.data.iter_mut() {
129 *v = 0f32;
130 }
131 }
132
133 /// Adds a scaled vector to a matrix row (SAXPY operation)
134 ///
135 /// Performs: `row[i] = row[i] + mul * vec`
136 ///
137 /// This is used during gradient updates where we add scaled gradients
138 /// to embedding vectors.
139 ///
140 /// ### Params
141 ///
142 /// * `vec` - Pointer to the vector to add (must have `n_col` elements)
143 /// * `i` - Row index
144 /// * `mul` - Scaling factor for the vector
145 ///
146 /// ### Safety
147 ///
148 /// The caller must ensure `vec` points to at least `n_col` valid f32
149 /// elements.
150 #[inline(always)]
151 pub unsafe fn add_row(&mut self, vec: *const f32, i: usize, mul: f32) {
152 let start = i * self.n_col;
153 unsafe {
154 let row_slice =
155 std::slice::from_raw_parts_mut(self.data.as_mut_ptr().add(start), self.n_col);
156 let vec_slice = std::slice::from_raw_parts(vec, self.n_col);
157 saxpy_simd(row_slice, vec_slice, mul);
158 }
159 }
160
161 /// Computes dot product between a vector and a matrix row
162 ///
163 /// ### Params
164 ///
165 /// * `vec` - Pointer to the vector (must have `row_size` elements)
166 /// * `i` - Row index
167 ///
168 /// ### Returns
169 ///
170 /// The dot product result
171 ///
172 /// ### Safety
173 ///
174 /// The caller must ensure `vec` points to at least `row_size` valid f32
175 /// elements
176 #[inline(always)]
177 pub unsafe fn dot_row(&self, vec: *const f32, i: usize) -> f32 {
178 let start = i * self.n_col;
179 unsafe {
180 let row_slice = std::slice::from_raw_parts(self.data.as_ptr().add(start), self.n_col);
181 let vec_slice = std::slice::from_raw_parts(vec, self.n_col);
182 dot_simd(row_slice, vec_slice)
183 }
184 }
185
186 /// Computes dot product between two matrix rows
187 ///
188 /// ### Params
189 ///
190 /// * `i` - First row index
191 /// * `j` - Second row index
192 ///
193 /// ### Returns
194 ///
195 /// The dot product of row i and row j
196 #[inline(always)]
197 pub fn dot_two_row(&self, i: usize, j: usize) -> f32 {
198 let start_i = i * self.n_col;
199 let start_j = j * self.n_col;
200 unsafe {
201 let row_i = std::slice::from_raw_parts(self.data.as_ptr().add(start_i), self.n_col);
202 let row_j = std::slice::from_raw_parts(self.data.as_ptr().add(start_j), self.n_col);
203 dot_simd(row_i, row_j)
204 }
205 }
206
207 /// Gets a mutable pointer to a matrix row
208 ///
209 /// ### Params
210 ///
211 /// * `i` - Row index
212 ///
213 /// ### Returns
214 ///
215 /// Mutable pointer to the start of row i
216 ///
217 /// ### Safety
218 ///
219 /// The caller must ensure the pointer is used correctly and doesn't
220 /// create aliasing issues. Primarily used for passing to SIMD functions.
221 #[inline(always)]
222 pub fn get_row(&mut self, i: usize) -> *mut f32 {
223 unsafe { self.data.as_mut_ptr().add(i * self.n_col) }
224 }
225
226 /// Returns a shared slice of row i
227 ///
228 /// ### Params
229 ///
230 /// * `i` - Row index
231 ///
232 /// ### Returns
233 ///
234 /// Slice of the row data
235 #[inline(always)]
236 pub fn row_as_slice(&self, i: usize) -> &[f32] {
237 let start = i * self.n_col;
238 &self.data[start..start + self.n_col]
239 }
240
241 /// Gets a const pointer to a matrix row
242 ///
243 /// ### Params
244 ///
245 /// * `i` - Row index
246 ///
247 /// ### Returns
248 ///
249 /// Const pointer to the start of row i
250 #[inline(always)]
251 pub fn get_row_unmod(&self, i: usize) -> *const f32 {
252 unsafe { self.data.as_ptr().add(i * self.n_col) }
253 }
254
255 /// Returns the number of elements per row
256 #[inline(always)]
257 pub fn n_col(&self) -> usize {
258 self.n_col
259 }
260
261 /// Returns the total number of rows
262 #[inline(always)]
263 pub fn n_rows(&self) -> usize {
264 self.n_row
265 }
266
267 /// Converts the matrix to a Faer matrix
268 ///
269 /// ### Returns
270 ///
271 /// Faer matrix
272 pub fn to_faer(&self) -> Mat<f32> {
273 Mat::from_fn(self.n_row, self.n_col, |i, j| self.data[i * self.n_col + j])
274 }
275
276 /// Write the matrix rows to a CSV file
277 ///
278 /// ### Params
279 ///
280 /// * `path` - Path to the output CSV
281 pub fn write_csv(&self, path: &str) -> std::io::Result<()> {
282 use std::io::Write;
283 let mut file = std::fs::File::create(path)?;
284 for i in 0..self.n_row {
285 let start = i * self.n_col;
286 let end = start + self.n_col;
287 let line = self.data[start..end]
288 .iter()
289 .map(|v| v.to_string())
290 .collect::<Vec<_>>()
291 .join(",");
292 writeln!(file, "{}", line)?;
293 }
294 Ok(())
295 }
296
297 /// Compute the element-wise average of two matrices
298 ///
299 /// ### Params
300 ///
301 /// * `other` - The other matrix (must have same dimensions)
302 ///
303 /// ### Returns
304 ///
305 /// A new matrix where each element is (self + other) / 2
306 ///
307 /// ### Panics
308 ///
309 /// Panics if dimensions do not match.
310 pub fn average_with(&self, other: &Matrix) -> Matrix {
311 assert_eq!(self.n_row, other.n_row, "Row count mismatch");
312 assert_eq!(self.n_col, other.n_col, "Column count mismatch");
313 let data = self
314 .data
315 .iter()
316 .zip(other.data.iter())
317 .map(|(a, b)| (a + b) * 0.5)
318 .collect();
319 Matrix {
320 n_row: self.n_row,
321 n_col: self.n_col,
322 data,
323 }
324 }
325}
326
327#[cfg(test)]
328mod tests {
329 use super::*;
330
331 #[test]
332 fn test_matrix_creation() {
333 let matrix = Matrix::new(10, 5);
334 assert_eq!(matrix.n_rows(), 10);
335 assert_eq!(matrix.n_col(), 5);
336 assert_eq!(matrix.data.len(), 50);
337 }
338
339 #[test]
340 fn test_matrix_zero() {
341 let mut matrix = Matrix::new(5, 4);
342 matrix.uniform(1.0, 123);
343 matrix.zero();
344 for v in &matrix.data {
345 assert_eq!(*v, 0.0);
346 }
347 }
348
349 #[test]
350 fn test_matrix_uniform() {
351 let mut matrix = Matrix::new(10, 10);
352 matrix.uniform(1.0, 123);
353
354 // Check values are within bounds
355 for v in &matrix.data {
356 assert!(v.abs() <= 1.0);
357 }
358
359 // Check not all zeros
360 assert!(matrix.data.iter().any(|&v| v != 0.0));
361 }
362
363 #[test]
364 fn test_dot_two_row() {
365 let mut matrix = Matrix::new(3, 4);
366 matrix.data = vec![1.0, 2.0, 3.0, 4.0, 2.0, 3.0, 4.0, 5.0, 1.0, 1.0, 1.0, 1.0];
367
368 let dot = matrix.dot_two_row(0, 1);
369 let expected = 1.0 * 2.0 + 2.0 * 3.0 + 3.0 * 4.0 + 4.0 * 5.0;
370 assert!((dot - expected).abs() < 1e-5);
371 }
372
373 #[test]
374 fn test_add_row() {
375 let mut matrix = Matrix::new(2, 4);
376 matrix.data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
377
378 let vec = [1.0, 1.0, 1.0, 1.0];
379 unsafe { matrix.add_row(vec.as_ptr(), 0, 2.0) };
380
381 assert!((matrix.data[0] - 3.0).abs() < 1e-5);
382 assert!((matrix.data[1] - 4.0).abs() < 1e-5);
383 assert!((matrix.data[2] - 5.0).abs() < 1e-5);
384 assert!((matrix.data[3] - 6.0).abs() < 1e-5);
385 }
386
387 #[test]
388 fn test_norm() {
389 let mut matrix = Matrix::new(2, 3);
390 matrix.data = vec![3.0, 4.0, 0.0, 1.0, 2.0, 2.0];
391
392 let norm0 = matrix.norm(0);
393 assert!((norm0 - 5.0).abs() < 1e-5);
394
395 let norm1 = matrix.norm(1);
396 assert!((norm1 - 3.0).abs() < 1e-5);
397 }
398
399 #[test]
400 fn test_norm_self() {
401 let mut matrix = Matrix::new(2, 3);
402 matrix.data = vec![3.0, 4.0, 0.0, 1.0, 2.0, 2.0];
403
404 matrix.norm_self();
405
406 let norm0 = matrix.norm(0);
407 let norm1 = matrix.norm(1);
408
409 assert!((norm0 - 1.0).abs() < 1e-5);
410 assert!((norm1 - 1.0).abs() < 1e-5);
411 }
412}