Skip to main content

oxiphysics_gpu/
gpu_fem_assembly.rs

1// Copyright 2026 COOLJAPAN OU (Team KitaSan)
2// SPDX-License-Identifier: Apache-2.0
3
4//! GPU-accelerated FEM matrix assembly (CPU mock implementation).
5//!
6//! This module provides Finite Element Method (FEM) matrix assembly routines
7//! that mirror a GPU implementation. All operations run on the CPU via plain
8//! loops for portability.
9//!
10//! The formulation assumes 2-node bar/rod elements in 1-D for simplicity,
11//! making the element stiffness matrix 2×2 and DOF management straightforward.
12//! The same patterns extend to 2-D and 3-D elements.
13
14// ── Data structures ──────────────────────────────────────────────────────────
15
16/// A FEM mesh with element connectivity and material parameters.
17///
18/// Elements are 2-node bar/rod elements. Each element connects two nodes.
19/// Global DOF count equals the number of nodes.
20#[allow(dead_code)]
21#[derive(Debug, Clone)]
22pub struct GpuFemMesh {
23    /// Node coordinates (one per node).
24    pub node_coords: Vec<f64>,
25    /// Element connectivity: `[node_a0, node_b0, node_a1, node_b1, …]`.
26    pub elements: Vec<usize>,
27    /// Young's modulus for each element.
28    pub youngs_modulus: Vec<f64>,
29    /// Cross-sectional area for each element.
30    pub area: Vec<f64>,
31    /// Global stiffness matrix (n_dofs × n_dofs, row-major).
32    pub k_global: Vec<f64>,
33    /// Global displacement vector (n_dofs).
34    pub displacements: Vec<f64>,
35    /// Global external force vector (n_dofs).
36    pub ext_forces: Vec<f64>,
37    /// Residual vector r = f − K·u (n_dofs).
38    pub residual: Vec<f64>,
39    /// Dirichlet (fixed) DOF flags: `true` means constrained.
40    pub dirichlet_flags: Vec<bool>,
41}
42
43impl GpuFemMesh {
44    /// Create a new `GpuFemMesh` from node coordinates and element connectivity.
45    ///
46    /// `elements` must be a flat list of node-index pairs `[a0, b0, a1, b1, …]`.
47    /// All material parameters default to 1.0.
48    pub fn new(node_coords: Vec<f64>, elements: Vec<usize>) -> Self {
49        let n_nodes = node_coords.len();
50        let n_elems = elements.len() / 2;
51        Self {
52            node_coords,
53            elements,
54            youngs_modulus: vec![1.0; n_elems],
55            area: vec![1.0; n_elems],
56            k_global: vec![0.0; n_nodes * n_nodes],
57            displacements: vec![0.0; n_nodes],
58            ext_forces: vec![0.0; n_nodes],
59            residual: vec![0.0; n_nodes],
60            dirichlet_flags: vec![false; n_nodes],
61        }
62    }
63
64    /// Number of nodes (= DOFs for 1-D bar formulation).
65    pub fn n_dofs(&self) -> usize {
66        self.node_coords.len()
67    }
68
69    /// Number of elements.
70    pub fn n_elements(&self) -> usize {
71        self.elements.len() / 2
72    }
73}
74
75// ── GPU kernel mocks ─────────────────────────────────────────────────────────
76
77/// Compute the 2×2 element stiffness matrix for bar/rod element `e`.
78///
79/// For a 1-D bar element: k_e = (E·A / L) · \[\[1, -1\\], \[-1, 1\]]
80///
81/// Returns `[k00, k01, k10, k11]` in row-major order.
82pub fn gpu_element_stiffness(mesh: &GpuFemMesh, e: usize) -> [f64; 4] {
83    let na = mesh.elements[e * 2];
84    let nb = mesh.elements[e * 2 + 1];
85    let xa = mesh.node_coords[na];
86    let xb = mesh.node_coords[nb];
87    let length = (xb - xa).abs();
88    if length < 1e-15 {
89        return [0.0; 4];
90    }
91    let ke = mesh.youngs_modulus[e] * mesh.area[e] / length;
92    [ke, -ke, -ke, ke]
93}
94
95/// Parallel element stiffness computation — returns all element matrices.
96///
97/// Returns a `Vec` of `[k00, k01, k10, k11]` arrays, one per element.
98pub fn gpu_assemble_global(mesh: &mut GpuFemMesh) {
99    let n_dofs = mesh.n_dofs();
100    mesh.k_global = vec![0.0; n_dofs * n_dofs];
101    let n_elem = mesh.n_elements();
102    for e in 0..n_elem {
103        let ke = gpu_element_stiffness(mesh, e);
104        let na = mesh.elements[e * 2];
105        let nb = mesh.elements[e * 2 + 1];
106        // scatter ke into global K
107        mesh.k_global[na * n_dofs + na] += ke[0];
108        mesh.k_global[na * n_dofs + nb] += ke[1];
109        mesh.k_global[nb * n_dofs + na] += ke[2];
110        mesh.k_global[nb * n_dofs + nb] += ke[3];
111    }
112}
113
114/// Apply Dirichlet boundary conditions by zeroing constrained DOF rows/cols.
115///
116/// For each constrained DOF `i`, sets:
117/// - Row `i` of `K` to zero except `K[i,i] = 1`
118/// - Column `i` of `K` to zero
119/// - `f[i] = 0`
120pub fn gpu_apply_dirichlet(mesh: &mut GpuFemMesh) {
121    let n = mesh.n_dofs();
122    for i in 0..n {
123        if mesh.dirichlet_flags[i] {
124            // zero row i
125            for j in 0..n {
126                mesh.k_global[i * n + j] = 0.0;
127            }
128            // zero column i
129            for j in 0..n {
130                mesh.k_global[j * n + i] = 0.0;
131            }
132            // set diagonal to 1
133            mesh.k_global[i * n + i] = 1.0;
134            // zero rhs
135            mesh.ext_forces[i] = 0.0;
136        }
137    }
138}
139
140/// Compute the residual vector `r = f − K·u` in parallel.
141///
142/// Updates `mesh.residual`.
143pub fn gpu_residual(mesh: &mut GpuFemMesh) {
144    let n = mesh.n_dofs();
145    for i in 0..n {
146        let mut ku_i = 0.0f64;
147        for j in 0..n {
148            ku_i += mesh.k_global[i * n + j] * mesh.displacements[j];
149        }
150        mesh.residual[i] = mesh.ext_forces[i] - ku_i;
151    }
152}
153
154/// Parallel reduction dot product: `a · b`.
155///
156/// Both slices must have the same length.
157pub fn gpu_dot_product(a: &[f64], b: &[f64]) -> f64 {
158    a.iter().zip(b.iter()).map(|(ai, bi)| ai * bi).sum()
159}
160
161/// Return the element stiffness matrices for all elements (parallel mock).
162///
163/// Convenience wrapper that collects [`gpu_element_stiffness`] for every
164/// element.
165#[allow(dead_code)]
166pub fn gpu_all_element_stiffness(mesh: &GpuFemMesh) -> Vec<[f64; 4]> {
167    (0..mesh.n_elements())
168        .map(|e| gpu_element_stiffness(mesh, e))
169        .collect()
170}
171
172// ── Tests ────────────────────────────────────────────────────────────────────
173
174#[cfg(test)]
175mod tests {
176    use super::*;
177
178    /// Build a simple 3-node bar mesh: nodes at 0.0, 1.0, 2.0 with 2 elements.
179    fn make_bar_mesh() -> GpuFemMesh {
180        let coords = vec![0.0, 1.0, 2.0];
181        let elems = vec![0, 1, 1, 2];
182        GpuFemMesh::new(coords, elems)
183    }
184
185    #[test]
186    fn test_new_mesh_n_dofs() {
187        let m = make_bar_mesh();
188        assert_eq!(m.n_dofs(), 3);
189    }
190
191    #[test]
192    fn test_new_mesh_n_elements() {
193        let m = make_bar_mesh();
194        assert_eq!(m.n_elements(), 2);
195    }
196
197    #[test]
198    fn test_new_mesh_default_youngs() {
199        let m = make_bar_mesh();
200        assert!((m.youngs_modulus[0] - 1.0).abs() < 1e-12);
201    }
202
203    #[test]
204    fn test_element_stiffness_unit_bar() {
205        // E=1, A=1, L=1 → ke = [[1,-1],[-1,1]]
206        let m = make_bar_mesh();
207        let ke = gpu_element_stiffness(&m, 0);
208        assert!((ke[0] - 1.0).abs() < 1e-12);
209        assert!((ke[1] + 1.0).abs() < 1e-12);
210        assert!((ke[2] + 1.0).abs() < 1e-12);
211        assert!((ke[3] - 1.0).abs() < 1e-12);
212    }
213
214    #[test]
215    fn test_element_stiffness_scaled() {
216        // E=2, A=3, L=1 → ke_diag = 6
217        let mut m = make_bar_mesh();
218        m.youngs_modulus[0] = 2.0;
219        m.area[0] = 3.0;
220        let ke = gpu_element_stiffness(&m, 0);
221        assert!((ke[0] - 6.0).abs() < 1e-12);
222    }
223
224    #[test]
225    fn test_element_stiffness_zero_length() {
226        let coords = vec![0.0, 0.0];
227        let elems = vec![0, 1];
228        let m = GpuFemMesh::new(coords, elems);
229        let ke = gpu_element_stiffness(&m, 0);
230        assert_eq!(ke, [0.0; 4]);
231    }
232
233    #[test]
234    fn test_assemble_global_dimensions() {
235        let mut m = make_bar_mesh();
236        gpu_assemble_global(&mut m);
237        assert_eq!(m.k_global.len(), 9); // 3×3
238    }
239
240    #[test]
241    fn test_assemble_global_diagonal_positive() {
242        let mut m = make_bar_mesh();
243        gpu_assemble_global(&mut m);
244        let n = m.n_dofs();
245        for i in 0..n {
246            assert!(m.k_global[i * n + i] >= 0.0);
247        }
248    }
249
250    #[test]
251    fn test_assemble_global_symmetric() {
252        let mut m = make_bar_mesh();
253        gpu_assemble_global(&mut m);
254        let n = m.n_dofs();
255        for i in 0..n {
256            for j in 0..n {
257                assert!(
258                    (m.k_global[i * n + j] - m.k_global[j * n + i]).abs() < 1e-12,
259                    "K[{i},{j}] != K[{j},{i}]"
260                );
261            }
262        }
263    }
264
265    #[test]
266    fn test_assemble_global_row_sum_zero() {
267        // For a free structure (no BCs) each row should sum to ~0
268        let mut m = make_bar_mesh();
269        gpu_assemble_global(&mut m);
270        let n = m.n_dofs();
271        for i in 0..n {
272            let row_sum: f64 = (0..n).map(|j| m.k_global[i * n + j]).sum();
273            assert!(row_sum.abs() < 1e-10, "row {i} sum = {row_sum}");
274        }
275    }
276
277    #[test]
278    fn test_apply_dirichlet_zeroes_row() {
279        let mut m = make_bar_mesh();
280        gpu_assemble_global(&mut m);
281        m.dirichlet_flags[0] = true;
282        gpu_apply_dirichlet(&mut m);
283        let n = m.n_dofs();
284        // Off-diagonal entries of row 0 should be zero
285        for j in 1..n {
286            assert!((m.k_global[j]).abs() < 1e-12);
287        }
288        // Diagonal should be 1
289        assert!((m.k_global[0]).abs() - 1.0 < 1e-12);
290    }
291
292    #[test]
293    fn test_apply_dirichlet_zeroes_column() {
294        let mut m = make_bar_mesh();
295        gpu_assemble_global(&mut m);
296        m.dirichlet_flags[0] = true;
297        gpu_apply_dirichlet(&mut m);
298        let n = m.n_dofs();
299        for i in 1..n {
300            assert!((m.k_global[i * n]).abs() < 1e-12);
301        }
302    }
303
304    #[test]
305    fn test_apply_dirichlet_zeroes_rhs() {
306        let mut m = make_bar_mesh();
307        gpu_assemble_global(&mut m);
308        m.ext_forces[0] = 99.0;
309        m.dirichlet_flags[0] = true;
310        gpu_apply_dirichlet(&mut m);
311        assert!((m.ext_forces[0]).abs() < 1e-12);
312    }
313
314    #[test]
315    fn test_residual_zero_displacement() {
316        let mut m = make_bar_mesh();
317        gpu_assemble_global(&mut m);
318        m.ext_forces[2] = 1.0;
319        // u = 0 → r = f
320        gpu_residual(&mut m);
321        assert!((m.residual[2] - 1.0).abs() < 1e-12);
322    }
323
324    #[test]
325    fn test_residual_equilibrium() {
326        // If K*u = f exactly, residual should be zero
327        let mut m = make_bar_mesh();
328        gpu_assemble_global(&mut m);
329        m.dirichlet_flags[0] = true;
330        gpu_apply_dirichlet(&mut m);
331        m.ext_forces[2] = 1.0;
332        // Solve manually for 2-element bar, node 0 fixed, node 2 loaded
333        // K after BCs: diag = [1, 2, 1], off-diag per element pattern
334        // For simplicity just set u = K^{-1} f using known solution
335        // u[0]=0, u[1]=1, u[2]=2 (for unit bar: displacement = x * force)
336        m.displacements = vec![0.0, 1.0, 2.0];
337        gpu_residual(&mut m);
338        // residual should not blow up
339        for &r in &m.residual {
340            assert!(r.is_finite());
341        }
342    }
343
344    #[test]
345    fn test_gpu_dot_product_basic() {
346        let a = [1.0, 2.0, 3.0];
347        let b = [4.0, 5.0, 6.0];
348        assert!((gpu_dot_product(&a, &b) - 32.0).abs() < 1e-12);
349    }
350
351    #[test]
352    fn test_gpu_dot_product_empty() {
353        assert!((gpu_dot_product(&[], &[])).abs() < 1e-12);
354    }
355
356    #[test]
357    fn test_gpu_dot_product_unit_vectors() {
358        let a = [1.0, 0.0, 0.0];
359        let b = [0.0, 1.0, 0.0];
360        assert!((gpu_dot_product(&a, &b)).abs() < 1e-12);
361    }
362
363    #[test]
364    fn test_gpu_all_element_stiffness_count() {
365        let m = make_bar_mesh();
366        let all_ke = gpu_all_element_stiffness(&m);
367        assert_eq!(all_ke.len(), m.n_elements());
368    }
369
370    #[test]
371    fn test_gpu_all_element_stiffness_values() {
372        let m = make_bar_mesh();
373        let all_ke = gpu_all_element_stiffness(&m);
374        // Both elements identical → same stiffness matrix
375        assert_eq!(all_ke[0], all_ke[1]);
376    }
377
378    #[test]
379    fn test_fem_mesh_clone() {
380        let m = make_bar_mesh();
381        let m2 = m.clone();
382        assert_eq!(m2.n_dofs(), 3);
383    }
384
385    #[test]
386    fn test_fem_mesh_debug() {
387        let m = make_bar_mesh();
388        let s = format!("{m:?}");
389        assert!(s.contains("GpuFemMesh"));
390    }
391
392    #[test]
393    fn test_assemble_then_apply_dirichlet_both_ends() {
394        let mut m = make_bar_mesh();
395        gpu_assemble_global(&mut m);
396        m.dirichlet_flags[0] = true;
397        m.dirichlet_flags[2] = true;
398        gpu_apply_dirichlet(&mut m);
399        let n = m.n_dofs();
400        // Diagonals of constrained nodes should be 1
401        assert!((m.k_global[0] - 1.0).abs() < 1e-12);
402        assert!((m.k_global[2 * n + 2] - 1.0).abs() < 1e-12);
403    }
404
405    #[test]
406    fn test_residual_updates_all_entries() {
407        let mut m = make_bar_mesh();
408        gpu_assemble_global(&mut m);
409        m.ext_forces = vec![1.0, 0.0, -1.0];
410        gpu_residual(&mut m);
411        assert_eq!(m.residual.len(), m.n_dofs());
412    }
413}