use crate::assembly::{HelmholtzMatrix, HelmholtzProblem};
use crate::mesh::Mesh;
use num_complex::Complex64;
use std::collections::HashSet;
pub struct DirichletBC {
pub tag: usize,
value_fn: Box<dyn Fn(f64, f64, f64) -> Complex64>,
}
impl std::fmt::Debug for DirichletBC {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("DirichletBC")
.field("tag", &self.tag)
.finish()
}
}
impl Clone for DirichletBC {
fn clone(&self) -> Self {
Self {
tag: self.tag,
value_fn: Box::new(|_, _, _| Complex64::new(0.0, 0.0)),
}
}
}
impl DirichletBC {
pub fn new<F>(tag: usize, value_fn: F) -> Self
where
F: Fn(f64, f64, f64) -> Complex64 + 'static,
{
Self {
tag,
value_fn: Box::new(value_fn),
}
}
pub fn value(&self, x: f64, y: f64, z: f64) -> Complex64 {
(self.value_fn)(x, y, z)
}
pub fn boundary_nodes(&self, mesh: &Mesh) -> HashSet<usize> {
let mut nodes = HashSet::new();
for boundary in &mesh.boundaries {
if boundary.marker == self.tag as i32 {
for &node in &boundary.nodes {
nodes.insert(node);
}
}
}
nodes
}
}
pub fn apply_dirichlet(problem: &mut HelmholtzProblem, mesh: &Mesh, dirichlet_bcs: &[DirichletBC]) {
use std::collections::HashMap;
let mut dirichlet_map: HashMap<usize, Complex64> = HashMap::new();
for bc in dirichlet_bcs {
for &node in &bc.boundary_nodes(mesh) {
dirichlet_map.entry(node).or_insert_with(|| {
let point = &mesh.nodes[node];
bc.value(point.x, point.y, point.z)
});
}
}
let dirichlet_nodes: Vec<(usize, Complex64)> =
dirichlet_map.iter().map(|(&n, &v)| (n, v)).collect();
let dirichlet_set: HashSet<usize> = dirichlet_map.keys().copied().collect();
let matrix = &problem.matrix;
let n = matrix.dim;
let mut rhs_correction = vec![Complex64::new(0.0, 0.0); n];
for k in 0..matrix.nnz() {
let col = matrix.cols[k];
if dirichlet_set.contains(&col) {
let row = matrix.rows[k];
if !dirichlet_set.contains(&row) {
if let Some((_, g)) = dirichlet_nodes.iter().find(|(node, _)| *node == col) {
rhs_correction[row] += matrix.values[k] * g;
}
}
}
}
for (rhs_i, &corr) in problem.rhs.iter_mut().zip(rhs_correction.iter()) {
*rhs_i -= corr;
}
for (node, value) in &dirichlet_nodes {
problem.rhs[*node] = *value;
}
let mut new_rows = Vec::new();
let mut new_cols = Vec::new();
let mut new_values = Vec::new();
let mut dirichlet_diagonals_added: HashSet<usize> = HashSet::new();
for k in 0..matrix.nnz() {
let row = matrix.rows[k];
let col = matrix.cols[k];
if dirichlet_set.contains(&row) {
if row == col && !dirichlet_diagonals_added.contains(&row) {
new_rows.push(row);
new_cols.push(col);
new_values.push(Complex64::new(1.0, 0.0));
dirichlet_diagonals_added.insert(row);
}
} else if dirichlet_set.contains(&col) {
continue;
} else {
new_rows.push(row);
new_cols.push(col);
new_values.push(matrix.values[k]);
}
}
for (node, _) in &dirichlet_nodes {
if !dirichlet_diagonals_added.contains(node) {
new_rows.push(*node);
new_cols.push(*node);
new_values.push(Complex64::new(1.0, 0.0));
dirichlet_diagonals_added.insert(*node);
}
}
problem.matrix = HelmholtzMatrix {
rows: new_rows,
cols: new_cols,
values: new_values,
dim: n,
wavenumber: matrix.wavenumber,
};
}
pub fn apply_homogeneous_dirichlet(problem: &mut HelmholtzProblem, mesh: &Mesh, tags: &[usize]) {
let bcs: Vec<DirichletBC> = tags
.iter()
.map(|&tag| DirichletBC::new(tag, |_, _, _| Complex64::new(0.0, 0.0)))
.collect();
apply_dirichlet(problem, mesh, &bcs);
}
#[cfg(test)]
mod tests {
use super::*;
use crate::assembly::HelmholtzProblem;
use crate::basis::PolynomialDegree;
use crate::mesh::unit_square_triangles;
#[test]
fn test_dirichlet_bc_creation() {
let bc = DirichletBC::new(1, |x, y, _z| Complex64::new(x + y, 0.0));
assert_eq!(bc.tag, 1);
assert_eq!(bc.value(1.0, 2.0, 0.0), Complex64::new(3.0, 0.0));
}
#[test]
fn test_apply_homogeneous_dirichlet() {
let mesh = unit_square_triangles(4);
let k = Complex64::new(1.0, 0.0);
let mut problem = HelmholtzProblem::assemble(&mesh, PolynomialDegree::P1, k, |_, _, _| {
Complex64::new(1.0, 0.0)
});
apply_homogeneous_dirichlet(&mut problem, &mesh, &[1, 2]);
let bc1 = DirichletBC::new(1, |_, _, _| Complex64::new(0.0, 0.0));
let bc2 = DirichletBC::new(2, |_, _, _| Complex64::new(0.0, 0.0));
for &node in &bc1.boundary_nodes(&mesh) {
assert!(
problem.rhs[node].norm() < 1e-10,
"Node {} RHS should be 0",
node
);
}
for &node in &bc2.boundary_nodes(&mesh) {
assert!(
problem.rhs[node].norm() < 1e-10,
"Node {} RHS should be 0",
node
);
}
}
}