Skip to main content

caustic/tooling/core/init/
mergers.rs

1//! Two-body merger / interaction initial conditions.
2//!
3//! Sets up a binary merger by superposing two isolated equilibria (e.g. Plummer + Plummer,
4//! King + Hernquist) with offset positions and relative velocities in the centre-of-mass
5//! frame. The combined distribution function is f(x,v) = f1(x-x1, v-v1) + f2(x-x2, v-v2),
6//! which is exact for collisionless systems before the halos begin to overlap. An impact
7//! parameter controls whether the encounter is head-on or off-axis.
8
9use super::super::types::PhaseSpaceSnapshot;
10use super::domain::Domain;
11use super::isolated::IsolatedEquilibrium;
12use rayon::prelude::*;
13use rust_decimal::Decimal;
14use rust_decimal::prelude::ToPrimitive;
15
16/// Binary merger initial conditions: two displaced and boosted equilibria.
17///
18/// The combined distribution is f(x,v) = f1(x-x1, v-v1) + f2(x-x2, v-v2),
19/// exact for collisionless systems before the halos begin to interact.
20pub struct MergerIC {
21    /// First body's equilibrium distribution (e.g. Plummer, King, NFW).
22    pub body1: Box<dyn IsolatedEquilibrium>,
23    /// Total mass of the first body.
24    pub mass1: Decimal,
25    /// Second body's equilibrium distribution.
26    pub body2: Box<dyn IsolatedEquilibrium>,
27    /// Total mass of the second body.
28    pub mass2: Decimal,
29    /// Initial separation vector between the two centres (body2 - body1).
30    pub separation: [f64; 3],
31    /// Relative velocity of body 2 with respect to body 1.
32    pub relative_velocity: [f64; 3],
33    /// Transverse offset for off-axis encounters; zero gives a head-on merger.
34    pub impact_parameter: Decimal,
35    // Cached f64 values for hot-path computation
36    mass1_f64: f64,
37    mass2_f64: f64,
38    impact_parameter_f64: f64,
39}
40
41impl MergerIC {
42    /// Create a MergerIC from f64 parameters (backward-compatible).
43    pub fn new(
44        body1: Box<dyn IsolatedEquilibrium>,
45        mass1: f64,
46        body2: Box<dyn IsolatedEquilibrium>,
47        mass2: f64,
48        separation: [f64; 3],
49        relative_velocity: [f64; 3],
50        impact_parameter: f64,
51    ) -> Self {
52        Self {
53            body1,
54            mass1: Decimal::from_f64_retain(mass1).unwrap_or(Decimal::ZERO),
55            body2,
56            mass2: Decimal::from_f64_retain(mass2).unwrap_or(Decimal::ZERO),
57            separation,
58            relative_velocity,
59            impact_parameter: Decimal::from_f64_retain(impact_parameter).unwrap_or(Decimal::ZERO),
60            mass1_f64: mass1,
61            mass2_f64: mass2,
62            impact_parameter_f64: impact_parameter,
63        }
64    }
65
66    /// Create a MergerIC from Decimal parameters (exact config).
67    pub fn new_decimal(
68        body1: Box<dyn IsolatedEquilibrium>,
69        mass1: Decimal,
70        body2: Box<dyn IsolatedEquilibrium>,
71        mass2: Decimal,
72        separation: [f64; 3],
73        relative_velocity: [f64; 3],
74        impact_parameter: Decimal,
75    ) -> Self {
76        Self {
77            body1,
78            mass1_f64: mass1.to_f64().unwrap_or(0.0),
79            body2,
80            mass2_f64: mass2.to_f64().unwrap_or(0.0),
81            separation,
82            relative_velocity,
83            impact_parameter_f64: impact_parameter.to_f64().unwrap_or(0.0),
84            mass1,
85            mass2,
86            impact_parameter,
87        }
88    }
89
90    /// Sample both components on the grid and sum.
91    /// Body 1 is centred at (-sep/2, 0, 0) with velocity (-v_rel/2, 0, 0).
92    /// Body 2 is centred at (+sep/2, 0, 0) with velocity (+v_rel/2, 0, 0).
93    pub fn sample_on_grid(
94        &self,
95        domain: &Domain,
96        progress: Option<&crate::tooling::core::progress::StepProgress>,
97    ) -> PhaseSpaceSnapshot {
98        let nx1 = domain.spatial_res.x1 as usize;
99        let nx2 = domain.spatial_res.x2 as usize;
100        let nx3 = domain.spatial_res.x3 as usize;
101        let nv1 = domain.velocity_res.v1 as usize;
102        let nv2 = domain.velocity_res.v2 as usize;
103        let nv3 = domain.velocity_res.v3 as usize;
104
105        let dx = domain.dx();
106        let dv = domain.dv();
107        let lx = domain.lx();
108        let lv = domain.lv();
109
110        // Centre-of-mass frame: body 1 at -sep/2, body 2 at +sep/2
111        let x1_offset = [
112            -self.separation[0] / 2.0,
113            -self.separation[1] / 2.0,
114            -self.separation[2] / 2.0,
115        ];
116        let x2_offset = [
117            self.separation[0] / 2.0,
118            self.separation[1] / 2.0,
119            self.separation[2] / 2.0,
120        ];
121        let v1_offset = [
122            -self.relative_velocity[0] / 2.0,
123            -self.relative_velocity[1] / 2.0,
124            -self.relative_velocity[2] / 2.0,
125        ];
126        let v2_offset = [
127            self.relative_velocity[0] / 2.0,
128            self.relative_velocity[1] / 2.0,
129            self.relative_velocity[2] / 2.0,
130        ];
131
132        let s_v3 = 1usize;
133        let s_v2 = nv3;
134        let s_v1 = nv2 * nv3;
135        let s_x3 = nv1 * s_v1;
136        let s_x2 = nx3 * s_x3;
137        let s_x1 = nx2 * s_x2;
138
139        let total = nx1 * nx2 * nx3 * nv1 * nv2 * nv3;
140        let mut data = vec![0.0f64; total];
141
142        let counter = std::sync::atomic::AtomicU64::new(0);
143        let report_interval = (nx1 / 100).max(1) as u64;
144
145        // Establish 0% baseline so the TUI doesn't jump to a non-zero first value
146        if let Some(p) = progress {
147            p.set_intra_progress(0, nx1 as u64);
148        }
149
150        data.par_chunks_mut(s_x1)
151            .enumerate()
152            .for_each(|(ix1, chunk)| {
153                let x1 = -lx[0] + (ix1 as f64 + 0.5) * dx[0];
154                for ix2 in 0..nx2 {
155                    let x2 = -lx[1] + (ix2 as f64 + 0.5) * dx[1];
156                    for ix3 in 0..nx3 {
157                        let x3 = -lx[2] + (ix3 as f64 + 0.5) * dx[2];
158                        let base = ix2 * s_x2 + ix3 * s_x3;
159
160                        // Radius from body 1 centre
161                        let dx1 = x1 - x1_offset[0];
162                        let dy1 = x2 - x1_offset[1];
163                        let dz1 = x3 - x1_offset[2];
164                        let r1 = (dx1 * dx1 + dy1 * dy1 + dz1 * dz1).sqrt();
165                        let phi1 = self.body1.potential(r1);
166
167                        // Radius from body 2 centre
168                        let dx2 = x1 - x2_offset[0];
169                        let dy2 = x2 - x2_offset[1];
170                        let dz2 = x3 - x2_offset[2];
171                        let r2 = (dx2 * dx2 + dy2 * dy2 + dz2 * dz2).sqrt();
172                        let phi2 = self.body2.potential(r2);
173
174                        for iv1 in 0..nv1 {
175                            let v1 = -lv[0] + (iv1 as f64 + 0.5) * dv[0];
176                            for iv2 in 0..nv2 {
177                                let v2 = -lv[1] + (iv2 as f64 + 0.5) * dv[1];
178                                for iv3 in 0..nv3 {
179                                    let v3 = -lv[2] + (iv3 as f64 + 0.5) * dv[2];
180                                    let idx = base + iv1 * s_v1 + iv2 * s_v2 + iv3 * s_v3;
181
182                                    // f₁(x−x₁, v−v₁): energy in body 1's rest frame
183                                    let dv1_1 = v1 - v1_offset[0];
184                                    let dv1_2 = v2 - v1_offset[1];
185                                    let dv1_3 = v3 - v1_offset[2];
186                                    let e1 = 0.5 * (dv1_1 * dv1_1 + dv1_2 * dv1_2 + dv1_3 * dv1_3)
187                                        + phi1;
188                                    let f1 = self.body1.distribution_function(e1, 0.0).max(0.0);
189
190                                    // f₂(x−x₂, v−v₂): energy in body 2's rest frame
191                                    let dv2_1 = v1 - v2_offset[0];
192                                    let dv2_2 = v2 - v2_offset[1];
193                                    let dv2_3 = v3 - v2_offset[2];
194                                    let e2 = 0.5 * (dv2_1 * dv2_1 + dv2_2 * dv2_2 + dv2_3 * dv2_3)
195                                        + phi2;
196                                    let f2 = self.body2.distribution_function(e2, 0.0).max(0.0);
197
198                                    chunk[idx] = f1 + f2;
199                                }
200                            }
201                        }
202                    }
203                }
204
205                if let Some(p) = progress {
206                    let c = counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
207                    if c.is_multiple_of(report_interval) {
208                        p.set_intra_progress(c, nx1 as u64);
209                    }
210                }
211            });
212
213        PhaseSpaceSnapshot {
214            data,
215            shape: [nx1, nx2, nx3, nv1, nv2, nv3],
216            time: 0.0,
217        }
218    }
219
220    /// Check that both components fit within the domain.
221    pub fn validate(&self, domain: &Domain) -> anyhow::Result<()> {
222        let lx = domain.lx()[0];
223        let sep_max = self
224            .separation
225            .iter()
226            .map(|s| s.abs())
227            .fold(0.0_f64, f64::max);
228        if sep_max / 2.0 > lx * 0.9 {
229            anyhow::bail!(
230                "Merger separation {:.2} exceeds 90% of domain half-extent {:.2}",
231                sep_max,
232                lx
233            );
234        }
235        Ok(())
236    }
237}