use scirs2_core::{IntoParallelRefIterator, ParallelIterator};
use crate::fem::element::TriangularElement;
use crate::fem::mesh::Mesh2D;
#[derive(Debug, Clone)]
pub struct SparseMatrix {
pub nrows: usize,
pub ncols: usize,
pub rows: Vec<usize>,
pub cols: Vec<usize>,
pub vals: Vec<f64>,
}
impl SparseMatrix {
pub fn new(nrows: usize, ncols: usize) -> Self {
Self {
nrows,
ncols,
rows: Vec::new(),
cols: Vec::new(),
vals: Vec::new(),
}
}
pub fn add_entry(&mut self, i: usize, j: usize, val: f64) {
self.rows.push(i);
self.cols.push(j);
self.vals.push(val);
}
pub fn matvec(&self, x: &[f64]) -> Vec<f64> {
let mut y = vec![0.0; self.nrows];
for ((&i, &j), &val) in self.rows.iter().zip(&self.cols).zip(&self.vals) {
y[i] += val * x[j];
}
y
}
pub fn matvec_row(&self, x: &[f64], row: usize) -> f64 {
let mut result = 0.0;
for ((&i, &j), &val) in self.rows.iter().zip(&self.cols).zip(&self.vals) {
if i == row {
result += val * x[j];
}
}
result
}
}
pub fn assemble_stiffness_matrix(mesh: &Mesh2D) -> SparseMatrix {
let n = mesh.n_nodes();
let mut k = SparseMatrix::new(n, n);
for elem in &mesh.elements {
let nodes = [
mesh.nodes[elem.nodes[0]].position,
mesh.nodes[elem.nodes[1]].position,
mesh.nodes[elem.nodes[2]].position,
];
let tri_elem = TriangularElement::new(nodes);
let gradients = tri_elem.shape_gradients();
let area = tri_elem.area();
for i in 0..3 {
for j in 0..3 {
let k_local =
area * (gradients[i][0] * gradients[j][0] + gradients[i][1] * gradients[j][1]);
let global_i = elem.nodes[i];
let global_j = elem.nodes[j];
k.add_entry(global_i, global_j, k_local);
}
}
}
k
}
pub fn assemble_stiffness_matrix_parallel(mesh: &Mesh2D) -> SparseMatrix {
let n = mesh.n_nodes();
let element_entries: Vec<Vec<(usize, usize, f64)>> = mesh
.elements
.par_iter()
.map(|elem| {
let nodes = [
mesh.nodes[elem.nodes[0]].position,
mesh.nodes[elem.nodes[1]].position,
mesh.nodes[elem.nodes[2]].position,
];
let tri_elem = TriangularElement::new(nodes);
let gradients = tri_elem.shape_gradients();
let area = tri_elem.area();
let mut entries = Vec::with_capacity(9);
for i in 0..3 {
for j in 0..3 {
let k_local = area
* (gradients[i][0] * gradients[j][0] + gradients[i][1] * gradients[j][1]);
let global_i = elem.nodes[i];
let global_j = elem.nodes[j];
entries.push((global_i, global_j, k_local));
}
}
entries
})
.collect();
let mut k = SparseMatrix::new(n, n);
for entries in element_entries {
for (i, j, val) in entries {
k.add_entry(i, j, val);
}
}
k
}
pub fn assemble_mass_matrix(mesh: &Mesh2D) -> SparseMatrix {
let n = mesh.n_nodes();
let mut m = SparseMatrix::new(n, n);
for elem in &mesh.elements {
let nodes = [
mesh.nodes[elem.nodes[0]].position,
mesh.nodes[elem.nodes[1]].position,
mesh.nodes[elem.nodes[2]].position,
];
let tri_elem = TriangularElement::new(nodes);
let area = tri_elem.area();
for i in 0..3 {
for j in 0..3 {
let m_local = if i == j { area / 6.0 } else { area / 12.0 };
m.add_entry(elem.nodes[i], elem.nodes[j], m_local);
}
}
}
m
}
pub fn assemble_mass_matrix_parallel(mesh: &Mesh2D) -> SparseMatrix {
let n = mesh.n_nodes();
let element_entries: Vec<Vec<(usize, usize, f64)>> = mesh
.elements
.par_iter()
.map(|elem| {
let nodes = [
mesh.nodes[elem.nodes[0]].position,
mesh.nodes[elem.nodes[1]].position,
mesh.nodes[elem.nodes[2]].position,
];
let tri_elem = TriangularElement::new(nodes);
let area = tri_elem.area();
let mut entries = Vec::with_capacity(9);
for i in 0..3 {
for j in 0..3 {
let m_local = if i == j { area / 6.0 } else { area / 12.0 };
entries.push((elem.nodes[i], elem.nodes[j], m_local));
}
}
entries
})
.collect();
let mut m = SparseMatrix::new(n, n);
for entries in element_entries {
for (i, j, val) in entries {
m.add_entry(i, j, val);
}
}
m
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_assemble_stiffness() {
let mesh =
Mesh2D::rectangle(1.0, 1.0, 0.5).expect("rectangle mesh creation should succeed");
let k = assemble_stiffness_matrix(&mesh);
assert_eq!(k.nrows, mesh.n_nodes());
assert_eq!(k.ncols, mesh.n_nodes());
assert!(!k.vals.is_empty());
}
#[test]
fn test_assemble_mass() {
let mesh =
Mesh2D::rectangle(1.0, 1.0, 0.5).expect("rectangle mesh creation should succeed");
let m = assemble_mass_matrix(&mesh);
assert_eq!(m.nrows, mesh.n_nodes());
assert!(!m.vals.is_empty());
}
#[test]
fn test_parallel_stiffness_assembly() {
let mesh =
Mesh2D::rectangle(1.0, 1.0, 0.25).expect("rectangle mesh creation should succeed");
let k_serial = assemble_stiffness_matrix(&mesh);
let k_parallel = assemble_stiffness_matrix_parallel(&mesh);
assert_eq!(k_serial.nrows, k_parallel.nrows);
assert_eq!(k_serial.ncols, k_parallel.ncols);
assert_eq!(k_serial.vals.len(), k_parallel.vals.len());
}
#[test]
fn test_parallel_mass_assembly() {
let mesh =
Mesh2D::rectangle(1.0, 1.0, 0.25).expect("rectangle mesh creation should succeed");
let m_serial = assemble_mass_matrix(&mesh);
let m_parallel = assemble_mass_matrix_parallel(&mesh);
assert_eq!(m_serial.nrows, m_parallel.nrows);
assert_eq!(m_serial.ncols, m_parallel.ncols);
assert_eq!(m_serial.vals.len(), m_parallel.vals.len());
}
#[test]
fn test_parallel_assembly_works() {
let mesh = Mesh2D::rectangle(100e-9, 50e-9, 10e-9)
.expect("nanoscale mesh creation should succeed");
let k_parallel = assemble_stiffness_matrix_parallel(&mesh);
let m_parallel = assemble_mass_matrix_parallel(&mesh);
assert_eq!(k_parallel.nrows, mesh.n_nodes());
assert_eq!(k_parallel.ncols, mesh.n_nodes());
assert_eq!(m_parallel.nrows, mesh.n_nodes());
assert_eq!(m_parallel.ncols, mesh.n_nodes());
assert!(!k_parallel.vals.is_empty());
assert!(!m_parallel.vals.is_empty());
let k_serial = assemble_stiffness_matrix(&mesh);
assert_eq!(k_serial.vals.len(), k_parallel.vals.len());
}
}