ballistics_engine/
api_client.rs

1//! HTTP client for Flask API communication
2//!
3//! This module provides an HTTP client for routing trajectory calculations
4//! through the Flask API instead of local computation, giving CLI users
5//! access to ML-enhanced predictions.
6
7use serde::{Deserialize, Serialize};
8use std::time::Duration;
9
10/// Request structure for trajectory calculation via Flask API
11#[derive(Debug, Clone, Serialize)]
12pub struct TrajectoryRequest {
13    /// Ballistic coefficient value
14    pub bc_value: f64,
15    /// BC type: "G1" or "G7"
16    pub bc_type: String,
17    /// Bullet mass in grams
18    pub bullet_mass: f64,
19    /// Muzzle velocity in m/s
20    pub muzzle_velocity: f64,
21    /// Target distance in meters
22    pub target_distance: f64,
23    /// Zero range in meters (optional)
24    #[serde(skip_serializing_if = "Option::is_none")]
25    pub zero_range: Option<f64>,
26    /// Wind speed in m/s (optional)
27    #[serde(skip_serializing_if = "Option::is_none")]
28    pub wind_speed: Option<f64>,
29    /// Wind angle in degrees (optional)
30    #[serde(skip_serializing_if = "Option::is_none")]
31    pub wind_angle: Option<f64>,
32    /// Temperature in Celsius (optional)
33    #[serde(skip_serializing_if = "Option::is_none")]
34    pub temperature: Option<f64>,
35    /// Pressure in hPa/mbar (optional)
36    #[serde(skip_serializing_if = "Option::is_none")]
37    pub pressure: Option<f64>,
38    /// Humidity percentage 0-100 (optional)
39    #[serde(skip_serializing_if = "Option::is_none")]
40    pub humidity: Option<f64>,
41    /// Altitude in meters (optional)
42    #[serde(skip_serializing_if = "Option::is_none")]
43    pub altitude: Option<f64>,
44    /// Latitude for Coriolis calculations (optional)
45    #[serde(skip_serializing_if = "Option::is_none")]
46    pub latitude: Option<f64>,
47    /// Shooting angle in degrees (optional)
48    #[serde(skip_serializing_if = "Option::is_none")]
49    pub shooting_angle: Option<f64>,
50    /// Barrel twist rate in inches per turn (optional)
51    #[serde(skip_serializing_if = "Option::is_none")]
52    pub twist_rate: Option<f64>,
53    /// Bullet diameter in meters (optional)
54    #[serde(skip_serializing_if = "Option::is_none")]
55    pub bullet_diameter: Option<f64>,
56    /// Bullet length in meters (optional)
57    #[serde(skip_serializing_if = "Option::is_none")]
58    pub bullet_length: Option<f64>,
59}
60
61/// Response structure from Flask API trajectory calculation
62#[derive(Debug, Clone, Deserialize)]
63pub struct TrajectoryResponse {
64    /// Array of trajectory points
65    pub trajectory: Vec<ApiTrajectoryPoint>,
66    /// Zero angle in radians
67    pub zero_angle: f64,
68    /// Total time of flight in seconds
69    pub time_of_flight: f64,
70    /// BC confidence score (0-1) if available
71    #[serde(default)]
72    pub bc_confidence: Option<f64>,
73    /// List of ML corrections applied
74    #[serde(default)]
75    pub ml_corrections_applied: Option<Vec<String>>,
76    /// Maximum ordinate (height) in meters
77    #[serde(default)]
78    pub max_ordinate: Option<f64>,
79    /// Impact velocity in m/s
80    #[serde(default)]
81    pub impact_velocity: Option<f64>,
82    /// Impact energy in Joules
83    #[serde(default)]
84    pub impact_energy: Option<f64>,
85}
86
87/// A single point in the trajectory from API response
88#[derive(Debug, Clone, Deserialize)]
89pub struct ApiTrajectoryPoint {
90    /// Range/distance in meters
91    pub range: f64,
92    /// Drop below line of sight in meters (negative = below)
93    pub drop: f64,
94    /// Wind drift in meters
95    pub drift: f64,
96    /// Velocity at this point in m/s
97    pub velocity: f64,
98    /// Kinetic energy at this point in Joules
99    pub energy: f64,
100    /// Time of flight to this point in seconds
101    pub time: f64,
102}
103
104/// Error types for API communication
105#[derive(Debug)]
106pub enum ApiError {
107    /// Network connectivity error
108    NetworkError(String),
109    /// Request timed out
110    Timeout,
111    /// Invalid or unparseable response
112    InvalidResponse(String),
113    /// HTTP error from server (status code, message)
114    ServerError(u16, String),
115    /// Request building error
116    RequestError(String),
117}
118
119impl std::fmt::Display for ApiError {
120    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
121        match self {
122            ApiError::NetworkError(msg) => write!(f, "Network error: {}", msg),
123            ApiError::Timeout => write!(f, "Request timed out"),
124            ApiError::InvalidResponse(msg) => write!(f, "Invalid response: {}", msg),
125            ApiError::ServerError(code, msg) => write!(f, "Server error {}: {}", code, msg),
126            ApiError::RequestError(msg) => write!(f, "Request error: {}", msg),
127        }
128    }
129}
130
131impl std::error::Error for ApiError {}
132
133/// HTTP client for Flask API communication
134pub struct ApiClient {
135    base_url: String,
136    timeout: Duration,
137}
138
139impl ApiClient {
140    /// Create a new API client
141    ///
142    /// # Arguments
143    /// * `base_url` - Base URL of the Flask API (e.g., "https://api.ballistics.7.62x51mm.sh")
144    /// * `timeout_secs` - Request timeout in seconds
145    pub fn new(base_url: &str, timeout_secs: u64) -> Self {
146        // Normalize URL by removing trailing slash
147        let base_url = base_url.trim_end_matches('/').to_string();
148
149        Self {
150            base_url,
151            timeout: Duration::from_secs(timeout_secs),
152        }
153    }
154
155    /// Calculate trajectory via Flask API
156    ///
157    /// # Arguments
158    /// * `request` - Trajectory calculation request parameters
159    ///
160    /// # Returns
161    /// * `Ok(TrajectoryResponse)` - Successful calculation with trajectory data
162    /// * `Err(ApiError)` - Error during API communication
163    #[cfg(feature = "online")]
164    pub fn calculate_trajectory(
165        &self,
166        request: &TrajectoryRequest,
167    ) -> Result<TrajectoryResponse, ApiError> {
168        // Flask API uses GET /v1/calculate with query parameters (imperial units)
169        let url = format!("{}/v1/calculate", self.base_url);
170
171        // Convert metric values to imperial for API
172        let velocity_fps = request.muzzle_velocity / 0.3048; // m/s to fps
173        let mass_grains = request.bullet_mass / 0.0647989; // grams to grains
174        let distance_yards = request.target_distance / 0.9144; // meters to yards
175
176        let mut req = ureq::get(&url)
177            .set("Accept", "application/json")
178            .set("User-Agent", "ballistics-cli/0.13.19")
179            .timeout(self.timeout)
180            .query("bc_value", &request.bc_value.to_string())
181            .query("bc_type", &request.bc_type)
182            .query("bullet_mass", &format!("{:.1}", mass_grains))
183            .query("muzzle_velocity", &format!("{:.1}", velocity_fps))
184            .query("target_distance", &format!("{:.1}", distance_yards));
185
186        // Add optional parameters
187        if let Some(zero_range) = request.zero_range {
188            let zero_yards = zero_range / 0.9144;
189            req = req.query("zero_distance", &format!("{:.1}", zero_yards));
190        }
191        if let Some(wind_speed) = request.wind_speed {
192            let wind_mph = wind_speed * 2.23694; // m/s to mph
193            req = req.query("wind_speed", &format!("{:.1}", wind_mph));
194        }
195        if let Some(wind_angle) = request.wind_angle {
196            req = req.query("wind_angle", &format!("{:.1}", wind_angle));
197        }
198        if let Some(temp) = request.temperature {
199            let temp_f = temp * 9.0 / 5.0 + 32.0; // Celsius to Fahrenheit
200            req = req.query("temperature", &format!("{:.1}", temp_f));
201        }
202        if let Some(pressure) = request.pressure {
203            let pressure_inhg = pressure / 33.8639; // hPa to inHg
204            req = req.query("pressure", &format!("{:.2}", pressure_inhg));
205        }
206        if let Some(humidity) = request.humidity {
207            req = req.query("humidity", &format!("{:.1}", humidity));
208        }
209        if let Some(altitude) = request.altitude {
210            let altitude_ft = altitude / 0.3048; // meters to feet
211            req = req.query("altitude", &format!("{:.1}", altitude_ft));
212        }
213        if let Some(shooting_angle) = request.shooting_angle {
214            req = req.query("shooting_angle", &format!("{:.1}", shooting_angle));
215        }
216        if let Some(latitude) = request.latitude {
217            req = req.query("latitude", &format!("{:.2}", latitude));
218        }
219        if let Some(twist_rate) = request.twist_rate {
220            req = req.query("twist_rate", &format!("{:.1}", twist_rate));
221        }
222        if let Some(diameter) = request.bullet_diameter {
223            let diameter_in = diameter / 0.0254; // meters to inches
224            req = req.query("bullet_diameter", &format!("{:.3}", diameter_in));
225        }
226
227        let response = req.call().map_err(|e| match e {
228            ureq::Error::Status(code, response) => {
229                let body = response.into_string().unwrap_or_default();
230                ApiError::ServerError(code, body)
231            }
232            ureq::Error::Transport(transport) => {
233                // Check for timeout by looking at the error message
234                let msg = transport.to_string();
235                if msg.contains("timed out") || msg.contains("timeout") {
236                    ApiError::Timeout
237                } else {
238                    ApiError::NetworkError(msg)
239                }
240            }
241        })?;
242
243        let body = response
244            .into_string()
245            .map_err(|e| ApiError::InvalidResponse(e.to_string()))?;
246
247        // Parse the Flask API response and convert to our format
248        let api_response: serde_json::Value = serde_json::from_str(&body)
249            .map_err(|e| ApiError::InvalidResponse(format!("JSON parse error: {}", e)))?;
250
251        // Convert Flask API response to our TrajectoryResponse format
252        self.convert_api_response(&api_response)
253    }
254
255    /// Helper to extract a value from a nested {value: x, unit: y} structure or plain number
256    #[cfg(feature = "online")]
257    fn extract_value(val: &serde_json::Value) -> Option<f64> {
258        // Try nested {value: x} first, then plain number
259        val.get("value")
260            .and_then(|v| v.as_f64())
261            .or_else(|| val.as_f64())
262    }
263
264    #[cfg(feature = "online")]
265    fn convert_api_response(&self, api_response: &serde_json::Value) -> Result<TrajectoryResponse, ApiError> {
266        // Get results object
267        let results = api_response.get("results");
268
269        // Extract trajectory points from Flask API response
270        // The Flask API returns trajectory in "trajectory" array with nested value objects
271        let trajectory_array = api_response.get("trajectory")
272            .and_then(|t| t.as_array())
273            .ok_or_else(|| ApiError::InvalidResponse("Missing trajectory array".to_string()))?;
274
275        let trajectory: Vec<ApiTrajectoryPoint> = trajectory_array
276            .iter()
277            .filter_map(|point| {
278                // Flask API returns nested {value: x, unit: y} objects in imperial units
279                let range_yards = point.get("distance")
280                    .and_then(Self::extract_value)?;
281                let drop_inches = point.get("drop")
282                    .and_then(Self::extract_value)
283                    .unwrap_or(0.0);
284                let drift_inches = point.get("wind_drift")
285                    .and_then(Self::extract_value)
286                    .unwrap_or(0.0);
287                let velocity_fps = point.get("velocity")
288                    .and_then(Self::extract_value)?;
289                let energy_ftlbs = point.get("energy")
290                    .and_then(Self::extract_value)
291                    .unwrap_or(0.0);
292                let time = point.get("time")
293                    .and_then(Self::extract_value)
294                    .unwrap_or(0.0);
295
296                Some(ApiTrajectoryPoint {
297                    range: range_yards * 0.9144,        // yards to meters
298                    drop: drop_inches * 0.0254,          // inches to meters
299                    drift: drift_inches * 0.0254,        // inches to meters
300                    velocity: velocity_fps * 0.3048,     // fps to m/s
301                    energy: energy_ftlbs * 1.35582,      // ft-lbs to Joules
302                    time,
303                })
304            })
305            .collect();
306
307        // Extract summary values from results object
308        let zero_angle = results
309            .and_then(|r| r.get("barrel_angle"))
310            .and_then(Self::extract_value)
311            .unwrap_or(0.0)
312            .to_radians();
313
314        let time_of_flight = results
315            .and_then(|r| r.get("time_of_flight"))
316            .and_then(Self::extract_value)
317            .unwrap_or_else(|| trajectory.last().map(|p| p.time).unwrap_or(0.0));
318
319        let bc_confidence = api_response.get("bc_confidence")
320            .and_then(|v| v.as_f64());
321
322        let ml_corrections = api_response.get("ml_corrections_applied")
323            .or_else(|| api_response.get("corrections_applied"))
324            .and_then(|v| v.as_array())
325            .map(|arr| {
326                arr.iter()
327                    .filter_map(|v| v.as_str().map(String::from))
328                    .collect()
329            });
330
331        // max_height is directly a number in results
332        let max_ordinate = results
333            .and_then(|r| r.get("max_height"))
334            .and_then(|v| v.as_f64())
335            .map(|h| h * 0.0254); // inches to meters
336
337        let impact_velocity = results
338            .and_then(|r| r.get("final_velocity"))
339            .and_then(Self::extract_value)
340            .map(|v| v * 0.3048); // fps to m/s
341
342        let impact_energy = results
343            .and_then(|r| r.get("final_energy"))
344            .and_then(Self::extract_value)
345            .map(|e| e * 1.35582); // ft-lbs to Joules
346
347        Ok(TrajectoryResponse {
348            trajectory,
349            zero_angle,
350            time_of_flight,
351            bc_confidence,
352            ml_corrections_applied: ml_corrections,
353            max_ordinate,
354            impact_velocity,
355            impact_energy,
356        })
357    }
358
359    /// Check API health
360    #[cfg(feature = "online")]
361    pub fn health_check(&self) -> Result<bool, ApiError> {
362        let url = format!("{}/health", self.base_url);
363
364        let response = ureq::get(&url)
365            .timeout(Duration::from_secs(5))
366            .call()
367            .map_err(|e| match e {
368                ureq::Error::Status(code, response) => {
369                    let body = response.into_string().unwrap_or_default();
370                    ApiError::ServerError(code, body)
371                }
372                ureq::Error::Transport(transport) => {
373                    let msg = transport.to_string();
374                    if msg.contains("timed out") || msg.contains("timeout") {
375                        ApiError::Timeout
376                    } else {
377                        ApiError::NetworkError(msg)
378                    }
379                }
380            })?;
381
382        Ok(response.status() == 200)
383    }
384}
385
386/// Builder for TrajectoryRequest
387#[derive(Default)]
388pub struct TrajectoryRequestBuilder {
389    bc_value: Option<f64>,
390    bc_type: Option<String>,
391    bullet_mass: Option<f64>,
392    muzzle_velocity: Option<f64>,
393    target_distance: Option<f64>,
394    zero_range: Option<f64>,
395    wind_speed: Option<f64>,
396    wind_angle: Option<f64>,
397    temperature: Option<f64>,
398    pressure: Option<f64>,
399    humidity: Option<f64>,
400    altitude: Option<f64>,
401    latitude: Option<f64>,
402    shooting_angle: Option<f64>,
403    twist_rate: Option<f64>,
404    bullet_diameter: Option<f64>,
405    bullet_length: Option<f64>,
406}
407
408impl TrajectoryRequestBuilder {
409    pub fn new() -> Self {
410        Self::default()
411    }
412
413    pub fn bc_value(mut self, value: f64) -> Self {
414        self.bc_value = Some(value);
415        self
416    }
417
418    pub fn bc_type(mut self, value: &str) -> Self {
419        self.bc_type = Some(value.to_string());
420        self
421    }
422
423    pub fn bullet_mass(mut self, value: f64) -> Self {
424        self.bullet_mass = Some(value);
425        self
426    }
427
428    pub fn muzzle_velocity(mut self, value: f64) -> Self {
429        self.muzzle_velocity = Some(value);
430        self
431    }
432
433    pub fn target_distance(mut self, value: f64) -> Self {
434        self.target_distance = Some(value);
435        self
436    }
437
438    pub fn zero_range(mut self, value: f64) -> Self {
439        self.zero_range = Some(value);
440        self
441    }
442
443    pub fn wind_speed(mut self, value: f64) -> Self {
444        self.wind_speed = Some(value);
445        self
446    }
447
448    pub fn wind_angle(mut self, value: f64) -> Self {
449        self.wind_angle = Some(value);
450        self
451    }
452
453    pub fn temperature(mut self, value: f64) -> Self {
454        self.temperature = Some(value);
455        self
456    }
457
458    pub fn pressure(mut self, value: f64) -> Self {
459        self.pressure = Some(value);
460        self
461    }
462
463    pub fn humidity(mut self, value: f64) -> Self {
464        self.humidity = Some(value);
465        self
466    }
467
468    pub fn altitude(mut self, value: f64) -> Self {
469        self.altitude = Some(value);
470        self
471    }
472
473    pub fn latitude(mut self, value: f64) -> Self {
474        self.latitude = Some(value);
475        self
476    }
477
478    pub fn shooting_angle(mut self, value: f64) -> Self {
479        self.shooting_angle = Some(value);
480        self
481    }
482
483    pub fn twist_rate(mut self, value: f64) -> Self {
484        self.twist_rate = Some(value);
485        self
486    }
487
488    pub fn bullet_diameter(mut self, value: f64) -> Self {
489        self.bullet_diameter = Some(value);
490        self
491    }
492
493    pub fn bullet_length(mut self, value: f64) -> Self {
494        self.bullet_length = Some(value);
495        self
496    }
497
498    /// Build the TrajectoryRequest
499    ///
500    /// # Returns
501    /// * `Ok(TrajectoryRequest)` - Valid request
502    /// * `Err(String)` - Missing required fields
503    pub fn build(self) -> Result<TrajectoryRequest, String> {
504        let bc_value = self.bc_value.ok_or("bc_value is required")?;
505        let bc_type = self.bc_type.ok_or("bc_type is required")?;
506        let bullet_mass = self.bullet_mass.ok_or("bullet_mass is required")?;
507        let muzzle_velocity = self.muzzle_velocity.ok_or("muzzle_velocity is required")?;
508        let target_distance = self.target_distance.ok_or("target_distance is required")?;
509
510        Ok(TrajectoryRequest {
511            bc_value,
512            bc_type,
513            bullet_mass,
514            muzzle_velocity,
515            target_distance,
516            zero_range: self.zero_range,
517            wind_speed: self.wind_speed,
518            wind_angle: self.wind_angle,
519            temperature: self.temperature,
520            pressure: self.pressure,
521            humidity: self.humidity,
522            altitude: self.altitude,
523            latitude: self.latitude,
524            shooting_angle: self.shooting_angle,
525            twist_rate: self.twist_rate,
526            bullet_diameter: self.bullet_diameter,
527            bullet_length: self.bullet_length,
528        })
529    }
530}
531
532#[cfg(test)]
533mod tests {
534    use super::*;
535
536    #[test]
537    fn test_request_builder_required_fields() {
538        let result = TrajectoryRequestBuilder::new()
539            .bc_value(0.238)
540            .bc_type("G7")
541            .bullet_mass(9.07) // 140gr in grams
542            .muzzle_velocity(860.0)
543            .target_distance(1000.0)
544            .build();
545
546        assert!(result.is_ok());
547        let request = result.unwrap();
548        assert_eq!(request.bc_value, 0.238);
549        assert_eq!(request.bc_type, "G7");
550    }
551
552    #[test]
553    fn test_request_builder_missing_fields() {
554        let result = TrajectoryRequestBuilder::new()
555            .bc_value(0.238)
556            .build();
557
558        assert!(result.is_err());
559    }
560
561    #[test]
562    fn test_request_builder_all_optional_fields() {
563        let result = TrajectoryRequestBuilder::new()
564            .bc_value(0.238)
565            .bc_type("G7")
566            .bullet_mass(9.07)
567            .muzzle_velocity(860.0)
568            .target_distance(1000.0)
569            .zero_range(100.0)
570            .wind_speed(5.0)
571            .wind_angle(90.0)
572            .temperature(15.0)
573            .pressure(1013.25)
574            .humidity(50.0)
575            .altitude(500.0)
576            .latitude(45.0)
577            .shooting_angle(0.0)
578            .twist_rate(10.0)
579            .bullet_diameter(0.00671)
580            .bullet_length(0.035)
581            .build();
582
583        assert!(result.is_ok());
584        let request = result.unwrap();
585        assert_eq!(request.zero_range, Some(100.0));
586        assert_eq!(request.wind_speed, Some(5.0));
587        assert_eq!(request.latitude, Some(45.0));
588    }
589
590    #[test]
591    fn test_api_client_url_normalization() {
592        let client1 = ApiClient::new("https://api.example.com/", 10);
593        assert_eq!(client1.base_url, "https://api.example.com");
594
595        let client2 = ApiClient::new("https://api.example.com", 10);
596        assert_eq!(client2.base_url, "https://api.example.com");
597    }
598
599    #[test]
600    fn test_api_error_display() {
601        assert_eq!(
602            format!("{}", ApiError::NetworkError("connection refused".to_string())),
603            "Network error: connection refused"
604        );
605        assert_eq!(format!("{}", ApiError::Timeout), "Request timed out");
606        assert_eq!(
607            format!("{}", ApiError::ServerError(500, "Internal error".to_string())),
608            "Server error 500: Internal error"
609        );
610    }
611}