Skip to main content

caustic/tooling/core/algos/
flow_map.rs

1//! Flow-map Lagrangian representation.
2//! Stores forward maps X(t;q) and V(t;q) on a Lagrangian grid.
3//! Reconstructs f via Liouville's theorem: f(x,v,t) = f₀(q,p) where (q,p) are pre-images.
4//!
5//! Unlike `SheetTracker` which models cold (delta-function) distributions, `FlowMapRepr`
6//! samples f₀(q,p) at each Lagrangian point and carries a finite mass per tracer.
7//! The distribution function remains smooth even when f develops fine filaments,
8//! because the map coordinates X(t;q) and V(t;q) stay smooth.
9
10use rayon::prelude::*;
11
12use super::super::{
13    init::domain::{Domain, SpatialBoundType},
14    phasespace::PhaseSpaceRepr,
15    types::*,
16};
17use std::any::Any;
18use std::sync::Arc;
19use std::sync::atomic::{AtomicU64, Ordering};
20
21/// Flow-map Lagrangian representation of the 6D distribution function.
22///
23/// Stores `n_lag³ * nv_lag³` Lagrangian tracer points, each carrying:
24/// - Initial position q and velocity p (from IC)
25/// - Current position X(t;q,p) and velocity V(t;q,p)
26/// - Initial distribution value f₀(q,p) and particle mass
27///
28/// Density is recovered via CIC deposition, following the same
29/// pattern as `SheetTracker`.
30pub struct FlowMapRepr {
31    /// Lagrangian grid positions X(t;q) — flat [x0,y0,z0, x1,y1,z1, ...]
32    pub positions: Vec<f64>,
33    /// Lagrangian grid velocities V(t;q) — flat [vx0,vy0,vz0, ...]
34    pub velocities: Vec<f64>,
35    /// Initial distribution function value f₀(q,p) at each Lagrangian point
36    pub f0_values: Vec<f64>,
37    /// Mass of each Lagrangian tracer (= f₀ * dq³ * dp³)
38    pub masses: Vec<f64>,
39    /// Number of Lagrangian points per spatial dimension
40    pub n_lag: usize,
41    /// Number of Lagrangian points per velocity dimension
42    pub nv_lag: usize,
43    /// Spatial grid dimensions for density deposition [nx, ny, nz]
44    pub spatial_shape: [usize; 3],
45    /// Domain specification
46    pub domain: Domain,
47    /// Cached total mass (constant by Liouville)
48    total_mass_cached: f64,
49    /// Cached entropy (constant by Liouville)
50    entropy_cached: f64,
51    // Cached domain values
52    cached_dx: [f64; 3],
53    cached_lx: [f64; 3],
54    cached_lv: [f64; 3],
55    cached_is_periodic: bool,
56    /// Optional progress reporter
57    progress: Option<Arc<super::super::progress::StepProgress>>,
58}
59
60impl FlowMapRepr {
61    /// Create a new FlowMapRepr with all tracers at their initial positions and zero f₀.
62    ///
63    /// The Lagrangian grid spans the full domain: `n_lag³` points in spatial dimensions,
64    /// `nv_lag³` points in velocity dimensions. Each dimension is uniformly spaced.
65    pub fn new(domain: &Domain, n_lag: usize, nv_lag: usize) -> Self {
66        let dx = domain.dx();
67        let lx = domain.lx();
68        let lv = domain.lv();
69        let dv = domain.dv();
70
71        let spatial_shape = [
72            domain.spatial_res.x1 as usize,
73            domain.spatial_res.x2 as usize,
74            domain.spatial_res.x3 as usize,
75        ];
76
77        let n_total = n_lag * n_lag * n_lag * nv_lag * nv_lag * nv_lag;
78        let dq = [
79            2.0 * lx[0] / n_lag as f64,
80            2.0 * lx[1] / n_lag as f64,
81            2.0 * lx[2] / n_lag as f64,
82        ];
83        let dp = [
84            2.0 * lv[0] / nv_lag as f64,
85            2.0 * lv[1] / nv_lag as f64,
86            2.0 * lv[2] / nv_lag as f64,
87        ];
88
89        let mut positions = Vec::with_capacity(n_total * 3);
90        let mut velocities = Vec::with_capacity(n_total * 3);
91
92        for ix in 0..n_lag {
93            for iy in 0..n_lag {
94                for iz in 0..n_lag {
95                    for iv0 in 0..nv_lag {
96                        for iv1 in 0..nv_lag {
97                            for iv2 in 0..nv_lag {
98                                let x = -lx[0] + (ix as f64 + 0.5) * dq[0];
99                                let y = -lx[1] + (iy as f64 + 0.5) * dq[1];
100                                let z = -lx[2] + (iz as f64 + 0.5) * dq[2];
101                                let vx = -lv[0] + (iv0 as f64 + 0.5) * dp[0];
102                                let vy = -lv[1] + (iv1 as f64 + 0.5) * dp[1];
103                                let vz = -lv[2] + (iv2 as f64 + 0.5) * dp[2];
104                                positions.extend_from_slice(&[x, y, z]);
105                                velocities.extend_from_slice(&[vx, vy, vz]);
106                            }
107                        }
108                    }
109                }
110            }
111        }
112
113        let f0_values = vec![0.0; n_total];
114        let masses = vec![0.0; n_total];
115
116        let is_periodic = matches!(domain.spatial_bc, SpatialBoundType::Periodic);
117
118        Self {
119            positions,
120            velocities,
121            f0_values,
122            masses,
123            n_lag,
124            nv_lag,
125            spatial_shape,
126            domain: domain.clone(),
127            total_mass_cached: 0.0,
128            entropy_cached: 0.0,
129            cached_dx: dx,
130            cached_lx: lx,
131            cached_lv: lv,
132            cached_is_periodic: is_periodic,
133            progress: None,
134        }
135    }
136
137    /// Initialize from a `PhaseSpaceSnapshot` by sampling f₀ at Lagrangian grid points.
138    ///
139    /// Places `n_lag³ * nv_lag³` tracers on a regular grid in (x,v) space.
140    /// Records f₀ at each point via trilinear interpolation of the snapshot data,
141    /// and computes particle masses as `m_i = f₀(q_i, p_i) * dq³ * dp³`.
142    pub fn from_snapshot(
143        snap: &PhaseSpaceSnapshot,
144        domain: &Domain,
145        n_lag: usize,
146        nv_lag: usize,
147    ) -> Self {
148        let mut repr = Self::new(domain, n_lag, nv_lag);
149        let n_total = repr.num_tracers();
150
151        let lx = domain.lx();
152        let lv = domain.lv();
153        let dq = [
154            2.0 * lx[0] / n_lag as f64,
155            2.0 * lx[1] / n_lag as f64,
156            2.0 * lx[2] / n_lag as f64,
157        ];
158        let dp = [
159            2.0 * lv[0] / nv_lag as f64,
160            2.0 * lv[1] / nv_lag as f64,
161            2.0 * lv[2] / nv_lag as f64,
162        ];
163        let phase_vol = dq[0] * dq[1] * dq[2] * dp[0] * dp[1] * dp[2];
164
165        // Snapshot grid parameters
166        let [nx0, nx1, nx2, nv0, nv1, nv2] = snap.shape;
167        let snap_dx = domain.dx();
168        let snap_dv = domain.dv();
169
170        let mut total_mass = 0.0;
171        let mut entropy = 0.0;
172
173        for i in 0..n_total {
174            let pos = [
175                repr.positions[3 * i],
176                repr.positions[3 * i + 1],
177                repr.positions[3 * i + 2],
178            ];
179            let vel = [
180                repr.velocities[3 * i],
181                repr.velocities[3 * i + 1],
182                repr.velocities[3 * i + 2],
183            ];
184
185            // Interpolate f₀ from the snapshot at this (x, v) point
186            let f0 = Self::interpolate_6d(
187                &snap.data, snap.shape, &pos, &vel, &snap_dx, &snap_dv, &lx, &lv,
188            );
189
190            let f0 = f0.max(0.0);
191            repr.f0_values[i] = f0;
192            repr.masses[i] = f0 * phase_vol;
193            total_mass += repr.masses[i];
194
195            if f0 > 0.0 {
196                entropy -= f0 * f0.ln() * phase_vol;
197            }
198        }
199
200        repr.total_mass_cached = total_mass;
201        repr.entropy_cached = entropy;
202        repr
203    }
204
205    /// Total number of Lagrangian tracer points.
206    #[inline]
207    pub fn num_tracers(&self) -> usize {
208        self.n_lag * self.n_lag * self.n_lag * self.nv_lag * self.nv_lag * self.nv_lag
209    }
210
211    /// Trilinear interpolation of a 6D field at (x, v).
212    ///
213    /// The field is stored as a flat array with shape [nx0, nx1, nx2, nv0, nv1, nv2]
214    /// in row-major order.
215    fn interpolate_6d(
216        data: &[f64],
217        shape: [usize; 6],
218        pos: &[f64; 3],
219        vel: &[f64; 3],
220        dx: &[f64; 3],
221        dv: &[f64; 3],
222        lx: &[f64; 3],
223        lv: &[f64; 3],
224    ) -> f64 {
225        let [nx0, nx1, nx2, nv0, nv1, nv2] = shape;
226        let ns_x = [nx0, nx1, nx2];
227        let ns_v = [nv0, nv1, nv2];
228
229        // Compute spatial CIC indices and fractions
230        let mut x_ci = [0isize; 3];
231        let mut x_frac = [0.0f64; 3];
232        for k in 0..3 {
233            let s = (pos[k] + lx[k]) / dx[k] - 0.5;
234            x_ci[k] = s.floor() as isize;
235            x_frac[k] = s - x_ci[k] as f64;
236        }
237
238        // Compute velocity CIC indices and fractions
239        let mut v_ci = [0isize; 3];
240        let mut v_frac = [0.0f64; 3];
241        for k in 0..3 {
242            let s = (vel[k] + lv[k]) / dv[k] - 0.5;
243            v_ci[k] = s.floor() as isize;
244            v_frac[k] = s - v_ci[k] as f64;
245        }
246
247        // Strides for row-major 6D: [nx0, nx1, nx2, nv0, nv1, nv2]
248        let sv3 = 1usize;
249        let sv2 = nv2;
250        let sv1 = nv1 * nv2;
251        let sx3 = nv0 * sv1;
252        let sx2 = nx2 * sx3;
253        let sx1 = nx1 * sx2;
254
255        let mut result = 0.0;
256
257        for dix in 0..2isize {
258            let wx0 = if dix == 0 { 1.0 - x_frac[0] } else { x_frac[0] };
259            let ix0 = x_ci[0] + dix;
260            if ix0 < 0 || ix0 >= ns_x[0] as isize {
261                continue;
262            }
263            for diy in 0..2isize {
264                let wx1 = if diy == 0 { 1.0 - x_frac[1] } else { x_frac[1] };
265                let ix1 = x_ci[1] + diy;
266                if ix1 < 0 || ix1 >= ns_x[1] as isize {
267                    continue;
268                }
269                for diz in 0..2isize {
270                    let wx2 = if diz == 0 { 1.0 - x_frac[2] } else { x_frac[2] };
271                    let ix2 = x_ci[2] + diz;
272                    if ix2 < 0 || ix2 >= ns_x[2] as isize {
273                        continue;
274                    }
275                    let wx = wx0 * wx1 * wx2;
276
277                    for div0 in 0..2isize {
278                        let wv0 = if div0 == 0 {
279                            1.0 - v_frac[0]
280                        } else {
281                            v_frac[0]
282                        };
283                        let iv0 = v_ci[0] + div0;
284                        if iv0 < 0 || iv0 >= ns_v[0] as isize {
285                            continue;
286                        }
287                        for div1 in 0..2isize {
288                            let wv1 = if div1 == 0 {
289                                1.0 - v_frac[1]
290                            } else {
291                                v_frac[1]
292                            };
293                            let iv1 = v_ci[1] + div1;
294                            if iv1 < 0 || iv1 >= ns_v[1] as isize {
295                                continue;
296                            }
297                            for div2 in 0..2isize {
298                                let wv2 = if div2 == 0 {
299                                    1.0 - v_frac[2]
300                                } else {
301                                    v_frac[2]
302                                };
303                                let iv2 = v_ci[2] + div2;
304                                if iv2 < 0 || iv2 >= ns_v[2] as isize {
305                                    continue;
306                                }
307                                let wv = wv0 * wv1 * wv2;
308                                let flat = ix0 as usize * sx1
309                                    + ix1 as usize * sx2
310                                    + ix2 as usize * sx3
311                                    + iv0 as usize * sv1
312                                    + iv1 as usize * sv2
313                                    + iv2 as usize * sv3;
314                                result += wx * wv * data[flat];
315                            }
316                        }
317                    }
318                }
319            }
320        }
321
322        result
323    }
324
325    /// Trilinear interpolation of a 3D vector field at an arbitrary position.
326    ///
327    /// Used by `advect_v` to obtain the acceleration at each particle's position.
328    /// Same algorithm as `SheetTracker::interpolate_vec_field`.
329    fn interpolate_vec_field(
330        field_x: &[f64],
331        field_y: &[f64],
332        field_z: &[f64],
333        shape: [usize; 3],
334        pos: &[f64; 3],
335        dx: &[f64; 3],
336        lx: &[f64; 3],
337        periodic: bool,
338    ) -> [f64; 3] {
339        let [nx, ny, nz] = shape;
340
341        // Grid node at index i is at coordinate -L + (i + 0.5) * dx.
342        // Find the nearest lower node and fractional offset.
343        let mut ci = [0isize; 3];
344        let mut frac = [0.0f64; 3];
345        for k in 0..3 {
346            let s = (pos[k] + lx[k]) / dx[k] - 0.5;
347            ci[k] = s.floor() as isize;
348            frac[k] = s - ci[k] as f64;
349        }
350
351        let mut result = [0.0f64; 3];
352
353        for di in 0..2isize {
354            let wx = if di == 0 { 1.0 - frac[0] } else { frac[0] };
355            for dj in 0..2isize {
356                let wy = if dj == 0 { 1.0 - frac[1] } else { frac[1] };
357                for dk in 0..2isize {
358                    let wz = if dk == 0 { 1.0 - frac[2] } else { frac[2] };
359                    let w = wx * wy * wz;
360
361                    let mut ii = ci[0] + di;
362                    let mut jj = ci[1] + dj;
363                    let mut kk = ci[2] + dk;
364
365                    if periodic {
366                        ii = ii.rem_euclid(nx as isize);
367                        jj = jj.rem_euclid(ny as isize);
368                        kk = kk.rem_euclid(nz as isize);
369                    } else if ii < 0
370                        || ii >= nx as isize
371                        || jj < 0
372                        || jj >= ny as isize
373                        || kk < 0
374                        || kk >= nz as isize
375                    {
376                        continue;
377                    }
378
379                    let flat = ii as usize * ny * nz + jj as usize * nz + kk as usize;
380                    result[0] += w * field_x[flat];
381                    result[1] += w * field_y[flat];
382                    result[2] += w * field_z[flat];
383                }
384            }
385        }
386
387        result
388    }
389
390    /// Find the flat spatial cell index for a position, or None if outside domain.
391    fn cell_index(&self, pos: &[f64; 3]) -> Option<usize> {
392        let dx = self.cached_dx;
393        let lx = self.cached_lx;
394        let [nx, ny, nz] = self.spatial_shape;
395
396        let mut ci = [0usize; 3];
397        let ns = [nx, ny, nz];
398        for k in 0..3 {
399            let idx = ((pos[k] + lx[k]) / dx[k]).floor() as isize;
400            if self.cached_is_periodic {
401                ci[k] = idx.rem_euclid(ns[k] as isize) as usize;
402            } else if idx < 0 || idx >= ns[k] as isize {
403                return None;
404            } else {
405                ci[k] = idx as usize;
406            }
407        }
408
409        Some(ci[0] * ny * nz + ci[1] * nz + ci[2])
410    }
411
412    /// Collect indices of all tracers that lie in the same spatial cell as `position`.
413    fn tracers_in_cell(&self, position: &[f64; 3]) -> Vec<usize> {
414        let target = match self.cell_index(position) {
415            Some(c) => c,
416            None => return Vec::new(),
417        };
418
419        let n = self.num_tracers();
420        let mut result = Vec::new();
421        for i in 0..n {
422            let pos = [
423                self.positions[3 * i],
424                self.positions[3 * i + 1],
425                self.positions[3 * i + 2],
426            ];
427            if let Some(c) = self.cell_index(&pos)
428                && c == target
429            {
430                result.push(i);
431            }
432        }
433        result
434    }
435}
436
437impl PhaseSpaceRepr for FlowMapRepr {
438    /// Register a progress reporter for intra-step progress updates.
439    fn set_progress(&mut self, p: Arc<super::super::progress::StepProgress>) {
440        self.progress = Some(p);
441    }
442
443    /// CIC deposition of tracer masses onto the spatial grid.
444    ///
445    /// For each Lagrangian tracer, finds the 8 nearest spatial cells and distributes
446    /// its mass using trilinear weights, then divides by cell volume to get density.
447    /// Uses rayon parallelism with fold/reduce (same pattern as `SheetTracker`).
448    fn compute_density(&self) -> DensityField {
449        let [nx, ny, nz] = self.spatial_shape;
450        let n_cells = nx * ny * nz;
451        let dx = self.cached_dx;
452        let lx = self.cached_lx;
453        let cell_vol = self.domain.cell_volume_3d();
454        let is_periodic = self.cached_is_periodic;
455        let n_tracers = self.num_tracers();
456
457        let n_tracers_u64 = n_tracers as u64;
458        if let Some(ref p) = self.progress {
459            p.set_intra_progress(0, n_tracers_u64);
460        }
461
462        // Build a slice of (position_slice, mass) pairs for parallel iteration
463        let density_data: Vec<f64> = (0..n_tracers)
464            .into_par_iter()
465            .fold(
466                || vec![0.0f64; n_cells],
467                |mut local, i| {
468                    let mass = self.masses[i];
469                    if mass <= 0.0 {
470                        return local;
471                    }
472                    let pos = [
473                        self.positions[3 * i],
474                        self.positions[3 * i + 1],
475                        self.positions[3 * i + 2],
476                    ];
477
478                    let mut ci = [0isize; 3];
479                    let mut frac = [0.0f64; 3];
480                    for k in 0..3 {
481                        let s = (pos[k] + lx[k]) / dx[k] - 0.5;
482                        ci[k] = s.floor() as isize;
483                        frac[k] = s - ci[k] as f64;
484                    }
485                    for di in 0..2isize {
486                        let wx = if di == 0 { 1.0 - frac[0] } else { frac[0] };
487                        for dj in 0..2isize {
488                            let wy = if dj == 0 { 1.0 - frac[1] } else { frac[1] };
489                            for dk in 0..2isize {
490                                let wz = if dk == 0 { 1.0 - frac[2] } else { frac[2] };
491                                let w = wx * wy * wz;
492                                let mut ii = ci[0] + di;
493                                let mut jj = ci[1] + dj;
494                                let mut kk = ci[2] + dk;
495                                if is_periodic {
496                                    ii = ii.rem_euclid(nx as isize);
497                                    jj = jj.rem_euclid(ny as isize);
498                                    kk = kk.rem_euclid(nz as isize);
499                                } else if ii < 0
500                                    || ii >= nx as isize
501                                    || jj < 0
502                                    || jj >= ny as isize
503                                    || kk < 0
504                                    || kk >= nz as isize
505                                {
506                                    continue;
507                                }
508                                let flat = ii as usize * ny * nz + jj as usize * nz + kk as usize;
509                                local[flat] += mass * w;
510                            }
511                        }
512                    }
513                    local
514                },
515            )
516            .reduce(
517                || vec![0.0f64; n_cells],
518                |mut a, b| {
519                    for i in 0..n_cells {
520                        a[i] += b[i];
521                    }
522                    a
523                },
524            );
525
526        // Convert mass per cell to mass density
527        let mut density = density_data;
528        for d in &mut density {
529            *d /= cell_vol;
530        }
531
532        DensityField {
533            data: density,
534            shape: [nx, ny, nz],
535        }
536    }
537
538    /// Drift sub-step: `X[i] += V[i] * dt` for each tracer. Exact -- no interpolation.
539    ///
540    /// For periodic domains, positions are wrapped into [-L, L].
541    fn advect_x(&mut self, _displacement: &DisplacementField, dt: f64) {
542        let is_periodic = self.cached_is_periodic;
543        let lx = self.cached_lx;
544        let n_tracers = self.num_tracers();
545        let progress = self.progress.clone();
546        let n_tracers_u64 = n_tracers as u64;
547        let counter = AtomicU64::new(0);
548        let report_interval = (n_tracers_u64 / 100).max(1);
549
550        if let Some(ref p) = self.progress {
551            p.set_intra_progress(0, n_tracers_u64);
552        }
553
554        // Process positions and velocities as chunks of 3 (x,y,z)
555        let positions = &mut self.positions;
556        let velocities = &self.velocities;
557
558        // Use par_chunks_mut for parallel update
559        positions
560            .par_chunks_mut(3)
561            .enumerate()
562            .for_each(|(i, pos_chunk)| {
563                let vel_base = 3 * i;
564                for k in 0..3 {
565                    pos_chunk[k] += velocities[vel_base + k] * dt;
566                    if is_periodic {
567                        let two_l = 2.0 * lx[k];
568                        pos_chunk[k] = ((pos_chunk[k] + lx[k]).rem_euclid(two_l)) - lx[k];
569                    }
570                }
571                if let Some(ref prog) = progress {
572                    let c = counter.fetch_add(1, Ordering::Relaxed);
573                    if c.is_multiple_of(report_interval) {
574                        prog.set_intra_progress(c, n_tracers_u64);
575                    }
576                }
577            });
578    }
579
580    /// Kick sub-step: `V[i] += g(X[i]) * dt` for each tracer.
581    ///
582    /// The acceleration at each tracer's current position is obtained by trilinear
583    /// interpolation of the acceleration field.
584    fn advect_v(&mut self, acceleration: &AccelerationField, dt: f64) {
585        let dx = self.cached_dx;
586        let lx = self.cached_lx;
587        let is_periodic = self.cached_is_periodic;
588        let n_tracers = self.num_tracers();
589        let progress = self.progress.clone();
590        let n_tracers_u64 = n_tracers as u64;
591        let counter = AtomicU64::new(0);
592        let report_interval = (n_tracers_u64 / 100).max(1);
593
594        if let Some(ref p) = self.progress {
595            p.set_intra_progress(0, n_tracers_u64);
596        }
597
598        let positions = &self.positions;
599        let velocities = &mut self.velocities;
600
601        velocities
602            .par_chunks_mut(3)
603            .enumerate()
604            .for_each(|(i, vel_chunk)| {
605                let pos = [positions[3 * i], positions[3 * i + 1], positions[3 * i + 2]];
606                let a = Self::interpolate_vec_field(
607                    &acceleration.gx,
608                    &acceleration.gy,
609                    &acceleration.gz,
610                    acceleration.shape,
611                    &pos,
612                    &dx,
613                    &lx,
614                    is_periodic,
615                );
616                for k in 0..3 {
617                    vel_chunk[k] += a[k] * dt;
618                }
619                if let Some(ref prog) = progress {
620                    let c = counter.fetch_add(1, Ordering::Relaxed);
621                    if c.is_multiple_of(report_interval) {
622                        prog.set_intra_progress(c, n_tracers_u64);
623                    }
624                }
625            });
626    }
627
628    /// Velocity moment at a given spatial position.
629    ///
630    /// Finds all tracers in the same spatial cell and computes the requested moment
631    /// from their velocities and masses.
632    fn moment(&self, position: &[f64; 3], order: usize) -> Tensor {
633        let indices = self.tracers_in_cell(position);
634        let cell_vol = self.domain.cell_volume_3d();
635
636        match order {
637            0 => {
638                // Zeroth moment: density = sum(masses) / cell_volume
639                let rho: f64 = indices.iter().map(|&i| self.masses[i]).sum::<f64>() / cell_vol;
640                Tensor {
641                    data: vec![rho],
642                    rank: 0,
643                    shape: vec![],
644                }
645            }
646            1 => {
647                // First moment: mass-weighted mean velocity
648                let mut mean_v = [0.0f64; 3];
649                let total_mass: f64 = indices.iter().map(|&i| self.masses[i]).sum();
650                if total_mass > 0.0 {
651                    for &i in &indices {
652                        let m = self.masses[i];
653                        for k in 0..3 {
654                            mean_v[k] += m * self.velocities[3 * i + k];
655                        }
656                    }
657                    for k in 0..3 {
658                        mean_v[k] /= total_mass;
659                    }
660                }
661                Tensor {
662                    data: mean_v.to_vec(),
663                    rank: 1,
664                    shape: vec![3],
665                }
666            }
667            2 => {
668                // Second moment: mass-weighted velocity dispersion tensor
669                let mut mean_v = [0.0f64; 3];
670                let mut tensor = [0.0f64; 9];
671                let total_mass: f64 = indices.iter().map(|&i| self.masses[i]).sum();
672                if total_mass > 0.0 {
673                    for &i in &indices {
674                        let m = self.masses[i];
675                        for k in 0..3 {
676                            mean_v[k] += m * self.velocities[3 * i + k];
677                        }
678                    }
679                    for k in 0..3 {
680                        mean_v[k] /= total_mass;
681                    }
682
683                    for &i in &indices {
684                        let m = self.masses[i];
685                        for a in 0..3 {
686                            for b in 0..3 {
687                                let dv_a = self.velocities[3 * i + a] - mean_v[a];
688                                let dv_b = self.velocities[3 * i + b] - mean_v[b];
689                                tensor[a * 3 + b] += m * dv_a * dv_b;
690                            }
691                        }
692                    }
693                    for val in &mut tensor {
694                        *val /= total_mass;
695                    }
696                }
697                Tensor {
698                    data: tensor.to_vec(),
699                    rank: 2,
700                    shape: vec![3, 3],
701                }
702            }
703            _ => {
704                let dim = 3usize.pow(order as u32);
705                Tensor {
706                    data: vec![0.0; dim],
707                    rank: order,
708                    shape: vec![3; order],
709                }
710            }
711        }
712    }
713
714    /// Total mass. Constant by Liouville's theorem — returns the cached value.
715    fn total_mass(&self) -> f64 {
716        self.total_mass_cached
717    }
718
719    /// Casimir invariant C₂ = integral of f² over phase space.
720    ///
721    /// Approximated by depositing onto the spatial grid and computing integral of rho^2.
722    /// This is a spatial-only approximation; the true C₂ involves the 6D distribution.
723    fn casimir_c2(&self) -> f64 {
724        let density = self.compute_density();
725        let cell_vol = self.domain.cell_volume_3d();
726        density.data.iter().map(|&rho| rho * rho).sum::<f64>() * cell_vol
727    }
728
729    /// Entropy S = -integral of f ln f over phase space.
730    /// Constant by Liouville's theorem — returns the cached value.
731    fn entropy(&self) -> f64 {
732        self.entropy_cached
733    }
734
735    /// Number of distinct velocity streams at each spatial point.
736    ///
737    /// Counts the number of distinct Lagrangian tracers per spatial cell.
738    fn stream_count(&self) -> StreamCountField {
739        let [nx, ny, nz] = self.spatial_shape;
740        let n_cells = nx * ny * nz;
741        let n_tracers = self.num_tracers();
742        let dx = self.cached_dx;
743        let lx = self.cached_lx;
744        let is_periodic = self.cached_is_periodic;
745
746        let counts: Vec<u32> = (0..n_tracers)
747            .into_par_iter()
748            .fold(
749                || vec![0u32; n_cells],
750                |mut local, i| {
751                    let mass = self.masses[i];
752                    if mass <= 0.0 {
753                        return local;
754                    }
755                    let pos = [
756                        self.positions[3 * i],
757                        self.positions[3 * i + 1],
758                        self.positions[3 * i + 2],
759                    ];
760                    let mut skip = false;
761                    let mut ci = [0usize; 3];
762                    let ns = [nx, ny, nz];
763                    for k in 0..3 {
764                        let idx = ((pos[k] + lx[k]) / dx[k]).floor() as isize;
765                        if is_periodic {
766                            ci[k] = idx.rem_euclid(ns[k] as isize) as usize;
767                        } else if idx < 0 || idx >= ns[k] as isize {
768                            skip = true;
769                            break;
770                        } else {
771                            ci[k] = idx as usize;
772                        }
773                    }
774                    if !skip {
775                        let flat = ci[0] * ny * nz + ci[1] * nz + ci[2];
776                        local[flat] += 1;
777                    }
778                    local
779                },
780            )
781            .reduce(
782                || vec![0u32; n_cells],
783                |mut a, b| {
784                    for i in 0..n_cells {
785                        a[i] += b[i];
786                    }
787                    a
788                },
789            );
790
791        StreamCountField {
792            data: counts,
793            shape: [nx, ny, nz],
794        }
795    }
796
797    /// Local velocity distribution at a given spatial position.
798    ///
799    /// Collects the speed |v| of all tracers in the same cell as the given position.
800    fn velocity_distribution(&self, position: &[f64; 3]) -> Vec<f64> {
801        let indices = self.tracers_in_cell(position);
802        indices
803            .iter()
804            .map(|&i| {
805                let vx = self.velocities[3 * i];
806                let vy = self.velocities[3 * i + 1];
807                let vz = self.velocities[3 * i + 2];
808                (vx * vx + vy * vy + vz * vz).sqrt()
809            })
810            .collect()
811    }
812
813    /// Total kinetic energy `T = sum of 0.5 * mass[i] * |V[i]|^2`.
814    fn total_kinetic_energy(&self) -> Option<f64> {
815        let n = self.num_tracers();
816        Some(
817            (0..n)
818                .into_par_iter()
819                .map(|i| {
820                    let vx = self.velocities[3 * i];
821                    let vy = self.velocities[3 * i + 1];
822                    let vz = self.velocities[3 * i + 2];
823                    0.5 * self.masses[i] * (vx * vx + vy * vy + vz * vz)
824                })
825                .sum(),
826        )
827    }
828
829    /// Extract a full 6D snapshot by CIC deposition of all tracers.
830    ///
831    /// Each tracer is deposited onto 2^3 x 2^3 = 64 surrounding cells in the 6D grid.
832    /// This is expensive for large grids and should only be used for checkpoints.
833    fn to_snapshot(&self, time: f64) -> Option<PhaseSpaceSnapshot> {
834        let d = &self.domain;
835        let nx = [
836            d.spatial_res.x1 as usize,
837            d.spatial_res.x2 as usize,
838            d.spatial_res.x3 as usize,
839        ];
840        let nv = [
841            d.velocity_res.v1 as usize,
842            d.velocity_res.v2 as usize,
843            d.velocity_res.v3 as usize,
844        ];
845        let dx = d.dx();
846        let dv = d.dv();
847        let lx = d.lx();
848        let lv = d.lv();
849
850        let total_6d = nx[0] * nx[1] * nx[2] * nv[0] * nv[1] * nv[2];
851        let mut data = vec![0.0f64; total_6d];
852
853        let cell_vol_6d = d.cell_volume_6d();
854        let is_periodic = self.cached_is_periodic;
855
856        // Strides for row-major 6D: x1, x2, x3, v1, v2, v3
857        let sv3 = 1;
858        let sv2 = nv[2];
859        let sv1 = nv[1] * nv[2];
860        let sx3 = nv[0] * sv1;
861        let sx2 = nx[2] * sx3;
862        let sx1 = nx[1] * sx2;
863
864        let n_tracers = self.num_tracers();
865
866        for i in 0..n_tracers {
867            let mass = self.masses[i];
868            if mass <= 0.0 {
869                continue;
870            }
871            let pos = [
872                self.positions[3 * i],
873                self.positions[3 * i + 1],
874                self.positions[3 * i + 2],
875            ];
876            let vel = [
877                self.velocities[3 * i],
878                self.velocities[3 * i + 1],
879                self.velocities[3 * i + 2],
880            ];
881
882            // Spatial CIC indices
883            let mut x_ci = [0isize; 3];
884            let mut x_frac = [0.0f64; 3];
885            for k in 0..3 {
886                let s = (pos[k] + lx[k]) / dx[k] - 0.5;
887                x_ci[k] = s.floor() as isize;
888                x_frac[k] = s - x_ci[k] as f64;
889            }
890
891            // Velocity CIC indices
892            let mut v_ci = [0isize; 3];
893            let mut v_frac = [0.0f64; 3];
894            for k in 0..3 {
895                let s = (vel[k] + lv[k]) / dv[k] - 0.5;
896                v_ci[k] = s.floor() as isize;
897                v_frac[k] = s - v_ci[k] as f64;
898            }
899
900            // Deposit to 2^3 x 2^3 = 64 surrounding 6D cells
901            for dix in 0..2isize {
902                let wx0 = if dix == 0 { 1.0 - x_frac[0] } else { x_frac[0] };
903                let ix0 = x_ci[0] + dix;
904                if is_periodic {
905                    // handled below
906                } else if ix0 < 0 || ix0 >= nx[0] as isize {
907                    continue;
908                }
909                let ix0_w = if is_periodic {
910                    ix0.rem_euclid(nx[0] as isize) as usize
911                } else {
912                    ix0 as usize
913                };
914
915                for diy in 0..2isize {
916                    let wx1 = if diy == 0 { 1.0 - x_frac[1] } else { x_frac[1] };
917                    let ix1 = x_ci[1] + diy;
918                    if is_periodic {
919                        // handled below
920                    } else if ix1 < 0 || ix1 >= nx[1] as isize {
921                        continue;
922                    }
923                    let ix1_w = if is_periodic {
924                        ix1.rem_euclid(nx[1] as isize) as usize
925                    } else {
926                        ix1 as usize
927                    };
928
929                    for diz in 0..2isize {
930                        let wx2 = if diz == 0 { 1.0 - x_frac[2] } else { x_frac[2] };
931                        let ix2 = x_ci[2] + diz;
932                        if is_periodic {
933                            // handled below
934                        } else if ix2 < 0 || ix2 >= nx[2] as isize {
935                            continue;
936                        }
937                        let ix2_w = if is_periodic {
938                            ix2.rem_euclid(nx[2] as isize) as usize
939                        } else {
940                            ix2 as usize
941                        };
942                        let wx = wx0 * wx1 * wx2;
943
944                        for div0 in 0..2isize {
945                            let wv0 = if div0 == 0 {
946                                1.0 - v_frac[0]
947                            } else {
948                                v_frac[0]
949                            };
950                            let iv0 = v_ci[0] + div0;
951                            if iv0 < 0 || iv0 >= nv[0] as isize {
952                                continue;
953                            }
954                            for div1 in 0..2isize {
955                                let wv1 = if div1 == 0 {
956                                    1.0 - v_frac[1]
957                                } else {
958                                    v_frac[1]
959                                };
960                                let iv1 = v_ci[1] + div1;
961                                if iv1 < 0 || iv1 >= nv[1] as isize {
962                                    continue;
963                                }
964                                for div2 in 0..2isize {
965                                    let wv2 = if div2 == 0 {
966                                        1.0 - v_frac[2]
967                                    } else {
968                                        v_frac[2]
969                                    };
970                                    let iv2 = v_ci[2] + div2;
971                                    if iv2 < 0 || iv2 >= nv[2] as isize {
972                                        continue;
973                                    }
974                                    let wv = wv0 * wv1 * wv2;
975
976                                    let flat = ix0_w * sx1
977                                        + ix1_w * sx2
978                                        + ix2_w * sx3
979                                        + iv0 as usize * sv1
980                                        + iv1 as usize * sv2
981                                        + iv2 as usize * sv3;
982
983                                    data[flat] += mass * wx * wv / cell_vol_6d;
984                                }
985                            }
986                        }
987                    }
988                }
989            }
990        }
991
992        Some(PhaseSpaceSnapshot {
993            data,
994            shape: [nx[0], nx[1], nx[2], nv[0], nv[1], nv[2]],
995            time,
996        })
997    }
998
999    /// Downcast to `&dyn Any` for runtime type queries.
1000    fn as_any(&self) -> &dyn Any {
1001        self
1002    }
1003
1004    /// Downcast to `&mut dyn Any` for runtime type queries.
1005    fn as_any_mut(&mut self) -> &mut dyn Any {
1006        self
1007    }
1008
1009    /// Heap memory used by position, velocity, f0, and mass arrays.
1010    fn memory_bytes(&self) -> usize {
1011        let n = self.num_tracers();
1012        // positions (3*n f64) + velocities (3*n f64) + f0_values (n f64) + masses (n f64)
1013        (3 * n + 3 * n + n + n) * std::mem::size_of::<f64>()
1014    }
1015}
1016
1017#[cfg(test)]
1018mod tests {
1019    use super::*;
1020    use crate::tooling::core::init::domain::{Domain, SpatialBoundType, VelocityBoundType};
1021
1022    fn test_domain() -> Domain {
1023        Domain::builder()
1024            .spatial_extent(1.0)
1025            .velocity_extent(1.0)
1026            .spatial_resolution(8)
1027            .velocity_resolution(8)
1028            .t_final(1.0)
1029            .spatial_bc(SpatialBoundType::Periodic)
1030            .velocity_bc(VelocityBoundType::Open)
1031            .build()
1032            .unwrap()
1033    }
1034
1035    /// Create a FlowMapRepr with a Gaussian IC in both position and velocity.
1036    /// Returns the representation and expected total mass.
1037    fn gaussian_flow_map(domain: &Domain, n_lag: usize, nv_lag: usize) -> FlowMapRepr {
1038        let lx = domain.lx();
1039        let lv = domain.lv();
1040        let dq = [
1041            2.0 * lx[0] / n_lag as f64,
1042            2.0 * lx[1] / n_lag as f64,
1043            2.0 * lx[2] / n_lag as f64,
1044        ];
1045        let dp = [
1046            2.0 * lv[0] / nv_lag as f64,
1047            2.0 * lv[1] / nv_lag as f64,
1048            2.0 * lv[2] / nv_lag as f64,
1049        ];
1050        let phase_vol = dq[0] * dq[1] * dq[2] * dp[0] * dp[1] * dp[2];
1051
1052        let mut repr = FlowMapRepr::new(domain, n_lag, nv_lag);
1053        let n = repr.num_tracers();
1054        let sigma_x = 0.3;
1055        let sigma_v = 0.3;
1056
1057        let mut total_mass = 0.0;
1058        let mut entropy = 0.0;
1059
1060        for i in 0..n {
1061            let x = repr.positions[3 * i];
1062            let y = repr.positions[3 * i + 1];
1063            let z = repr.positions[3 * i + 2];
1064            let vx = repr.velocities[3 * i];
1065            let vy = repr.velocities[3 * i + 1];
1066            let vz = repr.velocities[3 * i + 2];
1067
1068            let r2 = x * x + y * y + z * z;
1069            let v2 = vx * vx + vy * vy + vz * vz;
1070            let f0 =
1071                (-r2 / (2.0 * sigma_x * sigma_x)).exp() * (-v2 / (2.0 * sigma_v * sigma_v)).exp();
1072
1073            repr.f0_values[i] = f0;
1074            repr.masses[i] = f0 * phase_vol;
1075            total_mass += repr.masses[i];
1076            if f0 > 0.0 {
1077                entropy -= f0 * f0.ln() * phase_vol;
1078            }
1079        }
1080
1081        repr.total_mass_cached = total_mass;
1082        repr.entropy_cached = entropy;
1083        repr
1084    }
1085
1086    #[test]
1087    fn test_flow_map_free_streaming() {
1088        let domain = test_domain();
1089        let n_lag = 4;
1090        let nv_lag = 4;
1091        let mut repr = gaussian_flow_map(&domain, n_lag, nv_lag);
1092        let n = repr.num_tracers();
1093
1094        // Record initial positions
1095        let x0: Vec<f64> = repr.positions.clone();
1096        let v0: Vec<f64> = repr.velocities.clone();
1097
1098        let dt = 0.1;
1099        let dummy_disp = DisplacementField {
1100            dx: vec![0.0; 8 * 8 * 8],
1101            dy: vec![0.0; 8 * 8 * 8],
1102            dz: vec![0.0; 8 * 8 * 8],
1103            shape: [8, 8, 8],
1104        };
1105
1106        repr.advect_x(&dummy_disp, dt);
1107
1108        let lx = domain.lx();
1109        for i in 0..n {
1110            for k in 0..3 {
1111                let expected = x0[3 * i + k] + v0[3 * i + k] * dt;
1112                // With periodic wrapping
1113                let two_l = 2.0 * lx[k];
1114                let wrapped = ((expected + lx[k]).rem_euclid(two_l)) - lx[k];
1115                assert!(
1116                    (repr.positions[3 * i + k] - wrapped).abs() < 1e-12,
1117                    "tracer {i} dim {k}: expected {wrapped}, got {}",
1118                    repr.positions[3 * i + k]
1119                );
1120            }
1121        }
1122    }
1123
1124    #[test]
1125    fn test_flow_map_mass_conservation() {
1126        let domain = test_domain();
1127        let mut repr = gaussian_flow_map(&domain, 4, 4);
1128
1129        let mass_before = repr.total_mass();
1130        assert!(mass_before > 0.0, "initial mass must be positive");
1131
1132        // Advect in x
1133        let dummy_disp = DisplacementField {
1134            dx: vec![0.0; 8 * 8 * 8],
1135            dy: vec![0.0; 8 * 8 * 8],
1136            dz: vec![0.0; 8 * 8 * 8],
1137            shape: [8, 8, 8],
1138        };
1139        repr.advect_x(&dummy_disp, 0.1);
1140
1141        let mass_after = repr.total_mass();
1142        assert!(
1143            (mass_after - mass_before).abs() < 1e-14,
1144            "mass must be conserved: before={mass_before}, after={mass_after}"
1145        );
1146
1147        // Advect in v with a uniform acceleration field
1148        let n_cells = 8 * 8 * 8;
1149        let acc = AccelerationField {
1150            gx: vec![0.1; n_cells],
1151            gy: vec![0.0; n_cells],
1152            gz: vec![0.0; n_cells],
1153            shape: [8, 8, 8],
1154        };
1155        repr.advect_v(&acc, 0.05);
1156
1157        let mass_after_kick = repr.total_mass();
1158        assert!(
1159            (mass_after_kick - mass_before).abs() < 1e-14,
1160            "mass must be conserved after kick: before={mass_before}, after={mass_after_kick}"
1161        );
1162    }
1163
1164    #[test]
1165    fn test_flow_map_density_recovery() {
1166        let domain = test_domain();
1167        let repr = gaussian_flow_map(&domain, 6, 6);
1168
1169        let density = repr.compute_density();
1170        let cell_vol = domain.cell_volume_3d();
1171
1172        // Total mass from density field should match total_mass()
1173        let mass_from_density: f64 = density.data.iter().sum::<f64>() * cell_vol;
1174        let mass_from_repr = repr.total_mass();
1175
1176        assert!(
1177            (mass_from_density - mass_from_repr).abs() / mass_from_repr.max(1e-15) < 0.05,
1178            "density-integrated mass ({mass_from_density}) should match total_mass ({mass_from_repr})"
1179        );
1180
1181        // Density should be positive or zero everywhere
1182        for (i, &rho) in density.data.iter().enumerate() {
1183            assert!(
1184                rho >= 0.0,
1185                "density must be non-negative at cell {i}, got {rho}"
1186            );
1187        }
1188
1189        // Peak density should be near the center (Gaussian)
1190        let [nx, ny, nz] = density.shape;
1191        let center = (nx / 2) * ny * nz + (ny / 2) * nz + nz / 2;
1192        let center_rho = density.data[center];
1193        assert!(
1194            center_rho > 0.0,
1195            "center density should be positive for a Gaussian IC"
1196        );
1197    }
1198
1199    #[test]
1200    fn test_flow_map_kinetic_energy() {
1201        let domain = test_domain();
1202        let repr = gaussian_flow_map(&domain, 4, 4);
1203        let n = repr.num_tracers();
1204
1205        // Compute expected KE manually
1206        let mut expected_ke = 0.0;
1207        for i in 0..n {
1208            let vx = repr.velocities[3 * i];
1209            let vy = repr.velocities[3 * i + 1];
1210            let vz = repr.velocities[3 * i + 2];
1211            expected_ke += 0.5 * repr.masses[i] * (vx * vx + vy * vy + vz * vz);
1212        }
1213
1214        let ke = repr.total_kinetic_energy().unwrap();
1215        assert!(
1216            (ke - expected_ke).abs() < 1e-14,
1217            "kinetic energy mismatch: expected {expected_ke}, got {ke}"
1218        );
1219        assert!(ke >= 0.0, "kinetic energy must be non-negative");
1220    }
1221}