use super::assembly::{CsrMatrix, GlobalSystem};
use super::mesh::Mesh;
use super::{FemError, FemResult};
use std::collections::HashMap;
#[derive(Clone, Debug)]
pub struct DirichletBc {
pub node_id: usize,
pub local_dof: usize,
pub value: f64,
}
impl DirichletBc {
pub fn new(node_id: usize, local_dof: usize, value: f64) -> Self {
Self {
node_id,
local_dof,
value,
}
}
pub fn scalar(node_id: usize, value: f64) -> Self {
Self::new(node_id, 0, value)
}
}
#[derive(Clone, Debug)]
pub struct NeumannBc {
pub node_id: usize,
pub local_dof: usize,
pub value: f64,
}
impl NeumannBc {
pub fn new(node_id: usize, local_dof: usize, value: f64) -> Self {
Self {
node_id,
local_dof,
value,
}
}
pub fn scalar(node_id: usize, value: f64) -> Self {
Self::new(node_id, 0, value)
}
}
#[derive(Clone, Debug, Default)]
pub struct BoundaryConditionSet {
pub dirichlet: Vec<DirichletBc>,
pub neumann: Vec<NeumannBc>,
pub penalty_factor: Option<f64>,
}
impl BoundaryConditionSet {
pub fn new() -> Self {
Self::default()
}
pub fn add_dirichlet(&mut self, bc: DirichletBc) -> &mut Self {
self.dirichlet.push(bc);
self
}
pub fn add_neumann(&mut self, bc: NeumannBc) -> &mut Self {
self.neumann.push(bc);
self
}
pub fn set_penalty_factor(&mut self, factor: f64) -> &mut Self {
self.penalty_factor = Some(factor);
self
}
pub fn add_scalar_dirichlet(&mut self, conditions: &[(usize, f64)]) -> &mut Self {
for &(node_id, value) in conditions {
self.dirichlet.push(DirichletBc::scalar(node_id, value));
}
self
}
pub fn add_point_force(&mut self, node_id: usize, dof: usize, force: f64) -> &mut Self {
self.neumann.push(NeumannBc::new(node_id, dof, force));
self
}
pub fn num_dirichlet(&self) -> usize {
self.dirichlet.len()
}
pub fn num_neumann(&self) -> usize {
self.neumann.len()
}
pub fn is_empty(&self) -> bool {
self.dirichlet.is_empty() && self.neumann.is_empty()
}
}
#[derive(Clone, Debug)]
pub struct BoundaryCondition {
pub constrained_dofs: HashMap<usize, f64>,
pub neumann_loads: HashMap<usize, f64>,
pub penalty_factor: Option<f64>,
}
impl BoundaryCondition {
pub fn from_set(
bc_set: &BoundaryConditionSet,
mesh: &Mesh,
dofs_per_node: usize,
) -> FemResult<Self> {
let mut constrained_dofs = HashMap::new();
let mut neumann_loads = HashMap::new();
for bc in &bc_set.dirichlet {
if bc.node_id >= mesh.num_nodes() {
return Err(FemError::BoundaryError(format!(
"Dirichlet BC node {} out of range (mesh has {} nodes)",
bc.node_id,
mesh.num_nodes()
)));
}
if bc.local_dof >= dofs_per_node {
return Err(FemError::BoundaryError(format!(
"Dirichlet BC local DOF {} out of range (max {})",
bc.local_dof,
dofs_per_node - 1
)));
}
let global_dof = bc.node_id * dofs_per_node + bc.local_dof;
constrained_dofs.insert(global_dof, bc.value);
}
for bc in &bc_set.neumann {
if bc.node_id >= mesh.num_nodes() {
return Err(FemError::BoundaryError(format!(
"Neumann BC node {} out of range (mesh has {} nodes)",
bc.node_id,
mesh.num_nodes()
)));
}
if bc.local_dof >= dofs_per_node {
return Err(FemError::BoundaryError(format!(
"Neumann BC local DOF {} out of range (max {})",
bc.local_dof,
dofs_per_node - 1
)));
}
let global_dof = bc.node_id * dofs_per_node + bc.local_dof;
*neumann_loads.entry(global_dof).or_insert(0.0) += bc.value;
}
Ok(Self {
constrained_dofs,
neumann_loads,
penalty_factor: bc_set.penalty_factor,
})
}
}
pub fn apply_boundary_conditions(
system: &GlobalSystem,
bc: &BoundaryCondition,
) -> FemResult<(CsrMatrix, Vec<f64>)> {
let n = system.stiffness.n_rows;
let mut load = system.load.clone();
for (&dof, &force) in &bc.neumann_loads {
if dof < n {
load[dof] += force;
}
}
if let Some(penalty) = bc.penalty_factor {
apply_dirichlet_penalty(system, &bc.constrained_dofs, &mut load, penalty)
} else {
apply_dirichlet_elimination(system, &bc.constrained_dofs, &load)
}
}
fn apply_dirichlet_elimination(
system: &GlobalSystem,
constrained: &HashMap<usize, f64>,
load: &[f64],
) -> FemResult<(CsrMatrix, Vec<f64>)> {
let n = system.stiffness.n_rows;
let mut f = load.to_vec();
for (&constrained_dof, &prescribed_val) in constrained {
for j in 0..n {
if !constrained.contains_key(&j) {
let k_ji = system.stiffness.get(j, constrained_dof);
f[j] -= k_ji * prescribed_val;
}
}
}
let mut triplets: Vec<(usize, usize, f64)> = Vec::new();
for i in 0..n {
if constrained.contains_key(&i) {
triplets.push((i, i, 1.0));
f[i] = constrained.get(&i).copied().unwrap_or(0.0);
} else {
let start = system.stiffness.row_ptr[i];
let end = system.stiffness.row_ptr[i + 1];
for idx in start..end {
let j = system.stiffness.col_idx[idx];
if !constrained.contains_key(&j) {
triplets.push((i, j, system.stiffness.values[idx]));
}
}
}
}
let k_mod = CsrMatrix::from_triplets(n, n, &triplets);
Ok((k_mod, f))
}
fn apply_dirichlet_penalty(
system: &GlobalSystem,
constrained: &HashMap<usize, f64>,
load: &mut [f64],
penalty: f64,
) -> FemResult<(CsrMatrix, Vec<f64>)> {
let mut triplets: Vec<(usize, usize, f64)> = Vec::new();
for i in 0..system.stiffness.n_rows {
let start = system.stiffness.row_ptr[i];
let end = system.stiffness.row_ptr[i + 1];
for idx in start..end {
triplets.push((
i,
system.stiffness.col_idx[idx],
system.stiffness.values[idx],
));
}
}
for (&dof, &value) in constrained {
triplets.push((dof, dof, penalty));
load[dof] += penalty * value;
}
let k_mod =
CsrMatrix::from_triplets(system.stiffness.n_rows, system.stiffness.n_cols, &triplets);
Ok((k_mod, load.to_vec()))
}
pub fn apply_natural_bc(_system: &GlobalSystem) -> FemResult<()> {
Ok(())
}
pub fn validate_boundary_conditions(
bc_set: &BoundaryConditionSet,
mesh: &Mesh,
dofs_per_node: usize,
) -> FemResult<()> {
for bc in &bc_set.dirichlet {
if bc.node_id >= mesh.num_nodes() {
return Err(FemError::BoundaryError(format!(
"Dirichlet BC references non-existent node {}",
bc.node_id
)));
}
if bc.local_dof >= dofs_per_node {
return Err(FemError::BoundaryError(format!(
"Dirichlet BC references invalid DOF {} (max {})",
bc.local_dof,
dofs_per_node - 1
)));
}
}
for bc in &bc_set.neumann {
if bc.node_id >= mesh.num_nodes() {
return Err(FemError::BoundaryError(format!(
"Neumann BC references non-existent node {}",
bc.node_id
)));
}
if bc.local_dof >= dofs_per_node {
return Err(FemError::BoundaryError(format!(
"Neumann BC references invalid DOF {} (max {})",
bc.local_dof,
dofs_per_node - 1
)));
}
}
let dirichlet_dofs: std::collections::HashSet<(usize, usize)> = bc_set
.dirichlet
.iter()
.map(|bc| (bc.node_id, bc.local_dof))
.collect();
for bc in &bc_set.neumann {
if dirichlet_dofs.contains(&(bc.node_id, bc.local_dof)) {
return Err(FemError::BoundaryError(format!(
"Conflicting BCs: node {} DOF {} has both Dirichlet and Neumann conditions",
bc.node_id, bc.local_dof
)));
}
}
if let Some(penalty) = bc_set.penalty_factor {
if penalty <= 0.0 {
return Err(FemError::BoundaryError(
"Penalty factor must be positive".to_string(),
));
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::super::assembly::assemble_global_system;
use super::super::mesh::Mesh;
use super::*;
#[test]
fn test_dirichlet_bc_creation() {
let bc = DirichletBc::scalar(5, 100.0);
assert_eq!(bc.node_id, 5);
assert_eq!(bc.local_dof, 0);
assert!((bc.value - 100.0).abs() < 1e-12);
let bc2 = DirichletBc::new(3, 1, -2.5);
assert_eq!(bc2.node_id, 3);
assert_eq!(bc2.local_dof, 1);
assert!((bc2.value - (-2.5)).abs() < 1e-12);
}
#[test]
fn test_neumann_bc_creation() {
let bc = NeumannBc::new(3, 1, -50.0);
assert_eq!(bc.node_id, 3);
assert_eq!(bc.local_dof, 1);
assert!((bc.value - (-50.0)).abs() < 1e-12);
let bc2 = NeumannBc::scalar(0, 10.0);
assert_eq!(bc2.node_id, 0);
assert_eq!(bc2.local_dof, 0);
}
#[test]
fn test_bc_set_builder() {
let mut bc_set = BoundaryConditionSet::new();
assert!(bc_set.is_empty());
bc_set
.add_scalar_dirichlet(&[(0, 0.0), (10, 1.0)])
.add_point_force(5, 0, 100.0);
assert_eq!(bc_set.num_dirichlet(), 2);
assert_eq!(bc_set.num_neumann(), 1);
assert!(!bc_set.is_empty());
}
#[test]
fn test_bc_processing() {
let mesh = Mesh::generate_1d(0.0, 1.0, 5).expect("mesh gen should succeed");
let mut bc_set = BoundaryConditionSet::new();
bc_set.add_scalar_dirichlet(&[(0, 0.0), (5, 1.0)]);
let bc =
BoundaryCondition::from_set(&bc_set, &mesh, 1).expect("BC processing should succeed");
assert_eq!(bc.constrained_dofs.len(), 2);
assert!((bc.constrained_dofs[&0] - 0.0).abs() < 1e-12);
assert!((bc.constrained_dofs[&5] - 1.0).abs() < 1e-12);
}
#[test]
fn test_bc_processing_invalid_node() {
let mesh = Mesh::generate_1d(0.0, 1.0, 5).expect("mesh gen should succeed");
let mut bc_set = BoundaryConditionSet::new();
bc_set.add_scalar_dirichlet(&[(100, 0.0)]);
let result = BoundaryCondition::from_set(&bc_set, &mesh, 1);
assert!(result.is_err());
}
#[test]
fn test_apply_dirichlet_elimination() {
let mesh = Mesh::generate_1d(0.0, 1.0, 2).expect("mesh gen should succeed");
let system =
assemble_global_system(&mesh, 1.0, &|_| 0.0, 2, 1).expect("assembly should succeed");
let mut bc_set = BoundaryConditionSet::new();
bc_set.add_scalar_dirichlet(&[(0, 0.0), (2, 1.0)]);
let bc =
BoundaryCondition::from_set(&bc_set, &mesh, 1).expect("BC processing should succeed");
let (k_mod, f_mod) =
apply_boundary_conditions(&system, &bc).expect("BC application should succeed");
assert!((k_mod.get(0, 0) - 1.0).abs() < 1e-12);
assert!((k_mod.get(2, 2) - 1.0).abs() < 1e-12);
assert!(k_mod.get(0, 1).abs() < 1e-12);
assert!((f_mod[0] - 0.0).abs() < 1e-12);
assert!((f_mod[2] - 1.0).abs() < 1e-12);
}
#[test]
fn test_apply_neumann_bc() {
let mesh = Mesh::generate_1d(0.0, 1.0, 4).expect("mesh gen should succeed");
let system =
assemble_global_system(&mesh, 1.0, &|_| 0.0, 2, 1).expect("assembly should succeed");
let mut bc_set = BoundaryConditionSet::new();
bc_set
.add_dirichlet(DirichletBc::scalar(0, 0.0))
.add_neumann(NeumannBc::scalar(4, 10.0));
let bc =
BoundaryCondition::from_set(&bc_set, &mesh, 1).expect("BC processing should succeed");
let (_k_mod, f_mod) =
apply_boundary_conditions(&system, &bc).expect("BC application should succeed");
assert!((f_mod[4] - 10.0).abs() < 1e-10);
}
#[test]
fn test_apply_penalty_method() {
let mesh = Mesh::generate_1d(0.0, 1.0, 2).expect("mesh gen should succeed");
let system =
assemble_global_system(&mesh, 1.0, &|_| 0.0, 2, 1).expect("assembly should succeed");
let mut bc_set = BoundaryConditionSet::new();
bc_set
.add_scalar_dirichlet(&[(0, 0.0), (2, 1.0)])
.set_penalty_factor(1e10);
let bc =
BoundaryCondition::from_set(&bc_set, &mesh, 1).expect("BC processing should succeed");
let (k_mod, f_mod) =
apply_boundary_conditions(&system, &bc).expect("BC application should succeed");
let k00 = k_mod.get(0, 0);
assert!(k00 > 1e9);
assert!(f_mod[2] > 1e9);
}
#[test]
fn test_validate_boundary_conditions() {
let mesh = Mesh::generate_1d(0.0, 1.0, 5).expect("mesh gen should succeed");
let mut bc_set = BoundaryConditionSet::new();
bc_set.add_scalar_dirichlet(&[(0, 0.0), (5, 1.0)]);
assert!(validate_boundary_conditions(&bc_set, &mesh, 1).is_ok());
let mut bad_bc = BoundaryConditionSet::new();
bad_bc.add_scalar_dirichlet(&[(100, 0.0)]);
assert!(validate_boundary_conditions(&bad_bc, &mesh, 1).is_err());
let mut conflict = BoundaryConditionSet::new();
conflict
.add_dirichlet(DirichletBc::scalar(0, 0.0))
.add_neumann(NeumannBc::scalar(0, 1.0));
assert!(validate_boundary_conditions(&conflict, &mesh, 1).is_err());
}
#[test]
fn test_validate_penalty_factor() {
let mesh = Mesh::generate_1d(0.0, 1.0, 3).expect("mesh gen should succeed");
let mut bc_set = BoundaryConditionSet::new();
bc_set.add_scalar_dirichlet(&[(0, 0.0)]);
bc_set.set_penalty_factor(-1.0);
assert!(validate_boundary_conditions(&bc_set, &mesh, 1).is_err());
}
}