Skip to main content

node2vec_rs/cpu/
matrix.rs

1use faer::Mat;
2use rand::rngs::StdRng;
3use rand::SeedableRng;
4use rand_distr::{Distribution, Uniform};
5use std::cell::UnsafeCell;
6
7use crate::cpu::simd::*;
8
9/// Dense matrix stored in row-major format
10///
11/// The matrix data is stored as a flat vector where each row
12/// is laid out contiguously in memory for cache efficiency.
13/// This layout enables efficient SIMD operations on individual rows.
14///
15/// ### Fields
16///
17/// - `n_col` - Number of elements per row (dimension of embeddings)
18/// - `data` - Flat storage of matrix data (rows * ncol elements)
19#[derive(Debug)]
20pub struct Matrix {
21    n_col: usize,
22    n_row: usize,
23    data: Vec<f32>,
24}
25
26/// Thread-safe wrapper around Matrix for concurrent training
27///
28/// Uses `UnsafeCell` to allow interior mutability across threads.
29/// Safety is ensured by the training algorithm's access patterns where
30/// each thread writes to distinct rows.
31///
32/// ### Fields
33///
34/// * `inner` - The inner matrix data (unsafe cell)
35#[derive(Debug)]
36pub struct MatrixWrapper {
37    pub inner: UnsafeCell<Matrix>,
38}
39
40/// SAFETY: This is intentionally unsound. Multiple threads will concurrently
41/// read and write overlapping rows (e.g. a target in one thread may be a
42/// negative sample in another). This mirrors the deliberate data race in
43/// Mikolov's original word2vec C implementation and the word2vec-rs crate
44/// it was ported from. SGD tolerates stale/torn reads and the resulting
45/// embeddings converge in practice. Do not use MatrixWrapper as a general-
46/// purpose concurrent container.
47unsafe impl Sync for MatrixWrapper {}
48
49impl Matrix {
50    /// Creates a new matrix initialised with zeros
51    ///
52    /// ### Params
53    ///
54    /// * `rows` - Number of rows in the matrix (typically vocabulary size)
55    /// * `n_col` - Number of elements per row (embedding dimension)
56    ///
57    /// ### Returns
58    ///
59    /// A new Matrix with all elements set to 0.0
60    pub fn new(n_row: usize, n_col: usize) -> Matrix {
61        Matrix {
62            data: vec![0f32; n_col * n_row],
63            n_col,
64            n_row,
65        }
66    }
67
68    /// Normalises all rows in the matrix to unit length
69    ///
70    /// Each row vector is divided by its L2 norm, making it a unit vector.
71    pub fn norm_self(&mut self) {
72        let num_rows = self.n_row;
73        for i in 0..num_rows {
74            let n = self.norm(i);
75            if n > 0.0 {
76                let start = i * self.n_col;
77                let end = start + self.n_col;
78                for j in start..end {
79                    self.data[j] /= n;
80                }
81            }
82        }
83    }
84
85    /// Wraps the matrix in a thread-safe wrapper
86    ///
87    /// ### Returns
88    ///
89    /// A `MatrixWrapper` that can be safely shared across threads using Arc
90    pub fn make_send(self) -> MatrixWrapper {
91        MatrixWrapper {
92            inner: UnsafeCell::new(self),
93        }
94    }
95
96    /// Initialises matrix with uniform random values
97    ///
98    /// ### Params
99    ///
100    /// * `bound` - Values will be sampled uniformly from [-bound, bound]
101    /// * `seed` - Seed for reproducibility.
102    pub fn uniform(&mut self, bound: f32, seed: usize) {
103        let between = Uniform::new(-bound, bound).unwrap();
104        let mut rng = StdRng::seed_from_u64(seed as u64);
105        for v in &mut self.data {
106            *v = between.sample(&mut rng);
107        }
108    }
109
110    /// Computes the L2 norm of a matrix row
111    ///
112    /// ### Params
113    ///
114    /// * `i` - Row index
115    ///
116    /// ### Returns
117    ///
118    /// The L2 norm (Euclidean length) of the row vector
119    pub fn norm(&self, i: usize) -> f32 {
120        let start = i * self.n_col;
121        let end = start + self.n_col;
122        norm_l2_simd(&self.data[start..end])
123    }
124
125    /// Sets all matrix elements to zero
126    #[inline(always)]
127    pub fn zero(&mut self) {
128        for v in self.data.iter_mut() {
129            *v = 0f32;
130        }
131    }
132
133    /// Adds a scaled vector to a matrix row (SAXPY operation)
134    ///
135    /// Performs: `row[i] = row[i] + mul * vec`
136    ///
137    /// This is used during gradient updates where we add scaled gradients
138    /// to embedding vectors.
139    ///
140    /// ### Params
141    ///
142    /// * `vec` - Pointer to the vector to add (must have `n_col` elements)
143    /// * `i` - Row index
144    /// * `mul` - Scaling factor for the vector
145    ///
146    /// ### Safety
147    ///
148    /// The caller must ensure `vec` points to at least `n_col` valid f32
149    /// elements.
150    #[inline(always)]
151    pub unsafe fn add_row(&mut self, vec: *const f32, i: usize, mul: f32) {
152        let start = i * self.n_col;
153        unsafe {
154            let row_slice =
155                std::slice::from_raw_parts_mut(self.data.as_mut_ptr().add(start), self.n_col);
156            let vec_slice = std::slice::from_raw_parts(vec, self.n_col);
157            saxpy_simd(row_slice, vec_slice, mul);
158        }
159    }
160
161    /// Computes dot product between a vector and a matrix row
162    ///
163    /// ### Params
164    ///
165    /// * `vec` - Pointer to the vector (must have `row_size` elements)
166    /// * `i` - Row index
167    ///
168    /// ### Returns
169    ///
170    /// The dot product result
171    ///
172    /// ### Safety
173    ///
174    /// The caller must ensure `vec` points to at least `row_size` valid f32
175    /// elements
176    #[inline(always)]
177    pub unsafe fn dot_row(&self, vec: *const f32, i: usize) -> f32 {
178        let start = i * self.n_col;
179        unsafe {
180            let row_slice = std::slice::from_raw_parts(self.data.as_ptr().add(start), self.n_col);
181            let vec_slice = std::slice::from_raw_parts(vec, self.n_col);
182            dot_simd(row_slice, vec_slice)
183        }
184    }
185
186    /// Computes dot product between two matrix rows
187    ///
188    /// ### Params
189    ///
190    /// * `i` - First row index
191    /// * `j` - Second row index
192    ///
193    /// ### Returns
194    ///
195    /// The dot product of row i and row j
196    #[inline(always)]
197    pub fn dot_two_row(&self, i: usize, j: usize) -> f32 {
198        let start_i = i * self.n_col;
199        let start_j = j * self.n_col;
200        unsafe {
201            let row_i = std::slice::from_raw_parts(self.data.as_ptr().add(start_i), self.n_col);
202            let row_j = std::slice::from_raw_parts(self.data.as_ptr().add(start_j), self.n_col);
203            dot_simd(row_i, row_j)
204        }
205    }
206
207    /// Gets a mutable pointer to a matrix row
208    ///
209    /// ### Params
210    ///
211    /// * `i` - Row index
212    ///
213    /// ### Returns
214    ///
215    /// Mutable pointer to the start of row i
216    ///
217    /// ### Safety
218    ///
219    /// The caller must ensure the pointer is used correctly and doesn't
220    /// create aliasing issues. Primarily used for passing to SIMD functions.
221    #[inline(always)]
222    pub fn get_row(&mut self, i: usize) -> *mut f32 {
223        unsafe { self.data.as_mut_ptr().add(i * self.n_col) }
224    }
225
226    /// Returns a shared slice of row i
227    ///
228    /// ### Params
229    ///
230    /// * `i` - Row index
231    ///
232    /// ### Returns
233    ///
234    /// Slice of the row data
235    #[inline(always)]
236    pub fn row_as_slice(&self, i: usize) -> &[f32] {
237        let start = i * self.n_col;
238        &self.data[start..start + self.n_col]
239    }
240
241    /// Gets a const pointer to a matrix row
242    ///
243    /// ### Params
244    ///
245    /// * `i` - Row index
246    ///
247    /// ### Returns
248    ///
249    /// Const pointer to the start of row i
250    #[inline(always)]
251    pub fn get_row_unmod(&self, i: usize) -> *const f32 {
252        unsafe { self.data.as_ptr().add(i * self.n_col) }
253    }
254
255    /// Returns the number of elements per row
256    #[inline(always)]
257    pub fn n_col(&self) -> usize {
258        self.n_col
259    }
260
261    /// Returns the total number of rows
262    #[inline(always)]
263    pub fn n_rows(&self) -> usize {
264        self.n_row
265    }
266
267    /// Converts the matrix to a Faer matrix
268    ///
269    /// ### Returns
270    ///
271    /// Faer matrix
272    pub fn to_faer(&self) -> Mat<f32> {
273        Mat::from_fn(self.n_row, self.n_col, |i, j| self.data[i * self.n_col + j])
274    }
275
276    /// Write the matrix rows to a CSV file
277    ///
278    /// ### Params
279    ///
280    /// * `path` - Path to the output CSV
281    pub fn write_csv(&self, path: &str) -> std::io::Result<()> {
282        use std::io::Write;
283        let mut file = std::fs::File::create(path)?;
284        for i in 0..self.n_row {
285            let start = i * self.n_col;
286            let end = start + self.n_col;
287            let line = self.data[start..end]
288                .iter()
289                .map(|v| v.to_string())
290                .collect::<Vec<_>>()
291                .join(",");
292            writeln!(file, "{}", line)?;
293        }
294        Ok(())
295    }
296
297    /// Compute the element-wise average of two matrices
298    ///
299    /// ### Params
300    ///
301    /// * `other` - The other matrix (must have same dimensions)
302    ///
303    /// ### Returns
304    ///
305    /// A new matrix where each element is (self + other) / 2
306    ///
307    /// ### Panics
308    ///
309    /// Panics if dimensions do not match.
310    pub fn average_with(&self, other: &Matrix) -> Matrix {
311        assert_eq!(self.n_row, other.n_row, "Row count mismatch");
312        assert_eq!(self.n_col, other.n_col, "Column count mismatch");
313        let data = self
314            .data
315            .iter()
316            .zip(other.data.iter())
317            .map(|(a, b)| (a + b) * 0.5)
318            .collect();
319        Matrix {
320            n_row: self.n_row,
321            n_col: self.n_col,
322            data,
323        }
324    }
325}
326
327#[cfg(test)]
328mod tests {
329    use super::*;
330
331    #[test]
332    fn test_matrix_creation() {
333        let matrix = Matrix::new(10, 5);
334        assert_eq!(matrix.n_rows(), 10);
335        assert_eq!(matrix.n_col(), 5);
336        assert_eq!(matrix.data.len(), 50);
337    }
338
339    #[test]
340    fn test_matrix_zero() {
341        let mut matrix = Matrix::new(5, 4);
342        matrix.uniform(1.0, 123);
343        matrix.zero();
344        for v in &matrix.data {
345            assert_eq!(*v, 0.0);
346        }
347    }
348
349    #[test]
350    fn test_matrix_uniform() {
351        let mut matrix = Matrix::new(10, 10);
352        matrix.uniform(1.0, 123);
353
354        // Check values are within bounds
355        for v in &matrix.data {
356            assert!(v.abs() <= 1.0);
357        }
358
359        // Check not all zeros
360        assert!(matrix.data.iter().any(|&v| v != 0.0));
361    }
362
363    #[test]
364    fn test_dot_two_row() {
365        let mut matrix = Matrix::new(3, 4);
366        matrix.data = vec![1.0, 2.0, 3.0, 4.0, 2.0, 3.0, 4.0, 5.0, 1.0, 1.0, 1.0, 1.0];
367
368        let dot = matrix.dot_two_row(0, 1);
369        let expected = 1.0 * 2.0 + 2.0 * 3.0 + 3.0 * 4.0 + 4.0 * 5.0;
370        assert!((dot - expected).abs() < 1e-5);
371    }
372
373    #[test]
374    fn test_add_row() {
375        let mut matrix = Matrix::new(2, 4);
376        matrix.data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
377
378        let vec = [1.0, 1.0, 1.0, 1.0];
379        unsafe { matrix.add_row(vec.as_ptr(), 0, 2.0) };
380
381        assert!((matrix.data[0] - 3.0).abs() < 1e-5);
382        assert!((matrix.data[1] - 4.0).abs() < 1e-5);
383        assert!((matrix.data[2] - 5.0).abs() < 1e-5);
384        assert!((matrix.data[3] - 6.0).abs() < 1e-5);
385    }
386
387    #[test]
388    fn test_norm() {
389        let mut matrix = Matrix::new(2, 3);
390        matrix.data = vec![3.0, 4.0, 0.0, 1.0, 2.0, 2.0];
391
392        let norm0 = matrix.norm(0);
393        assert!((norm0 - 5.0).abs() < 1e-5);
394
395        let norm1 = matrix.norm(1);
396        assert!((norm1 - 3.0).abs() < 1e-5);
397    }
398
399    #[test]
400    fn test_norm_self() {
401        let mut matrix = Matrix::new(2, 3);
402        matrix.data = vec![3.0, 4.0, 0.0, 1.0, 2.0, 2.0];
403
404        matrix.norm_self();
405
406        let norm0 = matrix.norm(0);
407        let norm1 = matrix.norm(1);
408
409        assert!((norm0 - 1.0).abs() < 1e-5);
410        assert!((norm1 - 1.0).abs() < 1e-5);
411    }
412}