1use nalgebra::Vector3;
2use std::collections::HashSet;
3
4#[derive(Debug, Clone, PartialEq, Eq, Hash)]
6pub enum TrajectoryFlag {
7 ZeroCrossing,
8 MachTransition,
9 Apex,
10}
11
12impl TrajectoryFlag {
13 pub fn to_string(&self) -> String {
14 match self {
15 TrajectoryFlag::ZeroCrossing => "zero_crossing".to_string(),
16 TrajectoryFlag::MachTransition => "mach_transition".to_string(),
17 TrajectoryFlag::Apex => "apex".to_string(),
18 }
19 }
20}
21
22#[derive(Debug, Clone)]
24pub struct TrajectorySample {
25 pub distance_m: f64,
26 pub drop_m: f64,
27 pub wind_drift_m: f64,
28 pub velocity_mps: f64,
29 pub energy_j: f64,
30 pub time_s: f64,
31 pub flags: Vec<TrajectoryFlag>,
32}
33
34#[derive(Debug, Clone)]
36pub struct TrajectoryData {
37 pub times: Vec<f64>,
38 pub positions: Vec<Vector3<f64>>, pub velocities: Vec<Vector3<f64>>, pub transonic_distances: Vec<f64>, }
42
43#[derive(Debug, Clone)]
45pub struct TrajectoryOutputs {
46 pub target_distance_horiz_m: f64,
47 pub target_vertical_height_m: f64,
48 pub time_of_flight_s: f64,
49 pub max_ord_dist_horiz_m: f64,
50 pub sight_height_m: f64,
53}
54
55pub fn sample_trajectory(
57 trajectory_data: &TrajectoryData,
58 outputs: &TrajectoryOutputs,
59 step_m: f64,
60 mass_kg: f64,
61) -> Vec<TrajectorySample> {
62 let step_size = if step_m <= 0.0 {
63 return Vec::new();
64 } else if step_m < 0.1 {
65 0.1
66 } else {
67 step_m
68 };
69
70 let max_dist = outputs.target_distance_horiz_m;
72 if max_dist < 1e-9 {
73 return Vec::new();
74 }
75
76 let x_vals: Vec<f64> = trajectory_data.positions.iter().map(|p| p.x).collect();
78 let y_vals: Vec<f64> = trajectory_data.positions.iter().map(|p| p.y).collect();
79 let z_vals: Vec<f64> = trajectory_data.positions.iter().map(|p| p.z).collect();
80
81 let speeds: Vec<f64> = trajectory_data
83 .velocities
84 .iter()
85 .map(|v| v.norm())
86 .collect();
87 let energies: Vec<f64> = speeds
88 .iter()
89 .map(|&speed| 0.5 * mass_kg * speed * speed)
90 .collect();
91
92 let num_steps = (max_dist / step_size).ceil() as usize + 1;
95 let distances: Vec<f64> = (0..num_steps)
96 .map(|i| i as f64 * step_size)
97 .filter(|&d| d <= max_dist + 0.1) .collect();
99
100 let mut samples = Vec::with_capacity(distances.len());
102
103 for &distance in &distances {
104 let y_interp = interpolate(&z_vals, &y_vals, distance); let wind_drift = interpolate(&z_vals, &x_vals, distance); let velocity = interpolate(&z_vals, &speeds, distance); let time = interpolate(&z_vals, &trajectory_data.times, distance); let energy = interpolate(&z_vals, &energies, distance); let los_y = outputs.sight_height_m
127 + (outputs.target_vertical_height_m - outputs.sight_height_m) * distance / max_dist;
128 let drop = los_y - y_interp; samples.push(TrajectorySample {
131 distance_m: distance,
132 drop_m: drop,
133 wind_drift_m: wind_drift,
134 velocity_mps: velocity,
135 energy_j: energy,
136 time_s: time,
137 flags: Vec::new(), });
139 }
140
141 add_trajectory_flags(&mut samples, &trajectory_data.transonic_distances, max_dist);
143
144 samples
145}
146
147fn interpolate(x_vals: &[f64], y_vals: &[f64], x: f64) -> f64 {
149 if x_vals.is_empty() || y_vals.is_empty() {
150 return 0.0;
151 }
152
153 if x_vals.len() != y_vals.len() {
154 return 0.0;
155 }
156
157 if x <= x_vals[0] {
158 return y_vals[0];
159 }
160
161 if x >= x_vals[x_vals.len() - 1] {
162 return y_vals[y_vals.len() - 1];
163 }
164
165 let mut left = 0;
167 let mut right = x_vals.len() - 1;
168
169 while right - left > 1 {
170 let mid = (left + right) / 2;
171 if x_vals[mid] <= x {
172 left = mid;
173 } else {
174 right = mid;
175 }
176 }
177
178 let x1 = x_vals[left];
180 let x2 = x_vals[right];
181 let y1 = y_vals[left];
182 let y2 = y_vals[right];
183
184 if (x2 - x1).abs() < f64::EPSILON {
185 return y1;
186 }
187
188 y1 + (y2 - y1) * (x - x1) / (x2 - x1)
189}
190
191fn add_trajectory_flags(
193 samples: &mut [TrajectorySample],
194 transonic_distances: &[f64],
195 target_distance_input_m: f64,
196) {
197 let tolerance = 1e-6;
198
199 detect_zero_crossings(samples, tolerance);
201
202 for &transonic_dist in transonic_distances {
204 if let Some(idx) = find_closest_sample_index(samples, transonic_dist) {
205 samples[idx].flags.push(TrajectoryFlag::MachTransition);
206 }
207 }
208
209 if samples.len() > 2 {
213 let target_distance_m = target_distance_input_m;
215
216 let mut min_drop = f64::INFINITY;
219 let mut apex_idx = 1;
220
221 for i in 1..samples.len() {
223 if samples[i].distance_m > target_distance_m {
225 break;
226 }
227
228 if samples[i].drop_m < min_drop {
229 min_drop = samples[i].drop_m;
230 apex_idx = i;
231 }
232 }
233
234 samples[apex_idx].flags.push(TrajectoryFlag::Apex);
236 }
237}
238
239fn detect_zero_crossings(samples: &mut [TrajectorySample], tolerance: f64) {
241 if samples.len() < 2 {
242 return;
243 }
244
245 let drops: Vec<f64> = samples.iter().map(|s| s.drop_m).collect();
246
247 for i in 0..(drops.len() - 1) {
249 let current = drops[i];
250 let next = drops[i + 1];
251
252 let crosses_zero = (current < -tolerance && next >= -tolerance)
254 || (current > tolerance && next <= tolerance);
255
256 if crosses_zero {
257 samples[i + 1].flags.push(TrajectoryFlag::ZeroCrossing);
258 }
259 }
260
261 for (i, &drop) in drops.iter().enumerate() {
263 if drop.abs() <= tolerance {
264 samples[i].flags.push(TrajectoryFlag::ZeroCrossing);
265 }
266 }
267
268 for sample in samples.iter_mut() {
270 let mut unique_flags = Vec::new();
271 let mut seen = HashSet::new();
272
273 for flag in &sample.flags {
274 if seen.insert(flag.clone()) {
275 unique_flags.push(flag.clone());
276 }
277 }
278 sample.flags = unique_flags;
279 }
280}
281
282fn find_closest_sample_index(samples: &[TrajectorySample], target_distance: f64) -> Option<usize> {
284 if samples.is_empty() {
285 return None;
286 }
287
288 let distances: Vec<f64> = samples.iter().map(|s| s.distance_m).collect();
290
291 let mut left = 0;
292 let mut right = distances.len();
293
294 while left < right {
295 let mid = (left + right) / 2;
296 if distances[mid] < target_distance {
297 left = mid + 1;
298 } else {
299 right = mid;
300 }
301 }
302
303 let mut best_idx = left.min(distances.len() - 1);
305
306 if left > 0 {
307 let left_dist = (distances[left - 1] - target_distance).abs();
308 let right_dist = (distances[best_idx] - target_distance).abs();
309
310 if left_dist <= right_dist {
312 best_idx = left - 1;
313 }
314 }
315
316 Some(best_idx)
317}
318
319pub fn trajectory_samples_to_dicts(samples: &[TrajectorySample]) -> Vec<TrajectoryDict> {
321 samples
322 .iter()
323 .map(|sample| TrajectoryDict {
324 distance_m: sample.distance_m,
325 drop_m: sample.drop_m,
326 wind_drift_m: sample.wind_drift_m,
327 velocity_mps: sample.velocity_mps,
328 energy_j: sample.energy_j,
329 time_s: sample.time_s,
330 flags: sample.flags.iter().map(|f| f.to_string()).collect(),
331 })
332 .collect()
333}
334
335#[derive(Debug, Clone)]
337pub struct TrajectoryDict {
338 pub distance_m: f64,
339 pub drop_m: f64,
340 pub wind_drift_m: f64,
341 pub velocity_mps: f64,
342 pub energy_j: f64,
343 pub time_s: f64,
344 pub flags: Vec<String>,
345}
346
347#[cfg(test)]
348mod tests {
349 use super::*;
350
351 #[test]
352 fn test_interpolate() {
353 let x_vals = vec![0.0, 1.0, 2.0, 3.0];
354 let y_vals = vec![0.0, 10.0, 20.0, 30.0];
355
356 assert_eq!(interpolate(&x_vals, &y_vals, 0.5), 5.0);
357 assert_eq!(interpolate(&x_vals, &y_vals, 1.5), 15.0);
358 assert_eq!(interpolate(&x_vals, &y_vals, 2.5), 25.0);
359
360 assert_eq!(interpolate(&x_vals, &y_vals, -1.0), 0.0); assert_eq!(interpolate(&x_vals, &y_vals, 4.0), 30.0); }
364
365 #[test]
366 fn test_find_closest_sample_index() {
367 let samples = vec![
368 TrajectorySample {
369 distance_m: 0.0,
370 drop_m: 0.0,
371 wind_drift_m: 0.0,
372 velocity_mps: 100.0,
373 energy_j: 1000.0,
374 time_s: 0.0,
375 flags: Vec::new(),
376 },
377 TrajectorySample {
378 distance_m: 10.0,
379 drop_m: -1.0,
380 wind_drift_m: 0.1,
381 velocity_mps: 95.0,
382 energy_j: 950.0,
383 time_s: 0.1,
384 flags: Vec::new(),
385 },
386 TrajectorySample {
387 distance_m: 20.0,
388 drop_m: -4.0,
389 wind_drift_m: 0.2,
390 velocity_mps: 90.0,
391 energy_j: 900.0,
392 time_s: 0.2,
393 flags: Vec::new(),
394 },
395 ];
396
397 assert_eq!(find_closest_sample_index(&samples, 5.0), Some(0));
398 assert_eq!(find_closest_sample_index(&samples, 12.0), Some(1));
399 assert_eq!(find_closest_sample_index(&samples, 18.0), Some(2));
400 }
401
402 #[test]
403 fn test_detect_zero_crossings() {
404 let mut samples = vec![
405 TrajectorySample {
406 distance_m: 0.0,
407 drop_m: 1.0, wind_drift_m: 0.0,
409 velocity_mps: 100.0,
410 energy_j: 1000.0,
411 time_s: 0.0,
412 flags: Vec::new(),
413 },
414 TrajectorySample {
415 distance_m: 10.0,
416 drop_m: -0.5, wind_drift_m: 0.1,
418 velocity_mps: 95.0,
419 energy_j: 950.0,
420 time_s: 0.1,
421 flags: Vec::new(),
422 },
423 TrajectorySample {
424 distance_m: 20.0,
425 drop_m: -2.0, wind_drift_m: 0.2,
427 velocity_mps: 90.0,
428 energy_j: 900.0,
429 time_s: 0.2,
430 flags: Vec::new(),
431 },
432 ];
433
434 detect_zero_crossings(&mut samples, 1e-6);
435
436 assert!(!samples[0].flags.contains(&TrajectoryFlag::ZeroCrossing));
438 assert!(samples[1].flags.contains(&TrajectoryFlag::ZeroCrossing));
439 assert!(!samples[2].flags.contains(&TrajectoryFlag::ZeroCrossing));
440 }
441
442 #[test]
443 fn test_sample_trajectory_basic() {
444 let trajectory_data = TrajectoryData {
447 times: vec![0.0, 1.0, 2.0],
448 positions: vec![
449 Vector3::new(0.0, 0.0, 0.0), Vector3::new(1.0, 10.0, 100.0), Vector3::new(2.0, 5.0, 200.0), ],
453 velocities: vec![
454 Vector3::new(1.0, 10.0, 100.0),
455 Vector3::new(1.0, 5.0, 95.0),
456 Vector3::new(1.0, 0.0, 90.0),
457 ],
458 transonic_distances: vec![150.0],
459 };
460
461 let outputs = TrajectoryOutputs {
462 target_distance_horiz_m: 200.0,
463 target_vertical_height_m: 0.0,
464 time_of_flight_s: 2.0,
465 max_ord_dist_horiz_m: 100.0,
466 sight_height_m: 0.0, };
468
469 let samples = sample_trajectory(&trajectory_data, &outputs, 50.0, 0.1);
470
471 assert_eq!(samples.len(), 5);
473 assert_eq!(samples[0].distance_m, 0.0);
474 assert_eq!(samples[1].distance_m, 50.0);
475 assert_eq!(samples[2].distance_m, 100.0);
476 assert_eq!(samples[3].distance_m, 150.0);
477 assert_eq!(samples[4].distance_m, 200.0);
478
479 assert!(samples[1].velocity_mps > 90.0 && samples[1].velocity_mps < 100.0);
481
482 assert!(samples[2].flags.contains(&TrajectoryFlag::Apex)); assert!(samples[3].flags.contains(&TrajectoryFlag::MachTransition)); }
486}