1use nalgebra::Vector3;
2use std::collections::HashSet;
3
4#[derive(Debug, Clone, PartialEq, Eq, Hash)]
6pub enum TrajectoryFlag {
7 ZeroCrossing,
8 MachTransition,
9 Apex,
10}
11
12impl TrajectoryFlag {
13 pub fn to_string(&self) -> String {
14 match self {
15 TrajectoryFlag::ZeroCrossing => "zero_crossing".to_string(),
16 TrajectoryFlag::MachTransition => "mach_transition".to_string(),
17 TrajectoryFlag::Apex => "apex".to_string(),
18 }
19 }
20}
21
22#[derive(Debug, Clone)]
24pub struct TrajectorySample {
25 pub distance_m: f64,
26 pub drop_m: f64,
27 pub wind_drift_m: f64,
28 pub velocity_mps: f64,
29 pub energy_j: f64,
30 pub time_s: f64,
31 pub flags: Vec<TrajectoryFlag>,
32}
33
34#[derive(Debug, Clone)]
36pub struct TrajectoryData {
37 pub times: Vec<f64>,
38 pub positions: Vec<Vector3<f64>>, pub velocities: Vec<Vector3<f64>>, pub transonic_distances: Vec<f64>, }
42
43#[derive(Debug, Clone)]
45pub struct TrajectoryOutputs {
46 pub target_distance_horiz_m: f64,
47 pub target_vertical_height_m: f64,
48 pub time_of_flight_s: f64,
49 pub max_ord_dist_horiz_m: f64,
50}
51
52pub fn sample_trajectory(
54 trajectory_data: &TrajectoryData,
55 outputs: &TrajectoryOutputs,
56 step_m: f64,
57 mass_kg: f64,
58) -> Vec<TrajectorySample> {
59 let step_size = if step_m <= 0.0 {
60 return Vec::new();
61 } else if step_m < 0.1 {
62 0.1
63 } else {
64 step_m
65 };
66
67 let max_dist = outputs.target_distance_horiz_m;
69 if max_dist < 1e-9 {
70 return Vec::new();
71 }
72
73 let x_vals: Vec<f64> = trajectory_data.positions.iter().map(|p| p.x).collect();
75 let y_vals: Vec<f64> = trajectory_data.positions.iter().map(|p| p.y).collect();
76 let z_vals: Vec<f64> = trajectory_data.positions.iter().map(|p| p.z).collect();
77
78 let speeds: Vec<f64> = trajectory_data
80 .velocities
81 .iter()
82 .map(|v| v.norm())
83 .collect();
84 let energies: Vec<f64> = speeds
85 .iter()
86 .map(|&speed| 0.5 * mass_kg * speed * speed)
87 .collect();
88
89 let num_steps = (max_dist / step_size).ceil() as usize + 1;
92 let distances: Vec<f64> = (0..num_steps)
93 .map(|i| i as f64 * step_size)
94 .filter(|&d| d <= max_dist + 0.1) .collect();
96
97 let mut samples = Vec::with_capacity(distances.len());
99
100 let muzzle_y = if !y_vals.is_empty() { y_vals[0] } else { 0.0 };
102
103 for &distance in &distances {
104 let y_interp = interpolate(&z_vals, &y_vals, distance); let wind_drift = interpolate(&z_vals, &x_vals, distance); let velocity = interpolate(&z_vals, &speeds, distance); let time = interpolate(&z_vals, &trajectory_data.times, distance); let energy = interpolate(&z_vals, &energies, distance); let los_y = muzzle_y + (outputs.target_vertical_height_m - muzzle_y) * distance / max_dist;
120 let drop = los_y - y_interp; samples.push(TrajectorySample {
123 distance_m: distance,
124 drop_m: drop,
125 wind_drift_m: wind_drift,
126 velocity_mps: velocity,
127 energy_j: energy,
128 time_s: time,
129 flags: Vec::new(), });
131 }
132
133 add_trajectory_flags(&mut samples, &trajectory_data.transonic_distances, max_dist);
135
136 samples
137}
138
139fn interpolate(x_vals: &[f64], y_vals: &[f64], x: f64) -> f64 {
141 if x_vals.is_empty() || y_vals.is_empty() {
142 return 0.0;
143 }
144
145 if x_vals.len() != y_vals.len() {
146 return 0.0;
147 }
148
149 if x <= x_vals[0] {
150 return y_vals[0];
151 }
152
153 if x >= x_vals[x_vals.len() - 1] {
154 return y_vals[y_vals.len() - 1];
155 }
156
157 let mut left = 0;
159 let mut right = x_vals.len() - 1;
160
161 while right - left > 1 {
162 let mid = (left + right) / 2;
163 if x_vals[mid] <= x {
164 left = mid;
165 } else {
166 right = mid;
167 }
168 }
169
170 let x1 = x_vals[left];
172 let x2 = x_vals[right];
173 let y1 = y_vals[left];
174 let y2 = y_vals[right];
175
176 if (x2 - x1).abs() < f64::EPSILON {
177 return y1;
178 }
179
180 y1 + (y2 - y1) * (x - x1) / (x2 - x1)
181}
182
183fn add_trajectory_flags(
185 samples: &mut [TrajectorySample],
186 transonic_distances: &[f64],
187 target_distance_input_m: f64,
188) {
189 let tolerance = 1e-6;
190
191 detect_zero_crossings(samples, tolerance);
193
194 for &transonic_dist in transonic_distances {
196 if let Some(idx) = find_closest_sample_index(samples, transonic_dist) {
197 samples[idx].flags.push(TrajectoryFlag::MachTransition);
198 }
199 }
200
201 if samples.len() > 2 {
205 let target_distance_m = target_distance_input_m;
207
208 let mut min_drop = f64::INFINITY;
211 let mut apex_idx = 1;
212
213 for i in 1..samples.len() {
215 if samples[i].distance_m > target_distance_m {
217 break;
218 }
219
220 if samples[i].drop_m < min_drop {
221 min_drop = samples[i].drop_m;
222 apex_idx = i;
223 }
224 }
225
226 samples[apex_idx].flags.push(TrajectoryFlag::Apex);
228 }
229}
230
231fn detect_zero_crossings(samples: &mut [TrajectorySample], tolerance: f64) {
233 if samples.len() < 2 {
234 return;
235 }
236
237 let drops: Vec<f64> = samples.iter().map(|s| s.drop_m).collect();
238
239 for i in 0..(drops.len() - 1) {
241 let current = drops[i];
242 let next = drops[i + 1];
243
244 let crosses_zero = (current < -tolerance && next >= -tolerance)
246 || (current > tolerance && next <= tolerance);
247
248 if crosses_zero {
249 samples[i + 1].flags.push(TrajectoryFlag::ZeroCrossing);
250 }
251 }
252
253 for (i, &drop) in drops.iter().enumerate() {
255 if drop.abs() <= tolerance {
256 samples[i].flags.push(TrajectoryFlag::ZeroCrossing);
257 }
258 }
259
260 for sample in samples.iter_mut() {
262 let mut unique_flags = Vec::new();
263 let mut seen = HashSet::new();
264
265 for flag in &sample.flags {
266 if seen.insert(flag.clone()) {
267 unique_flags.push(flag.clone());
268 }
269 }
270 sample.flags = unique_flags;
271 }
272}
273
274fn find_closest_sample_index(samples: &[TrajectorySample], target_distance: f64) -> Option<usize> {
276 if samples.is_empty() {
277 return None;
278 }
279
280 let distances: Vec<f64> = samples.iter().map(|s| s.distance_m).collect();
282
283 let mut left = 0;
284 let mut right = distances.len();
285
286 while left < right {
287 let mid = (left + right) / 2;
288 if distances[mid] < target_distance {
289 left = mid + 1;
290 } else {
291 right = mid;
292 }
293 }
294
295 let mut best_idx = left.min(distances.len() - 1);
297
298 if left > 0 {
299 let left_dist = (distances[left - 1] - target_distance).abs();
300 let right_dist = (distances[best_idx] - target_distance).abs();
301
302 if left_dist <= right_dist {
304 best_idx = left - 1;
305 }
306 }
307
308 Some(best_idx)
309}
310
311pub fn trajectory_samples_to_dicts(samples: &[TrajectorySample]) -> Vec<TrajectoryDict> {
313 samples
314 .iter()
315 .map(|sample| TrajectoryDict {
316 distance_m: sample.distance_m,
317 drop_m: sample.drop_m,
318 wind_drift_m: sample.wind_drift_m,
319 velocity_mps: sample.velocity_mps,
320 energy_j: sample.energy_j,
321 time_s: sample.time_s,
322 flags: sample.flags.iter().map(|f| f.to_string()).collect(),
323 })
324 .collect()
325}
326
327#[derive(Debug, Clone)]
329pub struct TrajectoryDict {
330 pub distance_m: f64,
331 pub drop_m: f64,
332 pub wind_drift_m: f64,
333 pub velocity_mps: f64,
334 pub energy_j: f64,
335 pub time_s: f64,
336 pub flags: Vec<String>,
337}
338
339#[cfg(test)]
340mod tests {
341 use super::*;
342
343 #[test]
344 fn test_interpolate() {
345 let x_vals = vec![0.0, 1.0, 2.0, 3.0];
346 let y_vals = vec![0.0, 10.0, 20.0, 30.0];
347
348 assert_eq!(interpolate(&x_vals, &y_vals, 0.5), 5.0);
349 assert_eq!(interpolate(&x_vals, &y_vals, 1.5), 15.0);
350 assert_eq!(interpolate(&x_vals, &y_vals, 2.5), 25.0);
351
352 assert_eq!(interpolate(&x_vals, &y_vals, -1.0), 0.0); assert_eq!(interpolate(&x_vals, &y_vals, 4.0), 30.0); }
356
357 #[test]
358 fn test_find_closest_sample_index() {
359 let samples = vec![
360 TrajectorySample {
361 distance_m: 0.0,
362 drop_m: 0.0,
363 wind_drift_m: 0.0,
364 velocity_mps: 100.0,
365 energy_j: 1000.0,
366 time_s: 0.0,
367 flags: Vec::new(),
368 },
369 TrajectorySample {
370 distance_m: 10.0,
371 drop_m: -1.0,
372 wind_drift_m: 0.1,
373 velocity_mps: 95.0,
374 energy_j: 950.0,
375 time_s: 0.1,
376 flags: Vec::new(),
377 },
378 TrajectorySample {
379 distance_m: 20.0,
380 drop_m: -4.0,
381 wind_drift_m: 0.2,
382 velocity_mps: 90.0,
383 energy_j: 900.0,
384 time_s: 0.2,
385 flags: Vec::new(),
386 },
387 ];
388
389 assert_eq!(find_closest_sample_index(&samples, 5.0), Some(0));
390 assert_eq!(find_closest_sample_index(&samples, 12.0), Some(1));
391 assert_eq!(find_closest_sample_index(&samples, 18.0), Some(2));
392 }
393
394 #[test]
395 fn test_detect_zero_crossings() {
396 let mut samples = vec![
397 TrajectorySample {
398 distance_m: 0.0,
399 drop_m: 1.0, wind_drift_m: 0.0,
401 velocity_mps: 100.0,
402 energy_j: 1000.0,
403 time_s: 0.0,
404 flags: Vec::new(),
405 },
406 TrajectorySample {
407 distance_m: 10.0,
408 drop_m: -0.5, wind_drift_m: 0.1,
410 velocity_mps: 95.0,
411 energy_j: 950.0,
412 time_s: 0.1,
413 flags: Vec::new(),
414 },
415 TrajectorySample {
416 distance_m: 20.0,
417 drop_m: -2.0, wind_drift_m: 0.2,
419 velocity_mps: 90.0,
420 energy_j: 900.0,
421 time_s: 0.2,
422 flags: Vec::new(),
423 },
424 ];
425
426 detect_zero_crossings(&mut samples, 1e-6);
427
428 assert!(!samples[0].flags.contains(&TrajectoryFlag::ZeroCrossing));
430 assert!(samples[1].flags.contains(&TrajectoryFlag::ZeroCrossing));
431 assert!(!samples[2].flags.contains(&TrajectoryFlag::ZeroCrossing));
432 }
433
434 #[test]
435 fn test_sample_trajectory_basic() {
436 let trajectory_data = TrajectoryData {
439 times: vec![0.0, 1.0, 2.0],
440 positions: vec![
441 Vector3::new(0.0, 0.0, 0.0), Vector3::new(1.0, 10.0, 100.0), Vector3::new(2.0, 5.0, 200.0), ],
445 velocities: vec![
446 Vector3::new(1.0, 10.0, 100.0),
447 Vector3::new(1.0, 5.0, 95.0),
448 Vector3::new(1.0, 0.0, 90.0),
449 ],
450 transonic_distances: vec![150.0],
451 };
452
453 let outputs = TrajectoryOutputs {
454 target_distance_horiz_m: 200.0,
455 target_vertical_height_m: 0.0,
456 time_of_flight_s: 2.0,
457 max_ord_dist_horiz_m: 100.0,
458 };
459
460 let samples = sample_trajectory(&trajectory_data, &outputs, 50.0, 0.1);
461
462 assert_eq!(samples.len(), 5);
464 assert_eq!(samples[0].distance_m, 0.0);
465 assert_eq!(samples[1].distance_m, 50.0);
466 assert_eq!(samples[2].distance_m, 100.0);
467 assert_eq!(samples[3].distance_m, 150.0);
468 assert_eq!(samples[4].distance_m, 200.0);
469
470 assert!(samples[1].velocity_mps > 90.0 && samples[1].velocity_mps < 100.0);
472
473 assert!(samples[2].flags.contains(&TrajectoryFlag::Apex)); assert!(samples[3].flags.contains(&TrajectoryFlag::MachTransition)); }
477}