use super::{FemError, FemResult};
use std::collections::{HashMap, HashSet};
#[derive(Clone, Debug, PartialEq)]
pub struct Node {
pub id: usize,
pub coords: Vec<f64>,
pub dofs: Vec<usize>,
}
impl Node {
pub fn new(id: usize, coords: Vec<f64>) -> Self {
Self {
id,
coords,
dofs: Vec::new(),
}
}
pub fn dimension(&self) -> usize {
self.coords.len()
}
pub fn distance_to(&self, other: &Node) -> f64 {
self.coords
.iter()
.zip(other.coords.iter())
.map(|(a, b)| (a - b) * (a - b))
.sum::<f64>()
.sqrt()
}
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub enum ElementKind {
Line2,
Triangle3,
Quad4,
}
impl ElementKind {
pub fn num_nodes(&self) -> usize {
match self {
ElementKind::Line2 => 2,
ElementKind::Triangle3 => 3,
ElementKind::Quad4 => 4,
}
}
pub fn dimension(&self) -> usize {
match self {
ElementKind::Line2 => 1,
ElementKind::Triangle3 | ElementKind::Quad4 => 2,
}
}
}
#[derive(Clone, Debug)]
pub struct Element {
pub id: usize,
pub kind: ElementKind,
pub nodes: Vec<usize>,
pub material_id: usize,
}
impl Element {
pub fn new(id: usize, kind: ElementKind, nodes: Vec<usize>) -> FemResult<Self> {
if nodes.len() != kind.num_nodes() {
return Err(FemError::ElementError(format!(
"Element type {:?} requires {} nodes, got {}",
kind,
kind.num_nodes(),
nodes.len()
)));
}
Ok(Self {
id,
kind,
nodes,
material_id: 0,
})
}
pub fn num_nodes(&self) -> usize {
self.nodes.len()
}
}
#[derive(Clone, Debug)]
pub struct MeshQuality {
pub min_quality: f64,
pub max_quality: f64,
pub avg_quality: f64,
pub num_poor_elements: usize,
pub min_size: f64,
pub max_size: f64,
pub aspect_ratio_range: (f64, f64),
}
#[derive(Clone, Debug)]
pub struct Mesh {
pub nodes: Vec<Node>,
pub elements: Vec<Element>,
pub node_to_elements: HashMap<usize, HashSet<usize>>,
pub element_neighbors: HashMap<usize, HashSet<usize>>,
pub boundary_nodes: Vec<usize>,
pub dimension: usize,
}
impl Mesh {
pub fn new(dimension: usize) -> FemResult<Self> {
if dimension == 0 || dimension > 3 {
return Err(FemError::MeshError(format!(
"Spatial dimension must be 1, 2, or 3, got {}",
dimension
)));
}
Ok(Self {
nodes: Vec::new(),
elements: Vec::new(),
node_to_elements: HashMap::new(),
element_neighbors: HashMap::new(),
boundary_nodes: Vec::new(),
dimension,
})
}
pub fn add_node(&mut self, coords: Vec<f64>) -> FemResult<usize> {
if coords.len() != self.dimension {
return Err(FemError::MeshError(format!(
"Node coordinates dimension {} does not match mesh dimension {}",
coords.len(),
self.dimension
)));
}
let id = self.nodes.len();
self.nodes.push(Node::new(id, coords));
self.node_to_elements.insert(id, HashSet::new());
Ok(id)
}
pub fn add_element(&mut self, kind: ElementKind, node_ids: Vec<usize>) -> FemResult<usize> {
for &nid in &node_ids {
if nid >= self.nodes.len() {
return Err(FemError::MeshError(format!(
"Node index {} out of range (mesh has {} nodes)",
nid,
self.nodes.len()
)));
}
}
let eid = self.elements.len();
let element = Element::new(eid, kind, node_ids.clone())?;
self.elements.push(element);
for &nid in &node_ids {
if let Some(elems) = self.node_to_elements.get_mut(&nid) {
elems.insert(eid);
}
}
Ok(eid)
}
pub fn num_nodes(&self) -> usize {
self.nodes.len()
}
pub fn num_elements(&self) -> usize {
self.elements.len()
}
pub fn node_coords(&self, node_id: usize) -> FemResult<&[f64]> {
if node_id >= self.nodes.len() {
return Err(FemError::MeshError(format!(
"Node index {} out of range",
node_id
)));
}
Ok(&self.nodes[node_id].coords)
}
pub fn element_coords(&self, element_id: usize) -> FemResult<Vec<Vec<f64>>> {
if element_id >= self.elements.len() {
return Err(FemError::MeshError(format!(
"Element index {} out of range",
element_id
)));
}
let elem = &self.elements[element_id];
let mut coords = Vec::with_capacity(elem.nodes.len());
for &nid in &elem.nodes {
coords.push(self.nodes[nid].coords.clone());
}
Ok(coords)
}
pub fn build_adjacency(&mut self) {
self.element_neighbors.clear();
for eid in 0..self.elements.len() {
self.element_neighbors.insert(eid, HashSet::new());
}
for elem_set in self.node_to_elements.values() {
let elems: Vec<usize> = elem_set.iter().copied().collect();
for i in 0..elems.len() {
for j in (i + 1)..elems.len() {
if let Some(neighbors) = self.element_neighbors.get_mut(&elems[i]) {
neighbors.insert(elems[j]);
}
if let Some(neighbors) = self.element_neighbors.get_mut(&elems[j]) {
neighbors.insert(elems[i]);
}
}
}
}
}
pub fn identify_boundary_nodes(&mut self) {
self.boundary_nodes.clear();
if self.dimension == 1 {
for (nid, elem_set) in &self.node_to_elements {
if elem_set.len() <= 1 {
self.boundary_nodes.push(*nid);
}
}
self.boundary_nodes.sort();
return;
}
let mut edge_count: HashMap<(usize, usize), usize> = HashMap::new();
for elem in &self.elements {
let edges = element_edges(&elem.kind, &elem.nodes);
for (n1, n2) in edges {
let key = if n1 < n2 { (n1, n2) } else { (n2, n1) };
*edge_count.entry(key).or_insert(0) += 1;
}
}
let mut boundary_set: HashSet<usize> = HashSet::new();
for ((n1, n2), count) in &edge_count {
if *count == 1 {
boundary_set.insert(*n1);
boundary_set.insert(*n2);
}
}
self.boundary_nodes = boundary_set.into_iter().collect();
self.boundary_nodes.sort();
}
pub fn generate_1d(a: f64, b: f64, n_elements: usize) -> FemResult<Self> {
if n_elements == 0 {
return Err(FemError::MeshError(
"Number of elements must be positive".to_string(),
));
}
if b <= a {
return Err(FemError::MeshError(format!(
"Right endpoint {} must be greater than left endpoint {}",
b, a
)));
}
let mut mesh = Mesh::new(1)?;
let h = (b - a) / n_elements as f64;
for i in 0..=n_elements {
let x = a + i as f64 * h;
mesh.add_node(vec![x])?;
}
for i in 0..n_elements {
mesh.add_element(ElementKind::Line2, vec![i, i + 1])?;
}
mesh.identify_boundary_nodes();
mesh.build_adjacency();
Ok(mesh)
}
pub fn generate_2d_rectangular(
x0: f64,
x1: f64,
y0: f64,
y1: f64,
nx: usize,
ny: usize,
) -> FemResult<Self> {
if nx == 0 || ny == 0 {
return Err(FemError::MeshError(
"Number of elements in each direction must be positive".to_string(),
));
}
if x1 <= x0 || y1 <= y0 {
return Err(FemError::MeshError(
"Domain extents must have positive length".to_string(),
));
}
let mut mesh = Mesh::new(2)?;
let hx = (x1 - x0) / nx as f64;
let hy = (y1 - y0) / ny as f64;
for j in 0..=ny {
for i in 0..=nx {
let x = x0 + i as f64 * hx;
let y = y0 + j as f64 * hy;
mesh.add_node(vec![x, y])?;
}
}
for j in 0..ny {
for i in 0..nx {
let n0 = j * (nx + 1) + i;
let n1 = n0 + 1;
let n2 = n1 + (nx + 1);
let n3 = n0 + (nx + 1);
mesh.add_element(ElementKind::Quad4, vec![n0, n1, n2, n3])?;
}
}
mesh.identify_boundary_nodes();
mesh.build_adjacency();
Ok(mesh)
}
pub fn generate_2d_triangular(
x0: f64,
x1: f64,
y0: f64,
y1: f64,
nx: usize,
ny: usize,
) -> FemResult<Self> {
if nx == 0 || ny == 0 {
return Err(FemError::MeshError(
"Number of cells in each direction must be positive".to_string(),
));
}
if x1 <= x0 || y1 <= y0 {
return Err(FemError::MeshError(
"Domain extents must have positive length".to_string(),
));
}
let mut mesh = Mesh::new(2)?;
let hx = (x1 - x0) / nx as f64;
let hy = (y1 - y0) / ny as f64;
for j in 0..=ny {
for i in 0..=nx {
let x = x0 + i as f64 * hx;
let y = y0 + j as f64 * hy;
mesh.add_node(vec![x, y])?;
}
}
for j in 0..ny {
for i in 0..nx {
let n0 = j * (nx + 1) + i;
let n1 = n0 + 1;
let n2 = n1 + (nx + 1);
let n3 = n0 + (nx + 1);
mesh.add_element(ElementKind::Triangle3, vec![n0, n1, n2])?;
mesh.add_element(ElementKind::Triangle3, vec![n0, n2, n3])?;
}
}
mesh.identify_boundary_nodes();
mesh.build_adjacency();
Ok(mesh)
}
pub fn quality(&self) -> FemResult<MeshQuality> {
if self.elements.is_empty() {
return Err(FemError::MeshError("Mesh has no elements".to_string()));
}
let mut qualities = Vec::with_capacity(self.elements.len());
let mut min_size = f64::INFINITY;
let mut max_size = 0.0_f64;
let mut min_aspect = f64::INFINITY;
let mut max_aspect = 0.0_f64;
for elem in &self.elements {
let coords: Vec<&[f64]> = elem
.nodes
.iter()
.map(|&nid| self.nodes[nid].coords.as_slice())
.collect();
let (quality, size, aspect) = match elem.kind {
ElementKind::Line2 => {
let length = euclidean_distance(coords[0], coords[1]);
(1.0, length, 1.0) }
ElementKind::Triangle3 => {
let q = triangle_quality(coords[0], coords[1], coords[2]);
let edges = [
euclidean_distance(coords[0], coords[1]),
euclidean_distance(coords[1], coords[2]),
euclidean_distance(coords[2], coords[0]),
];
let min_edge = edges.iter().copied().fold(f64::INFINITY, f64::min);
let max_edge = edges.iter().copied().fold(0.0, f64::max);
let aspect = if min_edge > 1e-15 {
max_edge / min_edge
} else {
f64::INFINITY
};
(q, min_edge, aspect)
}
ElementKind::Quad4 => {
let q = quad_quality(coords[0], coords[1], coords[2], coords[3]);
let edges = [
euclidean_distance(coords[0], coords[1]),
euclidean_distance(coords[1], coords[2]),
euclidean_distance(coords[2], coords[3]),
euclidean_distance(coords[3], coords[0]),
];
let min_edge = edges.iter().copied().fold(f64::INFINITY, f64::min);
let max_edge = edges.iter().copied().fold(0.0, f64::max);
let aspect = if min_edge > 1e-15 {
max_edge / min_edge
} else {
f64::INFINITY
};
(q, min_edge, aspect)
}
};
qualities.push(quality);
min_size = min_size.min(size);
max_size = max_size.max(size);
min_aspect = min_aspect.min(aspect);
max_aspect = max_aspect.max(aspect);
}
let min_quality = qualities.iter().copied().fold(f64::INFINITY, f64::min);
let max_quality = qualities.iter().copied().fold(0.0, f64::max);
let avg_quality = qualities.iter().sum::<f64>() / qualities.len() as f64;
let num_poor = qualities.iter().filter(|&&q| q < 0.1).count();
Ok(MeshQuality {
min_quality,
max_quality,
avg_quality,
num_poor_elements: num_poor,
min_size,
max_size,
aspect_ratio_range: (min_aspect, max_aspect),
})
}
pub fn refine(&self) -> FemResult<Mesh> {
match self.dimension {
1 => self.refine_1d(),
2 => self.refine_2d(),
_ => Err(FemError::MeshError(format!(
"Refinement not implemented for dimension {}",
self.dimension
))),
}
}
fn refine_1d(&self) -> FemResult<Mesh> {
let mut new_mesh = Mesh::new(1)?;
for node in &self.nodes {
new_mesh.add_node(node.coords.clone())?;
}
for elem in &self.elements {
let n0 = elem.nodes[0];
let n1 = elem.nodes[1];
let mid_x = (self.nodes[n0].coords[0] + self.nodes[n1].coords[0]) / 2.0;
let mid_id = new_mesh.add_node(vec![mid_x])?;
new_mesh.add_element(ElementKind::Line2, vec![n0, mid_id])?;
new_mesh.add_element(ElementKind::Line2, vec![mid_id, n1])?;
}
new_mesh.identify_boundary_nodes();
new_mesh.build_adjacency();
Ok(new_mesh)
}
fn refine_2d(&self) -> FemResult<Mesh> {
let mut new_mesh = Mesh::new(2)?;
for node in &self.nodes {
new_mesh.add_node(node.coords.clone())?;
}
let mut edge_midpoints: HashMap<(usize, usize), usize> = HashMap::new();
let get_or_create_midpoint = |mesh: &mut Mesh,
midpoints: &mut HashMap<(usize, usize), usize>,
n0: usize,
n1: usize,
nodes: &[Node]|
-> FemResult<usize> {
let key = if n0 < n1 { (n0, n1) } else { (n1, n0) };
if let Some(&mid_id) = midpoints.get(&key) {
return Ok(mid_id);
}
let mid_coords: Vec<f64> = nodes[n0]
.coords
.iter()
.zip(nodes[n1].coords.iter())
.map(|(a, b)| (a + b) / 2.0)
.collect();
let mid_id = mesh.add_node(mid_coords)?;
midpoints.insert(key, mid_id);
Ok(mid_id)
};
for elem in &self.elements {
match elem.kind {
ElementKind::Triangle3 => {
let n0 = elem.nodes[0];
let n1 = elem.nodes[1];
let n2 = elem.nodes[2];
let m01 = get_or_create_midpoint(
&mut new_mesh,
&mut edge_midpoints,
n0,
n1,
&self.nodes,
)?;
let m12 = get_or_create_midpoint(
&mut new_mesh,
&mut edge_midpoints,
n1,
n2,
&self.nodes,
)?;
let m20 = get_or_create_midpoint(
&mut new_mesh,
&mut edge_midpoints,
n2,
n0,
&self.nodes,
)?;
new_mesh.add_element(ElementKind::Triangle3, vec![n0, m01, m20])?;
new_mesh.add_element(ElementKind::Triangle3, vec![m01, n1, m12])?;
new_mesh.add_element(ElementKind::Triangle3, vec![m20, m12, n2])?;
new_mesh.add_element(ElementKind::Triangle3, vec![m01, m12, m20])?;
}
ElementKind::Quad4 => {
let n0 = elem.nodes[0];
let n1 = elem.nodes[1];
let n2 = elem.nodes[2];
let n3 = elem.nodes[3];
let m01 = get_or_create_midpoint(
&mut new_mesh,
&mut edge_midpoints,
n0,
n1,
&self.nodes,
)?;
let m12 = get_or_create_midpoint(
&mut new_mesh,
&mut edge_midpoints,
n1,
n2,
&self.nodes,
)?;
let m23 = get_or_create_midpoint(
&mut new_mesh,
&mut edge_midpoints,
n2,
n3,
&self.nodes,
)?;
let m30 = get_or_create_midpoint(
&mut new_mesh,
&mut edge_midpoints,
n3,
n0,
&self.nodes,
)?;
let center_coords: Vec<f64> = (0..self.dimension)
.map(|d| {
(self.nodes[n0].coords[d]
+ self.nodes[n1].coords[d]
+ self.nodes[n2].coords[d]
+ self.nodes[n3].coords[d])
/ 4.0
})
.collect();
let center = new_mesh.add_node(center_coords)?;
new_mesh.add_element(ElementKind::Quad4, vec![n0, m01, center, m30])?;
new_mesh.add_element(ElementKind::Quad4, vec![m01, n1, m12, center])?;
new_mesh.add_element(ElementKind::Quad4, vec![center, m12, n2, m23])?;
new_mesh.add_element(ElementKind::Quad4, vec![m30, center, m23, n3])?;
}
ElementKind::Line2 => {
return Err(FemError::MeshError(
"Cannot refine 1D elements in a 2D mesh".to_string(),
));
}
}
}
new_mesh.identify_boundary_nodes();
new_mesh.build_adjacency();
Ok(new_mesh)
}
}
fn element_edges(kind: &ElementKind, nodes: &[usize]) -> Vec<(usize, usize)> {
match kind {
ElementKind::Line2 => vec![(nodes[0], nodes[1])],
ElementKind::Triangle3 => vec![
(nodes[0], nodes[1]),
(nodes[1], nodes[2]),
(nodes[2], nodes[0]),
],
ElementKind::Quad4 => vec![
(nodes[0], nodes[1]),
(nodes[1], nodes[2]),
(nodes[2], nodes[3]),
(nodes[3], nodes[0]),
],
}
}
fn euclidean_distance(p1: &[f64], p2: &[f64]) -> f64 {
p1.iter()
.zip(p2.iter())
.map(|(a, b)| (a - b) * (a - b))
.sum::<f64>()
.sqrt()
}
fn triangle_quality(p0: &[f64], p1: &[f64], p2: &[f64]) -> f64 {
let a = euclidean_distance(p0, p1);
let b = euclidean_distance(p1, p2);
let c = euclidean_distance(p2, p0);
let s = (a + b + c) / 2.0;
let area_sq = s * (s - a) * (s - b) * (s - c);
if area_sq <= 0.0 {
return 0.0;
}
let area = area_sq.sqrt();
let denom = a * a + b * b + c * c;
if denom < 1e-30 {
return 0.0;
}
4.0 * 3.0_f64.sqrt() * area / denom
}
fn quad_quality(p0: &[f64], p1: &[f64], p2: &[f64], p3: &[f64]) -> f64 {
let d1 = euclidean_distance(p0, p2);
let d2 = euclidean_distance(p1, p3);
let min_d = d1.min(d2);
let max_d = d1.max(d2);
if max_d < 1e-30 {
return 0.0;
}
min_d / max_d
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_node_creation() {
let node = Node::new(0, vec![1.0, 2.0]);
assert_eq!(node.id, 0);
assert_eq!(node.dimension(), 2);
assert_eq!(node.coords, vec![1.0, 2.0]);
}
#[test]
fn test_node_distance() {
let n1 = Node::new(0, vec![0.0, 0.0]);
let n2 = Node::new(1, vec![3.0, 4.0]);
assert!((n1.distance_to(&n2) - 5.0).abs() < 1e-12);
}
#[test]
fn test_element_kind() {
assert_eq!(ElementKind::Line2.num_nodes(), 2);
assert_eq!(ElementKind::Triangle3.num_nodes(), 3);
assert_eq!(ElementKind::Quad4.num_nodes(), 4);
assert_eq!(ElementKind::Line2.dimension(), 1);
assert_eq!(ElementKind::Triangle3.dimension(), 2);
}
#[test]
fn test_element_creation() {
let elem = Element::new(0, ElementKind::Triangle3, vec![0, 1, 2]);
assert!(elem.is_ok());
let bad = Element::new(0, ElementKind::Triangle3, vec![0, 1]);
assert!(bad.is_err());
}
#[test]
fn test_mesh_1d_generation() {
let mesh = Mesh::generate_1d(0.0, 1.0, 10).expect("mesh generation should succeed");
assert_eq!(mesh.num_nodes(), 11);
assert_eq!(mesh.num_elements(), 10);
assert_eq!(mesh.dimension, 1);
assert!(mesh.boundary_nodes.contains(&0));
assert!(mesh.boundary_nodes.contains(&10));
let first = mesh.node_coords(0).expect("node 0 should exist");
assert!((first[0] - 0.0).abs() < 1e-12);
let last = mesh.node_coords(10).expect("node 10 should exist");
assert!((last[0] - 1.0).abs() < 1e-12);
}
#[test]
fn test_mesh_2d_rectangular() {
let mesh = Mesh::generate_2d_rectangular(0.0, 1.0, 0.0, 1.0, 3, 3)
.expect("mesh generation should succeed");
assert_eq!(mesh.num_nodes(), 16); assert_eq!(mesh.num_elements(), 9); assert_eq!(mesh.dimension, 2);
assert!(!mesh.boundary_nodes.is_empty());
}
#[test]
fn test_mesh_2d_triangular() {
let mesh = Mesh::generate_2d_triangular(0.0, 1.0, 0.0, 1.0, 2, 2)
.expect("mesh generation should succeed");
assert_eq!(mesh.num_nodes(), 9); assert_eq!(mesh.num_elements(), 8); assert_eq!(mesh.dimension, 2);
}
#[test]
fn test_mesh_quality_1d() {
let mesh = Mesh::generate_1d(0.0, 1.0, 5).expect("mesh generation should succeed");
let quality = mesh.quality().expect("quality computation should succeed");
assert!((quality.min_quality - 1.0).abs() < 1e-12);
assert!((quality.avg_quality - 1.0).abs() < 1e-12);
assert_eq!(quality.num_poor_elements, 0);
}
#[test]
fn test_mesh_quality_2d_rectangular() {
let mesh = Mesh::generate_2d_rectangular(0.0, 1.0, 0.0, 1.0, 4, 4)
.expect("mesh generation should succeed");
let quality = mesh.quality().expect("quality computation should succeed");
assert!((quality.min_quality - 1.0).abs() < 1e-12);
}
#[test]
fn test_mesh_refinement_1d() {
let mesh = Mesh::generate_1d(0.0, 1.0, 4).expect("mesh generation should succeed");
let refined = mesh.refine().expect("refinement should succeed");
assert_eq!(refined.num_nodes(), 9); assert_eq!(refined.num_elements(), 8); }
#[test]
fn test_mesh_refinement_2d_triangular() {
let mesh = Mesh::generate_2d_triangular(0.0, 1.0, 0.0, 1.0, 1, 1)
.expect("mesh generation should succeed");
assert_eq!(mesh.num_elements(), 2);
let refined = mesh.refine().expect("refinement should succeed");
assert_eq!(refined.num_elements(), 8); }
#[test]
fn test_mesh_adjacency() {
let mesh = Mesh::generate_1d(0.0, 1.0, 3).expect("mesh generation should succeed");
let neighbors = mesh
.element_neighbors
.get(&0)
.expect("element 0 should have neighbors");
assert!(neighbors.contains(&1));
}
#[test]
fn test_invalid_mesh_generation() {
assert!(Mesh::generate_1d(1.0, 0.0, 5).is_err());
assert!(Mesh::generate_1d(0.0, 1.0, 0).is_err());
assert!(Mesh::generate_2d_rectangular(0.0, 1.0, 0.0, 1.0, 0, 5).is_err());
}
}