use crate::error::IntegrateError;
use crate::gpu_fem::stiffness::{
assemble_stiffness_cpu, Element2D, MeshElement2D, StiffnessMatrix,
};
use scirs2_core::ndarray::Array2;
#[derive(Debug, Clone)]
pub struct FemAssemblyConfig {
pub gpu_threshold: usize,
pub use_gpu: bool,
}
impl Default for FemAssemblyConfig {
fn default() -> Self {
FemAssemblyConfig {
gpu_threshold: 10_000,
use_gpu: true,
}
}
}
#[derive(Debug, Clone)]
pub enum GpuFemError {
GpuNotAvailable,
AssemblyFailed(String),
}
impl std::fmt::Display for GpuFemError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
GpuFemError::GpuNotAvailable => write!(f, "GPU not available"),
GpuFemError::AssemblyFailed(msg) => write!(f, "GPU assembly failed: {msg}"),
}
}
}
impl std::error::Error for GpuFemError {}
pub fn assemble_stiffness_auto(
elements: &[Element2D],
d_matrix: &Array2<f64>,
n_nodes: usize,
config: &FemAssemblyConfig,
) -> Result<StiffnessMatrix, IntegrateError> {
#[cfg(feature = "gpu_fem")]
if config.use_gpu && elements.len() >= config.gpu_threshold {
match crate::gpu_fem::wgpu_backend::assemble_stiffness_gpu(elements, d_matrix, n_nodes) {
Ok(result) => return Ok(result),
Err(_) => {
}
}
}
assemble_stiffness_cpu(elements, d_matrix, n_nodes)
}
pub fn assemble_stiffness_mesh_auto(
mesh_elements: &[MeshElement2D],
d_matrix: &Array2<f64>,
n_nodes: usize,
config: &FemAssemblyConfig,
) -> Result<StiffnessMatrix, IntegrateError> {
#[cfg(feature = "gpu_fem")]
if config.use_gpu && mesh_elements.len() >= config.gpu_threshold {
let _ = (&mesh_elements, n_nodes); }
crate::gpu_fem::stiffness::assemble_stiffness_mesh(mesh_elements, d_matrix, n_nodes)
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::array;
fn isotropic_d() -> Array2<f64> {
let e = 1.0_f64;
let nu = 0.3_f64;
let c = e / (1.0 - nu * nu);
array![
[c, c * nu, 0.0],
[c * nu, c, 0.0],
[0.0, 0.0, c * (1.0 - nu) / 2.0],
]
}
fn single_triangle() -> Vec<Element2D> {
vec![Element2D {
nodes: [[0.0, 0.0], [1.0, 0.0], [0.0, 1.0]],
material_id: 0,
}]
}
#[test]
fn test_auto_dispatch_routes_to_cpu_below_threshold() {
let d = isotropic_d();
let elements = single_triangle();
let config = FemAssemblyConfig {
gpu_threshold: 10_000,
use_gpu: true,
};
let km = assemble_stiffness_auto(&elements, &d, 3, &config).expect("auto dispatch failed");
assert!(!km.vals.is_empty());
}
#[test]
fn test_auto_dispatch_use_gpu_false() {
let d = isotropic_d();
let elements = single_triangle();
let config = FemAssemblyConfig {
gpu_threshold: 0,
use_gpu: false,
};
let km = assemble_stiffness_auto(&elements, &d, 3, &config).expect("auto dispatch failed");
assert!(!km.vals.is_empty());
}
#[test]
fn test_auto_dispatch_large_mesh_cpu_fallback() {
let d = isotropic_d();
let config = FemAssemblyConfig {
gpu_threshold: 5,
use_gpu: false, };
let n_elems = 10_usize;
let elements: Vec<Element2D> = (0..n_elems)
.map(|k| Element2D {
nodes: [
[k as f64, 0.0],
[k as f64 + 1.0, 0.0],
[k as f64 + 0.5, 1.0],
],
material_id: 0,
})
.collect();
let km = assemble_stiffness_auto(&elements, &d, 3 * n_elems, &config)
.expect("auto dispatch failed");
assert!(!km.vals.is_empty());
}
#[test]
fn test_cpu_two_element_assembly_distinct_dofs() {
let d = isotropic_d();
let config = FemAssemblyConfig {
gpu_threshold: 10_000,
use_gpu: false,
};
let elements = vec![
Element2D {
nodes: [[0.0, 0.0], [1.0, 0.0], [0.0, 1.0]],
material_id: 0,
},
Element2D {
nodes: [[1.0, 0.0], [1.0, 1.0], [0.0, 1.0]],
material_id: 0,
},
];
let km = assemble_stiffness_auto(&elements, &d, 6, &config)
.expect("two-element assembly failed");
let dense = km.to_dense();
let view = dense.view();
let block0_norm: f64 = (0..6)
.flat_map(|i| (0..6).map(move |j| (i, j)))
.map(|(i, j)| view[[i, j]].abs())
.sum();
let block1_norm: f64 = (6..12)
.flat_map(|i| (6..12).map(move |j| (i, j)))
.map(|(i, j)| view[[i, j]].abs())
.sum();
assert!(block0_norm > 0.0, "Element 0 DOF block is zero");
assert!(
block1_norm > 0.0,
"Element 1 DOF block is zero — DOFs were not distinct"
);
let cross_norm: f64 = (0..6)
.flat_map(|i| (6..12).map(move |j| (i, j)))
.map(|(i, j)| view[[i, j]].abs())
.sum();
assert!(
cross_norm < 1e-12,
"Cross-block is non-zero — shared DOFs detected unexpectedly"
);
}
}