Skip to main content

oxirs_physics/gpu/
stress_assembly.rs

1//! GPU-accelerated FEM stress assembly dispatcher.
2//!
3//! When the `gpu` feature is enabled, [`StressAssemblyDispatcher`] attempts
4//! to dispatch element-wise stiffness and consistent-mass matrix assembly to
5//! the `scirs2_core` GPU backend. When disabled, every method short-circuits
6//! to [`GpuError::BackendUnavailable`] so that callers can drop in a CPU
7//! fallback without `cfg` plumbing at the call site.
8//!
9//! The dispatcher is intentionally a thin scaffold: in this round we
10//! reproduce CPU semantics on the "GPU side" so behaviour is bit-identical
11//! between CPU and GPU paths. This matches the SAMM W3-S12 pattern where
12//! the dispatcher is in place ahead of a full kernel implementation.
13
14use super::{GpuError, GpuResult};
15
16/// Element type recognised by the stress-assembly dispatcher.
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum FemElementKind {
19    /// 1D bar element — 2 nodes, 1 DoF/node.
20    Bar1D,
21    /// 1D Euler-Bernoulli beam — 2 nodes, 2 DoF/node.
22    Beam1D,
23    /// 2D constant-strain triangle — 3 nodes, 2 DoF/node.
24    Triangle2D,
25    /// 2D bilinear quadrilateral — 4 nodes, 2 DoF/node.
26    Quad2D,
27}
28
29impl FemElementKind {
30    /// Number of DoFs per element of this kind.
31    #[inline]
32    pub fn dofs(self) -> usize {
33        match self {
34            Self::Bar1D => 2,
35            Self::Beam1D => 4,
36            Self::Triangle2D => 6,
37            Self::Quad2D => 8,
38        }
39    }
40}
41
42/// Description of a single FEM element used during GPU dispatch.
43///
44/// The dispatcher does not carry node geometry; it only carries the element
45/// kind and a per-element scaling factor (typically the product of Young's
46/// modulus and characteristic length used to scale the element stiffness).
47#[derive(Debug, Clone, PartialEq)]
48pub struct GpuElementDescriptor {
49    /// Element kind.
50    pub kind: FemElementKind,
51    /// Scaling factor applied to the element stiffness (Pa·m).
52    pub stiffness_scale: f64,
53    /// Scaling factor applied to the consistent mass matrix (kg).
54    pub mass_scale: f64,
55}
56
57impl Default for GpuElementDescriptor {
58    fn default() -> Self {
59        Self {
60            kind: FemElementKind::Bar1D,
61            stiffness_scale: 1.0,
62            mass_scale: 1.0,
63        }
64    }
65}
66
67/// Output of a stress-assembly dispatch — one entry per element.
68#[derive(Debug, Clone, PartialEq)]
69pub struct GpuElementContribution {
70    /// Element index in the input batch.
71    pub element_index: usize,
72    /// Trace of the element stiffness matrix (for sanity checks).
73    pub stiffness_trace: f64,
74    /// Trace of the element mass matrix.
75    pub mass_trace: f64,
76    /// Number of DoFs contributed by this element.
77    pub dofs: usize,
78}
79
80/// GPU-accelerated FEM stress / mass assembly dispatcher.
81#[derive(Debug, Default)]
82pub struct StressAssemblyDispatcher {
83    /// Whether the underlying GPU backend is ready to accept work.
84    backend_ready: bool,
85}
86
87impl StressAssemblyDispatcher {
88    /// Create a new dispatcher.
89    ///
90    /// With the `gpu` feature enabled the backend is reported ready; without
91    /// it the dispatcher always reports unavailable.
92    pub fn new() -> Self {
93        Self {
94            backend_ready: super::backend_available(),
95        }
96    }
97
98    /// Returns `true` when a usable GPU backend is available.
99    pub fn is_available(&self) -> bool {
100        self.backend_ready
101    }
102
103    /// Dispatch element stiffness assembly for `elements`.
104    ///
105    /// # Errors
106    ///
107    /// Returns [`GpuError::BackendUnavailable`] when compiled without the
108    /// `gpu` feature, or [`GpuError::InvalidInput`] when any element has a
109    /// non-finite scaling factor.
110    pub fn dispatch_stiffness_assembly(
111        &self,
112        elements: &[GpuElementDescriptor],
113    ) -> GpuResult<Vec<GpuElementContribution>> {
114        validate_elements(elements)?;
115        #[cfg(feature = "gpu")]
116        {
117            if !self.backend_ready {
118                return Err(GpuError::BackendUnavailable);
119            }
120            // Backend ready: produce the same per-element contributions the
121            // CPU path would. The trace of an element stiffness matrix is
122            // proportional to (stiffness_scale * dofs); we expose that scalar
123            // so callers can validate against the CPU reference assembly.
124            Ok(elements
125                .iter()
126                .enumerate()
127                .map(|(idx, e)| GpuElementContribution {
128                    element_index: idx,
129                    stiffness_trace: e.stiffness_scale * e.kind.dofs() as f64,
130                    mass_trace: e.mass_scale * e.kind.dofs() as f64,
131                    dofs: e.kind.dofs(),
132                })
133                .collect())
134        }
135        #[cfg(not(feature = "gpu"))]
136        {
137            Err(GpuError::BackendUnavailable)
138        }
139    }
140
141    /// Convenience wrapper around [`Self::dispatch_stiffness_assembly`] that
142    /// returns just the per-element DoF count (used to size the global
143    /// stiffness scatter-add).
144    ///
145    /// # Errors
146    ///
147    /// Returns [`GpuError::BackendUnavailable`] when the backend is not ready.
148    pub fn element_dof_layout(&self, elements: &[GpuElementDescriptor]) -> GpuResult<Vec<usize>> {
149        let contribs = self.dispatch_stiffness_assembly(elements)?;
150        Ok(contribs.into_iter().map(|c| c.dofs).collect())
151    }
152
153    /// Dispatch the consistent-mass matrix assembly for `elements`.
154    ///
155    /// # Errors
156    ///
157    /// Returns [`GpuError::BackendUnavailable`] when compiled without the
158    /// `gpu` feature.
159    pub fn dispatch_mass_assembly(&self, elements: &[GpuElementDescriptor]) -> GpuResult<Vec<f64>> {
160        let contribs = self.dispatch_stiffness_assembly(elements)?;
161        Ok(contribs.into_iter().map(|c| c.mass_trace).collect())
162    }
163}
164
165fn validate_elements(elements: &[GpuElementDescriptor]) -> GpuResult<()> {
166    for (idx, e) in elements.iter().enumerate() {
167        if !e.stiffness_scale.is_finite() {
168            return Err(GpuError::InvalidInput(format!(
169                "element {idx}: stiffness_scale is not finite"
170            )));
171        }
172        if !e.mass_scale.is_finite() {
173            return Err(GpuError::InvalidInput(format!(
174                "element {idx}: mass_scale is not finite"
175            )));
176        }
177    }
178    Ok(())
179}
180
181#[cfg(test)]
182mod tests {
183    use super::*;
184
185    fn sample_elements() -> Vec<GpuElementDescriptor> {
186        vec![
187            GpuElementDescriptor {
188                kind: FemElementKind::Bar1D,
189                stiffness_scale: 1.0e6,
190                mass_scale: 2.0,
191            },
192            GpuElementDescriptor {
193                kind: FemElementKind::Triangle2D,
194                stiffness_scale: 5.0e5,
195                mass_scale: 1.0,
196            },
197            GpuElementDescriptor {
198                kind: FemElementKind::Quad2D,
199                stiffness_scale: 7.0e5,
200                mass_scale: 0.5,
201            },
202        ]
203    }
204
205    #[test]
206    fn dofs_match_element_kind() {
207        assert_eq!(FemElementKind::Bar1D.dofs(), 2);
208        assert_eq!(FemElementKind::Beam1D.dofs(), 4);
209        assert_eq!(FemElementKind::Triangle2D.dofs(), 6);
210        assert_eq!(FemElementKind::Quad2D.dofs(), 8);
211    }
212
213    #[test]
214    fn dispatcher_availability_matches_feature() {
215        let d = StressAssemblyDispatcher::new();
216        #[cfg(feature = "gpu")]
217        assert!(d.is_available());
218        #[cfg(not(feature = "gpu"))]
219        assert!(!d.is_available());
220    }
221
222    #[test]
223    fn stiffness_assembly_no_feature_returns_unavailable() {
224        let d = StressAssemblyDispatcher::new();
225        let elements = sample_elements();
226        let result = d.dispatch_stiffness_assembly(&elements);
227        #[cfg(not(feature = "gpu"))]
228        assert!(matches!(result, Err(GpuError::BackendUnavailable)));
229        #[cfg(feature = "gpu")]
230        {
231            let contribs = result.expect("dispatch should succeed under gpu feature");
232            assert_eq!(contribs.len(), elements.len());
233            assert_eq!(contribs[0].dofs, 2);
234            assert_eq!(contribs[1].dofs, 6);
235            assert_eq!(contribs[2].dofs, 8);
236        }
237    }
238
239    #[test]
240    fn mass_assembly_no_feature_returns_unavailable() {
241        let d = StressAssemblyDispatcher::new();
242        let elements = sample_elements();
243        let result = d.dispatch_mass_assembly(&elements);
244        #[cfg(not(feature = "gpu"))]
245        assert!(matches!(result, Err(GpuError::BackendUnavailable)));
246        #[cfg(feature = "gpu")]
247        {
248            let traces = result.expect("dispatch should succeed under gpu feature");
249            assert_eq!(traces.len(), elements.len());
250            assert!((traces[0] - 4.0).abs() < 1e-12);
251        }
252    }
253
254    #[test]
255    fn invalid_input_caught_eagerly() {
256        let d = StressAssemblyDispatcher::new();
257        let bad = vec![GpuElementDescriptor {
258            kind: FemElementKind::Bar1D,
259            stiffness_scale: f64::NAN,
260            mass_scale: 1.0,
261        }];
262        // Validation runs before the feature gate so this should always fail
263        // with InvalidInput regardless of how the crate is compiled.
264        let result = d.dispatch_stiffness_assembly(&bad);
265        assert!(matches!(result, Err(GpuError::InvalidInput(_))));
266    }
267
268    #[test]
269    fn element_dof_layout_matches_dofs() {
270        let d = StressAssemblyDispatcher::new();
271        let elements = sample_elements();
272        let result = d.element_dof_layout(&elements);
273        #[cfg(not(feature = "gpu"))]
274        assert!(matches!(result, Err(GpuError::BackendUnavailable)));
275        #[cfg(feature = "gpu")]
276        {
277            let dofs = result.expect("dispatch should succeed under gpu feature");
278            assert_eq!(dofs, vec![2, 6, 8]);
279        }
280    }
281}