single_svdlib/lanczos/
masked.rs

1use crate::{determine_chunk_size, SMat, SvdFloat};
2use nalgebra_sparse::na::{DMatrix, DVector};
3use nalgebra_sparse::CsrMatrix;
4use num_traits::Float;
5use rayon::iter::IndexedParallelIterator;
6use rayon::iter::ParallelIterator;
7use rayon::prelude::{
8    IntoParallelIterator, IntoParallelRefIterator, ParallelBridge, ParallelSliceMut,
9};
10use std::fmt::Debug;
11use std::ops::AddAssign;
12
13pub struct MaskedCSRMatrix<'a, T: Float> {
14    matrix: &'a CsrMatrix<T>,
15    column_mask: Vec<bool>,
16    masked_to_original: Vec<usize>,
17    original_to_masked: Vec<Option<usize>>,
18}
19
20impl<'a, T: Float> MaskedCSRMatrix<'a, T> {
21    pub fn new(matrix: &'a CsrMatrix<T>, column_mask: Vec<bool>) -> Self {
22        assert_eq!(
23            column_mask.len(),
24            matrix.ncols(),
25            "Column mask must have the same length as the number of columns in the matrix"
26        );
27
28        let mut masked_to_original = Vec::new();
29        let mut original_to_masked = vec![None; column_mask.len()];
30        let mut masked_index = 0;
31
32        for (i, &is_included) in column_mask.iter().enumerate() {
33            if is_included {
34                masked_to_original.push(i);
35                original_to_masked[i] = Some(masked_index);
36                masked_index += 1;
37            }
38        }
39
40        Self {
41            matrix,
42            column_mask,
43            masked_to_original,
44            original_to_masked,
45        }
46    }
47
48    pub fn with_columns(matrix: &'a CsrMatrix<T>, columns: &[usize]) -> Self {
49        let mut mask = vec![false; matrix.ncols()];
50        for &col in columns {
51            assert!(col < matrix.ncols(), "Column index out of bounds");
52            mask[col] = true;
53        }
54        Self::new(matrix, mask)
55    }
56
57    pub fn uses_all_columns(&self) -> bool {
58        self.masked_to_original.len() == self.matrix.ncols() && self.column_mask.iter().all(|&x| x)
59    }
60
61    pub fn ensure_identical_results_mode(&self) -> bool {
62        // For very small matrices where precision is critical
63        let is_small_matrix = self.matrix.nrows() <= 5 && self.matrix.ncols() <= 5;
64        is_small_matrix && self.uses_all_columns()
65    }
66}
67
68impl<
69        T: Float
70            + AddAssign
71            + Sync
72            + Send
73            + std::ops::MulAssign
74            + Debug
75            + 'static
76            + std::iter::Sum
77            + std::ops::SubAssign
78            + num_traits::FromPrimitive,
79    > SMat<T> for MaskedCSRMatrix<'_, T>
80{
81    fn nrows(&self) -> usize {
82        self.matrix.nrows()
83    }
84
85    fn ncols(&self) -> usize {
86        self.masked_to_original.len()
87    }
88
89    fn nnz(&self) -> usize {
90        let (major_offsets, minor_indices, _) = self.matrix.csr_data();
91        let mut count = 0;
92
93        for i in 0..self.matrix.nrows() {
94            for j in major_offsets[i]..major_offsets[i + 1] {
95                let col = minor_indices[j];
96                if self.column_mask[col] {
97                    count += 1;
98                }
99            }
100        }
101        count
102    }
103
104    fn svd_opa(&self, x: &[T], y: &mut [T], transposed: bool) {
105        let nrows = if transposed {
106            self.ncols()
107        } else {
108            self.nrows()
109        };
110        let ncols = if transposed {
111            self.nrows()
112        } else {
113            self.ncols()
114        };
115
116        assert_eq!(
117            x.len(),
118            ncols,
119            "svd_opa: x must be A.ncols() in length, x = {}, A.ncols = {}",
120            x.len(),
121            ncols
122        );
123        assert_eq!(
124            y.len(),
125            nrows,
126            "svd_opa: y must be A.nrows() in length, y = {}, A.nrows = {}",
127            y.len(),
128            nrows
129        );
130
131        let (major_offsets, minor_indices, values) = self.matrix.csr_data();
132
133        if self.uses_all_columns() || (self.matrix.nrows() < 1000 && self.matrix.ncols() < 1000) {
134            // Fast path for unmasked matrices or small matrices
135            if !transposed {
136                // A * x calculation
137                self.matrix.svd_opa(x, y, false);
138            } else {
139                // A^T * x calculation
140                self.matrix.svd_opa(x, y, true);
141            }
142            return;
143        }
144
145        y.fill(T::zero());
146
147        if !transposed {
148            // A * x calculation
149            let valid_indices: Vec<Option<usize>> = (0..self.matrix.ncols())
150                .map(|col| self.original_to_masked[col])
151                .collect();
152
153            // Parallelization parameters
154            let rows = self.matrix.nrows();
155            let chunk_size = std::cmp::max(16, rows / (rayon::current_num_threads() * 2));
156
157            // Process in parallel chunks
158            y.par_chunks_mut(chunk_size)
159                .enumerate()
160                .for_each(|(chunk_idx, y_chunk)| {
161                    let start_row = chunk_idx * chunk_size;
162                    let end_row = (start_row + y_chunk.len()).min(rows);
163
164                    for i in start_row..end_row {
165                        let row_idx = i - start_row;
166                        let mut sum = T::zero();
167
168                        // Process row in blocks of 16 elements for better vectorization
169                        let row_start = major_offsets[i];
170                        let row_end = major_offsets[i + 1];
171
172                        // Unroll the loop by 4 for better instruction-level parallelism
173                        let mut j = row_start;
174                        while j + 4 <= row_end {
175                            for offset in 0..4 {
176                                let idx = j + offset;
177                                let col = minor_indices[idx];
178                                if let Some(masked_col) = valid_indices[col] {
179                                    sum += values[idx] * x[masked_col];
180                                }
181                            }
182                            j += 4;
183                        }
184
185                        // Handle remaining elements
186                        while j < row_end {
187                            let col = minor_indices[j];
188                            if let Some(masked_col) = valid_indices[col] {
189                                sum += values[j] * x[masked_col];
190                            }
191                            j += 1;
192                        }
193
194                        y_chunk[row_idx] = sum;
195                    }
196                });
197        } else {
198            // A^T * x calculation
199            let nrows = self.matrix.nrows();
200            let chunk_size = crate::utils::determine_chunk_size(nrows);
201
202            // Create thread-local partial results and combine at the end
203            let results: Vec<Vec<T>> = (0..nrows.div_ceil(chunk_size))
204                .into_par_iter()
205                .map(|chunk_idx| {
206                    let start = chunk_idx * chunk_size;
207                    let end = (start + chunk_size).min(nrows);
208                    let mut local_y = vec![T::zero(); y.len()];
209
210                    // Process a chunk of rows
211                    for i in start..end {
212                        let row_val = x[i];
213                        if row_val.is_zero() {
214                            continue; // Skip zero values for performance
215                        }
216
217                        for j in major_offsets[i]..major_offsets[i + 1] {
218                            let col = minor_indices[j];
219                            if let Some(masked_col) = self.original_to_masked[col] {
220                                local_y[masked_col] += values[j] * row_val;
221                            }
222                        }
223                    }
224                    local_y
225                })
226                .collect();
227
228            // Combine results efficiently
229            for local_y in results {
230                // Only update non-zero elements to reduce memory traffic
231                for (idx, &val) in local_y.iter().enumerate() {
232                    if !val.is_zero() {
233                        y[idx] += val;
234                    }
235                }
236            }
237        }
238    }
239
240    fn compute_column_means(&self) -> Vec<T> {
241        let rows = self.nrows();
242        let masked_cols = self.ncols();
243        let row_count_recip = T::one() / T::from(rows).unwrap();
244
245        let mut col_sums = vec![T::zero(); masked_cols];
246        let (row_offsets, col_indices, values) = self.matrix.csr_data();
247
248        for i in 0..rows {
249            for j in row_offsets[i]..row_offsets[i + 1] {
250                let original_col = col_indices[j];
251                if let Some(masked_col) = self.original_to_masked[original_col] {
252                    col_sums[masked_col] += values[j];
253                }
254            }
255        }
256
257        // Convert to means
258        for j in 0..masked_cols {
259            col_sums[j] *= row_count_recip;
260        }
261
262        col_sums
263    }
264
265    fn multiply_with_dense(
266        &self,
267        dense: &DMatrix<T>,
268        result: &mut DMatrix<T>,
269        transpose_self: bool,
270    ) {
271        let m_rows = if transpose_self {
272            self.ncols()
273        } else {
274            self.nrows()
275        };
276        let m_cols = if transpose_self {
277            self.nrows()
278        } else {
279            self.ncols()
280        };
281
282        assert_eq!(
283            dense.nrows(),
284            m_cols,
285            "Dense matrix has incompatible row count"
286        );
287        assert_eq!(
288            result.nrows(),
289            m_rows,
290            "Result matrix has incompatible row count"
291        );
292        assert_eq!(
293            result.ncols(),
294            dense.ncols(),
295            "Result matrix has incompatible column count"
296        );
297
298        let (major_offsets, minor_indices, values) = self.matrix.csr_data();
299
300        if !transpose_self {
301            let rows = self.matrix.nrows();
302            let dense_cols = dense.ncols();
303
304            // Pre-filter valid column mappings to avoid repeated lookups
305            let valid_cols: Vec<Option<usize>> = (0..self.matrix.ncols())
306                .map(|col| self.original_to_masked.get(col).copied().flatten())
307                .collect();
308
309            // Compute results in parallel, then apply to result matrix
310            let row_results: Vec<(usize, Vec<T>)> = (0..rows)
311                .into_par_iter()
312                .map(|row| {
313                    let mut row_result = vec![T::zero(); dense_cols];
314
315                    // Process sparse row with blocked inner loop for better vectorization
316                    let row_start = major_offsets[row];
317                    let row_end = major_offsets[row + 1];
318
319                    // Unroll the sparse elements loop by 4 for better ILP
320                    let mut j = row_start;
321                    while j + 4 <= row_end {
322                        // Process 4 sparse elements at once
323                        for offset in 0..4 {
324                            let idx = j + offset;
325                            let col = minor_indices[idx];
326                            if let Some(masked_col) = valid_cols[col] {
327                                let val = values[idx];
328
329                                // Vectorized dense column update
330                                for c in 0..dense_cols {
331                                    row_result[c] += val * dense[(masked_col, c)];
332                                }
333                            }
334                        }
335                        j += 4;
336                    }
337
338                    // Handle remaining elements
339                    while j < row_end {
340                        let col = minor_indices[j];
341                        if let Some(masked_col) = valid_cols[col] {
342                            let val = values[j];
343
344                            for c in 0..dense_cols {
345                                row_result[c] += val * dense[(masked_col, c)];
346                            }
347                        }
348                        j += 1;
349                    }
350
351                    (row, row_result)
352                })
353                .collect();
354
355            // Apply results to output matrix
356            for (row, row_values) in row_results {
357                for c in 0..dense_cols {
358                    result[(row, c)] = row_values[c];
359                }
360            }
361        } else {
362            let nrows = self.matrix.nrows();
363            let ncols = self.ncols();
364            let dense_cols = dense.ncols();
365
366            // Clear result matrix once at the beginning
367            result.fill(T::zero());
368
369            // Pre-filter valid column mappings
370            let valid_cols: Vec<Option<usize>> = (0..self.matrix.ncols())
371                .map(|col| self.original_to_masked.get(col).copied().flatten())
372                .collect();
373
374            let chunk_size = determine_chunk_size(nrows);
375
376            // Use atomic-free approach with proper synchronization
377            let partial_results: Vec<Vec<T>> = (0..nrows.div_ceil(chunk_size))
378                .into_par_iter()
379                .map(|chunk_idx| {
380                    let start = chunk_idx * chunk_size;
381                    let end = (start + chunk_size).min(nrows);
382
383                    // Use flat vector for better cache performance
384                    let mut local_result = vec![T::zero(); ncols * dense_cols];
385
386                    // Process chunk with better memory access patterns
387                    for i in start..end {
388                        let dense_row = unsafe {
389                            std::slice::from_raw_parts(
390                                dense.as_ptr().add(i * dense_cols),
391                                dense_cols,
392                            )
393                        };
394
395                        // Block processing for better cache usage
396                        let row_start = major_offsets[i];
397                        let row_end = major_offsets[i + 1];
398
399                        // Process sparse elements in blocks of 8 for better vectorization
400                        let mut j = row_start;
401                        while j + 8 <= row_end {
402                            for offset in 0..8 {
403                                let idx = j + offset;
404                                let col = minor_indices[idx];
405                                if let Some(masked_col) = valid_cols[col] {
406                                    let val = values[idx];
407                                    let base_offset = masked_col * dense_cols;
408
409                                    // Vectorized update with manual loop unrolling
410                                    let mut c = 0;
411                                    while c + 4 <= dense_cols {
412                                        local_result[base_offset + c] += val * dense_row[c];
413                                        local_result[base_offset + c + 1] += val * dense_row[c + 1];
414                                        local_result[base_offset + c + 2] += val * dense_row[c + 2];
415                                        local_result[base_offset + c + 3] += val * dense_row[c + 3];
416                                        c += 4;
417                                    }
418
419                                    // Handle remaining columns
420                                    while c < dense_cols {
421                                        local_result[base_offset + c] += val * dense_row[c];
422                                        c += 1;
423                                    }
424                                }
425                            }
426                            j += 8;
427                        }
428
429                        // Handle remaining sparse elements
430                        while j < row_end {
431                            let col = minor_indices[j];
432                            if let Some(masked_col) = valid_cols[col] {
433                                let val = values[j];
434                                let base_offset = masked_col * dense_cols;
435
436                                for c in 0..dense_cols {
437                                    local_result[base_offset + c] += val * dense_row[c];
438                                }
439                            }
440                            j += 1;
441                        }
442                    }
443
444                    local_result
445                })
446                .collect();
447
448            // Efficient reduction with blocked memory access
449            const BLOCK_SIZE: usize = 64;
450            for local_result in partial_results {
451                // Process in blocks for better cache performance
452                for r_block in (0..ncols).step_by(BLOCK_SIZE) {
453                    let r_end = (r_block + BLOCK_SIZE).min(ncols);
454
455                    for c_block in (0..dense_cols).step_by(BLOCK_SIZE) {
456                        let c_end = (c_block + BLOCK_SIZE).min(dense_cols);
457
458                        // Update result block
459                        for r in r_block..r_end {
460                            for c in c_block..c_end {
461                                let val = local_result[r * dense_cols + c];
462                                if !val.is_zero() {
463                                    result[(r, c)] += val;
464                                }
465                            }
466                        }
467                    }
468                }
469            }
470        }
471    }
472
473    fn multiply_with_dense_centered(
474        &self,
475        dense: &DMatrix<T>,
476        result: &mut DMatrix<T>,
477        transpose_self: bool,
478        means: &DVector<T>,
479    ) {
480        let (major_offsets, minor_indices, values) = self.matrix.csr_data();
481
482        // Pre-compute column sums for the dense matrix - do this once
483        let dense_cols = dense.ncols();
484        let dense_rows = dense.nrows();
485
486        // Pre-compute all column sums to avoid redundant calculations
487        let col_sums: Vec<T> = (0..dense_cols)
488            .into_par_iter()
489            .map(|c| (0..dense_rows).map(|i| dense[(i, c)]).sum())
490            .collect();
491
492        if !transpose_self {
493            let rows = self.matrix.nrows();
494
495            // Pre-compute mean adjustments for each column
496            let mean_adjustments: Vec<T> = col_sums
497                .iter()
498                .map(|&col_sum| {
499                    means
500                        .iter()
501                        .enumerate()
502                        .filter_map(|(original_idx, &mean_val)| {
503                            self.original_to_masked
504                                .get(original_idx)
505                                .map(|_| mean_val * col_sum)
506                        })
507                        .sum()
508                })
509                .collect();
510
511            let row_updates: Vec<(usize, Vec<T>)> = (0..rows)
512                .into_par_iter()
513                .map(|row| {
514                    let mut row_result = vec![T::zero(); dense_cols];
515
516                    for j in major_offsets[row]..major_offsets[row + 1] {
517                        let col = minor_indices[j];
518                        if let Some(masked_col) = self.original_to_masked[col] {
519                            let val = values[j];
520
521                            for c in 0..dense_cols {
522                                row_result[c] += val * dense[(masked_col, c)];
523                            }
524                        }
525                    }
526
527                    for c in 0..dense_cols {
528                        row_result[c] -= mean_adjustments[c];
529                    }
530
531                    (row, row_result)
532                })
533                .collect();
534
535            for (row, row_values) in row_updates {
536                for c in 0..dense_cols {
537                    result[(row, c)] = row_values[c];
538                }
539            }
540        } else {
541            let nrows = self.matrix.nrows();
542            let ncols = self.ncols();
543
544            // Clear the result matrix first
545            for i in 0..result.nrows() {
546                for j in 0..result.ncols() {
547                    result[(i, j)] = T::zero();
548                }
549            }
550
551            // Choose optimal chunk size
552            let chunk_size = determine_chunk_size(nrows);
553
554            // Compute partial results in parallel
555            let partial_results: Vec<DMatrix<T>> = (0..nrows.div_ceil(chunk_size))
556                .into_par_iter()
557                .map(|chunk_idx| {
558                    let start = chunk_idx * chunk_size;
559                    let end = std::cmp::min(start + chunk_size, nrows);
560
561                    let mut local_result = DMatrix::<T>::zeros(ncols, dense_cols);
562
563                    for i in start..end {
564                        for j in major_offsets[i]..major_offsets[i + 1] {
565                            let col = minor_indices[j];
566                            if let Some(masked_col) = self.original_to_masked[col] {
567                                let sparse_val = values[j];
568
569                                for c in 0..dense_cols {
570                                    local_result[(masked_col, c)] += sparse_val * dense[(i, c)];
571                                }
572                            }
573                        }
574                    }
575
576                    // Apply mean adjustment for this chunk
577                    let chunk_fraction =
578                        T::from_f64((end - start) as f64 / dense_rows as f64).unwrap();
579
580                    for masked_col in 0..ncols {
581                        if masked_col < means.len() {
582                            let mean = means[masked_col];
583                            for c in 0..dense_cols {
584                                local_result[(masked_col, c)] -=
585                                    mean * col_sums[c] * chunk_fraction;
586                            }
587                        }
588                    }
589
590                    local_result
591                })
592                .collect();
593
594            for local_result in partial_results {
595                const BLOCK_SIZE: usize = 32;
596
597                for r_block in 0..ncols.div_ceil(BLOCK_SIZE) {
598                    let r_start = r_block * BLOCK_SIZE;
599                    let r_end = std::cmp::min(r_start + BLOCK_SIZE, ncols);
600
601                    for c_block in 0..dense_cols.div_ceil(BLOCK_SIZE) {
602                        let c_start = c_block * BLOCK_SIZE;
603                        let c_end = std::cmp::min(c_start + BLOCK_SIZE, dense_cols);
604
605                        for r in r_start..r_end {
606                            for c in c_start..c_end {
607                                result[(r, c)] += local_result[(r, c)];
608                            }
609                        }
610                    }
611                }
612            }
613        }
614    }
615
616    fn multiply_transposed_by_dense(&self, q: &DMatrix<T>, result: &mut DMatrix<T>) {
617        let q_rows = q.nrows();
618        let q_cols = q.ncols();
619        let masked_cols = self.ncols();
620
621        assert_eq!(
622            q_rows,
623            self.nrows(),
624            "Q matrix has incompatible row count: expected {}, got {}",
625            self.nrows(),
626            q_rows
627        );
628        assert_eq!(
629            result.nrows(),
630            q_cols,
631            "Result matrix has incompatible row count: expected {}, got {}",
632            q_cols,
633            result.nrows()
634        );
635        assert_eq!(
636            result.ncols(),
637            masked_cols,
638            "Result matrix has incompatible column count: expected {}, got {}",
639            masked_cols,
640            result.ncols()
641        );
642
643        // Clear result matrix
644        for i in 0..result.nrows() {
645            for j in 0..result.ncols() {
646                result[(i, j)] = T::zero();
647            }
648        }
649
650        let (major_offsets, minor_indices, values) = self.matrix.csr_data();
651        let nrows = self.matrix.nrows();
652        let chunk_size = determine_chunk_size(nrows);
653
654        if self.uses_all_columns() && (nrows < 1000 && self.matrix.ncols() < 1000) {
655            // Fast path for small unmasked matrices
656            let partial_results: Vec<DMatrix<T>> = (0..nrows.div_ceil(chunk_size))
657                .into_par_iter()
658                .map(|chunk_idx| {
659                    let start = chunk_idx * chunk_size;
660                    let end = (start + chunk_size).min(nrows);
661                    let mut local_result = DMatrix::<T>::zeros(q_cols, masked_cols);
662
663                    for row in start..end {
664                        // Process all non-zeros in this row
665                        for idx in major_offsets[row]..major_offsets[row + 1] {
666                            let col = minor_indices[idx];
667                            let sparse_val = values[idx];
668
669                            // Accumulate: local_result[q_col, col] += q[row, q_col] * sparse_val
670                            for q_col in 0..q_cols {
671                                local_result[(q_col, col)] += q[(row, q_col)] * sparse_val;
672                            }
673                        }
674                    }
675
676                    local_result
677                })
678                .collect();
679
680            // Combine partial results efficiently
681            for local_result in partial_results {
682                for r in 0..q_cols {
683                    for c in 0..masked_cols {
684                        let val = local_result[(r, c)];
685                        if !val.is_zero() {
686                            result[(r, c)] += val;
687                        }
688                    }
689                }
690            }
691        } else {
692            // Optimized path for masked matrices
693            let partial_results: Vec<DMatrix<T>> = (0..nrows.div_ceil(chunk_size))
694                .into_par_iter()
695                .map(|chunk_idx| {
696                    let start = chunk_idx * chunk_size;
697                    let end = (start + chunk_size).min(nrows);
698                    let mut local_result = DMatrix::<T>::zeros(q_cols, masked_cols);
699
700                    for row in start..end {
701                        // Process all non-zeros in this row
702                        for idx in major_offsets[row]..major_offsets[row + 1] {
703                            let original_col = minor_indices[idx];
704                            
705                            // Check if this column is in our mask
706                            if let Some(masked_col) = self.original_to_masked[original_col] {
707                                let sparse_val = values[idx];
708
709                                // Accumulate: local_result[q_col, masked_col] += q[row, q_col] * sparse_val
710                                for q_col in 0..q_cols {
711                                    local_result[(q_col, masked_col)] += q[(row, q_col)] * sparse_val;
712                                }
713                            }
714                        }
715                    }
716
717                    local_result
718                })
719                .collect();
720
721            // Combine partial results efficiently
722            for local_result in partial_results {
723                for r in 0..q_cols {
724                    for c in 0..masked_cols {
725                        let val = local_result[(r, c)];
726                        if !val.is_zero() {
727                            result[(r, c)] += val;
728                        }
729                    }
730                }
731            }
732        }
733    }
734
735    fn multiply_transposed_by_dense_centered(
736        &self,
737        q: &DMatrix<T>,
738        result: &mut DMatrix<T>,
739        means: &DVector<T>,
740    ) {
741        let q_rows = q.nrows();
742        let q_cols = q.ncols();
743        let masked_cols = self.ncols();
744
745        assert_eq!(
746            q_rows,
747            self.nrows(),
748            "Q matrix has incompatible row count: expected {}, got {}",
749            self.nrows(),
750            q_rows
751        );
752        assert_eq!(
753            result.nrows(),
754            q_cols,
755            "Result matrix has incompatible row count: expected {}, got {}",
756            q_cols,
757            result.nrows()
758        );
759        assert_eq!(
760            result.ncols(),
761            masked_cols,
762            "Result matrix has incompatible column count: expected {}, got {}",
763            masked_cols,
764            result.ncols()
765        );
766        assert_eq!(
767            means.len(),
768            masked_cols,
769            "Means vector has incompatible length: expected {}, got {}",
770            masked_cols,
771            means.len()
772        );
773
774        // Clear result matrix
775        for i in 0..result.nrows() {
776            for j in 0..result.ncols() {
777                result[(i, j)] = T::zero();
778            }
779        }
780
781        let (major_offsets, minor_indices, values) = self.matrix.csr_data();
782
783        // Pre-compute column sums of Q - following the pattern from multiply_with_dense_centered
784        let q_col_sums: Vec<T> = (0..q_cols)
785            .into_par_iter()
786            .map(|col| {
787                (0..q_rows).map(|row| q[(row, col)]).sum()
788            })
789            .collect();
790
791        // Pre-compute mean adjustments for each masked column
792        // For Q^T * (A - means): result[q_col, masked_col] = Q^T * A - sum(Q[q_col]) * means[masked_col]
793        let mean_adjustments: Vec<T> = q_col_sums
794            .iter()
795            .enumerate()
796            .map(|(q_col, &q_sum)| {
797                means
798                    .iter()
799                    .enumerate()
800                    .map(|(masked_col_idx, &mean_val)| {
801                        if masked_col_idx < masked_cols {
802                            q_sum * mean_val
803                        } else {
804                            T::zero()
805                        }
806                    })
807                    .sum()
808            })
809            .collect();
810
811        let nrows = self.matrix.nrows();
812        let chunk_size = determine_chunk_size(nrows);
813
814        // Process sparse matrix rows in chunks, similar to the transpose_self=true case
815        let partial_results: Vec<DMatrix<T>> = (0..nrows.div_ceil(chunk_size))
816            .into_par_iter()
817            .map(|chunk_idx| {
818                let start = chunk_idx * chunk_size;
819                let end = std::cmp::min(start + chunk_size, nrows);
820
821                let mut local_result = DMatrix::<T>::zeros(q_cols, masked_cols);
822
823                for row in start..end {
824                    // Process all non-zeros in this row
825                    for idx in major_offsets[row]..major_offsets[row + 1] {
826                        let original_col = minor_indices[idx];
827                        
828                        // Check if this column is in our mask
829                        if let Some(masked_col) = self.original_to_masked[original_col] {
830                            let sparse_val = values[idx];
831
832                            // Accumulate: local_result[q_col, masked_col] += q[row, q_col] * sparse_val
833                            for q_col in 0..q_cols {
834                                local_result[(q_col, masked_col)] += q[(row, q_col)] * sparse_val;
835                            }
836                        }
837                    }
838                }
839
840                // Apply mean adjustment for this chunk, following the pattern from your function
841                let chunk_fraction = T::from_f64((end - start) as f64 / q_rows as f64).unwrap();
842
843                for q_col in 0..q_cols {
844                    let q_sum = q_col_sums[q_col];
845                    for masked_col in 0..masked_cols {
846                        local_result[(q_col, masked_col)] -= q_sum * means[masked_col] * chunk_fraction;
847                    }
848                }
849
850                local_result
851            })
852            .collect();
853
854        // Combine partial results with block-wise writing for better cache locality
855        for local_result in partial_results {
856            const BLOCK_SIZE: usize = 64;
857
858            for r_block in 0..q_cols.div_ceil(BLOCK_SIZE) {
859                let r_start = r_block * BLOCK_SIZE;
860                let r_end = std::cmp::min(r_start + BLOCK_SIZE, q_cols);
861
862                for c_block in 0..masked_cols.div_ceil(BLOCK_SIZE) {
863                    let c_start = c_block * BLOCK_SIZE;
864                    let c_end = std::cmp::min(c_start + BLOCK_SIZE, masked_cols);
865
866                    for r in r_start..r_end {
867                        for c in c_start..c_end {
868                            result[(r, c)] += local_result[(r, c)];
869                        }
870                    }
871                }
872            }
873        }
874    }
875}
876
877#[cfg(test)]
878mod tests {
879    use super::*;
880    use crate::SMat;
881    use nalgebra_sparse::{coo::CooMatrix, csr::CsrMatrix};
882    use rand::rngs::StdRng;
883    use rand::{Rng, SeedableRng};
884
885    #[test]
886    fn test_masked_matrix() {
887        // Create a test matrix
888        let mut coo = CooMatrix::<f64>::new(3, 5);
889        coo.push(0, 0, 1.0);
890        coo.push(0, 2, 2.0);
891        coo.push(0, 4, 3.0);
892        coo.push(1, 1, 4.0);
893        coo.push(1, 3, 5.0);
894        coo.push(2, 0, 6.0);
895        coo.push(2, 2, 7.0);
896        coo.push(2, 4, 8.0);
897
898        let csr = CsrMatrix::from(&coo);
899
900        // Create a masked matrix with columns 0, 2, 4
901        let columns = vec![0, 2, 4];
902        let masked = MaskedCSRMatrix::with_columns(&csr, &columns);
903
904        // Check dimensions
905        assert_eq!(masked.nrows(), 3);
906        assert_eq!(masked.ncols(), 3);
907        assert_eq!(masked.nnz(), 6); // Only entries in the selected columns
908
909        // Test SVD on the masked matrix
910        let svd_result = crate::lanczos::svd(&masked);
911        assert!(svd_result.is_ok());
912    }
913
914    #[test]
915    fn test_masked_vs_physical_subset() {
916        // Create a fixed seed for reproducible tests
917        let mut rng = StdRng::seed_from_u64(42);
918
919        // Generate a random matrix (5x8)
920        let nrows = 14;
921        let ncols = 10;
922        let nnz = 40; // Number of non-zero elements
923
924        let mut coo = CooMatrix::<f64>::new(nrows, ncols);
925
926        // Fill with random non-zero values
927        for _ in 0..nnz {
928            let row = rng.gen_range(0..nrows);
929            let col = rng.gen_range(0..ncols);
930            let val = rng.gen_range(0.1..10.0);
931
932            // Note: CooMatrix will overwrite if the position already has a value
933            coo.push(row, col, val);
934        }
935
936        // Convert to CSR which is what our masked implementation uses
937        let csr = CsrMatrix::from(&coo);
938
939        // Select a subset of columns (e.g., columns 1, 3, 5, 7)
940        let selected_columns = vec![1, 3, 5, 7];
941
942        // Create the masked matrix view
943        let masked_matrix = MaskedCSRMatrix::with_columns(&csr, &selected_columns);
944
945        // Create a physical copy with just those columns
946        let mut physical_subset = CooMatrix::<f64>::new(nrows, selected_columns.len());
947
948        // Map original column indices to new column indices
949        let col_map: std::collections::HashMap<usize, usize> = selected_columns
950            .iter()
951            .enumerate()
952            .map(|(new_idx, &old_idx)| (old_idx, new_idx))
953            .collect();
954
955        // Copy the values for the selected columns
956        for (row, col, val) in coo.triplet_iter() {
957            if let Some(&new_col) = col_map.get(&col) {
958                physical_subset.push(row, new_col, *val);
959            }
960        }
961
962        // Convert to CSR for SVD
963        let physical_csr = CsrMatrix::from(&physical_subset);
964
965        // Compare dimensions and nnz
966        assert_eq!(masked_matrix.nrows(), physical_csr.nrows());
967        assert_eq!(masked_matrix.ncols(), physical_csr.ncols());
968        assert_eq!(masked_matrix.nnz(), physical_csr.nnz());
969
970        // Perform SVD on both
971        let svd_masked = crate::lanczos::svd(&masked_matrix).unwrap();
972        let svd_physical = crate::lanczos::svd(&physical_csr).unwrap();
973
974        // Compare SVD results - they should be very close but not exactly the same
975        // due to potential differences in numerical computation
976
977        // Check dimension (rank)
978        assert_eq!(svd_masked.d, svd_physical.d);
979
980        // Basic tolerance for floating point comparisons
981        let epsilon = 1e-10;
982
983        // Check singular values (may be in different order, so we sort them)
984        let mut masked_s = svd_masked.s.to_vec();
985        let mut physical_s = svd_physical.s.to_vec();
986        masked_s.sort_by(|a, b| b.partial_cmp(a).unwrap()); // Sort in descending order
987        physical_s.sort_by(|a, b| b.partial_cmp(a).unwrap());
988
989        for (m, p) in masked_s.iter().zip(physical_s.iter()) {
990            assert!(
991                (m - p).abs() < epsilon,
992                "Singular values differ: {} vs {}",
993                m,
994                p
995            );
996        }
997
998        // Note: Comparing singular vectors is more complex due to potential sign flips
999        // and different ordering, so we'll skip that level of detailed comparison
1000    }
1001}