use crate::allocators::DimAllocator;
use crate::assembly::buffers::{BasisFunctionBuffer, QuadratureBuffer};
use crate::assembly::local::{ElementConnectivityAssembler, ElementMatrixAssembler, QuadratureTable};
use crate::element::{ReferenceFiniteElement, VolumetricFiniteElement};
use crate::nalgebra::{DMatrixViewMut, DefaultAllocator, DimName, OPoint};
use crate::space::{ElementInSpace, FiniteElementConnectivity, VolumetricFiniteElementSpace};
use crate::util::clone_upper_to_lower;
use crate::Real;
use davenport::{define_thread_local_workspace, with_thread_local_workspace};
use itertools::izip;
use nalgebra::Scalar;
use serde::{Deserialize, Serialize};
use std::fmt::{Display, Formatter};
#[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Serialize, Deserialize)]
#[repr(transparent)]
pub struct Density<T>(pub T);
impl<T> Density<T> {
pub fn as_inner_slice<'a>(slice: &'a [Density<T>]) -> &'a [T] {
unsafe { std::mem::transmute(slice) }
}
pub fn from_inner_slice<'a>(slice: &'a [T]) -> &'a [Density<T>] {
unsafe { std::mem::transmute(slice) }
}
}
impl<T: Real> Default for Density<T> {
fn default() -> Self {
Density(T::zero())
}
}
impl<T: Display> Display for Density<T> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "Density({})", self.0)
}
}
#[derive(Debug, Clone)]
pub struct ElementMassAssembler<'a, Space, QTable> {
space: &'a Space,
qtable: &'a QTable,
solution_dim: usize,
}
impl<'a> ElementMassAssembler<'a, (), ()> {
pub fn with_solution_dim(solution_dim: usize) -> Self {
Self {
space: &(),
qtable: &(),
solution_dim,
}
}
}
impl<'a, QTable> ElementMassAssembler<'a, (), QTable> {
pub fn with_space<Space>(self, space: &'a Space) -> ElementMassAssembler<'a, Space, QTable> {
ElementMassAssembler {
space,
qtable: self.qtable,
solution_dim: self.solution_dim,
}
}
}
impl<'a, Space> ElementMassAssembler<'a, Space, ()> {
pub fn with_quadrature_table<QTable>(self, table: &'a QTable) -> ElementMassAssembler<'a, Space, QTable> {
ElementMassAssembler {
space: self.space,
qtable: table,
solution_dim: self.solution_dim,
}
}
}
define_thread_local_workspace!(WORKSPACE);
impl<'a, Space, QTable> ElementConnectivityAssembler for ElementMassAssembler<'a, Space, QTable>
where
Space: FiniteElementConnectivity,
{
fn solution_dim(&self) -> usize {
self.solution_dim
}
fn num_elements(&self) -> usize {
self.space.num_elements()
}
fn num_nodes(&self) -> usize {
self.space.num_nodes()
}
fn element_node_count(&self, element_index: usize) -> usize {
self.space.element_node_count(element_index)
}
fn populate_element_nodes(&self, output: &mut [usize], element_index: usize) {
self.space.populate_element_nodes(output, element_index)
}
}
#[derive(Debug)]
struct MassAssemblerWorkspace<T: Scalar, D: DimName>
where
DefaultAllocator: DimAllocator<T, D>,
{
quadrature_buffer: QuadratureBuffer<T, D, Density<T>>,
basis_buffer: BasisFunctionBuffer<T>,
}
impl<T: Real, D: DimName> Default for MassAssemblerWorkspace<T, D>
where
DefaultAllocator: DimAllocator<T, D>,
{
fn default() -> Self {
Self {
quadrature_buffer: Default::default(),
basis_buffer: Default::default(),
}
}
}
impl<'a, T, Space, QTable> ElementMatrixAssembler<T> for ElementMassAssembler<'a, Space, QTable>
where
T: Real,
Space: VolumetricFiniteElementSpace<T>,
QTable: QuadratureTable<T, Space::GeometryDim, Data = Density<T>>,
DefaultAllocator: DimAllocator<T, Space::GeometryDim>,
{
fn assemble_element_matrix_into(&self, element_index: usize, output: DMatrixViewMut<T>) -> eyre::Result<()> {
with_thread_local_workspace(&WORKSPACE, |ws: &mut MassAssemblerWorkspace<T, Space::GeometryDim>| {
let element = ElementInSpace::from_space_and_element_index(self.space, element_index);
ws.basis_buffer
.resize(element.num_nodes(), Space::ReferenceDim::dim());
ws.basis_buffer
.populate_element_nodes_from_space(element_index, self.space);
ws.quadrature_buffer
.populate_element_quadrature_from_table(element_index, self.qtable);
assemble_element_mass_matrix(
output,
&element,
ws.quadrature_buffer.weights(),
ws.quadrature_buffer.points(),
Density::as_inner_slice(ws.quadrature_buffer.data()),
self.solution_dim,
ws.basis_buffer.element_basis_values_mut(),
)
})
}
}
#[allow(non_snake_case)]
pub fn assemble_element_mass_matrix<'a, T, Element>(
output: impl Into<DMatrixViewMut<'a, T>>,
element: &Element,
quadrature_weights: &[T],
quadrature_points: &[OPoint<T, Element::ReferenceDim>],
quadrature_density: &[T],
solution_dim: usize,
basis_values_buffer: &mut [T],
) -> eyre::Result<()>
where
T: Real,
Element: VolumetricFiniteElement<T>,
DefaultAllocator: DimAllocator<T, Element::GeometryDim>,
{
assemble_element_mass_matrix_(
output.into(),
element,
quadrature_weights,
quadrature_points,
quadrature_density,
solution_dim,
basis_values_buffer,
)
}
#[allow(non_snake_case)]
fn assemble_element_mass_matrix_<T, Element>(
mut output: DMatrixViewMut<T>,
element: &Element,
quadrature_weights: &[T],
quadrature_points: &[OPoint<T, Element::ReferenceDim>],
quadrature_density: &[T],
solution_dim: usize,
basis_values_buffer: &mut [T],
) -> eyre::Result<()>
where
T: Real,
Element: VolumetricFiniteElement<T>,
DefaultAllocator: DimAllocator<T, Element::GeometryDim>,
{
assert_eq!(quadrature_weights.len(), quadrature_points.len());
assert_eq!(quadrature_points.len(), quadrature_density.len());
assert_eq!(basis_values_buffer.len(), element.num_nodes());
let s = solution_dim;
let n = element.num_nodes();
assert_eq!(output.nrows(), s * n, "Output matrix dimension mismatch");
assert_eq!(output.ncols(), s * n, "Output matrix dimension mismatch");
output.fill(T::zero());
let phi = basis_values_buffer;
let quadrature_iter = izip!(quadrature_weights, quadrature_points, quadrature_density);
for (&weight, point, density) in quadrature_iter {
let j_det = element.reference_jacobian(point).determinant();
element.populate_basis(phi, &point);
let scale = weight * j_det.abs() * *density;
for I in 0..n {
for J in I..n {
let m_IJ_contrib = scale * phi[I] * phi[J];
let mut M_IJ = output.view_mut((s * I, s * J), (s, s));
for i in 0..s {
M_IJ[(i, i)] += m_IJ_contrib;
}
}
}
}
clone_upper_to_lower(&mut output);
Ok(())
}