1use nalgebra::Vector3;
2use std::collections::HashSet;
3
4#[derive(Debug, Clone, PartialEq, Eq, Hash)]
6pub enum TrajectoryFlag {
7 ZeroCrossing,
8 MachTransition,
9 Apex,
10}
11
12impl TrajectoryFlag {
13 pub fn to_string(&self) -> String {
14 match self {
15 TrajectoryFlag::ZeroCrossing => "zero_crossing".to_string(),
16 TrajectoryFlag::MachTransition => "mach_transition".to_string(),
17 TrajectoryFlag::Apex => "apex".to_string(),
18 }
19 }
20}
21
22#[derive(Debug, Clone)]
24pub struct TrajectorySample {
25 pub distance_m: f64,
26 pub drop_m: f64,
27 pub wind_drift_m: f64,
28 pub velocity_mps: f64,
29 pub energy_j: f64,
30 pub time_s: f64,
31 pub flags: Vec<TrajectoryFlag>,
32}
33
34#[derive(Debug, Clone)]
36pub struct TrajectoryData {
37 pub times: Vec<f64>,
38 pub positions: Vec<Vector3<f64>>, pub velocities: Vec<Vector3<f64>>, pub transonic_distances: Vec<f64>, }
42
43#[derive(Debug, Clone)]
45pub struct TrajectoryOutputs {
46 pub target_distance_horiz_m: f64,
47 pub target_vertical_height_m: f64,
48 pub time_of_flight_s: f64,
49 pub max_ord_dist_horiz_m: f64,
50}
51
52pub fn sample_trajectory(
54 trajectory_data: &TrajectoryData,
55 outputs: &TrajectoryOutputs,
56 step_m: f64,
57 mass_kg: f64,
58) -> Vec<TrajectorySample> {
59 let step_size = if step_m <= 0.0 {
60 return Vec::new();
61 } else if step_m < 0.1 {
62 0.1
63 } else {
64 step_m
65 };
66
67 let max_dist = outputs.target_distance_horiz_m;
69 if max_dist < 1e-9 {
70 return Vec::new();
71 }
72
73 let x_vals: Vec<f64> = trajectory_data.positions.iter().map(|p| p.x).collect();
75 let y_vals: Vec<f64> = trajectory_data.positions.iter().map(|p| p.y).collect();
76 let z_vals: Vec<f64> = trajectory_data.positions.iter().map(|p| p.z).collect();
77
78 let speeds: Vec<f64> = trajectory_data
80 .velocities
81 .iter()
82 .map(|v| v.norm())
83 .collect();
84 let energies: Vec<f64> = speeds
85 .iter()
86 .map(|&speed| 0.5 * mass_kg * speed * speed)
87 .collect();
88
89 let num_steps = (max_dist / step_size).ceil() as usize + 1;
92 let distances: Vec<f64> = (0..num_steps)
93 .map(|i| i as f64 * step_size)
94 .filter(|&d| d <= max_dist + 0.1) .collect();
96
97 let mut samples = Vec::with_capacity(distances.len());
99
100 let muzzle_y = if !y_vals.is_empty() { y_vals[0] } else { 0.0 };
102
103 for &distance in &distances {
104 let y_interp = interpolate(&x_vals, &y_vals, distance); let wind_drift = interpolate(&x_vals, &z_vals, distance); let velocity = interpolate(&x_vals, &speeds, distance); let time = interpolate(&x_vals, &trajectory_data.times, distance); let energy = interpolate(&x_vals, &energies, distance); let los_y = muzzle_y + (outputs.target_vertical_height_m - muzzle_y) * distance / max_dist;
119 let drop = los_y - y_interp; samples.push(TrajectorySample {
122 distance_m: distance,
123 drop_m: drop,
124 wind_drift_m: wind_drift,
125 velocity_mps: velocity,
126 energy_j: energy,
127 time_s: time,
128 flags: Vec::new(), });
130 }
131
132 add_trajectory_flags(&mut samples, &trajectory_data.transonic_distances, max_dist);
134
135 samples
136}
137
138fn interpolate(x_vals: &[f64], y_vals: &[f64], x: f64) -> f64 {
140 if x_vals.is_empty() || y_vals.is_empty() {
141 return 0.0;
142 }
143
144 if x_vals.len() != y_vals.len() {
145 return 0.0;
146 }
147
148 if x <= x_vals[0] {
149 return y_vals[0];
150 }
151
152 if x >= x_vals[x_vals.len() - 1] {
153 return y_vals[y_vals.len() - 1];
154 }
155
156 let mut left = 0;
158 let mut right = x_vals.len() - 1;
159
160 while right - left > 1 {
161 let mid = (left + right) / 2;
162 if x_vals[mid] <= x {
163 left = mid;
164 } else {
165 right = mid;
166 }
167 }
168
169 let x1 = x_vals[left];
171 let x2 = x_vals[right];
172 let y1 = y_vals[left];
173 let y2 = y_vals[right];
174
175 if (x2 - x1).abs() < f64::EPSILON {
176 return y1;
177 }
178
179 y1 + (y2 - y1) * (x - x1) / (x2 - x1)
180}
181
182fn add_trajectory_flags(
184 samples: &mut [TrajectorySample],
185 transonic_distances: &[f64],
186 target_distance_input_m: f64,
187) {
188 let tolerance = 1e-6;
189
190 detect_zero_crossings(samples, tolerance);
192
193 for &transonic_dist in transonic_distances {
195 if let Some(idx) = find_closest_sample_index(samples, transonic_dist) {
196 samples[idx].flags.push(TrajectoryFlag::MachTransition);
197 }
198 }
199
200 if samples.len() > 2 {
204 let target_distance_m = target_distance_input_m;
206
207 let mut min_drop = f64::INFINITY;
210 let mut apex_idx = 1;
211
212 for i in 1..samples.len() {
214 if samples[i].distance_m > target_distance_m {
216 break;
217 }
218
219 if samples[i].drop_m < min_drop {
220 min_drop = samples[i].drop_m;
221 apex_idx = i;
222 }
223 }
224
225 samples[apex_idx].flags.push(TrajectoryFlag::Apex);
227 }
228}
229
230fn detect_zero_crossings(samples: &mut [TrajectorySample], tolerance: f64) {
232 if samples.len() < 2 {
233 return;
234 }
235
236 let drops: Vec<f64> = samples.iter().map(|s| s.drop_m).collect();
237
238 for i in 0..(drops.len() - 1) {
240 let current = drops[i];
241 let next = drops[i + 1];
242
243 let crosses_zero = (current < -tolerance && next >= -tolerance)
245 || (current > tolerance && next <= tolerance);
246
247 if crosses_zero {
248 samples[i + 1].flags.push(TrajectoryFlag::ZeroCrossing);
249 }
250 }
251
252 for (i, &drop) in drops.iter().enumerate() {
254 if drop.abs() <= tolerance {
255 samples[i].flags.push(TrajectoryFlag::ZeroCrossing);
256 }
257 }
258
259 for sample in samples.iter_mut() {
261 let mut unique_flags = Vec::new();
262 let mut seen = HashSet::new();
263
264 for flag in &sample.flags {
265 if seen.insert(flag.clone()) {
266 unique_flags.push(flag.clone());
267 }
268 }
269 sample.flags = unique_flags;
270 }
271}
272
273fn find_closest_sample_index(samples: &[TrajectorySample], target_distance: f64) -> Option<usize> {
275 if samples.is_empty() {
276 return None;
277 }
278
279 let distances: Vec<f64> = samples.iter().map(|s| s.distance_m).collect();
281
282 let mut left = 0;
283 let mut right = distances.len();
284
285 while left < right {
286 let mid = (left + right) / 2;
287 if distances[mid] < target_distance {
288 left = mid + 1;
289 } else {
290 right = mid;
291 }
292 }
293
294 let mut best_idx = left.min(distances.len() - 1);
296
297 if left > 0 {
298 let left_dist = (distances[left - 1] - target_distance).abs();
299 let right_dist = (distances[best_idx] - target_distance).abs();
300
301 if left_dist <= right_dist {
303 best_idx = left - 1;
304 }
305 }
306
307 Some(best_idx)
308}
309
310pub fn trajectory_samples_to_dicts(samples: &[TrajectorySample]) -> Vec<TrajectoryDict> {
312 samples
313 .iter()
314 .map(|sample| TrajectoryDict {
315 distance_m: sample.distance_m,
316 drop_m: sample.drop_m,
317 wind_drift_m: sample.wind_drift_m,
318 velocity_mps: sample.velocity_mps,
319 energy_j: sample.energy_j,
320 time_s: sample.time_s,
321 flags: sample.flags.iter().map(|f| f.to_string()).collect(),
322 })
323 .collect()
324}
325
326#[derive(Debug, Clone)]
328pub struct TrajectoryDict {
329 pub distance_m: f64,
330 pub drop_m: f64,
331 pub wind_drift_m: f64,
332 pub velocity_mps: f64,
333 pub energy_j: f64,
334 pub time_s: f64,
335 pub flags: Vec<String>,
336}
337
338#[cfg(test)]
339mod tests {
340 use super::*;
341
342 #[test]
343 fn test_interpolate() {
344 let x_vals = vec![0.0, 1.0, 2.0, 3.0];
345 let y_vals = vec![0.0, 10.0, 20.0, 30.0];
346
347 assert_eq!(interpolate(&x_vals, &y_vals, 0.5), 5.0);
348 assert_eq!(interpolate(&x_vals, &y_vals, 1.5), 15.0);
349 assert_eq!(interpolate(&x_vals, &y_vals, 2.5), 25.0);
350
351 assert_eq!(interpolate(&x_vals, &y_vals, -1.0), 0.0); assert_eq!(interpolate(&x_vals, &y_vals, 4.0), 30.0); }
355
356 #[test]
357 fn test_find_closest_sample_index() {
358 let samples = vec![
359 TrajectorySample {
360 distance_m: 0.0,
361 drop_m: 0.0,
362 wind_drift_m: 0.0,
363 velocity_mps: 100.0,
364 energy_j: 1000.0,
365 time_s: 0.0,
366 flags: Vec::new(),
367 },
368 TrajectorySample {
369 distance_m: 10.0,
370 drop_m: -1.0,
371 wind_drift_m: 0.1,
372 velocity_mps: 95.0,
373 energy_j: 950.0,
374 time_s: 0.1,
375 flags: Vec::new(),
376 },
377 TrajectorySample {
378 distance_m: 20.0,
379 drop_m: -4.0,
380 wind_drift_m: 0.2,
381 velocity_mps: 90.0,
382 energy_j: 900.0,
383 time_s: 0.2,
384 flags: Vec::new(),
385 },
386 ];
387
388 assert_eq!(find_closest_sample_index(&samples, 5.0), Some(0));
389 assert_eq!(find_closest_sample_index(&samples, 12.0), Some(1));
390 assert_eq!(find_closest_sample_index(&samples, 18.0), Some(2));
391 }
392
393 #[test]
394 fn test_detect_zero_crossings() {
395 let mut samples = vec![
396 TrajectorySample {
397 distance_m: 0.0,
398 drop_m: 1.0, wind_drift_m: 0.0,
400 velocity_mps: 100.0,
401 energy_j: 1000.0,
402 time_s: 0.0,
403 flags: Vec::new(),
404 },
405 TrajectorySample {
406 distance_m: 10.0,
407 drop_m: -0.5, wind_drift_m: 0.1,
409 velocity_mps: 95.0,
410 energy_j: 950.0,
411 time_s: 0.1,
412 flags: Vec::new(),
413 },
414 TrajectorySample {
415 distance_m: 20.0,
416 drop_m: -2.0, wind_drift_m: 0.2,
418 velocity_mps: 90.0,
419 energy_j: 900.0,
420 time_s: 0.2,
421 flags: Vec::new(),
422 },
423 ];
424
425 detect_zero_crossings(&mut samples, 1e-6);
426
427 assert!(!samples[0].flags.contains(&TrajectoryFlag::ZeroCrossing));
429 assert!(samples[1].flags.contains(&TrajectoryFlag::ZeroCrossing));
430 assert!(!samples[2].flags.contains(&TrajectoryFlag::ZeroCrossing));
431 }
432
433 #[test]
434 fn test_sample_trajectory_basic() {
435 let trajectory_data = TrajectoryData {
438 times: vec![0.0, 1.0, 2.0],
439 positions: vec![
440 Vector3::new(0.0, 0.0, 0.0), Vector3::new(100.0, 10.0, 1.0), Vector3::new(200.0, 5.0, 2.0), ],
444 velocities: vec![
445 Vector3::new(1.0, 10.0, 100.0),
446 Vector3::new(1.0, 5.0, 95.0),
447 Vector3::new(1.0, 0.0, 90.0),
448 ],
449 transonic_distances: vec![150.0],
450 };
451
452 let outputs = TrajectoryOutputs {
453 target_distance_horiz_m: 200.0,
454 target_vertical_height_m: 0.0,
455 time_of_flight_s: 2.0,
456 max_ord_dist_horiz_m: 100.0,
457 };
458
459 let samples = sample_trajectory(&trajectory_data, &outputs, 50.0, 0.1);
460
461 assert_eq!(samples.len(), 5);
463 assert_eq!(samples[0].distance_m, 0.0);
464 assert_eq!(samples[1].distance_m, 50.0);
465 assert_eq!(samples[2].distance_m, 100.0);
466 assert_eq!(samples[3].distance_m, 150.0);
467 assert_eq!(samples[4].distance_m, 200.0);
468
469 assert!(samples[1].velocity_mps > 90.0 && samples[1].velocity_mps < 100.0);
471
472 assert!(samples[2].flags.contains(&TrajectoryFlag::Apex)); assert!(samples[3].flags.contains(&TrajectoryFlag::MachTransition)); }
476}