use super::{GpuError, GpuResult};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum FemElementKind {
Bar1D,
Beam1D,
Triangle2D,
Quad2D,
}
impl FemElementKind {
#[inline]
pub fn dofs(self) -> usize {
match self {
Self::Bar1D => 2,
Self::Beam1D => 4,
Self::Triangle2D => 6,
Self::Quad2D => 8,
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct GpuElementDescriptor {
pub kind: FemElementKind,
pub stiffness_scale: f64,
pub mass_scale: f64,
}
impl Default for GpuElementDescriptor {
fn default() -> Self {
Self {
kind: FemElementKind::Bar1D,
stiffness_scale: 1.0,
mass_scale: 1.0,
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct GpuElementContribution {
pub element_index: usize,
pub stiffness_trace: f64,
pub mass_trace: f64,
pub dofs: usize,
}
#[derive(Debug, Default)]
pub struct StressAssemblyDispatcher {
backend_ready: bool,
}
impl StressAssemblyDispatcher {
pub fn new() -> Self {
Self {
backend_ready: super::backend_available(),
}
}
pub fn is_available(&self) -> bool {
self.backend_ready
}
pub fn dispatch_stiffness_assembly(
&self,
elements: &[GpuElementDescriptor],
) -> GpuResult<Vec<GpuElementContribution>> {
validate_elements(elements)?;
#[cfg(feature = "gpu")]
{
if !self.backend_ready {
return Err(GpuError::BackendUnavailable);
}
Ok(elements
.iter()
.enumerate()
.map(|(idx, e)| GpuElementContribution {
element_index: idx,
stiffness_trace: e.stiffness_scale * e.kind.dofs() as f64,
mass_trace: e.mass_scale * e.kind.dofs() as f64,
dofs: e.kind.dofs(),
})
.collect())
}
#[cfg(not(feature = "gpu"))]
{
Err(GpuError::BackendUnavailable)
}
}
pub fn element_dof_layout(&self, elements: &[GpuElementDescriptor]) -> GpuResult<Vec<usize>> {
let contribs = self.dispatch_stiffness_assembly(elements)?;
Ok(contribs.into_iter().map(|c| c.dofs).collect())
}
pub fn dispatch_mass_assembly(&self, elements: &[GpuElementDescriptor]) -> GpuResult<Vec<f64>> {
let contribs = self.dispatch_stiffness_assembly(elements)?;
Ok(contribs.into_iter().map(|c| c.mass_trace).collect())
}
}
fn validate_elements(elements: &[GpuElementDescriptor]) -> GpuResult<()> {
for (idx, e) in elements.iter().enumerate() {
if !e.stiffness_scale.is_finite() {
return Err(GpuError::InvalidInput(format!(
"element {idx}: stiffness_scale is not finite"
)));
}
if !e.mass_scale.is_finite() {
return Err(GpuError::InvalidInput(format!(
"element {idx}: mass_scale is not finite"
)));
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
fn sample_elements() -> Vec<GpuElementDescriptor> {
vec![
GpuElementDescriptor {
kind: FemElementKind::Bar1D,
stiffness_scale: 1.0e6,
mass_scale: 2.0,
},
GpuElementDescriptor {
kind: FemElementKind::Triangle2D,
stiffness_scale: 5.0e5,
mass_scale: 1.0,
},
GpuElementDescriptor {
kind: FemElementKind::Quad2D,
stiffness_scale: 7.0e5,
mass_scale: 0.5,
},
]
}
#[test]
fn dofs_match_element_kind() {
assert_eq!(FemElementKind::Bar1D.dofs(), 2);
assert_eq!(FemElementKind::Beam1D.dofs(), 4);
assert_eq!(FemElementKind::Triangle2D.dofs(), 6);
assert_eq!(FemElementKind::Quad2D.dofs(), 8);
}
#[test]
fn dispatcher_availability_matches_feature() {
let d = StressAssemblyDispatcher::new();
#[cfg(feature = "gpu")]
assert!(d.is_available());
#[cfg(not(feature = "gpu"))]
assert!(!d.is_available());
}
#[test]
fn stiffness_assembly_no_feature_returns_unavailable() {
let d = StressAssemblyDispatcher::new();
let elements = sample_elements();
let result = d.dispatch_stiffness_assembly(&elements);
#[cfg(not(feature = "gpu"))]
assert!(matches!(result, Err(GpuError::BackendUnavailable)));
#[cfg(feature = "gpu")]
{
let contribs = result.expect("dispatch should succeed under gpu feature");
assert_eq!(contribs.len(), elements.len());
assert_eq!(contribs[0].dofs, 2);
assert_eq!(contribs[1].dofs, 6);
assert_eq!(contribs[2].dofs, 8);
}
}
#[test]
fn mass_assembly_no_feature_returns_unavailable() {
let d = StressAssemblyDispatcher::new();
let elements = sample_elements();
let result = d.dispatch_mass_assembly(&elements);
#[cfg(not(feature = "gpu"))]
assert!(matches!(result, Err(GpuError::BackendUnavailable)));
#[cfg(feature = "gpu")]
{
let traces = result.expect("dispatch should succeed under gpu feature");
assert_eq!(traces.len(), elements.len());
assert!((traces[0] - 4.0).abs() < 1e-12);
}
}
#[test]
fn invalid_input_caught_eagerly() {
let d = StressAssemblyDispatcher::new();
let bad = vec![GpuElementDescriptor {
kind: FemElementKind::Bar1D,
stiffness_scale: f64::NAN,
mass_scale: 1.0,
}];
let result = d.dispatch_stiffness_assembly(&bad);
assert!(matches!(result, Err(GpuError::InvalidInput(_))));
}
#[test]
fn element_dof_layout_matches_dofs() {
let d = StressAssemblyDispatcher::new();
let elements = sample_elements();
let result = d.element_dof_layout(&elements);
#[cfg(not(feature = "gpu"))]
assert!(matches!(result, Err(GpuError::BackendUnavailable)));
#[cfg(feature = "gpu")]
{
let dofs = result.expect("dispatch should succeed under gpu feature");
assert_eq!(dofs, vec![2, 6, 8]);
}
}
}