1use std::collections::HashMap;
4
5use phyz_math::{GRAVITY, Mat3, Vec3};
6
7use crate::particle::Particle;
8
9type GridIndex = (i32, i32, i32);
11
12#[derive(Debug, Clone, Default)]
14struct GridNode {
15 mass: f64,
16 momentum: Vec3,
17 velocity: Vec3,
18 force: Vec3,
19}
20
21pub struct MpmSolver {
23 pub h: f64,
25 pub dt: f64,
27 pub alpha: f64,
29 pub bounds: (Vec3, Vec3),
31 grid: HashMap<GridIndex, GridNode>,
33}
34
35impl MpmSolver {
36 pub fn new(h: f64, dt: f64, bounds: (Vec3, Vec3)) -> Self {
38 Self {
39 h,
40 dt,
41 alpha: 0.95, bounds,
43 grid: HashMap::new(),
44 }
45 }
46
47 pub fn step(&mut self, particles: &mut [Particle]) {
49 self.grid.clear();
51
52 self.particle_to_grid(particles);
54
55 self.grid_update();
57
58 self.grid_to_particle(particles);
60 }
61
62 fn position_to_index(&self, x: &Vec3) -> GridIndex {
64 (
65 (x.x / self.h).floor() as i32,
66 (x.y / self.h).floor() as i32,
67 (x.z / self.h).floor() as i32,
68 )
69 }
70
71 fn index_to_position(&self, idx: GridIndex) -> Vec3 {
73 Vec3::new(
74 idx.0 as f64 * self.h,
75 idx.1 as f64 * self.h,
76 idx.2 as f64 * self.h,
77 )
78 }
79
80 fn weight(&self, x: &Vec3, xi: &Vec3) -> f64 {
82 let dx = (x - xi) / self.h;
83 self.n(dx.x) * self.n(dx.y) * self.n(dx.z)
84 }
85
86 fn n(&self, x: f64) -> f64 {
88 let x = x.abs();
89 if x < 1.0 {
90 0.5 * x * x * x - x * x + 2.0 / 3.0
91 } else if x < 2.0 {
92 -(x - 2.0).powi(3) / 6.0
93 } else {
94 0.0
95 }
96 }
97
98 fn weight_gradient(&self, x: &Vec3, xi: &Vec3) -> Vec3 {
100 let dx = (x - xi) / self.h;
101 Vec3::new(
102 self.dn(dx.x) * self.n(dx.y) * self.n(dx.z) / self.h,
103 self.n(dx.x) * self.dn(dx.y) * self.n(dx.z) / self.h,
104 self.n(dx.x) * self.n(dx.y) * self.dn(dx.z) / self.h,
105 )
106 }
107
108 fn dn(&self, x: f64) -> f64 {
110 let sign = x.signum();
111 let x = x.abs();
112 if x < 1.0 {
113 sign * (1.5 * x * x - 2.0 * x)
114 } else if x < 2.0 {
115 sign * (-0.5 * (x - 2.0).powi(2))
116 } else {
117 0.0
118 }
119 }
120
121 fn get_neighbors(&self, idx: GridIndex) -> Vec<GridIndex> {
123 let mut neighbors = Vec::new();
124 for di in -1..=1 {
125 for dj in -1..=1 {
126 for dk in -1..=1 {
127 neighbors.push((idx.0 + di, idx.1 + dj, idx.2 + dk));
128 }
129 }
130 }
131 neighbors
132 }
133
134 fn particle_to_grid(&mut self, particles: &[Particle]) {
136 for p in particles {
137 let base_idx = self.position_to_index(&p.x);
138 let neighbors = self.get_neighbors(base_idx);
139
140 let stress = p.material.compute_stress(&p.f, p.j);
142
143 for &idx in &neighbors {
144 let xi = self.index_to_position(idx);
145 let w = self.weight(&p.x, &xi);
146
147 if w > 1e-12 {
148 let grad_w = self.weight_gradient(&p.x, &xi);
150 let force = -p.volume * stress * grad_w;
151
152 let node = self.grid.entry(idx).or_default();
154
155 node.mass += w * p.mass;
157 node.momentum += w * p.mass * (p.v + p.c * (xi - p.x));
158
159 node.force += force;
161 }
162 }
163 }
164
165 for node in self.grid.values_mut() {
167 if node.mass > 1e-12 {
168 node.velocity = node.momentum / node.mass;
169 }
170 }
171 }
172
173 fn grid_update(&mut self) {
175 let indices: Vec<GridIndex> = self.grid.keys().copied().collect();
177
178 for idx in indices {
179 let xi = self.index_to_position(idx);
181 let epsilon = 0.01; let bounds = self.bounds;
183
184 let node = self.grid.get_mut(&idx).unwrap();
185
186 if node.mass < 1e-12 {
187 continue;
188 }
189
190 node.force += node.mass * Vec3::new(0.0, -GRAVITY, 0.0);
192
193 let acc = node.force / node.mass;
195 node.velocity += acc * self.dt;
196
197 if xi.y < bounds.0.y + epsilon && node.velocity.y < 0.0 {
200 node.velocity.y = 0.0;
201 }
202 if xi.y > bounds.1.y - epsilon && node.velocity.y > 0.0 {
204 node.velocity.y = 0.0;
205 }
206 if xi.x < bounds.0.x + epsilon && node.velocity.x < 0.0 {
208 node.velocity.x = 0.0;
209 }
210 if xi.x > bounds.1.x - epsilon && node.velocity.x > 0.0 {
212 node.velocity.x = 0.0;
213 }
214 if xi.z < bounds.0.z + epsilon && node.velocity.z < 0.0 {
216 node.velocity.z = 0.0;
217 }
218 if xi.z > bounds.1.z - epsilon && node.velocity.z > 0.0 {
220 node.velocity.z = 0.0;
221 }
222 }
223 }
224
225 fn grid_to_particle(&self, particles: &mut [Particle]) {
227 for p in particles.iter_mut() {
228 let base_idx = self.position_to_index(&p.x);
229 let neighbors = self.get_neighbors(base_idx);
230
231 let mut v_pic = Vec3::zeros();
232 let mut v_old_grid = Vec3::zeros();
233 let mut grad_v = Mat3::zeros();
234 let mut total_w = 0.0;
235
236 for &idx in &neighbors {
237 if let Some(node) = self.grid.get(&idx) {
238 if node.mass < 1e-12 {
239 continue;
240 }
241
242 let xi = self.index_to_position(idx);
243 let w = self.weight(&p.x, &xi);
244
245 if w > 1e-12 {
246 v_pic += w * node.velocity;
248
249 let v_old = node.velocity - (node.force / node.mass) * self.dt;
251 v_old_grid += w * v_old;
252
253 let grad_w = self.weight_gradient(&p.x, &xi);
255 grad_v += node.velocity * grad_w.transpose();
256
257 total_w += w;
258 }
259 }
260 }
261
262 if total_w > 1e-12 {
263 let v_flip = p.v + (v_pic - v_old_grid);
265
266 let v_new = (1.0 - self.alpha) * v_pic + self.alpha * v_flip;
268
269 if v_new.x.is_finite() && v_new.y.is_finite() && v_new.z.is_finite() {
271 p.v = v_new;
272 }
273
274 p.x += p.v * self.dt;
276
277 p.x.x = p.x.x.clamp(self.bounds.0.x, self.bounds.1.x);
279 p.x.y = p.x.y.clamp(self.bounds.0.y, self.bounds.1.y);
280 p.x.z = p.x.z.clamp(self.bounds.0.z, self.bounds.1.z);
281
282 p.update_deformation(&grad_v, self.dt);
284
285 p.c = grad_v;
287 }
288 }
289 }
290
291 pub fn total_grid_mass(&self) -> f64 {
293 self.grid.values().map(|n| n.mass).sum()
294 }
295
296 pub fn total_momentum(particles: &[Particle]) -> Vec3 {
298 particles.iter().map(|p| p.mass * p.v).sum()
299 }
300
301 pub fn kinetic_energy(particles: &[Particle]) -> f64 {
303 particles
304 .iter()
305 .map(|p| 0.5 * p.mass * p.v.norm_squared())
306 .sum()
307 }
308}
309
310#[cfg(test)]
311mod tests {
312 use super::*;
313 use crate::material::Material;
314
315 #[test]
316 fn test_weight_function() {
317 let solver = MpmSolver::new(1.0, 0.01, (Vec3::zeros(), Vec3::new(10.0, 10.0, 10.0)));
318
319 let x = Vec3::new(5.0, 5.0, 5.0);
321 let w = solver.weight(&x, &x);
322 assert!(w > 0.2);
324
325 let x2 = Vec3::new(10.0, 10.0, 10.0);
327 let w2 = solver.weight(&x, &x2);
328 assert!(w2.abs() < 1e-12);
329 }
330
331 #[test]
332 fn test_mass_conservation() {
333 let mut solver = MpmSolver::new(0.1, 0.01, (Vec3::zeros(), Vec3::new(1.0, 1.0, 1.0)));
334
335 let mat = Material::Elastic { e: 1e6, nu: 0.3 };
336 let mut particles = vec![
337 Particle::new(Vec3::new(0.5, 0.5, 0.5), Vec3::zeros(), 1.0, 0.01, mat),
338 Particle::new(Vec3::new(0.6, 0.5, 0.5), Vec3::zeros(), 1.0, 0.01, mat),
339 ];
340
341 let initial_mass: f64 = particles.iter().map(|p| p.mass).sum();
342
343 solver.step(&mut particles);
344
345 let final_mass: f64 = particles.iter().map(|p| p.mass).sum();
346
347 assert!((initial_mass - final_mass).abs() < 1e-10);
348 }
349
350 #[test]
351 fn test_free_fall() {
352 let mut solver = MpmSolver::new(0.1, 0.01, (Vec3::zeros(), Vec3::new(1.0, 2.0, 1.0)));
353
354 let mat = Material::Elastic { e: 1e6, nu: 0.3 };
355 let mut particles = vec![Particle::new(
356 Vec3::new(0.5, 1.0, 0.5),
357 Vec3::zeros(),
358 1.0,
359 0.01,
360 mat,
361 )];
362
363 let initial_y = particles[0].x.y;
364
365 for _ in 0..50 {
367 solver.step(&mut particles);
368 }
369
370 assert!(
374 particles[0].x.y < initial_y - 0.1,
375 "Particle should have fallen, but y went from {} to {}",
376 initial_y,
377 particles[0].x.y
378 );
379
380 }
383
384 #[test]
385 fn test_momentum_conservation() {
386 use crate::material::EquationOfState;
387
388 let mut solver = MpmSolver::new(0.1, 0.001, (Vec3::zeros(), Vec3::new(2.0, 2.0, 2.0)));
389
390 let mat = Material::Fluid {
392 viscosity: 1e-3,
393 eos: EquationOfState::IdealGas {
394 rho0: 1000.0,
395 cs: 10.0, },
397 };
398
399 let mut particles = vec![
400 Particle::new(
401 Vec3::new(0.8, 1.0, 1.0),
402 Vec3::new(0.1, 0.0, 0.0),
403 0.1,
404 0.001,
405 mat,
406 ),
407 Particle::new(
408 Vec3::new(1.2, 1.0, 1.0),
409 Vec3::new(-0.1, 0.0, 0.0),
410 0.1,
411 0.001,
412 mat,
413 ),
414 ];
415
416 let initial_momentum = MpmSolver::total_momentum(&particles);
417 let total_mass: f64 = particles.iter().map(|p| p.mass).sum();
418
419 let n_steps = 5;
421 for _ in 0..n_steps {
422 solver.step(&mut particles);
423 }
424
425 let final_momentum = MpmSolver::total_momentum(&particles);
426
427 assert!(
430 (initial_momentum.x - final_momentum.x).abs() < 0.5,
431 "X momentum changed from {} to {}",
432 initial_momentum.x,
433 final_momentum.x
434 );
435 assert!((initial_momentum.z - final_momentum.z).abs() < 0.5);
436
437 let expected_dp_y = -total_mass * GRAVITY * n_steps as f64 * solver.dt;
440 let actual_dp_y = final_momentum.y - initial_momentum.y;
441 assert!(
442 (expected_dp_y - actual_dp_y).abs() < 1.0,
443 "Vertical momentum change {} vs expected {}",
444 actual_dp_y,
445 expected_dp_y
446 );
447 }
448}