ballistics_engine/
trajectory_sampling.rs

1use nalgebra::Vector3;
2use std::collections::HashSet;
3
4/// Trajectory flags for notable events
5#[derive(Debug, Clone, PartialEq, Eq, Hash)]
6pub enum TrajectoryFlag {
7    ZeroCrossing,
8    MachTransition,
9    Apex,
10}
11
12impl TrajectoryFlag {
13    pub fn to_string(&self) -> String {
14        match self {
15            TrajectoryFlag::ZeroCrossing => "zero_crossing".to_string(),
16            TrajectoryFlag::MachTransition => "mach_transition".to_string(),
17            TrajectoryFlag::Apex => "apex".to_string(),
18        }
19    }
20}
21
22/// Single trajectory sample point
23#[derive(Debug, Clone)]
24pub struct TrajectorySample {
25    pub distance_m: f64,
26    pub drop_m: f64,
27    pub wind_drift_m: f64,
28    pub velocity_mps: f64,
29    pub energy_j: f64,
30    pub time_s: f64,
31    pub flags: Vec<TrajectoryFlag>,
32}
33
34/// Trajectory solution data for sampling
35#[derive(Debug, Clone)]
36pub struct TrajectoryData {
37    pub times: Vec<f64>,
38    pub positions: Vec<Vector3<f64>>,  // [x, y, z] positions
39    pub velocities: Vec<Vector3<f64>>, // [vx, vy, vz] velocities
40    pub transonic_distances: Vec<f64>, // Distances where mach transitions occur
41}
42
43/// Output data for trajectory sampling
44#[derive(Debug, Clone)]
45pub struct TrajectoryOutputs {
46    pub target_distance_horiz_m: f64,
47    pub target_vertical_height_m: f64,
48    pub time_of_flight_s: f64,
49    pub max_ord_dist_horiz_m: f64,
50}
51
52/// Sample trajectory at regular distance intervals with vectorized operations
53pub fn sample_trajectory(
54    trajectory_data: &TrajectoryData,
55    outputs: &TrajectoryOutputs,
56    step_m: f64,
57    mass_kg: f64,
58) -> Vec<TrajectorySample> {
59    let step_size = if step_m <= 0.0 {
60        return Vec::new();
61    } else if step_m < 0.1 {
62        0.1
63    } else {
64        step_m
65    };
66
67    // Use the input target distance as the limit for sampling
68    let max_dist = outputs.target_distance_horiz_m;
69    if max_dist < 1e-9 {
70        return Vec::new();
71    }
72
73    // Extract trajectory arrays for vectorized operations
74    let x_vals: Vec<f64> = trajectory_data.positions.iter().map(|p| p.x).collect();
75    let y_vals: Vec<f64> = trajectory_data.positions.iter().map(|p| p.y).collect();
76    let z_vals: Vec<f64> = trajectory_data.positions.iter().map(|p| p.z).collect();
77
78    // Calculate speeds and energies
79    let speeds: Vec<f64> = trajectory_data
80        .velocities
81        .iter()
82        .map(|v| v.norm())
83        .collect();
84    let energies: Vec<f64> = speeds
85        .iter()
86        .map(|&speed| 0.5 * mass_kg * speed * speed)
87        .collect();
88
89    // Generate sampling distances
90    // Calculate number of steps to reach target without exceeding it
91    let num_steps = (max_dist / step_size).ceil() as usize + 1;
92    let distances: Vec<f64> = (0..num_steps)
93        .map(|i| i as f64 * step_size)
94        .filter(|&d| d <= max_dist + 0.1) // Stop exactly at target (with tiny tolerance for rounding)
95        .collect();
96
97    // Vectorized interpolation for all trajectory data
98    let mut samples = Vec::with_capacity(distances.len());
99
100    // Get initial height (muzzle height) for proper LOS calculation
101    let muzzle_y = if !y_vals.is_empty() { y_vals[0] } else { 0.0 };
102
103    for &distance in &distances {
104        // Interpolate using x (downrange) as the independent variable - STANDARD BALLISTICS CONVENTION
105        let y_interp = interpolate(&x_vals, &y_vals, distance); // vertical at downrange distance
106        let wind_drift = interpolate(&x_vals, &z_vals, distance); // lateral drift at downrange distance
107        let velocity = interpolate(&x_vals, &speeds, distance); // velocity at downrange distance
108        let time = interpolate(&x_vals, &trajectory_data.times, distance); // time at downrange distance
109        let energy = interpolate(&x_vals, &energies, distance); // energy at downrange distance
110
111        // Calculate line-of-sight y-coordinate and drop
112        // The LOS is the straight line from initial position to target
113        // For coordinate shots: goes from muzzle_y to target_vertical_height_m
114        // Drop convention:
115        // - Positive drop means bullet is below LOS (has dropped)
116        // - Negative drop means bullet is above LOS (has risen)
117        // Therefore: drop = LOS - actual (not actual - LOS)
118        let los_y = muzzle_y + (outputs.target_vertical_height_m - muzzle_y) * distance / max_dist;
119        let drop = los_y - y_interp; // LOS - actual: positive when bullet is below LOS
120
121        samples.push(TrajectorySample {
122            distance_m: distance,
123            drop_m: drop,
124            wind_drift_m: wind_drift,
125            velocity_mps: velocity,
126            energy_j: energy,
127            time_s: time,
128            flags: Vec::new(), // Flags will be added later
129        });
130    }
131
132    // Add flags using vectorized detection
133    add_trajectory_flags(&mut samples, &trajectory_data.transonic_distances, max_dist);
134
135    samples
136}
137
138/// Linear interpolation function optimized for trajectory data
139fn interpolate(x_vals: &[f64], y_vals: &[f64], x: f64) -> f64 {
140    if x_vals.is_empty() || y_vals.is_empty() {
141        return 0.0;
142    }
143
144    if x_vals.len() != y_vals.len() {
145        return 0.0;
146    }
147
148    if x <= x_vals[0] {
149        return y_vals[0];
150    }
151
152    if x >= x_vals[x_vals.len() - 1] {
153        return y_vals[y_vals.len() - 1];
154    }
155
156    // Binary search for the correct interval
157    let mut left = 0;
158    let mut right = x_vals.len() - 1;
159
160    while right - left > 1 {
161        let mid = (left + right) / 2;
162        if x_vals[mid] <= x {
163            left = mid;
164        } else {
165            right = mid;
166        }
167    }
168
169    // Linear interpolation
170    let x1 = x_vals[left];
171    let x2 = x_vals[right];
172    let y1 = y_vals[left];
173    let y2 = y_vals[right];
174
175    if (x2 - x1).abs() < f64::EPSILON {
176        return y1;
177    }
178
179    y1 + (y2 - y1) * (x - x1) / (x2 - x1)
180}
181
182/// Add trajectory flags using vectorized detection algorithms
183fn add_trajectory_flags(
184    samples: &mut [TrajectorySample],
185    transonic_distances: &[f64],
186    target_distance_input_m: f64,
187) {
188    let tolerance = 1e-6;
189
190    // 1. Zero crossings - vectorized detection
191    detect_zero_crossings(samples, tolerance);
192
193    // 2. Mach transitions
194    for &transonic_dist in transonic_distances {
195        if let Some(idx) = find_closest_sample_index(samples, transonic_dist) {
196            samples[idx].flags.push(TrajectoryFlag::MachTransition);
197        }
198    }
199
200    // 3. Apex - find the point with maximum height between muzzle and target
201    // Since drop is positive when bullet is below LOS and negative when above,
202    // the apex is where drop is minimum (most negative)
203    if samples.len() > 2 {
204        // Use the target distance passed as parameter
205        let target_distance_m = target_distance_input_m;
206
207        // Find the index of maximum height (minimum drop, most negative) within target distance
208        // Exclude first point (always 0 for auto-zeroing)
209        let mut min_drop = f64::INFINITY;
210        let mut apex_idx = 1;
211
212        // Search from index 1, but stop at target distance
213        for i in 1..samples.len() {
214            // Only consider points up to target distance
215            if samples[i].distance_m > target_distance_m {
216                break;
217            }
218
219            if samples[i].drop_m < min_drop {
220                min_drop = samples[i].drop_m;
221                apex_idx = i;
222            }
223        }
224
225        // Mark the apex
226        samples[apex_idx].flags.push(TrajectoryFlag::Apex);
227    }
228}
229
230/// Detect zero crossings in trajectory drop values using vectorized operations
231fn detect_zero_crossings(samples: &mut [TrajectorySample], tolerance: f64) {
232    if samples.len() < 2 {
233        return;
234    }
235
236    let drops: Vec<f64> = samples.iter().map(|s| s.drop_m).collect();
237
238    // Find crossing indices where drop changes sign
239    for i in 0..(drops.len() - 1) {
240        let current = drops[i];
241        let next = drops[i + 1];
242
243        // Check for sign change crossings
244        let crosses_zero = (current < -tolerance && next >= -tolerance)
245            || (current > tolerance && next <= tolerance);
246
247        if crosses_zero {
248            samples[i + 1].flags.push(TrajectoryFlag::ZeroCrossing);
249        }
250    }
251
252    // Find points very close to zero
253    for (i, &drop) in drops.iter().enumerate() {
254        if drop.abs() <= tolerance {
255            samples[i].flags.push(TrajectoryFlag::ZeroCrossing);
256        }
257    }
258
259    // Remove duplicate zero crossing flags
260    for sample in samples.iter_mut() {
261        let mut unique_flags = Vec::new();
262        let mut seen = HashSet::new();
263
264        for flag in &sample.flags {
265            if seen.insert(flag.clone()) {
266                unique_flags.push(flag.clone());
267            }
268        }
269        sample.flags = unique_flags;
270    }
271}
272
273/// Find the closest sample index to a given distance
274fn find_closest_sample_index(samples: &[TrajectorySample], target_distance: f64) -> Option<usize> {
275    if samples.is_empty() {
276        return None;
277    }
278
279    // Binary search for the closest distance
280    let distances: Vec<f64> = samples.iter().map(|s| s.distance_m).collect();
281
282    let mut left = 0;
283    let mut right = distances.len();
284
285    while left < right {
286        let mid = (left + right) / 2;
287        if distances[mid] < target_distance {
288            left = mid + 1;
289        } else {
290            right = mid;
291        }
292    }
293
294    // Find the closest point (could be left-1 or left)
295    let mut best_idx = left.min(distances.len() - 1);
296
297    if left > 0 {
298        let left_dist = (distances[left - 1] - target_distance).abs();
299        let right_dist = (distances[best_idx] - target_distance).abs();
300
301        // Prefer earlier index in case of tie
302        if left_dist <= right_dist {
303            best_idx = left - 1;
304        }
305    }
306
307    Some(best_idx)
308}
309
310/// Convert trajectory samples to Python-compatible format
311pub fn trajectory_samples_to_dicts(samples: &[TrajectorySample]) -> Vec<TrajectoryDict> {
312    samples
313        .iter()
314        .map(|sample| TrajectoryDict {
315            distance_m: sample.distance_m,
316            drop_m: sample.drop_m,
317            wind_drift_m: sample.wind_drift_m,
318            velocity_mps: sample.velocity_mps,
319            energy_j: sample.energy_j,
320            time_s: sample.time_s,
321            flags: sample.flags.iter().map(|f| f.to_string()).collect(),
322        })
323        .collect()
324}
325
326/// Python-compatible trajectory sample structure
327#[derive(Debug, Clone)]
328pub struct TrajectoryDict {
329    pub distance_m: f64,
330    pub drop_m: f64,
331    pub wind_drift_m: f64,
332    pub velocity_mps: f64,
333    pub energy_j: f64,
334    pub time_s: f64,
335    pub flags: Vec<String>,
336}
337
338#[cfg(test)]
339mod tests {
340    use super::*;
341
342    #[test]
343    fn test_interpolate() {
344        let x_vals = vec![0.0, 1.0, 2.0, 3.0];
345        let y_vals = vec![0.0, 10.0, 20.0, 30.0];
346
347        assert_eq!(interpolate(&x_vals, &y_vals, 0.5), 5.0);
348        assert_eq!(interpolate(&x_vals, &y_vals, 1.5), 15.0);
349        assert_eq!(interpolate(&x_vals, &y_vals, 2.5), 25.0);
350
351        // Test boundary conditions
352        assert_eq!(interpolate(&x_vals, &y_vals, -1.0), 0.0); // Below range
353        assert_eq!(interpolate(&x_vals, &y_vals, 4.0), 30.0); // Above range
354    }
355
356    #[test]
357    fn test_find_closest_sample_index() {
358        let samples = vec![
359            TrajectorySample {
360                distance_m: 0.0,
361                drop_m: 0.0,
362                wind_drift_m: 0.0,
363                velocity_mps: 100.0,
364                energy_j: 1000.0,
365                time_s: 0.0,
366                flags: Vec::new(),
367            },
368            TrajectorySample {
369                distance_m: 10.0,
370                drop_m: -1.0,
371                wind_drift_m: 0.1,
372                velocity_mps: 95.0,
373                energy_j: 950.0,
374                time_s: 0.1,
375                flags: Vec::new(),
376            },
377            TrajectorySample {
378                distance_m: 20.0,
379                drop_m: -4.0,
380                wind_drift_m: 0.2,
381                velocity_mps: 90.0,
382                energy_j: 900.0,
383                time_s: 0.2,
384                flags: Vec::new(),
385            },
386        ];
387
388        assert_eq!(find_closest_sample_index(&samples, 5.0), Some(0));
389        assert_eq!(find_closest_sample_index(&samples, 12.0), Some(1));
390        assert_eq!(find_closest_sample_index(&samples, 18.0), Some(2));
391    }
392
393    #[test]
394    fn test_detect_zero_crossings() {
395        let mut samples = vec![
396            TrajectorySample {
397                distance_m: 0.0,
398                drop_m: 1.0, // Positive
399                wind_drift_m: 0.0,
400                velocity_mps: 100.0,
401                energy_j: 1000.0,
402                time_s: 0.0,
403                flags: Vec::new(),
404            },
405            TrajectorySample {
406                distance_m: 10.0,
407                drop_m: -0.5, // Negative - crossing here
408                wind_drift_m: 0.1,
409                velocity_mps: 95.0,
410                energy_j: 950.0,
411                time_s: 0.1,
412                flags: Vec::new(),
413            },
414            TrajectorySample {
415                distance_m: 20.0,
416                drop_m: -2.0, // Still negative
417                wind_drift_m: 0.2,
418                velocity_mps: 90.0,
419                energy_j: 900.0,
420                time_s: 0.2,
421                flags: Vec::new(),
422            },
423        ];
424
425        detect_zero_crossings(&mut samples, 1e-6);
426
427        // Should have a zero crossing flag at index 1
428        assert!(!samples[0].flags.contains(&TrajectoryFlag::ZeroCrossing));
429        assert!(samples[1].flags.contains(&TrajectoryFlag::ZeroCrossing));
430        assert!(!samples[2].flags.contains(&TrajectoryFlag::ZeroCrossing));
431    }
432
433    #[test]
434    fn test_sample_trajectory_basic() {
435        // Create simple test trajectory data
436        // Note: x=downrange, y=vertical, z=lateral (STANDARD BALLISTICS CONVENTION)
437        let trajectory_data = TrajectoryData {
438            times: vec![0.0, 1.0, 2.0],
439            positions: vec![
440                Vector3::new(0.0, 0.0, 0.0), // x=0 (start), y=0 (vertical), z=0 (lateral)
441                Vector3::new(100.0, 10.0, 1.0), // x=100 (mid - apex region), y=10, z=1 (drift)
442                Vector3::new(200.0, 5.0, 2.0), // x=200 (end), y=5, z=2 (drift)
443            ],
444            velocities: vec![
445                Vector3::new(1.0, 10.0, 100.0),
446                Vector3::new(1.0, 5.0, 95.0),
447                Vector3::new(1.0, 0.0, 90.0),
448            ],
449            transonic_distances: vec![150.0],
450        };
451
452        let outputs = TrajectoryOutputs {
453            target_distance_horiz_m: 200.0,
454            target_vertical_height_m: 0.0,
455            time_of_flight_s: 2.0,
456            max_ord_dist_horiz_m: 100.0,
457        };
458
459        let samples = sample_trajectory(&trajectory_data, &outputs, 50.0, 0.1);
460
461        // Should have samples at 0, 50, 100, 150, 200 meters
462        assert_eq!(samples.len(), 5);
463        assert_eq!(samples[0].distance_m, 0.0);
464        assert_eq!(samples[1].distance_m, 50.0);
465        assert_eq!(samples[2].distance_m, 100.0);
466        assert_eq!(samples[3].distance_m, 150.0);
467        assert_eq!(samples[4].distance_m, 200.0);
468
469        // Check that interpolation is working
470        assert!(samples[1].velocity_mps > 90.0 && samples[1].velocity_mps < 100.0);
471
472        // Check flags
473        assert!(samples[2].flags.contains(&TrajectoryFlag::Apex)); // At apex distance
474        assert!(samples[3].flags.contains(&TrajectoryFlag::MachTransition)); // At transonic distance
475    }
476}