use rayon::prelude::*;
use crate::constants::GAMMA;
use crate::error::{Error, Result};
use crate::vector3::Vector3;
type TileBoundary = (
Vec<Vector3<f64>>,
Vec<Vector3<f64>>,
Vec<Vector3<f64>>,
Vec<Vector3<f64>>,
);
#[derive(Debug, Clone)]
pub struct Domain {
pub spins: Vec<Vector3<f64>>,
pub ghost_left: Vec<Vector3<f64>>,
pub ghost_right: Vec<Vector3<f64>>,
pub global_offset: usize,
}
impl Domain {
pub fn total_cells(&self) -> usize {
self.ghost_left.len() + self.spins.len() + self.ghost_right.len()
}
pub fn get_spin(&self, local_idx: usize) -> Option<&Vector3<f64>> {
let gl = self.ghost_left.len();
let ns = self.spins.len();
if local_idx < gl {
self.ghost_left.get(local_idx)
} else if local_idx < gl + ns {
self.spins.get(local_idx - gl)
} else {
self.ghost_right.get(local_idx - gl - ns)
}
}
}
#[derive(Debug, Clone)]
pub struct DomainDecomposition {
pub num_domains: usize,
pub domain_size: usize,
pub ghost_width: usize,
pub domains: Vec<Domain>,
}
impl DomainDecomposition {
pub fn new(spins: &[Vector3<f64>], num_domains: usize, ghost_width: usize) -> Result<Self> {
if num_domains == 0 {
return Err(Error::InvalidParameter {
param: "num_domains".to_string(),
reason: "must be at least 1".to_string(),
});
}
if spins.is_empty() {
return Err(Error::InvalidParameter {
param: "spins".to_string(),
reason: "spin array must not be empty".to_string(),
});
}
let n = spins.len();
let base_size = n / num_domains;
let remainder = n % num_domains;
let min_domain = base_size;
if ghost_width > min_domain && num_domains > 1 {
return Err(Error::InvalidParameter {
param: "ghost_width".to_string(),
reason: format!(
"ghost_width {} exceeds smallest domain size {}",
ghost_width, min_domain
),
});
}
let mut domains = Vec::with_capacity(num_domains);
let mut offset: usize = 0;
for d in 0..num_domains {
let size = base_size + if d < remainder { 1 } else { 0 };
let interior = spins[offset..offset + size].to_vec();
let ghost_left = if d == 0 {
Vec::new()
} else {
let start = offset.saturating_sub(ghost_width);
spins[start..offset].to_vec()
};
let ghost_right = if d == num_domains - 1 {
Vec::new()
} else {
let end = (offset + size + ghost_width).min(n);
spins[offset + size..end].to_vec()
};
domains.push(Domain {
spins: interior,
ghost_left,
ghost_right,
global_offset: offset,
});
offset += size;
}
Ok(Self {
num_domains,
domain_size: base_size,
ghost_width,
domains,
})
}
pub fn total_spins(&self) -> usize {
self.domains.iter().map(|d| d.spins.len()).sum()
}
pub fn gather(&self) -> Vec<Vector3<f64>> {
let mut global = Vec::with_capacity(self.total_spins());
for domain in &self.domains {
global.extend_from_slice(&domain.spins);
}
global
}
pub fn update_ghost_cells(&mut self) {
if self.num_domains <= 1 {
return;
}
let boundary_data: Vec<(Vec<Vector3<f64>>, Vec<Vector3<f64>>)> = self
.domains
.iter()
.map(|d| {
let tail: Vec<Vector3<f64>> = d
.spins
.iter()
.rev()
.take(self.ghost_width)
.copied()
.collect::<Vec<_>>()
.into_iter()
.rev()
.collect();
let head: Vec<Vector3<f64>> =
d.spins.iter().take(self.ghost_width).copied().collect();
(tail, head)
})
.collect();
for d in 0..self.num_domains {
if d > 0 {
self.domains[d].ghost_left = boundary_data[d - 1].0.clone();
}
if d < self.num_domains - 1 {
self.domains[d].ghost_right = boundary_data[d + 1].1.clone();
}
}
}
pub fn parallel_llg_step(
&mut self,
h_ext: Vector3<f64>,
alpha: f64,
exchange_stiffness: f64,
cell_size: f64,
dt: f64,
) {
let ghost_width = self.ghost_width;
self.domains.par_iter_mut().for_each(|domain| {
let gl = domain.ghost_left.len();
let ns = domain.spins.len();
let mut all_spins = Vec::with_capacity(domain.total_cells());
all_spins.extend_from_slice(&domain.ghost_left);
all_spins.extend_from_slice(&domain.spins);
all_spins.extend_from_slice(&domain.ghost_right);
let mut new_spins = Vec::with_capacity(ns);
for i in 0..ns {
let idx = gl + i; let m = all_spins[idx];
let exchange_field =
compute_exchange_field(&all_spins, idx, exchange_stiffness, cell_size);
let h_eff = h_ext + exchange_field;
let dm_dt = llg_torque(m, h_eff, alpha);
let m_new = (m + dm_dt * dt).normalize();
new_spins.push(m_new);
}
domain.spins = new_spins;
let _ = (ghost_width, gl); });
self.update_ghost_cells();
}
pub fn parallel_heun_step(
&mut self,
h_ext: Vector3<f64>,
alpha: f64,
exchange_stiffness: f64,
cell_size: f64,
dt: f64,
) {
let ghost_width = self.ghost_width;
self.domains.par_iter_mut().for_each(|domain| {
let gl = domain.ghost_left.len();
let ns = domain.spins.len();
let mut all_spins = Vec::with_capacity(domain.total_cells());
all_spins.extend_from_slice(&domain.ghost_left);
all_spins.extend_from_slice(&domain.spins);
all_spins.extend_from_slice(&domain.ghost_right);
let mut k1 = Vec::with_capacity(ns);
for i in 0..ns {
let idx = gl + i;
let m = all_spins[idx];
let exchange_field =
compute_exchange_field(&all_spins, idx, exchange_stiffness, cell_size);
let h_eff = h_ext + exchange_field;
k1.push(llg_torque(m, h_eff, alpha));
}
let mut predicted = Vec::with_capacity(domain.total_cells());
predicted.extend_from_slice(&domain.ghost_left);
for (i, k1_val) in k1.iter().enumerate().take(ns) {
predicted.push((domain.spins[i] + *k1_val * dt).normalize());
}
predicted.extend_from_slice(&domain.ghost_right);
let mut k2 = Vec::with_capacity(ns);
for i in 0..ns {
let idx = gl + i;
let m = predicted[idx];
let exchange_field =
compute_exchange_field(&predicted, idx, exchange_stiffness, cell_size);
let h_eff = h_ext + exchange_field;
k2.push(llg_torque(m, h_eff, alpha));
}
for i in 0..ns {
let dm_dt = (k1[i] + k2[i]) * 0.5;
domain.spins[i] = (domain.spins[i] + dm_dt * dt).normalize();
}
let _ = (ghost_width, gl);
});
self.update_ghost_cells();
}
}
#[derive(Debug, Clone)]
pub struct Tile2D {
pub spins: Vec<Vector3<f64>>,
pub rows: usize,
pub cols: usize,
pub ghost_top: Vec<Vector3<f64>>,
pub ghost_bottom: Vec<Vector3<f64>>,
pub ghost_left: Vec<Vector3<f64>>,
pub ghost_right: Vec<Vector3<f64>>,
pub grid_row: usize,
pub grid_col: usize,
}
impl Tile2D {
pub fn get(&self, row: usize, col: usize) -> Option<&Vector3<f64>> {
if row < self.rows && col < self.cols {
Some(&self.spins[row * self.cols + col])
} else {
None
}
}
}
#[derive(Debug, Clone)]
pub struct DomainDecomposition2D {
pub grid_rows: usize,
pub grid_cols: usize,
pub global_rows: usize,
pub global_cols: usize,
pub tiles: Vec<Tile2D>,
}
impl DomainDecomposition2D {
pub fn new(
spins: &[Vector3<f64>],
global_rows: usize,
global_cols: usize,
grid_rows: usize,
grid_cols: usize,
) -> Result<Self> {
if grid_rows == 0 || grid_cols == 0 {
return Err(Error::InvalidParameter {
param: "grid dimensions".to_string(),
reason: "grid_rows and grid_cols must be at least 1".to_string(),
});
}
if spins.len() != global_rows * global_cols {
return Err(Error::DimensionMismatch {
expected: format!(
"{}x{} = {}",
global_rows,
global_cols,
global_rows * global_cols
),
actual: format!("{}", spins.len()),
});
}
let base_tile_rows = global_rows / grid_rows;
let rem_rows = global_rows % grid_rows;
let base_tile_cols = global_cols / grid_cols;
let rem_cols = global_cols % grid_cols;
let idx = |r: usize, c: usize| -> usize { r * global_cols + c };
let mut tiles = Vec::with_capacity(grid_rows * grid_cols);
let mut row_offset: usize = 0;
for gr in 0..grid_rows {
let tile_rows = base_tile_rows + if gr < rem_rows { 1 } else { 0 };
let mut col_offset: usize = 0;
for gc in 0..grid_cols {
let tile_cols = base_tile_cols + if gc < rem_cols { 1 } else { 0 };
let mut interior = Vec::with_capacity(tile_rows * tile_cols);
for r in row_offset..row_offset + tile_rows {
for c in col_offset..col_offset + tile_cols {
interior.push(spins[idx(r, c)]);
}
}
let ghost_top = if gr == 0 {
Vec::new()
} else {
let r = row_offset - 1;
(col_offset..col_offset + tile_cols)
.map(|c| spins[idx(r, c)])
.collect()
};
let ghost_bottom = if gr == grid_rows - 1 {
Vec::new()
} else {
let r = row_offset + tile_rows;
(col_offset..col_offset + tile_cols)
.map(|c| spins[idx(r, c)])
.collect()
};
let ghost_left = if gc == 0 {
Vec::new()
} else {
let c = col_offset - 1;
(row_offset..row_offset + tile_rows)
.map(|r| spins[idx(r, c)])
.collect()
};
let ghost_right = if gc == grid_cols - 1 {
Vec::new()
} else {
let c = col_offset + tile_cols;
(row_offset..row_offset + tile_rows)
.map(|r| spins[idx(r, c)])
.collect()
};
tiles.push(Tile2D {
spins: interior,
rows: tile_rows,
cols: tile_cols,
ghost_top,
ghost_bottom,
ghost_left,
ghost_right,
grid_row: gr,
grid_col: gc,
});
col_offset += tile_cols;
}
row_offset += tile_rows;
}
Ok(Self {
grid_rows,
grid_cols,
global_rows,
global_cols,
tiles,
})
}
pub fn total_spins(&self) -> usize {
self.tiles.iter().map(|t| t.spins.len()).sum()
}
pub fn gather(&self) -> Vec<Vector3<f64>> {
let mut global = vec![Vector3::zero(); self.global_rows * self.global_cols];
for tile in &self.tiles {
let row_offset = self.row_offset(tile.grid_row);
let col_offset = self.col_offset(tile.grid_col);
for r in 0..tile.rows {
for c in 0..tile.cols {
let gr = row_offset + r;
let gc = col_offset + c;
global[gr * self.global_cols + gc] = tile.spins[r * tile.cols + c];
}
}
}
global
}
pub fn update_ghost_cells(&mut self) {
let boundary: Vec<TileBoundary> = self
.tiles
.iter()
.map(|t| {
let top_row: Vec<_> = t.spins.iter().take(t.cols).copied().collect();
let bottom_row: Vec<_> = t
.spins
.iter()
.skip((t.rows - 1) * t.cols)
.take(t.cols)
.copied()
.collect();
let left_col: Vec<_> = (0..t.rows).map(|r| t.spins[r * t.cols]).collect();
let right_col: Vec<_> = (0..t.rows)
.map(|r| t.spins[r * t.cols + t.cols - 1])
.collect();
(top_row, bottom_row, left_col, right_col)
})
.collect();
let gc = self.grid_cols;
for tile in self.tiles.iter_mut() {
let gr = tile.grid_row;
let gcol = tile.grid_col;
if gr > 0 {
let above_idx = (gr - 1) * gc + gcol;
tile.ghost_top = boundary[above_idx].1.clone();
}
if gr < self.grid_rows - 1 {
let below_idx = (gr + 1) * gc + gcol;
tile.ghost_bottom = boundary[below_idx].0.clone();
}
if gcol > 0 {
let left_idx = gr * gc + (gcol - 1);
tile.ghost_left = boundary[left_idx].3.clone();
}
if gcol < self.grid_cols - 1 {
let right_idx = gr * gc + (gcol + 1);
tile.ghost_right = boundary[right_idx].2.clone();
}
}
}
fn row_offset(&self, gr: usize) -> usize {
let base = self.global_rows / self.grid_rows;
let rem = self.global_rows % self.grid_rows;
let full = gr.min(rem) * (base + 1);
let rest = gr.saturating_sub(rem) * base;
full + rest
}
fn col_offset(&self, gc: usize) -> usize {
let base = self.global_cols / self.grid_cols;
let rem = self.global_cols % self.grid_cols;
let full = gc.min(rem) * (base + 1);
let rest = gc.saturating_sub(rem) * base;
full + rest
}
}
fn compute_exchange_field(
spins: &[Vector3<f64>],
idx: usize,
exchange_stiffness: f64,
cell_size: f64,
) -> Vector3<f64> {
let n = spins.len();
if n < 2 {
return Vector3::zero();
}
let prefactor = exchange_stiffness / (cell_size * cell_size);
let current = spins[idx];
let left = if idx > 0 { spins[idx - 1] } else { current };
let right = if idx < n - 1 { spins[idx + 1] } else { current };
(left + right - current * 2.0) * prefactor
}
fn llg_torque(m: Vector3<f64>, h_eff: Vector3<f64>, alpha: f64) -> Vector3<f64> {
let m_cross_h = m.cross(&h_eff);
let m_cross_m_cross_h = m.cross(&m_cross_h);
let prefactor = -GAMMA / (1.0 + alpha * alpha);
(m_cross_h + m_cross_m_cross_h * alpha) * prefactor
}
#[cfg(test)]
mod tests {
use super::*;
fn uniform_z_spins(n: usize) -> Vec<Vector3<f64>> {
vec![Vector3::new(0.0, 0.0, 1.0); n]
}
#[test]
fn test_decomposition_total_size() {
let spins = uniform_z_spins(100);
let decomp = DomainDecomposition::new(&spins, 4, 1).expect("decomposition should succeed");
assert_eq!(decomp.total_spins(), 100);
assert_eq!(decomp.num_domains, 4);
let gathered = decomp.gather();
assert_eq!(gathered.len(), 100);
for (a, b) in gathered.iter().zip(spins.iter()) {
assert!((a.x - b.x).abs() < 1e-15);
assert!((a.y - b.y).abs() < 1e-15);
assert!((a.z - b.z).abs() < 1e-15);
}
}
#[test]
fn test_decomposition_uneven_split() {
let spins = uniform_z_spins(103);
let decomp = DomainDecomposition::new(&spins, 4, 1).expect("decomposition should succeed");
assert_eq!(decomp.total_spins(), 103);
assert_eq!(decomp.domains[0].spins.len(), 26);
assert_eq!(decomp.domains[1].spins.len(), 26);
assert_eq!(decomp.domains[2].spins.len(), 26);
assert_eq!(decomp.domains[3].spins.len(), 25);
}
#[test]
fn test_ghost_cells_width_1() {
let spins: Vec<Vector3<f64>> = (0..20).map(|i| Vector3::new(i as f64, 0.0, 0.0)).collect();
let decomp = DomainDecomposition::new(&spins, 4, 1).expect("decomposition should succeed");
assert!(decomp.domains[0].ghost_left.is_empty());
assert_eq!(decomp.domains[0].ghost_right.len(), 1);
assert!((decomp.domains[0].ghost_right[0].x - 5.0).abs() < 1e-15);
assert_eq!(decomp.domains[1].ghost_left.len(), 1);
assert!((decomp.domains[1].ghost_left[0].x - 4.0).abs() < 1e-15);
let last = decomp.domains.last().expect("should have domains");
assert!(last.ghost_right.is_empty());
assert!(!last.ghost_left.is_empty());
}
#[test]
fn test_ghost_cell_update() {
let spins: Vec<Vector3<f64>> = (0..20).map(|i| Vector3::new(i as f64, 0.0, 0.0)).collect();
let mut decomp =
DomainDecomposition::new(&spins, 4, 1).expect("decomposition should succeed");
for s in &mut decomp.domains[1].spins {
s.x += 100.0;
}
decomp.update_ghost_cells();
let expected_x = 5.0 + 100.0; assert!(
(decomp.domains[0].ghost_right[0].x - expected_x).abs() < 1e-15,
"ghost_right[0] = {}, expected {}",
decomp.domains[0].ghost_right[0].x,
expected_x,
);
let d1_last_x = 9.0 + 100.0; assert!(
(decomp.domains[2].ghost_left[0].x - d1_last_x).abs() < 1e-15,
"ghost_left[0] = {}, expected {}",
decomp.domains[2].ghost_left[0].x,
d1_last_x,
);
}
#[test]
fn test_parallel_llg_step_matches_serial() {
let n = 40;
let spins: Vec<Vector3<f64>> = (0..n)
.map(|i| {
let angle = 0.1 * (i as f64);
Vector3::new(angle.sin(), 0.0, angle.cos()).normalize()
})
.collect();
let h_ext = Vector3::new(0.0, 0.0, 1.0);
let alpha = 0.01;
let a_ex = 1e-11;
let cell_size = 1e-9;
let dt = 1e-14;
let mut serial_decomp =
DomainDecomposition::new(&spins, 1, 0).expect("single domain decomposition");
serial_decomp.parallel_llg_step(h_ext, alpha, a_ex, cell_size, dt);
let serial_result = serial_decomp.gather();
let mut par_decomp =
DomainDecomposition::new(&spins, 4, 1).expect("4-domain decomposition");
par_decomp.parallel_llg_step(h_ext, alpha, a_ex, cell_size, dt);
let par_result = par_decomp.gather();
assert_eq!(serial_result.len(), par_result.len());
let mut max_diff = 0.0_f64;
for (s, p) in serial_result.iter().zip(par_result.iter()) {
let diff = (*s - *p).magnitude();
if diff > max_diff {
max_diff = diff;
}
}
assert!(
max_diff < 1e-6,
"max difference between serial and parallel: {:.2e}",
max_diff,
);
}
#[test]
fn test_2d_decomposition_total_size() {
let rows = 12;
let cols = 15;
let spins: Vec<Vector3<f64>> = (0..rows * cols)
.map(|_| Vector3::new(0.0, 0.0, 1.0))
.collect();
let decomp = DomainDecomposition2D::new(&spins, rows, cols, 3, 3)
.expect("2D decomposition should succeed");
assert_eq!(decomp.total_spins(), rows * cols);
let gathered = decomp.gather();
assert_eq!(gathered.len(), rows * cols);
}
#[test]
fn test_2d_ghost_update() {
let rows = 6;
let cols = 6;
let spins: Vec<Vector3<f64>> = (0..rows * cols)
.map(|i| Vector3::new(i as f64, 0.0, 0.0))
.collect();
let mut decomp = DomainDecomposition2D::new(&spins, rows, cols, 2, 2)
.expect("2D decomposition should succeed");
for s in &mut decomp.tiles[0].spins {
s.x += 1000.0;
}
decomp.update_ghost_cells();
let tile_01 = &decomp.tiles[1];
assert!(
!tile_01.ghost_left.is_empty(),
"tile (0,1) should have ghost_left"
);
for g in &tile_01.ghost_left {
assert!(
g.x >= 1000.0,
"ghost should reflect modified tile, got {}",
g.x,
);
}
}
#[test]
fn test_error_on_zero_domains() {
let spins = uniform_z_spins(10);
let result = DomainDecomposition::new(&spins, 0, 1);
assert!(result.is_err());
}
#[test]
fn test_error_on_empty_spins() {
let spins: Vec<Vector3<f64>> = Vec::new();
let result = DomainDecomposition::new(&spins, 2, 1);
assert!(result.is_err());
}
#[test]
fn test_heun_step_preserves_normalization() {
let n = 20;
let spins: Vec<Vector3<f64>> = (0..n)
.map(|i| {
let angle = 0.2 * (i as f64);
Vector3::new(angle.sin(), 0.0, angle.cos()).normalize()
})
.collect();
let mut decomp =
DomainDecomposition::new(&spins, 2, 1).expect("decomposition should succeed");
decomp.parallel_heun_step(Vector3::new(0.0, 0.0, 1.0), 0.01, 1e-11, 1e-9, 1e-14);
for domain in &decomp.domains {
for s in &domain.spins {
let mag = s.magnitude();
assert!(
(mag - 1.0).abs() < 1e-10,
"spin magnitude {} deviates from 1.0",
mag,
);
}
}
}
}