caustic/tooling/core/init/
mergers.rs1use super::super::types::PhaseSpaceSnapshot;
10use super::domain::Domain;
11use super::isolated::IsolatedEquilibrium;
12use rayon::prelude::*;
13use rust_decimal::Decimal;
14use rust_decimal::prelude::ToPrimitive;
15
16pub struct MergerIC {
21 pub body1: Box<dyn IsolatedEquilibrium>,
23 pub mass1: Decimal,
25 pub body2: Box<dyn IsolatedEquilibrium>,
27 pub mass2: Decimal,
29 pub separation: [f64; 3],
31 pub relative_velocity: [f64; 3],
33 pub impact_parameter: Decimal,
35 mass1_f64: f64,
37 mass2_f64: f64,
38 impact_parameter_f64: f64,
39}
40
41impl MergerIC {
42 pub fn new(
44 body1: Box<dyn IsolatedEquilibrium>,
45 mass1: f64,
46 body2: Box<dyn IsolatedEquilibrium>,
47 mass2: f64,
48 separation: [f64; 3],
49 relative_velocity: [f64; 3],
50 impact_parameter: f64,
51 ) -> Self {
52 Self {
53 body1,
54 mass1: Decimal::from_f64_retain(mass1).unwrap_or(Decimal::ZERO),
55 body2,
56 mass2: Decimal::from_f64_retain(mass2).unwrap_or(Decimal::ZERO),
57 separation,
58 relative_velocity,
59 impact_parameter: Decimal::from_f64_retain(impact_parameter).unwrap_or(Decimal::ZERO),
60 mass1_f64: mass1,
61 mass2_f64: mass2,
62 impact_parameter_f64: impact_parameter,
63 }
64 }
65
66 pub fn new_decimal(
68 body1: Box<dyn IsolatedEquilibrium>,
69 mass1: Decimal,
70 body2: Box<dyn IsolatedEquilibrium>,
71 mass2: Decimal,
72 separation: [f64; 3],
73 relative_velocity: [f64; 3],
74 impact_parameter: Decimal,
75 ) -> Self {
76 Self {
77 body1,
78 mass1_f64: mass1.to_f64().unwrap_or(0.0),
79 body2,
80 mass2_f64: mass2.to_f64().unwrap_or(0.0),
81 separation,
82 relative_velocity,
83 impact_parameter_f64: impact_parameter.to_f64().unwrap_or(0.0),
84 mass1,
85 mass2,
86 impact_parameter,
87 }
88 }
89
90 pub fn sample_on_grid(
94 &self,
95 domain: &Domain,
96 progress: Option<&crate::tooling::core::progress::StepProgress>,
97 ) -> PhaseSpaceSnapshot {
98 let nx1 = domain.spatial_res.x1 as usize;
99 let nx2 = domain.spatial_res.x2 as usize;
100 let nx3 = domain.spatial_res.x3 as usize;
101 let nv1 = domain.velocity_res.v1 as usize;
102 let nv2 = domain.velocity_res.v2 as usize;
103 let nv3 = domain.velocity_res.v3 as usize;
104
105 let dx = domain.dx();
106 let dv = domain.dv();
107 let lx = domain.lx();
108 let lv = domain.lv();
109
110 let x1_offset = [
112 -self.separation[0] / 2.0,
113 -self.separation[1] / 2.0,
114 -self.separation[2] / 2.0,
115 ];
116 let x2_offset = [
117 self.separation[0] / 2.0,
118 self.separation[1] / 2.0,
119 self.separation[2] / 2.0,
120 ];
121 let v1_offset = [
122 -self.relative_velocity[0] / 2.0,
123 -self.relative_velocity[1] / 2.0,
124 -self.relative_velocity[2] / 2.0,
125 ];
126 let v2_offset = [
127 self.relative_velocity[0] / 2.0,
128 self.relative_velocity[1] / 2.0,
129 self.relative_velocity[2] / 2.0,
130 ];
131
132 let s_v3 = 1usize;
133 let s_v2 = nv3;
134 let s_v1 = nv2 * nv3;
135 let s_x3 = nv1 * s_v1;
136 let s_x2 = nx3 * s_x3;
137 let s_x1 = nx2 * s_x2;
138
139 let total = nx1 * nx2 * nx3 * nv1 * nv2 * nv3;
140 let mut data = vec![0.0f64; total];
141
142 let counter = std::sync::atomic::AtomicU64::new(0);
143 let report_interval = (nx1 / 100).max(1) as u64;
144
145 if let Some(p) = progress {
147 p.set_intra_progress(0, nx1 as u64);
148 }
149
150 data.par_chunks_mut(s_x1)
151 .enumerate()
152 .for_each(|(ix1, chunk)| {
153 let x1 = -lx[0] + (ix1 as f64 + 0.5) * dx[0];
154 for ix2 in 0..nx2 {
155 let x2 = -lx[1] + (ix2 as f64 + 0.5) * dx[1];
156 for ix3 in 0..nx3 {
157 let x3 = -lx[2] + (ix3 as f64 + 0.5) * dx[2];
158 let base = ix2 * s_x2 + ix3 * s_x3;
159
160 let dx1 = x1 - x1_offset[0];
162 let dy1 = x2 - x1_offset[1];
163 let dz1 = x3 - x1_offset[2];
164 let r1 = (dx1 * dx1 + dy1 * dy1 + dz1 * dz1).sqrt();
165 let phi1 = self.body1.potential(r1);
166
167 let dx2 = x1 - x2_offset[0];
169 let dy2 = x2 - x2_offset[1];
170 let dz2 = x3 - x2_offset[2];
171 let r2 = (dx2 * dx2 + dy2 * dy2 + dz2 * dz2).sqrt();
172 let phi2 = self.body2.potential(r2);
173
174 for iv1 in 0..nv1 {
175 let v1 = -lv[0] + (iv1 as f64 + 0.5) * dv[0];
176 for iv2 in 0..nv2 {
177 let v2 = -lv[1] + (iv2 as f64 + 0.5) * dv[1];
178 for iv3 in 0..nv3 {
179 let v3 = -lv[2] + (iv3 as f64 + 0.5) * dv[2];
180 let idx = base + iv1 * s_v1 + iv2 * s_v2 + iv3 * s_v3;
181
182 let dv1_1 = v1 - v1_offset[0];
184 let dv1_2 = v2 - v1_offset[1];
185 let dv1_3 = v3 - v1_offset[2];
186 let e1 = 0.5 * (dv1_1 * dv1_1 + dv1_2 * dv1_2 + dv1_3 * dv1_3)
187 + phi1;
188 let f1 = self.body1.distribution_function(e1, 0.0).max(0.0);
189
190 let dv2_1 = v1 - v2_offset[0];
192 let dv2_2 = v2 - v2_offset[1];
193 let dv2_3 = v3 - v2_offset[2];
194 let e2 = 0.5 * (dv2_1 * dv2_1 + dv2_2 * dv2_2 + dv2_3 * dv2_3)
195 + phi2;
196 let f2 = self.body2.distribution_function(e2, 0.0).max(0.0);
197
198 chunk[idx] = f1 + f2;
199 }
200 }
201 }
202 }
203 }
204
205 if let Some(p) = progress {
206 let c = counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
207 if c.is_multiple_of(report_interval) {
208 p.set_intra_progress(c, nx1 as u64);
209 }
210 }
211 });
212
213 PhaseSpaceSnapshot {
214 data,
215 shape: [nx1, nx2, nx3, nv1, nv2, nv3],
216 time: 0.0,
217 }
218 }
219
220 pub fn validate(&self, domain: &Domain) -> anyhow::Result<()> {
222 let lx = domain.lx()[0];
223 let sep_max = self
224 .separation
225 .iter()
226 .map(|s| s.abs())
227 .fold(0.0_f64, f64::max);
228 if sep_max / 2.0 > lx * 0.9 {
229 anyhow::bail!(
230 "Merger separation {:.2} exceeds 90% of domain half-extent {:.2}",
231 sep_max,
232 lx
233 );
234 }
235 Ok(())
236 }
237}