use nalgebra::{Matrix3, SMatrix, SVector};
use sprs::{CsMat, TriMat};
use crate::mesh::Mesh;
use cartan_core::Manifold;
const D: usize = 3;
#[derive(Debug, Clone)]
pub struct FaceData {
pub normals: Vec<SVector<f64, D>>,
pub areas: Vec<f64>,
pub projectors: Vec<Matrix3<f64>>,
pub fem_grads: Vec<[SVector<f64, D>; 3]>,
}
impl FaceData {
pub fn from_mesh<M: Manifold<Point = SVector<f64, D>>>(mesh: &Mesh<M, 3, 2>) -> Self {
let nf = mesh.n_simplices();
let mut normals = Vec::with_capacity(nf);
let mut areas = Vec::with_capacity(nf);
let mut projectors = Vec::with_capacity(nf);
let mut fem_grads = Vec::with_capacity(nf);
for f in 0..nf {
let [i0, i1, i2] = mesh.simplices[f];
let v0 = mesh.vertices[i0];
let v1 = mesh.vertices[i1];
let v2 = mesh.vertices[i2];
let e01 = v1 - v0;
let e02 = v2 - v0;
let cross = e01.cross(&e02);
let area = 0.5 * cross.norm();
let n = if cross.norm() > 1e-30 {
cross / cross.norm()
} else {
SVector::<f64, D>::zeros()
};
let proj = Matrix3::identity() - n * n.transpose();
let e12 = v2 - v1; let e20 = v0 - v2; let e01_opp = v1 - v0;
let inv_2a = if area > 1e-30 {
1.0 / (2.0 * area)
} else {
0.0
};
let grad0 = inv_2a * n.cross(&e12);
let grad1 = inv_2a * n.cross(&e20);
let grad2 = inv_2a * n.cross(&e01_opp);
normals.push(n);
areas.push(area);
projectors.push(proj);
fem_grads.push([grad0, grad1, grad2]);
}
Self {
normals,
areas,
projectors,
fem_grads,
}
}
}
#[derive(Debug, Clone)]
pub struct ExtrinsicOperators {
pub n_vertices: usize,
pub n_faces: usize,
pub div: CsMat<f64>,
pub grad: CsMat<f64>,
pub viscosity_lap: CsMat<f64>,
pub face_areas: Vec<f64>,
}
impl ExtrinsicOperators {
pub fn from_mesh<M: Manifold<Point = SVector<f64, D>>>(mesh: &Mesh<M, 3, 2>) -> Self {
let face_data = FaceData::from_mesh(mesh);
let nv = mesh.n_vertices();
let nf = mesh.n_simplices();
let mut div_triplets = TriMat::new((nv, 3 * nv));
let mut lap_triplets = TriMat::new((3 * nv, 3 * nv));
for f in 0..nf {
let simplex = &mesh.simplices[f];
let proj = &face_data.projectors[f];
let grads = &face_data.fem_grads[f];
let area = face_data.areas[f];
if area < 1e-30 {
continue;
}
let mut k_f = SMatrix::<f64, 6, 9>::zeros();
for (local_v, grad) in grads.iter().enumerate().take(3) {
let col_offset = local_v * 3;
for i in 0..3 {
for k in 0..3 {
k_f[(k, col_offset + i)] += proj[(i, k)] * grad[k];
}
k_f[(3, col_offset + i)] +=
0.5 * (proj[(i, 0)] * grad[1] + proj[(i, 1)] * grad[0]);
k_f[(4, col_offset + i)] +=
0.5 * (proj[(i, 0)] * grad[2] + proj[(i, 2)] * grad[0]);
k_f[(5, col_offset + i)] +=
0.5 * (proj[(i, 1)] * grad[2] + proj[(i, 2)] * grad[1]);
}
}
let ktk = k_f.transpose() * k_f * (1.0 / area);
for local_a in 0..3 {
let va = simplex[local_a];
for local_b in 0..3 {
let vb = simplex[local_b];
for ia in 0..3 {
for ib in 0..3 {
let val = ktk[(local_a * 3 + ia, local_b * 3 + ib)];
if val.abs() > 1e-30 {
lap_triplets.add_triplet(va * 3 + ia, vb * 3 + ib, val);
}
}
}
}
}
for &v in simplex.iter().take(3) {
for (local_l, &vl) in simplex.iter().enumerate().take(3) {
for i in 0..3 {
let col = local_l * 3 + i;
let trace_val = k_f[(0, col)] + k_f[(1, col)] + k_f[(2, col)];
if trace_val.abs() > 1e-30 {
div_triplets.add_triplet(v, vl * 3 + i, trace_val * area / 3.0);
}
}
}
}
}
let div = div_triplets.to_csc();
let viscosity_lap = lap_triplets.to_csc();
let grad = div.transpose_view().to_csc().map(|&v| -v);
Self {
n_vertices: nv,
n_faces: nf,
div,
grad,
viscosity_lap,
face_areas: face_data.areas,
}
}
pub fn apply_div(&self, velocity: &[f64]) -> Vec<f64> {
assert_eq!(velocity.len(), 3 * self.n_vertices);
sparse_matvec(&self.div, velocity)
}
pub fn apply_grad(&self, pressure: &[f64]) -> Vec<f64> {
assert_eq!(pressure.len(), self.n_vertices);
sparse_matvec(&self.grad, pressure)
}
pub fn apply_viscosity_lap(&self, velocity: &[f64]) -> Vec<f64> {
assert_eq!(velocity.len(), 3 * self.n_vertices);
sparse_matvec(&self.viscosity_lap, velocity)
}
}
fn sparse_matvec(mat: &CsMat<f64>, x: &[f64]) -> Vec<f64> {
let nrows = mat.rows();
let mut y = vec![0.0; nrows];
for (col, col_view) in mat.outer_iterator().enumerate() {
let xc = x[col];
if xc.abs() < 1e-30 {
continue;
}
for (row, &val) in col_view.iter() {
y[row] += val * xc;
}
}
y
}