use nalgebra::DVector;
use cartan_core::Manifold;
use crate::mesh::{FlatMesh, Mesh};
pub fn apply_scalar_advection_generic<M: Manifold, const K: usize, const B: usize>(
mesh: &Mesh<M, K, B>,
manifold: &M,
f: &DVector<f64>,
u: &[M::Tangent],
) -> DVector<f64> {
let nv = mesh.n_vertices();
assert_eq!(f.len(), nv, "advection: f must have n_v entries");
assert_eq!(u.len(), nv, "advection: u must have n_v tangent vectors");
let mut result = DVector::<f64>::zeros(nv);
for v in 0..nv {
let pv = &mesh.vertices[v];
let uv = &u[v];
for &b in &mesh.vertex_boundaries[v] {
let boundary = &mesh.boundaries[b];
for &other in boundary {
if other == v {
continue;
}
let po = &mesh.vertices[other];
let edge_tangent = match manifold.log(pv, po) {
Ok(t) => t,
Err(_) => continue,
};
let len = manifold.norm(pv, &edge_tangent);
if len < 1e-30 {
continue;
}
let u_proj = manifold.inner(pv, uv, &edge_tangent) / len;
result[v] += if u_proj > 0.0 {
u_proj * (f[other] - f[v]) / len
} else {
u_proj * (f[v] - f[other]) / len
};
}
}
}
result
}
pub fn apply_scalar_advection(mesh: &FlatMesh, f: &DVector<f64>, u: &DVector<f64>) -> DVector<f64> {
let nv = mesh.n_vertices();
assert_eq!(f.len(), nv, "advection: f must have n_v entries");
assert_eq!(u.len(), 2 * nv, "advection: u must have 2*n_v entries");
let u_tangent: Vec<nalgebra::SVector<f64, 2>> = (0..nv)
.map(|v| nalgebra::SVector::<f64, 2>::new(u[v], u[nv + v]))
.collect();
let manifold = cartan_manifolds::euclidean::Euclidean::<2>;
apply_scalar_advection_generic(mesh, &manifold, f, &u_tangent)
}
pub fn apply_vector_advection(mesh: &FlatMesh, q: &DVector<f64>, u: &DVector<f64>) -> DVector<f64> {
let nv = mesh.n_vertices();
assert_eq!(
q.len(),
2 * nv,
"vector_advection: q must have 2*n_v entries"
);
assert_eq!(
u.len(),
2 * nv,
"vector_advection: u must have 2*n_v entries"
);
let qx = q.rows(0, nv).into_owned();
let qy = q.rows(nv, nv).into_owned();
let lqx = apply_scalar_advection(mesh, &qx, u);
let lqy = apply_scalar_advection(mesh, &qy, u);
let mut result = DVector::<f64>::zeros(2 * nv);
result.rows_mut(0, nv).copy_from(&lqx);
result.rows_mut(nv, nv).copy_from(&lqy);
result
}