use super::types::{GpuPdeConfig, PdeSolverError, PdeSolverResult, SolverStats};
#[derive(Debug, Clone)]
pub struct FemMesh {
pub nodes: Vec<[f64; 2]>,
pub elements: Vec<[usize; 3]>,
}
impl FemMesh {
pub fn num_nodes(&self) -> usize {
self.nodes.len()
}
pub fn num_elements(&self) -> usize {
self.elements.len()
}
}
pub fn uniform_rect_mesh(nx: usize, ny: usize, lx: f64, ly: f64) -> PdeSolverResult<FemMesh> {
if nx < 1 || ny < 1 {
return Err(PdeSolverError::InvalidGrid);
}
let dx = lx / nx as f64;
let dy = ly / ny as f64;
let node_count = (nx + 1) * (ny + 1);
let mut nodes = Vec::with_capacity(node_count);
for j in 0..=ny {
for i in 0..=nx {
nodes.push([i as f64 * dx, j as f64 * dy]);
}
}
let elem_count = 2 * nx * ny;
let mut elements = Vec::with_capacity(elem_count);
for j in 0..ny {
for i in 0..nx {
let n00 = j * (nx + 1) + i;
let n10 = j * (nx + 1) + i + 1;
let n01 = (j + 1) * (nx + 1) + i;
let n11 = (j + 1) * (nx + 1) + i + 1;
elements.push([n00, n10, n01]);
elements.push([n10, n11, n01]);
}
}
Ok(FemMesh { nodes, elements })
}
fn element_stiffness(nodes: &[[f64; 2]], el: &[usize; 3], kappa: f64) -> ([[f64; 9]; 1], f64) {
let [x0, y0] = nodes[el[0]];
let [x1, y1] = nodes[el[1]];
let [x2, y2] = nodes[el[2]];
let jac_det = (x1 - x0) * (y2 - y0) - (x2 - x0) * (y1 - y0);
let area = 0.5 * jac_det.abs();
let b = [
(y1 - y2) / jac_det,
(y2 - y0) / jac_det,
(y0 - y1) / jac_det,
];
let c = [
(x2 - x1) / jac_det,
(x0 - x2) / jac_det,
(x1 - x0) / jac_det,
];
let mut ke = [0.0_f64; 9];
for a in 0..3usize {
for bb in 0..3usize {
ke[a * 3 + bb] = kappa * area * (b[a] * b[bb] + c[a] * c[bb]);
}
}
([ke], area)
}
pub fn assemble_stiffness_gpu(mesh: &FemMesh, diffusivity: f64) -> PdeSolverResult<Vec<f64>> {
let n = mesh.num_nodes();
if n == 0 || mesh.num_elements() == 0 {
return Err(PdeSolverError::InvalidGrid);
}
let mut k_global = vec![0.0_f64; n * n];
for el in &mesh.elements {
let ([ke], _area) = element_stiffness(&mesh.nodes, el, diffusivity);
for a in 0..3usize {
for b in 0..3usize {
let row = el[a];
let col = el[b];
k_global[row * n + col] += ke[a * 3 + b];
}
}
}
Ok(k_global)
}
pub fn apply_dirichlet_gpu(
k: &mut Vec<f64>,
f: &mut Vec<f64>,
n: usize,
dof: usize,
value: f64,
) -> PdeSolverResult<()> {
if dof >= n {
return Err(PdeSolverError::BoundaryMismatch);
}
for row in 0..n {
if row != dof {
f[row] -= k[row * n + dof] * value;
}
}
for col in 0..n {
k[dof * n + col] = 0.0;
k[col * n + dof] = 0.0;
}
k[dof * n + dof] = 1.0;
f[dof] = value;
Ok(())
}
fn matvec_parallel(k: &[f64], x: &[f64], n: usize, tile_size: usize) -> Vec<f64> {
let mut y = vec![0.0_f64; n];
let effective_tile = tile_size.max(1);
let num_tiles = (n + effective_tile - 1) / effective_tile;
let tile_chunks: Vec<&mut [f64]> = {
y.chunks_mut(effective_tile).collect()
};
std::thread::scope(|s| {
for (tile_idx, chunk) in tile_chunks.into_iter().enumerate() {
let row_start = tile_idx * effective_tile;
s.spawn(move || {
for (local_row, yi) in chunk.iter_mut().enumerate() {
let row = row_start + local_row;
if row >= n {
break;
}
let mut acc = 0.0_f64;
for col in 0..n {
acc += k[row * n + col] * x[col];
}
*yi = acc;
}
});
}
});
y
}
fn matvec_seq(k: &[f64], x: &[f64], n: usize) -> Vec<f64> {
let mut y = vec![0.0_f64; n];
for row in 0..n {
let mut acc = 0.0_f64;
for col in 0..n {
acc += k[row * n + col] * x[col];
}
y[row] = acc;
}
y
}
fn dot(a: &[f64], b: &[f64]) -> f64 {
a.iter().zip(b.iter()).map(|(ai, bi)| ai * bi).sum()
}
fn linf_norm(v: &[f64]) -> f64 {
v.iter().map(|x| x.abs()).fold(0.0_f64, f64::max)
}
pub fn conjugate_gradient_gpu(
k: &[f64],
b: &[f64],
n: usize,
max_iter: usize,
tol: f64,
) -> PdeSolverResult<(Vec<f64>, SolverStats)> {
let mut diag = vec![0.0_f64; n];
for i in 0..n {
let d = k[i * n + i];
if d.abs() < f64::EPSILON * 1e6 {
return Err(PdeSolverError::SingularSystem);
}
diag[i] = d;
}
let mut x = vec![0.0_f64; n];
let mut r: Vec<f64> = b.to_vec();
let mut z: Vec<f64> = r.iter().zip(diag.iter()).map(|(ri, di)| ri / di).collect();
let mut p = z.clone();
let mut rz = dot(&r, &z);
if rz.sqrt() < tol {
return Ok((x, SolverStats::converged(0, rz.sqrt())));
}
for iter in 0..max_iter {
let ap = matvec_seq(k, &p, n);
let pap = dot(&p, &ap);
if pap.abs() < f64::EPSILON {
return Err(PdeSolverError::SingularSystem);
}
let alpha = rz / pap;
for i in 0..n {
x[i] += alpha * p[i];
}
for i in 0..n {
r[i] -= alpha * ap[i];
}
let res_norm = linf_norm(&r);
if res_norm < tol {
return Ok((x, SolverStats::converged(iter + 1, res_norm)));
}
for i in 0..n {
z[i] = r[i] / diag[i];
}
let rz_new = dot(&r, &z);
let beta = rz_new / rz;
rz = rz_new;
for i in 0..n {
p[i] = z[i] + beta * p[i];
}
}
Err(PdeSolverError::NotConverged { iterations: max_iter })
}
pub fn solve_fem_poisson(
mesh: &FemMesh,
source: &[f64],
bc_nodes: &[(usize, f64)],
config: &GpuPdeConfig,
) -> PdeSolverResult<Vec<f64>> {
let n = mesh.num_nodes();
if source.len() != n {
return Err(PdeSolverError::InvalidGrid);
}
let mut k = assemble_stiffness_gpu(mesh, 1.0)?;
let mut f = vec![0.0_f64; n];
for el in &mesh.elements {
let ([_ke], area) = element_stiffness(&mesh.nodes, el, 1.0);
let contrib = area / 3.0;
for &node in el.iter() {
f[node] += contrib * source[node];
}
}
for &(dof, val) in bc_nodes {
apply_dirichlet_gpu(&mut k, &mut f, n, dof, val)?;
}
let (u, _stats) =
conjugate_gradient_gpu(&k, &f, n, config.max_iterations, config.tolerance)?;
Ok(u)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_fem_uniform_mesh_triangle_count() {
let mesh = uniform_rect_mesh(2, 2, 1.0, 1.0).expect("mesh");
assert_eq!(mesh.num_elements(), 8, "expected 8 triangles");
assert_eq!(mesh.num_nodes(), 9, "expected 9 nodes");
}
#[test]
fn test_fem_assemble_stiffness_symmetric() {
let mesh = uniform_rect_mesh(3, 3, 1.0, 1.0).expect("mesh");
let n = mesh.num_nodes();
let k = assemble_stiffness_gpu(&mesh, 1.0).expect("assemble");
for i in 0..n {
for j in 0..n {
let diff = (k[i * n + j] - k[j * n + i]).abs();
assert!(diff < 1e-12, "K not symmetric at ({i},{j}): diff={diff}");
}
}
}
#[test]
fn test_fem_cg_solver_identity() {
let n = 3usize;
let k: Vec<f64> = (0..n * n)
.map(|idx| if idx / n == idx % n { 1.0 } else { 0.0 })
.collect();
let b = vec![1.0, 2.0, 3.0];
let (x, stats) = conjugate_gradient_gpu(&k, &b, n, 100, 1e-10).expect("cg");
assert!(stats.converged);
for i in 0..n {
assert!((x[i] - b[i]).abs() < 1e-8, "x[{i}]={} expected {}", x[i], b[i]);
}
}
#[test]
fn test_fem_cg_solver_diag_matrix() {
let n = 3usize;
let mut k = vec![0.0_f64; n * n];
k[0] = 2.0;
k[4] = 3.0;
k[8] = 4.0;
let b = vec![4.0, 9.0, 8.0];
let expected = vec![2.0, 3.0, 2.0];
let (x, _) = conjugate_gradient_gpu(&k, &b, n, 100, 1e-10).expect("cg");
for i in 0..n {
assert!(
(x[i] - expected[i]).abs() < 1e-8,
"x[{i}]={} expected {}",
x[i],
expected[i]
);
}
}
#[test]
fn test_cg_symmetric_pd_system() {
let n = 4usize;
let mut k = vec![0.0_f64; n * n];
for i in 0..n {
k[i * n + i] = 3.0;
if i + 1 < n {
k[i * n + i + 1] = -1.0;
k[(i + 1) * n + i] = -1.0;
}
}
let b = vec![1.0, 0.0, 0.0, 1.0];
let (x, stats) =
conjugate_gradient_gpu(&k, &b, n, 200, 1e-10).expect("cg");
assert!(stats.converged, "CG did not converge");
let ax = matvec_seq(&k, &x, n);
for i in 0..n {
assert!((ax[i] - b[i]).abs() < 1e-8, "residual at {i}: {}", (ax[i] - b[i]).abs());
}
}
#[test]
fn test_fem_poisson_1d_analytic() {
let nx = 4usize;
let ny = 4usize;
let mesh = uniform_rect_mesh(nx, ny, 1.0, 1.0).expect("mesh");
let n = mesh.num_nodes();
let source = vec![0.0_f64; n];
let mut bc_nodes = Vec::new();
for (k, node) in mesh.nodes.iter().enumerate() {
let [x, y] = *node;
let on_boundary = x.abs() < 1e-12
|| (x - 1.0).abs() < 1e-12
|| y.abs() < 1e-12
|| (y - 1.0).abs() < 1e-12;
if on_boundary {
bc_nodes.push((k, x)); }
}
let config = GpuPdeConfig { max_iterations: 5000, tolerance: 1e-10, ..Default::default() };
let u = solve_fem_poisson(&mesh, &source, &bc_nodes, &config).expect("fem");
for (idx, node) in mesh.nodes.iter().enumerate() {
let [x, _y] = *node;
let err = (u[idx] - x).abs();
assert!(err < 0.05, "node {idx} x={x:.3} u={:.3} err={err:.4}", u[idx]);
}
}
#[test]
fn test_fem_poisson_convergence() {
let mesh = uniform_rect_mesh(4, 4, 1.0, 1.0).expect("mesh");
let n = mesh.num_nodes();
let source = vec![0.0_f64; n];
let mut bc_nodes = Vec::new();
for (k, node) in mesh.nodes.iter().enumerate() {
let [x, y] = *node;
if x.abs() < 1e-12
|| (x - 1.0).abs() < 1e-12
|| y.abs() < 1e-12
|| (y - 1.0).abs() < 1e-12
{
bc_nodes.push((k, 1.0));
}
}
let config = GpuPdeConfig { max_iterations: 5000, tolerance: 1e-10, ..Default::default() };
let u = solve_fem_poisson(&mesh, &source, &bc_nodes, &config).expect("fem");
for (idx, &val) in u.iter().enumerate() {
assert!((val - 1.0).abs() < 0.05, "node {idx} val={val:.4} not close to 1");
}
}
}