1use nalgebra::Vector3;
2use std::collections::HashSet;
3
4#[derive(Debug, Clone, PartialEq, Eq, Hash)]
6pub enum TrajectoryFlag {
7 ZeroCrossing,
8 MachTransition,
9 Apex,
10}
11
12impl TrajectoryFlag {
13 pub fn to_string(&self) -> String {
14 match self {
15 TrajectoryFlag::ZeroCrossing => "zero_crossing".to_string(),
16 TrajectoryFlag::MachTransition => "mach_transition".to_string(),
17 TrajectoryFlag::Apex => "apex".to_string(),
18 }
19 }
20}
21
22#[derive(Debug, Clone)]
24pub struct TrajectorySample {
25 pub distance_m: f64,
26 pub drop_m: f64,
27 pub wind_drift_m: f64,
28 pub velocity_mps: f64,
29 pub energy_j: f64,
30 pub time_s: f64,
31 pub flags: Vec<TrajectoryFlag>,
32}
33
34#[derive(Debug, Clone)]
36pub struct TrajectoryData {
37 pub times: Vec<f64>,
38 pub positions: Vec<Vector3<f64>>, pub velocities: Vec<Vector3<f64>>, pub transonic_distances: Vec<f64>, }
42
43#[derive(Debug, Clone)]
45pub struct TrajectoryOutputs {
46 pub target_distance_horiz_m: f64,
47 pub target_vertical_height_m: f64,
48 pub time_of_flight_s: f64,
49 pub max_ord_dist_horiz_m: f64,
50}
51
52pub fn sample_trajectory(
54 trajectory_data: &TrajectoryData,
55 outputs: &TrajectoryOutputs,
56 step_m: f64,
57 mass_kg: f64,
58) -> Vec<TrajectorySample> {
59 let step_size = if step_m <= 0.0 {
60 return Vec::new();
61 } else if step_m < 0.1 {
62 0.1
63 } else {
64 step_m
65 };
66
67 let max_dist = outputs.target_distance_horiz_m;
69 if max_dist < 1e-9 {
70 return Vec::new();
71 }
72
73 let x_vals: Vec<f64> = trajectory_data.positions.iter().map(|p| p.x).collect();
75 let y_vals: Vec<f64> = trajectory_data.positions.iter().map(|p| p.y).collect();
76 let z_vals: Vec<f64> = trajectory_data.positions.iter().map(|p| p.z).collect();
77
78 let speeds: Vec<f64> = trajectory_data.velocities.iter()
80 .map(|v| v.norm())
81 .collect();
82 let energies: Vec<f64> = speeds.iter()
83 .map(|&speed| 0.5 * mass_kg * speed * speed)
84 .collect();
85
86 let num_steps = (max_dist / step_size).ceil() as usize + 1;
89 let distances: Vec<f64> = (0..num_steps)
90 .map(|i| i as f64 * step_size)
91 .filter(|&d| d <= max_dist + 0.1) .collect();
93
94 let mut samples = Vec::with_capacity(distances.len());
96
97 let muzzle_y = if !y_vals.is_empty() { y_vals[0] } else { 0.0 };
99
100 for &distance in &distances {
101 let y_interp = interpolate(&x_vals, &y_vals, distance); let wind_drift = interpolate(&x_vals, &z_vals, distance); let velocity = interpolate(&x_vals, &speeds, distance); let time = interpolate(&x_vals, &trajectory_data.times, distance); let energy = interpolate(&x_vals, &energies, distance); let los_y = muzzle_y + (outputs.target_vertical_height_m - muzzle_y) * distance / max_dist;
116 let drop = los_y - y_interp; samples.push(TrajectorySample {
119 distance_m: distance,
120 drop_m: drop,
121 wind_drift_m: wind_drift,
122 velocity_mps: velocity,
123 energy_j: energy,
124 time_s: time,
125 flags: Vec::new(), });
127 }
128
129 add_trajectory_flags(&mut samples, &trajectory_data.transonic_distances, max_dist);
131
132 samples
133}
134
135fn interpolate(x_vals: &[f64], y_vals: &[f64], x: f64) -> f64 {
137 if x_vals.is_empty() || y_vals.is_empty() {
138 return 0.0;
139 }
140
141 if x_vals.len() != y_vals.len() {
142 return 0.0;
143 }
144
145 if x <= x_vals[0] {
146 return y_vals[0];
147 }
148
149 if x >= x_vals[x_vals.len() - 1] {
150 return y_vals[y_vals.len() - 1];
151 }
152
153 let mut left = 0;
155 let mut right = x_vals.len() - 1;
156
157 while right - left > 1 {
158 let mid = (left + right) / 2;
159 if x_vals[mid] <= x {
160 left = mid;
161 } else {
162 right = mid;
163 }
164 }
165
166 let x1 = x_vals[left];
168 let x2 = x_vals[right];
169 let y1 = y_vals[left];
170 let y2 = y_vals[right];
171
172 if (x2 - x1).abs() < f64::EPSILON {
173 return y1;
174 }
175
176 y1 + (y2 - y1) * (x - x1) / (x2 - x1)
177}
178
179fn add_trajectory_flags(
181 samples: &mut [TrajectorySample],
182 transonic_distances: &[f64],
183 target_distance_input_m: f64,
184) {
185 let tolerance = 1e-6;
186
187 detect_zero_crossings(samples, tolerance);
189
190 for &transonic_dist in transonic_distances {
192 if let Some(idx) = find_closest_sample_index(samples, transonic_dist) {
193 samples[idx].flags.push(TrajectoryFlag::MachTransition);
194 }
195 }
196
197 if samples.len() > 2 {
201 let target_distance_m = target_distance_input_m;
203
204 let mut min_drop = f64::INFINITY;
207 let mut apex_idx = 1;
208
209 for i in 1..samples.len() {
211 if samples[i].distance_m > target_distance_m {
213 break;
214 }
215
216 if samples[i].drop_m < min_drop {
217 min_drop = samples[i].drop_m;
218 apex_idx = i;
219 }
220 }
221
222 samples[apex_idx].flags.push(TrajectoryFlag::Apex);
224 }
225}
226
227fn detect_zero_crossings(samples: &mut [TrajectorySample], tolerance: f64) {
229 if samples.len() < 2 {
230 return;
231 }
232
233 let drops: Vec<f64> = samples.iter().map(|s| s.drop_m).collect();
234
235 for i in 0..(drops.len() - 1) {
237 let current = drops[i];
238 let next = drops[i + 1];
239
240 let crosses_zero = (current < -tolerance && next >= -tolerance) ||
242 (current > tolerance && next <= tolerance);
243
244 if crosses_zero {
245 samples[i + 1].flags.push(TrajectoryFlag::ZeroCrossing);
246 }
247 }
248
249 for (i, &drop) in drops.iter().enumerate() {
251 if drop.abs() <= tolerance {
252 samples[i].flags.push(TrajectoryFlag::ZeroCrossing);
253 }
254 }
255
256 for sample in samples.iter_mut() {
258 let mut unique_flags = Vec::new();
259 let mut seen = HashSet::new();
260
261 for flag in &sample.flags {
262 if seen.insert(flag.clone()) {
263 unique_flags.push(flag.clone());
264 }
265 }
266 sample.flags = unique_flags;
267 }
268}
269
270fn find_closest_sample_index(samples: &[TrajectorySample], target_distance: f64) -> Option<usize> {
272 if samples.is_empty() {
273 return None;
274 }
275
276 let distances: Vec<f64> = samples.iter().map(|s| s.distance_m).collect();
278
279 let mut left = 0;
280 let mut right = distances.len();
281
282 while left < right {
283 let mid = (left + right) / 2;
284 if distances[mid] < target_distance {
285 left = mid + 1;
286 } else {
287 right = mid;
288 }
289 }
290
291 let mut best_idx = left.min(distances.len() - 1);
293
294 if left > 0 {
295 let left_dist = (distances[left - 1] - target_distance).abs();
296 let right_dist = (distances[best_idx] - target_distance).abs();
297
298 if left_dist <= right_dist {
300 best_idx = left - 1;
301 }
302 }
303
304 Some(best_idx)
305}
306
307pub fn trajectory_samples_to_dicts(samples: &[TrajectorySample]) -> Vec<TrajectoryDict> {
309 samples.iter().map(|sample| {
310 TrajectoryDict {
311 distance_m: sample.distance_m,
312 drop_m: sample.drop_m,
313 wind_drift_m: sample.wind_drift_m,
314 velocity_mps: sample.velocity_mps,
315 energy_j: sample.energy_j,
316 time_s: sample.time_s,
317 flags: sample.flags.iter().map(|f| f.to_string()).collect(),
318 }
319 }).collect()
320}
321
322#[derive(Debug, Clone)]
324pub struct TrajectoryDict {
325 pub distance_m: f64,
326 pub drop_m: f64,
327 pub wind_drift_m: f64,
328 pub velocity_mps: f64,
329 pub energy_j: f64,
330 pub time_s: f64,
331 pub flags: Vec<String>,
332}
333
334#[cfg(test)]
335mod tests {
336 use super::*;
337
338 #[test]
339 fn test_interpolate() {
340 let x_vals = vec![0.0, 1.0, 2.0, 3.0];
341 let y_vals = vec![0.0, 10.0, 20.0, 30.0];
342
343 assert_eq!(interpolate(&x_vals, &y_vals, 0.5), 5.0);
344 assert_eq!(interpolate(&x_vals, &y_vals, 1.5), 15.0);
345 assert_eq!(interpolate(&x_vals, &y_vals, 2.5), 25.0);
346
347 assert_eq!(interpolate(&x_vals, &y_vals, -1.0), 0.0); assert_eq!(interpolate(&x_vals, &y_vals, 4.0), 30.0); }
351
352 #[test]
353 fn test_find_closest_sample_index() {
354 let samples = vec![
355 TrajectorySample {
356 distance_m: 0.0,
357 drop_m: 0.0,
358 wind_drift_m: 0.0,
359 velocity_mps: 100.0,
360 energy_j: 1000.0,
361 time_s: 0.0,
362 flags: Vec::new(),
363 },
364 TrajectorySample {
365 distance_m: 10.0,
366 drop_m: -1.0,
367 wind_drift_m: 0.1,
368 velocity_mps: 95.0,
369 energy_j: 950.0,
370 time_s: 0.1,
371 flags: Vec::new(),
372 },
373 TrajectorySample {
374 distance_m: 20.0,
375 drop_m: -4.0,
376 wind_drift_m: 0.2,
377 velocity_mps: 90.0,
378 energy_j: 900.0,
379 time_s: 0.2,
380 flags: Vec::new(),
381 },
382 ];
383
384 assert_eq!(find_closest_sample_index(&samples, 5.0), Some(0));
385 assert_eq!(find_closest_sample_index(&samples, 12.0), Some(1));
386 assert_eq!(find_closest_sample_index(&samples, 18.0), Some(2));
387 }
388
389 #[test]
390 fn test_detect_zero_crossings() {
391 let mut samples = vec![
392 TrajectorySample {
393 distance_m: 0.0,
394 drop_m: 1.0, wind_drift_m: 0.0,
396 velocity_mps: 100.0,
397 energy_j: 1000.0,
398 time_s: 0.0,
399 flags: Vec::new(),
400 },
401 TrajectorySample {
402 distance_m: 10.0,
403 drop_m: -0.5, wind_drift_m: 0.1,
405 velocity_mps: 95.0,
406 energy_j: 950.0,
407 time_s: 0.1,
408 flags: Vec::new(),
409 },
410 TrajectorySample {
411 distance_m: 20.0,
412 drop_m: -2.0, wind_drift_m: 0.2,
414 velocity_mps: 90.0,
415 energy_j: 900.0,
416 time_s: 0.2,
417 flags: Vec::new(),
418 },
419 ];
420
421 detect_zero_crossings(&mut samples, 1e-6);
422
423 assert!(!samples[0].flags.contains(&TrajectoryFlag::ZeroCrossing));
425 assert!(samples[1].flags.contains(&TrajectoryFlag::ZeroCrossing));
426 assert!(!samples[2].flags.contains(&TrajectoryFlag::ZeroCrossing));
427 }
428
429 #[test]
430 fn test_sample_trajectory_basic() {
431 let trajectory_data = TrajectoryData {
434 times: vec![0.0, 1.0, 2.0],
435 positions: vec![
436 Vector3::new(0.0, 0.0, 0.0), Vector3::new(100.0, 10.0, 1.0), Vector3::new(200.0, 5.0, 2.0), ],
440 velocities: vec![
441 Vector3::new(1.0, 10.0, 100.0),
442 Vector3::new(1.0, 5.0, 95.0),
443 Vector3::new(1.0, 0.0, 90.0),
444 ],
445 transonic_distances: vec![150.0],
446 };
447
448 let outputs = TrajectoryOutputs {
449 target_distance_horiz_m: 200.0,
450 target_vertical_height_m: 0.0,
451 time_of_flight_s: 2.0,
452 max_ord_dist_horiz_m: 100.0,
453 };
454
455 let samples = sample_trajectory(&trajectory_data, &outputs, 50.0, 0.1);
456
457 assert_eq!(samples.len(), 5);
459 assert_eq!(samples[0].distance_m, 0.0);
460 assert_eq!(samples[1].distance_m, 50.0);
461 assert_eq!(samples[2].distance_m, 100.0);
462 assert_eq!(samples[3].distance_m, 150.0);
463 assert_eq!(samples[4].distance_m, 200.0);
464
465 assert!(samples[1].velocity_mps > 90.0 && samples[1].velocity_mps < 100.0);
467
468 assert!(samples[2].flags.contains(&TrajectoryFlag::Apex)); assert!(samples[3].flags.contains(&TrajectoryFlag::MachTransition)); }
472}