scirs2_sparse/distributed/
partition.rs1use std::collections::HashSet;
8
9use crate::csr::CsrMatrix;
10use crate::error::{SparseError, SparseResult};
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
18#[non_exhaustive]
19pub enum PartitionMethod {
20 #[default]
22 Contiguous,
23 RoundRobin,
25 GraphBased,
28}
29
30#[derive(Debug, Clone)]
36pub struct PartitionConfig {
37 pub n_workers: usize,
39 pub overlap: usize,
42 pub method: PartitionMethod,
44}
45
46impl Default for PartitionConfig {
47 fn default() -> Self {
48 Self {
49 n_workers: 4,
50 overlap: 0,
51 method: PartitionMethod::Contiguous,
52 }
53 }
54}
55
56#[derive(Debug, Clone)]
62pub struct RowPartition {
63 pub worker_id: usize,
65 pub local_rows: Vec<usize>,
67 pub n_global_rows: usize,
69}
70
71impl RowPartition {
72 #[inline]
74 pub fn n_local(&self) -> usize {
75 self.local_rows.len()
76 }
77}
78
79#[derive(Debug, Clone)]
90pub struct DistributedCsr {
91 pub local_matrix: CsrMatrix<f64>,
93 pub partition: RowPartition,
95 pub ghost_rows: Vec<usize>,
98}
99
100pub fn partition_rows(n_rows: usize, config: &PartitionConfig) -> Vec<RowPartition> {
108 let p = config.n_workers.max(1);
109
110 match config.method {
111 PartitionMethod::Contiguous => (0..p)
112 .map(|w| {
113 let start = w * n_rows / p;
114 let end = (w + 1) * n_rows / p;
115 RowPartition {
116 worker_id: w,
117 local_rows: (start..end).collect(),
118 n_global_rows: n_rows,
119 }
120 })
121 .collect(),
122 PartitionMethod::RoundRobin => {
123 let mut bins: Vec<Vec<usize>> = vec![Vec::new(); p];
124 for r in 0..n_rows {
125 bins[r % p].push(r);
126 }
127 bins.into_iter()
128 .enumerate()
129 .map(|(w, rows)| RowPartition {
130 worker_id: w,
131 local_rows: rows,
132 n_global_rows: n_rows,
133 })
134 .collect()
135 }
136 PartitionMethod::GraphBased => {
137 (0..p)
141 .map(|w| {
142 let start = w * n_rows / p;
143 let end = (w + 1) * n_rows / p;
144 RowPartition {
145 worker_id: w,
146 local_rows: (start..end).collect(),
147 n_global_rows: n_rows,
148 }
149 })
150 .collect()
151 }
152 }
153}
154
155pub fn create_distributed_csr(
165 global_matrix: &CsrMatrix<f64>,
166 partition: &RowPartition,
167) -> SparseResult<DistributedCsr> {
168 let n_local = partition.local_rows.len();
169 let n_cols = global_matrix.cols();
170 let n_global_rows = global_matrix.rows();
171
172 let owned_set: HashSet<usize> = partition.local_rows.iter().copied().collect();
174
175 let mut row_indices: Vec<usize> = Vec::new();
177 let mut col_indices: Vec<usize> = Vec::new();
178 let mut values: Vec<f64> = Vec::new();
179 let mut ghost_set: HashSet<usize> = HashSet::new();
180
181 for (local_row, &global_row) in partition.local_rows.iter().enumerate() {
182 if global_row >= n_global_rows {
183 return Err(SparseError::ValueError(format!(
184 "Global row {global_row} out of bounds (n_rows={n_global_rows})"
185 )));
186 }
187 let row_start = global_matrix.indptr[global_row];
188 let row_end = global_matrix.indptr[global_row + 1];
189
190 for idx in row_start..row_end {
191 let col = global_matrix.indices[idx];
192 let val = global_matrix.data[idx];
193
194 row_indices.push(local_row);
195 col_indices.push(col);
196 values.push(val);
197
198 if col < n_global_rows && !owned_set.contains(&col) {
201 ghost_set.insert(col);
202 }
203 }
204 }
205
206 let local_matrix = CsrMatrix::from_triplets(n_local, n_cols, row_indices, col_indices, values)?;
207
208 let mut ghost_rows: Vec<usize> = ghost_set.into_iter().collect();
209 ghost_rows.sort_unstable();
210
211 Ok(DistributedCsr {
212 local_matrix,
213 partition: partition.clone(),
214 ghost_rows,
215 })
216}
217
218pub fn partition_matrix_nnz(
227 global_matrix: &CsrMatrix<f64>,
228 n_workers: usize,
229) -> SparseResult<Vec<DistributedCsr>> {
230 let n_rows = global_matrix.rows();
231 let p = n_workers.max(1);
232
233 let row_nnz: Vec<usize> = (0..n_rows)
235 .map(|r| global_matrix.indptr[r + 1] - global_matrix.indptr[r])
236 .collect();
237 let total_nnz: usize = row_nnz.iter().sum();
238 let target = (total_nnz + p - 1) / p;
239
240 let mut partitions_rows: Vec<Vec<usize>> = vec![Vec::new(); p];
242 let mut worker = 0usize;
243 let mut acc = 0usize;
244
245 for r in 0..n_rows {
246 partitions_rows[worker].push(r);
247 acc += row_nnz[r];
248 if acc >= target && worker + 1 < p {
249 worker += 1;
250 acc = 0;
251 }
252 }
253
254 let result: SparseResult<Vec<DistributedCsr>> = partitions_rows
255 .into_iter()
256 .enumerate()
257 .map(|(w, rows)| {
258 let rp = RowPartition {
259 worker_id: w,
260 local_rows: rows,
261 n_global_rows: n_rows,
262 };
263 create_distributed_csr(global_matrix, &rp)
264 })
265 .collect();
266
267 result
268}
269
270#[cfg(test)]
275mod tests {
276 use super::*;
277
278 fn tridiag_100() -> CsrMatrix<f64> {
280 let n = 100usize;
281 let mut rows = Vec::new();
282 let mut cols = Vec::new();
283 let mut vals = Vec::new();
284 for i in 0..n {
285 rows.push(i);
286 cols.push(i);
287 vals.push(2.0_f64);
288 if i > 0 {
289 rows.push(i);
290 cols.push(i - 1);
291 vals.push(-1.0);
292 rows.push(i - 1);
293 cols.push(i);
294 vals.push(-1.0);
295 }
296 }
297 CsrMatrix::from_triplets(n, n, rows, cols, vals).expect("tridiag_100 construction")
298 }
299
300 #[test]
301 fn test_contiguous_row_count_sums_to_n() {
302 let config = PartitionConfig {
303 n_workers: 4,
304 ..Default::default()
305 };
306 let parts = partition_rows(100, &config);
307 assert_eq!(parts.len(), 4);
308 let total: usize = parts.iter().map(|p| p.n_local()).sum();
309 assert_eq!(total, 100);
310 }
311
312 #[test]
313 fn test_contiguous_first_partition_rows() {
314 let config = PartitionConfig {
315 n_workers: 4,
316 ..Default::default()
317 };
318 let parts = partition_rows(100, &config);
319 assert_eq!(parts[0].local_rows, (0..25).collect::<Vec<_>>());
321 assert_eq!(parts[1].local_rows, (25..50).collect::<Vec<_>>());
322 assert_eq!(parts[2].local_rows, (50..75).collect::<Vec<_>>());
323 assert_eq!(parts[3].local_rows, (75..100).collect::<Vec<_>>());
324 }
325
326 #[test]
327 fn test_round_robin_all_rows_assigned() {
328 let config = PartitionConfig {
329 n_workers: 4,
330 method: PartitionMethod::RoundRobin,
331 ..Default::default()
332 };
333 let parts = partition_rows(100, &config);
334 let total: usize = parts.iter().map(|p| p.n_local()).sum();
335 assert_eq!(total, 100);
336 }
337
338 #[test]
339 fn test_create_distributed_csr_ghost_rows() {
340 let mat = tridiag_100();
341 let config = PartitionConfig {
342 n_workers: 4,
343 ..Default::default()
344 };
345 let partitions = partition_rows(100, &config);
346 let dcsr =
348 create_distributed_csr(&mat, &partitions[1]).expect("create_distributed_csr failed");
349 assert!(
351 dcsr.ghost_rows.contains(&24),
352 "Expected row 24 as ghost, got {:?}",
353 dcsr.ghost_rows
354 );
355 assert!(
356 dcsr.ghost_rows.contains(&50),
357 "Expected row 50 as ghost, got {:?}",
358 dcsr.ghost_rows
359 );
360 }
361
362 #[test]
363 fn test_distributed_csr_local_matrix_nnz() {
364 let mat = tridiag_100();
365 let config = PartitionConfig {
366 n_workers: 4,
367 ..Default::default()
368 };
369 let partitions = partition_rows(100, &config);
370 let dcsr =
371 create_distributed_csr(&mat, &partitions[0]).expect("create_distributed_csr failed");
372 assert_eq!(dcsr.local_matrix.nnz(), 2 + 23 * 3 + 3);
378 }
379
380 #[test]
381 fn test_partition_matrix_nnz_balanced() {
382 let mat = tridiag_100();
383 let dcsrs = partition_matrix_nnz(&mat, 4).expect("partition_matrix_nnz failed");
384 assert_eq!(dcsrs.len(), 4);
385 let total_rows: usize = dcsrs.iter().map(|d| d.partition.n_local()).sum();
386 assert_eq!(total_rows, 100);
387 }
388}