ballistics_engine/
trajectory_sampling.rs

1use nalgebra::Vector3;
2use std::collections::HashSet;
3
4/// Trajectory flags for notable events
5#[derive(Debug, Clone, PartialEq, Eq, Hash)]
6pub enum TrajectoryFlag {
7    ZeroCrossing,
8    MachTransition,
9    Apex,
10}
11
12impl TrajectoryFlag {
13    pub fn to_string(&self) -> String {
14        match self {
15            TrajectoryFlag::ZeroCrossing => "zero_crossing".to_string(),
16            TrajectoryFlag::MachTransition => "mach_transition".to_string(),
17            TrajectoryFlag::Apex => "apex".to_string(),
18        }
19    }
20}
21
22/// Single trajectory sample point
23#[derive(Debug, Clone)]
24pub struct TrajectorySample {
25    pub distance_m: f64,
26    pub drop_m: f64,
27    pub wind_drift_m: f64,
28    pub velocity_mps: f64,
29    pub energy_j: f64,
30    pub time_s: f64,
31    pub flags: Vec<TrajectoryFlag>,
32}
33
34/// Trajectory solution data for sampling
35#[derive(Debug, Clone)]
36pub struct TrajectoryData {
37    pub times: Vec<f64>,
38    pub positions: Vec<Vector3<f64>>,  // [x, y, z] positions
39    pub velocities: Vec<Vector3<f64>>, // [vx, vy, vz] velocities
40    pub transonic_distances: Vec<f64>, // Distances where mach transitions occur
41}
42
43/// Output data for trajectory sampling
44#[derive(Debug, Clone)]
45pub struct TrajectoryOutputs {
46    pub target_distance_horiz_m: f64,
47    pub target_vertical_height_m: f64,
48    pub time_of_flight_s: f64,
49    pub max_ord_dist_horiz_m: f64,
50}
51
52/// Sample trajectory at regular distance intervals with vectorized operations
53pub fn sample_trajectory(
54    trajectory_data: &TrajectoryData,
55    outputs: &TrajectoryOutputs,
56    step_m: f64,
57    mass_kg: f64,
58) -> Vec<TrajectorySample> {
59    let step_size = if step_m <= 0.0 {
60        return Vec::new();
61    } else if step_m < 0.1 {
62        0.1
63    } else {
64        step_m
65    };
66
67    // Use the input target distance as the limit for sampling
68    let max_dist = outputs.target_distance_horiz_m;
69    if max_dist < 1e-9 {
70        return Vec::new();
71    }
72
73    // Extract trajectory arrays for vectorized operations
74    let x_vals: Vec<f64> = trajectory_data.positions.iter().map(|p| p.x).collect();
75    let y_vals: Vec<f64> = trajectory_data.positions.iter().map(|p| p.y).collect();
76    let z_vals: Vec<f64> = trajectory_data.positions.iter().map(|p| p.z).collect();
77
78    // Calculate speeds and energies
79    let speeds: Vec<f64> = trajectory_data
80        .velocities
81        .iter()
82        .map(|v| v.norm())
83        .collect();
84    let energies: Vec<f64> = speeds
85        .iter()
86        .map(|&speed| 0.5 * mass_kg * speed * speed)
87        .collect();
88
89    // Generate sampling distances
90    // Calculate number of steps to reach target without exceeding it
91    let num_steps = (max_dist / step_size).ceil() as usize + 1;
92    let distances: Vec<f64> = (0..num_steps)
93        .map(|i| i as f64 * step_size)
94        .filter(|&d| d <= max_dist + 0.1) // Stop exactly at target (with tiny tolerance for rounding)
95        .collect();
96
97    // Vectorized interpolation for all trajectory data
98    let mut samples = Vec::with_capacity(distances.len());
99
100    // Get initial height (muzzle height) for proper LOS calculation
101    let muzzle_y = if !y_vals.is_empty() { y_vals[0] } else { 0.0 };
102
103    for &distance in &distances {
104        // Interpolate using z (downrange) as the independent variable
105        // Coordinate system: x=lateral (wind drift), y=vertical, z=downrange
106        let y_interp = interpolate(&z_vals, &y_vals, distance); // vertical at downrange distance
107        let wind_drift = interpolate(&z_vals, &x_vals, distance); // lateral drift at downrange distance
108        let velocity = interpolate(&z_vals, &speeds, distance); // velocity at downrange distance
109        let time = interpolate(&z_vals, &trajectory_data.times, distance); // time at downrange distance
110        let energy = interpolate(&z_vals, &energies, distance); // energy at downrange distance
111
112        // Calculate line-of-sight y-coordinate and drop
113        // The LOS is the straight line from initial position to target
114        // For coordinate shots: goes from muzzle_y to target_vertical_height_m
115        // Drop convention:
116        // - Positive drop means bullet is below LOS (has dropped)
117        // - Negative drop means bullet is above LOS (has risen)
118        // Therefore: drop = LOS - actual (not actual - LOS)
119        let los_y = muzzle_y + (outputs.target_vertical_height_m - muzzle_y) * distance / max_dist;
120        let drop = los_y - y_interp; // LOS - actual: positive when bullet is below LOS
121
122        samples.push(TrajectorySample {
123            distance_m: distance,
124            drop_m: drop,
125            wind_drift_m: wind_drift,
126            velocity_mps: velocity,
127            energy_j: energy,
128            time_s: time,
129            flags: Vec::new(), // Flags will be added later
130        });
131    }
132
133    // Add flags using vectorized detection
134    add_trajectory_flags(&mut samples, &trajectory_data.transonic_distances, max_dist);
135
136    samples
137}
138
139/// Linear interpolation function optimized for trajectory data
140fn interpolate(x_vals: &[f64], y_vals: &[f64], x: f64) -> f64 {
141    if x_vals.is_empty() || y_vals.is_empty() {
142        return 0.0;
143    }
144
145    if x_vals.len() != y_vals.len() {
146        return 0.0;
147    }
148
149    if x <= x_vals[0] {
150        return y_vals[0];
151    }
152
153    if x >= x_vals[x_vals.len() - 1] {
154        return y_vals[y_vals.len() - 1];
155    }
156
157    // Binary search for the correct interval
158    let mut left = 0;
159    let mut right = x_vals.len() - 1;
160
161    while right - left > 1 {
162        let mid = (left + right) / 2;
163        if x_vals[mid] <= x {
164            left = mid;
165        } else {
166            right = mid;
167        }
168    }
169
170    // Linear interpolation
171    let x1 = x_vals[left];
172    let x2 = x_vals[right];
173    let y1 = y_vals[left];
174    let y2 = y_vals[right];
175
176    if (x2 - x1).abs() < f64::EPSILON {
177        return y1;
178    }
179
180    y1 + (y2 - y1) * (x - x1) / (x2 - x1)
181}
182
183/// Add trajectory flags using vectorized detection algorithms
184fn add_trajectory_flags(
185    samples: &mut [TrajectorySample],
186    transonic_distances: &[f64],
187    target_distance_input_m: f64,
188) {
189    let tolerance = 1e-6;
190
191    // 1. Zero crossings - vectorized detection
192    detect_zero_crossings(samples, tolerance);
193
194    // 2. Mach transitions
195    for &transonic_dist in transonic_distances {
196        if let Some(idx) = find_closest_sample_index(samples, transonic_dist) {
197            samples[idx].flags.push(TrajectoryFlag::MachTransition);
198        }
199    }
200
201    // 3. Apex - find the point with maximum height between muzzle and target
202    // Since drop is positive when bullet is below LOS and negative when above,
203    // the apex is where drop is minimum (most negative)
204    if samples.len() > 2 {
205        // Use the target distance passed as parameter
206        let target_distance_m = target_distance_input_m;
207
208        // Find the index of maximum height (minimum drop, most negative) within target distance
209        // Exclude first point (always 0 for auto-zeroing)
210        let mut min_drop = f64::INFINITY;
211        let mut apex_idx = 1;
212
213        // Search from index 1, but stop at target distance
214        for i in 1..samples.len() {
215            // Only consider points up to target distance
216            if samples[i].distance_m > target_distance_m {
217                break;
218            }
219
220            if samples[i].drop_m < min_drop {
221                min_drop = samples[i].drop_m;
222                apex_idx = i;
223            }
224        }
225
226        // Mark the apex
227        samples[apex_idx].flags.push(TrajectoryFlag::Apex);
228    }
229}
230
231/// Detect zero crossings in trajectory drop values using vectorized operations
232fn detect_zero_crossings(samples: &mut [TrajectorySample], tolerance: f64) {
233    if samples.len() < 2 {
234        return;
235    }
236
237    let drops: Vec<f64> = samples.iter().map(|s| s.drop_m).collect();
238
239    // Find crossing indices where drop changes sign
240    for i in 0..(drops.len() - 1) {
241        let current = drops[i];
242        let next = drops[i + 1];
243
244        // Check for sign change crossings
245        let crosses_zero = (current < -tolerance && next >= -tolerance)
246            || (current > tolerance && next <= tolerance);
247
248        if crosses_zero {
249            samples[i + 1].flags.push(TrajectoryFlag::ZeroCrossing);
250        }
251    }
252
253    // Find points very close to zero
254    for (i, &drop) in drops.iter().enumerate() {
255        if drop.abs() <= tolerance {
256            samples[i].flags.push(TrajectoryFlag::ZeroCrossing);
257        }
258    }
259
260    // Remove duplicate zero crossing flags
261    for sample in samples.iter_mut() {
262        let mut unique_flags = Vec::new();
263        let mut seen = HashSet::new();
264
265        for flag in &sample.flags {
266            if seen.insert(flag.clone()) {
267                unique_flags.push(flag.clone());
268            }
269        }
270        sample.flags = unique_flags;
271    }
272}
273
274/// Find the closest sample index to a given distance
275fn find_closest_sample_index(samples: &[TrajectorySample], target_distance: f64) -> Option<usize> {
276    if samples.is_empty() {
277        return None;
278    }
279
280    // Binary search for the closest distance
281    let distances: Vec<f64> = samples.iter().map(|s| s.distance_m).collect();
282
283    let mut left = 0;
284    let mut right = distances.len();
285
286    while left < right {
287        let mid = (left + right) / 2;
288        if distances[mid] < target_distance {
289            left = mid + 1;
290        } else {
291            right = mid;
292        }
293    }
294
295    // Find the closest point (could be left-1 or left)
296    let mut best_idx = left.min(distances.len() - 1);
297
298    if left > 0 {
299        let left_dist = (distances[left - 1] - target_distance).abs();
300        let right_dist = (distances[best_idx] - target_distance).abs();
301
302        // Prefer earlier index in case of tie
303        if left_dist <= right_dist {
304            best_idx = left - 1;
305        }
306    }
307
308    Some(best_idx)
309}
310
311/// Convert trajectory samples to Python-compatible format
312pub fn trajectory_samples_to_dicts(samples: &[TrajectorySample]) -> Vec<TrajectoryDict> {
313    samples
314        .iter()
315        .map(|sample| TrajectoryDict {
316            distance_m: sample.distance_m,
317            drop_m: sample.drop_m,
318            wind_drift_m: sample.wind_drift_m,
319            velocity_mps: sample.velocity_mps,
320            energy_j: sample.energy_j,
321            time_s: sample.time_s,
322            flags: sample.flags.iter().map(|f| f.to_string()).collect(),
323        })
324        .collect()
325}
326
327/// Python-compatible trajectory sample structure
328#[derive(Debug, Clone)]
329pub struct TrajectoryDict {
330    pub distance_m: f64,
331    pub drop_m: f64,
332    pub wind_drift_m: f64,
333    pub velocity_mps: f64,
334    pub energy_j: f64,
335    pub time_s: f64,
336    pub flags: Vec<String>,
337}
338
339#[cfg(test)]
340mod tests {
341    use super::*;
342
343    #[test]
344    fn test_interpolate() {
345        let x_vals = vec![0.0, 1.0, 2.0, 3.0];
346        let y_vals = vec![0.0, 10.0, 20.0, 30.0];
347
348        assert_eq!(interpolate(&x_vals, &y_vals, 0.5), 5.0);
349        assert_eq!(interpolate(&x_vals, &y_vals, 1.5), 15.0);
350        assert_eq!(interpolate(&x_vals, &y_vals, 2.5), 25.0);
351
352        // Test boundary conditions
353        assert_eq!(interpolate(&x_vals, &y_vals, -1.0), 0.0); // Below range
354        assert_eq!(interpolate(&x_vals, &y_vals, 4.0), 30.0); // Above range
355    }
356
357    #[test]
358    fn test_find_closest_sample_index() {
359        let samples = vec![
360            TrajectorySample {
361                distance_m: 0.0,
362                drop_m: 0.0,
363                wind_drift_m: 0.0,
364                velocity_mps: 100.0,
365                energy_j: 1000.0,
366                time_s: 0.0,
367                flags: Vec::new(),
368            },
369            TrajectorySample {
370                distance_m: 10.0,
371                drop_m: -1.0,
372                wind_drift_m: 0.1,
373                velocity_mps: 95.0,
374                energy_j: 950.0,
375                time_s: 0.1,
376                flags: Vec::new(),
377            },
378            TrajectorySample {
379                distance_m: 20.0,
380                drop_m: -4.0,
381                wind_drift_m: 0.2,
382                velocity_mps: 90.0,
383                energy_j: 900.0,
384                time_s: 0.2,
385                flags: Vec::new(),
386            },
387        ];
388
389        assert_eq!(find_closest_sample_index(&samples, 5.0), Some(0));
390        assert_eq!(find_closest_sample_index(&samples, 12.0), Some(1));
391        assert_eq!(find_closest_sample_index(&samples, 18.0), Some(2));
392    }
393
394    #[test]
395    fn test_detect_zero_crossings() {
396        let mut samples = vec![
397            TrajectorySample {
398                distance_m: 0.0,
399                drop_m: 1.0, // Positive
400                wind_drift_m: 0.0,
401                velocity_mps: 100.0,
402                energy_j: 1000.0,
403                time_s: 0.0,
404                flags: Vec::new(),
405            },
406            TrajectorySample {
407                distance_m: 10.0,
408                drop_m: -0.5, // Negative - crossing here
409                wind_drift_m: 0.1,
410                velocity_mps: 95.0,
411                energy_j: 950.0,
412                time_s: 0.1,
413                flags: Vec::new(),
414            },
415            TrajectorySample {
416                distance_m: 20.0,
417                drop_m: -2.0, // Still negative
418                wind_drift_m: 0.2,
419                velocity_mps: 90.0,
420                energy_j: 900.0,
421                time_s: 0.2,
422                flags: Vec::new(),
423            },
424        ];
425
426        detect_zero_crossings(&mut samples, 1e-6);
427
428        // Should have a zero crossing flag at index 1
429        assert!(!samples[0].flags.contains(&TrajectoryFlag::ZeroCrossing));
430        assert!(samples[1].flags.contains(&TrajectoryFlag::ZeroCrossing));
431        assert!(!samples[2].flags.contains(&TrajectoryFlag::ZeroCrossing));
432    }
433
434    #[test]
435    fn test_sample_trajectory_basic() {
436        // Create simple test trajectory data
437        // Coordinate system: x=lateral (wind drift), y=vertical, z=downrange
438        let trajectory_data = TrajectoryData {
439            times: vec![0.0, 1.0, 2.0],
440            positions: vec![
441                Vector3::new(0.0, 0.0, 0.0),     // x=0 (no drift), y=0 (vertical), z=0 (start)
442                Vector3::new(1.0, 10.0, 100.0),  // x=1 (drift), y=10 (apex height), z=100 (mid)
443                Vector3::new(2.0, 5.0, 200.0),   // x=2 (drift), y=5 (below apex), z=200 (end)
444            ],
445            velocities: vec![
446                Vector3::new(1.0, 10.0, 100.0),
447                Vector3::new(1.0, 5.0, 95.0),
448                Vector3::new(1.0, 0.0, 90.0),
449            ],
450            transonic_distances: vec![150.0],
451        };
452
453        let outputs = TrajectoryOutputs {
454            target_distance_horiz_m: 200.0,
455            target_vertical_height_m: 0.0,
456            time_of_flight_s: 2.0,
457            max_ord_dist_horiz_m: 100.0,
458        };
459
460        let samples = sample_trajectory(&trajectory_data, &outputs, 50.0, 0.1);
461
462        // Should have samples at 0, 50, 100, 150, 200 meters
463        assert_eq!(samples.len(), 5);
464        assert_eq!(samples[0].distance_m, 0.0);
465        assert_eq!(samples[1].distance_m, 50.0);
466        assert_eq!(samples[2].distance_m, 100.0);
467        assert_eq!(samples[3].distance_m, 150.0);
468        assert_eq!(samples[4].distance_m, 200.0);
469
470        // Check that interpolation is working
471        assert!(samples[1].velocity_mps > 90.0 && samples[1].velocity_mps < 100.0);
472
473        // Check flags
474        assert!(samples[2].flags.contains(&TrajectoryFlag::Apex)); // At apex distance
475        assert!(samples[3].flags.contains(&TrajectoryFlag::MachTransition)); // At transonic distance
476    }
477}