1use crate::sparse::CsrMatrix;
22use crate::traits::{ComplexField, Preconditioner};
23use ndarray::Array1;
24use num_traits::{FromPrimitive, One};
25
26#[cfg(feature = "rayon")]
27use rayon::prelude::*;
28
29#[derive(Debug, Clone)]
31pub struct AdditiveSchwarzPreconditioner<T: ComplexField> {
32 subdomains: Vec<Subdomain<T>>,
34 weights: Vec<T>,
36 n: usize,
38}
39
40#[derive(Debug, Clone)]
42#[allow(dead_code)]
43struct Subdomain<T: ComplexField> {
44 global_indices: Vec<usize>,
46 local_values: Vec<T>,
48 local_col_indices: Vec<usize>,
49 local_row_ptrs: Vec<usize>,
50 l_values: Vec<T>,
52 l_col_indices: Vec<usize>,
53 l_row_ptrs: Vec<usize>,
54 u_values: Vec<T>,
55 u_col_indices: Vec<usize>,
56 u_row_ptrs: Vec<usize>,
57 u_diag: Vec<T>,
58}
59
60impl<T: ComplexField> AdditiveSchwarzPreconditioner<T> {
61 pub fn from_csr(matrix: &CsrMatrix<T>, num_subdomains: usize, overlap: usize) -> Self {
68 let n = matrix.num_rows;
69 let num_subdomains = num_subdomains.max(1).min(n);
70
71 let base_size = n / num_subdomains;
73 let remainder = n % num_subdomains;
74
75 let mut partitions: Vec<Vec<usize>> = Vec::with_capacity(num_subdomains);
76 let mut start = 0;
77 for i in 0..num_subdomains {
78 let size = base_size + if i < remainder { 1 } else { 0 };
79 partitions.push((start..start + size).collect());
80 start += size;
81 }
82
83 let adjacency = build_adjacency(matrix);
85
86 let extended_partitions: Vec<Vec<usize>> = partitions
88 .iter()
89 .map(|partition| extend_partition(partition, &adjacency, overlap, n))
90 .collect();
91
92 let mut overlap_count = vec![0usize; n];
94 for partition in &extended_partitions {
95 for &idx in partition {
96 overlap_count[idx] += 1;
97 }
98 }
99 let weights: Vec<T> = overlap_count
100 .iter()
101 .map(|&c| {
102 if c > 0 {
103 T::from_real(T::Real::one() / T::Real::from_usize(c).unwrap())
104 } else {
105 T::one()
106 }
107 })
108 .collect();
109
110 let subdomains: Vec<Subdomain<T>> = extended_partitions
112 .into_iter()
113 .map(|indices| build_subdomain(matrix, indices))
114 .collect();
115
116 Self {
117 subdomains,
118 weights,
119 n,
120 }
121 }
122
123 #[cfg(feature = "rayon")]
125 pub fn from_csr_auto(matrix: &CsrMatrix<T>, overlap: usize) -> Self {
126 let num_threads = rayon::current_num_threads();
127 Self::from_csr(matrix, num_threads, overlap)
128 }
129
130 pub fn stats(&self) -> (usize, usize, usize, f64) {
132 let num_subdomains = self.subdomains.len();
133 let min_size = self
134 .subdomains
135 .iter()
136 .map(|s| s.global_indices.len())
137 .min()
138 .unwrap_or(0);
139 let max_size = self
140 .subdomains
141 .iter()
142 .map(|s| s.global_indices.len())
143 .max()
144 .unwrap_or(0);
145 let avg_size = self
146 .subdomains
147 .iter()
148 .map(|s| s.global_indices.len())
149 .sum::<usize>() as f64
150 / num_subdomains as f64;
151 (num_subdomains, min_size, max_size, avg_size)
152 }
153}
154
155fn build_adjacency<T: ComplexField>(matrix: &CsrMatrix<T>) -> Vec<Vec<usize>> {
157 let n = matrix.num_rows;
158 let mut adjacency = vec![Vec::new(); n];
159
160 for (i, row_adj) in adjacency.iter_mut().enumerate().take(n) {
161 for idx in matrix.row_ptrs[i]..matrix.row_ptrs[i + 1] {
162 let j = matrix.col_indices[idx];
163 if i != j {
164 row_adj.push(j);
165 }
166 }
167 }
168
169 adjacency
170}
171
172fn extend_partition(
174 partition: &[usize],
175 adjacency: &[Vec<usize>],
176 overlap: usize,
177 n: usize,
178) -> Vec<usize> {
179 let mut in_partition = vec![false; n];
180 for &idx in partition {
181 in_partition[idx] = true;
182 }
183
184 let mut frontier: Vec<usize> = partition.to_vec();
185
186 for _ in 0..overlap {
187 let mut new_frontier = Vec::new();
188 for &idx in &frontier {
189 for &neighbor in &adjacency[idx] {
190 if !in_partition[neighbor] {
191 in_partition[neighbor] = true;
192 new_frontier.push(neighbor);
193 }
194 }
195 }
196 frontier = new_frontier;
197 }
198
199 let mut result: Vec<usize> = (0..n).filter(|&i| in_partition[i]).collect();
201 result.sort_unstable();
202 result
203}
204
205fn build_subdomain<T: ComplexField>(
207 matrix: &CsrMatrix<T>,
208 global_indices: Vec<usize>,
209) -> Subdomain<T> {
210 let local_n = global_indices.len();
211
212 let mut global_to_local = vec![usize::MAX; matrix.num_rows];
214 for (local_idx, &global_idx) in global_indices.iter().enumerate() {
215 global_to_local[global_idx] = local_idx;
216 }
217
218 let mut local_values = Vec::new();
220 let mut local_col_indices = Vec::new();
221 let mut local_row_ptrs = vec![0];
222
223 for &global_row in &global_indices {
224 for idx in matrix.row_ptrs[global_row]..matrix.row_ptrs[global_row + 1] {
225 let global_col = matrix.col_indices[idx];
226 let local_col = global_to_local[global_col];
227 if local_col != usize::MAX {
228 local_values.push(matrix.values[idx]);
229 local_col_indices.push(local_col);
230 }
231 }
232 local_row_ptrs.push(local_values.len());
233 }
234
235 let (l_values, l_col_indices, l_row_ptrs, u_values, u_col_indices, u_row_ptrs, u_diag) =
237 ilu_factorize(&local_values, &local_col_indices, &local_row_ptrs, local_n);
238
239 Subdomain {
240 global_indices,
241 local_values,
242 local_col_indices,
243 local_row_ptrs,
244 l_values,
245 l_col_indices,
246 l_row_ptrs,
247 u_values,
248 u_col_indices,
249 u_row_ptrs,
250 u_diag,
251 }
252}
253
254#[allow(clippy::type_complexity)]
256fn ilu_factorize<T: ComplexField>(
257 values: &[T],
258 col_indices: &[usize],
259 row_ptrs: &[usize],
260 n: usize,
261) -> (
262 Vec<T>,
263 Vec<usize>,
264 Vec<usize>,
265 Vec<T>,
266 Vec<usize>,
267 Vec<usize>,
268 Vec<T>,
269) {
270 let mut values = values.to_vec();
272
273 for i in 0..n {
275 for idx in row_ptrs[i]..row_ptrs[i + 1] {
276 let k = col_indices[idx];
277 if k >= i {
278 break;
279 }
280
281 let mut u_kk = T::zero();
283 for k_idx in row_ptrs[k]..row_ptrs[k + 1] {
284 if col_indices[k_idx] == k {
285 u_kk = values[k_idx];
286 break;
287 }
288 }
289
290 if u_kk.norm() < T::Real::from_f64(1e-30).unwrap() {
291 continue;
292 }
293
294 let l_ik = values[idx] * u_kk.inv();
295 values[idx] = l_ik;
296
297 for j_idx in row_ptrs[i]..row_ptrs[i + 1] {
298 let j = col_indices[j_idx];
299 if j <= k {
300 continue;
301 }
302
303 for k_j_idx in row_ptrs[k]..row_ptrs[k + 1] {
304 if col_indices[k_j_idx] == j {
305 values[j_idx] = values[j_idx] - l_ik * values[k_j_idx];
306 break;
307 }
308 }
309 }
310 }
311 }
312
313 let mut l_values = Vec::new();
315 let mut l_col_indices = Vec::new();
316 let mut l_row_ptrs = vec![0];
317
318 let mut u_values = Vec::new();
319 let mut u_col_indices = Vec::new();
320 let mut u_row_ptrs = vec![0];
321 let mut u_diag = vec![T::one(); n];
322
323 for i in 0..n {
324 for idx in row_ptrs[i]..row_ptrs[i + 1] {
325 let j = col_indices[idx];
326 let val = values[idx];
327
328 if j < i {
329 l_values.push(val);
330 l_col_indices.push(j);
331 } else {
332 u_values.push(val);
333 u_col_indices.push(j);
334 if j == i {
335 u_diag[i] = val;
336 }
337 }
338 }
339 l_row_ptrs.push(l_values.len());
340 u_row_ptrs.push(u_values.len());
341 }
342
343 (
344 l_values,
345 l_col_indices,
346 l_row_ptrs,
347 u_values,
348 u_col_indices,
349 u_row_ptrs,
350 u_diag,
351 )
352}
353
354impl<T: ComplexField> Subdomain<T> {
355 fn solve(&self, local_rhs: &[T]) -> Vec<T> {
357 let n = self.global_indices.len();
358 let mut y = local_rhs.to_vec();
359
360 for i in 0..n {
362 for idx in self.l_row_ptrs[i]..self.l_row_ptrs[i + 1] {
363 let j = self.l_col_indices[idx];
364 let l_ij = self.l_values[idx];
365 y[i] = y[i] - l_ij * y[j];
366 }
367 }
368
369 let mut x = y;
371 for i in (0..n).rev() {
372 for idx in self.u_row_ptrs[i]..self.u_row_ptrs[i + 1] {
373 let j = self.u_col_indices[idx];
374 if j > i {
375 let u_ij = self.u_values[idx];
376 x[i] = x[i] - u_ij * x[j];
377 }
378 }
379
380 let u_ii = self.u_diag[i];
381 if u_ii.norm() > T::Real::from_f64(1e-30).unwrap() {
382 x[i] *= u_ii.inv();
383 }
384 }
385
386 x
387 }
388}
389
390impl<T: ComplexField + Send + Sync> Preconditioner<T> for AdditiveSchwarzPreconditioner<T> {
391 fn apply(&self, r: &Array1<T>) -> Array1<T> {
392 #[cfg(feature = "rayon")]
393 {
394 if self.n >= 1000 && self.subdomains.len() > 1 {
395 return self.apply_parallel(r);
396 }
397 }
398 self.apply_sequential(r)
399 }
400}
401
402impl<T: ComplexField + Send + Sync> AdditiveSchwarzPreconditioner<T> {
403 fn apply_sequential(&self, r: &Array1<T>) -> Array1<T> {
404 let mut result = Array1::from_elem(self.n, T::zero());
405
406 for subdomain in &self.subdomains {
407 let local_rhs: Vec<T> = subdomain.global_indices.iter().map(|&i| r[i]).collect();
409
410 let local_solution = subdomain.solve(&local_rhs);
412
413 for (local_idx, &global_idx) in subdomain.global_indices.iter().enumerate() {
415 result[global_idx] += local_solution[local_idx] * self.weights[global_idx];
416 }
417 }
418
419 result
420 }
421
422 #[cfg(feature = "rayon")]
423 fn apply_parallel(&self, r: &Array1<T>) -> Array1<T> {
424 use std::cell::UnsafeCell;
425
426 struct UnsafeVec<U>(UnsafeCell<Vec<U>>);
428 unsafe impl<U: Send> Sync for UnsafeVec<U> {}
429
430 impl<U: ComplexField> UnsafeVec<U> {
431 fn add(&self, i: usize, val: U) {
432 unsafe {
433 let vec = &mut (*self.0.get());
434 vec[i] = vec[i] + val;
435 }
436 }
437 }
438
439 let result = UnsafeVec(UnsafeCell::new(vec![T::zero(); self.n]));
440 let r_slice = r.as_slice().expect("Array should be contiguous");
441 let weights = &self.weights;
442
443 self.subdomains.par_iter().for_each(|subdomain| {
445 let local_rhs: Vec<T> = subdomain
447 .global_indices
448 .iter()
449 .map(|&i| r_slice[i])
450 .collect();
451
452 let local_solution = subdomain.solve(&local_rhs);
454
455 for (local_idx, &global_idx) in subdomain.global_indices.iter().enumerate() {
460 result.add(global_idx, local_solution[local_idx] * weights[global_idx]);
461 }
462 });
463
464 Array1::from_vec(result.0.into_inner())
465 }
466}
467
468#[cfg(test)]
469mod tests {
470 use super::*;
471 use crate::iterative::{GmresConfig, gmres_preconditioned};
472 use num_complex::Complex64;
473
474 fn create_test_matrix() -> CsrMatrix<Complex64> {
475 let n = 20;
476 let mut dense = ndarray::Array2::from_elem((n, n), Complex64::new(0.0, 0.0));
477
478 for i in 0..n {
480 dense[[i, i]] = Complex64::new(4.0, 0.0);
481 if i > 0 {
482 dense[[i, i - 1]] = Complex64::new(-1.0, 0.0);
483 }
484 if i < n - 1 {
485 dense[[i, i + 1]] = Complex64::new(-1.0, 0.0);
486 }
487 if i >= 5 {
489 dense[[i, i - 5]] = Complex64::new(-0.5, 0.0);
490 }
491 if i < n - 5 {
492 dense[[i, i + 5]] = Complex64::new(-0.5, 0.0);
493 }
494 }
495
496 CsrMatrix::from_dense(&dense, 1e-15)
497 }
498
499 #[test]
500 fn test_schwarz_basic() {
501 let matrix = create_test_matrix();
502 let precond = AdditiveSchwarzPreconditioner::from_csr(&matrix, 4, 1);
503
504 let r = Array1::from_iter((0..20).map(|i| Complex64::new((i as f64).sin(), 0.0)));
505 let result = precond.apply(&r);
506
507 assert_eq!(result.len(), 20);
508 assert!(result.iter().all(|x| x.norm() < 100.0));
509 }
510
511 #[test]
512 fn test_schwarz_stats() {
513 let matrix = create_test_matrix();
514 let precond = AdditiveSchwarzPreconditioner::from_csr(&matrix, 4, 2);
515
516 let (num_subdomains, min_size, max_size, avg_size) = precond.stats();
517
518 assert_eq!(num_subdomains, 4);
519 assert!(min_size > 0);
520 assert!(max_size >= min_size);
521 assert!(avg_size > 0.0);
522 assert!(avg_size > 5.0);
524 }
525
526 #[test]
527 fn test_schwarz_with_gmres() {
528 let matrix = create_test_matrix();
529 let precond = AdditiveSchwarzPreconditioner::from_csr(&matrix, 4, 2);
530
531 let b = Array1::from_iter((0..20).map(|i| Complex64::new((i as f64).sin(), 0.0)));
532
533 let config = GmresConfig {
534 max_iterations: 100,
535 restart: 20,
536 tolerance: 1e-8,
537 print_interval: 0,
538 };
539
540 let sol = gmres_preconditioned(&matrix, &precond, &b, &config);
541 assert!(sol.converged, "GMRES with Schwarz should converge");
542 }
543
544 #[test]
545 fn test_schwarz_overlap_effect() {
546 let matrix = create_test_matrix();
547
548 let precond_no_overlap = AdditiveSchwarzPreconditioner::from_csr(&matrix, 4, 0);
550 let precond_overlap_1 = AdditiveSchwarzPreconditioner::from_csr(&matrix, 4, 1);
551 let precond_overlap_2 = AdditiveSchwarzPreconditioner::from_csr(&matrix, 4, 2);
552
553 let b = Array1::from_iter((0..20).map(|i| Complex64::new((i as f64).sin(), 0.0)));
554
555 let config = GmresConfig {
556 max_iterations: 100,
557 restart: 20,
558 tolerance: 1e-8,
559 print_interval: 0,
560 };
561
562 let sol_no_overlap = gmres_preconditioned(&matrix, &precond_no_overlap, &b, &config);
563 let sol_overlap_1 = gmres_preconditioned(&matrix, &precond_overlap_1, &b, &config);
564 let sol_overlap_2 = gmres_preconditioned(&matrix, &precond_overlap_2, &b, &config);
565
566 assert!(sol_no_overlap.converged);
568 assert!(sol_overlap_1.converged);
569 assert!(sol_overlap_2.converged);
570
571 }
574}