use nalgebra::SVector;
use sprs::{CsMat, TriMat};
use crate::extrinsic::ExtrinsicOperators;
use crate::mesh::Mesh;
use cartan_core::Manifold;
const D: usize = 3;
pub struct StokesSolverAL {
ops: ExtrinsicOperators,
augmented: CsMat<f64>,
killing_basis: Vec<Vec<f64>>,
penalty: f64,
tolerance: f64,
max_al_iterations: usize,
max_cg_iterations: usize,
n_vertices: usize,
}
#[derive(Debug, Clone)]
pub struct StokesResult {
pub velocity: Vec<f64>,
pub pressure: Vec<f64>,
pub div_residual: f64,
pub al_iterations: usize,
}
impl StokesSolverAL {
pub fn new<M: Manifold<Point = SVector<f64, D>>>(
mesh: &Mesh<M, 3, 2>,
penalty: f64,
tolerance: f64,
max_al_iterations: usize,
max_cg_iterations: usize,
) -> Self {
let ops = ExtrinsicOperators::from_mesh(mesh);
let nv = ops.n_vertices;
let dtd = &ops.div.transpose_view().to_csc() * &ops.div;
let neg_l = ops.viscosity_lap.map(|&v| -v);
let augmented_base = &neg_l + &(dtd.map(|&v| v * penalty));
let n = 3 * nv;
let eps = 1e-8;
let mut reg = TriMat::new((n, n));
for i in 0..n {
reg.add_triplet(i, i, eps);
}
let augmented = &augmented_base + ®.to_csc();
let killing_basis = compute_killing_basis(mesh);
Self {
ops,
augmented,
killing_basis,
penalty,
tolerance,
max_al_iterations,
max_cg_iterations,
n_vertices: nv,
}
}
pub fn solve(&self, force: &[f64]) -> StokesResult {
let n = 3 * self.n_vertices;
assert_eq!(force.len(), n);
let mut pressure = vec![0.0; self.n_vertices];
let mut velocity = vec![0.0; n];
let mut div_residual = f64::MAX;
let force_norm = force.iter().map(|x| x * x).sum::<f64>().sqrt().max(1e-15);
for iter in 0..self.max_al_iterations {
let grad_p = self.ops.apply_grad(&pressure);
let rhs: Vec<f64> = force.iter().zip(&grad_p).map(|(f, gp)| f - gp).collect();
velocity = self.cg_solve(&rhs, &velocity);
self.project_out_killing(&mut velocity);
let div_u = self.ops.apply_div(&velocity);
for i in 0..self.n_vertices {
pressure[i] += self.penalty * div_u[i];
}
div_residual = div_u.iter().map(|x| x * x).sum::<f64>().sqrt();
if div_residual / force_norm < self.tolerance {
return StokesResult {
velocity,
pressure,
div_residual,
al_iterations: iter + 1,
};
}
}
StokesResult {
velocity,
pressure,
div_residual,
al_iterations: self.max_al_iterations,
}
}
fn cg_solve(&self, b: &[f64], x0: &[f64]) -> Vec<f64> {
let n = b.len();
let mut x = x0.to_vec();
let ax = sparse_matvec_real(&self.augmented, &x);
let mut r: Vec<f64> = b.iter().zip(&ax).map(|(bi, ai)| bi - ai).collect();
let mut p = r.clone();
let mut rs_old: f64 = r.iter().map(|ri| ri * ri).sum();
if rs_old.sqrt() < 1e-15 {
return x;
}
for _ in 0..self.max_cg_iterations {
let ap = sparse_matvec_real(&self.augmented, &p);
let pap: f64 = p.iter().zip(&ap).map(|(pi, api)| pi * api).sum();
if pap.abs() < 1e-30 {
break;
}
let alpha = rs_old / pap;
for i in 0..n {
x[i] += alpha * p[i];
r[i] -= alpha * ap[i];
}
let rs_new: f64 = r.iter().map(|ri| ri * ri).sum();
if rs_new.sqrt() < 1e-12 {
break;
}
let beta = rs_new / rs_old;
for i in 0..n {
p[i] = r[i] + beta * p[i];
}
rs_old = rs_new;
}
x
}
fn project_out_killing(&self, u: &mut [f64]) {
for basis in &self.killing_basis {
let dot: f64 = u.iter().zip(basis).map(|(a, b)| a * b).sum();
let norm_sq: f64 = basis.iter().map(|b| b * b).sum();
if norm_sq > 1e-30 {
let coeff = dot / norm_sq;
for i in 0..u.len() {
u[i] -= coeff * basis[i];
}
}
}
}
}
fn compute_killing_basis<M: Manifold<Point = SVector<f64, D>>>(
mesh: &Mesh<M, 3, 2>,
) -> Vec<Vec<f64>> {
let nv = mesh.n_vertices();
let n = 3 * nv;
let mut basis = Vec::with_capacity(6);
for axis in 0..3 {
let mut b = vec![0.0; n];
for v in 0..nv {
b[v * 3 + axis] = 1.0;
}
basis.push(b);
}
for axis in 0..3 {
let mut b = vec![0.0; n];
for v in 0..nv {
let r = mesh.vertices[v];
let cross = match axis {
0 => SVector::<f64, 3>::new(0.0, -r[2], r[1]), 1 => SVector::<f64, 3>::new(r[2], 0.0, -r[0]), 2 => SVector::<f64, 3>::new(-r[1], r[0], 0.0), _ => unreachable!(),
};
b[v * 3] = cross[0];
b[v * 3 + 1] = cross[1];
b[v * 3 + 2] = cross[2];
}
basis.push(b);
}
for i in 0..basis.len() {
for j in 0..i {
let dot: f64 = basis[i].iter().zip(&basis[j]).map(|(a, b)| a * b).sum();
let norm_sq: f64 = basis[j].iter().map(|b| b * b).sum();
if norm_sq > 1e-30 {
let coeff = dot / norm_sq;
let bj = basis[j].clone();
for (k, bj_k) in bj.iter().enumerate() {
basis[i][k] -= coeff * bj_k;
}
}
}
}
basis
}
fn sparse_matvec_real(mat: &CsMat<f64>, x: &[f64]) -> Vec<f64> {
let nrows = mat.rows();
let mut y = vec![0.0; nrows];
for (col, col_view) in mat.outer_iterator().enumerate() {
let xc = x[col];
if xc.abs() < 1e-30 {
continue;
}
for (row, &val) in col_view.iter() {
y[row] += val * xc;
}
}
y
}