Skip to main content

dsfb_starship/
lib.rs

1pub mod config;
2pub mod estimators;
3pub mod output;
4pub mod physics;
5pub mod sensors;
6
7use std::fs;
8use std::path::{Path, PathBuf};
9
10use anyhow::Context;
11use chrono::Utc;
12use nalgebra::Vector3;
13use pyo3::exceptions::PyRuntimeError;
14use pyo3::prelude::*;
15use pyo3::types::PyModule;
16use rand::{Rng, SeedableRng};
17use rand_chacha::ChaCha8Rng;
18use rand_distr::StandardNormal;
19
20use crate::config::SimConfig;
21use crate::estimators::{mean_measurement, DsfbFusionLayer, NavState, SimpleEkf};
22use crate::output::{make_plots, write_csv, write_summary, MethodMetrics, OutputFiles, SimRecord, Summary};
23use crate::physics::{initial_truth_state, truth_step, ReentryEventState, VehicleParams};
24use crate::sensors::ImuArray;
25
26pub fn run_simulation(cfg: &SimConfig, output_dir: &Path) -> anyhow::Result<Summary> {
27    cfg.validate()?;
28    let output_base_dir = resolve_output_base_dir(output_dir);
29    let output_dir = create_timestamped_run_dir(&output_base_dir)?;
30
31    let vehicle = VehicleParams::default();
32    let mut truth = initial_truth_state(cfg, &vehicle);
33    let mut events = ReentryEventState::default();
34    let mut imu_array = ImuArray::new(cfg.seed, cfg.imu_count);
35
36    let mut inertial = NavState::from_truth_with_seed_error(&truth, 1.00);
37    let mut ekf = SimpleEkf::new(NavState::from_truth_with_seed_error(&truth, 1.12));
38    let mut dsfb_nav = NavState::from_truth_with_seed_error(&truth, 0.86);
39    let mut dsfb_fusion = DsfbFusionLayer::new(cfg);
40
41    let mut gnss_rng = ChaCha8Rng::seed_from_u64(cfg.seed ^ 0xCAB00D1E_u64);
42
43    let mut records = Vec::with_capacity(cfg.steps());
44
45    let mut blackout_start: Option<f64> = None;
46    let mut blackout_end: Option<f64> = None;
47
48    for step_idx in 0..cfg.steps() {
49        let t_s = step_idx as f64 * cfg.dt;
50
51        let truth_sample = truth_step(&mut truth, &vehicle, cfg, t_s, cfg.dt, &mut events);
52        let imu_measurements = imu_array.measure(
53            truth_sample.aero.specific_force_b_mps2,
54            truth.omega_b_rps,
55            truth.heat_shield_temp_k,
56            t_s,
57            &events,
58        );
59
60        // Pure inertial baseline: first IMU only.
61        if let Some(primary) = imu_measurements.first() {
62            inertial.propagate(primary.accel_b_mps2, primary.gyro_b_rps, cfg.dt);
63        }
64
65        // Simple EKF baseline: average IMU propagation + GNSS update when not in blackout.
66        let mean_imu = mean_measurement(&imu_measurements);
67        ekf.propagate(mean_imu.accel_b_mps2, mean_imu.gyro_b_rps, cfg.dt);
68
69        // DSFB fusion over redundant IMUs.
70        let dsfb_out = dsfb_fusion.fuse(&imu_measurements, cfg.dt);
71        dsfb_nav.propagate(dsfb_out.fused_accel_b_mps2, dsfb_out.fused_gyro_b_rps, cfg.dt);
72
73        if !finite_nav(&truth.pos_n_m, &truth.vel_n_mps)
74            || !finite_nav(&inertial.pos_n_m, &inertial.vel_n_mps)
75            || !finite_nav(&ekf.nav.pos_n_m, &ekf.nav.vel_n_mps)
76            || !finite_nav(&dsfb_nav.pos_n_m, &dsfb_nav.vel_n_mps)
77        {
78            break;
79        }
80
81        let is_blackout = truth_sample.blackout;
82        if is_blackout {
83            if blackout_start.is_none() {
84                blackout_start = Some(t_s);
85            }
86        } else if blackout_start.is_some() && blackout_end.is_none() {
87            blackout_end = Some(t_s);
88        }
89
90        // GNSS aiding outside blackout at 1 Hz.
91        if !is_blackout && step_idx % (1.0 / cfg.dt).round().max(1.0) as usize == 0 {
92            let gnss_pos = truth.pos_n_m
93                + Vector3::new(
94                    gaussian(&mut gnss_rng, 5.5),
95                    gaussian(&mut gnss_rng, 5.5),
96                    gaussian(&mut gnss_rng, 7.0),
97                );
98            let gnss_vel = truth.vel_n_mps
99                + Vector3::new(
100                    gaussian(&mut gnss_rng, 0.75),
101                    gaussian(&mut gnss_rng, 0.75),
102                    gaussian(&mut gnss_rng, 0.90),
103                );
104
105            ekf.update_gnss(gnss_pos, gnss_vel);
106
107            dsfb_nav.pos_n_m = dsfb_nav.pos_n_m * 0.75 + gnss_pos * 0.25;
108            dsfb_nav.vel_n_mps = dsfb_nav.vel_n_mps * 0.70 + gnss_vel * 0.30;
109        }
110
111        let trust_imu0 = *dsfb_out.trust_weights.first().unwrap_or(&0.0);
112        let trust_imu1 = *dsfb_out.trust_weights.get(1).unwrap_or(&0.0);
113        let trust_imu2 = *dsfb_out.trust_weights.get(2).unwrap_or(&0.0);
114
115        let resid_imu0 = *dsfb_out.residual_increments.first().unwrap_or(&0.0);
116        let resid_imu1 = *dsfb_out.residual_increments.get(1).unwrap_or(&0.0);
117        let resid_imu2 = *dsfb_out.residual_increments.get(2).unwrap_or(&0.0);
118
119        records.push(SimRecord {
120            time_s: t_s,
121            altitude_m: truth.altitude_m(),
122            speed_mps: truth.vel_n_mps.norm(),
123            mach: truth_sample.aero.mach,
124            dynamic_pressure_pa: truth_sample.aero.dynamic_pressure_pa,
125            heat_flux_w_m2: truth_sample.heat_flux_w_m2,
126            heat_shield_temp_k: truth.heat_shield_temp_k,
127            blackout: is_blackout,
128
129            truth_x_km: truth.pos_n_m.x / 1_000.0,
130            truth_y_km: truth.pos_n_m.y / 1_000.0,
131            truth_z_km: truth.pos_n_m.z / 1_000.0,
132
133            inertial_x_km: inertial.pos_n_m.x / 1_000.0,
134            inertial_y_km: inertial.pos_n_m.y / 1_000.0,
135            inertial_z_km: inertial.pos_n_m.z / 1_000.0,
136            ekf_x_km: ekf.nav.pos_n_m.x / 1_000.0,
137            ekf_y_km: ekf.nav.pos_n_m.y / 1_000.0,
138            ekf_z_km: ekf.nav.pos_n_m.z / 1_000.0,
139            dsfb_x_km: dsfb_nav.pos_n_m.x / 1_000.0,
140            dsfb_y_km: dsfb_nav.pos_n_m.y / 1_000.0,
141            dsfb_z_km: dsfb_nav.pos_n_m.z / 1_000.0,
142
143            inertial_pos_err_m: inertial.position_error_m(&truth),
144            inertial_vel_err_mps: inertial.velocity_error_mps(&truth),
145            inertial_att_err_deg: inertial.attitude_error_deg(&truth),
146            ekf_pos_err_m: ekf.nav.position_error_m(&truth),
147            ekf_vel_err_mps: ekf.nav.velocity_error_mps(&truth),
148            ekf_att_err_deg: ekf.nav.attitude_error_deg(&truth),
149            dsfb_pos_err_m: dsfb_nav.position_error_m(&truth),
150            dsfb_vel_err_mps: dsfb_nav.velocity_error_mps(&truth),
151            dsfb_att_err_deg: dsfb_nav.attitude_error_deg(&truth),
152
153            dsfb_trust_imu0: trust_imu0,
154            dsfb_trust_imu1: trust_imu1,
155            dsfb_trust_imu2: trust_imu2,
156            dsfb_resid_inc_imu0: resid_imu0,
157            dsfb_resid_inc_imu1: resid_imu1,
158            dsfb_resid_inc_imu2: resid_imu2,
159        });
160
161        if truth.altitude_m() <= 18_000.0 {
162            break;
163        }
164    }
165
166    let blackout_duration_s = if let (Some(start), Some(end)) = (blackout_start, blackout_end) {
167        (end - start).max(0.0)
168    } else {
169        0.0
170    };
171
172    let files = OutputFiles {
173        output_dir: output_dir.clone(),
174        csv_path: output_dir.join("starship_timeseries.csv"),
175        summary_path: output_dir.join("starship_summary.json"),
176        plot_altitude_path: output_dir.join("plot_altitude.png"),
177        plot_error_path: output_dir.join("plot_position_error_log.png"),
178        plot_trust_path: output_dir.join("plot_dsfb_trust.png"),
179    };
180
181    let inertial_metrics = compute_metrics(
182        &records,
183        |r| r.inertial_pos_err_m,
184        |r| r.inertial_vel_err_mps,
185        |r| r.inertial_att_err_deg,
186    );
187    let ekf_metrics = compute_metrics(
188        &records,
189        |r| r.ekf_pos_err_m,
190        |r| r.ekf_vel_err_mps,
191        |r| r.ekf_att_err_deg,
192    );
193    let dsfb_metrics = compute_metrics(
194        &records,
195        |r| r.dsfb_pos_err_m,
196        |r| r.dsfb_vel_err_mps,
197        |r| r.dsfb_att_err_deg,
198    );
199
200    let summary = Summary {
201        config: cfg.clone(),
202        samples: records.len(),
203        blackout_start_s: blackout_start,
204        blackout_end_s: blackout_end,
205        blackout_duration_s,
206        inertial: inertial_metrics,
207        ekf: ekf_metrics,
208        dsfb: dsfb_metrics,
209        outputs: files.clone(),
210    };
211
212    write_csv(&files.csv_path, &records)?;
213    write_summary(&files.summary_path, &summary)?;
214    make_plots(&records, &files)?;
215
216    Ok(summary)
217}
218
219fn compute_metrics(
220    records: &[SimRecord],
221    pos_fn: impl Fn(&SimRecord) -> f64,
222    vel_fn: impl Fn(&SimRecord) -> f64,
223    att_fn: impl Fn(&SimRecord) -> f64,
224) -> MethodMetrics {
225    let mut pos_sq = 0.0;
226    let mut vel_sq = 0.0;
227    let mut att_sq = 0.0;
228    let mut max_pos = 0.0_f64;
229    let mut count = 0.0_f64;
230
231    for r in records {
232        let p = pos_fn(r);
233        let v = vel_fn(r);
234        let a = att_fn(r);
235        if !(p.is_finite() && v.is_finite() && a.is_finite()) {
236            continue;
237        }
238        pos_sq += p * p;
239        vel_sq += v * v;
240        att_sq += a * a;
241        max_pos = max_pos.max(p);
242        count += 1.0;
243    }
244
245    let final_pos = records
246        .iter()
247        .rev()
248        .find_map(|r| {
249            let p = pos_fn(r);
250            if p.is_finite() {
251                Some(p)
252            } else {
253                None
254            }
255        })
256        .unwrap_or(0.0);
257    let n = count.max(1.0);
258
259    MethodMetrics {
260        rmse_position_m: (pos_sq / n).sqrt(),
261        rmse_velocity_mps: (vel_sq / n).sqrt(),
262        rmse_attitude_deg: (att_sq / n).sqrt(),
263        final_position_error_m: final_pos,
264        max_position_error_m: max_pos,
265    }
266}
267
268fn gaussian(rng: &mut ChaCha8Rng, sigma: f64) -> f64 {
269    let z: f64 = rng.sample(StandardNormal);
270    sigma * z
271}
272
273fn finite_nav(pos: &Vector3<f64>, vel: &Vector3<f64>) -> bool {
274    pos.iter().all(|v| v.is_finite()) && vel.iter().all(|v| v.is_finite())
275}
276
277pub fn workspace_root_dir() -> PathBuf {
278    let manifest_dir = Path::new(env!("CARGO_MANIFEST_DIR"));
279    manifest_dir
280        .join("../..")
281        .canonicalize()
282        .unwrap_or_else(|_| manifest_dir.join("../.."))
283}
284
285pub fn default_output_base_dir() -> PathBuf {
286    workspace_root_dir().join("output-dsfb-starship")
287}
288
289fn resolve_output_base_dir(requested: &Path) -> PathBuf {
290    if requested.is_absolute() {
291        requested.to_path_buf()
292    } else {
293        workspace_root_dir().join(requested)
294    }
295}
296
297fn create_timestamped_run_dir(base_dir: &Path) -> anyhow::Result<PathBuf> {
298    fs::create_dir_all(base_dir)
299        .with_context(|| format!("failed to create output base directory {}", base_dir.display()))?;
300
301    let timestamp = Utc::now().format("%Y%m%d-%H%M%S").to_string();
302    let run_dir = base_dir.join(&timestamp);
303    if !run_dir.exists() {
304        fs::create_dir_all(&run_dir)?;
305        return Ok(run_dir);
306    }
307
308    let mut counter: usize = 1;
309    loop {
310        let candidate = base_dir.join(format!("{timestamp}-{counter:02}"));
311        if !candidate.exists() {
312            fs::create_dir_all(&candidate)?;
313            return Ok(candidate);
314        }
315        counter += 1;
316    }
317}
318
319#[pyfunction]
320#[pyo3(signature = (output_dir=None, dt=None, t_final=None, rho=None, slew_threshold=None, seed=None))]
321fn run_starship_simulation(
322    output_dir: Option<String>,
323    dt: Option<f64>,
324    t_final: Option<f64>,
325    rho: Option<f64>,
326    slew_threshold: Option<f64>,
327    seed: Option<u64>,
328) -> PyResult<String> {
329    let mut cfg = SimConfig::default();
330
331    if let Some(v) = dt {
332        cfg.dt = v;
333    }
334    if let Some(v) = t_final {
335        cfg.t_final = v;
336    }
337    if let Some(v) = rho {
338        cfg.rho = v;
339    }
340    if let Some(v) = slew_threshold {
341        cfg.slew_threshold_accel = v;
342        cfg.slew_threshold_gyro = (v * 0.055).max(0.15);
343    }
344    if let Some(v) = seed {
345        cfg.seed = v;
346    }
347
348    let out = output_dir
349        .map(PathBuf::from)
350        .unwrap_or_else(|| PathBuf::from("output-dsfb-starship"));
351
352    let summary = run_simulation(&cfg, &out)
353        .map_err(|e| PyRuntimeError::new_err(format!("simulation failed: {e:#}")))?;
354
355    serde_json::to_string_pretty(&summary)
356        .map_err(|e| PyRuntimeError::new_err(format!("summary serialization failed: {e}")))
357}
358
359#[pyfunction]
360fn default_config_json() -> PyResult<String> {
361    serde_json::to_string_pretty(&SimConfig::default())
362        .map_err(|e| PyRuntimeError::new_err(format!("config serialization failed: {e}")))
363}
364
365#[pymodule]
366fn dsfb_starship(m: &Bound<'_, PyModule>) -> PyResult<()> {
367    m.add_function(wrap_pyfunction!(run_starship_simulation, m)?)?;
368    m.add_function(wrap_pyfunction!(default_config_json, m)?)?;
369    Ok(())
370}