#![allow(clippy::needless_range_loop)]
use rustfft::{FftPlanner, num_complex::Complex};
use volterra_core::ActiveNematicParams3D;
use volterra_fields::{QField3D, VelocityField3D, PressureField3D};
pub fn stokes_solve_3d(q: &QField3D, p: &ActiveNematicParams3D) -> (VelocityField3D, PressureField3D) {
let nx = q.nx;
let ny = q.ny;
let nz = q.nz;
let n = nx * ny * nz;
let dx = q.dx;
let inv_2dx = 1.0 / (2.0 * dx);
let mut f = vec![[0.0f64; 3]; n];
for i in 0..nx {
for j in 0..ny {
for l in 0..nz {
let k = q.idx(i, j, l);
let get_q = |ki: usize, row: usize, col: usize| -> f64 {
let [q11, q12, q13, q22, q23] = q.q[ki];
let q33 = -(q11 + q22);
match (row, col) {
(0, 0) => q11,
(0, 1) | (1, 0) => q12,
(0, 2) | (2, 0) => q13,
(1, 1) => q22,
(1, 2) | (2, 1) => q23,
(2, 2) => q33,
_ => 0.0,
}
};
let ip = q.idx((i + 1) % nx, j, l);
let im = q.idx((i + nx - 1) % nx, j, l);
let jp = q.idx(i, (j + 1) % ny, l);
let jm = q.idx(i, (j + ny - 1) % ny, l);
let lp = q.idx(i, j, (l + 1) % nz);
let lm = q.idx(i, j, (l + nz - 1) % nz);
for alpha in 0..3usize {
let div_q_alpha =
(get_q(ip, alpha, 0) - get_q(im, alpha, 0)) * inv_2dx
+ (get_q(jp, alpha, 1) - get_q(jm, alpha, 1)) * inv_2dx
+ (get_q(lp, alpha, 2) - get_q(lm, alpha, 2)) * inv_2dx;
f[k][alpha] = -p.zeta_eff * div_q_alpha;
}
}
}
}
let mut planner = FftPlanner::<f64>::new();
let mut f_hat: Vec<[Complex<f64>; 3]> = f
.iter()
.map(|fi| {
[
Complex::new(fi[0], 0.0),
Complex::new(fi[1], 0.0),
Complex::new(fi[2], 0.0),
]
})
.collect();
let fft_x = planner.plan_fft_forward(nx);
let fft_y = planner.plan_fft_forward(ny);
let fft_z = planner.plan_fft_forward(nz);
for comp in 0..3 {
for i in 0..nx {
for j in 0..ny {
let mut row: Vec<Complex<f64>> =
(0..nz).map(|l| f_hat[q.idx(i, j, l)][comp]).collect();
fft_z.process(&mut row);
for l in 0..nz {
f_hat[q.idx(i, j, l)][comp] = row[l];
}
}
}
for i in 0..nx {
for l in 0..nz {
let mut row: Vec<Complex<f64>> =
(0..ny).map(|j| f_hat[q.idx(i, j, l)][comp]).collect();
fft_y.process(&mut row);
for j in 0..ny {
f_hat[q.idx(i, j, l)][comp] = row[j];
}
}
}
for j in 0..ny {
for l in 0..nz {
let mut row: Vec<Complex<f64>> =
(0..nx).map(|i| f_hat[q.idx(i, j, l)][comp]).collect();
fft_x.process(&mut row);
for i in 0..nx {
f_hat[q.idx(i, j, l)][comp] = row[i];
}
}
}
}
let mut u_hat: Vec<[Complex<f64>; 3]> = vec![[Complex::new(0.0, 0.0); 3]; n];
for i in 0..nx {
for j in 0..ny {
for l in 0..nz {
let k = q.idx(i, j, l);
let kx = wavenumber(i, nx, dx);
let ky = wavenumber(j, ny, dx);
let kz = wavenumber(l, nz, dx);
let k2 = kx * kx + ky * ky + kz * kz;
if k2 < 1e-14 {
continue;
}
let kv = [kx, ky, kz];
let k_dot_f: Complex<f64> = kv
.iter()
.zip(f_hat[k].iter())
.map(|(&ki, &fi)| Complex::new(ki, 0.0) * fi)
.sum();
let inv_eta_k2 = 1.0 / (p.eta * k2);
for a in 0..3 {
u_hat[k][a] = (f_hat[k][a]
- Complex::new(kv[a] / k2, 0.0) * k_dot_f)
* inv_eta_k2;
}
}
}
}
let ifft_x = planner.plan_fft_inverse(nx);
let ifft_y = planner.plan_fft_inverse(ny);
let ifft_z = planner.plan_fft_inverse(nz);
for comp in 0..3 {
for i in 0..nx {
for j in 0..ny {
let mut row: Vec<Complex<f64>> =
(0..nz).map(|l| u_hat[q.idx(i, j, l)][comp]).collect();
ifft_z.process(&mut row);
for l in 0..nz {
u_hat[q.idx(i, j, l)][comp] = row[l];
}
}
}
for i in 0..nx {
for l in 0..nz {
let mut row: Vec<Complex<f64>> =
(0..ny).map(|j| u_hat[q.idx(i, j, l)][comp]).collect();
ifft_y.process(&mut row);
for j in 0..ny {
u_hat[q.idx(i, j, l)][comp] = row[j];
}
}
}
for j in 0..ny {
for l in 0..nz {
let mut row: Vec<Complex<f64>> =
(0..nx).map(|i| u_hat[q.idx(i, j, l)][comp]).collect();
ifft_x.process(&mut row);
for i in 0..nx {
u_hat[q.idx(i, j, l)][comp] = row[i];
}
}
}
}
let norm = 1.0 / (n as f64);
let mut u = VelocityField3D::zeros(nx, ny, nz, dx);
for k in 0..n {
for a in 0..3 {
u.u[k][a] = u_hat[k][a].re * norm;
}
}
(u, PressureField3D::zeros(nx, ny, nz, dx))
}
#[inline]
fn wavenumber(idx: usize, n: usize, dx: f64) -> f64 {
let i = if idx <= n / 2 {
idx as f64
} else {
idx as f64 - n as f64
};
let theta = 2.0 * std::f64::consts::PI * i / n as f64; theta.sin() / dx
}
#[cfg(test)]
mod tests {
use super::*;
use volterra_core::ActiveNematicParams3D;
use volterra_fields::QField3D;
#[test]
fn test_stokes_3d_incompressible() {
let p = ActiveNematicParams3D::default_test();
let q = QField3D::random_perturbation(8, 8, 8, 1.0, 0.1, 42);
let (u, _p_out) = stokes_solve_3d(&q, &p);
let div = u.divergence();
for d in &div.phi {
assert!(
d.abs() < 1e-8,
"Stokes output must be divergence-free, got divergence={}",
d
);
}
}
#[test]
fn test_stokes_3d_zero_q_gives_zero_u() {
let p = ActiveNematicParams3D::default_test();
let q = QField3D::zeros(8, 8, 8, 1.0);
let (u, _p_out) = stokes_solve_3d(&q, &p);
for &uv in &u.u {
for a in 0..3 {
assert!(
uv[a].abs() < 1e-12,
"zero Q must give zero velocity, got u[{}]={}",
a,
uv[a]
);
}
}
}
}