ballistics_engine/
trajectory_sampling.rs

1use nalgebra::Vector3;
2use std::collections::HashSet;
3
4// Constants for unit conversions
5const YARDS_TO_METERS: f64 = 0.9144;
6
7/// Trajectory flags for notable events
8#[derive(Debug, Clone, PartialEq, Eq, Hash)]
9pub enum TrajectoryFlag {
10    ZeroCrossing,
11    MachTransition,
12    Apex,
13}
14
15impl TrajectoryFlag {
16    pub fn to_string(&self) -> String {
17        match self {
18            TrajectoryFlag::ZeroCrossing => "zero_crossing".to_string(),
19            TrajectoryFlag::MachTransition => "mach_transition".to_string(),
20            TrajectoryFlag::Apex => "apex".to_string(),
21        }
22    }
23}
24
25/// Single trajectory sample point
26#[derive(Debug, Clone)]
27pub struct TrajectorySample {
28    pub distance_m: f64,
29    pub drop_m: f64,
30    pub wind_drift_m: f64,
31    pub velocity_mps: f64,
32    pub energy_j: f64,
33    pub time_s: f64,
34    pub flags: Vec<TrajectoryFlag>,
35}
36
37/// Trajectory solution data for sampling
38#[derive(Debug, Clone)]
39pub struct TrajectoryData {
40    pub times: Vec<f64>,
41    pub positions: Vec<Vector3<f64>>,  // [x, y, z] positions
42    pub velocities: Vec<Vector3<f64>>, // [vx, vy, vz] velocities
43    pub transonic_distances: Vec<f64>, // Distances where mach transitions occur
44}
45
46/// Output data for trajectory sampling
47#[derive(Debug, Clone)]
48pub struct TrajectoryOutputs {
49    pub target_distance_horiz_m: f64,
50    pub target_vertical_height_m: f64,
51    pub time_of_flight_s: f64,
52    pub max_ord_dist_horiz_m: f64,
53}
54
55/// Sample trajectory at regular distance intervals with vectorized operations
56pub fn sample_trajectory(
57    trajectory_data: &TrajectoryData,
58    outputs: &TrajectoryOutputs,
59    step_m: f64,
60    mass_kg: f64,
61) -> Vec<TrajectorySample> {
62    let step_size = if step_m <= 0.0 {
63        return Vec::new();
64    } else if step_m < 0.1 {
65        0.1
66    } else {
67        step_m
68    };
69    
70    let max_dist = outputs.target_distance_horiz_m;
71    if max_dist < 1e-9 {
72        return Vec::new();
73    }
74    
75    // Extract trajectory arrays for vectorized operations
76    let x_vals: Vec<f64> = trajectory_data.positions.iter().map(|p| p.x).collect();
77    let y_vals: Vec<f64> = trajectory_data.positions.iter().map(|p| p.y).collect();
78    let z_vals: Vec<f64> = trajectory_data.positions.iter().map(|p| p.z).collect();
79    
80    // Calculate speeds and energies
81    let speeds: Vec<f64> = trajectory_data.velocities.iter()
82        .map(|v| v.norm())
83        .collect();
84    let energies: Vec<f64> = speeds.iter()
85        .map(|&speed| 0.5 * mass_kg * speed * speed)
86        .collect();
87    
88    // Generate sampling distances
89    let num_steps = ((max_dist / step_size) + 0.5) as usize + 1;
90    let distances: Vec<f64> = (0..num_steps)
91        .map(|i| i as f64 * step_size)
92        .take_while(|&d| d <= max_dist + step_size * 0.5)
93        .collect();
94    
95    // Vectorized interpolation for all trajectory data
96    let mut samples = Vec::with_capacity(distances.len());
97    
98    for &distance in &distances {
99        let y_interp = interpolate(&x_vals, &y_vals, distance);
100        let wind_drift = interpolate(&x_vals, &z_vals, distance);
101        let velocity = interpolate(&x_vals, &speeds, distance);
102        let time = interpolate(&x_vals, &trajectory_data.times, distance);
103        let energy = interpolate(&x_vals, &energies, distance);
104        
105        // Calculate line-of-sight y-coordinate and drop
106        let los_y = outputs.target_vertical_height_m * distance / max_dist;
107        let drop = los_y - y_interp;
108        
109        samples.push(TrajectorySample {
110            distance_m: distance,
111            drop_m: drop,
112            wind_drift_m: wind_drift,
113            velocity_mps: velocity,
114            energy_j: energy,
115            time_s: time,
116            flags: Vec::new(), // Flags will be added later
117        });
118    }
119    
120    // Add flags using vectorized detection
121    add_trajectory_flags(&mut samples, &trajectory_data.transonic_distances, outputs.max_ord_dist_horiz_m);
122    
123    samples
124}
125
126/// Linear interpolation function optimized for trajectory data
127fn interpolate(x_vals: &[f64], y_vals: &[f64], x: f64) -> f64 {
128    if x_vals.is_empty() || y_vals.is_empty() {
129        return 0.0;
130    }
131    
132    if x_vals.len() != y_vals.len() {
133        return 0.0;
134    }
135    
136    if x <= x_vals[0] {
137        return y_vals[0];
138    }
139    
140    if x >= x_vals[x_vals.len() - 1] {
141        return y_vals[y_vals.len() - 1];
142    }
143    
144    // Binary search for the correct interval
145    let mut left = 0;
146    let mut right = x_vals.len() - 1;
147    
148    while right - left > 1 {
149        let mid = (left + right) / 2;
150        if x_vals[mid] <= x {
151            left = mid;
152        } else {
153            right = mid;
154        }
155    }
156    
157    // Linear interpolation
158    let x1 = x_vals[left];
159    let x2 = x_vals[right];
160    let y1 = y_vals[left];
161    let y2 = y_vals[right];
162    
163    if (x2 - x1).abs() < f64::EPSILON {
164        return y1;
165    }
166    
167    y1 + (y2 - y1) * (x - x1) / (x2 - x1)
168}
169
170/// Add trajectory flags using vectorized detection algorithms
171fn add_trajectory_flags(
172    samples: &mut [TrajectorySample],
173    transonic_distances: &[f64],
174    apex_distance: f64,
175) {
176    let tolerance = 1e-6;
177    
178    // 1. Zero crossings - vectorized detection
179    detect_zero_crossings(samples, tolerance);
180    
181    // 2. Mach transitions
182    for &transonic_dist in transonic_distances {
183        if let Some(idx) = find_closest_sample_index(samples, transonic_dist) {
184            samples[idx].flags.push(TrajectoryFlag::MachTransition);
185        }
186    }
187    
188    // 3. Apex
189    if apex_distance > 0.0 {
190        if let Some(idx) = find_closest_sample_index(samples, apex_distance) {
191            samples[idx].flags.push(TrajectoryFlag::Apex);
192        }
193    }
194}
195
196/// Detect zero crossings in trajectory drop values using vectorized operations
197fn detect_zero_crossings(samples: &mut [TrajectorySample], tolerance: f64) {
198    if samples.len() < 2 {
199        return;
200    }
201    
202    let drops: Vec<f64> = samples.iter().map(|s| s.drop_m).collect();
203    
204    // Find crossing indices where drop changes sign
205    for i in 0..(drops.len() - 1) {
206        let current = drops[i];
207        let next = drops[i + 1];
208        
209        // Check for sign change crossings
210        let crosses_zero = (current < -tolerance && next >= -tolerance) ||
211                          (current > tolerance && next <= tolerance);
212        
213        if crosses_zero {
214            samples[i + 1].flags.push(TrajectoryFlag::ZeroCrossing);
215        }
216    }
217    
218    // Find points very close to zero
219    for (i, &drop) in drops.iter().enumerate() {
220        if drop.abs() <= tolerance {
221            samples[i].flags.push(TrajectoryFlag::ZeroCrossing);
222        }
223    }
224    
225    // Remove duplicate zero crossing flags
226    for sample in samples.iter_mut() {
227        let mut unique_flags = Vec::new();
228        let mut seen = HashSet::new();
229        
230        for flag in &sample.flags {
231            if seen.insert(flag.clone()) {
232                unique_flags.push(flag.clone());
233            }
234        }
235        sample.flags = unique_flags;
236    }
237}
238
239/// Find the closest sample index to a given distance
240fn find_closest_sample_index(samples: &[TrajectorySample], target_distance: f64) -> Option<usize> {
241    if samples.is_empty() {
242        return None;
243    }
244    
245    // Binary search for the closest distance
246    let distances: Vec<f64> = samples.iter().map(|s| s.distance_m).collect();
247    
248    let mut left = 0;
249    let mut right = distances.len();
250    
251    while left < right {
252        let mid = (left + right) / 2;
253        if distances[mid] < target_distance {
254            left = mid + 1;
255        } else {
256            right = mid;
257        }
258    }
259    
260    // Find the closest point (could be left-1 or left)
261    let mut best_idx = left.min(distances.len() - 1);
262    
263    if left > 0 {
264        let left_dist = (distances[left - 1] - target_distance).abs();
265        let right_dist = (distances[best_idx] - target_distance).abs();
266        
267        if left_dist < right_dist {
268            best_idx = left - 1;
269        }
270    }
271    
272    Some(best_idx)
273}
274
275/// Convert trajectory samples to Python-compatible format
276pub fn trajectory_samples_to_dicts(samples: &[TrajectorySample]) -> Vec<TrajectoryDict> {
277    samples.iter().map(|sample| {
278        TrajectoryDict {
279            distance_m: sample.distance_m,
280            drop_m: sample.drop_m,
281            wind_drift_m: sample.wind_drift_m,
282            velocity_mps: sample.velocity_mps,
283            energy_j: sample.energy_j,
284            time_s: sample.time_s,
285            flags: sample.flags.iter().map(|f| f.to_string()).collect(),
286        }
287    }).collect()
288}
289
290/// Python-compatible trajectory sample structure
291#[derive(Debug, Clone)]
292pub struct TrajectoryDict {
293    pub distance_m: f64,
294    pub drop_m: f64,
295    pub wind_drift_m: f64,
296    pub velocity_mps: f64,
297    pub energy_j: f64,
298    pub time_s: f64,
299    pub flags: Vec<String>,
300}
301
302#[cfg(test)]
303mod tests {
304    use super::*;
305    
306    #[test]
307    fn test_interpolate() {
308        let x_vals = vec![0.0, 1.0, 2.0, 3.0];
309        let y_vals = vec![0.0, 10.0, 20.0, 30.0];
310        
311        assert_eq!(interpolate(&x_vals, &y_vals, 0.5), 5.0);
312        assert_eq!(interpolate(&x_vals, &y_vals, 1.5), 15.0);
313        assert_eq!(interpolate(&x_vals, &y_vals, 2.5), 25.0);
314        
315        // Test boundary conditions
316        assert_eq!(interpolate(&x_vals, &y_vals, -1.0), 0.0);  // Below range
317        assert_eq!(interpolate(&x_vals, &y_vals, 4.0), 30.0);  // Above range
318    }
319    
320    #[test]
321    fn test_find_closest_sample_index() {
322        let samples = vec![
323            TrajectorySample {
324                distance_m: 0.0,
325                drop_m: 0.0,
326                wind_drift_m: 0.0,
327                velocity_mps: 100.0,
328                energy_j: 1000.0,
329                time_s: 0.0,
330                flags: Vec::new(),
331            },
332            TrajectorySample {
333                distance_m: 10.0,
334                drop_m: -1.0,
335                wind_drift_m: 0.1,
336                velocity_mps: 95.0,
337                energy_j: 950.0,
338                time_s: 0.1,
339                flags: Vec::new(),
340            },
341            TrajectorySample {
342                distance_m: 20.0,
343                drop_m: -4.0,
344                wind_drift_m: 0.2,
345                velocity_mps: 90.0,
346                energy_j: 900.0,
347                time_s: 0.2,
348                flags: Vec::new(),
349            },
350        ];
351        
352        // Test that we find a valid index for each target
353        assert!(find_closest_sample_index(&samples, 5.0).is_some());
354        assert!(find_closest_sample_index(&samples, 12.0).is_some());
355        assert!(find_closest_sample_index(&samples, 18.0).is_some());
356    }
357    
358    #[test]
359    fn test_detect_zero_crossings() {
360        let mut samples = vec![
361            TrajectorySample {
362                distance_m: 0.0,
363                drop_m: 1.0,  // Positive
364                wind_drift_m: 0.0,
365                velocity_mps: 100.0,
366                energy_j: 1000.0,
367                time_s: 0.0,
368                flags: Vec::new(),
369            },
370            TrajectorySample {
371                distance_m: 10.0,
372                drop_m: -0.5,  // Negative - crossing here
373                wind_drift_m: 0.1,
374                velocity_mps: 95.0,
375                energy_j: 950.0,
376                time_s: 0.1,
377                flags: Vec::new(),
378            },
379            TrajectorySample {
380                distance_m: 20.0,
381                drop_m: -2.0,  // Still negative
382                wind_drift_m: 0.2,
383                velocity_mps: 90.0,
384                energy_j: 900.0,
385                time_s: 0.2,
386                flags: Vec::new(),
387            },
388        ];
389        
390        detect_zero_crossings(&mut samples, 1e-6);
391        
392        // Should have a zero crossing flag at index 1
393        assert!(!samples[0].flags.contains(&TrajectoryFlag::ZeroCrossing));
394        assert!(samples[1].flags.contains(&TrajectoryFlag::ZeroCrossing));
395        assert!(!samples[2].flags.contains(&TrajectoryFlag::ZeroCrossing));
396    }
397    
398    #[test]
399    fn test_sample_trajectory_basic() {
400        // Create simple test trajectory data
401        let trajectory_data = TrajectoryData {
402            times: vec![0.0, 1.0, 2.0],
403            positions: vec![
404                Vector3::new(0.0, 0.0, 0.0),
405                Vector3::new(100.0, 10.0, 1.0),
406                Vector3::new(200.0, 5.0, 2.0),
407            ],
408            velocities: vec![
409                Vector3::new(100.0, 10.0, 1.0),
410                Vector3::new(95.0, 5.0, 1.0),
411                Vector3::new(90.0, 0.0, 1.0),
412            ],
413            transonic_distances: vec![150.0],
414        };
415        
416        let outputs = TrajectoryOutputs {
417            target_distance_horiz_m: 200.0,
418            target_vertical_height_m: 0.0,
419            time_of_flight_s: 2.0,
420            max_ord_dist_horiz_m: 100.0,
421        };
422        
423        let samples = sample_trajectory(&trajectory_data, &outputs, 50.0, 0.1);
424        
425        // Should have samples at 0, 50, 100, 150, 200 meters
426        assert_eq!(samples.len(), 5);
427        assert_eq!(samples[0].distance_m, 0.0);
428        assert_eq!(samples[1].distance_m, 50.0);
429        assert_eq!(samples[2].distance_m, 100.0);
430        assert_eq!(samples[3].distance_m, 150.0);
431        assert_eq!(samples[4].distance_m, 200.0);
432        
433        // Check that interpolation is working
434        assert!(samples[1].velocity_mps > 90.0 && samples[1].velocity_mps < 100.0);
435        
436        // Check flags
437        assert!(samples[2].flags.contains(&TrajectoryFlag::Apex)); // At apex distance
438        assert!(samples[3].flags.contains(&TrajectoryFlag::MachTransition)); // At transonic distance
439    }
440}